diff --git a/run_train.py b/run_train.py index b33231f4..aa54937e 100644 --- a/run_train.py +++ b/run_train.py @@ -130,14 +130,12 @@ def get_dataloader_from_data_stage( ) # Check if we have enough samples for train_steps - total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length - num_tokens_needed_for_training = ( - num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length - ) - assert num_tokens_needed_for_training <= total_tokens_dataset, ( - f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " - f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}" - ) + len(dataloader.dataset) * trainer.sequence_length + (num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length) + # assert num_tokens_needed_for_training <= total_tokens_dataset, ( + # f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " + # f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}" + # ) # Case 3: Nanosets elif isinstance(data.dataset, NanosetDatasetsArgs): diff --git a/src/nanotron/dataloader.py b/src/nanotron/dataloader.py index 61f73557..8eb51e71 100644 --- a/src/nanotron/dataloader.py +++ b/src/nanotron/dataloader.py @@ -505,7 +505,13 @@ def get_train_dataloader( consumed_train_samples=consumed_train_samples, ) - return DataLoader( + class CyclingDataLoader(DataLoader): + def __iter__(self): + import itertools + + return itertools.cycle(super().__iter__()) + + return CyclingDataLoader( train_dataset, batch_size=micro_batch_size, sampler=train_sampler, diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca8894b9..94f2f131 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -27,7 +27,6 @@ from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN -from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer @@ -608,7 +607,8 @@ def __init__( layer_idx: int, ): super().__init__() - self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, @@ -616,7 +616,8 @@ def __init__( layer_idx=layer_idx, ) - self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) def forward( @@ -725,8 +726,10 @@ def __init__( self.final_layer_norm = PipelineBlock( p2p=self.p2p, - module_builder=TritonRMSNorm, - module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + # module_builder=TritonRMSNorm, + # module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_builder=nn.LayerNorm, + module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps}, module_input_keys={"input"}, module_output_keys={"hidden_states"}, ) # TODO diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..a8f5f93d 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -37,6 +37,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, + nn.LayerNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, } diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0eda00dc..da590c88 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -60,7 +60,6 @@ from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext -from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.parallel.pipeline_parallel.engine import ( PipelineEngine, @@ -72,7 +71,6 @@ from nanotron.parallel.tied_parameters import ( create_pg_for_tied_weights, get_tied_id_to_param, - sync_tied_weights_gradients, tie_parameters, ) from nanotron.random import set_random_seed @@ -272,11 +270,11 @@ def pre_training(self, *args, **kwargs): rank=0, ) - current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") + datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.init( project=self.config.general.project, - name=f"{current_time}_{self.config.general.run}", + name=f"{self.config.general.run}", config={"nanotron_config": self.config.as_dict()}, ) @@ -474,23 +472,23 @@ def training_step( self.grad_accumulator.fp32_grads_allreduce_handle.wait() # Sync tied weights - if not isinstance(self.model, DistributedDataParallel): - # Manually sync across DP if it's not handled by DDP - sync_gradients_across_dp( - module=self.model, - dp_pg=self.parallel_context.dp_pg, - reduce_op=dist.ReduceOp.AVG, - # TODO @thomasw21: This is too memory hungry, instead we run all_reduce - reduce_scatter=False, # optimizer.inherit_from(ZeroDistributedOptimizer), - grad_accumulator=self.grad_accumulator, - ) + # if not isinstance(self.model, DistributedDataParallel): + # # Manually sync across DP if it's not handled by DDP + # sync_gradients_across_dp( + # module=self.model, + # dp_pg=self.parallel_context.dp_pg, + # reduce_op=dist.ReduceOp.AVG, + # # TODO @thomasw21: This is too memory hungry, instead we run all_reduce + # reduce_scatter=False, # optimizer.inherit_from(ZeroDistributedOptimizer), + # grad_accumulator=self.grad_accumulator, + # ) # TODO @nouamane: Put this in hooks so we can overlap communication with gradient computation on the last backward pass. - sync_tied_weights_gradients( - module=self.unwrapped_model, - parallel_context=self.parallel_context, - grad_accumulator=self.grad_accumulator, - ) + # sync_tied_weights_gradients( + # module=self.unwrapped_model, + # parallel_context=self.parallel_context, + # grad_accumulator=self.grad_accumulator, + # ) # Clip gradients if self.config.optimizer.clip_grad is not None: