Skip to content

Commit 8f9bed3

Browse files
committed
Adding basic elastic training
Added guards to only use fast-resume if the proxy backend is used. Added the changes to the jobset for elastic training Temporary changes to the configuration to decrease batch size Adding a stop_trace to cancel any ongoing traces Changing the batch size to match the chip count and the checkpoint step interval to avoid any checkpoints for testing
1 parent 9e739e3 commit 8f9bed3

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
320320
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
321321
f"--server_port={_PATHWAYS_PROXY_PORT}",
322322
f"--gcs_scratch_location={staging_location}",
323+
# This should be made configurable
324+
f"--num_elastic_slices={cfg.accelerator.num_replicas}"
323325
]
324326
cmd_args.extend(xla_flags_from_options(self._xla_options).split())
325327

@@ -581,14 +583,19 @@ def _build_pathways_worker_job(
581583
annotations.update(
582584
{"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"}
583585
)
586+
# Default value for suspend and resume.
587+
# References:
588+
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
589+
# backoffLimit = system.vms_per_slice * 4
590+
591+
# This backoffLimit is just for verifying elastic fast-resume
592+
large_number = 1000
593+
backoffLimit = system.vms_per_slice * 4 * large_number
584594

585595
spec = dict(
586596
parallelism=system.vms_per_slice,
587597
completions=system.vms_per_slice,
588-
# Default value for suspend and resume.
589-
# References:
590-
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
591-
backoffLimit=system.vms_per_slice * 4,
598+
backoffLimit=backoffLimit,
592599
template=self._build_pathways_worker_pod(pathways_worker_replicated_job_index),
593600
)
594601
worker_job = dict(

axlearn/common/launch_trainer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,34 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
147147
f,
148148
)
149149

150-
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
151-
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
152-
output = trainer.run(prng_key)
150+
if FLAGS.jax_backend == "proxy":
151+
# pylint: disable-next=import-error,import-outside-toplevel
152+
from pathwaysutils.elastic import manager
153+
elastic_manager = manager.Manager()
154+
while True:
155+
try:
156+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
157+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
158+
output = trainer.run(prng_key)
159+
break
160+
except jax.errors.JaxRuntimeError as error:
161+
if not elastic_manager.is_error_due_to_slice_down(error):
162+
raise
163+
try:
164+
logging.info("Trying to clean up ongoing traces")
165+
jax.stop_trace()
166+
logging.info("Successfully cleaned up ongoing traces")
167+
except ValueError as e:
168+
logging.info("No ongoing traces to clean up", exc_info=True)
169+
except Exception as e:
170+
logging.exception("Error trying to clean up ongoing traces")
171+
raise
172+
ten_minutes = 10 * 60
173+
elastic_manager.wait_for_slices(timeout=ten_minutes)
174+
else:
175+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
176+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
177+
output = trainer.run(prng_key)
178+
153179
measurement.record_event(measurement.Event.END_JOB)
154180
return output

axlearn/experiments/text/gpt/fuji.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def get_trainer_kwargs(
249249
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
250250
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
251251
train_batch_size = tokens_per_batch // max_sequence_length
252+
train_batch_size = 16
252253

253254
# Whether to use grouped query attention.
254255
num_kv_heads = None
@@ -380,6 +381,7 @@ def get_trainer_kwargs(
380381
max_sequence_length=max_sequence_length,
381382
train_batch_size=train_batch_size,
382383
max_step=max_step,
384+
save_every_n_steps=1000000,
383385
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
384386
mesh_rules=(
385387
# Step time:

0 commit comments

Comments
 (0)