From 23f210815e06147cf95fff3fede4641cc1c43101 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 10 Feb 2025 16:33:23 +0000 Subject: [PATCH] fix stream not sync --- src/nanotron/constants.py | 4 + src/nanotron/models/llama.py | 46 ++++---- src/nanotron/parallel/comm.py | 15 ++- src/nanotron/parallel/dependency.py | 102 ++++++++++++++++++ .../distributed_differentiable_primitives.py | 4 + src/nanotron/parallel/tensor_parallel/nn.py | 8 +- src/nanotron/trainer.py | 17 ++- 7 files changed, 172 insertions(+), 24 deletions(-) create mode 100644 src/nanotron/parallel/dependency.py diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 78fd0bb9..3fe440a8 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -13,3 +13,7 @@ CUDA_STREAMS = {} + +CLOCK = 0 +_AUTOGRAD_RUNS = [] +_NOT_BWD_ASYNC_OPS = [] diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index acbece96..72ebf478 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -245,13 +245,15 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - async_all_reduce=parallel_config.domino.num_input_batches > 1, + # async_all_reduce=parallel_config.domino.num_input_batches > 1, ) self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx) - hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx) + hidden_states, work = self.down_proj( + self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx + ) return {"hidden_states": hidden_states, "work": work} @@ -428,7 +430,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - async_all_reduce=async_all_reduce, + # async_all_reduce=async_all_reduce, ) self.attention = CoreAttention( @@ -699,7 +701,7 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) - output, work = self.o_proj(attention_output, handle_idx=handle_idx) + output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx) return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask} @@ -876,6 +878,7 @@ def _core_forward( comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()] with torch.cuda.stream(comm_stream): attn_output0["work"].wait() + attn_output0["work"].is_completed() hidden_states0 = attn_output0["hidden_states"] + residual0 residual0 = hidden_states0 @@ -890,8 +893,16 @@ def _core_forward( handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0), ) + # attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend( + # run_after=attn_output1["hidden_states"], + # run_before=mlp_output0["hidden_states"] + # ) + with torch.cuda.stream(comm_stream): attn_output1["work"].wait() + attn_output1["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) hidden_states1 = attn_output1["hidden_states"] + residual1 residual1 = hidden_states1 @@ -906,11 +917,24 @@ def _core_forward( mlp_output0["work"].wait() mlp_output1["work"].wait() + mlp_output0["work"].is_completed() + mlp_output1["work"].is_completed() + + torch.cuda.current_stream().wait_stream(comm_stream) + hidden_states0 = mlp_output0["hidden_states"] + residual0 hidden_states1 = mlp_output1["hidden_states"] + residual1 + # hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1) + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) assert 1 == 1 + + # assert attn_output0["work"].is_completed() + # assert attn_output1["work"].is_completed() + # assert mlp_output0["work"].is_completed() + # assert mlp_output1["work"].is_completed() + return hidden_states, orig_sequence_mask def _checkpointed_forward( @@ -1080,23 +1104,9 @@ def forward_with_hidden_states( "sequence_mask": input_mask, } - # assert 1 == 1 - # num_input_batches = self.parallel_config.domino.num_input_batches - # hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1) - # hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0) - - # # Combine the chunks into a list of dictionaries - # hidden_encoder_states_list = [ - # {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]} - # for i in range(num_input_batches) - # ] - for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) - # for hidden_encoder_states in hidden_encoder_states_list: - # hidden_encoder_states = encoder_block(**hidden_encoder_states) - hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] diff --git a/src/nanotron/parallel/comm.py b/src/nanotron/parallel/comm.py index b00f6e9e..789416c3 100644 --- a/src/nanotron/parallel/comm.py +++ b/src/nanotron/parallel/comm.py @@ -33,6 +33,7 @@ class AsyncCommBucket: """ _async_op: Dict[int, "dist.Work"] = {} + _copy_async_op: Dict[int, "dist.Work"] = {} @staticmethod def add(tensor_id: int, work: "dist.Work"): @@ -40,6 +41,7 @@ def add(tensor_id: int, work: "dist.Work"): tensor_id not in AsyncCommBucket._async_op ), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}" AsyncCommBucket._async_op[tensor_id] = work + AsyncCommBucket._copy_async_op[tensor_id] = work @staticmethod def get(tensor_id: int): @@ -58,6 +60,7 @@ def wait(tensor_id: int): @staticmethod def clear_all(): AsyncCommBucket._async_op.clear() + AsyncCommBucket._copy_async_op.clear() def is_async_comm(x): @@ -92,9 +95,19 @@ def backward(ctx, grad_output): assert 1 == 1 if is_async_comm(ctx.wait_handle_idx): + from nanotron.constants import _AUTOGRAD_RUNS + + _AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}") handle = AsyncCommBucket.pop(ctx.wait_handle_idx) assert handle is not None handle.wait() - # assert handle.is_completed() is True + # assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}" + else: + + from nanotron import constants + + # if dist.get_rank() == 0: + # constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) + constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx) return grad_output, None diff --git a/src/nanotron/parallel/dependency.py b/src/nanotron/parallel/dependency.py new file mode 100644 index 00000000..6a633d8a --- /dev/null +++ b/src/nanotron/parallel/dependency.py @@ -0,0 +1,102 @@ +from typing import Dict, Tuple + +import torch +from torch import Tensor + +_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} + + +def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: + """Gets a phony. Phony is tensor without space. It is useful to make + arbitrary dependency in a autograd graph because it doesn't require any + gradient accumulation. + + .. note:: + + Phonies for each device are cached. If an autograd function gets a phony + internally, the phony must be detached to be returned. Otherwise, the + autograd engine will mutate the cached phony in-place:: + + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() # detach() is necessary. + + """ + key = (device, requires_grad) + + try: + phony = _phonies[key] + except KeyError: + with torch.cuda.stream(torch.cuda.default_stream(device)): + phony = torch.empty(0, device=device, requires_grad=requires_grad) + + _phonies[key] = phony + + return phony + + +def fork(input: Tensor) -> Tuple[Tensor, Tensor]: + """Branches out from an autograd lane of the given tensor.""" + if torch.is_grad_enabled() and input.requires_grad: + input, phony = Fork.apply(input) + else: + phony = get_phony(input.device, requires_grad=False) + + return input, phony + + +class Fork(torch.autograd.Function): + @staticmethod + def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + phony = get_phony(input.device, requires_grad=False) + return input, phony.detach() + + @staticmethod + def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + return grad_input + + +def join(input: Tensor, phony: Tensor) -> Tensor: + """Merges two autograd lanes.""" + if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): + input = Join.apply(input, phony) + + return input + + +class Join(torch.autograd.Function): + @staticmethod + def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore + return input + + @staticmethod + def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore + # import pydevd + # pydevd.settrace(suspend=False, trace_only_current_thread=True) + return grad_input, None + + +# def depend(fork_from, join_to) -> None: +# # Ensure that batches[i-1] is executed after batches[i] in +# # # backpropagation by an explicit dependency. +# # if i != 0: +# # depend(batches[i-1], batches[i]) +# # depend(run_after, run_before) +# fork_from, phony = fork(fork_from) +# join_to = join(join_to, phony) +# return fork_from, join_to + + +def depend(run_after, run_before) -> None: + # Ensure that batches[i-1] is executed after batches[i] in + # # backpropagation by an explicit dependency. + # if i != 0: + # depend(batches[i-1], batches[i]) + # depend(run_after, run_before) + run_after, phony = fork(run_after) + run_before = join(run_before, phony) + return run_after, run_before diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index c4f69c05..58275368 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -58,6 +58,10 @@ def backward(ctx, grad_output): if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True: assert 1 == 1 + from nanotron.constants import _AUTOGRAD_RUNS + + _AUTOGRAD_RUNS.append(handle_idx) + return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 847454fd..4fea1838 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -115,7 +115,7 @@ def __init__( device=None, dtype=None, async_communication: bool = False, - async_all_reduce: bool = False, + # async_all_reduce: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, ): self.pg = pg @@ -138,7 +138,7 @@ def __init__( ) self.mode = mode self.async_communication = async_communication - self.async_all_reduce = async_all_reduce + # self.async_all_reduce = async_all_reduce if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: raise ValueError("async_communication is not supported for ALL_REDUCE mode") @@ -164,7 +164,7 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig): ) setattr(self, name, new_param) - def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: + def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor: return row_linear( input=x, weight=self.weight, @@ -172,7 +172,7 @@ def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - async_all_reduce=self.async_all_reduce, + async_all_reduce=async_all_reduce, handle_idx=handle_idx, ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 96af6b12..e58af9f5 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -564,6 +564,9 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) + if dist.get_rank() == 0: + assert 1 == 1 + # Apply gradient self.optimizer.step() self.optimizer.zero_grad() @@ -580,8 +583,20 @@ def training_step( from nanotron.parallel.comm import AsyncCommBucket + # import torch.distributed as dist + + not_finished = [] + for k, v in AsyncCommBucket._copy_async_op.items(): + # assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}" + if v.is_completed() is not True: + not_finished.append((k, v)) + + # if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS: + # assert 1 == 1 + + assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}" assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}" - # AsyncCommBucket.clear_all() + AsyncCommBucket.clear_all() return outputs, loss_avg