Skip to content

Commit

Permalink
fix stream not sync
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 10, 2025
1 parent 31db05d commit 23f2108
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 24 deletions.
4 changes: 4 additions & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@


CUDA_STREAMS = {}

CLOCK = 0
_AUTOGRAD_RUNS = []
_NOT_BWD_ASYNC_OPS = []
46 changes: 28 additions & 18 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ def __init__(
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
async_all_reduce=parallel_config.domino.num_input_batches > 1,
# async_all_reduce=parallel_config.domino.num_input_batches > 1,
)
self.split_silu_mul = GLUActivation(config.hidden_act)

def forward(self, hidden_states, handle_idx=None): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states, async_all_reduce=True, handle_idx=handle_idx)
hidden_states, work = self.down_proj(self.split_silu_mul(merged_states), handle_idx)
hidden_states, work = self.down_proj(
self.split_silu_mul(merged_states), async_all_reduce=True, handle_idx=handle_idx
)
return {"hidden_states": hidden_states, "work": work}


Expand Down Expand Up @@ -428,7 +430,7 @@ def __init__(
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
async_all_reduce=async_all_reduce,
# async_all_reduce=async_all_reduce,
)

self.attention = CoreAttention(
Expand Down Expand Up @@ -699,7 +701,7 @@ def forward(
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
output, work = self.o_proj(attention_output, handle_idx=handle_idx)
output, work = self.o_proj(attention_output, async_all_reduce=True, handle_idx=handle_idx)

return {"hidden_states": output, "work": work, "sequence_mask": sequence_mask}

Expand Down Expand Up @@ -876,6 +878,7 @@ def _core_forward(
comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
with torch.cuda.stream(comm_stream):
attn_output0["work"].wait()
attn_output0["work"].is_completed()

hidden_states0 = attn_output0["hidden_states"] + residual0
residual0 = hidden_states0
Expand All @@ -890,8 +893,16 @@ def _core_forward(
handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
)

# attn_output1["hidden_states"], mlp_output0["hidden_states"] = depend(
# run_after=attn_output1["hidden_states"],
# run_before=mlp_output0["hidden_states"]
# )

with torch.cuda.stream(comm_stream):
attn_output1["work"].wait()
attn_output1["work"].is_completed()

torch.cuda.current_stream().wait_stream(comm_stream)

hidden_states1 = attn_output1["hidden_states"] + residual1
residual1 = hidden_states1
Expand All @@ -906,11 +917,24 @@ def _core_forward(
mlp_output0["work"].wait()
mlp_output1["work"].wait()

mlp_output0["work"].is_completed()
mlp_output1["work"].is_completed()

torch.cuda.current_stream().wait_stream(comm_stream)

hidden_states0 = mlp_output0["hidden_states"] + residual0
hidden_states1 = mlp_output1["hidden_states"] + residual1

# hidden_states0, hidden_states1 = depend(run_after=hidden_states0, run_before=hidden_states1)

hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
assert 1 == 1

# assert attn_output0["work"].is_completed()
# assert attn_output1["work"].is_completed()
# assert mlp_output0["work"].is_completed()
# assert mlp_output1["work"].is_completed()

return hidden_states, orig_sequence_mask

def _checkpointed_forward(
Expand Down Expand Up @@ -1080,23 +1104,9 @@ def forward_with_hidden_states(
"sequence_mask": input_mask,
}

# assert 1 == 1
# num_input_batches = self.parallel_config.domino.num_input_batches
# hidden_encoder_states["hidden_states"] = torch.chunk(hidden_encoder_states["hidden_states"], chunks=num_input_batches, dim=1)
# hidden_encoder_states["sequence_mask"] = torch.chunk(hidden_encoder_states["sequence_mask"], chunks=num_input_batches, dim=0)

# # Combine the chunks into a list of dictionaries
# hidden_encoder_states_list = [
# {"hidden_states": hidden_encoder_states["hidden_states"][i], "sequence_mask": hidden_encoder_states["sequence_mask"][i]}
# for i in range(num_input_batches)
# ]

for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)

# for hidden_encoder_states in hidden_encoder_states_list:
# hidden_encoder_states = encoder_block(**hidden_encoder_states)

hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]

sharded_logits = self.lm_head(x=hidden_states)["logits"]
Expand Down
15 changes: 14 additions & 1 deletion src/nanotron/parallel/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ class AsyncCommBucket:
"""

_async_op: Dict[int, "dist.Work"] = {}
_copy_async_op: Dict[int, "dist.Work"] = {}

@staticmethod
def add(tensor_id: int, work: "dist.Work"):
assert (
tensor_id not in AsyncCommBucket._async_op
), f"tensor_id: {tensor_id}, keys: {AsyncCommBucket._async_op.keys()}"
AsyncCommBucket._async_op[tensor_id] = work
AsyncCommBucket._copy_async_op[tensor_id] = work

@staticmethod
def get(tensor_id: int):
Expand All @@ -58,6 +60,7 @@ def wait(tensor_id: int):
@staticmethod
def clear_all():
AsyncCommBucket._async_op.clear()
AsyncCommBucket._copy_async_op.clear()


def is_async_comm(x):
Expand Down Expand Up @@ -92,9 +95,19 @@ def backward(ctx, grad_output):
assert 1 == 1

if is_async_comm(ctx.wait_handle_idx):
from nanotron.constants import _AUTOGRAD_RUNS

_AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}")
handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
assert handle is not None
handle.wait()
# assert handle.is_completed() is True
# assert handle.is_completed() is True, f"ctx.wait_handle_idx: {ctx.wait_handle_idx}"
else:

from nanotron import constants

# if dist.get_rank() == 0:
# constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)
constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)

return grad_output, None
102 changes: 102 additions & 0 deletions src/nanotron/parallel/dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Dict, Tuple

import torch
from torch import Tensor

_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}


def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
"""Gets a phony. Phony is tensor without space. It is useful to make
arbitrary dependency in a autograd graph because it doesn't require any
gradient accumulation.
.. note::
Phonies for each device are cached. If an autograd function gets a phony
internally, the phony must be detached to be returned. Otherwise, the
autograd engine will mutate the cached phony in-place::
class Phonify(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
phony = get_phony(input.device, requires_grad=False)
return phony.detach() # detach() is necessary.
"""
key = (device, requires_grad)

try:
phony = _phonies[key]
except KeyError:
with torch.cuda.stream(torch.cuda.default_stream(device)):
phony = torch.empty(0, device=device, requires_grad=requires_grad)

_phonies[key] = phony

return phony


def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
"""Branches out from an autograd lane of the given tensor."""
if torch.is_grad_enabled() and input.requires_grad:
input, phony = Fork.apply(input)
else:
phony = get_phony(input.device, requires_grad=False)

return input, phony


class Fork(torch.autograd.Function):
@staticmethod
def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
phony = get_phony(input.device, requires_grad=False)
return input, phony.detach()

@staticmethod
def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore
# import pydevd
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
return grad_input


def join(input: Tensor, phony: Tensor) -> Tensor:
"""Merges two autograd lanes."""
if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
input = Join.apply(input, phony)

return input


class Join(torch.autograd.Function):
@staticmethod
def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore
return input

@staticmethod
def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore
# import pydevd
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
return grad_input, None


# def depend(fork_from, join_to) -> None:
# # Ensure that batches[i-1] is executed after batches[i] in
# # # backpropagation by an explicit dependency.
# # if i != 0:
# # depend(batches[i-1], batches[i])
# # depend(run_after, run_before)
# fork_from, phony = fork(fork_from)
# join_to = join(join_to, phony)
# return fork_from, join_to


def depend(run_after, run_before) -> None:
# Ensure that batches[i-1] is executed after batches[i] in
# # backpropagation by an explicit dependency.
# if i != 0:
# depend(batches[i-1], batches[i])
# depend(run_after, run_before)
run_after, phony = fork(run_after)
run_before = join(run_before, phony)
return run_after, run_before
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def backward(ctx, grad_output):
if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True:
assert 1 == 1

from nanotron.constants import _AUTOGRAD_RUNS

_AUTOGRAD_RUNS.append(handle_idx)

return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, handle_idx), None, None, None


Expand Down
8 changes: 4 additions & 4 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
device=None,
dtype=None,
async_communication: bool = False,
async_all_reduce: bool = False,
# async_all_reduce: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
):
self.pg = pg
Expand All @@ -138,7 +138,7 @@ def __init__(
)
self.mode = mode
self.async_communication = async_communication
self.async_all_reduce = async_all_reduce
# self.async_all_reduce = async_all_reduce
if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
raise ValueError("async_communication is not supported for ALL_REDUCE mode")

Expand All @@ -164,15 +164,15 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
)
setattr(self, name, new_param)

def forward(self, x: torch.Tensor, handle_idx=None) -> torch.Tensor:
def forward(self, x: torch.Tensor, async_all_reduce, handle_idx=None) -> torch.Tensor:
return row_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
async_all_reduce=self.async_all_reduce,
async_all_reduce=async_all_reduce,
handle_idx=handle_idx,
)

Expand Down
17 changes: 16 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,9 @@ def training_step(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
)

if dist.get_rank() == 0:
assert 1 == 1

# Apply gradient
self.optimizer.step()
self.optimizer.zero_grad()
Expand All @@ -580,8 +583,20 @@ def training_step(

from nanotron.parallel.comm import AsyncCommBucket

# import torch.distributed as dist

not_finished = []
for k, v in AsyncCommBucket._copy_async_op.items():
# assert v.is_completed(), f"AsyncCommBucket._copy_async_op: {AsyncCommBucket._copy_async_op}"
if v.is_completed() is not True:
not_finished.append((k, v))

# if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS:
# assert 1 == 1

assert len(not_finished) == 0, f"AsyncCommBucket._copy_async_op: {not_finished}"
assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}"
# AsyncCommBucket.clear_all()
AsyncCommBucket.clear_all()

return outputs, loss_avg

Expand Down

0 comments on commit 23f2108

Please sign in to comment.