14
14
from QEfficient .transformers .quantizers .auto import replace_transformers_quantizers , undo_transformers_quantizers
15
15
from QEfficient .transformers .quantizers .awq import WQLinear_GEMM
16
16
from QEfficient .transformers .quantizers .gptq import QuantLinearGPTQ
17
+ from QEfficient .transformers .quantizers .quantizer_compressed_tensors import FP8DeQuantLinear
17
18
18
19
19
20
def duplicate_weights_for_linear_layer (
@@ -49,6 +50,15 @@ def duplicate_weights_for_linear_layer(
49
50
1 ,
50
51
).view (hidden_size // layer .group_size , new_kv_heads * head_dim )
51
52
layer .out_features = layer .out_features * repeat
53
+
54
+ elif isinstance (layer , FP8DeQuantLinear ):
55
+ layer .weight .data = torch .repeat_interleave (
56
+ layer .weight .data .view (orig_kv_heads , head_dim , hidden_size ), repeat , 0
57
+ ).view (new_kv_heads * head_dim , hidden_size )
58
+ layer .weight_scale .data = torch .repeat_interleave (
59
+ layer .weight_scale .data .view (orig_kv_heads , head_dim ), repeat , 0
60
+ ).view (new_kv_heads * head_dim , - 1 )
61
+
52
62
else :
53
63
layer .weight .data = torch .repeat_interleave (
54
64
layer .weight .data .view (orig_kv_heads , head_dim , hidden_size ), repeat , 0
@@ -65,7 +75,6 @@ def main(args):
65
75
model_kwargs = {"attn_implementation" : "eager" }
66
76
if args .num_hidden_layers :
67
77
model_kwargs ["num_hidden_layers" ] = args .num_hidden_layers
68
-
69
78
model = AutoModelForCausalLM .from_pretrained (model_name , ** model_kwargs )
70
79
71
80
# Undo the effect of replace_transformers_quantizers
0 commit comments