Skip to content

Commit a74c0f9

Browse files
committed
Merge pull request #1174 from Edwinhr716:lws-integration
GitOrigin-RevId: 6fe455e23b395d623898e702e0449bcd38fccf00
2 parents b6ab98e + 1d8119c commit a74c0f9

File tree

12 files changed

+1278
-20
lines changed

12 files changed

+1278
-20
lines changed

axlearn/cloud/gcp/job.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
from axlearn.cloud.common.utils import generate_job_name, subprocess_run
2020
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone
2121
from axlearn.cloud.gcp.jobset_utils import BaseReplicatedJob
22-
from axlearn.cloud.gcp.utils import custom_jobset_kwargs, delete_k8s_jobset
22+
from axlearn.cloud.gcp.lws_utils import BaseLeaderWorkerTemplate
23+
from axlearn.cloud.gcp.utils import (
24+
custom_jobset_kwargs,
25+
custom_leaderworkerset_kwargs,
26+
delete_k8s_jobset,
27+
delete_k8s_leaderworkerset,
28+
)
2329
from axlearn.common.config import REQUIRED, ConfigOr, Required, config_class, maybe_instantiate
2430
from axlearn.common.utils import Nested
2531

@@ -267,3 +273,103 @@ def docker_command(
267273
)
268274
logging.debug("Docker run command: %s", cmd)
269275
return cmd
276+
277+
278+
class GKELeaderWorkerSet(GCPJob):
279+
"""Base GKE LeaderWorkerSet interface"""
280+
281+
@config_class
282+
class Config(GCPJob.Config):
283+
"""Configures GKELeaderWorkerSet.
284+
Attributes:
285+
builder: A builder that returns one or more statefulset specs.
286+
namespace: The namespace to use within the k8s cluster.
287+
annotations: LeaderWorkerSet annotations.
288+
num_replicas: number of LWS replicas.
289+
"""
290+
291+
builder: Required[BaseLeaderWorkerTemplate.Config] = REQUIRED
292+
namespace: str = "default"
293+
annotations: Optional[ConfigOr[dict]] = None
294+
num_replicas: int = 1
295+
296+
@classmethod
297+
def set_defaults(cls, fv):
298+
super().set_defaults(fv)
299+
fv.set_default("max_tries", fv.max_tries or 10)
300+
fv.set_default("retry_interval", fv.retry_interval or 60)
301+
302+
@classmethod
303+
def define_flags(cls, fv: flags.FlagValues):
304+
super().define_flags(fv)
305+
common_kwargs = dict(flag_values=fv, allow_override=True)
306+
flags.DEFINE_string("name", None, "Name of the LeaderWorkerSet.", **common_kwargs)
307+
308+
@classmethod
309+
def from_flags(cls, fv: flags.FlagValues, **kwargs):
310+
cfg: GKELeaderWorkerSet.Config = super().from_flags(fv, **kwargs)
311+
cfg.num_replicas = fv.num_replicas
312+
return cfg
313+
314+
def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
315+
super().__init__(cfg)
316+
cfg: GKELeaderWorkerSet.Config = self.config
317+
self._bundler = bundler
318+
# This instantiatees a builder for constructing replicated job specs, which will be managed
319+
# together under the leaderworkerset represented by this class.
320+
# Note the distinction from bundlers, which are responsible for bundling any code assets
321+
# required to run the job.
322+
self._builder: BaseLeaderWorkerTemplate = cfg.builder.instantiate(bundler=bundler)
323+
324+
def _delete(self):
325+
cfg: GKELeaderWorkerSet.Config = self.config
326+
# Issues a delete request for the LeaderWorkerSet and proactively delete its descendants.
327+
# This is not fully blocking; after the call returns there can be a delay before
328+
# everything is deleted.
329+
delete_k8s_leaderworkerset(cfg.name, namespace=cfg.namespace)
330+
331+
def _build_leaderworkerset(self) -> Nested[Any]:
332+
"""
333+
Builds a config for a LeaderWorkerSet, which is a set for multi-host inference
334+
335+
Returns:
336+
A nested dict corresponding to a k8s LWS config
337+
"""
338+
cfg: GKELeaderWorkerSet.Config = self.config
339+
annotations = maybe_instantiate(cfg.annotations or {})
340+
341+
return dict(
342+
metadata=dict(name=cfg.name, annotations=annotations),
343+
spec=dict(
344+
replicas=cfg.num_replicas,
345+
leaderWorkerTemplate=self._builder(),
346+
),
347+
)
348+
349+
def _execute(self):
350+
cfg: GKELeaderWorkerSet.Config = self.config
351+
352+
api_kwargs = custom_leaderworkerset_kwargs()
353+
custom_object = dict(
354+
apiVersion=f"{api_kwargs['group']}/{api_kwargs['version']}",
355+
kind="LeaderWorkerSet",
356+
**self._build_leaderworkerset(),
357+
)
358+
logging.info("submitting LeaderWorkerSet: %s", custom_object)
359+
return k8s.client.CustomObjectsApi().create_namespaced_custom_object(
360+
namespace=cfg.namespace,
361+
body=custom_object,
362+
**api_kwargs,
363+
)
364+
365+
366+
def exclusive_topology_annotations_leaderworkerset() -> dict:
367+
"""Used for TPU GKELeaderWorkerSet.
368+
369+
The exclusive topology annotation will ensure that all Pods will have affinity
370+
rules added that will ensure that they are fully scheduled on the same pod-slice
371+
node-pools.
372+
"""
373+
return {
374+
"leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology": "cloud.google.com/gke-nodepool"
375+
}

