20
20
21
21
import flax .linen as nn
22
22
import jax
23
- from jax .experimental import mesh_utils
24
23
import numpy as np
25
24
26
25
@@ -68,7 +67,7 @@ class DataParallelPartitioner(Partitioner):
68
67
"""Data parallel partitioner."""
69
68
70
69
def __init__ (self , data_axis : str = "batch" ):
71
- self .mesh = jax .sharding . Mesh ( jax .devices ( ), (data_axis ,))
70
+ self .mesh = jax .make_mesh (( jax .device_count (), ), (data_axis ,))
72
71
self .data_sharding = jax .sharding .NamedSharding (
73
72
self .mesh , jax .sharding .PartitionSpec (data_axis )
74
73
)
@@ -109,6 +108,12 @@ def partition_init(
109
108
self , init_fn : CreateStateFn , * , abstract_batch : PyTree | None = None
110
109
) -> CreateStateFn :
111
110
with jax .sharding .use_mesh (self .mesh ):
111
+ if abstract_batch is not None :
112
+ abstract_state = jax .eval_shape (init_fn , abstract_batch )
113
+ specs = nn .get_partition_spec (abstract_state )
114
+ self .state_sharding = jax .tree .map (
115
+ lambda x : jax .sharding .NamedSharding (self .mesh , x ), specs
116
+ )
112
117
init_fn = jax .jit (init_fn , out_shardings = self .state_sharding )
113
118
114
119
def _wrapped_init (batch : PyTree ) -> State :
@@ -145,12 +150,12 @@ class ModelParallelPartitioner(Partitioner):
145
150
This only works with multi-controller Jax, i.e. communications along the ICI
146
151
for TPUs. For scaling beyond a single TPU slice this needs to be extended to
147
152
support Megascale XLA or single-controller Pathways. Consider using T5X, Pax,
148
- or Gemax for these use cases.
153
+ MaxText externally or Gemax internally for these use cases.
149
154
150
- Note: This assumes that all axes of the inputs except the final one are used
151
- for data parallelism while the final one is used for model parallelism.
152
- This tends to work well for 2D and 3D torus topologies since network latency
153
- tends to be much higher for the leading axes .
155
+ By default, all axes of the input are used for data parallelism. This results
156
+ in fully-sharded data- parallelism for ND topologies or data-parallelism for 1D
157
+ topologies. The range of axes can be configured using the `dp_axes` argument,
158
+ i.e. axes[:dp_axes] will be used for data parallelism .
154
159
155
160
IMPORTANT: `shard_inputs` operates on a per process batch. This means that the
156
161
input batch size on CPU must already be the per process batch size,
@@ -160,45 +165,49 @@ class ModelParallelPartitioner(Partitioner):
160
165
161
166
def __init__ (
162
167
self ,
163
- axes : Sequence [tuple [str , int ]],
168
+ axes : Sequence [tuple [str , int ]] = (("batch" , - 1 ),),
169
+ dp_axes : int | None = None ,
164
170
rules : Mapping [str , str ] | None = None ,
165
171
aot_compile : bool = False ,
166
172
options : jax .stages .CompilerOptions | None = None ,
173
+ devices : Sequence [jax .Device ] | None = None ,
167
174
):
168
- if len (axes ) < 2 :
175
+ if not axes :
176
+ raise ValueError ("At least one axis must be specified in `axes`." )
177
+ if dp_axes == 0 :
178
+ raise ValueError (
179
+ "Data parallelism axes range must be positive or negative."
180
+ )
181
+
182
+ devices = devices if devices is not None else jax .devices ()
183
+ axis_names = [axis for axis , _ in axes ]
184
+ axis_sizes = [dim for _ , dim in axes ]
185
+ if any (dim <= 0 for dim in axis_sizes [1 :]):
169
186
raise ValueError (
170
- "`axes` cannot less than 2D, use data-parallel "
171
- f" partitioner instead . Got axes: { axes } ."
187
+ "All dimensions except the first in the axes must be positive "
188
+ f" integers . Got axes: { axes } ."
172
189
)
190
+ if axis_sizes [0 ] == - 1 :
191
+ axis_sizes [0 ] = len (devices ) // math .prod (axis_sizes [1 :])
173
192
174
- mesh_devices = mesh_utils .create_device_mesh ([dim for _ , dim , in axes ])
175
- self .mesh = jax .sharding .Mesh (mesh_devices , [axis for axis , _ in axes ])
193
+ self .mesh = jax .make_mesh (axis_sizes , axis_names , devices = devices )
176
194
self .rules = rules
177
195
self .aot_compile = aot_compile
178
196
self .options = options
179
197
180
- dp_axes , dp_dims = zip (* axes [:- 1 ])
181
- _ , mp_dim = axes [- 1 ]
182
-
183
- if math .prod (dp_dims ) % jax .process_count () != 0 :
198
+ dp_axis_names , dp_axis_sizes = zip (* axes [:dp_axes ])
199
+ num_processes = jax .process_count ()
200
+ if math .prod (dp_axis_sizes ) % num_processes != 0 :
184
201
raise ValueError (
185
202
"The data parallel dimensions in the mesh must be divisible by the"
186
203
" number of processes as we assume data parallelism across"
187
- f" processes. Got process count: { jax .process_count ()} and data"
188
- f" parallelism dimensions: { dp_dims } for axes: { axes } and mesh"
189
- f" devices: { self .mesh .devices } ."
190
- )
191
- if jax .local_device_count () % mp_dim != 0 :
192
- raise ValueError (
193
- "The number of local devices on each host must be divisible by the"
194
- " model dimension as we assume model parallelism across local"
195
- f" devices. Got local device count: { jax .local_device_count ()} and"
196
- f" model parallelism dimension: { mp_dim } for axes: { axes } and mesh"
204
+ f" processes. Got process count: { num_processes } and data"
205
+ f" parallelism dimensions: { dp_axis_sizes } for axes: { axes } and mesh"
197
206
f" devices: { self .mesh .devices } ."
198
207
)
199
208
200
209
self .data_sharding = jax .sharding .NamedSharding (
201
- self .mesh , jax .sharding .PartitionSpec (dp_axes )
210
+ self .mesh , jax .sharding .PartitionSpec (dp_axis_names )
202
211
)
203
212
self .state_sharding = None
204
213
self .abstract_batch = None
0 commit comments