Skip to content

Commit 4ec01a5

Browse files
Fix for Replicate_kv_heads script (#412)
As we have moved to the transformer version 4.50.0, we have changed the attention implementation and hence number of attention heads and hidden size are no more attributes of this class. We have added these parameters are optional arguments for the replicate_kv_heads script. User can now pass them as arguments, if not passed explicitly then it shall be picked from config.json file of the model. --------- Signed-off-by: Hem Agnihotri <[email protected]>
1 parent f1fb026 commit 4ec01a5

File tree

2 files changed

+71
-13
lines changed

2 files changed

+71
-13
lines changed

scripts/replicate_kv_head/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@ Replace `<hf_token>` with your actual token.
3030
### Arguments
3131
- **--model_name**: Model card name to use (default: “meta-llama/Meta-Llama-3-8B-Instruct”).
3232
- **--prompt**: Prompt to use for the model (default: “My name is”).
33-
- **--repeat**: Factor to repeat key-value heads (default: 2).
33+
- **--repeat**: Factor to repeat key-value heads (default: 2).
34+
- **--num_attention_heads**: Number of attentin heads (default: None). This is optional param, if not given explicitly the will be read from config.json.
35+
- **--hidden_size**: Hidden size (default: None). This is optional param, if not given explicitly the will be read from config.json.

scripts/replicate_kv_head/replicate_kv_heads.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
import argparse
9+
from typing import Optional
910

1011
import torch
1112
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -70,46 +71,78 @@ def duplicate_weights_for_linear_layer(
7071
)
7172

7273

73-
def main(args):
74+
def replicate_kv_heads(
75+
model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct",
76+
prompt: str = "My name is",
77+
repeat: int = 2,
78+
full_batch_size: Optional[int] = None,
79+
num_hidden_layers: Optional[int] = None,
80+
num_attention_heads: Optional[int] = None,
81+
hidden_size: Optional[int] = None,
82+
):
83+
"""
84+
Replicate the KV heads. The script performs the following steps:
85+
1. Runs inference with the original model.
86+
2. Replicates the KV heads.
87+
3. Runs inference on the modified model to validate the changes.
88+
4. Exports the modified model to ONNX format.
89+
90+
``Mandatory`` Args:
91+
:model_name (str): Model card name to use, default value as meta-llama/Meta-Llama-3-8B-Instruct.
92+
:prompt (str): Prompt to use for the model, default value as My name is
93+
:repeat (int): Factor to repeat key-value heads.
94+
``Optional`` Args:
95+
:full_batch_size (int): Set full batch size to enable continuous batching mode, default is None.
96+
:num_hidden_layers (int): Number of hidden layers to use, default is None.
97+
:num_attention_heads (int): Number of attention heads, if not passed explicitly then will be picked from config.json.
98+
:hidden_size (int): Hidden size to use, if not passed explicitly then will be picked from config.json.
99+
100+
"""
74101
# Load the model and tokenizer
75-
model_name = args.model_name
76102
model_base_name = model_name.split("/")[-1]
77103
# Replace quantizers for loading Quantized AWQ/GPTQ models on CPU.
78104
replace_transformers_quantizers()
79105
# Prepare kwargs for model loading
80106
model_kwargs = {"attn_implementation": "eager"}
81-
if args.num_hidden_layers:
82-
model_kwargs["num_hidden_layers"] = args.num_hidden_layers
107+
108+
if num_hidden_layers:
109+
model_kwargs["num_hidden_layers"] = num_hidden_layers
83110

84111
pretrained_model_name_or_path = login_and_download_hf_lm(model_name)
85112
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
86113

87114
# Undo the effect of replace_transformers_quantizers
88115
undo_transformers_quantizers()
89116
tokenizer = AutoTokenizer.from_pretrained(model_name)
90-
inputs = tokenizer(args.prompt, return_tensors="pt")
117+
inputs = tokenizer(prompt, return_tensors="pt")
91118

92119
# Generate original outputs and tokens
93120
with torch.inference_mode():
94121
_ = model(**inputs) # original output
95122
orig_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False)
96123

97124
# Modify the number of key-value heads
98-
repeat = args.repeat
99125
orig_kv_heads = model.config.num_key_value_heads
100126
new_kv_heads = repeat * orig_kv_heads
101127
model.config.num_key_value_heads = new_kv_heads
102128

103129
print("Original KV heads:", orig_kv_heads)
104130
print("Modified KV heads:", new_kv_heads)
105131

132+
# Check if hidden size and number of attention heads are explicitly passed as arguments or not
133+
if num_attention_heads is None:
134+
num_attention_heads = model.config.num_attention_heads
135+
136+
if hidden_size is None:
137+
hidden_size = model.config.hidden_size
138+
106139
# Update the model's attention layers with new key-value heads
107140
for block in model.model.layers:
108141
attn = block.self_attn
109142
attn.num_key_value_heads = new_kv_heads
110-
attn.num_key_value_groups = block.self_attn.num_heads // new_kv_heads
111-
duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size)
112-
duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size)
143+
attn.num_key_value_groups = num_attention_heads // new_kv_heads
144+
duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size)
145+
duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size)
113146

114147
# Generate modified outputs and tokens
115148
with torch.inference_mode():
@@ -126,13 +159,13 @@ def main(args):
126159
)
127160

128161
# Export the modified model
129-
q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if args.full_batch_size else False))
162+
q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if full_batch_size else False))
130163
export(
131164
model_name,
132165
q_model,
133166
tokenizer=tokenizer,
134167
onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads",
135-
full_batch_size=(args.full_batch_size if args.full_batch_size else None),
168+
full_batch_size=(full_batch_size if full_batch_size else None),
136169
)
137170

138171

@@ -162,6 +195,29 @@ def main(args):
162195
default=None,
163196
help="Number of hidden layers to use, default is None",
164197
)
198+
parser.add_argument(
199+
"--num_attention_heads",
200+
"--num-attention-heads",
201+
type=int,
202+
default=None,
203+
help="Number of attention heads, if not passed explicitly then will be picked from config.json",
204+
)
205+
parser.add_argument(
206+
"--hidden_size",
207+
"--hidden-size",
208+
type=int,
209+
default=None,
210+
help="Hidden size to use, if not passed explicitly then will be picked from config.json",
211+
)
165212

166213
args = parser.parse_args()
167-
main(args)
214+
215+
replicate_kv_heads(
216+
model_name=args.model_name,
217+
prompt=args.prompt,
218+
repeat=args.repeat,
219+
full_batch_size=args.full_batch_size,
220+
num_hidden_layers=args.num_hidden_layers,
221+
num_attention_heads=args.num_attention_heads,
222+
hidden_size=args.hidden_size,
223+
)

0 commit comments

Comments
 (0)