Skip to content

Commit

Permalink
CPT: update eval support
Browse files Browse the repository at this point in the history
- tb: add TGS, MFU, update load&importance logging step
- add eval support, fix balance_loss = None bug during eval grad ckpt
- update logging strategy
  • Loading branch information
Spico197 committed Oct 13, 2023
1 parent f0e5ae3 commit 678f83b
Show file tree
Hide file tree
Showing 18 changed files with 290 additions and 75 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,4 @@ results/gate_loss_original_clustering_model
results/llama_7B_MoE_16Select4-l2_norm
results/random_16select4_moe
results/gate_loss.png
smoe/utils/gpu_diag.py
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"type": "python",
"request": "attach",
"connect": {
"host": "SH-IDCA1404-10-140-54-122",
"host": "SH-IDCA1404-10-140-54-56",
"port": 5678
},
"pathMappings": [
Expand Down
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
3 changes: 2 additions & 1 deletion conf/deepspeed/bf16_zero1_default.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
"wall_clock_breakdown": false,
"reduce_bucket_size": 536870912
}
15 changes: 15 additions & 0 deletions conf/deepspeed/bf16_zero3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false,
"reduce_bucket_size": 536870912
}
38 changes: 25 additions & 13 deletions scripts/cpt/fpt_13b.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/usr/bin/bash

#SBATCH --job-name=cpt-13b-test
#SBATCH --job-name=cpt-7b-4_16_noisygate
#SBATCH --output=logs/%x-%j.log
#SBATCH --error=logs/%x-%j.log
##SBATCH --output=logs/%x.log
##SBATCH --error=logs/%x.log

#SBATCH --partition=MoE
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=32
#SBATCH --cpus-per-task=16
#SBATCH --mem=0

#SBATCH --nodes=2
#SBATCH --gres=gpu:8
#SBATCH --quotatype=auto
#SBATCH --quotatype=reserved
##SBATCH --time=5:00:00

source ~/anaconda3/bin/activate smoe
Expand All @@ -29,9 +29,11 @@ source ~/anaconda3/bin/activate smoe
# export TORCH_DISTRIBUTED_DEBUG=DETAIL
# export TORCH_SHOW_CPP_STACKTRACES=1
# export CUDA_LAUNCH_BLOCKING=1
# export ACCELERATE_DEBUG_MODE=1

# comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4, expert weight re-scale"
comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4"
# comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4"
comment="llama2 7B, expert 4/16, noisy gate, seq len 4096, lr=2e-4"
# comment="random initialized llama1-7B"
# comment="random initialized llama1-13B"
# comment="7B, expert 4/16, noisy gate, gradient shared neurons, w/o residual, w/o weight re-scale, lr2e-4"
Expand All @@ -42,9 +44,11 @@ source ~/anaconda3/bin/activate smoe
# pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random
# pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random
model_type="llama_moe"
# pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama_7B-16Select16-up_proj
pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-16Select4-688Neurons-Share
# pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_3B-8Select2-4320Neurons-Share"
# pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-16Select4-688Neurons-Share"
pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Gradient-max-l1_norm-sample-feature_change/llama_13B-16Select4-864Neurons-Share"
# pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Gradient-max-l1_norm-sample-feature_change/llama_13B-16Select4-864Neurons-Share"
# pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B_MoE_16Select4-l2_norm
# pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Clustering-l2/llama_13B-16Select4-up_proj"
# pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj
Expand All @@ -54,24 +58,28 @@ source ~/anaconda3/bin/activate smoe
# pretrained_model=$1
echo "==================> $pretrained_model <=================="

tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B/
# tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B
# tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax-copy/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj
# tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random
tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_13B
# tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_13B
# tokenizer_path="/mnt/petrelfs/share_data/quxiaoye/models/llama_3B"

dataset_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed
# dataset_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed
dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed/
# dataset_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized
validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/

lr=2e-4
final_lr_portion=0.1
per_device_train_batch_size=8
per_device_eval_batch_size=1
per_device_eval_batch_size=8
gradient_accumulation_steps=4
num_tokens="3*10^11"
num_tokens="1*10^11"
seed=1227
block_size=2048
block_size=4096
deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json
# deepspeed_config_file=conf/deepspeed/bf16_zero3.json

max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc)
max_train_samples=$(echo "${num_tokens} / $block_size" | bc)
Expand Down Expand Up @@ -114,10 +122,13 @@ source ~/anaconda3/bin/activate smoe
--tokenizer_name_or_path ${tokenizer_path} \
--dataset_dir ${dataset_dir} \
--data_cache_dir ${data_cache} \
--validation_split_percentage 0.001 \
--validation_dir ${validation_dir} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--do_train \
--do_eval \
--evaluation_strategy steps \
--eval_steps 1000 \
--seed ${seed} \
--bf16 \
--num_train_epochs 1 \
Expand All @@ -135,6 +146,7 @@ source ~/anaconda3/bin/activate smoe
--save_total_limit 1 \
--save_steps 1000 \
--dataloader_num_workers 0 \
--dataloader_pin_memory True \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--block_size ${block_size} \
--output_dir ${output_dir} \
Expand All @@ -149,9 +161,9 @@ source ~/anaconda3/bin/activate smoe
--log_level info \
--log_level_replica warning \
--log_on_each_node False \
--report_to none \
--gate_type "TopKBalancedNoisyGate" \
--calculator_type "UniversalCalculator" \
--num_selects 4 \
--report_to none
--num_selects 4

}
16 changes: 10 additions & 6 deletions scripts/tokenize/redpajama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

set -vx

# tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B
# data_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data
# out_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed
tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B
data_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data
out_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed

# tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_3B
# data_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples
# out_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized

tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_3B
data_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples
out_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized
# tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B
# data_dir=/mnt/petrelfs/zhutong/lm-evaluation-harness-b281b0921b636bc36ad05c0b0b0763bd6dd43463/val_set/final
# out_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized

logs_dir=logs

Expand Down
77 changes: 59 additions & 18 deletions smoe/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import time
from typing import Iterable

import torch
from transformers import TrainerControl, TrainerState, TrainingArguments
from transformers.integrations import TensorBoardCallback, logger, rewrite_logs
from transformers.integrations import TensorBoardCallback, rewrite_logs

from smoe.utils.visualization.visualize import get_heatmap_img_grid_for_tb


class EnhancedTensorboardCallback(TensorBoardCallback):
def __init__(self, tb_writer=None):
super().__init__(tb_writer)

self._heatmap_img_dump_step = 200

def on_log(
self,
args: TrainingArguments,
Expand All @@ -21,15 +29,35 @@ def on_log(
self._init_summary_writer(args)

if self.tb_writer is not None:
logs.update({"Total_FLOPs": state.total_flos})
logs.update({"Buggy_Estimated_Total_FLOPs": state.total_flos})
free_mem, tot_mem = torch.cuda.mem_get_info()
used_mem = (tot_mem - free_mem) / 1024**3
gpu_util = torch.cuda.utilization()
logs.update({"GPU_Mem_GB": used_mem, "GPU_Util": gpu_util})
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
if "train/loss" == k:
if k == "train/loss":
tokens = state.global_step * args.num_tokens_per_batch
token_loss_key = "train/loss_on_tokens"
self.tb_writer.add_scalar(token_loss_key, v, tokens)
elif k == "train/Buggy_Estimated_Total_FLOPs":
# write tokens per GPU per second (TGS) and Model FLOPs Utilization (MFU)
seconds = time.time() - state.start_timestamp
tgs = (
state.global_step
* args.num_tokens_per_batch
/ args.world_size
/ seconds
)
self.tb_writer.add_scalar(
"train/Avg_TGS", tgs, state.global_step
)
mfu = 6 * args.num_training_params * tgs / args.flops_per_device
self.tb_writer.add_scalar(
"train/Avg_MFU_per_second", mfu, state.global_step
)
elif k == "train/balance_loss":
if isinstance(v, torch.Tensor) and hasattr(v, "item"):
_v = v.item()
Expand All @@ -38,28 +66,41 @@ def on_log(
else:
continue
self.tb_writer.add_scalar(k, _v, state.global_step)
elif k == "train/num_dropped_tokens" and isinstance(v, tuple):
elif k == "train/num_dropped_tokens" and isinstance(v, Iterable):
# (tensor(1.0), tensor(2.3)) -> [1.0, 2.3]
if all(isinstance(n, torch.Tensor) for n in v):
if (
state.global_step
% (self._heatmap_img_dump_step * args.logging_steps)
== 0
):
self.tb_writer.add_image(
k, get_heatmap_img_grid_for_tb(v), state.global_step
)
v = [n.item() for n in v]
self.tb_writer.add_scalars(
f"{k}/layer",
{str(i): n for i, n in enumerate(v)},
state.global_step,
)
# self.tb_writer.add_scalars(
# f"{k}/layer",
# {str(i): n for i, n in enumerate(v)},
# state.global_step,
# )
self.tb_writer.add_scalar(f"{k}/total", sum(v), state.global_step)
elif (
k == "train/gate_load" or k == "train/gate_importance"
) and isinstance(v, tuple):
) and isinstance(v, Iterable):
if not all(isinstance(n, torch.Tensor) for n in v):
v = [torch.tensor(n) for n in v]
# v: (tensor([1.0, 2.3, ... num_experts]), tensor([3.0, 4.5, ... num_experts]), ... num_layers)
self.tb_writer.add_scalars(
f"{k}/std/layer",
{str(i): n.std().item() for i, n in enumerate(v)},
state.global_step,
)
self.tb_writer.add_image(
k, get_heatmap_img_grid_for_tb(v), state.global_step
)
# self.tb_writer.add_scalars(
# f"{k}/std/layer",
# {str(i): n.std().item() for i, n in enumerate(v)},
# state.global_step,
# )
if (
state.global_step
% (self._heatmap_img_dump_step * args.logging_steps)
== 0
):
self.tb_writer.add_image(
k, get_heatmap_img_grid_for_tb(v), state.global_step
)
self.tb_writer.flush()
28 changes: 26 additions & 2 deletions smoe/data/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from typing import Iterator

import torch
from torch.utils.data import IterableDataset
from torch.utils.data import Dataset, IterableDataset

from smoe.data.aggregation import group_instances
from smoe.utils.io import load_jsonlines_iter
from smoe.utils.io import load_jsonlines, load_jsonlines_iter
from smoe.utils.logging import get_logger
from smoe.utils.random import get_random_string
from smoe.utils.vars import JSONL_DATASET_CACHE_NAME
Expand Down Expand Up @@ -193,6 +193,30 @@ def __iter__(self) -> Iterator:
yield from ds


class CachedJsonlDataset(Dataset):
def __init__(
self,
filepath: str,
seed: int = 1227,
buffer_size: int = 700,
block_size: int = 2048,
):
super().__init__()
self.filepath = filepath
self.rng = random.Random(seed)
self.buffer_size = buffer_size
self.block_size = block_size

dataset = load_jsonlines(self.filepath)
self.cached = group_instances(dataset, self.block_size)

def __getitem__(self, index: int):
return self.cached[index]

def __len__(self):
return len(self.cached)


class PackedJsonlDataset(IterableDataset):
def __init__(
self,
Expand Down
Loading

0 comments on commit 678f83b

Please sign in to comment.