Skip to content

Commit 8ea5b2c

Browse files
Hilly12recml authors
authored andcommitted
Support fetching embedding tables to host and doing host lookups.
This also updates the partitioning API and allows for unboxing of variables in the data parallel parititioner. PiperOrigin-RevId: 775898177
1 parent a129913 commit 8ea5b2c

File tree

6 files changed

+473
-207
lines changed

6 files changed

+473
-207
lines changed

recml/core/ops/embedding_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SparsecoreParams:
3838
"""Embedding parameters."""
3939

4040
feature_specs: Nested[FeatureSpec]
41-
abstract_mesh: jax.sharding.AbstractMesh
41+
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
4242
data_axes: Sequence[str | None]
4343
embedding_axes: Sequence[str | None]
4444
sharding_strategy: str
@@ -53,11 +53,11 @@ def sparsecore_lookup(
5353
return shard_map.shard_map(
5454
functools.partial(
5555
embedding.tpu_sparse_dense_matmul,
56-
global_device_count=sparsecore_params.abstract_mesh.size,
56+
global_device_count=sparsecore_params.mesh.size,
5757
feature_specs=sparsecore_params.feature_specs,
5858
sharding_strategy=sparsecore_params.sharding_strategy,
5959
),
60-
mesh=sparsecore_params.abstract_mesh,
60+
mesh=sparsecore_params.mesh,
6161
in_specs=(
6262
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
6363
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
@@ -90,7 +90,7 @@ def _emb_lookup_bwd(
9090
feature_specs=sparsecore_params.feature_specs,
9191
sharding_strategy=sparsecore_params.sharding_strategy,
9292
),
93-
mesh=sparsecore_params.abstract_mesh,
93+
mesh=sparsecore_params.mesh,
9494
in_specs=(
9595
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
9696
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),

recml/core/training/partitioning.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import flax.linen as nn
2222
import jax
23-
from jax.experimental import mesh_utils
2423
import numpy as np
2524

2625

@@ -68,7 +67,7 @@ class DataParallelPartitioner(Partitioner):
6867
"""Data parallel partitioner."""
6968

7069
def __init__(self, data_axis: str = "batch"):
71-
self.mesh = jax.sharding.Mesh(jax.devices(), (data_axis,))
70+
self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,))
7271
self.data_sharding = jax.sharding.NamedSharding(
7372
self.mesh, jax.sharding.PartitionSpec(data_axis)
7473
)
@@ -109,6 +108,12 @@ def partition_init(
109108
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
110109
) -> CreateStateFn:
111110
with jax.sharding.use_mesh(self.mesh):
111+
if abstract_batch is not None:
112+
abstract_state = jax.eval_shape(init_fn, abstract_batch)
113+
specs = nn.get_partition_spec(abstract_state)
114+
self.state_sharding = jax.tree.map(
115+
lambda x: jax.sharding.NamedSharding(self.mesh, x), specs
116+
)
112117
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
113118

114119
def _wrapped_init(batch: PyTree) -> State:
@@ -145,12 +150,12 @@ class ModelParallelPartitioner(Partitioner):
145150
This only works with multi-controller Jax, i.e. communications along the ICI
146151
for TPUs. For scaling beyond a single TPU slice this needs to be extended to
147152
support Megascale XLA or single-controller Pathways. Consider using T5X, Pax,
148-
or Gemax for these use cases.
153+
MaxText externally or Gemax internally for these use cases.
149154
150-
Note: This assumes that all axes of the inputs except the final one are used
151-
for data parallelism while the final one is used for model parallelism.
152-
This tends to work well for 2D and 3D torus topologies since network latency
153-
tends to be much higher for the leading axes.
155+
By default, all axes of the input are used for data parallelism. This results
156+
in fully-sharded data-parallelism for ND topologies or data-parallelism for 1D
157+
topologies. The range of axes can be configured using the `dp_axes` argument,
158+
i.e. axes[:dp_axes] will be used for data parallelism.
154159
155160
IMPORTANT: `shard_inputs` operates on a per process batch. This means that the
156161
input batch size on CPU must already be the per process batch size,
@@ -160,45 +165,49 @@ class ModelParallelPartitioner(Partitioner):
160165

161166
def __init__(
162167
self,
163-
axes: Sequence[tuple[str, int]],
168+
axes: Sequence[tuple[str, int]] = (("batch", -1),),
169+
dp_axes: int | None = None,
164170
rules: Mapping[str, str] | None = None,
165171
aot_compile: bool = False,
166172
options: jax.stages.CompilerOptions | None = None,
173+
devices: Sequence[jax.Device] | None = None,
167174
):
168-
if len(axes) < 2:
175+
if not axes:
176+
raise ValueError("At least one axis must be specified in `axes`.")
177+
if dp_axes == 0:
178+
raise ValueError(
179+
"Data parallelism axes range must be positive or negative."
180+
)
181+
182+
devices = devices if devices is not None else jax.devices()
183+
axis_names = [axis for axis, _ in axes]
184+
axis_sizes = [dim for _, dim in axes]
185+
if any(dim <= 0 for dim in axis_sizes[1:]):
169186
raise ValueError(
170-
"`axes` cannot less than 2D, use data-parallel"
171-
f" partitioner instead. Got axes: {axes}."
187+
"All dimensions except the first in the axes must be positive"
188+
f" integers. Got axes: {axes}."
172189
)
190+
if axis_sizes[0] == -1:
191+
axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:])
173192

174-
mesh_devices = mesh_utils.create_device_mesh([dim for _, dim, in axes])
175-
self.mesh = jax.sharding.Mesh(mesh_devices, [axis for axis, _ in axes])
193+
self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
176194
self.rules = rules
177195
self.aot_compile = aot_compile
178196
self.options = options
179197

180-
dp_axes, dp_dims = zip(*axes[:-1])
181-
_, mp_dim = axes[-1]
182-
183-
if math.prod(dp_dims) % jax.process_count() != 0:
198+
dp_axis_names, dp_axis_sizes = zip(*axes[:dp_axes])
199+
num_processes = jax.process_count()
200+
if math.prod(dp_axis_sizes) % num_processes != 0:
184201
raise ValueError(
185202
"The data parallel dimensions in the mesh must be divisible by the"
186203
" number of processes as we assume data parallelism across"
187-
f" processes. Got process count: {jax.process_count()} and data"
188-
f" parallelism dimensions: {dp_dims} for axes: {axes} and mesh"
189-
f" devices: {self.mesh.devices}."
190-
)
191-
if jax.local_device_count() % mp_dim != 0:
192-
raise ValueError(
193-
"The number of local devices on each host must be divisible by the"
194-
" model dimension as we assume model parallelism across local"
195-
f" devices. Got local device count: {jax.local_device_count()} and"
196-
f" model parallelism dimension: {mp_dim} for axes: {axes} and mesh"
204+
f" processes. Got process count: {num_processes} and data"
205+
f" parallelism dimensions: {dp_axis_sizes} for axes: {axes} and mesh"
197206
f" devices: {self.mesh.devices}."
198207
)
199208

200209
self.data_sharding = jax.sharding.NamedSharding(
201-
self.mesh, jax.sharding.PartitionSpec(dp_axes)
210+
self.mesh, jax.sharding.PartitionSpec(dp_axis_names)
202211
)
203212
self.state_sharding = None
204213
self.abstract_batch = None

