Skip to content

Commit 7ffd13d

Browse files
author
maxtext authors
committed
Merge pull request #1533 from AI-Hypercomputer:lance-tp-fix2
PiperOrigin-RevId: 750766633
2 parents 9f62bc4 + 03ad82e commit 7ffd13d

17 files changed

+466
-85
lines changed

MaxText/configs/base.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ checkpoint_storage_use_zarr3: True
6363
checkpoint_storage_concurrent_gb: 96
6464
############################### END CHECKPOINTING ##################################
6565

66-
66+
############################### BEGIN TESTING ##################################
6767
reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch.
68-
69-
7068
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
69+
disalbe_key_validation: False # for testing, if true, skip the key validation.
70+
############################### END TESTING ##################################
71+
7172
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
7273
gcs_metrics: False
7374

@@ -330,7 +331,7 @@ logical_axis_rules: [
330331
]
331332
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
332333
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
333-
334+
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
334335
# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
335336
sharding_tolerance: 0.02
336337

MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def make_grain_train_iterator(
196196
tokenize=config.tokenize_train_data,
197197
grain_worker_count=config.grain_worker_count,
198198
)
199-
return multihost_dataloading.MultiHostDataLoadIterator(train_dataloader, global_mesh)
199+
return multihost_dataloading.MultiHostDataLoadIterator(train_dataloader, global_mesh, config)
200200
else:
201201
get_ds_fn = functools.partial(
202202
get_datasets,
@@ -262,7 +262,7 @@ def make_grain_eval_iterator(
262262
tokenize=config.tokenize_eval_data,
263263
grain_worker_count=config.grain_worker_count_eval,
264264
)
265-
return multihost_dataloading.MultiHostDataLoadIterator(eval_dataloader, global_mesh)
265+
return multihost_dataloading.MultiHostDataLoadIterator(eval_dataloader, global_mesh, config)
266266
else:
267267
get_ds_fn = functools.partial(
268268
get_datasets,

MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,7 @@ def preprocessing_pipeline(
185185
read_options=grain.ReadOptions(num_threads=num_threads, prefetch_buffer_size=128),
186186
)
187187

188-
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh)
189-
190-
# Return multi-host jax.Array prep iterator
191-
return multihost_gen
188+
return dataloader
192189

193190

194191
def make_hf_train_iterator(
@@ -205,7 +202,7 @@ def make_hf_train_iterator(
205202
streaming=True,
206203
token=config.hf_access_token,
207204
)
208-
train_iter = preprocessing_pipeline(
205+
train_data_loader = preprocessing_pipeline(
209206
dataloading_host_index=process_indices_train.index(jax.process_index()),
210207
dataloading_host_count=len(process_indices_train),
211208
global_mesh=global_mesh,
@@ -226,6 +223,7 @@ def make_hf_train_iterator(
226223
use_sft=config.use_sft,
227224
sft_train_on_completion_only=config.sft_train_on_completion_only,
228225
)
226+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_data_loader, global_mesh, config)
229227
return train_iter
230228

231229

@@ -247,7 +245,7 @@ def make_hf_eval_iterator(
247245
eval_generate_padding_example = True
248246
else:
249247
eval_generate_padding_example = False
250-
eval_iter = preprocessing_pipeline(
248+
eval_data_loader = preprocessing_pipeline(
251249
dataloading_host_index=process_indices_eval.index(jax.process_index()),
252250
dataloading_host_count=len(process_indices_eval),
253251
global_mesh=global_mesh,
@@ -268,4 +266,5 @@ def make_hf_eval_iterator(
268266
use_sft=config.use_sft,
269267
sft_train_on_completion_only=config.sft_train_on_completion_only,
270268
)
269+
eval_iter = multihost_dataloading.MultiHostDataLoadIterator(eval_data_loader, global_mesh, config)
271270
return eval_iter

MaxText/input_pipeline/_tfds_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def make_tfds_train_iterator(
197197
use_dpo=config.use_dpo,
198198
hf_access_token=config.hf_access_token,
199199
)
200-
return multihost_dataloading.MultiHostDataLoadIterator(train_dataloader, global_mesh)
200+
return multihost_dataloading.MultiHostDataLoadIterator(train_dataloader, global_mesh, config)
201201
else:
202202
get_ds_fn = functools.partial(
203203
get_datasets,
@@ -261,7 +261,7 @@ def make_tfds_eval_iterator(
261261
use_dpo=config.use_dpo,
262262
hf_access_token=config.hf_access_token,
263263
)
264-
return multihost_dataloading.MultiHostDataLoadIterator(eval_dataloader, global_mesh)
264+
return multihost_dataloading.MultiHostDataLoadIterator(eval_dataloader, global_mesh, config)
265265
else:
266266
get_ds_fn = functools.partial(
267267
get_datasets,

MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def make_c4_mlperf_train_iterator(
330330
shuffle_buffer_size=128,
331331
data_shuffle_seed=config.data_shuffle_seed,
332332
)
333-
train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh)
333+
train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh, config)
334334
return train_multihost_gen
335335

336336

@@ -360,7 +360,7 @@ def make_c4_mlperf_eval_iterator(
360360
max_target_length=config.max_target_length,
361361
)
362362

363-
eval_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(eval_ds, global_mesh)
363+
eval_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(eval_ds, global_mesh, config)
364364

365365
# Return multi-host jax.Array prep iterator
366366
return eval_multihost_gen

MaxText/input_pipeline/input_pipeline_interface.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from MaxText.input_pipeline._grain_data_processing import make_grain_train_iterator, make_grain_eval_iterator
2727
from MaxText.input_pipeline._tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator
2828
from MaxText.input_pipeline._hf_data_processing import make_hf_train_iterator, make_hf_eval_iterator
29+
from MaxText import maxtext_utils
2930
from MaxText import multihost_dataloading
3031

3132

@@ -35,8 +36,9 @@ class SyntheticDataIterator:
3536
def __init__(self, config, mesh):
3637
self.mesh = mesh
3738
self.config = config
38-
data_pspec = P(*config.data_sharding)
39-
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
39+
data_pspec_shardings = maxtext_utils.get_input_data_sharding(
40+
mesh, config.input_data_sharding_logical_axes, config.logical_axis_rules
41+
)
4042
self.data_generator = jax.jit(
4143
SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0
4244
)
@@ -82,7 +84,7 @@ class BadSyntheticDataIterator:
8284
def __init__(self, config, mesh):
8385
self.mesh = mesh
8486
dataset = BadSyntheticDataIterator.get_bad_synthetic_data(config)
85-
self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh)
87+
self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh, config)
8688

8789
def __iter__(self):
8890
return self.data_generator
@@ -118,8 +120,7 @@ def get_process_loading_real_data(
118120
data_sharding, global_batch_size_to_load, global_batch_size_to_train_on, max_target_length, mesh
119121
):
120122
"""Get list of processes loading data from GCS when expansion_factor_real_data != -1"""
121-
sharding = jax.sharding.NamedSharding(mesh, P(*data_sharding))
122-
devices_indices_map = sharding.devices_indices_map((global_batch_size_to_load, max_target_length))
123+
devices_indices_map = data_sharding.devices_indices_map((global_batch_size_to_load, max_target_length))
123124
batch_cutoff = global_batch_size_to_train_on
124125
process_loading_real_data = set()
125126
for p, indices in devices_indices_map.items():
@@ -149,16 +150,19 @@ def create_data_iterator(config, mesh):
149150
if config.dataset_type == "synthetic":
150151
return SyntheticDataIterator(config, mesh), None
151152

153+
input_data_sharding = maxtext_utils.get_input_data_sharding(
154+
mesh, config.input_data_sharding_logical_axes, config.logical_axis_rules
155+
)
152156
process_indices_train = get_process_loading_real_data(
153-
config.data_sharding,
157+
input_data_sharding,
154158
config.global_batch_size_to_load,
155159
config.global_batch_size_to_train_on,
156160
config.max_target_length,
157161
mesh,
158162
)
159163
if config.eval_interval > 0:
160164
process_indices_eval = get_process_loading_real_data(
161-
config.data_sharding,
165+
input_data_sharding,
162166
config.global_batch_size_to_load_eval,
163167
config.global_batch_size_to_eval_on,
164168
config.max_target_length,

MaxText/max_utils.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -373,21 +373,27 @@ def get_coordinator_ip_address():
373373
return coordinator_ip_address
374374

375375

376-
def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type):
376+
def get_unspecified_mesh_axes_value(parallelism_vals, target_product, parallelism_type):
377377
"""Evaluates unspecified DCN/ICI parallelism values"""
378-
if -1 in parallelism_vals:
379-
assert (
380-
parallelism_vals.count(-1) == 1
381-
), f"Found unspecified values (-1) for more than one {parallelism_type}\
382-
parallelism axis. At most one axis can be unspecified."
378+
assert (
379+
parallelism_vals.count(-1) == 1
380+
), f"Found unspecified values (-1) for more than one {parallelism_type}\
381+
parallelism axis. At most one axis can be unspecified."
382+
383+
determined_val = target_product / np.prod(parallelism_vals) * -1
384+
385+
assert (
386+
determined_val >= 1 and determined_val.is_integer
387+
), f"Unspecified value unable to be determined with the given\
388+
{parallelism_type} parallelism values"
383389

384-
determined_val = target_product / np.prod(parallelism_vals) * -1
390+
return int(determined_val)
385391

386-
assert (
387-
determined_val >= 1 and determined_val.is_integer
388-
), f"Unspecified value unable to be determined with the given\
389-
{parallelism_type} parallelism values"
390392

393+
def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type):
394+
"""Evaluates unspecified DCN/ICI parallelism values"""
395+
if -1 in parallelism_vals:
396+
determined_val = get_unspecified_mesh_axes_value(parallelism_vals, target_product, parallelism_type)
391397
parallelism_vals[parallelism_vals.index(-1)] = int(determined_val)
392398

393399
target_type = "slices" if parallelism_type == "DCN" else "devices per slice"
@@ -780,6 +786,35 @@ def reorder_causal_load_balanced(batch, cp_size):
780786
}
781787

782788

789+
def shard_reorder_causal_load_balanced(batch, cp_size):
790+
"""Shard the output of the reordered sequence."""
791+
reordered = reorder_causal_load_balanced(batch, cp_size)
792+
for _, v in batch.items():
793+
if isinstance(v, jax.Array):
794+
reordered = jax.lax.with_sharding_constraint(reordered, v.sharding)
795+
break
796+
return reordered
797+
798+
783799
def get_reorder_callable(cp_size):
784800
"""Creates a callable that can be used with map() to reorder batches."""
785-
return functools.partial(reorder_causal_load_balanced, cp_size=cp_size)
801+
return functools.partial(shard_reorder_causal_load_balanced, cp_size=cp_size)
802+
803+
804+
def compute_axis_product(axis_spec, mesh_dict):
805+
"""Computes the product of the axis specified in axis_spec."""
806+
if isinstance(axis_spec, str):
807+
axis_spec = (axis_spec,)
808+
elif axis_spec is None:
809+
return 1
810+
product = 1
811+
for dim_name in axis_spec:
812+
if dim_name in mesh_dict:
813+
product *= mesh_dict[dim_name]
814+
return product
815+
816+
817+
def construct_parallelism_name(mesh_axis: str, prefix: str) -> str:
818+
if mesh_axis == "stage":
819+
return f"{prefix}_pipeline_parallelism"
820+
return f"{prefix}_{mesh_axis}_parallelism"

MaxText/maxtext_utils.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from MaxText import max_utils
2424
from jax.sharding import PartitionSpec as P
2525
from jax.experimental.serialize_executable import deserialize_and_load
26+
from flax import linen as nn
2627

2728
import pickle
2829
import functools
@@ -32,6 +33,7 @@
3233
from flax.linen import partitioning as nn_partitioning
3334

3435
from MaxText import max_logging
36+
import ml_collections
3537
import numpy as np
3638
import jax.numpy as jnp
3739
from MaxText import checkpointing
@@ -50,12 +52,16 @@
5052
NUM_IMAGE_CHANNELS = 3
5153

5254

55+
def get_input_data_sharding(mesh, input_data_sharding_logical_axes, logical_axis_rules):
56+
data_pspec = P(*input_data_sharding_logical_axes)
57+
return nn.logical_to_mesh_sharding(data_pspec, mesh, logical_axis_rules)
58+
59+
5360
def get_functional_train_with_signature(train_step, mesh, state_mesh_shardings, model, config):
5461
"""Get the shardings (both state and data) for train_step"""
5562
functional_train = get_functional_train_step(train_step, model, config, state_mesh_shardings)
5663
functional_train.__name__ = "train_step"
57-
data_pspec = P(*config.data_sharding)
58-
data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
64+
data_sharding = get_input_data_sharding(mesh, config.input_data_sharding_logical_axes, config.logical_axis_rules)
5965
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
6066
out_shardings = (state_mesh_shardings, None) # State, metrics
6167
static_argnums = () # We partial out the static argnums of model and config
@@ -71,8 +77,7 @@ def get_functional_eval_with_signature(eval_step, mesh, state_mesh_shardings, mo
7177
"""Get the shardings (both state and data) for eval_step"""
7278
functional_eval = get_functional_eval_step(eval_step, model, config)
7379
functional_eval.__name__ = "eval_step"
74-
data_pspec = P(*config.data_sharding)
75-
data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
80+
data_sharding = get_input_data_sharding(mesh, config.input_data_sharding_logical_axes, config.logical_axis_rules)
7681
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
7782
out_shardings = None # metrics
7883
static_argnums = () # We partial out the static argnums of model, config
@@ -690,14 +695,36 @@ def add_config_to_summary_writer(config, summary_writer):
690695
max_utils.add_text_to_summary_writer(key, str(value), summary_writer)
691696

692697

693-
def create_device_mesh(config, devices=None):
694-
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
698+
def get_ici_parallelism(config, devices=None):
699+
"""Get the ICI parallelism for the model."""
695700
if devices is None:
696701
devices = jax.devices()
697702
num_devices = len(devices)
698703
num_slices = 1 if config.inference_benchmark_test else config.num_slices
699704
num_devices_per_slice = num_devices // num_slices
700705

706+
return max_utils.fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
707+
708+
709+
def get_dcn_parallelism(config):
710+
"""Get the DCN parallelism for the model."""
711+
num_slices = 1 if config.inference_benchmark_test else config.num_slices
712+
return max_utils.fill_unspecified_mesh_axes(config.dcn_parallelism.copy(), num_slices, "DCN")
713+
714+
715+
def get_slices_and_devices(config, devices=None):
716+
if devices is None:
717+
devices = jax.devices()
718+
num_devices = len(devices)
719+
num_slices = 1 if config.inference_benchmark_test else config.num_slices
720+
num_devices_per_slice = num_devices // num_slices
721+
return num_slices, num_devices_per_slice
722+
723+
724+
def create_device_mesh(config, devices=None):
725+
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
726+
num_slices, num_devices_per_slice = get_slices_and_devices(config, devices)
727+
num_devices = num_devices_per_slice * num_slices
701728
multi_slice_env = num_slices > 1
702729

703730
# Find possible unspecified parallelisms

0 commit comments

Comments
 (0)