Skip to content

Commit 597d743

Browse files
committed
feat: add metadata parameters to dist/spmd components (#1037)
1 parent f7ab10a commit 597d743

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

torchx/components/dist.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def spmd(
9292
h: str = "gpu.small",
9393
j: str = "1x1",
9494
env: Optional[Dict[str, str]] = None,
95+
metadata: Optional[Dict[str, str]] = None,
9596
max_retries: int = 0,
9697
mounts: Optional[List[str]] = None,
9798
debug: bool = False,
@@ -131,6 +132,7 @@ def spmd(
131132
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
132133
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133134
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
135+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
134136
max_retries: the number of scheduler retries allowed
135137
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
136138
Only takes effect when running multi-node. When running single node, this parameter
@@ -153,6 +155,7 @@ def spmd(
153155
h=h,
154156
j=str(StructuredJArgument.parse_from(h, j)),
155157
env=env,
158+
metadata=metadata,
156159
max_retries=max_retries,
157160
mounts=mounts,
158161
debug=debug,
@@ -171,6 +174,7 @@ def ddp(
171174
memMB: int = 1024,
172175
j: str = "1x2",
173176
env: Optional[Dict[str, str]] = None,
177+
metadata: Optional[Dict[str, str]] = None,
174178
max_retries: int = 0,
175179
rdzv_port: int = 29500,
176180
rdzv_backend: str = "c10d",
@@ -203,6 +207,7 @@ def ddp(
203207
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
204208
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
205209
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
210+
metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
206211
max_retries: the number of scheduler retries allowed
207212
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
208213
Only takes effect when running multi-node. When running single node, this parameter
@@ -238,8 +243,8 @@ def ddp(
238243
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
239244
rdzv_endpoint = _noquote(f"$${{{macros.rank0_env}:=localhost}}:{rdzv_port}")
240245

241-
if env is None:
242-
env = {}
246+
env = env or {}
247+
metadata = metadata or {}
243248

244249
argname = StructuredNameArgument.parse_from(
245250
name=name,
@@ -292,6 +297,7 @@ def ddp(
292297
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
293298
args=["-c", _args_join(cmd)],
294299
env=env,
300+
metadata=metadata,
295301
port_map={
296302
"c10d": rdzv_port,
297303
},

torchx/components/test/dist_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def test_ddp_debug(self) -> None:
4040
for k, v in _TORCH_DEBUG_FLAGS.items():
4141
self.assertEqual(env[k], v)
4242

43+
def test_ddp_metadata(self) -> None:
44+
metadata = {"key": "value"}
45+
app = ddp(script="foo.py", metadata=metadata)
46+
meta = app.roles[0].metadata
47+
for k, v in metadata.items():
48+
self.assertEqual(meta[k], v)
49+
self.assertEqual(len(metadata), len(meta))
50+
4351
def test_ddp_rdzv_backend_static(self) -> None:
4452
app = ddp(script="foo.py", rdzv_backend="static")
4553
cmd = app.roles[0].args[1]
@@ -53,6 +61,14 @@ def test_validate_spmd(self) -> None:
5361

5462
self.validate(dist, "ddp")
5563

64+
def test_spmd_metadata(self) -> None:
65+
metadata = {"key": "value"}
66+
app = spmd(script="foo.py", metadata=metadata)
67+
meta = app.roles[0].metadata
68+
for k, v in metadata.items():
69+
self.assertEqual(meta[k], v)
70+
self.assertEqual(len(metadata), len(meta))
71+
5672
def test_spmd_call_by_module_or_script_no_name(self) -> None:
5773
appdef = spmd(script="foo/bar.py")
5874
self.assertEqual("bar", appdef.name)

0 commit comments

Comments
 (0)