How would I reparametrize nnx.Module
parameters?
#4546
-
In jax, it's easy to re-parametrize a neural network using something similar to the following:
How do I achieve something similar using nnx since params are part of the model. Ofcourse I can use something like:
Now if i want to get the grads w.r.t. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 7 replies
-
I also looked into the LoRA implementation to see how this issue might be handled. Here's what happens: def __call__(self, x: jax.Array):
out = x @ self.lora_a @ self.lora_b
if self.base_module is not None:
if not callable(self.base_module):
raise ValueError('`self.base_module` must be callable.')
out += self.base_module(x)
return out But the problem here is that it is essentially calculating So essentially, my question could be simplified as follows. Given a reparametrization function, |
Beta Was this translation helpful? Give feedback.
-
Hi @aniquetahir, to get a gradient wrt to any substate you can pass a |
Beta Was this translation helpful? Give feedback.
@aniquetahir can you create a separate optimizer for
new_model
(maybe call itsampled_model
) at the begging and then simply update it after sampling its params? E.g.