Skip to content

Commit 2c05d47

Browse files
committed
New hpset for dataparallel training
1 parent 20ae187 commit 2c05d47

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

hyper_params.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def __init__(self):
5757

5858
# Data parallel
5959
self.rank = 0
60-
self.parallelism = 1
61-
self.discovery_port = 29500
60+
self.parallelism = 1 # Number of data parallel processes. Must be set explicitly when using schedule.py, otherwise runner.py will just spawn a single process.
6261

6362
# Observations
6463
self.obs_allies = 10 # Max number of allied drones returned by the env
@@ -228,6 +227,20 @@ def standard():
228227

229228
return hps
230229

230+
@staticmethod
231+
def standard_dataparallel():
232+
hps = HyperParams.standard()
233+
234+
hps.steps = 300e6
235+
236+
hps.adr_hstepsize = 2e-06
237+
hps.batches_per_update = 16
238+
hps.num_envs = 64
239+
hps.num_self_play = 32
240+
hps.lr = 0.0003
241+
hps.final_lr = 0.00003
242+
243+
return hps
231244

232245
@staticmethod
233246
def arena():

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,8 @@ def main():
10031003
hps = HyperParams.arena()
10041004
elif args.hpset == 'standard':
10051005
hps = HyperParams.standard()
1006+
elif args.hpset == 'standard_dataparallel':
1007+
hps = HyperParams.standard_dataparallel()
10061008
elif args.hpset == 'micro_practice':
10071009
hps = HyperParams.micro_practice()
10081010
elif args.hpset == 'scout':
@@ -1025,8 +1027,6 @@ def main():
10251027
hps.objective = envs.Objective(hps.objective)
10261028

10271029
if hps.parallelism > 1:
1028-
os.environ['MASTER_ADDR'] = 'localhost'
1029-
os.environ['MASTER_PORT'] = str(hps.discovery_port)
10301030
dist.init_process_group(backend='gloo', rank=hps.rank, world_size=hps.parallelism)
10311031

10321032
if hps.rank == 0:

runner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ def git(args, workdir=dir):
129129
with open(logpath, "w+") as outfile:
130130
retcode = subprocess.call(
131131
["python3", "main.py", "--out-dir", out_dir] + args,
132-
env=dict(os.environ, CUDA_VISIBLE_DEVICES=str(job.device)),
132+
env=dict(
133+
os.environ,
134+
CUDA_VISIBLE_DEVICES=str(job.device),
135+
MASTER_ADDR='localhost',
136+
MASTER_PORT=str(job.discovery_port),
137+
),
133138
stdout=outfile, stderr=outfile, cwd=dir
134139
)
135140
if retcode != 0:
@@ -197,13 +202,14 @@ def __init__(self, repo_path, revision, params, handle, parallelism):
197202
self.parallelism = parallelism
198203
self.descriptor = "-".join([revision[:6]] + [f'{k}{v}' for k, v in params.items()])
199204
self.rank = 0
205+
self.discovery_port = None
200206

201207
def set_device(self, device, rank, discovery_port):
202208
self.device = device
203209
self.rank = rank
210+
self.discovery_port = discovery_port
204211
self.params['device'] = device
205212
self.params['rank'] = rank
206-
self.params['discovery_port'] = discovery_port
207213

208214

209215
@click.command()

0 commit comments

Comments
 (0)