diff --git a/.gitignore b/.gitignore index 5c2bc94..c5387a4 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,4 @@ results/gate_loss_original_clustering_model results/llama_7B_MoE_16Select4-l2_norm results/random_16select4_moe results/gate_loss.png +smoe/utils/gpu_diag.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 6010699..b3f5f9c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "type": "python", "request": "attach", "connect": { - "host": "SH-IDCA1404-10-140-54-122", + "host": "SH-IDCA1404-10-140-54-56", "port": 5678 }, "pathMappings": [ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..a3a1838 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/conf/deepspeed/bf16_zero1_default.json b/conf/deepspeed/bf16_zero1_default.json index 10f23ac..199498a 100644 --- a/conf/deepspeed/bf16_zero1_default.json +++ b/conf/deepspeed/bf16_zero1_default.json @@ -10,5 +10,6 @@ "steps_per_print": 2000, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false + "wall_clock_breakdown": false, + "reduce_bucket_size": 536870912 } diff --git a/conf/deepspeed/bf16_zero3.json b/conf/deepspeed/bf16_zero3.json new file mode 100644 index 0000000..de7b831 --- /dev/null +++ b/conf/deepspeed/bf16_zero3.json @@ -0,0 +1,15 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3 + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false, + "reduce_bucket_size": 536870912 +} diff --git a/scripts/cpt/fpt_13b.sh b/scripts/cpt/fpt_13b.sh index cab4512..0ade3c9 100644 --- a/scripts/cpt/fpt_13b.sh +++ b/scripts/cpt/fpt_13b.sh @@ -1,6 +1,6 @@ #!/usr/bin/bash -#SBATCH --job-name=cpt-13b-test +#SBATCH --job-name=cpt-7b-4_16_noisygate #SBATCH --output=logs/%x-%j.log #SBATCH --error=logs/%x-%j.log ##SBATCH --output=logs/%x.log @@ -8,12 +8,12 @@ #SBATCH --partition=MoE #SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=32 +#SBATCH --cpus-per-task=16 #SBATCH --mem=0 #SBATCH --nodes=2 #SBATCH --gres=gpu:8 -#SBATCH --quotatype=auto +#SBATCH --quotatype=reserved ##SBATCH --time=5:00:00 source ~/anaconda3/bin/activate smoe @@ -29,9 +29,11 @@ source ~/anaconda3/bin/activate smoe # export TORCH_DISTRIBUTED_DEBUG=DETAIL # export TORCH_SHOW_CPP_STACKTRACES=1 # export CUDA_LAUNCH_BLOCKING=1 + # export ACCELERATE_DEBUG_MODE=1 # comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4, expert weight re-scale" - comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4" + # comment="13B, expert 4/16, noisy gate, seq len 2048, lr=4e-4" + comment="llama2 7B, expert 4/16, noisy gate, seq len 4096, lr=2e-4" # comment="random initialized llama1-7B" # comment="random initialized llama1-13B" # comment="7B, expert 4/16, noisy gate, gradient shared neurons, w/o residual, w/o weight re-scale, lr2e-4" @@ -42,9 +44,11 @@ source ~/anaconda3/bin/activate smoe # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random model_type="llama_moe" + # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Random/llama_7B-16Select16-up_proj + pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama2_7B-16Select4-688Neurons-Share # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_3B-8Select2-4320Neurons-Share" # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM/Gradient-max-l1_norm-sample-feature_change/llama_7B-16Select4-688Neurons-Share" - pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Gradient-max-l1_norm-sample-feature_change/llama_13B-16Select4-864Neurons-Share" + # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Gradient-max-l1_norm-sample-feature_change/llama_13B-16Select4-864Neurons-Share" # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B_MoE_16Select4-l2_norm # pretrained_model="/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-copy/Clustering-l2/llama_13B-16Select4-up_proj" # pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj @@ -54,24 +58,28 @@ source ~/anaconda3/bin/activate smoe # pretrained_model=$1 echo "==================> $pretrained_model <==================" + tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama2_7B/ # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/LlamaMoEForCausalLM-no-softmax-copy/Clustering-l2-l2_norm/llama_13B-16Select4-gate_proj # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama1_7B_random - tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_13B + # tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_13B # tokenizer_path="/mnt/petrelfs/share_data/quxiaoye/models/llama_3B" - dataset_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed + # dataset_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed + dataset_dir=/mnt/petrelfs/share_data/quxiaoye/SlimPajama_processed/ # dataset_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized + validation_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized/ lr=2e-4 final_lr_portion=0.1 per_device_train_batch_size=8 - per_device_eval_batch_size=1 + per_device_eval_batch_size=8 gradient_accumulation_steps=4 - num_tokens="3*10^11" + num_tokens="1*10^11" seed=1227 - block_size=2048 + block_size=4096 deepspeed_config_file=conf/deepspeed/bf16_zero1_default.json + # deepspeed_config_file=conf/deepspeed/bf16_zero3.json max_steps=$(echo "${num_tokens} / ($block_size * $per_device_train_batch_size * $gradient_accumulation_steps * $num_nodes * $num_gpu_per_node)" | bc) max_train_samples=$(echo "${num_tokens} / $block_size" | bc) @@ -114,10 +122,13 @@ source ~/anaconda3/bin/activate smoe --tokenizer_name_or_path ${tokenizer_path} \ --dataset_dir ${dataset_dir} \ --data_cache_dir ${data_cache} \ - --validation_split_percentage 0.001 \ + --validation_dir ${validation_dir} \ --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --do_train \ + --do_eval \ + --evaluation_strategy steps \ + --eval_steps 1000 \ --seed ${seed} \ --bf16 \ --num_train_epochs 1 \ @@ -135,6 +146,7 @@ source ~/anaconda3/bin/activate smoe --save_total_limit 1 \ --save_steps 1000 \ --dataloader_num_workers 0 \ + --dataloader_pin_memory True \ --gradient_accumulation_steps ${gradient_accumulation_steps} \ --block_size ${block_size} \ --output_dir ${output_dir} \ @@ -149,9 +161,9 @@ source ~/anaconda3/bin/activate smoe --log_level info \ --log_level_replica warning \ --log_on_each_node False \ + --report_to none \ --gate_type "TopKBalancedNoisyGate" \ --calculator_type "UniversalCalculator" \ - --num_selects 4 \ - --report_to none + --num_selects 4 } diff --git a/scripts/tokenize/redpajama.sh b/scripts/tokenize/redpajama.sh index 1c2400b..f4f7bec 100644 --- a/scripts/tokenize/redpajama.sh +++ b/scripts/tokenize/redpajama.sh @@ -2,13 +2,17 @@ set -vx -# tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B -# data_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data -# out_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed +tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B +data_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data +out_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed + +# tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_3B +# data_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples +# out_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized -tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_3B -data_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples -out_dir=/mnt/petrelfs/zhutong/smoe/resources/slimpajama_samples_openllama3B_tokenized +# tokenizer_dir=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B +# data_dir=/mnt/petrelfs/zhutong/lm-evaluation-harness-b281b0921b636bc36ad05c0b0b0763bd6dd43463/val_set/final +# out_dir=/mnt/petrelfs/share_data/quxiaoye/data/llama1_7B_val_set_tokenized logs_dir=logs diff --git a/smoe/callbacks/tensorboard.py b/smoe/callbacks/tensorboard.py index 6696024..a3e0469 100644 --- a/smoe/callbacks/tensorboard.py +++ b/smoe/callbacks/tensorboard.py @@ -1,11 +1,19 @@ +import time +from typing import Iterable + import torch from transformers import TrainerControl, TrainerState, TrainingArguments -from transformers.integrations import TensorBoardCallback, logger, rewrite_logs +from transformers.integrations import TensorBoardCallback, rewrite_logs from smoe.utils.visualization.visualize import get_heatmap_img_grid_for_tb class EnhancedTensorboardCallback(TensorBoardCallback): + def __init__(self, tb_writer=None): + super().__init__(tb_writer) + + self._heatmap_img_dump_step = 200 + def on_log( self, args: TrainingArguments, @@ -21,15 +29,35 @@ def on_log( self._init_summary_writer(args) if self.tb_writer is not None: - logs.update({"Total_FLOPs": state.total_flos}) + logs.update({"Buggy_Estimated_Total_FLOPs": state.total_flos}) + free_mem, tot_mem = torch.cuda.mem_get_info() + used_mem = (tot_mem - free_mem) / 1024**3 + gpu_util = torch.cuda.utilization() + logs.update({"GPU_Mem_GB": used_mem, "GPU_Util": gpu_util}) logs = rewrite_logs(logs) for k, v in logs.items(): if isinstance(v, (int, float)): self.tb_writer.add_scalar(k, v, state.global_step) - if "train/loss" == k: + if k == "train/loss": tokens = state.global_step * args.num_tokens_per_batch token_loss_key = "train/loss_on_tokens" self.tb_writer.add_scalar(token_loss_key, v, tokens) + elif k == "train/Buggy_Estimated_Total_FLOPs": + # write tokens per GPU per second (TGS) and Model FLOPs Utilization (MFU) + seconds = time.time() - state.start_timestamp + tgs = ( + state.global_step + * args.num_tokens_per_batch + / args.world_size + / seconds + ) + self.tb_writer.add_scalar( + "train/Avg_TGS", tgs, state.global_step + ) + mfu = 6 * args.num_training_params * tgs / args.flops_per_device + self.tb_writer.add_scalar( + "train/Avg_MFU_per_second", mfu, state.global_step + ) elif k == "train/balance_loss": if isinstance(v, torch.Tensor) and hasattr(v, "item"): _v = v.item() @@ -38,28 +66,41 @@ def on_log( else: continue self.tb_writer.add_scalar(k, _v, state.global_step) - elif k == "train/num_dropped_tokens" and isinstance(v, tuple): + elif k == "train/num_dropped_tokens" and isinstance(v, Iterable): # (tensor(1.0), tensor(2.3)) -> [1.0, 2.3] if all(isinstance(n, torch.Tensor) for n in v): + if ( + state.global_step + % (self._heatmap_img_dump_step * args.logging_steps) + == 0 + ): + self.tb_writer.add_image( + k, get_heatmap_img_grid_for_tb(v), state.global_step + ) v = [n.item() for n in v] - self.tb_writer.add_scalars( - f"{k}/layer", - {str(i): n for i, n in enumerate(v)}, - state.global_step, - ) + # self.tb_writer.add_scalars( + # f"{k}/layer", + # {str(i): n for i, n in enumerate(v)}, + # state.global_step, + # ) self.tb_writer.add_scalar(f"{k}/total", sum(v), state.global_step) elif ( k == "train/gate_load" or k == "train/gate_importance" - ) and isinstance(v, tuple): + ) and isinstance(v, Iterable): if not all(isinstance(n, torch.Tensor) for n in v): v = [torch.tensor(n) for n in v] # v: (tensor([1.0, 2.3, ... num_experts]), tensor([3.0, 4.5, ... num_experts]), ... num_layers) - self.tb_writer.add_scalars( - f"{k}/std/layer", - {str(i): n.std().item() for i, n in enumerate(v)}, - state.global_step, - ) - self.tb_writer.add_image( - k, get_heatmap_img_grid_for_tb(v), state.global_step - ) + # self.tb_writer.add_scalars( + # f"{k}/std/layer", + # {str(i): n.std().item() for i, n in enumerate(v)}, + # state.global_step, + # ) + if ( + state.global_step + % (self._heatmap_img_dump_step * args.logging_steps) + == 0 + ): + self.tb_writer.add_image( + k, get_heatmap_img_grid_for_tb(v), state.global_step + ) self.tb_writer.flush() diff --git a/smoe/data/streaming.py b/smoe/data/streaming.py index 4a018c6..576f05d 100644 --- a/smoe/data/streaming.py +++ b/smoe/data/streaming.py @@ -10,10 +10,10 @@ from typing import Iterator import torch -from torch.utils.data import IterableDataset +from torch.utils.data import Dataset, IterableDataset from smoe.data.aggregation import group_instances -from smoe.utils.io import load_jsonlines_iter +from smoe.utils.io import load_jsonlines, load_jsonlines_iter from smoe.utils.logging import get_logger from smoe.utils.random import get_random_string from smoe.utils.vars import JSONL_DATASET_CACHE_NAME @@ -193,6 +193,30 @@ def __iter__(self) -> Iterator: yield from ds +class CachedJsonlDataset(Dataset): + def __init__( + self, + filepath: str, + seed: int = 1227, + buffer_size: int = 700, + block_size: int = 2048, + ): + super().__init__() + self.filepath = filepath + self.rng = random.Random(seed) + self.buffer_size = buffer_size + self.block_size = block_size + + dataset = load_jsonlines(self.filepath) + self.cached = group_instances(dataset, self.block_size) + + def __getitem__(self, index: int): + return self.cached[index] + + def __len__(self): + return len(self.cached) + + class PackedJsonlDataset(IterableDataset): def __init__( self, diff --git a/smoe/entrypoint/cpt/cpt_fpt.py b/smoe/entrypoint/cpt/cpt_fpt.py index abebc05..2c5d1d8 100644 --- a/smoe/entrypoint/cpt/cpt_fpt.py +++ b/smoe/entrypoint/cpt/cpt_fpt.py @@ -1,7 +1,11 @@ +import logging import os +import sys +from pathlib import Path +import datasets import torch -from torch.distributed.elastic.multiprocessing.errors import record +import transformers from transformers import ( CONFIG_MAPPING, AutoConfig, @@ -19,8 +23,7 @@ from smoe.callbacks.tensorboard import EnhancedTensorboardCallback from smoe.data.collate_fn import fault_tolerance_data_collator from smoe.data.redpajama import load_streaming_datasets -from smoe.data.streaming import SubDirWeightedPackedJsonlDataset -from smoe.metrics.accuracy import compute_metrics +from smoe.data.streaming import CachedJsonlDataset, SubDirWeightedPackedJsonlDataset from smoe.metrics.preprocess import logits_argmax from smoe.models.llama_moefication.configuration_llama_moe import LlamaMoEConfig from smoe.models.llama_moefication.modeling_llama_moe import LlamaMoEForCausalLM @@ -32,7 +35,6 @@ ModelArguments, parse_args, ) -from smoe.utils.logging import get_logger_from_training_args from smoe.utils.notification import wechat_sender from smoe.utils.param import get_trainable_parameters @@ -49,12 +51,32 @@ ) -@wechat_sender() +logger = logging.getLogger(__name__) + + +# @wechat_sender() def main(): model_args, data_args, training_args = parse_args( ModelArguments, DataArguments, EnhancedTrainingArguments ) - logger = get_logger_from_training_args(__name__, training_args) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.warning( f"Process local rank: {training_args.local_rank}, " f"device: {training_args.device}, " @@ -68,12 +90,9 @@ def main(): logger.info(f"Training args: {training_args.to_json_string()}") if training_args.debug_mode: - import torch.distributed as dist - from smoe.utils.debugging import remote_breakpoint - if dist.get_rank() == 0: - remote_breakpoint() + remote_breakpoint() # Detecting last checkpoint. last_checkpoint = None @@ -132,7 +151,7 @@ def main(): # zhutong: this is for debug usage only if training_args.debug_mode: - config.num_hidden_layers = 1 + config.num_hidden_layers = 2 tokenizer_kwargs = { "cache_dir": model_args.cache_dir, @@ -235,7 +254,14 @@ def main(): eval_dataset = None if training_args.do_eval: - raise NotImplementedError + paths = Path(data_args.validation_dir).glob("*.jsonl") + eval_dataset = { + path.stem: CachedJsonlDataset( + str(path), training_args.seed, block_size=data_args.block_size + ) + for path in paths + } + logger.info(f"eval types: {list(eval_dataset.keys())}") if model_args.model_name_or_path: torch_dtype = ( @@ -286,7 +312,8 @@ def main(): f" tokenizer ({len(tokenizer)})" ) - get_trainable_parameters(model, verbose=True) + trainable_params, _ = get_trainable_parameters(model, verbose=True) + training_args.num_training_params = trainable_params # Initialize our Trainer trainer = LlamaLrSchedulingTrainer( @@ -296,11 +323,7 @@ def main(): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=fault_tolerance_data_collator, - compute_metrics=( - compute_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None - ), + compute_metrics=None, preprocess_logits_for_metrics=( logits_argmax if training_args.do_eval and not is_torch_tpu_available() diff --git a/smoe/models/llama_moefication/modeling_llama_moe.py b/smoe/models/llama_moefication/modeling_llama_moe.py index 2c0749e..e005df2 100644 --- a/smoe/models/llama_moefication/modeling_llama_moe.py +++ b/smoe/models/llama_moefication/modeling_llama_moe.py @@ -484,7 +484,7 @@ def forward( # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - if outputs.balance_loss is not None: + if outputs.balance_loss is not None and outputs.balance_loss > 0: loss += outputs.balance_loss if not return_dict: @@ -628,7 +628,7 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) - if loss is not None and balance_loss is not None: + if loss is not None and balance_loss is not None and balance_loss > 0: loss += balance_loss if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] diff --git a/smoe/modules/flash_attn.py b/smoe/modules/flash_attn.py index 873fb18..6805239 100644 --- a/smoe/modules/flash_attn.py +++ b/smoe/modules/flash_attn.py @@ -1,17 +1,17 @@ from types import MethodType -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch import torch.nn as nn from transformers.models.llama.modeling_llama import ( LlamaAttention, + LlamaRMSNorm, apply_rotary_pos_emb, ) SUPPORT_XFORMERS = True SUPPORT_FLASH2 = False -# from flash_attn import flash_attn_func, flash_attn_qkvpacked_func try: import xformers.ops as xops @@ -95,7 +95,24 @@ def llama_flash_attention( return attn_output, attn_weights, past_key_value +def llama_fast_rms_norm(self: LlamaRMSNorm, hidden_states: torch.FloatTensor): + # return rms_norm(hidden_states, self.weight, self.variance_epsilon) + from apex.normalization.fused_layer_norm import ( + manual_rms_norm, + mixed_dtype_fused_rms_norm_affine, + ) + + hsz = self.weight.shape[0] + if not input.is_cuda: + return manual_rms_norm(hidden_states, hsz, self.weight, self.variance_epsilon) + return mixed_dtype_fused_rms_norm_affine( + hidden_states, self.weight, hsz, self.variance_epsilon + ) + + def replace_xformers(model: nn.Module): for module in model.modules(): if isinstance(module, LlamaAttention): module.forward = MethodType(llama_flash_attention, module) + # if isinstance(module, LlamaRMSNorm) and rms_norm is not None and isinstance(rms_norm, Callable): + # module.forward = MethodType(llama_fast_rms_norm, module) diff --git a/smoe/modules/moe/moe_calculators.py b/smoe/modules/moe/moe_calculators.py index 02c0752..fa2676d 100644 --- a/smoe/modules/moe/moe_calculators.py +++ b/smoe/modules/moe/moe_calculators.py @@ -23,6 +23,8 @@ class UniversalCalculator(nn.Module): def __init__(self, experts, multiply_gate_scores=True): super(UniversalCalculator, self).__init__() self.experts = experts + # TODO (zhutong): use vmap to boost the training efficiency + # self.experts_vmap = torch.vmap(self.experts) self.multiply_gate_scores = multiply_gate_scores self.num_experts = experts.num_experts @@ -36,25 +38,27 @@ def forward( num_selects = topK_indices.size(1) topK_indices = topK_indices.flatten() # shape(batch_size*num_selects) topK_scores = topK_scores.flatten() # shape(batch_size*num_selects) - batch_indices = torch.arange(batch_size, device=topK_scores.device).repeat_interleave(num_selects) # 选出的专家编号所对应的batch编号,shape(batch_size*num_selects) + # 选出的专家编号所对应的batch编号,shape(batch_size*num_selects) + # repeat_interleave(repeats=2): [1,2,3] -> [1,1,2,2,3,3] + batch_indices = torch.arange(batch_size, device=topK_scores.device).repeat_interleave(num_selects) """按照专家序号从小到大的顺序,生成专家索引""" _, index_sorted_topK_indices = topK_indices.sort(0) """按照索引重新排列scores与batch_indices,并计算专家的batch_size""" sorted_topK_scores = topK_scores.index_select(0, index_sorted_topK_indices) # 各个输出对应的权重 - sorted_batch_indices = batch_indices.index_select(0, index_sorted_topK_indices) # 各个专家对应的batch编号 + sorted_batch_indices = batch_indices.index_select(0, index_sorted_topK_indices) # 各个专家对应的一个batch里所有token的编号 if expert_batch_size is None: - expert_batch_size = topK_indices.bincount().tolist() # 各个专家对应的batch_size - if len(expert_batch_size) < self.num_experts: # 列表长度不足专家数,说明 被选择的最大专家序号 小于 所有专家中的最大专家序号 - expert_batch_size.extend([0] * (self.num_experts - len(expert_batch_size))) # 使用0补全列表 + expert_batch_size = topK_indices.bincount(minlength=self.num_experts).tolist() # 各个专家对应的batch_size """对每个专家重新组合batch""" - sorted_x = x.index_select(0, sorted_batch_indices).squeeze(1) # 将输入按照排序后的batch编号,重新编制 + sorted_x = x.index_select(0, sorted_batch_indices) # 将输入按照排序后的batch编号,重新编制 split_x = torch.split(sorted_x, expert_batch_size, dim=0) # 按照排序后每个专家的batch_size进行分隔,恰好得到各个专家所需的batch """各专家分别正向传播""" # 此处应该有并行优化的空间 (如果单次forward不足以占满显卡利用率) + # args = [(split_x[i], i) for i in range(self.num_experts) if split_x[i].shape[0] > 0] + # expert_outputs = self.experts_vmap(args) expert_outputs = [self.experts(split_x[i], i) for i in range(self.num_experts) if split_x[i].shape[0] > 0] """重组各个专家的输出,并进行加权""" diff --git a/smoe/modules/moe/moe_gates.py b/smoe/modules/moe/moe_gates.py index 92aeb5f..2261c22 100644 --- a/smoe/modules/moe/moe_gates.py +++ b/smoe/modules/moe/moe_gates.py @@ -119,6 +119,7 @@ def forward(self, x): zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device) scores_filtered = zeros.scatter(dim=1, index=top_k_indices, src=top_k_scores) # shape(batch_size, num_experts) importance = scores_filtered.sum(0) # shape(num_experts) + # importance = scores_filtered.float().sum(0) # shape(num_experts) """计算load""" batch_size = logits_gate.size(0) @@ -133,14 +134,15 @@ def forward(self, x): prob_if_in = self.normal.cdf((logits_gate - threshold_if_in) / noise_control) prob_if_out = self.normal.cdf((logits_gate - threshold_if_out) / noise_control) prob = torch.where(is_in, prob_if_in, prob_if_out) + # load = prob.float().sum(0) load = prob.sum(0) """计算balance loss""" balance_loss = self.cv_squared(importance) + self.cv_squared(load) balance_loss *= self.balance_loss_weight - + # balance_loss = balance_loss.to(logits) # fallback to fp16 else: - balance_loss = None + balance_loss = torch.tensor(-100.0) return { "topK_indices": top_k_indices, diff --git a/smoe/trainer/llama_lr_scheduling.py b/smoe/trainer/llama_lr_scheduling.py index ff95812..a35ba20 100644 --- a/smoe/trainer/llama_lr_scheduling.py +++ b/smoe/trainer/llama_lr_scheduling.py @@ -4,6 +4,7 @@ import shutil import sys import time +from dataclasses import dataclass from functools import partial from pathlib import Path from typing import Any, Dict, Union @@ -89,6 +90,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda( ) +@dataclass +class EnhancedTrainerState(TrainerState): + # last Token/GPU/second timestamp + start_timestamp: float = 0.0 + + class LlamaLrSchedulingTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -113,6 +120,7 @@ def create_scheduler( ) last_epoch = -1 self.lr_scheduler = LambdaLR(optimizer, lr_lambda, last_epoch) + self._created_lr_scheduler = True return self.lr_scheduler def training_step( @@ -387,7 +395,7 @@ def _inner_training_loop( if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self.state = TrainerState() + self.state = EnhancedTrainerState() self.state.is_hyper_param_search = trial is not None # Activate gradient checkpointing if needed @@ -544,6 +552,7 @@ def _inner_training_loop( for _ in train_dataloader: break + self.state.start_timestamp = time.time() total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader diff --git a/smoe/utils/config.py b/smoe/utils/config.py index 0fd8594..76642ed 100644 --- a/smoe/utils/config.py +++ b/smoe/utils/config.py @@ -163,6 +163,15 @@ class DataArguments: default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}, ) + validation_dir: Optional[str] = field( + default=None, + metadata={ + "help": ( + "An optional input evaluation data file to evaluate the perplexity on" + " (a folder of text files)." + ) + }, + ) dataset_config_name: Optional[str] = field( default=None, metadata={ @@ -292,6 +301,18 @@ class EnhancedTrainingArguments(TrainingArguments): default=-1, metadata={"help": "the number of max_tokens"}, ) + flops_per_device: Optional[int] = field( + default=312e12, + metadata={ + "help": "FLOPS of one device. A100 312 TFLOPS", + }, + ) + num_training_params: Optional[int] = field( + default=-1, + metadata={ + "help": "The number of model parameters used for training. If set to -1, it will be calculated automatically." + }, + ) @property def block_size(self): diff --git a/smoe/utils/visualization/visualize.py b/smoe/utils/visualization/visualize.py index f0d0511..b654efd 100644 --- a/smoe/utils/visualization/visualize.py +++ b/smoe/utils/visualization/visualize.py @@ -217,9 +217,9 @@ def find_factors_with_minimal_sum(number): def plot_to_image(figure): """Converts the matplotlib plot specified by 'figure' to a PNG image and returns it. The supplied figure is closed and inaccessible after this call.""" - # Save the plot to a PNG in memory. + # Save the plot to a image in memory. buf = io.BytesIO() - plt.savefig(buf, format="png") + plt.savefig(buf, format="jpg") # Closing the figure prevents it from being displayed directly inside # the notebook. plt.close(figure) @@ -301,7 +301,7 @@ def vis_tuple_heatmaps(tensors: tuple[torch.FloatTensor]): ax.text( col, row, - f"{data[i, row, col]:.5f}", + f"{data[i, row, col]:.0f}", ha="center", va="center", color="black", diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index cae1582..72b897d 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -4,7 +4,9 @@ from pathlib import Path import pytest +from torch.utils.data import DataLoader +from smoe.data.collate_fn import fault_tolerance_data_collator from smoe.data.streaming import JsonlDataset, SubDirWeightedPackedJsonlDataset from smoe.utils.io import load_jsonlines @@ -99,6 +101,38 @@ def test_weighted_streaming(): break +def test_weighted_streaming_loader(): + prob_map = { + "en_cc": 0.67, + "en_c4": 0.15, + "github": 0.045, + "en_wikipedia": 0.045, + "en_book": 0.045, + "en_arxiv": 0.025, + "en_stack": 0.02, + } + lm_datasets = SubDirWeightedPackedJsonlDataset( + "/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed", + prob_map=prob_map, + seed=1227, + block_size=2048, + ) + num_test_case = 2000 + bsz = 8 + loader = DataLoader( + lm_datasets, + batch_size=bsz, + num_workers=4, + collate_fn=fault_tolerance_data_collator, + pin_memory=True, + ) + for batch in loader: + if num_test_case <= 0: + break + assert len(batch["input_ids"]) == bsz + num_test_case -= 1 + + if __name__ == "__main__": # test_jsonl_dataset() # test_subdir_weighted_pack_with_type()