Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,120 @@ def _to_backend_layout(tensor_layout):
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)
jax_mesh = tensor_layout.device_mesh.backend_mesh
return jax.sharding.NamedSharding(jax_mesh, partition_spec)


def _distribute_initializer(
init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None
):
"""
Distribution-aware token embedding initializer for JAX backend.

This function will create a Jax random array and
distribute it according to the current token embedding layout.

Args:
init_func: A functools.partial-wrapped object that takes the seed
as argument and returns a jax.Array. Must have shape and dtype
already bound via partial.
mean: Mean of distribution (applied to normal/truncated_normal).
stddev: Standard deviation of the distribution.
seed: Random seed for initialization.
layout: TensorLayout for the distributed tensor.

Returns:
A distributed jax array.

Raises:
ValueError: If init_func or seed is None.
If init_func.func is not a supported random function.
Supported jax.random func: normal, truncated_normal, uniform
TypeError: If init_func is not a functools.partial object.
"""
import warnings
from functools import partial

# Create SeedGenerator to ensure backend variable exists
# For future state tracking for distributed keys, add
# attributes for base/split keys and number of devices sharded.
if isinstance(seed, jax.Array):
seed_gen = seed_generator.SeedGenerator(seed=int(seed[0]))
elif isinstance(seed, int):
seed_gen = seed_generator.SeedGenerator(seed=seed)
elif isinstance(seed, seed_generator.SeedGenerator):
seed_gen = seed
else:
raise ValueError(
f"seed must be int, JAX array, or SeedGenerator, got {type(seed)}"
)

# Extract the state value as JAX array
jax_seed = seed_gen.state.value

# Convert to JAX PRNG key format (swap counter and seed value)
jax_compatible_seed = jax.numpy.array(
[jax_seed[1], jax_seed[0]], dtype=jax.numpy.uint32
)

# Validate all required arguments
if init_func is None or init_func.func.__name__ not in [
"normal",
"truncated_normal",
"uniform",
]:
raise ValueError(
"init_func cannot be None or "
"Unsupported initializer: {init_func.func.__name__}."
"only JAX-compatible random initializers are supported. "
"Supported jax.random funcs: normal, truncated_normal, uniform"
)

# Ensure init_func is a partial
if not isinstance(init_func, partial):
raise TypeError(
f"init_func must be functools.partial object, got {type(init_func)}"
"init_func is a jax.random.* function with shape and "
"dtype bound via partial"
)

# Shard based on tensor layout
if layout is None:
warnings.warn(
f"The layout is {layout}, sharding will default to single device"
)
sharding = None
else:
sharding = _to_backend_layout(layout)

# JAX PRNG key handling within JIT:
# The key is passed directly to jax.random.* functions which are
# JIT-compatible and functional. JAX automatically ensures different
# random values per shard when out_shardings is specified.
try:
compiled_init = jax.jit(
lambda jax_compatible_seed: init_func(jax_compatible_seed),
out_shardings=sharding,
)
sample = compiled_init(jax_compatible_seed)
except RuntimeError as e:
warnings.warn(
f"Sharding failed due to: {e}, falling back to single device"
)
compiled_init = jax.jit(
lambda jax_compatible_seed: init_func(jax_compatible_seed),
out_shardings=None,
)
sample = compiled_init(jax_compatible_seed)

# Store the SeedGenerator for state tracking
seed = seed_gen.next()

# Apply mean/stddev only for distributions where it makes sense
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
return sample * stddev + mean
elif init_func.func == jax.random.uniform:
# Uniform doesn't use mean/stddev - warn
if mean != 0.0 or stddev != 1.0:
warnings.warn(
"mean and stddev are ignored for uniform distribution"
)
return sample