Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
a27367a
Added tensor parallel for keras (Part 1/3)
buildwithsuhana Sep 26, 2025
488cd8f
Removed unnecessary lines
buildwithsuhana Sep 26, 2025
71ddd1a
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
bc4e4e2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
d4200b5
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
21f89a2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
299bd45
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
da625e1
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
c233b8c
Fixing the failing test
buildwithsuhana Sep 26, 2025
7b8d733
Fixing the failing test
buildwithsuhana Sep 26, 2025
f825cd3
Fixing test
buildwithsuhana Sep 26, 2025
3725180
Adding tests for distributed_backends
buildwithsuhana Sep 29, 2025
a6c8a96
Modifications for failing tests
buildwithsuhana Sep 29, 2025
3fabfde
Modified for failing test
buildwithsuhana Sep 29, 2025
b133752
Modified for failing test
buildwithsuhana Sep 29, 2025
83c2e3f
Modified for failing test
buildwithsuhana Sep 29, 2025
3f3be6b
added debuggers
buildwithsuhana Sep 29, 2025
be325ab
removed debuggers
buildwithsuhana Sep 29, 2025
e1282ac
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Sep 29, 2025
fc11aaa
Removed the tensorflow, numpy and torch backends
buildwithsuhana Sep 30, 2025
ef6e2a0
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Sep 30, 2025
bea6ffa
Refactoring the code
buildwithsuhana Sep 30, 2025
4e00245
Refactoring the code
buildwithsuhana Sep 30, 2025
2f973b0
refactoring
buildwithsuhana Sep 30, 2025
bdb2b84
Adding necessary docstrings
buildwithsuhana Sep 30, 2025
d77fa71
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Oct 1, 2025
b9990b0
Removing redundancies
buildwithsuhana Oct 3, 2025
0aeee6f
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 3, 2025
f784956
Modifying tests
buildwithsuhana Oct 3, 2025
8895a78
Reformatting
buildwithsuhana Oct 3, 2025
fe97f3b
Reformatting the code
buildwithsuhana Oct 3, 2025
77f01aa
Fixing failing tests
buildwithsuhana Oct 3, 2025
7080328
fixes
buildwithsuhana Oct 3, 2025
af711fd
Fixing tests
buildwithsuhana Oct 3, 2025
97dde17
formatting
buildwithsuhana Oct 3, 2025
f322a97
fixing test
buildwithsuhana Oct 3, 2025
5269ac9
fixing test
buildwithsuhana Oct 3, 2025
b9f36e9
Removing redundant lines
buildwithsuhana Oct 6, 2025
555e5c9
Refactoring to remove communications.py and state_action_keras.py
buildwithsuhana Oct 12, 2025
b80d264
formatting the files
buildwithsuhana Oct 12, 2025
93b1738
fixing skip issues
buildwithsuhana Oct 12, 2025
b7b2b9b
fixing test
buildwithsuhana Oct 12, 2025
f6c1142
fixing test
buildwithsuhana Oct 12, 2025
669c799
refactoring to remove distributed backend wrapper
buildwithsuhana Oct 13, 2025
cd20b9f
fixing test
buildwithsuhana Oct 13, 2025
cd0049f
making distrubed backend more jax friendly
buildwithsuhana Oct 13, 2025
d1e4c69
Fixing comments
buildwithsuhana Oct 17, 2025
86e0557
Fixing comments
buildwithsuhana Oct 17, 2025
6c3883f
Fixing comments
buildwithsuhana Oct 17, 2025
3e31e1e
fixes
buildwithsuhana Oct 17, 2025
c99601e
Refactor
buildwithsuhana Oct 18, 2025
dbae56d
refactoring to resolve comments
buildwithsuhana Oct 18, 2025
2fc0f0e
fixes
buildwithsuhana Oct 18, 2025
174093c
fixes
buildwithsuhana Oct 18, 2025
7d18b0a
fix
buildwithsuhana Oct 18, 2025
f570925
fix
buildwithsuhana Oct 18, 2025
9e7f873
removing get_best_devices
buildwithsuhana Oct 21, 2025
5136091
fixing comments
buildwithsuhana Oct 26, 2025
8f40c53
Merge branch 'master' into Tensor_parallel_keras
buildwithsuhana Oct 26, 2025
08b8abe
fixing merge conflict
buildwithsuhana Oct 26, 2025
3a408da
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 26, 2025
eb796ea
modifying variable name
buildwithsuhana Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
distribute_tensor as distribute_tensor,
)
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import (
get_device_count as get_device_count,
)
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import (
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
3 changes: 3 additions & 0 deletions keras/api/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
distribute_tensor as distribute_tensor,
)
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import (
get_device_count as get_device_count,
)
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import (
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
56 changes: 56 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.experimental.sparse as jax_sparse
import jax.lax as lax
import jax.numpy as jnp
import ml_dtypes
import numpy as np
Expand Down Expand Up @@ -529,6 +530,61 @@ def remat(f):
return jax.checkpoint(f)


def all_reduce(x, op="sum", axis_name="model"):
"""
Performs an **all-reduce** operation across all replicas in the specified
distribution axis.

The all-reduce operation computes a reduction (like sum, mean, or product)
of the input tensor `x` across all devices/replicas in the `axis_name`
group, and then broadcasts the result back to all participating devices.

Args:
x: The tensor to reduce.
op: The reduction operation to perform. Common options include "sum",
"mean", or "product". Defaults to "sum".
axis_name: The name of the distribution axis (e.g., "model",
"data") over which to perform the reduction. Defaults to "model".

Returns:
The result of the all-reduce operation, with the same shape as the
input `x`.
"""
if op == "sum":
return lax.psum(x, axis_name=axis_name)
elif op == "mean":
return lax.pmean(x, axis_name=axis_name)
else:
raise ValueError(
f"Unsupported reduction operation: {op}. "
"Supported options are 'sum' and 'mean'."
)


def all_gather(x, axis, axis_name="model"):
"""
Performs an all-gather operation across all replicas in the specified
distribution axis.

The all-gather operation collects the input tensor `x` from all devices
in the `axis_name` group and concatenates them along the specified `axis`.
This is often used in tensor parallelism to combine parts of a tensor
distributed across devices.

Args:
x: The tensor to gather.
axis: The dimension along which to concatenate the gathered tensors.
axis_name: The name of the distribution axis (e.g., "model",
"data") over which to perform the gather.
Defaults to "model".

Returns:
The gathered tensor, which will have a larger size along `axis`
dimension.
"""
return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True)


