Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP hanging and timing out on H100s, not other hardware #3380

Open
EvanKomp opened this issue Feb 6, 2025 · 1 comment
Open

DDP hanging and timing out on H100s, not other hardware #3380

EvanKomp opened this issue Feb 6, 2025 · 1 comment

Comments

@EvanKomp
Copy link

EvanKomp commented Feb 6, 2025

Trace:

accelerate launch pipeline/2.1_self_supervised_training.py
/kfs2/projects/metalsitenn/metal_site_modeling/equiformer/nets/layer_norm.py:89: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
/kfs2/projects/metalsitenn/metal_site_modeling/equiformer/nets/layer_norm.py:89: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  @torch.cuda.amp.autocast(enabled=False)
x3100c0s5b0n0:3347432:3347432 [0] NCCL INFO Bootstrap : Using hsn0:10.150.3.12<0>
x3100c0s5b0n0:3347432:3347432 [0] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
x3100c0s5b0n0:3347433:3347433 [1] NCCL INFO cudaDriverVersion 12040
x3100c0s5b0n0:3347433:3347433 [1] NCCL INFO Bootstrap : Using hsn0:10.150.3.12<0>
x3100c0s5b0n0:3347433:3347433 [1] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
x3100c0s5b0n0:3347432:3347432 [0] NCCL INFO cudaDriverVersion 12040
NCCL version 2.20.5+cuda12.4
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO NET/IB : No device found.
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO NET/IB : No device found.
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO NET/Socket : Using [0]hsn0:10.150.3.12<0> [1]hsn1:10.150.1.122<0> [2]bond0:172.23.1.3<0>
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Using non-device net plugin version 0
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Using network Socket
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO NET/Socket : Using [0]hsn0:10.150.3.12<0> [1]hsn1:10.150.1.122<0> [2]bond0:172.23.1.3<0>
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Using non-device net plugin version 0
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Using network Socket
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO comm 0xaf7b270 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 4000 commId 0x9d8f751b9e10c9be - Init START
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO comm 0xc0884c0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 64000 commId 0x9d8f751b9e10c9be - Init START
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Setting affinity for GPU 1 to 01
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO comm 0xaf7b270 rank 0 nRanks 2 nNodes 1 localRanks 2 localRank 0 MNNVL 0
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 00/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 01/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 02/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 03/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 04/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 05/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 06/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 07/08 :    0   1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] -1/-1/-1->0->1 [3] -1/-1/-1->0->1 [4] 1/-1/-1->0->-1 [5] 1/-1/-1->0->-1 [6] -1/-1/-1->0->1 [7] -1/-1/-1->0->1
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO P2P Chunksize set to 524288
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO comm 0xc0884c0 rank 1 nRanks 2 nNodes 1 localRanks 2 localRank 1 MNNVL 0
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0 [2] 0/-1/-1->1->-1 [3] 0/-1/-1->1->-1 [4] -1/-1/-1->1->0 [5] -1/-1/-1->1->0 [6] 0/-1/-1->1->-1 [7] 0/-1/-1->1->-1
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO P2P Chunksize set to 524288
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 04/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 05/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 06/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Channel 07/0 : 1[1] -> 0[0] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 04/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 05/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 06/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Channel 07/0 : 0[0] -> 1[1] via P2P/CUMEM
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Connected all rings
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO Connected all trees
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Connected all rings
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO Connected all trees
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO 8 coll channels, 0 collnet channels, 0 nvls channels, 8 p2p channels, 8 p2p channels per peer
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO 8 coll channels, 0 collnet channels, 0 nvls channels, 8 p2p channels, 8 p2p channels per peer
x3100c0s5b0n0:3347433:3349752 [1] NCCL INFO comm 0xc0884c0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 64000 commId 0x9d8f751b9e10c9be - Init COMPLETE
x3100c0s5b0n0:3347432:3349753 [0] NCCL INFO comm 0xaf7b270 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 4000 commId 0x9d8f751b9e10c9be - Init COMPLETE
/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.
  warnings.warn(
 20%|████████████████████████████▌                                                                                                                  | 4/20 [00:08<00:41,  2.60s/it]/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/dvc_render/vega.py:169: UserWarning: `generate_markdown` can only be used with `LinearTemplate`
  warn("`generate_markdown` can only be used with `LinearTemplate`")  # noqa: B028
 45%|████████████████████████████████████████████████████████████████▎                                                                              | 9/20 [00:18<00:13,  1.20s/it]/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/dvc_render/vega.py:169: UserWarning: `generate_markdown` can only be used with `LinearTemplate`
  warn("`generate_markdown` can only be used with `LinearTemplate`")  # noqa: B028
/projects/proteinml/.links/miniconda3/envs/metal2/lib/python3.10/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.
  warnings.warn(
 50%|███████████████████████████████████████████████████████████████████████                                                                       | 10/20 [00:32<00:51,  5.10s/it]

Eventually it times out:

Rank 1] Timeout at NCCL work: 456, last enqueued NCCL work: 456, last completed NCCL work: 455.
[rank1]:[E205 14:27:18.083900057 ProcessGroupNCCL.cpp:621] [Rank 1] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E205 14:27:18.083904965 ProcessGroupNCCL.cpp:627] [Rank 1] To avoid data inconsistency, we are taking the entire process down.
[rank1]:[E205 14:27:19.451747711 ProcessGroupNCCL.cpp:1515] [PG 0 (default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=456, OpType=ALLREDUCE, NumelIn=65297, NumelOut=65297, Timeout(ms)=600000) ran for 600054 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:609 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f64e3b3af86 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f64e4e378d2 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f64e4e3e313 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f64e4e406fc in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xd3b65 (0x7f65325d6b65 in /kfs2/projects/proteinml/.links/miniconda3/envs/metal/bin/../lib/libstdc++.so.6)
frame #5: <unknown function> + 0x81ca (0x7f653423f1ca in /lib64/libpthread.so.0)
frame #6: clone + 0x43 (0x7f6533721e73 in /lib64/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 0 (default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=456, OpType=ALLREDUCE, NumelIn=65297, NumelOut=65297, Timeout(ms)=600000) ran for 600054 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:609 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f64e3b3af86 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f64e4e378d2 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f64e4e3e313 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f64e4e406fc in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xd3b65 (0x7f65325d6b65 in /kfs2/projects/proteinml/.links/miniconda3/envs/metal/bin/../lib/libstdc++.so.6)
frame #5: <unknown function> + 0x81ca (0x7f653423f1ca in /lib64/libpthread.so.0)
frame #6: clone + 0x43 (0x7f6533721e73 in /lib64/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f64e3b3af86 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f64e4ac9a84 in /projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xd3b65 (0x7f65325d6b65 in /kfs2/projects/proteinml/.links/miniconda3/envs/metal/bin/../lib/libstdc++.so.6)
frame #3: <unknown function> + 0x81ca (0x7f653423f1ca in /lib64/libpthread.so.0)
frame #4: clone + 0x43 (0x7f6533721e73 in /lib64/libc.so.6)

W0205 14:27:28.725000 140287485417280 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 1075899 closing signal SIGTERM
E0205 14:27:28.964000 140287485417280 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: -6) local_rank: 1 (pid: 1075900) of binary: /projects/proteinml/.links/miniconda3/envs/metal/bin/python3.10
Traceback (most recent call last):
  File "/projects/proteinml/.links/miniconda3/envs/metal/bin/accelerate", line 10, in <module>
    sys.exit(main())
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1163, in launch_command
    multi_gpu_launcher(args)
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/accelerate/commands/launch.py", line 792, in multi_gpu_launcher
    distrib_run.run(args)
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/projects/proteinml/.links/miniconda3/envs/metal/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
========================================================
pipeline/2.1_self_supervised_training.py FAILED
--------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-02-05_14:27:28
  host      : x3100c0s5b0n0.head.cm.kestrel.hpc.nrel.gov
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 1075900)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 1075900
========================================================

