Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 10, 2024
1 parent f1adf52 commit c05e032
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 34 deletions.
14 changes: 6 additions & 8 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,12 @@ def get_dataloader_from_data_stage(
)

# Check if we have enough samples for train_steps
total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
num_tokens_needed_for_training = (
num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
)
len(dataloader.dataset) * trainer.sequence_length
(num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length)
# assert num_tokens_needed_for_training <= total_tokens_dataset, (
# f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
# f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
# )

# Case 3: Nanosets
elif isinstance(data.dataset, NanosetDatasetsArgs):
Expand Down
8 changes: 7 additions & 1 deletion src/nanotron/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,13 @@ def get_train_dataloader(
consumed_train_samples=consumed_train_samples,
)

return DataLoader(
class CyclingDataLoader(DataLoader):
def __iter__(self):
import itertools

return itertools.cycle(super().__iter__())

return CyclingDataLoader(
train_dataset,
batch_size=micro_batch_size,
sampler=train_sampler,
Expand Down
15 changes: 9 additions & 6 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand All @@ -27,7 +27,6 @@
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
Expand Down Expand Up @@ -608,15 +607,17 @@ def __init__(
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

def forward(
Expand Down Expand Up @@ -725,8 +726,10 @@ def __init__(

self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
# module_builder=TritonRMSNorm,
# module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_builder=nn.LayerNorm,
module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO
Expand Down
1 change: 1 addition & 0 deletions src/nanotron/scaling/parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, config: ModelArgs):
TensorParallelColumnLinear: self._parametrize_column_linear,
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
nn.LayerNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
}

Expand Down
36 changes: 17 additions & 19 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from nanotron.models.starcoder2 import Starcoder2ForTraining
from nanotron.optim.clip_grads import clip_grad_norm
from nanotron.parallel import ParallelContext
from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
from nanotron.parallel.parameters import NanotronParameter, sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
PipelineEngine,
Expand All @@ -72,7 +71,6 @@
from nanotron.parallel.tied_parameters import (
create_pg_for_tied_weights,
get_tied_id_to_param,
sync_tied_weights_gradients,
tie_parameters,
)
from nanotron.random import set_random_seed
Expand Down Expand Up @@ -272,11 +270,11 @@ def pre_training(self, *args, **kwargs):
rank=0,
)

current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.init(
project=self.config.general.project,
name=f"{current_time}_{self.config.general.run}",
name=f"{self.config.general.run}",
config={"nanotron_config": self.config.as_dict()},
)

Expand Down Expand Up @@ -474,23 +472,23 @@ def training_step(
self.grad_accumulator.fp32_grads_allreduce_handle.wait()

# Sync tied weights
if not isinstance(self.model, DistributedDataParallel):
# Manually sync across DP if it's not handled by DDP
sync_gradients_across_dp(
module=self.model,
dp_pg=self.parallel_context.dp_pg,
reduce_op=dist.ReduceOp.AVG,
# TODO @thomasw21: This is too memory hungry, instead we run all_reduce
reduce_scatter=False, # optimizer.inherit_from(ZeroDistributedOptimizer),
grad_accumulator=self.grad_accumulator,
)
# if not isinstance(self.model, DistributedDataParallel):
# # Manually sync across DP if it's not handled by DDP
# sync_gradients_across_dp(
# module=self.model,
# dp_pg=self.parallel_context.dp_pg,
# reduce_op=dist.ReduceOp.AVG,
# # TODO @thomasw21: This is too memory hungry, instead we run all_reduce
# reduce_scatter=False, # optimizer.inherit_from(ZeroDistributedOptimizer),
# grad_accumulator=self.grad_accumulator,
# )

# TODO @nouamane: Put this in hooks so we can overlap communication with gradient computation on the last backward pass.
sync_tied_weights_gradients(
module=self.unwrapped_model,
parallel_context=self.parallel_context,
grad_accumulator=self.grad_accumulator,
)
# sync_tied_weights_gradients(
# module=self.unwrapped_model,
# parallel_context=self.parallel_context,
# grad_accumulator=self.grad_accumulator,
# )

# Clip gradients
if self.config.optimizer.clip_grad is not None:
Expand Down

0 comments on commit c05e032

Please sign in to comment.