From ade2f4621db7a7985379c6d960346dd9d71f2163 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 28 Jun 2024 22:58:51 +0900 Subject: [PATCH] chore(format): run black on main (#497) Co-authored-by: github-actions[bot] --- ChatTTS/model/cuda/te_llama.py | 80 +++++++++++++++++++++------------- ChatTTS/model/gpt.py | 5 ++- 2 files changed, 54 insertions(+), 31 deletions(-) diff --git a/ChatTTS/model/cuda/te_llama.py b/ChatTTS/model/cuda/te_llama.py index 103afa62f..6909ebb18 100644 --- a/ChatTTS/model/cuda/te_llama.py +++ b/ChatTTS/model/cuda/te_llama.py @@ -20,7 +20,11 @@ LlamaModel, LlamaConfig, ) -from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model +from transformers.modeling_utils import ( + _add_variant, + load_state_dict, + _load_state_dict_into_model, +) from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files @@ -30,12 +34,16 @@ def replace_decoder(te_decoder_cls): """ Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`. """ - original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer + original_llama_decoder_cls = ( + transformers.models.llama.modeling_llama.LlamaDecoderLayer + ) transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls try: yield finally: - transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls + transformers.models.llama.modeling_llama.LlamaDecoderLayer = ( + original_llama_decoder_cls + ) class TELlamaDecoderLayer(te.pytorch.TransformerLayer): @@ -64,7 +72,9 @@ def __init__(self, config, *args, **kwargs): attn_input_format="bshd", num_gqa_groups=config.num_key_value_heads, ) - te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + te_rope = RotaryPositionEmbedding( + config.hidden_size // config.num_attention_heads + ) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() def forward(self, hidden_states, *args, attention_mask, **kwargs): @@ -75,7 +85,9 @@ def forward(self, hidden_states, *args, attention_mask, **kwargs): """ return ( super().forward( - hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb + hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=self.te_rope_emb, ), ) @@ -96,7 +108,9 @@ def __new__(cls, config: LlamaConfig): return model @classmethod - def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs): + def from_pretrained_local( + cls, pretrained_model_name_or_path, *args, config, **kwargs + ): """ Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 @@ -120,16 +134,22 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k is_sharded = True elif os.path.isfile( os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), ) ): # Load from a sharded PyTorch checkpoint archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), ) is_sharded = True else: - raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") + raise AssertionError( + "Only sharded PyTorch ckpt format supported at the moment" + ) resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path, @@ -168,34 +188,34 @@ def _replace_params(hf_state_dict, te_state_dict, config): # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model if layer_prefix + "input_layernorm.weight" in hf_state_dict: - te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[ - : - ] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:] + te_state_dict[ + layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight" + ].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:] if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict: - te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = ( - hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:] - ) + te_state_dict[ + layer_prefix + "self_attention.layernorm_qkv.query_weight" + ].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:] if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict: - te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = ( - hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:] - ) + te_state_dict[ + layer_prefix + "self_attention.layernorm_qkv.key_weight" + ].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:] if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict: - te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = ( - hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:] - ) + te_state_dict[ + layer_prefix + "self_attention.layernorm_qkv.value_weight" + ].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:] if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict: - te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[ - layer_prefix + "self_attn.o_proj.weight" - ].data[:] + te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = ( + hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:] + ) if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict: - te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[ - layer_prefix + "post_attention_layernorm.weight" - ].data[:] + te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = ( + hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:] + ) # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to # load them separately. @@ -210,7 +230,7 @@ def _replace_params(hf_state_dict, te_state_dict, config): ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data if layer_prefix + "mlp.down_proj.weight" in hf_state_dict: - te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[ - layer_prefix + "mlp.down_proj.weight" - ].data[:] + te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = ( + hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:] + ) return all_layer_prefixes diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index d9fa230f8..187517796 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -118,11 +118,14 @@ def _build_llama( if "cuda" in str(device): try: from .cuda import TELlamaModel + model = TELlamaModel(llama_config) self.logger.info("use NVIDIA accelerated TELlamaModel") except Exception as e: model = None - self.logger.warn(f"use default LlamaModel for importing TELlamaModel error: {e}") + self.logger.warn( + f"use default LlamaModel for importing TELlamaModel error: {e}" + ) if model is None: model = LlamaModel(llama_config) del model.embed_tokens