@@ -92,6 +92,7 @@ def spmd(
92
92
h : str = "gpu.small" ,
93
93
j : str = "1x1" ,
94
94
env : Optional [Dict [str , str ]] = None ,
95
+ metadata : Optional [Dict [str , str ]] = None ,
95
96
max_retries : int = 0 ,
96
97
mounts : Optional [List [str ]] = None ,
97
98
debug : bool = False ,
@@ -131,6 +132,7 @@ def spmd(
131
132
h: the type of host to run on (e.g. aws_p4d.24xlarge). Must be one of the registered named resources
132
133
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133
134
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
135
+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
134
136
max_retries: the number of scheduler retries allowed
135
137
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
136
138
Only takes effect when running multi-node. When running single node, this parameter
@@ -153,6 +155,7 @@ def spmd(
153
155
h = h ,
154
156
j = str (StructuredJArgument .parse_from (h , j )),
155
157
env = env ,
158
+ metadata = metadata ,
156
159
max_retries = max_retries ,
157
160
mounts = mounts ,
158
161
debug = debug ,
@@ -171,6 +174,7 @@ def ddp(
171
174
memMB : int = 1024 ,
172
175
j : str = "1x2" ,
173
176
env : Optional [Dict [str , str ]] = None ,
177
+ metadata : Optional [Dict [str , str ]] = None ,
174
178
max_retries : int = 0 ,
175
179
rdzv_port : int = 29500 ,
176
180
rdzv_backend : str = "c10d" ,
@@ -203,6 +207,7 @@ def ddp(
203
207
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
204
208
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
205
209
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
210
+ metadata: metadata to be passed to the scheduler (e.g. KEY1=v1,KEY2=v2,KEY3=v3)
206
211
max_retries: the number of scheduler retries allowed
207
212
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
208
213
Only takes effect when running multi-node. When running single node, this parameter
@@ -238,8 +243,8 @@ def ddp(
238
243
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
239
244
rdzv_endpoint = _noquote (f"$${{{ macros .rank0_env } :=localhost}}:{ rdzv_port } " )
240
245
241
- if env is None :
242
- env = {}
246
+ env = env or {}
247
+ metadata = metadata or {}
243
248
244
249
argname = StructuredNameArgument .parse_from (
245
250
name = name ,
@@ -292,6 +297,7 @@ def ddp(
292
297
resource = specs .resource (cpu = cpu , gpu = gpu , memMB = memMB , h = h ),
293
298
args = ["-c" , _args_join (cmd )],
294
299
env = env ,
300
+ metadata = metadata ,
295
301
port_map = {
296
302
"c10d" : rdzv_port ,
297
303
},
0 commit comments