axlearn/cloud/gcp/job_test.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from axlearn.cloud.common.bundler import Bundler
1313
from axlearn.cloud.common.utils import define_flags, from_flags
14-
from axlearn.cloud.gcp import bundler, job, jobset_utils
14+
from axlearn.cloud.gcp import bundler, job, jobset_utils, pathways_utils
1515
from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler, CloudBuildBundler
1616
from axlearn.cloud.gcp.test_utils import default_mock_settings, mock_gcp_settings
1717
from axlearn.common.config import REQUIRED, Required, config_class
@@ -211,3 +211,90 @@ def test_build_jobset(
211211
self.assertNotIn("kueue.x-k8s.io/queue-name", jobset_annotations)
212212
else:
213213
self.assertEqual(jobset_annotations["kueue.x-k8s.io/queue-name"], queue)
214+
215+
216+
class TPUGKELeaderWorkerSetTest(TestCase):
217+
"""Tests GKELeaderWorkerSet with TPU."""
218+
219+
def run(self, result=None):
220+
# Run tests under mock user and settings.
221+
self._settings = default_mock_settings()
222+
with mock_gcp_settings(
223+
[jobset_utils.__name__, bundler.__name__],
224+
settings=self._settings,
225+
):
226+
return super().run(result)
227+
228+
def _job_config(
229+
self,
230+
*,
231+
command: str,
232+
bundler_cls: type[Bundler],
233+
**kwargs,
234+
) -> tuple[job.GKELeaderWorkerSet.Config, Bundler.Config]:
235+
fv = flags.FlagValues()
236+
cfg = job.GKELeaderWorkerSet.default_config().set(
237+
builder=pathways_utils.PathwaysLeaderWorkerTemplate.default_config()
238+
)
239+
define_flags(cfg, fv)
240+
for key, value in kwargs.items():
241+
if value is not None:
242+
# Use setattr rather than set_default to set flags.
243+
setattr(fv, key, value)
244+
fv.name = "fake-name"
245+
fv.output_dir = "FAKE"
246+
fv.instance_type = "tpu-v4-8"
247+
fv.mark_as_parsed()
248+
from_flags(cfg, fv, command=command)
249+
# Test that retries are configured on fv by default.
250+
self.assertIsNotNone(fv["max_tries"].default)
251+
self.assertIsNotNone(fv["retry_interval"].default)
252+
bundler_cfg = bundler_cls.from_spec([], fv=fv).set(image="test-image")
253+
return cfg, bundler_cfg
254+
255+
@parameterized.product(
256+
reservation=[None, "test"],
257+
bundler_cls=[ArtifactRegistryBundler, CloudBuildBundler],
258+
wrap_bundler=[False, True],
259+
)
260+
def test_instantiate(
261+
self,
262+
reservation,
263+
bundler_cls: type[Bundler],
264+
wrap_bundler,
265+
):
266+
class WrappedBundler(Bundler):
267+
@config_class
268+
class Config(Bundler.Config):
269+
inner: Required[Bundler.Config] = REQUIRED
270+
271+
cfg, bundler_cfg = self._job_config(
272+
command="test-command",
273+
bundler_cls=bundler_cls,
274+
reservation=reservation,
275+
num_replicas=1,
276+
)
277+
278+
self.assertIsInstance(cfg.builder, pathways_utils.PathwaysLeaderWorkerTemplate.Config)
279+
cfg.builder = cast(pathways_utils.PathwaysLeaderWorkerTemplate.Config, cfg.builder)
280+
281+
self.assertEqual(cfg.name, cfg.builder.name)
282+
self.assertEqual(cfg.project, self._settings["project"])
283+
self.assertEqual(cfg.zone, self._settings["zone"])
284+
self.assertEqual(
285+
cfg.builder.inner.reservation, reservation or self._settings["gke_reservation"]
286+
)
287+
self.assertEqual(cfg.num_replicas, 1)
288+
# Should work with wrapped bundlers.
289+
if wrap_bundler:
290+
bundler_cfg = WrappedBundler.default_config().set(inner=bundler_cfg)
291+
gke_job = cfg.instantiate(bundler=bundler_cfg.instantiate())
292+
self.assertEqual("v4-8", gke_job._builder._tpu_type)
293+
294+
def test_delete(self):
295+
patch_delete = mock.patch(f"{job.__name__}.delete_k8s_leaderworkerset")
296+
with patch_delete as mock_delete:
297+
cfg, _ = self._job_config(command="test-command", bundler_cls=CloudBuildBundler)
298+
gke_job = cfg.instantiate(bundler=mock.Mock())
299+
gke_job._delete() # pylint: disable=protected-access
300+
mock_delete.assert_called()

axlearn/cloud/gcp/jobset_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,12 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs):
308308
return cfg
309309

