Skip to content

Commit 5830df1

Browse files
yashk2810recml authors
authored andcommitted
Use Layout in place of DeviceLocalLayout and .layout in place of .device_local_layout.
JAX is undergoing a rename of the contents of jax.experimental.layouts in preparation for its graduation from experimental: * "Formats" are replacing "layouts", and specifically `Layout -> Format` * "Layouts" are replacing "device local layouts", and specifically `DeviceLocalLayout -> Layout` This is an incremental update carrying out #2. PiperOrigin-RevId: 774508649
1 parent 566f874 commit 5830df1

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

recml/layers/linen/sparsecore.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
from recml.core.ops import embedding_ops
3131
import tensorflow as tf
3232

33+
if jax.__version_info__ >= (0, 6, 3):
34+
DLL = layout.Layout
35+
else:
36+
DLL = layout.DeviceLocalLayout
37+
3338

3439
with epy.lazy_imports():
3540
# pylint: disable=g-import-not-at-top
@@ -382,7 +387,7 @@ class SparsecoreLayout(nn.Partitioned[A]):
382387
def get_sharding(self, _):
383388
assert self.mesh is not None
384389
return layout.Format(
385-
layout.DeviceLocalLayout(major_to_minor=(0, 1), _tiling=((8,),)),
390+
DLL(major_to_minor=(0, 1), _tiling=((8,),)),
386391
jax.sharding.NamedSharding(self.mesh, self.get_partition_spec()),
387392
)
388393

0 commit comments

Comments
 (0)