Open
Description
I was debugging on rank 1 using torch.distributed.breakpoint(rank=1)
, but it's always hanging. It turns out to be caused by --local-ranks-filter 0
in run_llama_train.sh
. Not sure if we want to remind people that two things don't work well together
I have to debug rank 1 (instead of rank0) because dim-0 sharding can be uneven and only rank1+ have paddings
repo:
diff --git a/train.py b/train.py
index 7945949..e2843b2 100644
--- a/train.py
+++ b/train.py
@@ -64,6 +64,8 @@ def main(job_config: JobConfig):
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
utils.init_distributed(job_config)
+
+ torch.distributed.breakpoint(rank=1)