The trainer I wrote to run this:

@dataclass
class EarlyStoppingState:
    """Tracks early stopping state."""
    counter: int = 0
    best_metric: float = float('inf')
    best_step: int = 0
    
    def state_dict(self) -> Dict[str, Any]:
        return asdict(self)
        
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        self.counter = state_dict['counter']
        self.best_metric = state_dict['best_metric']
        self.best_step = state_dict['best_step']

    def step(self, metric: float, current_step: int, min_improvement: float) -> bool:
        """Returns True if should stop."""
        improvement = (self.best_metric - metric) / self.best_metric
        if improvement > min_improvement:
            self.counter = 0
            bad_step =  False
        else:
            bad_step = True
            self.counter += 1
            logger.info(f"Early stopping counter triggered: {self.counter}, best metric: {self.best_metric}, current metric: {metric}, improvement: {improvement}, min improvement: {min_improvement}")
        if metric < self.best_metric:
            self.best_metric = metric
            self.best_step = current_step
        return bad_step

@dataclass
class MetalSiteTrainingArgs:
    """Arguments for training."""
    output_dir: str = field(default="./training_output")
    logging_dir: str = field(default="./logs")
    
    # Training loop
    num_epochs: int = field(default=1)
    per_device_train_batch_size: int = field(default=8) 
    per_device_eval_batch_size: int = field(default=8)
    gradient_accumulation_steps: int = field(default=1)
    dataloader_num_workers: int = field(default=0)
    
    # Optimizer
    learning_rate: float = field(default=5e-5)
    weight_decay: float = field(default=0.0)
    gradient_clipping: float = field(default=1.0)
    warmup_pct: float = field(default=0.1)
    frac_noise_loss: float = field(default=0.5)
    
    # Logging and checkpoints
    eval_steps: int = field(default=None)
    logging_steps: int = field(default=100) 
    load_best_model_at_end: bool = field(default=True)
    
    # Early stopping
    use_early_stopping: bool = field(default=False)
    early_stopping_patience: int = field(default=3)
    early_stopping_improvement_fraction: float = field(default=0.0)

    def __str__(self):
        return str(asdict(self))

