Skip to content

Commit fbec013

Browse files
add the raw version
1 parent 0c3dcd1 commit fbec013

File tree

5 files changed

+35
-5
lines changed

5 files changed

+35
-5
lines changed

src/nncf/onnx/graph/metatypes/groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
OPERATIONS_WITH_BIAS_REDUCED = [
140140
onnx_metatypes.ONNXConvolutionMetatype,
141141
onnx_metatypes.ONNXGemmMetatype,
142-
# TODO: Need to add MatMul with the separate bias support (CVS-135433)
142+
onnx_metatypes.ONNXMatMulMetatype,
143143
]
144144

145145
OPERATIONS_WITH_BIAS = [

src/nncf/onnx/graph/metatypes/onnx_metatypes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype):
143143
op_names = ["MatMul"]
144144
hw_config_names = [HWConfigOpName.MATMUL]
145145
weight_channel_axis = -1 # For port_id=1
146-
bias_port_id = 2
147146
possible_weight_ports = [0, 1]
148147
output_channel_axis = -1
149148

src/nncf/onnx/graph/nncf_graph_builder.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nncf.onnx.graph.metatypes.groups import CONSTANT_WEIGHT_LAYER_METATYPES
2626
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
2727
from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES
28+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype
2829
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype
2930
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype
3031
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
@@ -186,6 +187,7 @@ def _get_bias_attr(
186187
node: onnx.NodeProto,
187188
model: onnx.ModelProto,
188189
parents_node_mapping: dict[str, onnx.NodeProto],
190+
children_node_mapping
189191
) -> dict[str, str]:
190192
"""
191193
Returns bias tensor attributes.
@@ -197,6 +199,25 @@ def _get_bias_attr(
197199
"""
198200
bias_attrs = {}
199201
metatype = get_metatype(model, node)
202+
203+
if metatype == ONNXMatMulMetatype:
204+
weight_port_ids = _get_weight_port_ids(node, model, parents_node_mapping)
205+
if weight_port_ids:
206+
y = node.output[0] # only 1 output
207+
consumers = children_node_mapping[y]
208+
if len(consumers) == 1 and consumers[0].op_type == "Add":
209+
add_node = consumers[0]
210+
for port_id, input_name in enumerate(add_node.input):
211+
if input_name != y:
212+
bias_attrs["name"] = input_name
213+
bias_attrs["port_id"] = port_id
214+
215+
bias_attrs["add_node"] = add_node.name
216+
217+
# `name` should be or output from constant or `initializer`
218+
219+
return bias_attrs
220+
200221
if _is_node_with_bias(node, model):
201222
bias_tensor_port_id = get_bias_tensor_port_id(metatype)
202223
bias_edge_name = get_tensor_edge_name(model, node, bias_tensor_port_id, parents_node_mapping)
@@ -358,7 +379,8 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph:
358379
is_shared = None
359380
weight_attrs = {}
360381
node_attrs = _get_node_attrs(node, onnx_model)
361-
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping)
382+
bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping, children_node_mapping)
383+
362384
if weight_port_ids: # If node has weight
363385
weight_edge_names = []
364386
for weight_port_id in weight_port_ids:

src/nncf/onnx/graph/transformations/command_creation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ def create_bias_correction_command(node: NNCFNode, bias_value: np.ndarray) -> ON
2929
:param bias_value: The new bias value that will be set.
3030
:return: The `ONNXInitializerUpdateCommand` command to update bias.
3131
"""
32-
bias_port_id = node.metatype.bias_port_id
33-
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
32+
node_name = node.layer_attributes.bias_attrs.get("add_node")
33+
if node_name:
34+
port_id = node.layer_attributes.bias_attrs["port_id"]
35+
target_point = ONNXTargetPoint(TargetType.LAYER, node_name, port_id)
36+
else:
37+
bias_port_id = node.metatype.bias_port_id
38+
target_point = ONNXTargetPoint(TargetType.LAYER, node.node_name, bias_port_id)
3439
return ONNXInitializerUpdateCommand(target_point, bias_value)
3540

3641

src/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
104104

105105
@staticmethod
106106
def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[str, str]:
107+
output = node.layer_attributes.bias_attrs.get("add_node", None)
108+
if output:
109+
return node.node_name, output
110+
107111
return node.node_name, node.node_name
108112

109113
@staticmethod

0 commit comments

Comments
 (0)