Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,20 @@ def apply(
merge_weight = self._backend_entity.get_weight(merge_node, port_id, model, graph)
merge_weight = (merge_weight * a_scale).astype(weight_dtype)
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)
a_scale = fns.transpose(a_scale)
Copy link
Collaborator Author

@daniil-lyakhov daniil-lyakhov Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this a_scale is always being squeezed before the usage, there is no need to do transpose here

else: # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
a_scale = fns.transpose(a_scale).astype(weight_dtype)
next_nodes = graph.get_next_nodes(merge_node)
out_edges = graph.get_output_edges(merge_node)
next_nodes = [edge.to_node for edge in out_edges]
source_node_output_port = graph.get_output_edges(merge_node)[0].output_port_id

# Unsqueeze activation scale to match the size of the activation
# Output edges always have the same shape
tensor_shape = out_edges[0].tensor_shape
a_scale_shape = a_scale.shape
if len(tensor_shape) > len(a_scale_shape):
a_scale_shape = (1,) * (len(tensor_shape) - len(a_scale_shape)) + tuple(a_scale_shape)
a_scale = fns.reshape(a_scale, a_scale_shape)

scale_insertion_command = self._backend_entity.scale_insertion_command(
merge_node, next_nodes, source_node_output_port, a_scale.data
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(self, int

@staticmethod
@abstractmethod
def get_awq_model() -> TModel:
def get_awq_model(non_mergable_pattern: bool) -> TModel:
"Returns a backend model for test_awq_with_ignored_scope."

@staticmethod
Expand All @@ -388,7 +388,7 @@ def get_ignored_scope_name() -> str:
"Returns ignored scope name for test_awq_with_ignored_scope."

def test_awq_with_ignored_scope(self, mocker):
model = self.get_awq_model()
model = self.get_awq_model(non_mergable_pattern=False)
sz = 8
n_samples = 10

Expand Down Expand Up @@ -455,12 +455,14 @@ def test_sam_pe_weight_compression(self):

@staticmethod
@abstractmethod
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
@pytest.fixture
def test_awq_scale_ref() -> dict[str, Tensor]:
"Returns reference for test_awq_scale_reference."

def test_awq_scale_reference(self, monkeypatch, mocker):
@pytest.mark.parametrize("non_mergable_pattern", [True, False])
def test_awq_scale_reference(self, monkeypatch, mocker, non_mergable_pattern, test_awq_scale_ref):
monkeypatch.setattr("nncf.quantization.algorithms.weight_compression.algorithm.AWQ", SpyAWQ)
model = self.get_awq_model()
model = self.get_awq_model(non_mergable_pattern)

input = 0.01 * np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8) + 0.02
input = self.to_tensor(input)
Expand All @@ -477,7 +479,9 @@ def test_awq_scale_reference(self, monkeypatch, mocker):
)
assert spy_instance is not None
for node_name, scales in spy_instance._scale_per_target_node.items():
assert fns.allclose(scales, self.get_reference_for_test_awq_scale_reference()[node_name])
ref = test_awq_scale_ref[node_name]
assert fns.allclose(scales, ref)
assert scales.shape == ref.shape

