Skip to content

Commit

Permalink
llama tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed Apr 30, 2024
1 parent 1c7f038 commit c565621
Show file tree
Hide file tree
Showing 6 changed files with 396 additions and 48 deletions.
57 changes: 57 additions & 0 deletions .github/workflows/llama_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
name: Run examples

on:
push:
branches: [ main ]
# Only run tests if we modify the following files
paths:
- "src/**/*.py"
- "examples/**/*.py"
- "tests/**/*.py"

pull_request:
branches: [ '**' ]
paths:
- "src/**/*.py"
- "examples/**/*.py"
- "tests/**/*.py"

jobs:
tests:
# NOTE: 8-a10 to run LLama
runs-on: [multi-gpu, nvidia-gpu, 8-a10, ci]
container:
image: runpod/pytorch:2.1.1-py3.10-cuda12.1.1-devel-ubuntu22.04
ports:
- 80
options: --gpus all --shm-size "8G"
steps:
- uses: actions/checkout@v3
- name: Python environment
run: |
which python
python --version
- name: Check Pytorch version
run: |
nvidia-smi
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install nanotron's dependencies
run: |
python -m pip install --upgrade pip
pip install packaging
pip install wheel
pip install -e .
pip install -e .[dev]
pip install -e .[test]
- name: Show installed libraries and their versions
run: pip freeze | tee installed.txt

- name: Run tiny Llama example
run: ./examples/train_tiny_llama.sh

- name: Run Llama loss tests
run: pytest -sv tests/test_train_llama.py
62 changes: 15 additions & 47 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/checkpoints
checkpoints_path: /fsx/haojun/nanotron/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
Expand Down Expand Up @@ -37,24 +37,28 @@ general:
run: tiny_llama_%date_%jobid
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
# use_mup: true # uncomment this and comment the std line above to use spectral µTransfer
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 32
hidden_size: 16
initializer_range: 0.02
intermediate_size: 128
intermediate_size: 64
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 10
num_hidden_layers: 2
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
Expand All @@ -67,11 +71,11 @@ optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.001
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_steps: 48
lr_decay_style: cosine
lr_warmup_steps: 2000 # 20% of the total steps
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
Expand All @@ -85,37 +89,12 @@ optimizer:
parallelism:
dp: 2
expert_parallel_size: 1
pp: 1
pp: 2
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
data_stages:
- name: Stable Training Stage
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
- name: Annealing Phase
start_training_step: 10
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_codealpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
lighteval: null
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
Expand All @@ -126,16 +105,5 @@ tokens:
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 32
train_steps: 15
train_steps: 50
val_check_interval: -1
checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: checkpoints
save_initial_state: false
profiler: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
119 changes: 119 additions & 0 deletions examples/config_train_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information."""
import os

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LlamaConfig,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
OptimizerArgs,
ParallelismArgs,
PretrainDatasetsArgs,
RandomInit,
TokenizerArgs,
TokensArgs,
)
from nanotron.logging import human_format

model_config = LlamaConfig(
# Config for a tiny model model with 1.62M parameters
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
max_position_embeddings=512,
num_attention_heads=16,
num_hidden_layers=12,
num_key_value_heads=16,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=50272,
)

num_params = human_format(
model_config.vocab_size * model_config.hidden_size * 2
+ model_config.num_hidden_layers
* (
3 * model_config.hidden_size * model_config.intermediate_size
+ 4 * model_config.hidden_size * model_config.hidden_size
)
).replace(".", "p")

print(f"Model has {num_params} parameters")

seed = 42

learning_rate = LRSchedulerArgs(
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
)

optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
dp=4,
pp=1,
tp=2,
pp_engine="1f1b",
tp_mode="REDUCE_SCATTER",
tp_linear_async_communication=True,
)

# a global batch-size of 1M tokens. micro_batch_size * dp * sequence_length * batch_accumulation_per_replica
tokens = TokensArgs(sequence_length=512, train_steps=200, micro_batch_size=128, batch_accumulation_per_replica=4)

checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)

config = Config(
general=GeneralArgs(project="debug", run="tiny_llama_%date_%jobid", seed=seed),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
parallelism=parallelism,
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
tokenizer=TokenizerArgs("gpt2"),
optimizer=optimizer,
logging=LoggingArgs(),
tokens=tokens,
data_stages=[
DatasetStageArgs(
name="Stable Training Stage",
start_training_step=1,
data=DataArgs(
dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="roneneldan/TinyStories", text_column_name="text"),
seed=seed,
),
)
],
profiler=None,
)


if __name__ == "__main__":
dir = os.path.dirname(__file__)

# Save config as YAML file
config.save_as_yaml(f"{dir}/config_train_llama.yaml")

# You can now train a model with this config using `/run_train.py`
97 changes: 97 additions & 0 deletions examples/config_train_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/haojun/nanotron/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: debug
run: tiny_llama_%date_%jobid
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
is_llama_config: true
max_position_embeddings: 512
num_attention_heads: 16
num_hidden_layers: 12
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 50272
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 198
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 4
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 4
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 128
sequence_length: 512
train_steps: 200
val_check_interval: -1
2 changes: 1 addition & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_dataloader_from_data_stage(
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.start_iteration_step}"
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
)
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")
Expand Down
Loading

0 comments on commit c565621

Please sign in to comment.