Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 12, 2025
1 parent aa3e973 commit ea09a25
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 73 deletions.
2 changes: 1 addition & 1 deletion examples/config_llama_domino.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ data_stages:
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
ignore_sanity_checks: false
project: nanotron_domino
run: config_llama_domino
seed: 42
Expand Down
4 changes: 0 additions & 4 deletions src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,6 @@ def build_grad_buffers(
return fp32_grad_buffers, contiguous_buffer_f32_gradients

def backward(self, loss: torch.Tensor):
if not isinstance(loss, torch.Tensor):
assert 1 == 1
raise NotImplementedError("Not implemented yet")

result = loss.backward()

for name, elt in self.fp32_grad_buffers.items():
Expand Down
10 changes: 10 additions & 0 deletions src/nanotron/parallel/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ def wait(tensor_id: int):
work = AsyncCommBucket._async_op.pop(tensor_id)
work.wait()

@staticmethod
def is_all_completed() -> bool:
assert len(AsyncCommBucket._async_op) == 0, "there are still some async ops haven't executed"

not_finished = []
for k, v in AsyncCommBucket._copy_async_op.items():
if v.is_completed() is not True:
not_finished.append((k, v))
return len(not_finished) == 0

@staticmethod
def clear_all():
AsyncCommBucket._async_op.clear()
Expand Down
3 changes: 0 additions & 3 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def backward(
if grad_accumulator is None:
sum(activations).backward()
else:
# if not isinstance(activations, torch.Tensor):
# raise NotImplementedError("Only support sum of tensors for now")

grad_accumulator.backward(sum(activations))

# TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,9 @@ def forward(

@staticmethod
def backward(ctx, grad_output):
# import pydevd
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
from nanotron.constants import _AUTOGRAD_RUNS

_AUTOGRAD_RUNS.append(ctx.op_name)

group = ctx.group

op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce

return (
DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream),
None,
Expand Down
23 changes: 1 addition & 22 deletions src/nanotron/parallel/tensor_parallel/domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,8 @@ def forward(ctx, input, wait_handle_idx, comm_stream):

@staticmethod
def backward(ctx, grad_output):
# import pydevd
# pydevd.settrace(suspend=False, trace_only_current_thread=True)

if "bwd.layer_mlp_1_batch_0" == ctx.wait_handle_idx:
assert 1 == 1

if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx:
assert 1 == 1

if is_async_comm(ctx.wait_handle_idx):
from nanotron.constants import _AUTOGRAD_RUNS

_AUTOGRAD_RUNS.append(f"wait_{ctx.wait_handle_idx}")
handle = AsyncCommBucket.pop(ctx.wait_handle_idx)
assert handle is not None
handle.wait()
AsyncCommBucket.wait(ctx.wait_handle_idx)
torch.cuda.default_stream().wait_stream(ctx.comm_stream)
else:
from nanotron import constants

constants._NOT_BWD_ASYNC_OPS.append(ctx.wait_handle_idx)

# if "bwd.layer_mlp_0_batch_1" == ctx.wait_handle_idx:
# assert AsyncCommBucket._copy_async_op.get(ctx.wait_handle_idx).is_completed() is True

return grad_output, None, None
22 changes: 1 addition & 21 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import functional as F

import nanotron.distributed as dist
from nanotron.parallel.comm import AsyncCommBucket
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_reduce_sum,
differentiable_identity,
Expand Down Expand Up @@ -600,30 +601,9 @@ def row_linear(
out = F.linear(input, weight, bias)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
# id(out)
# NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
if op_name == "fwd.layer_attn_0_batch_0":
assert 1 == 1

if op_name == "fwd.layer_mlp_0_batch_1":
assert 1 == 1

if op_name == "fwd.layer_attn_0_batch_0":
assert 1 == 1

out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, op_name=op_name)
if async_all_reduce:
from nanotron.parallel.comm import AsyncCommBucket

# work = AsyncCommBucket.get(orig_out_id)
# work = AsyncCommBucket.pop(orig_out_id)
# if handle_idx == "fwd.layer_mlp_1_batch_0":
if op_name == "fwd.layer_attn_0_batch_0":
assert 1 == 1

work = AsyncCommBucket.pop(op_name)
assert 1 == 1
else:
work = None
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
Expand Down
4 changes: 4 additions & 0 deletions src/nanotron/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from nanotron.models import NanotronModel
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel import ParallelContext
from nanotron.parallel.comm import AsyncCommBucket
from nanotron.parallel.tied_parameters import get_tied_id_to_param

logger = get_logger(__name__)
Expand Down Expand Up @@ -239,6 +240,9 @@ def before_optim_step_sanity_checks(
# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_optim_step_sanity_checks()

# SANITY CHECK: for domino
assert AsyncCommBucket.is_all_completed(), "There are still some async ops haven't finishing"


def after_optim_step_sanity_checks(
config: Config,
Expand Down
16 changes: 2 additions & 14 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from nanotron.models.starcoder2 import Starcoder2ForTraining
from nanotron.optim.clip_grads import clip_grad_norm
from nanotron.parallel import ParallelContext
from nanotron.parallel.comm import AsyncCommBucket
from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
from nanotron.parallel.parameters import NanotronParameter, sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
Expand Down Expand Up @@ -563,6 +564,7 @@ def training_step(
before_optim_step_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
)
AsyncCommBucket.clear_all()

# Apply gradient
self.optimizer.step()
Expand All @@ -578,20 +580,6 @@ def training_step(

self.post_train_step()

from nanotron.parallel.comm import AsyncCommBucket

not_finished = []
for k, v in AsyncCommBucket._copy_async_op.items():
if v.is_completed() is not True:
not_finished.append((k, v))

# if dist.get_rank() == 0 and constants._NOT_BWD_ASYNC_OPS:
# assert 1 == 1

assert len(not_finished) == 0, f"len={len(not_finished)}, AsyncCommBucket._copy_async_op: {not_finished}"
assert len(AsyncCommBucket._async_op) == 0, f"AsyncCommBucket._async_op: {AsyncCommBucket._async_op}"
AsyncCommBucket.clear_all()

return outputs, loss_avg

def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:
Expand Down

0 comments on commit ea09a25

Please sign in to comment.