diff --git a/.github/workflows/llama_tests.yaml b/.github/workflows/llama_tests.yaml new file mode 100644 index 00000000..56d55a8e --- /dev/null +++ b/.github/workflows/llama_tests.yaml @@ -0,0 +1,57 @@ +name: Run examples + +on: + push: + branches: [ main ] + # Only run tests if we modify the following files + paths: + - "src/**/*.py" + - "examples/**/*.py" + - "tests/**/*.py" + + pull_request: + branches: [ '**' ] + paths: + - "src/**/*.py" + - "examples/**/*.py" + - "tests/**/*.py" + +jobs: + tests: + # NOTE: 8-a10 to run LLama + runs-on: [multi-gpu, nvidia-gpu, 8-a10, ci] + container: + image: runpod/pytorch:2.1.1-py3.10-cuda12.1.1-devel-ubuntu22.04 + ports: + - 80 + options: --gpus all --shm-size "8G" + steps: + - uses: actions/checkout@v3 + - name: Python environment + run: | + which python + python --version + + - name: Check Pytorch version + run: | + nvidia-smi + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install nanotron's dependencies + run: | + python -m pip install --upgrade pip + pip install packaging + pip install wheel + pip install -e . + pip install -e .[dev] + pip install -e .[test] + + - name: Show installed libraries and their versions + run: pip freeze | tee installed.txt + + - name: Run tiny Llama example + run: ./examples/train_tiny_llama.sh + + - name: Run Llama loss tests + run: pytest -sv tests/test_train_llama.py diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index 0e87c663..53aed0a9 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -1,6 +1,6 @@ checkpoints: checkpoint_interval: 10 - checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/checkpoints + checkpoints_path: /fsx/haojun/nanotron/checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null save_initial_state: false @@ -37,24 +37,28 @@ general: run: tiny_llama_%date_%jobid seed: 42 step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info model: ddp_bucket_cap_mb: 25 dtype: bfloat16 init_method: std: 0.025 - # use_mup: true # uncomment this and comment the std line above to use spectral µTransfer make_vocab_size_divisible_by: 1 model_config: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 32 + hidden_size: 16 initializer_range: 0.02 - intermediate_size: 128 + intermediate_size: 64 is_llama_config: true max_position_embeddings: 256 num_attention_heads: 4 - num_hidden_layers: 10 + num_hidden_layers: 2 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -67,11 +71,11 @@ optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 learning_rate_scheduler: - learning_rate: 0.001 + learning_rate: 0.0003 lr_decay_starting_step: null - lr_decay_steps: null + lr_decay_steps: 48 lr_decay_style: cosine - lr_warmup_steps: 2000 # 20% of the total steps + lr_warmup_steps: 2 lr_warmup_style: linear min_decay_lr: 1.0e-05 optimizer_factory: @@ -85,37 +89,12 @@ optimizer: parallelism: dp: 2 expert_parallel_size: 1 - pp: 1 + pp: 2 pp_engine: 1f1b tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER -data_stages: - - name: Stable Training Stage - start_training_step: 1 - data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small - hf_dataset_splits: train - text_column_name: completion - num_loading_workers: 1 - seed: 42 - - name: Annealing Phase - start_training_step: 10 - data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: HuggingFaceH4/testing_codealpaca_small - hf_dataset_splits: train - text_column_name: completion - num_loading_workers: 1 - seed: 42 -lighteval: null +profiler: null tokenizer: tokenizer_max_length: null tokenizer_name_or_path: gpt2 @@ -126,16 +105,5 @@ tokens: limit_val_batches: 0 micro_batch_size: 2 sequence_length: 32 - train_steps: 15 + train_steps: 50 val_check_interval: -1 -checkpoints: - checkpoint_interval: 10 - checkpoints_path: checkpoints - checkpoints_path_is_shared_file_system: false - resume_checkpoint_path: checkpoints - save_initial_state: false -profiler: null -logging: - iteration_step_info_interval: 1 - log_level: info - log_level_replica: info diff --git a/examples/config_train_llama.py b/examples/config_train_llama.py new file mode 100644 index 00000000..30104219 --- /dev/null +++ b/examples/config_train_llama.py @@ -0,0 +1,119 @@ +""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" +import os + +from nanotron.config import ( + AdamWOptimizerArgs, + CheckpointsArgs, + Config, + DataArgs, + DatasetStageArgs, + GeneralArgs, + LlamaConfig, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + PretrainDatasetsArgs, + RandomInit, + TokenizerArgs, + TokensArgs, +) +from nanotron.logging import human_format + +model_config = LlamaConfig( + # Config for a tiny model model with 1.62M parameters + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=768, + initializer_range=0.02, + intermediate_size=3072, + max_position_embeddings=512, + num_attention_heads=16, + num_hidden_layers=12, + num_key_value_heads=16, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=True, + use_cache=True, + vocab_size=50272, +) + +num_params = human_format( + model_config.vocab_size * model_config.hidden_size * 2 + + model_config.num_hidden_layers + * ( + 3 * model_config.hidden_size * model_config.intermediate_size + + 4 * model_config.hidden_size * model_config.hidden_size + ) +).replace(".", "p") + +print(f"Model has {num_params} parameters") + +seed = 42 + +learning_rate = LRSchedulerArgs( + learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 +) + +optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, + learning_rate_scheduler=learning_rate, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), +) + +parallelism = ParallelismArgs( + dp=4, + pp=1, + tp=2, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, +) + +# a global batch-size of 1M tokens. micro_batch_size * dp * sequence_length * batch_accumulation_per_replica +tokens = TokensArgs(sequence_length=512, train_steps=200, micro_batch_size=128, batch_accumulation_per_replica=4) + +checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" +os.makedirs(checkpoints_path, exist_ok=True) + +config = Config( + general=GeneralArgs(project="debug", run="tiny_llama_%date_%jobid", seed=seed), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + parallelism=parallelism, + model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), + tokenizer=TokenizerArgs("gpt2"), + optimizer=optimizer, + logging=LoggingArgs(), + tokens=tokens, + data_stages=[ + DatasetStageArgs( + name="Stable Training Stage", + start_training_step=1, + data=DataArgs( + dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="roneneldan/TinyStories", text_column_name="text"), + seed=seed, + ), + ) + ], + profiler=None, +) + + +if __name__ == "__main__": + dir = os.path.dirname(__file__) + + # Save config as YAML file + config.save_as_yaml(f"{dir}/config_train_llama.yaml") + + # You can now train a model with this config using `/run_train.py` diff --git a/examples/config_train_llama.yaml b/examples/config_train_llama.yaml new file mode 100644 index 00000000..d6ae7b32 --- /dev/null +++ b/examples/config_train_llama.yaml @@ -0,0 +1,97 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: /fsx/haojun/nanotron/checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 768 + initializer_range: 0.02 + intermediate_size: 3072 + is_llama_config: true + max_position_embeddings: 512 + num_attention_heads: 16 + num_hidden_layers: 12 + num_key_value_heads: 16 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 50272 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 198 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 4 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: gpt2 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 4 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 128 + sequence_length: 512 + train_steps: 200 + val_check_interval: -1 diff --git a/run_train.py b/run_train.py index 8dc16f7a..617d231b 100644 --- a/run_train.py +++ b/run_train.py @@ -133,7 +133,7 @@ def get_dataloader_from_data_stage( ) assert num_tokens_needed_for_training <= total_tokens_dataset, ( f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " - f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.start_iteration_step}" + f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}" ) else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") diff --git a/tests/test_train_llama.py b/tests/test_train_llama.py new file mode 100644 index 00000000..631b1b68 --- /dev/null +++ b/tests/test_train_llama.py @@ -0,0 +1,107 @@ +# """Script to test correctness of training script by comparing loss value after 100th iteration with expected loss value + +# ```bash +# pytest -sv tests/test_train_llama.py or python tests/test_train_llama.py +# ``` +# """ + +import atexit +import os +import re +import signal +import subprocess + +CONFIG_FILE = "examples/config_train_llama.yaml" +CREATE_CONFIG_FILE = "examples/config_train_llama.py" +TRAIN_SCRIPT = "run_train.py" +NUM_GPUS = 8 + +## 100+ steps: lm_loss < 3.5 +## 200 steps: lm_loss < 3 + +EXPECTED_LOSS = 3.5 +CHECK_ITERATION = 100 + +EXPECTED_LOSS_END = 3 +CHECK_ITERATION_END = 200 + + +def exit_with_children(): + """Kill all children processes when this process exits""" + os.killpg(0, signal.SIGKILL) + + +def extract_loss(line): + """Extract loss value from the line""" + # extract loss value of the type | lm_loss: 5.33 + try: + return float(re.search(r"lm_loss: (\d+.\d)", line.decode("utf-8")).group(1)) + except AttributeError: + raise ValueError(f"Could not extract loss value from line: {line}") + + +def test_tiny_llama(): + cmd = f"python {CREATE_CONFIG_FILE}" + subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + cmd = f'FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={NUM_GPUS} --rdzv_endpoint=localhost:29800 {TRAIN_SCRIPT} --config-file {CONFIG_FILE}' + os.setpgrp() # create new process group, become its leader + atexit.register(exit_with_children) # kill all children processes when this process exits + + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + # Read and print output in real-time + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + # for all iterations >= CHECK_ITERATION, loss should be below EXPECTED_LOSS + if re.search(r"iteration: (\d+) / ", line.decode("utf-8")): + if int(re.search(r"iteration: (\d+) / ", line.decode("utf-8")).group(1)) >= CHECK_ITERATION: + loss = extract_loss(line) + assert loss < EXPECTED_LOSS + + if re.search(rf"iteration: {CHECK_ITERATION_END} / ", line.decode("utf-8")): + loss = extract_loss(line) + assert loss < EXPECTED_LOSS_END + + process.wait() # Wait for the process to finish + assert process.returncode == 0 + + +if __name__ == "__main__": + cmd = f"python {CREATE_CONFIG_FILE}" + subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + cmd = f'FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={NUM_GPUS} --rdzv_endpoint=localhost:29800 {TRAIN_SCRIPT} --config-file {CONFIG_FILE}' + os.setpgrp() # create new process group, become its leader + atexit.register(exit_with_children) # kill all children processes when this process exits + + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + try: + # Read and print output in real-time + while True: + line = process.stdout.readline() + if process.poll() is not None and line == b"": + break + if line: + print(line.decode("utf-8"), end="") + + # for all iterations >= CHECK_ITERATION, loss should be below EXPECTED_LOSS + if re.search(r"iteration: (\d+) / ", line.decode("utf-8")): + if int(re.search(r"iteration: (\d+) / ", line.decode("utf-8")).group(1)) >= CHECK_ITERATION: + loss = extract_loss(line) + assert loss < EXPECTED_LOSS + # at iteration= CHECK_ITERATION, loss should be below EXPECTED_LOSS_END + if re.search(rf"iteration: {CHECK_ITERATION_END} / ", line.decode("utf-8")): + loss = extract_loss(line) + assert loss < EXPECTED_LOSS_END + process.wait() # Wait for the process to finish + assert process.returncode == 0 + except AssertionError: + print("Command failed with exit code:", process.returncode) + else: + print("Command executed successfully.")