Skip to content

Commit 3bb01f7

Browse files
[AWQ] Expand scale dims to match activation dims
1 parent 09246b2 commit 3bb01f7

File tree

7 files changed

+156
-48
lines changed

7 files changed

+156
-48
lines changed

src/nncf/quantization/algorithms/weight_compression/awq.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,20 @@ def apply(
194194
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
195195
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
196196
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
197-
a_scale = fns.transpose(a_scale)
198197
else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
199198
a_scale = fns.transpose(a_scale).astype(weight_dtype)
200-
next_nodes = graph.get_next_nodes(merge_node)
199+
out_edges = graph.get_output_edges(merge_node)
200+
next_nodes = [edge.to_node for edge in out_edges]
201201
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id
202+
203+
# Unsqueeze activation scale to match the size of the activation
204+
# Output edges always have the same shape
205+
tensor_shape = out_edges[0].tensor_shape
206+
a_scale_shape = a_scale.shape
207+
if len(tensor_shape) > len(a_scale_shape):
208+
a_scale_shape = (1,) * (len(tensor_shape) - len(a_scale_shape)) + tuple(a_scale_shape)
209+
a_scale = fns.reshape(a_scale, a_scale_shape)
210+
202211
scale_insertion_command = self._backend_entity.scale_insertion_command(
203212
merge_node, next_nodes, source_node_output_port, a_scale.data
204213
)

tests/cross_fw/test_templates/template_test_weights_compression.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(self, int
364364

365365
@staticmethod
366366
@abstractmethod
367-
def get_awq_model() -> TModel:
367+
def get_awq_model(non_mergable_pattern: bool) -> TModel:
368368
"Returns a backend model for test_awq_with_ignored_scope."
369369

370370
@staticmethod
@@ -388,7 +388,7 @@ def get_ignored_scope_name() -> str:
388388
"Returns ignored scope name for test_awq_with_ignored_scope."
389389

390390
def test_awq_with_ignored_scope(self, mocker):
391-
model = self.get_awq_model()
391+
model = self.get_awq_model(non_mergable_pattern=False)
392392
sz = 8
393393
n_samples = 10
394394

@@ -455,12 +455,14 @@ def test_sam_pe_weight_compression(self):
455455

456456
@staticmethod
457457
@abstractmethod
458-
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
458+
@pytest.fixture
459+
def test_awq_scale_ref() -> dict[str, Tensor]:
459460
"Returns reference for test_awq_scale_reference."
460461

461-
def test_awq_scale_reference(self, monkeypatch, mocker):
462+
@pytest.mark.parametrize("non_mergable_pattern", [True, False])
463+
def test_awq_scale_reference(self, monkeypatch, mocker, non_mergable_pattern, test_awq_scale_ref):
462464
monkeypatch.setattr("nncf.quantization.algorithms.weight_compression.algorithm.AWQ", SpyAWQ)
463-
model = self.get_awq_model()
465+
model = self.get_awq_model(non_mergable_pattern)
464466

465467
input = 0.01 * np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8) + 0.02
466468
input = self.to_tensor(input)
@@ -477,7 +479,9 @@ def test_awq_scale_reference(self, monkeypatch, mocker):
477479
)
478480
assert spy_instance is not None
479481
for node_name, scales in spy_instance._scale_per_target_node.items():
480-
assert fns.allclose(scales, self.get_reference_for_test_awq_scale_reference()[node_name])
482+
ref = test_awq_scale_ref[node_name]
483+
assert fns.allclose(scales, ref)
484+
assert scales.shape == ref.shape
481485

