-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Extended fix OOM Issue #21634 on Keras side #21755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Extended fix OOM Issue #21634 on Keras side #21755
Conversation
Summary of ChangesHello @praveenhosdrug123, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements a crucial fix to prevent Out-Of-Memory errors during the initialization of very large Keras models, particularly on resource-constrained TPU devices. By introducing a mechanism to distribute the initialization of token embeddings, it ensures that model weights are sharded across available devices from the outset, significantly reducing memory consumption and enabling the successful loading of models that previously failed. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new function, _distribute_initializer, to handle the distribution of token embedding initializers in the JAX backend. This function aims to resolve OOM errors encountered during the initialization of large models on TPUs with limited HBM by sharding weights across TPU devices during instantiation. The code includes argument validation, sharding logic based on tensor layout, and application of mean/stddev for relevant distributions. The review focuses on error handling, code clarity, and adherence to the Keras API design guidelines.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21755 +/- ##
==========================================
- Coverage 82.63% 82.60% -0.03%
==========================================
Files 572 577 +5
Lines 58555 59223 +668
Branches 9153 9286 +133
==========================================
+ Hits 48385 48922 +537
- Misses 7843 7924 +81
- Partials 2327 2377 +50
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
d33cb4e to
b36b051
Compare
|
Thank you for the investigation. This is indeed an issue. Somebody on the team is working on a fix that's generally applicable to all variables so that you don't have to explicitly use the fix that you provided here. |
|
@hertschuh - Thanks for the feedback and for taking the time to review the document. I want to clarify the technical issue: Thank you for the context on the general solution. A few follow-up questions to help me understand the timeline:
The reason I ask: users are blocked on this today for 7B+ models on 8GB TPU devices.
Let me know if that's feasible. |
Summary
Applies distributed initialization fix to model backbone to resolve OOM errors during initialization of 7B+ parameter models on 8GB TPU devices. This PR adds a helper function to distribute the initializers at time of instantiation.
Issue
Token embedding initialization creates large arrays at time of creation, placing all weights on a single device.
Combined with forward passes during backbone initialization, this causes a 2X to 3X memory spike and triggers OOM on TPUs with limited HBM.
Solution
Implements _distribute_initializer helper that wraps embedding initializers with explicit TensorLayout, properly sharding weights across TPU devices during instantiation. Validated on 8-device TPU: models that previously OOM'd during backbone initialization now load successfully.
Reference
For memory profiling analysis, cache locality theory, validation logs and alternative solutions considered, refer to: Doc
Related PR: keras-team/keras-hub#2441