class name_scope(base_name_scope):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down
78 changes: 78 additions & 0 deletions keras/src/backend/jax/core_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os

import jax
Expand All @@ -9,6 +10,8 @@
from keras.src import backend
from keras.src import testing
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax.core import all_gather
from keras.src.backend.jax.core import all_reduce

if is_nnx_enabled():
from flax import nnx
Expand Down Expand Up @@ -66,3 +69,78 @@ def test_keras_variable_nnx_split_merge_sync(self):
state = jax.tree.map(lambda x: x + 1, state)
variable2 = nnx.merge(graphdef, state)
self.assertEqual(variable2._value, variable2.value)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="JAX backend specific test for collective operations.",
)
@pytest.mark.skipif(
jax.local_device_count() < 2,
reason="Requires multiple local devices for testing.",
)
class JaxCollectiveOpsTest(testing.TestCase):
def test_all_reduce_sum(self):
"""Tests the all_reduce operation with the 'sum' reduction."""
num_devices = jax.local_device_count()
local_value = 10.0

local_inputs = jax.numpy.array([local_value] * num_devices)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def reduce_sum_fn(x):
return all_reduce(x, op="sum", axis_name="all")

result = reduce_sum_fn(local_inputs)
expected_sum = local_value * num_devices

self.assertTrue(np.allclose(result, expected_sum))
self.assertEqual(result.shape, (num_devices,))

def test_all_reduce_mean(self):
"""Tests the all_reduce operation with the 'mean' reduction."""
num_devices = jax.local_device_count()
local_value = 10.0

local_inputs = jax.numpy.array([local_value] * num_devices)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def reduce_mean_fn(x):
return all_reduce(x, op="mean", axis_name="all")

result = reduce_mean_fn(local_inputs)
expected_mean = local_value

self.assertTrue(np.allclose(result, expected_mean))
self.assertEqual(result.shape, (num_devices,))

def test_all_gather(self):
"""Tests the all_gather operation."""
num_devices = jax.local_device_count()
local_data = np.arange(5)

local_inputs = jax.numpy.stack(
[local_data + (i * 5) for i in range(num_devices)]
)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def gather_fn(x):
return all_gather(x, axis=0, axis_name="all")

result_array_on_devices = gather_fn(local_inputs)

expected_shape = (num_devices, num_devices * local_data.shape[0])
self.assertEqual(result_array_on_devices.shape, expected_shape)

expected_gathered_data = np.arange(num_devices * local_data.shape[0])

for i in range(num_devices):
self.assertTrue(
np.allclose(result_array_on_devices[i], expected_gathered_data)
)
15 changes: 15 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ def list_devices(device_type=None):
return [f"{device.platform}:{device.id}" for device in jax_devices]


def get_device_count(device_type=None):
"""Returns the number of available JAX devices.

Args:
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
If `None`, it counts all available devices.

Returns:
int: The total number of JAX devices for the specified type.
"""
device_type = device_type.lower() if device_type else None
jax_devices = jax.devices(backend=device_type)
return len(jax_devices)


