Skip to content

Commit

Permalink
backup before refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 12, 2025
1 parent da948df commit aa3e973
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 182 deletions.
98 changes: 98 additions & 0 deletions examples/config_llama_domino.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: nanotron_domino
run: config_llama_domino
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 16384
is_llama_config: true
max_position_embeddings: 4096
num_attention_heads: 32
num_hidden_layers: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 1000
lr_decay_style: cosine
lr_warmup_steps: 500
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
pp: 1
tp: 8
expert_parallel_size: 1
pp_engine: 1f1b
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
domino:
num_input_batches: 2
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 4096
train_steps: 1500
val_check_interval: -1
107 changes: 8 additions & 99 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.domino import WaitComm
from nanotron.parallel.tensor_parallel.domino import (
BWD_ATTN_HANDLE_IDX,
BWD_MLP_HANDLE_IDX,
FWD_ATTN_HANDLE_IDX,
FWD_MLP_HANDLE_IDX,
WaitComm,
)
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
Expand All @@ -51,11 +57,6 @@

DOMINO_COMM_STREAM = "domino_comm_stream_{}"

FWD_MLP_HANDLE_IDX = "fwd.layer_mlp_{}_batch_{}"
FWD_ATTN_HANDLE_IDX = "fwd.layer_attn_{}_batch_{}"
BWD_ATTN_HANDLE_IDX = "bwd.layer_attn_{}_batch_{}"
BWD_MLP_HANDLE_IDX = "bwd.layer_mlp_{}_batch_{}"


class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
Expand Down Expand Up @@ -743,96 +744,6 @@ def __init__(

self.layer_idx = layer_idx

# def _core_forward(
# self,
# hidden_states: Union[torch.Tensor, TensorPointer],
# sequence_mask: Union[torch.Tensor, TensorPointer],
# ) -> List[Union[torch.Tensor, TensorPointer]]:
# from nanotron import constants

# num_input_batches = self.parallel_config.domino.num_input_batches
# orig_sequence_mask = sequence_mask

# assert num_input_batches == 2
# hidden_states = torch.chunk(hidden_states, chunks=num_input_batches, dim=1)
# sequence_mask = torch.chunk(sequence_mask, chunks=num_input_batches, dim=0)

# hidden_states0, hidden_states1 = hidden_states
# sequence_mask0, sequence_mask1 = sequence_mask

# residual0 = hidden_states0
# residual1 = hidden_states1

# hidden_states0 = self.input_layernorm(hidden_states0)
# hidden_states1 = self.input_layernorm(hidden_states1)

# attn_output0 = self.attn(
# hidden_states=hidden_states0,
# sequence_mask=sequence_mask0,
# handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 0),
# )
# # attn_output0["hidden_states"] = WaitComm.apply(
# # attn_output0["hidden_states"],
# # BWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
# # )

# attn_output1 = self.attn(
# hidden_states=hidden_states1,
# sequence_mask=sequence_mask1,
# handle_idx=FWD_ATTN_HANDLE_IDX.format(self.layer_idx, 1),
# )
# # attn_output1["hidden_states"] = WaitComm.apply(
# # attn_output1["hidden_states"],
# # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
# # )

# comm_stream = constants.CUDA_STREAMS[torch.cuda.current_device()]
# with torch.cuda.stream(comm_stream):
# attn_output0["work"].wait()

# hidden_states0 = attn_output0["hidden_states"] + residual0
# residual0 = hidden_states0
# hidden_states0 = self.post_attention_layernorm(hidden_states0)
# hidden_states0 = WaitComm.apply(
# hidden_states0,
# BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
# ) # new

# mlp_output0 = self.mlp(
# hidden_states=hidden_states0,
# handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
# )
# # mlp_output0["hidden_states"] = WaitComm.apply(
# # mlp_output0["hidden_states"],
# # BWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
# # )

# with torch.cuda.stream(comm_stream):
# attn_output1["work"].wait()

# hidden_states1 = attn_output1["hidden_states"] + residual1
# residual1 = hidden_states1
# hidden_states1 = self.post_attention_layernorm(hidden_states1)
# hidden_states1 = WaitComm.apply(
# hidden_states1,
# BWD_MLP_HANDLE_IDX.format(self.layer_idx, 0),
# )

# mlp_output1 = self.mlp(
# hidden_states=hidden_states1,
# handle_idx=FWD_MLP_HANDLE_IDX.format(self.layer_idx, 1),
# )

# with torch.cuda.stream(comm_stream):
# mlp_output0["work"].wait()
# mlp_output1["work"].wait()

# hidden_states0 = mlp_output0["hidden_states"] + residual0
# hidden_states1 = mlp_output1["hidden_states"] + residual1

# hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1)
# return hidden_states, orig_sequence_mask

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
Expand Down Expand Up @@ -908,12 +819,10 @@ def _core_forward(

with torch.cuda.stream(comm_stream):
mlp_output0["work"].wait()
# mlp_output1["work"].wait()

mlp_output0["work"].is_completed()
# mlp_output1["work"].is_completed()

torch.cuda.current_stream().wait_stream(comm_stream)
# torch.cuda.synchronize()

hidden_states0 = mlp_output0["hidden_states"] + residual0
hidden_states1 = mlp_output1["hidden_states"] + residual1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,106 +27,65 @@ class DifferentiableIdentity(torch.autograd.Function):
"""All-reduce gradients in a differentiable fashion"""

@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None):
def forward(
ctx,
tensor,
group: Optional[ProcessGroup],
async_all_reduce: bool,
op_name: str = None,
comm_stream: torch.cuda.Stream = None,
):
ctx.group = group
ctx.async_all_reduce = async_all_reduce
ctx.op_name = op_name
ctx.group = group
ctx.comm_stream = comm_stream
return tensor

@staticmethod
def backward(ctx, grad_output):
# import pydevd
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
# NOTE: lm_head is TensorParallelColumnLinear, and it doesn't do async
# assert ctx.handle_idx is not None
group = ctx.group

# if ctx.handle_idx is not None and "fwd." in ctx.handle_idx:
# handle_idx = ctx.handle_idx.replace("fwd.", "bwd.")
# # if "bwd.layer_mlp_1_batch_1" == handle_idx:
# # from nanotron.parallel.comm import is_async_comm
# # async_all_reduce = is_async_comm(handle_idx)
# # else:
# # async_all_reduce = ctx.async_all_reduce
# # from nanotron.parallel.comm import is_async_comm
# from nanotron.parallel.tensor_parallel.domino import is_async_comm

# async_all_reduce = is_async_comm(handle_idx)
# else:
# handle_idx = ctx.handle_idx
# async_all_reduce = ctx.async_all_reduce

# if handle_idx is not None and "bwd." in handle_idx and async_all_reduce is True:
# assert 1 == 1

op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce

if op_name is not None and "layer_mlp_27_batch_1" in op_name:
assert 1 == 1

from nanotron.constants import _AUTOGRAD_RUNS

_AUTOGRAD_RUNS.append(ctx.op_name)

return DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name), None, None, None

group = ctx.group

def is_last_batch_of_attn(x):
import re
op_name = ctx.op_name.replace("fwd.", "bwd.") if ctx.op_name is not None else ctx.op_name
async_all_reduce = is_async_comm(op_name) if ctx.op_name is not None else ctx.async_all_reduce

pattern = r"layer_attn_\d+_batch_0"
if re.match(pattern, x):
return True
return False
return (
DifferentiableAllReduceSum.apply(grad_output, group, async_all_reduce, op_name, ctx.comm_stream),
None,
None,
None,
)


class DifferentiableAllReduceSum(torch.autograd.Function):
"""All-reduce in a differentiable fashion"""

@staticmethod
def forward(
ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool, op_name: str = None
ctx,
tensor,
group: Optional[ProcessGroup],
async_all_reduce: bool,
op_name: str = None,
comm_stream: torch.cuda.Stream = None,
) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
ctx.async_all_reduce = async_all_reduce
ctx.comm_stream = comm_stream

if group.size() == 1:
return tensor

# if handle_idx == "bwd.layer_mlp_1_batch_0":
# assert 1 == 1

# id(tensor)
# if async_all_reduce is True:
# # if isinstance(handle_idx, str):
# # do_async = is_last_batch_of_attn(handle_idx) is False
# # else:
# # do_async = async_all_reduce
# # from nanotron.parallel.comm import is_async_comm
# from nanotron.parallel.tensor_parallel.domino import is_async_comm

# do_async = is_async_comm(handle_idx)

# handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
# if do_async:
# if "bwd" in handle_idx:
# assert 1 == 1

# # # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
# # if handle_idx is not None and "bwd." in handle_idx:
# # AsyncCommBucket.add(orig_id if handle_idx is None else handle_idx, handle)
# # else:
# # AsyncCommBucket.add(orig_id, handle)
# # NOTE: id(tensor) is for the fwd pass, for the bwd pass, we do handle_idx
# assert handle_idx is not None
# AsyncCommBucket.add(handle_idx, handle)
# else:
# dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
if async_all_reduce:
handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True)
AsyncCommBucket.add(op_name, handle)
else:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
with torch.cuda.stream(comm_stream):
if async_all_reduce:
handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=True)
AsyncCommBucket.add(op_name, handle)
else:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)

return tensor

Expand Down
Loading

0 comments on commit aa3e973

Please sign in to comment.