diff --git a/README.md b/README.md index 381b4f1..5314038 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ## 🌴 Dependencies -- Python >= 3.10 +- Python >= 3.11 - scikit-learn>=1.3.0 - omegaconf>=2.0.6 - tqdm>=4.65.0 @@ -33,6 +33,12 @@ ```bash $ git clone git@github.com:pjlab-sys4nlp/train-moe.git +$ cd train-moe $ pip install -e .[dev] $ pre-commit install ``` + +## 🔗 Experiments + +- CPT + - [MoEfication L2-norm 8选4 继续预训练实验](https://m04hsypyylv.feishu.cn/docx/R9Tid61U0oOuQ4xwrbGcyCyvnMf) diff --git a/scripts/cpt/fpt.sh b/scripts/cpt/fpt.sh index 62f54e9..9d0179e 100644 --- a/scripts/cpt/fpt.sh +++ b/scripts/cpt/fpt.sh @@ -1,6 +1,6 @@ #!/usr/bin/bash -#SBATCH --job-name=cpt-moe-fpt-bs8 +#SBATCH --job-name=cpt-moe-fpt-bs1-debug #SBATCH --partition=MoE #SBATCH --output=logs/%x-%j.log #SBATCH --error=logs/%x-%j.log @@ -18,6 +18,8 @@ num_gpu_per_node=8 # should match with --gres # #cpu/#num_gpu_per_node export OMP_NUM_THREADS=1 +export NCCL_DEBUG=INFO +export LOGLEVEL=INFO lr=1e-4 @@ -29,7 +31,7 @@ pretrained_model=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B_MoE_16Select4 tokenizer_path=/mnt/petrelfs/share_data/quxiaoye/models/llama_7B dataset_dir=/mnt/petrelfs/share_data/quxiaoye/pretrain_LLAMA_all_data_processed -per_device_train_batch_size=8 +per_device_train_batch_size=1 per_device_eval_batch_size=1 gradient_accumulation_steps=1 block_size=2048 @@ -53,7 +55,6 @@ head_node=${nodes_array[0]} head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) echo "Node: $head_node" echo "Node IP: $head_node_ip" -export LOGLEVEL=INFO srun torchrun \ --nnodes ${num_nodes} \ diff --git a/smoe/data/llama_moefication_datasets.py b/smoe/data/llama_moefication_datasets.py index c722446..d1b7e48 100644 --- a/smoe/data/llama_moefication_datasets.py +++ b/smoe/data/llama_moefication_datasets.py @@ -21,7 +21,8 @@ def __init__( """numthreads should be set <=1, otherwise it will slow down the reading process by ~4 times""" if num_threads > 1: warnings.warn( - "num_threads should be set <=1, otherwise it will slow down the reading process by ~4 times!" + "num_threads should be set <=1, otherwise it will slow down the reading" + " process by ~4 times!" ) if os.path.isfile(file_path) is False: diff --git a/smoe/entrypoint/cpt_fpt.py b/smoe/entrypoint/cpt_fpt.py index 7fb67ac..d2d4bee 100644 --- a/smoe/entrypoint/cpt_fpt.py +++ b/smoe/entrypoint/cpt_fpt.py @@ -31,6 +31,7 @@ parse_args, ) from smoe.utils.logging import get_logger_from_training_args +from smoe.utils.param import get_trainable_parameters MODEL_MAP = { "llama": LlamaForCausalLM, @@ -72,15 +73,16 @@ def main(): last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." + f"Output directory ({training_args.output_dir}) already exists and is" + " not empty. Use --overwrite_output_dir to overcome." ) elif ( last_checkpoint is not None and training_args.resume_from_checkpoint is None ): logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid" + " this behavior, change the `--output_dir` or add" + " `--overwrite_output_dir` to train from scratch." ) # Set seed before initializing model. @@ -128,8 +130,9 @@ def main(): ) else: raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." + "You are instantiating a new tokenizer from scratch. This is not supported" + " by this script.You can do it from another script, save it, and load it" + " from here, using --tokenizer_name." ) # Preprocessing the datasets. @@ -137,16 +140,18 @@ def main(): block_size = tokenizer.model_max_length if block_size > 1024: logger.warning( - "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" - " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + "The chosen tokenizer supports a `model_max_length` that is longer than" + " the default `block_size` value of 1024. If you would like to use a" + " longer `block_size` up to `tokenizer.model_max_length` you can" " override this default with `--block_size xxx`." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( - f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + f"The block_size passed ({data_args.block_size}) is larger than the" + f" maximum length for the model({tokenizer.model_max_length}). Using" + f" block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) @@ -200,11 +205,14 @@ def main(): torch_dtype=torch_dtype, low_cpu_mem_usage=True, ) - for name, param in model.named_parameters(): - if "weight_noise.weight" in name: - nn.init.zeros_(param) - model.change_moe_gate_add_noise(False) - model.change_moe_gate_use_balance(False) + # train an MoE model from scratch 👇 + # model: LlamaMoEForCausalLM = LlamaMoEForCausalLM(config) + if isinstance(model, LlamaMoEForCausalLM): + for name, param in model.named_parameters(): + if "weight_noise.weight" in name: + nn.init.zeros_(param) + model.change_moe_gate_add_noise(False) + model.change_moe_gate_use_balance(False) replace_xformers(model) else: model = AutoModelForCausalLM.from_config(config) @@ -217,9 +225,12 @@ def main(): if model_vocab_size != len(tokenizer): model.resize_token_embeddings(len(tokenizer)) raise ValueError( - f"The model's vocab size ({model_vocab_size}) does not match with the tokenizer ({len(tokenizer)})" + f"The model's vocab size ({model_vocab_size}) does not match with the" + f" tokenizer ({len(tokenizer)})" ) + get_trainable_parameters(model, verbose=True) + # Initialize our Trainer trainer = LlamaLrSchedulingTrainer( model=model, @@ -228,12 +239,16 @@ def main(): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=fault_tolerance_data_collator, - compute_metrics=compute_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None, - preprocess_logits_for_metrics=logits_argmax - if training_args.do_eval and not is_torch_tpu_available() - else None, + compute_metrics=( + compute_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None + ), + preprocess_logits_for_metrics=( + logits_argmax + if training_args.do_eval and not is_torch_tpu_available() + else None + ), ) trainer.add_callback(SaveModelCallback) # Training diff --git a/smoe/entrypoint/cpt_lora.py b/smoe/entrypoint/cpt_lora.py index 805a809..1cc1cc5 100644 --- a/smoe/entrypoint/cpt_lora.py +++ b/smoe/entrypoint/cpt_lora.py @@ -32,6 +32,7 @@ parse_args, ) from smoe.utils.logging import get_logger_from_training_args +from smoe.utils.param import get_trainable_parameters MODEL_MAP = { "llama": LlamaForCausalLM, @@ -73,15 +74,16 @@ def main(): last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." + f"Output directory ({training_args.output_dir}) already exists and is" + " not empty. Use --overwrite_output_dir to overcome." ) elif ( last_checkpoint is not None and training_args.resume_from_checkpoint is None ): logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid" + " this behavior, change the `--output_dir` or add" + " `--overwrite_output_dir` to train from scratch." ) # Set seed before initializing model. @@ -129,8 +131,9 @@ def main(): ) else: raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." + "You are instantiating a new tokenizer from scratch. This is not supported" + " by this script.You can do it from another script, save it, and load it" + " from here, using --tokenizer_name." ) # Preprocessing the datasets. @@ -138,16 +141,18 @@ def main(): block_size = tokenizer.model_max_length if block_size > 1024: logger.warning( - "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" - " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + "The chosen tokenizer supports a `model_max_length` that is longer than" + " the default `block_size` value of 1024. If you would like to use a" + " longer `block_size` up to `tokenizer.model_max_length` you can" " override this default with `--block_size xxx`." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( - f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + f"The block_size passed ({data_args.block_size}) is larger than the" + f" maximum length for the model({tokenizer.model_max_length}). Using" + f" block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) @@ -221,7 +226,8 @@ def main(): if model_vocab_size != len(tokenizer): model.resize_token_embeddings(len(tokenizer)) raise ValueError( - f"The model's vocab size ({model_vocab_size}) does not match with the tokenizer ({len(tokenizer)})" + f"The model's vocab size ({model_vocab_size}) does not match with the" + f" tokenizer ({len(tokenizer)})" ) if training_args.peft_path is not None: logger.info("Peft from pre-trained model") @@ -258,7 +264,7 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model = get_peft_model(model, peft_config) - model.print_trainable_parameters() + get_trainable_parameters(model, verbose=True) # Initialize our Trainer trainer = LlamaLrSchedulingTrainer( @@ -268,12 +274,16 @@ def make_inputs_require_grad(module, input, output): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=fault_tolerance_data_collator, - compute_metrics=compute_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None, - preprocess_logits_for_metrics=logits_argmax - if training_args.do_eval and not is_torch_tpu_available() - else None, + compute_metrics=( + compute_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None + ), + preprocess_logits_for_metrics=( + logits_argmax + if training_args.do_eval and not is_torch_tpu_available() + else None + ), ) trainer.add_callback(SaveModelCallback) # Training diff --git a/smoe/entrypoint/moefication/llama_split_clustering.py b/smoe/entrypoint/moefication/llama_split_clustering.py index fbc101e..0a25a77 100644 --- a/smoe/entrypoint/moefication/llama_split_clustering.py +++ b/smoe/entrypoint/moefication/llama_split_clustering.py @@ -19,7 +19,10 @@ "--templates", type=str, default="layers.{}.mlp.gate_proj.weight", - help="weight names of the first linear layer in each FFN (use comma to separate multiple templates)", + help=( + "weight names of the first linear layer in each FFN (use comma to separate" + " multiple templates)" + ), ) parser.add_argument("--num_experts", type=int, default=8, help="number of experts") diff --git a/smoe/models/llama_moefication/modeling_llama_moe.py b/smoe/models/llama_moefication/modeling_llama_moe.py index 14d3c4b..d7a71f7 100644 --- a/smoe/models/llama_moefication/modeling_llama_moe.py +++ b/smoe/models/llama_moefication/modeling_llama_moe.py @@ -36,9 +36,11 @@ def __init__(self, config: LlamaMoEConfig, layer_index): hidden_act=config.hidden_act, num_experts=config.num_experts, num_selects=config.num_selects, - size_experts=config.size_experts[layer_index] - if config.size_experts is not None - else None, + size_experts=( + config.size_experts[layer_index] + if config.size_experts is not None + else None + ), bias=False, gate_network=config.gates, gate_use_balance=True, @@ -145,7 +147,8 @@ def forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at" + " the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape @@ -197,7 +200,8 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing." + " Setting `use_cache=False`..." ) use_cache = False diff --git a/smoe/modules/moefication/moe_experts.py b/smoe/modules/moefication/moe_experts.py index dd95b73..7c43aee 100644 --- a/smoe/modules/moefication/moe_experts.py +++ b/smoe/modules/moefication/moe_experts.py @@ -179,12 +179,15 @@ def forward(self, input, i): return down def extra_repr(self): - return "in_features={}, hidden_features={}, out_features={}, hidden_act={}, num_experts={}, size_experts={}, bias={}".format( - self.in_features, - self.hidden_features, - self.out_features, - self.hidden_act, - self.num_experts, - self.size_experts, - self.bias_gate is not None, + return ( + "in_features={}, hidden_features={}, out_features={}, hidden_act={}," + " num_experts={}, size_experts={}, bias={}".format( + self.in_features, + self.hidden_features, + self.out_features, + self.hidden_act, + self.num_experts, + self.size_experts, + self.bias_gate is not None, + ) ) diff --git a/smoe/utils/config.py b/smoe/utils/config.py index 33604e4..4995d52 100644 --- a/smoe/utils/config.py +++ b/smoe/utils/config.py @@ -24,7 +24,8 @@ class ModelArguments: default=None, metadata={ "help": ( - "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + "The model checkpoint for weights initialization.Don't set if you want" + " to train a model from scratch." ) }, ) @@ -32,22 +33,26 @@ class ModelArguments: default=None, metadata={ "help": ( - "The tokenizer for weights initialization.Don't set if you want to train a model from scratch." + "The tokenizer for weights initialization.Don't set if you want to" + " train a model from scratch." ) }, ) model_type: Optional[str] = field( default=None, metadata={ - "help": "If training from scratch, pass a model type from the list: " - + ", ".join(MODEL_TYPES) + "help": ( + "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES) + ) }, ) config_overrides: Optional[str] = field( default=None, metadata={ "help": ( - "Override some existing default config settings when a model is trained from scratch. Example: " + "Override some existing default config settings when a model is trained" + " from scratch. Example: " "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" ) }, @@ -67,27 +72,36 @@ class ModelArguments: cache_dir: Optional[str] = field( default=None, metadata={ - "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + "help": ( + "Where do you want to store the pretrained models downloaded from" + " huggingface.co" + ) }, ) use_fast_tokenizer: bool = field( default=True, metadata={ - "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + "help": ( + "Whether to use one of the fast tokenizer (backed by the tokenizers" + " library) or not." + ) }, ) model_revision: str = field( default="main", metadata={ - "help": "The specific model version to use (can be a branch name, tag name or commit id)." + "help": ( + "The specific model version to use (can be a branch name, tag name or" + " commit id)." + ) }, ) use_auth_token: bool = field( default=False, metadata={ "help": ( - "Will use the token generated when running `huggingface-cli login` (necessary to use this script " - "with private models)." + "Will use the token generated when running `huggingface-cli login`" + " (necessary to use this script with private models)." ) }, ) @@ -95,8 +109,9 @@ class ModelArguments: default=None, metadata={ "help": ( - "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " - "dtype will be automatically derived from the model's weights." + "Override the default `torch.dtype` and load the model under this" + " dtype. If `auto` is passed, the dtype will be automatically derived" + " from the model's weights." ), "choices": ["auto", "bfloat16", "float16", "float32"], }, @@ -107,7 +122,8 @@ def __post_init__(self): self.config_name is not None or self.model_name_or_path is not None ): raise ValueError( - "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + "--config_overrides can't be used in combination with --config_name or" + " --model_name_or_path" ) @@ -124,7 +140,10 @@ class DataArguments: dataset_config_name: Optional[str] = field( default=None, metadata={ - "help": "The configuration name of the dataset to use (via the datasets library)." + "help": ( + "The configuration name of the dataset to use (via the datasets" + " library)." + ) }, ) train_file: Optional[str] = field( @@ -133,15 +152,18 @@ class DataArguments: validation_file: Optional[str] = field( default=None, metadata={ - "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." + "help": ( + "An optional input evaluation data file to evaluate the perplexity on" + " (a text file)." + ) }, ) max_train_samples: Optional[int] = field( default=None, metadata={ "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." + "For debugging purposes or quicker training, truncate the number of" + " training examples to this value if set." ) }, ) @@ -149,8 +171,8 @@ class DataArguments: default=None, metadata={ "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." + "For debugging purposes or quicker training, truncate the number of" + " evaluation examples to this value if set." ) }, ) @@ -159,9 +181,10 @@ class DataArguments: default=None, metadata={ "help": ( - "Optional input sequence length after tokenization. " - "The training dataset will be truncated in block of this size for training. " - "Default to the model max input length for single sentence inputs (take into account special tokens)." + "Optional input sequence length after tokenization. The training" + " dataset will be truncated in block of this size for training. Default" + " to the model max input length for single sentence inputs (take into" + " account special tokens)." ) }, ) @@ -172,7 +195,10 @@ class DataArguments: validation_split_percentage: Optional[float] = field( default=0.05, metadata={ - "help": "The percentage of the train set used as validation set in case there's no validation split" + "help": ( + "The percentage of the train set used as validation set in case there's" + " no validation split" + ) }, ) preprocessing_num_workers: Optional[int] = field( @@ -189,7 +215,10 @@ class DataArguments: prob_map: Optional[dict[str, float]] = field( default=None, metadata={ - "help": 'data type to sampling probabilities. e.g. {"commoncrawl": 0.67, "c4": 0.15}' + "help": ( + 'data type to sampling probabilities. e.g. {"commoncrawl": 0.67, "c4":' + " 0.15}" + ) }, ) @@ -241,7 +270,8 @@ def parse_args(*args: Type[Arguments]) -> tuple[Arguments, ...]: arg_tuple = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1])) else: raise ValueError( - f"Only yaml, yml, and json config files are supported, got {sys.argv[1]}" + "Only yaml, yml, and json config files are supported, got" + f" {sys.argv[1]}" ) else: arg_tuple = parser.parse_args_into_dataclasses() diff --git a/smoe/utils/merge_llama_with_lora.py b/smoe/utils/merge_llama_with_lora.py new file mode 100644 index 0000000..d76030a --- /dev/null +++ b/smoe/utils/merge_llama_with_lora.py @@ -0,0 +1,436 @@ +""" +Credit to: https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/scripts/merge_llama_with_chinese_lora_low_mem.py + +License: https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/LICENSE.md + +Usage: + + python -m smoe.utils.merge_llama_with_lora \ + --base_model path/to/llama/model \ + --lora_model path/to/first/lora[,path/to/second/lora] \ + --output_type [pth|huggingface] \ + --output_dir path/to/output/dir +""" + +import argparse +import gc +import json +import os +import re + +import peft +import torch +from huggingface_hub import snapshot_download +from transformers import LlamaTokenizer +from transformers.modeling_utils import dtype_byte_size + +parser = argparse.ArgumentParser() +parser.add_argument( + "--base_model", + default=None, + required=True, + type=str, + help="Please specify a base model", +) +parser.add_argument( + "--lora_model", + default=None, + required=True, + type=str, + help=( + "Please specify LoRA models to be merged (ordered); use commas to separate" + " multiple LoRA models" + ), +) +parser.add_argument( + "--output_type", + default="pth", + choices=["pth", "huggingface"], + type=str, + help="Save the merged model in pth or huggingface format", +) +parser.add_argument( + "--output_dir", + default="./merged_model", + type=str, + help="The output folder to save the merged model", +) +parser.add_argument( + "--verbose", default=False, action="store_true", help="Show detailed messages" +) + + +emb_to_model_size = { + 4096: "7B", + 5120: "13B", + 6656: "33B", + 8192: "65B", +} +num_shards_of_models = {"7B": 1, "13B": 2, "33B": 4, "65B": 8} +params_of_models = { + "7B": { + "dim": 4096, + "multiple_of": 256, + "n_heads": 32, + "n_layers": 32, + "norm_eps": 1e-06, + "vocab_size": -1, + }, + "13B": { + "dim": 5120, + "multiple_of": 256, + "n_heads": 40, + "n_layers": 40, + "norm_eps": 1e-06, + "vocab_size": -1, + }, + "33B": { + "dim": 6656, + "multiple_of": 256, + "n_heads": 52, + "n_layers": 60, + "norm_eps": 1e-06, + "vocab_size": -1, + }, + "65B": { + "dim": 8192, + "multiple_of": 256, + "n_heads": 64, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": -1, + }, +} + + +def transpose(weight, fan_in_fan_out): + return weight.T if fan_in_fan_out else weight + + +# Borrowed and modified from https://github.com/tloen/alpaca-lora +def translate_state_dict_key(k): + k = k.replace("base_model.model.", "") + if k == "model.embed_tokens.weight": + return "tok_embeddings.weight" + elif k == "model.norm.weight": + return "norm.weight" + elif k == "lm_head.weight": + return "output.weight" + elif k.startswith("model.layers."): + layer = k.split(".")[2] + if k.endswith(".self_attn.q_proj.weight"): + return f"layers.{layer}.attention.wq.weight" + elif k.endswith(".self_attn.k_proj.weight"): + return f"layers.{layer}.attention.wk.weight" + elif k.endswith(".self_attn.v_proj.weight"): + return f"layers.{layer}.attention.wv.weight" + elif k.endswith(".self_attn.o_proj.weight"): + return f"layers.{layer}.attention.wo.weight" + elif k.endswith(".mlp.gate_proj.weight"): + return f"layers.{layer}.feed_forward.w1.weight" + elif k.endswith(".mlp.down_proj.weight"): + return f"layers.{layer}.feed_forward.w2.weight" + elif k.endswith(".mlp.up_proj.weight"): + return f"layers.{layer}.feed_forward.w3.weight" + elif k.endswith(".input_layernorm.weight"): + return f"layers.{layer}.attention_norm.weight" + elif k.endswith(".post_attention_layernorm.weight"): + return f"layers.{layer}.ffn_norm.weight" + elif k.endswith("rotary_emb.inv_freq") or "lora" in k: + return None + else: + print(layer, k) + raise NotImplementedError + else: + print(k) + raise NotImplementedError + + +def unpermute(w): + return ( + w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim) + ) + + +def save_shards(model_sd, num_shards: int, prefix="", verbose=False): + """ + Convert and save the HF format weights to PTH format weights + """ + with torch.no_grad(): + if num_shards == 1: + new_state_dict = {} + for k, v in model_sd.items(): + new_k = translate_state_dict_key(k) + if new_k is not None: + if "wq" in new_k or "wk" in new_k: + new_state_dict[new_k] = unpermute(v) + else: + new_state_dict[new_k] = v + + os.makedirs(output_dir, exist_ok=True) + print( + f"Saving shard 1 of {num_shards} into" + f" {output_dir}/{prefix}consolidated.00.pth" + ) + torch.save(new_state_dict, output_dir + f"/{prefix}consolidated.00.pth") + else: + new_state_dicts = [dict() for _ in range(num_shards)] + for k in list(model_sd.keys()): + v = model_sd[k] + new_k = translate_state_dict_key(k) + if new_k is not None: + if new_k == "tok_embeddings.weight": + assert v.size(1) % num_shards == 0 + splits = v.split(v.size(1) // num_shards, dim=1) + elif new_k == "output.weight": + if v.size(0) % num_shards == 0: + splits = v.split(v.size(0) // num_shards, dim=0) + else: + size_list = [v.size(0) // num_shards] * num_shards + size_list[-1] += v.size(0) % num_shards + splits = v.split( + size_list, dim=0 + ) # 13B: size_list == [24976,24977] + elif new_k == "norm.weight": + splits = [v] * num_shards + elif "ffn_norm.weight" in new_k: + splits = [v] * num_shards + elif "attention_norm.weight" in new_k: + splits = [v] * num_shards + + elif "w1.weight" in new_k: + splits = v.split(v.size(0) // num_shards, dim=0) + elif "w2.weight" in new_k: + splits = v.split(v.size(1) // num_shards, dim=1) + elif "w3.weight" in new_k: + splits = v.split(v.size(0) // num_shards, dim=0) + + elif "wo.weight" in new_k: + splits = v.split(v.size(1) // num_shards, dim=1) + + elif "wv.weight" in new_k: + splits = v.split(v.size(0) // num_shards, dim=0) + + elif "wq.weight" in new_k or "wk.weight" in new_k: + v = unpermute(v) + splits = v.split(v.size(0) // num_shards, dim=0) + else: + print(f"Unexpected key {new_k}") + raise ValueError + if verbose: + print(f"Processing {new_k}") + for sd, split in zip(new_state_dicts, splits): + sd[new_k] = split.clone() + del split + del splits + del model_sd[k], v + gc.collect() # Effectively enforce garbage collection + + os.makedirs(output_dir, exist_ok=True) + for i, new_state_dict in enumerate(new_state_dicts): + print( + f"Saving shard {i+1} of {num_shards} into" + f" {output_dir}/{prefix}consolidated.0{i}.pth" + ) + torch.save( + new_state_dict, output_dir + f"/{prefix}consolidated.0{i}.pth" + ) + + +def merge_shards(output_dir, num_shards: int): + ckpt_filenames = sorted( + [ + f + for f in os.listdir(output_dir) + if re.match("L(\d+)-consolidated.(\d+).pth", f) + ] + ) + + for i in range(num_shards): + shards_filenames = sorted( + [f for f in ckpt_filenames if re.match(f"L(\d+)-consolidated.0{i}.pth", f)] + ) + print(f"Loading {shards_filenames} ...") + shards_dicts = [ + torch.load(os.path.join(output_dir, fn)) for fn in shards_filenames + ] + shards_merged = {} + for d in shards_dicts: + shards_merged |= d + + print( + "Saving the merged shard to " + + os.path.join(output_dir, f"consolidated.0{i}.pth") + ) + torch.save(shards_merged, os.path.join(output_dir, f"consolidated.0{i}.pth")) + + print("Cleaning up...") + del shards_merged + for d in shards_dicts: + del d + del shards_dicts + gc.collect() # Effectively enforce garbage collection + for fn in shards_filenames: + os.remove(os.path.join(output_dir, fn)) + + +if __name__ == "__main__": + args = parser.parse_args() + base_model_path = args.base_model + lora_model_paths = [ + s.strip() for s in args.lora_model.split(",") if len(s.strip()) != 0 + ] + output_dir = args.output_dir + output_type = args.output_type + os.makedirs(output_dir, exist_ok=True) + + print(f"Base model: {base_model_path}") + print(f"LoRA model(s) {lora_model_paths}:") + + tokenizers_and_loras = [] + for lora_model_path in lora_model_paths: + print(f"Loading {lora_model_path}") + if not os.path.exists(lora_model_path): + print( + "Cannot find lora model on the disk. Downloading lora model from hub..." + ) + lora_model_path = snapshot_download(repo_id=lora_model_path) + tokenizer = LlamaTokenizer.from_pretrained(lora_model_path) + lora_config = peft.LoraConfig.from_pretrained(lora_model_path) + lora_state_dict = torch.load( + os.path.join(lora_model_path, "adapter_model.bin"), map_location="cpu" + ) + if "base_model.model.model.embed_tokens.weight" in lora_state_dict: + lora_vocab_size = lora_state_dict[ + "base_model.model.model.embed_tokens.weight" + ].shape[0] + assert lora_vocab_size == len(tokenizer), ( + f"The vocab size of the tokenizer {len(tokenizer)} does not match the" + f" vocab size of the LoRA weight {lora_vocab_size}.\nMake sure that you" + " use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer" + " with the Alpaca-LoRA weight!" + ) + tokenizers_and_loras.append( + { + "tokenizer": tokenizer, + "state_dict": lora_state_dict, + "config": lora_config, + "scaling": lora_config.lora_alpha / lora_config.r, + "fan_in_fan_out": lora_config.fan_in_fan_out, + } + ) + if len(tokenizers_and_loras) == 2: + t1_vocab_size = len(tokenizers_and_loras[0]["tokenizer"]) + t2_vocab_size = len(tokenizers_and_loras[1]["tokenizer"]) + assert t1_vocab_size <= t2_vocab_size, ( + f"The vocab size of the first tokenizer is {t1_vocab_size}\nThe vocab size" + f" of the second tokenizer is {t2_vocab_size}, found to be smaller than" + f" {t1_vocab_size}\nThis is not the intended use. Please check your model" + " and tokenizer." + ) + + if not os.path.exists(base_model_path): + print("Cannot find lora model on the disk. Downloading lora model from hub...") + base_model_path = snapshot_download(repo_id=base_model_path) + ckpt_filenames = sorted( + [ + f + for f in os.listdir(base_model_path) + if re.match("pytorch_model-(\d+)-of-(\d+).bin", f) + ] + ) + + embedding_size = None + model_size = None + + total_size = 0 + for index, filename in enumerate(ckpt_filenames): + print(f"Loading ckpt {filename}") + state_dict = torch.load( + os.path.join(base_model_path, filename), map_location="cpu" + ) + if index == 0: + embedding_size = state_dict["model.embed_tokens.weight"].shape[1] + model_size = emb_to_model_size[embedding_size] + if output_type == "pth": + params = params_of_models[model_size] + num_shards = num_shards_of_models[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + dim = params["dim"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / ( + base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head) + ) + print("Merging...") + for k in state_dict: + for tl_idx, t_and_l in enumerate(tokenizers_and_loras): + saved_key = "base_model.model." + k + lora_key_A = saved_key.replace(".weight", ".lora_A.weight") + if saved_key in t_and_l["state_dict"]: + if args.verbose: + print( + f"copying {saved_key} from {tl_idx}-th LoRA weight to {k}" + ) + state_dict[k] = ( + t_and_l["state_dict"][saved_key].half().clone() + ) # do we need half()? + if lora_key_A in t_and_l["state_dict"]: + lora_key_B = lora_key_A.replace("lora_A.weight", "lora_B.weight") + if args.verbose: + print( + f"merging {lora_key_A} and lora_B.weight form {tl_idx}-th" + f" LoRA weight to {k}" + ) + state_dict[k] += ( + transpose( + t_and_l["state_dict"][lora_key_B].float() + @ t_and_l["state_dict"][lora_key_A].float(), + t_and_l["fan_in_fan_out"], + ) + * t_and_l["scaling"] + ) + weight_size = state_dict[k].numel() * dtype_byte_size(state_dict[k].dtype) + total_size += weight_size + + if output_type == "huggingface": + print(f"Saving ckpt {filename} to {output_dir} in HF format...") + torch.save(state_dict, os.path.join(output_dir, filename)) + elif output_type == "pth": + print("Converting to pth format...") + save_shards( + model_sd=state_dict, + num_shards=num_shards, + prefix=f"L{index+1}-", + verbose=args.verbose, + ) + del state_dict + gc.collect() # Effectively enforce garbage collection + + print("Saving tokenizer") + tokenizers_and_loras[-1]["tokenizer"].save_pretrained(output_dir) + if output_type == "pth": + with open(output_dir + "/params.json", "w") as f: + print(f"Saving params.json into {output_dir}/params.json") + json.dump(params, f) + merge_shards(output_dir, num_shards=num_shards) + + if output_type == "huggingface": + configs = ( + "config.json", + "generation_config.json", + "pytorch_model.bin.index.json", + ) + for config in configs: + if os.path.exists(os.path.join(base_model_path, config)): + print(f"Saving {config}") + with open(os.path.join(base_model_path, config), "r") as f: + obj = json.load(f) + if config == "config.json": + obj["vocab_size"] = len(tokenizers_and_loras[-1]["tokenizer"]) + if config == "pytorch_model.bin.index.json": + obj["metadata"]["total_size"] = total_size + with open(os.path.join(output_dir, config), "w") as f: + json.dump(obj, f, indent=2) + print("Done.") diff --git a/smoe/utils/param.py b/smoe/utils/param.py new file mode 100644 index 0000000..5364fb6 --- /dev/null +++ b/smoe/utils/param.py @@ -0,0 +1,38 @@ +import torch.nn as nn + +from smoe.utils.logging import get_logger + +logger = get_logger(__name__) + + +def get_trainable_parameters(model: nn.Module, verbose: bool = True): + """ + Prints the number of trainable parameters in the model. + + Credit to https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes + # one needs to multiply the number of parameters by 2 to get + # the correct number of parameters + if param.__class__.__name__ == "Params4bit": + num_params = num_params * 2 + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + if verbose: + logger.info( + f"trainable params: {trainable_params:,d}" + f" || all params: {all_param:,d}" + f" || trainable%: {100 * trainable_params / all_param}" + ) + + return trainable_params, all_param