recml/core/training/partitioning_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_data_parallelism(
4040
self, partitioner_cls: type[partitioning.Partitioner]
4141
):
4242
if partitioner_cls is partitioning.ModelParallelPartitioner:
43-
kwargs = {"axes": [("data", jax.device_count()), ("model", 1)]}
43+
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1}
4444
else:
4545
kwargs = {}
4646
partitioner = partitioner_cls(**kwargs)
@@ -113,7 +113,7 @@ def _eval_step(
113113

114114
def test_model_parallelism(self):
115115
partitioner = partitioning.ModelParallelPartitioner(
116-
axes=[("data", 1), ("model", jax.device_count())]
116+
axes=[("data", 1), ("model", jax.device_count())], dp_axes=1
117117
)
118118

119119
inputs = np.zeros((128, 16), dtype=np.float32)

recml/examples/dlrm_experiment.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,19 @@ class DLRMModel(nn.Module):
9999
dcn_layers: int
100100
dcn_inner_dim: int
101101

102-
# We need to track the embedder on the Flax module to ensure it is not
103-
# re-created on cloning. It is not possible to create an embedder inside
104-
# setup() because it is called lazily at compile time. The embedder needs
102+
# We need to track the sparsecore config on the Flax module to ensure it is
103+
# not re-created on cloning. It is not possible to create an config inside
104+
# setup() because it is called lazily at compile time. The config needs
105105
# to be created before `model.init` so we can use it to create a preprocessor.
106-
# A simpler pattern that works is passing `embedder` directly to the module.
107-
_embedder: sparsecore.SparsecoreEmbedder | None = None
106+
# A simpler pattern that works is passing the config directly to the module.
107+
_sparsecore_config: sparsecore.SparsecoreConfig | None = None
108108

109109
@property
110-
def embedder(self) -> sparsecore.SparsecoreEmbedder:
111-
if self._embedder is not None:
112-
return self._embedder
110+
def sparsecore_config(self) -> sparsecore.SparsecoreConfig:
111+
if self._sparsecore_config is not None:
112+
return self._sparsecore_config
113113

114-
embedder = sparsecore.SparsecoreEmbedder(
114+
sparsecore_config = sparsecore.SparsecoreConfig(
115115
specs={
116116
f.name: sparsecore.EmbeddingSpec(
117117
input_dim=f.vocab_size,
@@ -123,8 +123,8 @@ def embedder(self) -> sparsecore.SparsecoreEmbedder:
123123
},
124124
optimizer=self.embedding_optimizer,
125125
)
126-
object.__setattr__(self, '_embedder', embedder)
127-
return embedder
126+
object.__setattr__(self, '_sparsecore_config', sparsecore_config)
127+
return sparsecore_config
128128

129129
def bottom_mlp(self, inputs: Mapping[str, jt.Array]) -> jt.Array:
130130
x = jnp.concatenate(
@@ -174,7 +174,9 @@ def __call__(
174174
self, inputs: Mapping[str, jt.Array], training: bool = False
175175
) -> jt.Array:
176176
dense_embeddings = self.bottom_mlp(inputs)
177-
sparse_embeddings = self.embedder.make_sparsecore_module()(inputs)
177+
sparse_embeddings = sparsecore.SparsecoreEmbed(
178+
self.sparsecore_config, name='sparsecore_embed'
179+
)(inputs)
178180
sparse_embeddings = jax.tree.flatten(sparse_embeddings)[0]
179181
concatenated_embeddings = jnp.concatenate(
180182
(dense_embeddings, *sparse_embeddings), axis=-1
@@ -239,11 +241,15 @@ def create_datasets(self) -> tuple[recml.data.Iterator, recml.data.Iterator]:
239241
global_batch_size = self.train_data.global_batch_size
240242
train_iter = recml.data.TFDatasetIterator(
241243
dataset=self.train_data.make(),
242-
postprocessor=self.model.embedder.make_preprocessor(global_batch_size),
244+
postprocessor=sparsecore.SparsecorePreprocessor(
245+
self.model.sparsecore_config, global_batch_size
246+
),
243247
)
244248
eval_iter = recml.data.TFDatasetIterator(
245249
dataset=self.eval_data.make(),
246-
postprocessor=self.model.embedder.make_preprocessor(global_batch_size),
250+
postprocessor=sparsecore.SparsecorePreprocessor(
251+
self.model.sparsecore_config, global_batch_size
252+
),
247253
)
248254
return train_iter, eval_iter
249255

0 commit comments

Comments
 (0)