Skip to content

Conversation

@buildwithsuhana
Copy link
Contributor

@buildwithsuhana buildwithsuhana commented Sep 26, 2025

This Pull Request introduces the foundational components for a new, backend-agnostic auto-sharding system in Keras, specifically designed for tensor parallelism. It establishes the core data structures and the JAX-specific implementation of communication primitives.

  1. Core Backend-Agnostic Abstractions
    The most significant part of this PR is the creation of a generic, backend-agnostic system for defining sharding plans. This logic resides in keras/src/distribution/tensor_parallel/tensor_layout.py.

  2. JAX-Specific Backend Implementation
    This PR provides the first backend-specific implementation of the required distributed communication primitives,

Design Document: Autosharding for Keras

Example usage: https://colab.research.google.com/drive/1UAINIcstDuO0aeA9lxCF5LaIj5ne5X5z?resourcekey=0-pPF4COO19KRoqS5cpWNILA&usp=sharing

The full code of Tensor parallel for Keras has been devided into 4 PRs, this is the first PR for the same.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code.

Highlights

  • Core Distributed Backend Abstraction: Introduced BaseDistributedBackend as an abstract interface for distributed operations and a get_distributed_backend factory function to provide a unified, backend-agnostic way to interact with JAX, TensorFlow, PyTorch, and NumPy distributed environments.
  • High-Level Communication Primitives: Defined AllReduceKeras, AllGatherKeras, BroadcastKeras, and ScatterKeras classes, which serve as high-level wrappers for essential collective communication operations required for tensor parallelism.
  • Tensor Sharding Actions: Implemented StateActionKeras as an abstract base class for defining how tensors are transformed for distribution. Concrete implementations like SplitKeras handle tensor sharding, while GatherKeras and SumKeras define how to reconstruct original tensors from their distributed parts.
  • Sharding Plan Configuration: Introduced the ConfigKeras dataclass to store and manage model-wide sharding rules and output configurations, including a mechanism to dynamically create collective operations based on these rules.
  • Tensor Parallel Communicator: Added TensorParallelCommunicator to orchestrate complex communication patterns for tensor parallelism, including specific methods for handling forward and backward passes in column-parallel and row-parallel operations, along with gradient slicing logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.

@codecov-commenter
Copy link

codecov-commenter commented Sep 26, 2025

Codecov Report

❌ Patch coverage is 72.72727% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.55%. Comparing base (5ae5503) to head (9e7f873).
⚠️ Report is 31 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 27.27% 8 Missing ⚠️
keras/api/_tf_keras/keras/distribution/__init__.py 0.00% 1 Missing ⚠️
keras/src/backend/jax/distribution_lib.py 50.00% 1 Missing ⚠️
keras/src/distribution/distribution_lib.py 66.66% 1 Missing ⚠️
.../src/distribution/tensor_parallel/tensor_layout.py 96.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21697      +/-   ##
==========================================
- Coverage   82.59%   82.55%   -0.05%     
==========================================
  Files         572      574       +2     
  Lines       58327    58932     +605     
  Branches     9131     9224      +93     
==========================================
+ Hits        48177    48651     +474     
- Misses       7818     7935     +117     
- Partials     2332     2346      +14     
Flag Coverage Δ
keras 82.35% <72.72%> (-0.05%) ⬇️
keras-jax 63.12% <72.72%> (-0.19%) ⬇️
keras-numpy 57.70% <31.81%> (+0.05%) ⬆️
keras-openvino 34.39% <31.81%> (+0.07%) ⬆️
keras-tensorflow 63.99% <31.81%> (-0.06%) ⬇️
keras-torch 63.54% <31.81%> (-0.10%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@JyotinderSingh JyotinderSingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a few initial comments and questions during my first look.

To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

Some high level comments:

  • Out of context, it's really hard for me to understand why these abstractions are needed for Tensor Parallel.
    • Why do we need all these primitives?
    • Why do we need 3 layers of abstraction for the same concepts: the communications layer, the state_actions layer and the keras.distributed.get_communication_ops layer? Can we just have one?
  • These abstraction look Torch-like and not JAX-like. On JAX you never have to manually split and do an all-gather, you simply shard. You never have to explicitly have to do a "collective sum". You just do a sum, and if the tensors are sharded, it will magically do all the needed collectives for you. So it's unclear to me why any of these are needed for JAX.
  • I wouldn't export these symbols that you added to keras.distributed, I don't think they are needed. What we'll expose is the "Tensor Parallel" API.
  • For the better or worse, we don't do type annotations in Keras. And unfortunately, mixing code with type annotations with code without type annotation doesn't work well. It's better to not have any type annotations at all.

Comment on lines +556 to +558
sum_val = lax.psum(x, axis_name=axis_name)
axis_size = lax.psum(1, axis_name=axis_name)
return sum_val / axis_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use jax.lax.pmean?

The gathered tensor, which will have a larger size along `axis`
dimension.
"""
return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a general question. Is autosharding supposed to work for multi-host distribution or only single host?

Comment on lines +34 to +35
int: The total number of devices configured in the current distribution
strategy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation line 31 is correct. The documentation here is misleading. Maybe there is no current distribution strategy configured yet. This API works regardless of whether there is a distribution strategy.

   Returns: int representing the number of JAX devices.

Comment on lines +45 to +46
Returns the total number of devices (e.g., GPUs, TPUs) available for the
current distribution strategy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same way as list_devices, this is unrelated to any distribution strategy, it is purely wrapping the backend implementation.

Also:

  • please add device_type like list_devices for consistency (I should have caught this earlier).
  • format wise we always do this so that the documentation is formatted properly:
"""Single line documenting the function / class.

More details after a blank line.
"""

Comment on lines +48 to +49
Returns:
int: The total number of devices configured in the current distribution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment, don't mention "distribution strategy".

Comment on lines +10 to +13
It implements sharding by slicing a tensor along one of its axes.
It handles cases where the dimension size is not perfectly divisible by the
number of workers by distributing the remainder elements one by one to the
first few workers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: this class is purely a "smart split" that can handle non-divisible dimensions gracefully, but it doesn't actually shard. So mentioning "sharding" and "workers" here is a bit misleading.

Additionally, we shouldn't use the term "workers" because it's usually for multiple processes or multiple hosts, but what we're doing here is multiple devices connected to a single host and there is no "workers", at least not on JAX.

elif sharding_type == "column":
self.dim = 1

def __call__(self, tensor, rank):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, why is this a class and not just a function?

Args:
tensor: The full tensor to be sharded.
rank: The rank of the worker for which to get the shard.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rank is a bit confusing here, normally we only use the term for the rank of a tensor. How about index?

Same comment about the term worker. Just say index: The index of the slice to return.

worker rank, handling uneven distributions gracefully.
Args:
tensor: The full tensor to be sharded.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use the term "sharded", the tensor is not getting sharded, only split.

tensor: the full tensor to get a slice of.

device_count: The total number of workers/shards.
dim: The dimension along which to split the tensor. If -1, the
last dimension is used.
sharding_type: If `dim` is -1, this can be 'row' (dim=0) or
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have both sharding_type and dim? It makes things more complicated. Simply, dim should be one of 0, 1 or -1.

from keras.src import ops


class Split:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking at it again, I thought of two things:

  1. This exists as a NumPy API: array_split, so we should just make this a normal keras op that can be used for other purposes.
  2. Support for uneven shards is limited. Intermediary values can be sharded unevenly, but the outputs of a jitted function must be evenly sharded. Can you make sure this is not a blocker for this project?
    [1] Silent data replication when using shard constraints jax-ml/jax#26946 (comment)
    [2] https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants