You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi 👋,
I'm new to using Jax and I'm finding it challenging to convert my models from PyTorch to Flax NNX, particularly the LSTMs. For example, the following model processes data in batches of length 64 with 553 timesteps and 1 feature (64, 553, 1). How would the same model be implemented in Flax NNX? I feel that it would be helpful to supplement the package's documentation with more examples to assist new users.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi 👋,
I'm new to using Jax and I'm finding it challenging to convert my models from PyTorch to Flax NNX, particularly the LSTMs. For example, the following model processes data in batches of length 64 with 553 timesteps and 1 feature (64, 553, 1). How would the same model be implemented in Flax NNX? I feel that it would be helpful to supplement the package's documentation with more examples to assist new users.
Beta Was this translation helpful? Give feedback.
All reactions