Skip to content

Commit be3776a

Browse files
committed
Adding basic elastic training
Pulling pathwaysutils from github Added guards to only use fast-resume if the proxy backend is used. Added the changes to the jobset for elastic training Temporary changes to the configuration to decrease batch size Adding a stop_trace to cancel any ongoing traces checking in
1 parent 489a775 commit be3776a

File tree

4 files changed

+42
-8
lines changed

4 files changed

+42
-8
lines changed

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
320320
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
321321
f"--server_port={_PATHWAYS_PROXY_PORT}",
322322
f"--gcs_scratch_location={staging_location}",
323+
# This should be made configurable
324+
f"--num_elastic_slices={cfg.accelerator.num_replicas}"
323325
]
324326
cmd_args.extend(xla_flags_from_options(self._xla_options).split())
325327

@@ -588,14 +590,19 @@ def _build_pathways_worker_job(
588590
annotations.update(
589591
{"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"}
590592
)
593+
# Default value for suspend and resume.
594+
# References:
595+
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
596+
# backoffLimit = system.vms_per_slice * 4
597+
598+
# This backoffLimit is just for verifying elastic fast-resume
599+
large_number = 1000
600+
backoffLimit = system.vms_per_slice * 4 * large_number
591601

592602
spec = dict(
593603
parallelism=system.vms_per_slice,
594604
completions=system.vms_per_slice,
595-
# Default value for suspend and resume.
596-
# References:
597-
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
598-
backoffLimit=system.vms_per_slice * 4,
605+
backoffLimit=backoffLimit,
599606
template=self._build_pathways_worker_pod(pathways_worker_replicated_job_index),
600607
)
601608
worker_job = dict(

axlearn/common/launch_trainer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,34 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
147147
f,
148148
)
149149

150-
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
151-
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
152-
output = trainer.run(prng_key)
150+
if FLAGS.jax_backend == "proxy":
151+
# pylint: disable-next=import-error,import-outside-toplevel
152+
from pathwaysutils.elastic import manager
153+
elastic_manager = manager.Manager()
154+
while True:
155+
try:
156+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
157+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
158+
output = trainer.run(prng_key)
159+
break
160+
except jax.errors.JaxRuntimeError as error:
161+
if not elastic_manager.is_error_due_to_slice_down(error):
162+
raise
163+
try:
164+
logging.info("Trying to clean up ongoing traces")
165+
jax.stop_trace()
166+
logging.info("Successfully cleaned up ongoing traces")
167+
except ValueError as e:
168+
logging.info("No ongoing traces to clean up", exc_info=True)
169+
except Exception as e:
170+
logging.exception("Error trying to clean up ongoing traces")
171+
raise
172+
ten_minutes = 10 * 60
173+
elastic_manager.wait_for_slices(timeout=ten_minutes)
174+
else:
175+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
176+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
177+
output = trainer.run(prng_key)
178+
153179
measurement.record_event(measurement.Event.END_JOB)
154180
return output

axlearn/experiments/text/gpt/fuji.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def get_trainer_kwargs(
382382
max_sequence_length=max_sequence_length,
383383
train_batch_size=train_batch_size,
384384
max_step=max_step,
385+
save_every_n_steps=1000000,
385386
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
386387
mesh_rules=(
387388
# Step time:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ tpu = [
114114
pathways-tpu = [
115115
"axlearn[gcp]",
116116
"jax==0.5.3", # must be >=0.4.19 for compat with v5p.
117-
"pathwaysutils==0.1.1",
117+
"pathwaysutils @ git+https://github.com/AI-Hypercomputer/pathways-utils",
118118
]
119119
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
120120
vertexai_tensorboard = [

0 commit comments

Comments
 (0)