482486
@pytest.mark.parametrize(
483487
["group_size", "fallback_mode", "min_adjusted_group_size", "expected_outcome"],

tests/onnx/quantization/test_weights_compression.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def get_num_multiply_from_awq(model: onnx.ModelProto) -> int:
624624
return awq_num
625625

626626
@staticmethod
627-
def get_awq_model() -> onnx.ModelProto:
627+
def get_awq_model(non_mergable_pattern: bool) -> onnx.ModelProto:
628628
"""
629629
Builds a model to be used in the following tests:
630630
- TemplateWeightCompression.test_awq_with_ignored_scope()
@@ -641,11 +641,17 @@ def get_awq_model() -> onnx.ModelProto:
641641
w_data = w_data.T
642642

643643
num_blocks = 2
644+
644645
for i in range(num_blocks):
645-
a = mb.add_matmul(x, shape=w_data.shape, data=w_data)
646-
b = mb.add_matmul(x, shape=w_data.shape, data=w_data)
647-
x = mb.add_mul(a, b)
648-
x = mb.add_matmul(x, shape=w_data.shape, output=output if i == num_blocks - 1 else None, data=w_data)
646+
if non_mergable_pattern:
647+
a = mb.add_matmul(x, shape=w_data.shape, data=w_data)
648+
b = mb.add_relu(a)
649+
x = mb.add_matmul(b, shape=w_data.shape, output=output if i == num_blocks - 1 else None, data=w_data)
650+
else:
651+
a = mb.add_matmul(x, shape=w_data.shape, data=w_data)
652+
b = mb.add_matmul(x, shape=w_data.shape, data=w_data)
653+
x = mb.add_mul(a, b)
654+
x = mb.add_matmul(x, shape=w_data.shape, output=output if i == num_blocks - 1 else None, data=w_data)
649655

650656
return mb.build()
651657

@@ -692,14 +698,32 @@ def get_ignored_scope_name() -> str:
692698
return "MatMul_4" # Zero-based indices (e.g., MatMul_0, MatMul_1, ...)
693699

694700
@staticmethod
695-
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
701+
@pytest.fixture
702+
def test_awq_scale_ref() -> dict[str, Tensor]:
696703
return {
697704
"MatMul_3": Tensor(
698705
np.array(
699706
[[1.2264546, 1.2054994, 1.1413403, 1.0974358, 1.0643553, 1.0379708, 1.0161183, 0.9975262]],
700707
dtype=np.float32,
701-
).T
702-
)
708+
)
709+
),
710+
"MatMul_2": Tensor(
711+
np.array(
712+
[
713+
[
714+
[1.9909902],
715+
[1.8632966],
716+
[1.5759803],
717+
[1.3974594],
718+
[1.2722752],
719+
[1.1779976],
720+
[1.1035581],
721+
[1.042768],
722+
]
723+
],
724+
dtype=np.float32,
725+
),
726+
),
703727
}
704728

705729
@staticmethod

tests/openvino/native/models.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def get_weights(weights_data, is_int8, name):
10031003
)
10041004
return (qw - zp) * scale
10051005

1006-
def _create_ov_model(self, n_extra_dims: int = 1, is_int8=False):
1006+
def _create_ov_model(self, n_extra_dims: int = 1, is_int8=False, non_mergable_pattern: bool = False):
10071007
input_node = opset.parameter([1] * n_extra_dims + [-1, 8], name="Input_1")
10081008

10091009
weights_data1 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
@@ -1012,27 +1012,35 @@ def _create_ov_model(self, n_extra_dims: int = 1, is_int8=False):
10121012

10131013
weights_data2 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
10141014
weights2 = self.get_weights(weights_data2, is_int8, name="weights_2")
1015-
node2 = opset.matmul(input_node, weights2, transpose_a=False, transpose_b=True, name="MatMul_2")
1016-
1017-
node_multiply = opset.multiply(node1, node2, name="Multiply")
1015+
if non_mergable_pattern:
1016+
relu = opset.relu(node1)
1017+
node3 = opset.matmul(relu, weights2, transpose_a=False, transpose_b=True, name="MatMul_2")
1018+
else:
1019+
node2 = opset.matmul(input_node, weights2, transpose_a=False, transpose_b=True, name="MatMul_2")
1020+
node_multiply = opset.multiply(node1, node2, name="Multiply")
10181021

1019-
weights_data3 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
1020-
weights3 = self.get_weights(weights_data3, is_int8, name="weights_3")
1021-
node3 = opset.matmul(node_multiply, weights3, transpose_a=False, transpose_b=True, name="MatMul_3")
1022+
weights_data3 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
1023+
weights3 = self.get_weights(weights_data3, is_int8, name="weights_3")
1024+
node3 = opset.matmul(node_multiply, weights3, transpose_a=False, transpose_b=True, name="MatMul_3")
10221025

10231026
weights_data4 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
10241027
weights4 = self.get_weights(weights_data4, is_int8, name="weights_4")
10251028
node4 = opset.matmul(node3, weights4, transpose_a=False, transpose_b=True, name="MatMul_4")
10261029

10271030
weights_data5 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
10281031
weights5 = self.get_weights(weights_data5, is_int8, name="weights_5")
1029-
node5 = opset.matmul(node3, weights5, transpose_a=False, transpose_b=True, name="MatMul_5")
10301032

1031-
node_multiply_2 = opset.multiply(node4, node5, name="Multiply_2")
1033+
if non_mergable_pattern:
1034+
relu = opset.relu(node4)
1035+
node6 = opset.matmul(relu, weights5, transpose_a=False, transpose_b=True, name="MatMul_6")
1036+
else:
1037+
node5 = opset.matmul(node3, weights5, transpose_a=False, transpose_b=True, name="MatMul_5")
1038+
1039+
node_multiply_2 = opset.multiply(node4, node5, name="Multiply_2")
10321040

1033-
weights_data6 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
1034-
weights6 = self.get_weights(weights_data6, is_int8, name="weights_6")
1035-
node6 = opset.matmul(node_multiply_2, weights6, transpose_a=False, transpose_b=True, name="MatMul_6")
1041+
weights_data6 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
1042+
weights6 = self.get_weights(weights_data6, is_int8, name="weights_6")
1043+
node6 = opset.matmul(node_multiply_2, weights6, transpose_a=False, transpose_b=True, name="MatMul_6")
10361044

10371045
result = opset.result(node6, name="Result")
10381046
result.get_output_tensor(0).set_names(set(["Result"]))

tests/openvino/native/quantization/test_weights_compression.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,8 +2083,8 @@ def get_moe_model_for_test_scale_estimation():
20832083
return SimpleMoEModel().ov_model
20842084

20852085
@staticmethod
2086-
def get_awq_model() -> ov.Model:
2087-
return AWQMatmulModel().ov_model
2086+
def get_awq_model(non_mergable_pattern: bool) -> ov.Model:
2087+
return AWQMatmulModel(non_mergable_pattern=non_mergable_pattern).ov_model
20882088

20892089
@staticmethod
20902090
def get_different_channel_size_model(channel_sizes: list[int]) -> ov.Model:
@@ -2249,12 +2249,19 @@ def get_num_multiply_from_awq(model):
22492249
return awq_num
22502250

22512251
@staticmethod
2252-
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
2252+
@pytest.fixture
2253+
def test_awq_scale_ref() -> dict[str, Tensor]:
22532254
return {
22542255
"MatMul_3": Tensor(
22552256
np.array(
22562257
[[1.2264546, 1.2054994, 1.1413403, 1.0974358, 1.0643553, 1.0379708, 1.0161183, 0.9975262]],
22572258
dtype=np.float32,
2259+
).T
2260+
),
2261+
"MatMul_2": Tensor(
2262+
np.array(
2263+
[[[1.9909902, 1.8632966, 1.5759803, 1.3974594, 1.2722752, 1.1779976, 1.1035581, 1.042768]]],
2264+
dtype=np.float32,
22582265
)
2259-
)
2266+
),
22602267
}

tests/torch2/function_hook/quantization/test_weights_compression.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,19 @@ def forward(self, x):
174174

175175

176176
class AWQLinearModel(nn.Module):
177-
def __init__(self, is_int8=False):
177+
def __init__(self, non_mergable_pattern: bool = False, is_int8=False):
178178
super().__init__()
179179
self.is_int8 = is_int8
180+
self.non_mergable_pattern = non_mergable_pattern
180181

181182
self.linear1 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
182183
self.linear2 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
183184
self.linear3 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
184185
self.linear4 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
185-
self.linear5 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
186-
self.linear6 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
186+
187+
if not non_mergable_pattern:
188+
self.linear5 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
189+
self.linear6 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
187190

188191
def get_linear_layer(self, weights_data, is_int8):
189192
if not is_int8:
@@ -200,9 +203,19 @@ def get_linear_layer(self, weights_data, is_int8):
200203
return linear_layer
201204

202205
def forward(self, x):
203-
node1 = self.linear1(x)
204-
node2 = self.linear2(x)
205-
node_multiply = node1 * node2
206+
if self.non_mergable_pattern:
207+
node1 = self.linear1(x)
208+
y = torch.relu(node1)
209+
node_multiply = self.linear2(y)
210+
else:
211+
node1 = self.linear1(x)
212+
node2 = self.linear2(x)
213+
node_multiply = node1 * node2
214+
215+
if self.non_mergable_pattern:
216+
node3 = self.linear3(node_multiply)
217+
y = torch.relu(node3)
218+
return self.linear4(y)
206219

207220
node3 = self.linear3(node_multiply)
208221
node4 = self.linear4(node3)
@@ -516,8 +529,8 @@ def get_moe_model_for_test_scale_estimation():
516529
return model
517530

518531
@staticmethod
519-
def get_awq_model() -> torch.nn.Module:
520-
return AWQLinearModel()
532+
def get_awq_model(non_mergable_pattern: bool) -> torch.nn.Module:
533+
return AWQLinearModel(non_mergable_pattern=non_mergable_pattern)
521534

522535
@staticmethod
523536
def get_different_channel_size_model(channel_sizes: list[int]) -> torch.nn.Module:
@@ -674,11 +687,34 @@ def get_num_multiply_from_awq(model):
674687
return awq_num
675688

676689
@staticmethod
677-
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
690+
@pytest.fixture
691+
def test_awq_scale_ref() -> dict[str, Tensor]:
678692
return {
679693
"linear3/linear/0": Tensor(
680-
torch.tensor([[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]])
681-
)
694+
torch.tensor(
695+
[[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]],
696+
dtype=torch.float32,
697+
).T
698+
),
699+
"linear2/linear/0": Tensor(
700+
torch.tensor(
701+
[
702+
[
703+
[
704+
1.9909899235,
705+
1.8632963896,
706+
1.5759800673,
707+
1.3974593878,
708+
1.2722752094,
709+
1.1779977083,
710+
1.1035580635,
711+
1.0427680016,
712+
]
713+
]
714+
],
715+
dtype=torch.float32,
716+
)
717+
),
682718
}
683719

684720

tests/torch2/fx/test_compress_weights.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,8 @@ def get_moe_model_for_test_scale_estimation():
364364
return exported_model
365365

366366
@staticmethod
367-
def get_awq_model() -> torch.fx.GraphModule:
368-
model = AWQLinearModel()
367+
def get_awq_model(non_mergable_pattern: bool) -> torch.fx.GraphModule:
368+
model = AWQLinearModel(non_mergable_pattern=non_mergable_pattern)
369369
dynamic_shapes = [[None, torch.export.Dim("dynamic_shape"), None]]
370370
ex_input = torch.ones([1, 4, 8], dtype=torch.float32)
371371
exported_model = get_torch_fx_model(model, ex_input, dynamic_shapes=dynamic_shapes)
@@ -539,9 +539,29 @@ def get_num_multiply_from_awq(model):
539539
return awq_num
540540

541541
@staticmethod
542-
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
542+
@pytest.fixture
543+
def test_awq_scale_ref() -> dict[str, Tensor]:
543544
return {
544545
"linear_2": Tensor(
545-
torch.tensor([[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]])
546-
)
546+
torch.tensor([[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]]).T
547+
),
548+
"linear_1": Tensor(
549+
torch.tensor(
550+
[
551+
[
552+
[
553+
1.9909899235,
554+
1.8632963896,
555+
1.5759800673,
556+
1.3974593878,
557+
1.2722752094,
558+
1.1779977083,
559+
1.1035580635,
560+
1.0427680016,
561+
]
562+
]
563+
],
564+
dtype=torch.float32,
565+
)
566+
),
547567
}

0 commit comments

Comments
 (0)