def distribute_variable(value, layout):
"""Create a distributed variable for JAX.

Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@


@pytest.mark.skipif(
backend.backend() != "jax",
reason="Backend specific test",
backend.backend() != "jax" or len(jax.devices()) != 8,
reason="Backend specific test and requires 8 devices",
)
class JaxDistributionLibTest(testing.TestCase):
def _create_jax_layout(self, sharding):
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,10 @@ def split(x, indices_or_sections, axis=0):
return jnp.split(x, indices_or_sections, axis=axis)


def array_split(x, indices_or_sections, axis=0):
return jnp.array_split(x, indices_or_sections, axis=axis)


def stack(x, axis=0):
x = [convert_to_tensor(t) for t in x]
return jnp.stack(x, axis=axis)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,11 @@ def split(x, indices_or_sections, axis=0):
return np.split(x, indices_or_sections, axis=axis)


def array_split(x, indices_or_sections, axis=0):
axis = standardize_axis_for_numpy(axis)
return np.array_split(x, indices_or_sections, axis=axis)


def stack(x, axis=0):
axis = standardize_axis_for_numpy(axis)
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
Expand Down
67 changes: 67 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,73 @@ def split(x, indices_or_sections, axis=0):
)


def array_split(x, indices_or_sections, axis=0):
x = get_ov_output(x)

if not isinstance(indices_or_sections, int):
raise TypeError(
"Argument `indices_or_sections` must be of type `int`. "
f"Received: {indices_or_sections}"
)
if indices_or_sections <= 0:
raise ValueError(
"Argument `indices_or_sections` must be a positive integer. "
f"Received: {indices_or_sections}"
)

num_splits_val = indices_or_sections
num_splits = ov_opset.constant(
np.array(num_splits_val, dtype=np.int64)
).output(0)

axis_tensor = ov_opset.constant(np.array(axis, dtype=np.int64)).output(0)

zero_scalar = ov_opset.constant(np.array(0, dtype=np.int64)).output(0)

one_scalar = ov_opset.constant(np.array(1, dtype=np.int64)).output(0)

shape_tensor = ov_opset.shape_of(x, Type.i64).output(0)
axis_i64_vec = ov_opset.constant([axis], dtype=Type.i64).output(0)

total_size_tensor_vec = ov_opset.gather(
shape_tensor, axis_i64_vec, zero_scalar
).output(0)

total_size = ov_opset.squeeze(total_size_tensor_vec, zero_scalar).output(0)

split_size = ov_opset.divide(
total_size, num_splits, auto_broadcast="NUMPY"
).output(0)

remainder = ov_opset.mod(
total_size, num_splits, auto_broadcast="NUMPY"
).output(0)

splits_shape = ov_opset.constant([num_splits_val], dtype=Type.i64).output(0)
all_splits_base = ov_opset.broadcast(split_size, splits_shape).output(0)

range_splits = ov_opset.range(
zero_scalar,
num_splits,
one_scalar,
Type.i64,
).output(0)

remainder_bcast = ov_opset.broadcast(remainder, splits_shape).output(0)

add_one_mask = ov_opset.less(range_splits, remainder_bcast).output(0)

add_one_values = ov_opset.convert(add_one_mask, Type.i64).output(0)

split_lengths = ov_opset.add(all_splits_base, add_one_values).output(0)
splits = ov_opset.variadic_split(x, axis_tensor, split_lengths)

result = []
for i in range(num_splits_val):
result.append(OpenVINOKerasTensor(splits.output(i)))
return result


def stack(x, axis=0):
if isinstance(x, tuple):
x = list(x)
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,24 @@ def split(x, indices_or_sections, axis=0):
return tf.split(x, num_or_size_splits, axis=axis)


def array_split(x, indices_or_sections, axis=0):
x = tf.convert_to_tensor(x)
num_splits = indices_or_sections
total_size = tf.shape(x)[axis]
avg_size = tf.math.floordiv(total_size, num_splits)
remainder = tf.math.floormod(total_size, num_splits)

sizes = tf.concat(
[
tf.fill([remainder], avg_size + 1),
tf.fill([num_splits - remainder], avg_size),
],
axis=0,
)

return tf.split(x, sizes, axis=axis)


def stack(x, axis=0):
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
if len(dtype_set) > 1:
Expand Down
7 changes: 7 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,13 @@ def split(x, indices_or_sections, axis=0):
return list(out)


def array_split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x)
axis_int = int(axis)
out = torch.tensor_split(x, indices_or_sections, dim=axis_int)
return list(out)


def stack(x, axis=0):
x = [convert_to_tensor(elem) for elem in x]
return torch.stack(x, dim=axis)
Expand Down
1 change: 1 addition & 0 deletions keras/src/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.src.distribution.distribution_lib import TensorLayout
from keras.src.distribution.distribution_lib import distribute_tensor
from keras.src.distribution.distribution_lib import distribution
from keras.src.distribution.distribution_lib import get_device_count
from keras.src.distribution.distribution_lib import initialize
from keras.src.distribution.distribution_lib import list_devices
from keras.src.distribution.distribution_lib import set_distribution
Loading