class MetalSiteTrainer:
    """Trainer for metal site models with distributed training support.
    
    Args
    ----
    model: nn.Module
        Model to train
    compute_loss_fn: Callable
        Function to compute loss. Signiture should be:
            compute_loss_fn(trainer: MetalSiteTrainer, input_batch: Dict[str, torch.Tensor], return_outputs: bool = False) -> Dict[str, torch.Tensor]
            Must return dict like with at least a 'loss' key.
            During evaluation, this is called with return_outputs=True to return model outputs for metrics.
    args: MetalSiteTrainingArgs
        Training arguments
    train_dataset: Dataset
        Training dataset
    eval_dataset: Dataset
        Evaluation dataset
    data_collator: Callable
        Data collator
    eval_metrics: Optional[Dict[str, Callable]]
        Metrics to compute during evaluation. This is a dict of callable, each with signature: f(outputs) where outputs are the 
        returns of compute_loss_fn. If None, only loss is computed
    hard_eval_metrics: Optional[Dict[str, Callable]]
        Metrics that require additional computation and are not directly returned by compute_loss_fn. These are called seperately with trainer as the only argument.
        Up to you to loop through whatever dataset to compute it.
    """
    
    def __init__(
        self,
        model,
        compute_loss_fn: Callable,
        args: MetalSiteTrainingArgs,
        train_dataset=None,
        eval_dataset=None,
        data_collator=None,
        eval_metrics: Optional[Dict[str, Callable]]=None,
        hard_eval_metrics: Optional[Dict[str, Callable]]=None,
        quit_early: bool = False
    ):
        self.args = args
        self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.data_collator = data_collator
        self.compute_loss_fn = compute_loss_fn
        self.eval_metrics = eval_metrics or {}
        
        # Initialize early stopping
        self.early_stopping = EarlyStoppingState() if args.use_early_stopping else None
        
        # Initialize accelerator
        ipgk = InitProcessGroupKwargs(timeout=timedelta(180))
        self.accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            log_with="dvclive",
            project_dir=args.output_dir,
            kwargs_handlers=[ipgk]
        )
        if self.accelerator.is_main_process:
            logger.info(f"Accelerator params: {self.accelerator.__dict__}")
        self.accelerator.init_trackers(project_name="training", init_kwargs={
            "dvclive": {
                "dir": os.path.join(args.output_dir, "dvclive"),
                "report": 'md',
                "save_dvc_exp": False,
                "dvcyaml": None
            }
        })
        
        if self.early_stopping:
            self.accelerator.register_for_checkpointing(self.early_stopping)

        # Create dataloaders
        self.train_dataloader = self._get_train_dataloader() if train_dataset else None
        self.eval_dataloader = self._get_eval_dataloader() if eval_dataset else None

        # Set up optimizer and scheduler   
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay
        )
        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=args.learning_rate,
            epochs=args.num_epochs,
            steps_per_epoch=len(self.train_dataloader),
            pct_start=args.warmup_pct
        )

        # Prepare everything with accelerator
        prepared = self.accelerator.prepare(
            self.model,
            self.optimizer, 
            self.train_dataloader,
            self.eval_dataloader,
            self.scheduler
        )
        self.model, self.optimizer, self.train_dataloader, self.eval_dataloader, self.scheduler = prepared

        self.n_warmup_steps = args.warmup_pct * args.num_epochs * len(self.train_dataloader)

        # hard eval metrics
        self.hard_eval_metrics = hard_eval_metrics or {}

        # create checkpointomg folder if not present
        if self.accelerator.is_main_process:
            if not os.path.exists(os.path.join(args.output_dir, "checkpoints")):
                os.makedirs(os.path.join(args.output_dir, "checkpoints"))
        self.quit_early = quit_early
        os.environ["NCCL_DEBUG"] = "INFO"

    def _get_train_dataloader(self) -> DataLoader:
        """Create training dataloader."""
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            shuffle=True
        )

    def _get_eval_dataloader(self) -> DataLoader:
        """Create evaluation dataloader."""
        return DataLoader(
            self.eval_dataset,
            batch_size=self.args.per_device_eval_batch_size,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers
        )
    
    def save_checkpoint(self, output_dir: str):
        """Save model checkpoint with dynamic parameter handling"""
        # Initialize dynamic params before saving
        dummy_batch = next(iter(self.train_dataloader))
        with torch.no_grad():
            self.model(**dummy_batch)
        
        self.accelerator.save_state(output_dir, safe_serialization=False)

    def load_checkpoint(self, checkpoint_dir: str):
        """Load checkpoint with dynamic parameter handling"""
        # Initialize dynamic params before loading
        dummy_batch = next(iter(self.train_dataloader))
        with torch.no_grad():
            self.model(**dummy_batch)
            
        self.accelerator.load_state(checkpoint_dir)

    def _cleanup_checkpoints(self):
        """Maintain only best checkpoint and last N checkpoints where N=patience."""
        if not self.early_stopping:
            return
            
        checkpoint_dir = os.path.join(self.args.output_dir, "checkpoints")
        checkpoints = sorted([
            int(f.split('_')[-1]) 
            for f in os.listdir(checkpoint_dir) 
            if f.startswith('step_')
        ])
        
        # Always keep best checkpoint
        checkpoints_to_keep = {self.early_stopping.best_step}
        
        # Keep last patience number of checkpoints
        patience_checkpoints = checkpoints[-self.args.early_stopping_patience:]
        checkpoints_to_keep.update(patience_checkpoints)
        
        # Remove others
        for step in checkpoints:
            if step not in checkpoints_to_keep:
                checkpoint_path = os.path.join(checkpoint_dir, f'step_{step}')
                if os.path.exists(checkpoint_path):
                    import shutil
                    shutil.rmtree(checkpoint_path)

    def evaluate(self) -> float:
        """Run evaluation and compute metrics over full dataset."""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        # Initialize metric accumulators for each process
        process_metrics = {name: [] for name in self.eval_metrics.keys()}
        
        for batch in self.eval_dataloader:
            with torch.no_grad():
                outputs = self.compute_loss_fn(self, batch, return_outputs=True)
                loss = outputs["loss"]
                total_loss += loss.detach().float()
                
                # Compute metrics on each process separately
                if self.eval_metrics:
                    for name, func in self.eval_metrics.items():
                        metric_val = func(self, outputs, batch)
                        if metric_val is not None:
                            process_metrics[name].append(metric_val)
                            
            num_batches += 1

        # Gather and average loss across processes
        total_loss = self.accelerator.gather(total_loss).mean()
        num_batches = self.accelerator.gather(torch.tensor(num_batches, device=self.accelerator.device, dtype=torch.float)).mean()
        avg_loss = total_loss / num_batches

        # Average metrics for each process then gather
        metrics = {"eval/loss": avg_loss.cpu().item()}
        if self.eval_metrics:
            for name, values in process_metrics.items():
                if values:  # Only process if we have values
                    process_avg = torch.tensor(sum(values) / len(values), device=self.accelerator.device)
                    gathered_avgs = self.accelerator.gather(process_avg)
                    metrics[f"eval/{name}"] = gathered_avgs.mean().cpu().item()
                else:
                    metrics[f"eval/{name}"] = float('nan')
                    
        self.accelerator.log(metrics, step=self.global_step)

        # Run any hard metrics
        for name, func in self.hard_eval_metrics.items():
            func(self)
        
        self.model.train()
        torch.cuda.empty_cache()
        return avg_loss.item()

    def train(self, resume_from_checkpoint: Optional[str] = None):
        # Add global step tracking
        self.global_step = 0
        if resume_from_checkpoint:
            # Assuming checkpoint contains global step
            self.global_step = int(resume_from_checkpoint.split('_')[-1])
            self.accelerator.load_state(resume_from_checkpoint)
            logger.info(f"Resumed from checkpoint: {resume_from_checkpoint}")

        if self.accelerator.is_main_process:
            logger.info(
                f"Training with {self.accelerator.num_processes} processes on {self.accelerator.device.type}\n"
                f" - output_dir: {self.args.output_dir}\n"
                f" - examples in dataset: {len(self.train_dataset)}\n"
                f" - per device batch size: {self.args.per_device_train_batch_size}\n"
                f" - gradient accumulation steps: {self.args.gradient_accumulation_steps}\n"

                f" - effective batch size: {self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.accelerator.num_processes}\n"
                f" - total epochs: {self.args.num_epochs}\n"
                f" - steps per epoch: {len(self.train_dataloader)}\n"
                f" - total steps: {self.args.num_epochs * len(self.train_dataloader)}\n"
                f" - param updates per epoch: {len(self.train_dataloader) // self.args.gradient_accumulation_steps}\n"
                f" - warmup steps: {self.n_warmup_steps}\n"
                f" - log training loss every {self.args.logging_steps} steps\n"
                f" - eval and checkpoint every {self.args.eval_steps} steps\n"
                f" - total trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}"
            )

        # run eval before training
        if not self.quit_early:
            self.evaluate()
        
        # Training loop
        for epoch in range(self.args.num_epochs):
            self.model.train()
            total_loss = 0
            
            progress_bar = tqdm(
                self.train_dataloader,
                disable=not self.accelerator.is_local_main_process
            )

            for batch in progress_bar:
                with self.accelerator.accumulate(self.model):
                    outputs = self.compute_loss_fn(self, batch)
                    loss = outputs["loss"]
                    
                    self.accelerator.backward(loss)
                    
                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(
                            self.model.parameters(),
                            self.args.gradient_clipping
                        )
                        self.optimizer.step()
                        self.optimizer.zero_grad()
                        self.scheduler.step()

                        if self.quit_early:
                            logger.info("Quitting early")
                            return

                    total_loss += loss.detach().float()

                # Increment global step
                self.global_step += 1

                # Log training metrics
                if self.global_step > 0 and self.global_step % self.args.logging_steps == 0:
                    avg_loss = total_loss / self.args.logging_steps
                    self.accelerator.log({
                        "train/loss": avg_loss.item(),
                        "train/epoch": epoch,
                        "train/global_step": self.global_step,
                        "train/learning_rate": self.optimizer.param_groups[0]["lr"]
                    }, step=self.global_step)
                    total_loss = 0

                # Evaluate and checkpoint if needed
                if (
                    self.args.eval_steps 
                    and self.global_step > 0 
                    and self.global_step % self.args.eval_steps == 0
                ):
                    eval_loss = self.evaluate()
                    self.model.train()

                    # Save checkpoint
                    if self.accelerator.is_main_process:
                        output_dir = os.path.join(
                            self.args.output_dir,
                            "checkpoints",
                            f"step_{self.global_step}"
                        )
                        self.save_checkpoint(output_dir)
                        self._cleanup_checkpoints()

                        if self.early_stopping:
                            should_stop = self.early_stopping.step(
                                eval_loss,
                                self.global_step,
                                self.args.early_stopping_improvement_fraction
                            )
                            if (should_stop and 
                                self.early_stopping.counter >= self.args.early_stopping_patience):
                                if self.global_step > self.n_warmup_steps:
                                    logger.info("Early stopping triggered")
                                    self._finish_up()
                                    return
                        
        # Finish up
        if self.accelerator.is_main_process:
            self._finish_up()


    def _finish_up(self):
        output_dir = os.path.join(
            self.args.output_dir,
            "checkpoints",
            f"step_{self.global_step}"
        )
        self.save_checkpoint(output_dir)

        if self.args.load_best_model_at_end and self.early_stopping and self.early_stopping.best_step > 0:
            best_model_path = os.path.join(
                self.args.output_dir,
                "checkpoints",
                f"step_{self.early_stopping.best_step}"
            )
            logger.info(f"Loading best model from step {self.early_stopping.best_step}")
            self.load_checkpoint(best_model_path)

My accelerate config:

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Some notes:

  1. Script runs on 1 H100
  2. Script runs DDP with 2 A40s, just not on DDP H100s
  3. When I step through line by line with the debugger, the hang happens at the first accelerator.sync_gradients after evaluation. Eg. the evaluation loop runs successfully at the start of training, some training steps run, the first evaluation during training runs successfully, then when it tries to make the next training step it dies.

This is probably a lower level issue than accelerate but I not sure how to get down further and determine what. Any help appreciated.

@KeshavSingh29
Copy link

This is usually an issue with how you setup NCCL and connection between two nodes or GPUs.
I am assuming that the gpus cannot sync gradients due to network timeout issue or they cannot sync the gradients either due to slow speed connection (low probability) or version mismatch (high chances).

Are both GPUs on same node?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants