From a756526bab529c733a4afc909047bca2fb657330 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Wed, 4 Dec 2024 18:51:18 +0000 Subject: [PATCH] tp --- examples/config_tiny_llama.yaml | 10 +++---- scaling_benchmarks.py | 52 ++++++++++++++++++++------------- src/nanotron/helpers.py | 52 ++++++++++++++++++++++++++------- src/nanotron/trainer.py | 23 +++++++++++---- 4 files changed, 95 insertions(+), 42 deletions(-) diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index 9714a6d2..af4a835a 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -36,21 +36,21 @@ model: bos_token_id: 0 eos_token_id: 0 hidden_act: silu - hidden_size: 2048 + hidden_size: 3072 initializer_range: 0.02 intermediate_size: 8192 is_llama_config: true max_position_embeddings: 2048 - num_attention_heads: 32 - num_hidden_layers: 24 - num_key_value_heads: 32 + num_attention_heads: 24 + num_hidden_layers: 28 + num_key_value_heads: 8 pad_token_id: null pretraining_tp: 1 rms_norm_eps: 1.0e-05 rope_scaling: null tie_word_embeddings: true use_cache: true - vocab_size: 49152 + vocab_size: 128256 optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 diff --git a/scaling_benchmarks.py b/scaling_benchmarks.py index dc22e21b..1e37a1ec 100644 --- a/scaling_benchmarks.py +++ b/scaling_benchmarks.py @@ -1,4 +1,5 @@ # python scaling_benchmarks.py --base-config elie.yaml --debug +# python scaling_benchmarks.py --debug import argparse import math import os @@ -6,11 +7,13 @@ import yaml from nanotron.logging import human_format -VOCAB_SIZE = 128256 -NUM_KEY_VALUE_HEADS = 8 +VOCAB_SIZE = 32768 +NUM_KEY_VALUE_HEADS = None TIE_WORD_EMBEDDINGS = True ZERO_STAGE = 0 -TP_MODE = "ALL_REDUCE" +# TP_MODE = "REDUCE_SCATTER" # "REDUCE_SCATTER" "ALL_REDUCE" +TP_MODE = "ALL_REDUCE" # "REDUCE_SCATTER" "ALL_REDUCE" +PROFILE = True def estimate_num_params(layers, hidden_size, heads, intermediate_size, tie_word_embeddings): @@ -26,7 +29,7 @@ def create_config( batch_accum: int, seq_len: int, micro_batch_size: int = 1, - base_config_path: str = "examples/config_tiny_llama.yaml", + base_config_path: str = "examples/config_tiny_llama_bench.yaml", zero_stage: int = ZERO_STAGE, num_layers: int = 24, hidden_size: int = 2048, @@ -88,10 +91,15 @@ def create_config( # Update run name to reflect configuration config["general"][ "run" - ] = f"{N}_dp{dp}_tp{tp}_pp{pp}_acc{batch_accum}_mbs{micro_batch_size}_seq{seq_len}_zero{zero_stage}_l{num_layers}_h{hidden_size}_heads{num_attention_heads}" + ] = f"{N}_dp{dp}_tp{tp}_pp{pp}_acc{batch_accum}_mbs{micro_batch_size}_seq{seq_len}_zero{zero_stage}_tpmode{tp_mode[:3]}_l{num_layers}_h{hidden_size}_heads{num_attention_heads}" # Update benchmark CSV path - config["general"]["benchmark_csv_path"] = "bench_elie.csv" + config["general"]["benchmark_csv_path"] = "bench_tp.csv" + + if PROFILE: + config["profiler"] = {} + config["profiler"]["profiler_export_path"] = "./tb_logs" + config["tokens"]["train_steps"] = 10 return config @@ -192,22 +200,24 @@ def main(): # (1, 4, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), # (1, 8, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), # find best tput on 16 nodes with 4GBS - # Format: (dp, tp, pp, batch_accum, seq_len, mbs, ...) - # (1, 8, 1, 1, 2048, 16, num_layers, hidden_size, num_heads, intermediate_size), # test max MBS - # (16, 8, 1, 1, 2048, 16, num_layers, hidden_size, num_heads, intermediate_size), # ideal run i guess - # (32, 4, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size), # TP=4 - # (64, 2, 1, 1, 2048, 4, num_layers, hidden_size, num_heads, intermediate_size), # TP=2 - # (128, 1, 1, 1, 2048, 2, num_layers, hidden_size, num_heads, intermediate_size), # TP=1 - # # same for 8 nodes - (8, 8, 1, 1, 2048, 16, num_layers, hidden_size, num_heads, intermediate_size), - (16, 4, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size), - (32, 2, 1, 1, 2048, 4, num_layers, hidden_size, num_heads, intermediate_size), - (64, 1, 1, 1, 2048, 2, num_layers, hidden_size, num_heads, intermediate_size), + (1, 8, 1, 1, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), # test max MBS + # (8, 1, 1, 1, 4096, 1, num_layers, hidden_size, num_heads, intermediate_size), # test max MBS + # (1, 8, 1, 1, 4096, 64, num_layers, hidden_size, num_heads, intermediate_size), # test max MBS + # (16, 8, 1, 1, 4096, 16, num_layers, hidden_size, num_heads, intermediate_size), # ideal run i guess + # (32, 4, 1, 1, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), # TP=4 + # (64, 2, 1, 1, 4096, 4, num_layers, hidden_size, num_heads, intermediate_size), # TP=2 + # (128, 1, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), # TP=1 + # find best tput on 8 nodes with 1GBS + # (8, 8, 1, 1, 4096, 32, num_layers, hidden_size, num_heads, intermediate_size), + # (8, 8, 1, 2, 4096, 16, num_layers, hidden_size, num_heads, intermediate_size), + # (16, 4, 1, 2, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), + # (32, 2, 1, 2, 4096, 4, num_layers, hidden_size, num_heads, intermediate_size), + # (64, 1, 1, 2, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), # same for 4 nodes - # (4, 8, 1, 1, 2048, 16, num_layers, hidden_size, num_heads, intermediate_size), - # (8, 4, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size), - # (16, 2, 1, 1, 2048, 4, num_layers, hidden_size, num_heads, intermediate_size), - # (32, 1, 1, 1, 2048, 2, num_layers, hidden_size, num_heads, intermediate_size), + # (4, 8, 1, 1, 4096, 16, num_layers, hidden_size, num_heads, intermediate_size), + # (8, 4, 1, 1, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), + # (16, 2, 1, 1, 4096, 4, num_layers, hidden_size, num_heads, intermediate_size), + # (32, 1, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), ] configurations.extend(model_configs) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 2369e668..f61333fd 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -20,7 +20,7 @@ from nanotron import logging from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs from nanotron.distributed import ProcessGroup -from nanotron.logging import LogItem, log_rank +from nanotron.logging import LogItem, human_format, log_rank from nanotron.models.base import NanotronModel from nanotron.optim.base import BaseOptimizer, Optimizer from nanotron.optim.gradient_accumulator import ( @@ -480,8 +480,8 @@ def get_profiler(config: Config): activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=1, skip_first=3), on_trace_ready=on_trace_ready, - # record_shapes=True, - # profile_memory=True, + record_shapes=True, + profile_memory=True, with_stack=True, ) else: @@ -592,8 +592,11 @@ def create_table_log( hardware_tflops, tokens_per_sec, bandwidth, + num_params, slurm_job_id, ): + print("num_params") + print(num_params) return [ LogItem("job_id", slurm_job_id, "s"), LogItem("name", config.general.run, "s"), @@ -613,23 +616,49 @@ def create_table_log( LogItem("RS Intra-node (GB/s)", bandwidth["reduce_scatter_intranode"], ".2f"), LogItem("Mem Alloc (GB)", torch.cuda.max_memory_allocated() / 1024**3, ".2f"), LogItem("Mem Res (GB)", torch.cuda.max_memory_reserved() / 1024**3, ".2f"), + # Important config columns + LogItem("dp", config.parallelism.dp, "d"), + LogItem("pp", config.parallelism.pp, "d"), + LogItem("tp", config.parallelism.tp, "d"), + LogItem("pp_engine", str(config.parallelism.pp_engine), "s"), + LogItem("tp_mode", config.parallelism.tp_mode, "s"), + LogItem("tp_async_comm", str(config.parallelism.tp_linear_async_communication), "s"), + LogItem("hidden_size", config.model.model_config.hidden_size, "d"), + LogItem("hidden_act", config.model.model_config.hidden_act, "s"), + LogItem("num_layers", config.model.model_config.num_hidden_layers, "d"), + LogItem("num_heads", config.model.model_config.num_attention_heads, "d"), + LogItem("num_kv_heads", config.model.model_config.num_key_value_heads, "d"), + LogItem("max_pos", config.model.model_config.max_position_embeddings, "d"), + LogItem("vocab_size", config.model.model_config.vocab_size, "d"), + LogItem("tie_word_embeddings", str(config.model.model_config.tie_word_embeddings), "s"), + LogItem("dtype", str(config.model.dtype), "s"), + LogItem("zero_stage", config.optimizer.zero_stage, "d"), + LogItem("ddp_bucket_cap_mb", config.model.ddp_bucket_cap_mb, "d"), + LogItem("accumulate_grad_in_fp32", str(config.optimizer.accumulate_grad_in_fp32), "s"), + # Params + LogItem("Total Params", num_params["total"], "human_format"), + LogItem("Local Params", num_params["local"], "human_format"), ] +def get_formatted_value(item): + if item.log_format == "human_format": + return human_format(item.scalar_value) + return f"{item.scalar_value:{item.log_format}}" + + def create_table_output(table_log, column_widths): header_row = "| " + " | ".join([item.tag.ljust(width) for item, width in zip(table_log, column_widths)]) + " |" separator_row = "| " + " | ".join(["-" * width for width in column_widths]) + " |" data_row = ( "| " - + " | ".join( - [f"{item.scalar_value:{item.log_format}}".ljust(width) for item, width in zip(table_log, column_widths)] - ) + + " | ".join([get_formatted_value(item).ljust(width) for item, width in zip(table_log, column_widths)]) + " |" ) return f"{header_row}\n{separator_row}\n{data_row}" -def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id, config): +def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id): """Write benchmark results to a CSV file with file locking using fcntl.""" import fcntl @@ -646,7 +675,7 @@ def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id, config): # Format row data header = [item.tag for item in table_log] - row = [f"{item.scalar_value:{item.log_format}}" for item in table_log] + row = [get_formatted_value(item) for item in table_log] # Use fcntl for file locking max_attempts = 10 @@ -708,13 +737,14 @@ def log_throughput( "reduce_scatter_intranode": 0, "all_gather_intranode": 0, }, + num_params={"total": 0, "local": 0}, ): slurm_job_id = os.environ.get("SLURM_JOB_ID", "N/A") table_log = create_table_log( - config, parallel_context, model_tflops, hardware_tflops, tokens_per_sec, bandwidth, slurm_job_id + config, parallel_context, model_tflops, hardware_tflops, tokens_per_sec, bandwidth, num_params, slurm_job_id ) - column_widths = [max(len(item.tag), len(f"{item.scalar_value:{item.log_format}}")) for item in table_log] + column_widths = [max(len(item.tag), len(get_formatted_value(item))) for item in table_log] table_output = create_table_output(table_log, column_widths) log_rank( @@ -725,7 +755,7 @@ def log_throughput( ) if dist.get_rank(parallel_context.world_pg) == 0: - write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id, config) + write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id) dist.barrier(group=parallel_context.world_pg) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 8aa98254..a9d8876a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -487,6 +487,9 @@ def training_step( if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) + # if self.iteration_step == self.initial_iter_step and dist.get_rank(self.parallel_context.world_pg) == 0: + # torch.cuda.memory._record_memory_history(max_entries=100000) + outputs = self.pipeline_engine.train_batch_iter( model=self.model, pg=self.parallel_context.pp_pg, @@ -495,6 +498,12 @@ def training_step( grad_accumulator=self.grad_accumulator, ) + # if self.iteration_step == self.initial_iter_step and dist.get_rank(self.parallel_context.world_pg) == 0: + # snapshot_save_path = "snapshots/" + os.environ["SLURM_JOB_ID"] + "_" + self.config.general.run + "_memory_snapshot.pkl" + # log_rank(f"Dumping memory snapshot to {snapshot_save_path}", logger=logger, level=logging.INFO) + # torch.cuda.memory._dump_snapshot(snapshot_save_path) + # torch.cuda.memory._record_memory_history(enabled=None) + if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger) @@ -667,12 +676,14 @@ def train_step_logs( hardware_tflops, tokens_per_sec, bandwidth=self.BANDWIDTHS, + num_params=self.num_params, ) log_rank("Throughput logging complete", logger=logger, level=logging.INFO) - if "SLURM_JOB_ID" in os.environ: - os.system("scancel " + os.environ["SLURM_JOB_ID"]) - else: - exit(0) + if not self.config.profiler: + if "SLURM_JOB_ID" in os.environ: + os.system("scancel " + os.environ["SLURM_JOB_ID"]) + else: + exit(0) def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: """Initialize the model and load weights from checkpoint if needed.""" @@ -811,10 +822,12 @@ def _init_model( dist.all_reduce(total_params, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # PP dist.all_reduce(total_size, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) dist.all_reduce(total_size, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) + total_params = total_params.item() + self.num_params = {"total": total_params, "local": num_params} # TODO @nouamanetazi: better memory logs log_rank( - f"Total number of parameters: {human_format(total_params.item())} ({total_size.item() / 1024**2:.2f}MiB)", + f"Total number of parameters: {human_format(total_params)} ({total_size.item() / 1024**2:.2f}MiB)", logger=logger, level=logging.INFO, group=parallel_context.world_pg,