Skip to content

Support fetching embedding tables to host and doing host lookups. #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions recml/core/ops/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SparsecoreParams:
"""Embedding parameters."""

feature_specs: Nested[FeatureSpec]
abstract_mesh: jax.sharding.AbstractMesh
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
data_axes: Sequence[str | None]
embedding_axes: Sequence[str | None]
sharding_strategy: str
Expand All @@ -53,11 +53,11 @@ def sparsecore_lookup(
return shard_map.shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul,
global_device_count=sparsecore_params.abstract_mesh.size,
global_device_count=sparsecore_params.mesh.size,
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.abstract_mesh,
mesh=sparsecore_params.mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
Expand Down Expand Up @@ -90,7 +90,7 @@ def _emb_lookup_bwd(
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.abstract_mesh,
mesh=sparsecore_params.mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
Expand Down
65 changes: 37 additions & 28 deletions recml/core/training/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import flax.linen as nn
import jax
from jax.experimental import mesh_utils
import numpy as np


Expand Down Expand Up @@ -68,7 +67,7 @@ class DataParallelPartitioner(Partitioner):
"""Data parallel partitioner."""

def __init__(self, data_axis: str = "batch"):
self.mesh = jax.sharding.Mesh(jax.devices(), (data_axis,))
self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,))
self.data_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec(data_axis)
)
Expand Down Expand Up @@ -109,6 +108,12 @@ def partition_init(
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
) -> CreateStateFn:
with jax.sharding.use_mesh(self.mesh):
if abstract_batch is not None:
abstract_state = jax.eval_shape(init_fn, abstract_batch)
specs = nn.get_partition_spec(abstract_state)
self.state_sharding = jax.tree.map(
lambda x: jax.sharding.NamedSharding(self.mesh, x), specs
)
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)

def _wrapped_init(batch: PyTree) -> State:
Expand Down Expand Up @@ -145,12 +150,12 @@ class ModelParallelPartitioner(Partitioner):
This only works with multi-controller Jax, i.e. communications along the ICI
for TPUs. For scaling beyond a single TPU slice this needs to be extended to
support Megascale XLA or single-controller Pathways. Consider using T5X, Pax,
or Gemax for these use cases.
MaxText externally or Gemax internally for these use cases.

Note: This assumes that all axes of the inputs except the final one are used
for data parallelism while the final one is used for model parallelism.
This tends to work well for 2D and 3D torus topologies since network latency
tends to be much higher for the leading axes.
By default, all axes of the input are used for data parallelism. This results
in fully-sharded data-parallelism for ND topologies or data-parallelism for 1D
topologies. The range of axes can be configured using the `dp_axes` argument,
i.e. axes[:dp_axes] will be used for data parallelism.

IMPORTANT: `shard_inputs` operates on a per process batch. This means that the
input batch size on CPU must already be the per process batch size,
Expand All @@ -160,45 +165,49 @@ class ModelParallelPartitioner(Partitioner):

def __init__(
self,
axes: Sequence[tuple[str, int]],
axes: Sequence[tuple[str, int]] = (("batch", -1),),
dp_axes: int | None = None,
rules: Mapping[str, str] | None = None,
aot_compile: bool = False,
options: jax.stages.CompilerOptions | None = None,
devices: Sequence[jax.Device] | None = None,
):
if len(axes) < 2:
if not axes:
raise ValueError("At least one axis must be specified in `axes`.")
if dp_axes == 0:
raise ValueError(
"Data parallelism axes range must be positive or negative."
)

devices = devices if devices is not None else jax.devices()
axis_names = [axis for axis, _ in axes]
axis_sizes = [dim for _, dim in axes]
if any(dim <= 0 for dim in axis_sizes[1:]):
raise ValueError(
"`axes` cannot less than 2D, use data-parallel"
f" partitioner instead. Got axes: {axes}."
"All dimensions except the first in the axes must be positive"
f" integers. Got axes: {axes}."
)
if axis_sizes[0] == -1:
axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:])

mesh_devices = mesh_utils.create_device_mesh([dim for _, dim, in axes])
self.mesh = jax.sharding.Mesh(mesh_devices, [axis for axis, _ in axes])
self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
self.rules = rules
self.aot_compile = aot_compile
self.options = options

dp_axes, dp_dims = zip(*axes[:-1])
_, mp_dim = axes[-1]

if math.prod(dp_dims) % jax.process_count() != 0:
dp_axis_names, dp_axis_sizes = zip(*axes[:dp_axes])
num_processes = jax.process_count()
if math.prod(dp_axis_sizes) % num_processes != 0:
raise ValueError(
"The data parallel dimensions in the mesh must be divisible by the"
" number of processes as we assume data parallelism across"
f" processes. Got process count: {jax.process_count()} and data"
f" parallelism dimensions: {dp_dims} for axes: {axes} and mesh"
f" devices: {self.mesh.devices}."
)
if jax.local_device_count() % mp_dim != 0:
raise ValueError(
"The number of local devices on each host must be divisible by the"
" model dimension as we assume model parallelism across local"
f" devices. Got local device count: {jax.local_device_count()} and"
f" model parallelism dimension: {mp_dim} for axes: {axes} and mesh"
f" processes. Got process count: {num_processes} and data"
f" parallelism dimensions: {dp_axis_sizes} for axes: {axes} and mesh"
f" devices: {self.mesh.devices}."
)

self.data_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec(dp_axes)
self.mesh, jax.sharding.PartitionSpec(dp_axis_names)
)
self.state_sharding = None
self.abstract_batch = None
Expand Down
4 changes: 2 additions & 2 deletions recml/core/training/partitioning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_data_parallelism(
self, partitioner_cls: type[partitioning.Partitioner]
):
if partitioner_cls is partitioning.ModelParallelPartitioner:
kwargs = {"axes": [("data", jax.device_count()), ("model", 1)]}
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1}
else:
kwargs = {}
partitioner = partitioner_cls(**kwargs)
Expand Down Expand Up @@ -113,7 +113,7 @@ def _eval_step(

def test_model_parallelism(self):
partitioner = partitioning.ModelParallelPartitioner(
axes=[("data", 1), ("model", jax.device_count())]
axes=[("data", 1), ("model", jax.device_count())], dp_axes=1
)

inputs = np.zeros((128, 16), dtype=np.float32)
Expand Down
34 changes: 20 additions & 14 deletions recml/examples/dlrm_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,19 @@ class DLRMModel(nn.Module):
dcn_layers: int
dcn_inner_dim: int

# We need to track the embedder on the Flax module to ensure it is not
# re-created on cloning. It is not possible to create an embedder inside
# setup() because it is called lazily at compile time. The embedder needs
# We need to track the sparsecore config on the Flax module to ensure it is
# not re-created on cloning. It is not possible to create an config inside
# setup() because it is called lazily at compile time. The config needs
# to be created before `model.init` so we can use it to create a preprocessor.
# A simpler pattern that works is passing `embedder` directly to the module.
_embedder: sparsecore.SparsecoreEmbedder | None = None
# A simpler pattern that works is passing the config directly to the module.
_sparsecore_config: sparsecore.SparsecoreConfig | None = None

@property
def embedder(self) -> sparsecore.SparsecoreEmbedder:
if self._embedder is not None:
return self._embedder
def sparsecore_config(self) -> sparsecore.SparsecoreConfig:
if self._sparsecore_config is not None:
return self._sparsecore_config

embedder = sparsecore.SparsecoreEmbedder(
sparsecore_config = sparsecore.SparsecoreConfig(
specs={
f.name: sparsecore.EmbeddingSpec(
input_dim=f.vocab_size,
Expand All @@ -123,8 +123,8 @@ def embedder(self) -> sparsecore.SparsecoreEmbedder:
},
optimizer=self.embedding_optimizer,
)
object.__setattr__(self, '_embedder', embedder)
return embedder
object.__setattr__(self, '_sparsecore_config', sparsecore_config)
return sparsecore_config

def bottom_mlp(self, inputs: Mapping[str, jt.Array]) -> jt.Array:
x = jnp.concatenate(
Expand Down Expand Up @@ -174,7 +174,9 @@ def __call__(
self, inputs: Mapping[str, jt.Array], training: bool = False
) -> jt.Array:
dense_embeddings = self.bottom_mlp(inputs)
sparse_embeddings = self.embedder.make_sparsecore_module()(inputs)
sparse_embeddings = sparsecore.SparsecoreEmbed(
self.sparsecore_config, name='sparsecore_embed'
)(inputs)
sparse_embeddings = jax.tree.flatten(sparse_embeddings)[0]
concatenated_embeddings = jnp.concatenate(
(dense_embeddings, *sparse_embeddings), axis=-1
Expand Down Expand Up @@ -239,11 +241,15 @@ def create_datasets(self) -> tuple[recml.data.Iterator, recml.data.Iterator]:
global_batch_size = self.train_data.global_batch_size
train_iter = recml.data.TFDatasetIterator(
dataset=self.train_data.make(),
postprocessor=self.model.embedder.make_preprocessor(global_batch_size),
postprocessor=sparsecore.SparsecorePreprocessor(
self.model.sparsecore_config, global_batch_size
),
)
eval_iter = recml.data.TFDatasetIterator(
dataset=self.eval_data.make(),
postprocessor=self.model.embedder.make_preprocessor(global_batch_size),
postprocessor=sparsecore.SparsecorePreprocessor(
self.model.sparsecore_config, global_batch_size
),
)
return train_iter, eval_iter

Expand Down
Loading
Loading