-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21697 +/- ##
==========================================
- Coverage 82.59% 82.55% -0.05%
==========================================
Files 572 574 +2
Lines 58327 58932 +605
Branches 9131 9224 +93
==========================================
+ Hits 48177 48651 +474
- Misses 7818 7935 +117
- Partials 2332 2346 +14
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a few initial comments and questions during my first look.
To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
Some high level comments:
- Out of context, it's really hard for me to understand why these abstractions are needed for Tensor Parallel.
- Why do we need all these primitives?
- Why do we need 3 layers of abstraction for the same concepts: the
communicationslayer, thestate_actionslayer and thekeras.distributed.get_communication_opslayer? Can we just have one?
- These abstraction look Torch-like and not JAX-like. On JAX you never have to manually split and do an all-gather, you simply shard. You never have to explicitly have to do a "collective sum". You just do a sum, and if the tensors are sharded, it will magically do all the needed collectives for you. So it's unclear to me why any of these are needed for JAX.
- I wouldn't export these symbols that you added to
keras.distributed, I don't think they are needed. What we'll expose is the "Tensor Parallel" API. - For the better or worse, we don't do type annotations in Keras. And unfortunately, mixing code with type annotations with code without type annotation doesn't work well. It's better to not have any type annotations at all.
| sum_val = lax.psum(x, axis_name=axis_name) | ||
| axis_size = lax.psum(1, axis_name=axis_name) | ||
| return sum_val / axis_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use jax.lax.pmean?
| The gathered tensor, which will have a larger size along `axis` | ||
| dimension. | ||
| """ | ||
| return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a general question. Is autosharding supposed to work for multi-host distribution or only single host?
| int: The total number of devices configured in the current distribution | ||
| strategy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation line 31 is correct. The documentation here is misleading. Maybe there is no current distribution strategy configured yet. This API works regardless of whether there is a distribution strategy.
Returns: int representing the number of JAX devices.
| Returns the total number of devices (e.g., GPUs, TPUs) available for the | ||
| current distribution strategy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the same way as list_devices, this is unrelated to any distribution strategy, it is purely wrapping the backend implementation.
Also:
- please add
device_typelikelist_devicesfor consistency (I should have caught this earlier). - format wise we always do this so that the documentation is formatted properly:
"""Single line documenting the function / class.
More details after a blank line.
"""| Returns: | ||
| int: The total number of devices configured in the current distribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment, don't mention "distribution strategy".
| It implements sharding by slicing a tensor along one of its axes. | ||
| It handles cases where the dimension size is not perfectly divisible by the | ||
| number of workers by distributing the remainder elements one by one to the | ||
| first few workers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: this class is purely a "smart split" that can handle non-divisible dimensions gracefully, but it doesn't actually shard. So mentioning "sharding" and "workers" here is a bit misleading.
Additionally, we shouldn't use the term "workers" because it's usually for multiple processes or multiple hosts, but what we're doing here is multiple devices connected to a single host and there is no "workers", at least not on JAX.
| elif sharding_type == "column": | ||
| self.dim = 1 | ||
|
|
||
| def __call__(self, tensor, rank): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, why is this a class and not just a function?
| Args: | ||
| tensor: The full tensor to be sharded. | ||
| rank: The rank of the worker for which to get the shard. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank is a bit confusing here, normally we only use the term for the rank of a tensor. How about index?
Same comment about the term worker. Just say index: The index of the slice to return.
| worker rank, handling uneven distributions gracefully. | ||
| Args: | ||
| tensor: The full tensor to be sharded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use the term "sharded", the tensor is not getting sharded, only split.
tensor: the full tensor to get a slice of.
| device_count: The total number of workers/shards. | ||
| dim: The dimension along which to split the tensor. If -1, the | ||
| last dimension is used. | ||
| sharding_type: If `dim` is -1, this can be 'row' (dim=0) or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have both sharding_type and dim? It makes things more complicated. Simply, dim should be one of 0, 1 or -1.
| from keras.src import ops | ||
|
|
||
|
|
||
| class Split: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, looking at it again, I thought of two things:
- This exists as a NumPy API: array_split, so we should just make this a normal keras op that can be used for other purposes.
- Support for uneven shards is limited. Intermediary values can be sharded unevenly, but the outputs of a jitted function must be evenly sharded. Can you make sure this is not a blocker for this project?
[1] Silent data replication when using shard constraints jax-ml/jax#26946 (comment)
[2] https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html
This Pull Request introduces the foundational components for a new, backend-agnostic auto-sharding system in Keras, specifically designed for tensor parallelism. It establishes the core data structures and the JAX-specific implementation of communication primitives.
Core Backend-Agnostic Abstractions
The most significant part of this PR is the creation of a generic, backend-agnostic system for defining sharding plans. This logic resides in keras/src/distribution/tensor_parallel/tensor_layout.py.
JAX-Specific Backend Implementation
This PR provides the first backend-specific implementation of the required distributed communication primitives,
Design Document: Autosharding for Keras
Example usage: https://colab.research.google.com/drive/1UAINIcstDuO0aeA9lxCF5LaIj5ne5X5z?resourcekey=0-pPF4COO19KRoqS5cpWNILA&usp=sharing
The full code of Tensor parallel for Keras has been devided into 4 PRs, this is the first PR for the same.