-
Notifications
You must be signed in to change notification settings - Fork 417
Migrate Gemma3DecoderLayer and Gemma3ScannableBlock to NNX #2439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
100b132
to
2e56cb0
Compare
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
2e56cb0
to
dc4e6c6
Compare
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request migrates Gemma3DecoderLayer
and Gemma3ScannableBlock
from Flax Linen to NNX. The changes are well-structured and follow the imperative style of NNX. The core logic appears to be preserved correctly during the refactoring.
🔍 General Feedback
- The migration to NNX is clean and improves the code's structure by separating layer definition from execution.
- Good job on updating the necessary wrappers and configurations in
decoders.py
to accommodate the new NNX-based layers. - I've added a couple of minor suggestions to clean up unused imports and improve docstring consistency.
f5ae73f
to
c7a12b6
Compare
c7a12b6
to
b90252e
Compare
The before/after JetStream accuracy differs slightly for gemma3: https://diff.googleplex.com/#key=1I4o3eENf3aa
|
b90252e
to
3a140b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the new test results! Jetstream acc diff could be due to missing enable_dropout=False
in config, which is out of scope for this PR. I see you increased the profiling time from 200ms to 1000ms, this might have contributed to the HLO matching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Description
Tests
Training tests
Decoding tests
Noticed less memory usage (~10GB) in decoding in RAMstats.
== Updated
JetStream tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-review
label.