310310

311-
class TPUReplicatedJob(SingleReplicatedJob):
312-
"""Builds a replicated jobspec for TPU, to be used with JobSet API."""
311+
class TPUJobBuilder(SingleReplicatedJob):
312+
"""Common base class for TPU Specs"""
313313

314314
@config_class
315315
class Config(SingleReplicatedJob.Config):
316-
"""Configures TPUReplicatedJob.
316+
"""Configures TPUJobBuilder.
317317
318318
Attributes:
319319
reservation: If specified, the TPU reservation name. This is not necessarily specific to
@@ -380,7 +380,7 @@ def define_flags(cls, fv: flags.FlagValues):
380380

381381
@classmethod
382382
def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
383-
cfg: TPUReplicatedJob.Config = super().from_flags(fv, **kwargs)
383+
cfg: TPUJobBuilder.Config = super().from_flags(fv, **kwargs)
384384
default_env = get_default_env(
385385
tpu_type=infer_tpu_type(fv.instance_type),
386386
num_tpu_slices=fv.num_replicas,
@@ -404,7 +404,7 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
404404

405405
def __init__(self, cfg: Config, *, bundler: Bundler):
406406
super().__init__(cfg, bundler=bundler)
407-
cfg: TPUReplicatedJob.Config = self.config
407+
cfg: TPUJobBuilder.Config = self.config
408408
if cfg.output_dir is None:
409409
raise ValueError("cfg.output_dir is required.")
410410
self._tpu_type = infer_tpu_type(cfg.accelerator.instance_type)
@@ -433,7 +433,7 @@ def _build_container(self) -> Nested[Any]:
433433
Returns:
434434
A nested dict corresponding to a k8s Container config.
435435
"""
436-
cfg: TPUReplicatedJob.Config = self.config
436+
cfg: TPUJobBuilder.Config = self.config
437437
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
438438
volume_mounts = [self._output_volume_mount]
439439

@@ -503,7 +503,7 @@ def _build_uploader_container(
503503
Returns:
504504
A nested dict corresponding to a k8s Container config.
505505
"""
506-
cfg: TPUReplicatedJob.Config = self.config
506+
cfg: TPUJobBuilder.Config = self.config
507507
output_volume_mount = output_volume_mount or self._output_volume_mount
508508
dst = f"{cfg.output_dir}/output/$HOSTNAME/"
509509
interval_s = 60
@@ -538,7 +538,7 @@ def _build_pod(self) -> Nested[Any]:
538538
Returns:
539539
A nested dict corresponding to a k8s Pod template, including the pod metadata and spec.
540540
"""
541-
cfg: TPUReplicatedJob.Config = self.config
541+
cfg: TPUJobBuilder.Config = self.config
542542
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
543543
annotations, labels, selector, volumes, tolerations = {}, {}, {}, [], []
544544

@@ -727,6 +727,12 @@ def _build_pod(self) -> Nested[Any]:
727727
spec=spec,
728728
)
729729

730+
731+
class TPUReplicatedJob(TPUJobBuilder):
732+
"""Builds a replicated job spec for a generic TPU job to be used with the JobSet API"""
733+
734+
Config = TPUJobBuilder.Config
735+
730736
def __call__(self) -> Sequence[Nested[Any]]:
731737
"""See `BaseReplicatedJob` docstring for details."""
732738
cfg: TPUReplicatedJob.Config = self.config

axlearn/cloud/gcp/lws_utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
"""Utilities for building LeaderWorkerSet specs"""
4+
5+
from typing import Any, Optional, Sequence
6+
7+
from absl import flags
8+
9+
from axlearn.cloud.common.bundler import Bundler
10+
from axlearn.cloud.common.utils import AcceleratorConfig, FlagConfigurable, accelerator_flags
11+
from axlearn.cloud.gcp.config import gcp_settings
12+
from axlearn.cloud.gcp.jobset_utils import TPUJobBuilder
13+
from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
14+
from axlearn.common.config import REQUIRED, Required, config_class
15+
from axlearn.common.utils import Nested
16+
17+
18+
class BaseLeaderWorkerTemplate(FlagConfigurable):
19+
"""
20+
Common base class for LeaderWorker Templates
21+
"""
22+
23+
@config_class
24+
class Config(FlagConfigurable.Config):
25+
"""
26+
Configures BaseLeaderWorker.
27+
Attributes:
28+
name: Name of the LeaderWorkerSet
29+
command: Command to be executed.
30+
accelerator: Accelerator configuration.
31+
env_vars: Optional env vars to set.
32+
service_account: Optional service account to execute the job as.
33+
output_dir: An optional GCS path to upload LWS outputs to.
34+
"""
35+
36+
name: Required[str] = REQUIRED
37+
# TODO: Change this to be a list of str[], to support different commands
38+
# between leader and workers
39+
command: Required[str] = REQUIRED
40+
accelerator: AcceleratorConfig = AcceleratorConfig()
41+
env_vars: dict[str, str] = {}
42+
service_account: Optional[str] = None
43+
output_dir: Optional[str] = None
44+
45+
@classmethod
46+
def define_flags(cls, fv):
47+
super().define_flags(fv)
48+
common_kwargs = dict(flag_values=fv, allow_override=True)
49+
accelerator_flags(**common_kwargs)
50+
# NOTE: the parent typically sets these flags, so we leave them as None.
51+
flags.DEFINE_string("name", None, "Name of the LWS.", **common_kwargs)
52+
flags.DEFINE_string("command", None, "Command to execute.", **common_kwargs)
53+
flags.DEFINE_multi_string("env", [], "Env var in the format key:value.", **common_kwargs)
54+
flags.DEFINE_string(
55+
"service_account",
56+
None,
57+
"If specified, will run job as the service account.",
58+
**common_kwargs,
59+
)
60+
flags.DEFINE_string(
61+
"output_dir",
62+
None,
63+
"If specified, the directory to store outputs (such as logs).",
64+
**common_kwargs,
65+
)
66+
flags.DEFINE_boolean(
67+
"enable_pre_provisioner", None, "Whether to enable pre-provisioner.", **common_kwargs
68+
)
69+
70+
@classmethod
71+
def from_flags(cls, fv: flags.FlagValues, **kwargs):
72+
cfg: BaseLeaderWorkerTemplate.Config = super().from_flags(fv, **kwargs)
73+
cfg.service_account = cfg.service_account or gcp_settings(
74+
"k8s_service_account", default="default", fv=fv
75+
)
76+
cfg.accelerator.set(instance_type=fv.instance_type, num_replicas=fv.num_replicas)
77+
return cfg
78+
79+
def __init__(self, cfg: Config, *, bundler: Bundler):
80+
super().__init__(cfg)
81+
self._bundler = bundler
82+
83+
def __call__(self) -> Sequence[Nested[Any]]:
84+
"""Builds LeaderWorkerTemplate for the LWS API.
85+
86+
Returns:
87+
A nested dict corresponding to a LeaderWorkerTemplate config.
88+
"""
89+
raise NotImplementedError(type(self))
90+
91+
92+
class TPULeaderWorkerTemplate(TPUJobBuilder):
93+
"""Builds a LeaderWorkerTemplate spec for a generic TPU workload"""
94+
95+
Config = TPUJobBuilder.Config
96+
97+
def __call__(self) -> Sequence[Nested[Any]]:
98+
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
99+
return dict(
100+
size=system.vms_per_slice,
101+
workerTemplate=self._build_pod(),
102+
)

0 commit comments

Comments
 (0)