Skip to content

torch.distributed.breakpoint(rank=1) hangs because of --local-ranks-filter 0 #652

Open
@weifengpy

Description

@weifengpy

I was debugging on rank 1 using torch.distributed.breakpoint(rank=1), but it's always hanging. It turns out to be caused by --local-ranks-filter 0 in run_llama_train.sh. Not sure if we want to remind people that two things don't work well together

I have to debug rank 1 (instead of rank0) because dim-0 sharding can be uneven and only rank1+ have paddings

repo:

diff --git a/train.py b/train.py
index 7945949..e2843b2 100644
--- a/train.py
+++ b/train.py
@@ -64,6 +64,8 @@ def main(job_config: JobConfig):
     device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
     torch.cuda.set_device(device)
     utils.init_distributed(job_config)
+
+    torch.distributed.breakpoint(rank=1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions