Skip to content

Commit 1f41fca

Browse files
Hilly12recml authors
authored andcommitted
Make jax_tpu_embedding utils reference non global to avoid dependency.
PiperOrigin-RevId: 775915863
1 parent 8ea5b2c commit 1f41fca

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

recml/layers/linen/sparsecore.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
OptimizerSpec = Any
4545

4646

47+
def _num_sparsecores_per_device() -> int:
48+
"""Returns the number of sparsecores per tensorcore device."""
49+
return utils.num_sparsecores_per_device()
50+
51+
4752
# TODO(aahil): This should be common between Keras, Flax, NNX.
4853
@dataclasses.dataclass
4954
class EmbeddingSpec:
@@ -173,7 +178,7 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
173178
)
174179
global_device_count: int = dataclasses.field(default_factory=jax.device_count)
175180
num_sc_per_device: int = dataclasses.field(
176-
default_factory=utils.num_sparsecores_per_device
181+
default_factory=_num_sparsecores_per_device
177182
)
178183

179184
_feature_specs: Mapping[str, embedding_ops.FeatureSpec] | None = (
@@ -350,8 +355,8 @@ class SparsecoreEmbed(nn.Module):
350355
Attributes:
351356
sparsecore_config: A sparsecore config specifying how to create the tables.
352357
mesh: The mesh to use for the embedding layer. If not provided, the global
353-
mesh set by `jax.sharding.use_mesh` will be used. If neither is set,
354-
an error will be raised.
358+
mesh set by `jax.sharding.use_mesh` will be used. If neither is set, an
359+
error will be raised.
355360
"""
356361

357362
sparsecore_config: SparsecoreConfig

0 commit comments

Comments
 (0)