Skip to content

Commit 6199051

Browse files
authored
Enabled FP8 models for replicate_kv_heads script (#353)
Signed-off-by: Onkar Chougule <[email protected]>
1 parent 7b64b33 commit 6199051

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

scripts/replicate_kv_head/replicate_kv_heads.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
1515
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
1616
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
17+
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
1718

1819

1920
def duplicate_weights_for_linear_layer(
@@ -49,6 +50,15 @@ def duplicate_weights_for_linear_layer(
4950
1,
5051
).view(hidden_size // layer.group_size, new_kv_heads * head_dim)
5152
layer.out_features = layer.out_features * repeat
53+
54+
elif isinstance(layer, FP8DeQuantLinear):
55+
layer.weight.data = torch.repeat_interleave(
56+
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
57+
).view(new_kv_heads * head_dim, hidden_size)
58+
layer.weight_scale.data = torch.repeat_interleave(
59+
layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0
60+
).view(new_kv_heads * head_dim, -1)
61+
5262
else:
5363
layer.weight.data = torch.repeat_interleave(
5464
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
@@ -65,7 +75,6 @@ def main(args):
6575
model_kwargs = {"attn_implementation": "eager"}
6676
if args.num_hidden_layers:
6777
model_kwargs["num_hidden_layers"] = args.num_hidden_layers
68-
6978
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
7079

7180
# Undo the effect of replace_transformers_quantizers

0 commit comments

Comments
 (0)