@pytest.mark.parametrize(
["group_size", "fallback_mode", "min_adjusted_group_size", "expected_outcome"],
Expand Down
40 changes: 32 additions & 8 deletions tests/onnx/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def get_num_multiply_from_awq(model: onnx.ModelProto) -> int:
return awq_num

@staticmethod
def get_awq_model() -> onnx.ModelProto:
def get_awq_model(non_mergable_pattern: bool) -> onnx.ModelProto:
"""
Builds a model to be used in the following tests:
- TemplateWeightCompression.test_awq_with_ignored_scope()
Expand All @@ -641,11 +641,17 @@ def get_awq_model() -> onnx.ModelProto:
w_data = w_data.T

num_blocks = 2

for i in range(num_blocks):
a = mb.add_matmul(x, shape=w_data.shape, data=w_data)
b = mb.add_matmul(x, shape=w_data.shape, data=w_data)
x = mb.add_mul(a, b)
x = mb.add_matmul(x, shape=w_data.shape, output=output if i == num_blocks - 1 else None, data=w_data)
if non_mergable_pattern:
a = mb.add_matmul(x, shape=w_data.shape, data=w_data)
b = mb.add_relu(a)
x = mb.add_matmul(b, shape=w_data.shape, output=output if i == num_blocks - 1 else None, data=w_data)
else:
a = mb.add_matmul(x, shape=w_data.shape, data=w_data)
b = mb.add_matmul(x, shape=w_data.shape, data=w_data)
x = mb.add_mul(a, b)
x = mb.add_matmul(x, shape=w_data.shape, output=output if i == num_blocks - 1 else None, data=w_data)

return mb.build()

Expand Down Expand Up @@ -692,14 +698,32 @@ def get_ignored_scope_name() -> str:
return "MatMul_4" # Zero-based indices (e.g., MatMul_0, MatMul_1, ...)

@staticmethod
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
@pytest.fixture
def test_awq_scale_ref() -> dict[str, Tensor]:
return {
"MatMul_3": Tensor(
np.array(
[[1.2264546, 1.2054994, 1.1413403, 1.0974358, 1.0643553, 1.0379708, 1.0161183, 0.9975262]],
dtype=np.float32,
).T
)
)
),
"MatMul_2": Tensor(
np.array(
[
[
[1.9909902],
[1.8632966],
[1.5759803],
[1.3974594],
[1.2722752],
[1.1779976],
[1.1035581],
[1.042768],
]
],
dtype=np.float32,
),
),
}

@staticmethod
Expand Down
32 changes: 20 additions & 12 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def get_weights(weights_data, is_int8, name):
)
return (qw - zp) * scale

def _create_ov_model(self, n_extra_dims: int = 1, is_int8=False):
def _create_ov_model(self, n_extra_dims: int = 1, is_int8=False, non_mergable_pattern: bool = False):
input_node = opset.parameter([1] * n_extra_dims + [-1, 8], name="Input_1")

weights_data1 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
Expand All @@ -1012,27 +1012,35 @@ def _create_ov_model(self, n_extra_dims: int = 1, is_int8=False):

weights_data2 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
weights2 = self.get_weights(weights_data2, is_int8, name="weights_2")
node2 = opset.matmul(input_node, weights2, transpose_a=False, transpose_b=True, name="MatMul_2")

node_multiply = opset.multiply(node1, node2, name="Multiply")
if non_mergable_pattern:
relu = opset.relu(node1)
node3 = opset.matmul(relu, weights2, transpose_a=False, transpose_b=True, name="MatMul_2")
else:
node2 = opset.matmul(input_node, weights2, transpose_a=False, transpose_b=True, name="MatMul_2")
node_multiply = opset.multiply(node1, node2, name="Multiply")

weights_data3 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
weights3 = self.get_weights(weights_data3, is_int8, name="weights_3")
node3 = opset.matmul(node_multiply, weights3, transpose_a=False, transpose_b=True, name="MatMul_3")
weights_data3 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
weights3 = self.get_weights(weights_data3, is_int8, name="weights_3")
node3 = opset.matmul(node_multiply, weights3, transpose_a=False, transpose_b=True, name="MatMul_3")

weights_data4 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
weights4 = self.get_weights(weights_data4, is_int8, name="weights_4")
node4 = opset.matmul(node3, weights4, transpose_a=False, transpose_b=True, name="MatMul_4")

weights_data5 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
weights5 = self.get_weights(weights_data5, is_int8, name="weights_5")
node5 = opset.matmul(node3, weights5, transpose_a=False, transpose_b=True, name="MatMul_5")

node_multiply_2 = opset.multiply(node4, node5, name="Multiply_2")
if non_mergable_pattern:
relu = opset.relu(node4)
node6 = opset.matmul(relu, weights5, transpose_a=False, transpose_b=True, name="MatMul_6")
else:
node5 = opset.matmul(node3, weights5, transpose_a=False, transpose_b=True, name="MatMul_5")

node_multiply_2 = opset.multiply(node4, node5, name="Multiply_2")

weights_data6 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
weights6 = self.get_weights(weights_data6, is_int8, name="weights_6")
node6 = opset.matmul(node_multiply_2, weights6, transpose_a=False, transpose_b=True, name="MatMul_6")
weights_data6 = 0.01 * np.arange(0, 64).reshape(8, 8) + 0.05
weights6 = self.get_weights(weights_data6, is_int8, name="weights_6")
node6 = opset.matmul(node_multiply_2, weights6, transpose_a=False, transpose_b=True, name="MatMul_6")

result = opset.result(node6, name="Result")
result.get_output_tensor(0).set_names(set(["Result"]))
Expand Down
15 changes: 11 additions & 4 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,8 +2083,8 @@ def get_moe_model_for_test_scale_estimation():
return SimpleMoEModel().ov_model

@staticmethod
def get_awq_model() -> ov.Model:
return AWQMatmulModel().ov_model
def get_awq_model(non_mergable_pattern: bool) -> ov.Model:
return AWQMatmulModel(non_mergable_pattern=non_mergable_pattern).ov_model

@staticmethod
def get_different_channel_size_model(channel_sizes: list[int]) -> ov.Model:
Expand Down Expand Up @@ -2249,12 +2249,19 @@ def get_num_multiply_from_awq(model):
return awq_num

@staticmethod
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
@pytest.fixture
def test_awq_scale_ref() -> dict[str, Tensor]:
return {
"MatMul_3": Tensor(
np.array(
[[1.2264546, 1.2054994, 1.1413403, 1.0974358, 1.0643553, 1.0379708, 1.0161183, 0.9975262]],
dtype=np.float32,
).T
),
"MatMul_2": Tensor(
np.array(
[[[1.9909902, 1.8632966, 1.5759803, 1.3974594, 1.2722752, 1.1779976, 1.1035581, 1.042768]]],
dtype=np.float32,
)
)
),
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,19 @@ def forward(self, x):


class AWQLinearModel(nn.Module):
def __init__(self, is_int8=False):
def __init__(self, non_mergable_pattern: bool = False, is_int8=False):
super().__init__()
self.is_int8 = is_int8
self.non_mergable_pattern = non_mergable_pattern

self.linear1 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
self.linear2 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
self.linear3 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
self.linear4 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
self.linear5 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
self.linear6 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)

if not non_mergable_pattern:
self.linear5 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)
self.linear6 = self.get_linear_layer(0.01 * torch.arange(0, 64).reshape(8, 8) + 0.05, is_int8)

def get_linear_layer(self, weights_data, is_int8):
if not is_int8:
Expand All @@ -200,9 +203,19 @@ def get_linear_layer(self, weights_data, is_int8):
return linear_layer

def forward(self, x):
node1 = self.linear1(x)
node2 = self.linear2(x)
node_multiply = node1 * node2
if self.non_mergable_pattern:
node1 = self.linear1(x)
y = torch.relu(node1)
node_multiply = self.linear2(y)
else:
node1 = self.linear1(x)
node2 = self.linear2(x)
node_multiply = node1 * node2

if self.non_mergable_pattern:
node3 = self.linear3(node_multiply)
y = torch.relu(node3)
return self.linear4(y)

node3 = self.linear3(node_multiply)
node4 = self.linear4(node3)
Expand Down Expand Up @@ -516,8 +529,8 @@ def get_moe_model_for_test_scale_estimation():
return model

@staticmethod
def get_awq_model() -> torch.nn.Module:
return AWQLinearModel()
def get_awq_model(non_mergable_pattern: bool) -> torch.nn.Module:
return AWQLinearModel(non_mergable_pattern=non_mergable_pattern)

@staticmethod
def get_different_channel_size_model(channel_sizes: list[int]) -> torch.nn.Module:
Expand Down Expand Up @@ -674,11 +687,34 @@ def get_num_multiply_from_awq(model):
return awq_num

@staticmethod
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
@pytest.fixture
def test_awq_scale_ref() -> dict[str, Tensor]:
return {
"linear3/linear/0": Tensor(
torch.tensor([[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]])
)
torch.tensor(
[[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]],
dtype=torch.float32,
).T
),
"linear2/linear/0": Tensor(
torch.tensor(
[
[
[
1.9909899235,
1.8632963896,
1.5759800673,
1.3974593878,
1.2722752094,
1.1779977083,
1.1035580635,
1.0427680016,
]
]
],
dtype=torch.float32,
)
),
}


Expand Down
30 changes: 25 additions & 5 deletions tests/torch2/fx/test_compress_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def get_moe_model_for_test_scale_estimation():
return exported_model

@staticmethod
def get_awq_model() -> torch.fx.GraphModule:
model = AWQLinearModel()
def get_awq_model(non_mergable_pattern: bool) -> torch.fx.GraphModule:
model = AWQLinearModel(non_mergable_pattern=non_mergable_pattern)
dynamic_shapes = [[None, torch.export.Dim("dynamic_shape"), None]]
ex_input = torch.ones([1, 4, 8], dtype=torch.float32)
exported_model = get_torch_fx_model(model, ex_input, dynamic_shapes=dynamic_shapes)
Expand Down Expand Up @@ -539,9 +539,29 @@ def get_num_multiply_from_awq(model):
return awq_num

@staticmethod
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
@pytest.fixture
def test_awq_scale_ref() -> dict[str, Tensor]:
return {
"linear_2": Tensor(
torch.tensor([[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]])
)
torch.tensor([[1.226455, 1.205499, 1.141340, 1.097436, 1.064355, 1.037971, 1.016118, 0.997526]]).T
),
"linear_1": Tensor(
torch.tensor(
[
[
[
1.9909899235,
1.8632963896,
1.5759800673,
1.3974593878,
1.2722752094,
1.1779977083,
1.1035580635,
1.0427680016,
]
]
],
dtype=torch.float32,
)
),
}