The following three files showcase how to run distributed model training with the DDP i.e. distributed data parallelism approach with @torchrun
and @kubernetes
on Metaflow.
-
datautils.py
is adapted from here -
multinode_trainer.py
is adapted from here -
flow.py
uses the above script viacurrent.torch.run(entrypoint="multinode_trainer.py")
along with some entrypoint arguments such as total epochs, batch size, etc.
- The flow can be run using
python flow.py run
.