Skip to content

Commit 0c3dcd1

Browse files
[ONNX] Implementation of the SmoothQuant algorithm for the ONNX Backend (#3644)
### Changes - Implementation of the SmoothQuant algorithm for the ONNX Backend ### Related tickets Ref: 164847 ### Tests - tests/onnx/test_smooth_quant.py
1 parent 9ed7c8b commit 0c3dcd1

File tree

10 files changed

+457
-18
lines changed

10 files changed

+457
-18
lines changed

src/nncf/onnx/graph/model_transformer.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
import onnx
1717

1818
import nncf
19+
from nncf.common.factory import NNCFGraphFactory
1920
from nncf.common.graph.model_transformer import ModelTransformer
2021
from nncf.common.graph.transformations.commands import TargetType
2122
from nncf.common.graph.transformations.layout import TransformationLayout
2223
from nncf.onnx.graph.node_utils import get_input_edge
24+
from nncf.onnx.graph.node_utils import get_input_edges_mapping
2325
from nncf.onnx.graph.onnx_helper import get_children
2426
from nncf.onnx.graph.onnx_helper import get_children_node_mapping
2527
from nncf.onnx.graph.onnx_helper import get_edge_dtype
2628
from nncf.onnx.graph.onnx_helper import get_edge_info_mapping
2729
from nncf.onnx.graph.onnx_helper import get_name_to_node_map
2830
from nncf.onnx.graph.onnx_helper import get_node_index
31+
from nncf.onnx.graph.onnx_helper import get_parents_node_mapping
2932
from nncf.onnx.graph.onnx_helper import get_tensor
3033
from nncf.onnx.graph.transformations.commands import ONNXInitializerUpdateCommand
3134
from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand
@@ -55,8 +58,8 @@ def __init__(self, model: onnx.ModelProto, inplace: bool = False):
5558
self.onnx_model_extractor = onnx.utils.Extractor(inferred_model)
5659
self._inplace = inplace
5760

61+
@staticmethod
5862
def _get_target_edge(
59-
self,
6063
port_id: int,
6164
node_name: str,
6265
transform_type: TargetType,
@@ -381,10 +384,25 @@ def _apply_initializer_update_transformations(
381384
:return: Copy of original model with updated biases.
382385
"""
383386
name_to_node_map = get_name_to_node_map(model)
387+
output_name_to_node_map = get_parents_node_mapping(model)
388+
384389
for transformation in transformations:
385390
node = name_to_node_map[transformation.target_point.target_node_name]
386-
initializer_name = node.input[transformation.target_point.port_id]
387-
set_initializer(initializer_name, model, transformation.new_value)
391+
# NOTE: An `input_name` is either the name of an initializer or the name of a `Constant` operation output
392+
input_name = node.input[transformation.target_point.port_id]
393+
394+
constant_node = output_name_to_node_map.get(input_name, None)
395+
396+
if constant_node is None:
397+
set_initializer(input_name, model, transformation.new_value)
398+
else:
399+
for attr in constant_node.attribute:
400+
if attr.name == "value":
401+
array = transformation.new_value.astype(onnx.helper.tensor_dtype_to_np_dtype(attr.t.data_type))
402+
tensor_proto = onnx.numpy_helper.from_array(array)
403+
attr.t.CopyFrom(tensor_proto)
404+
break
405+
388406
return model
389407

390408
def _apply_model_extraction_transformation(self, transformation: ONNXModelExtractionCommand) -> onnx.ModelProto:
@@ -484,16 +502,25 @@ def _apply_multiply_insertion_transformations(
484502
:returns: Transformed model with Multiply nodes.
485503
"""
486504
node_name_to_node = get_name_to_node_map(model)
505+
# TODO(andrey-churkin): Optimize it
506+
graph = NNCFGraphFactory.create(model)
507+
input_edges_mapping = get_input_edges_mapping(graph)
487508

488509
for transformation in transformations:
510+
port_id = transformation.target_point.port_id
489511
target_node_name = transformation.target_point.target_node_name
490-
target_output_port = transformation.target_point.port_id
491-
target_node = node_name_to_node[target_node_name]
492-
output_tensor_name = target_node.output[target_output_port]
512+
transform_type = transformation.target_point.type
513+
output_tensor_name = ONNXModelTransformer._get_target_edge(
514+
port_id, target_node_name, transform_type, node_name_to_node, input_edges_mapping
515+
)
516+
517+
# TODO(andrey-churkin): Check type of `transformation.scale_value`
493518

494519
# Create a new initializer for the scale constant
495520
scale_tensor_name = f"{transformation.multiply_node_name}_scale"
496-
scale_tensor = onnx.numpy_helper.from_array(transformation.scale_value, name=scale_tensor_name)
521+
scale_tensor = onnx.numpy_helper.from_array(
522+
transformation.scale_value.astype(np.float32), name=scale_tensor_name
523+
)
497524
model.graph.initializer.append(scale_tensor)
498525

499526
# Create a new Multiply node
@@ -505,7 +532,8 @@ def _apply_multiply_insertion_transformations(
505532
name=transformation.multiply_node_name,
506533
)
507534
target_index = get_node_index(model, target_node_name)
508-
model.graph.node.insert(target_index + 1, mul_node)
535+
insert_index = 0 if target_index is None else target_index + 1
536+
model.graph.node.insert(insert_index, mul_node)
509537

510538
for name in transformation.destination_node_names:
511539
node = node_name_to_node[name]
@@ -524,6 +552,11 @@ def set_initializer(initializer_name: str, model: onnx.ModelProto, new_value: np
524552
:param new_value: New value for the initializer tensor.
525553
"""
526554
initializer = get_tensor(model, initializer_name)
555+
556+
required_dtype = onnx.helper.tensor_dtype_to_np_dtype(initializer.data_type)
557+
if new_value.dtype != required_dtype:
558+
new_value = new_value.astype(required_dtype)
559+
527560
new_tensor = onnx.numpy_helper.from_array(new_value, initializer_name)
528561
initializer.CopyFrom(new_tensor)
529562

src/nncf/quantization/algorithms/smooth_quant/algorithm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969

7070
@property
7171
def available_backends(self) -> list[BackendType]:
72-
return [BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX]
72+
return [BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX, BackendType.ONNX]
7373

7474
def _set_backend_entity(self, model: TModel) -> None:
7575
"""
@@ -90,6 +90,10 @@ def _set_backend_entity(self, model: TModel) -> None:
9090
from nncf.quantization.algorithms.smooth_quant.torch_fx_backend import FXSmoothQuantAlgoBackend
9191

9292
self._backend_entity = FXSmoothQuantAlgoBackend()
93+
elif model_backend == BackendType.ONNX:
94+
from nncf.quantization.algorithms.smooth_quant.onnx_backend import ONNXSmoothQuantAlgoBackend
95+
96+
self._backend_entity = ONNXSmoothQuantAlgoBackend()
9397
else:
9498
msg = f"Cannot return backend-specific entity because {model_backend.value} is not supported!"
9599
raise nncf.UnsupportedBackendError(msg)

src/nncf/quantization/algorithms/smooth_quant/backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,12 @@ def get_abs_max_channel_collector(
113113

114114
@staticmethod
115115
@abstractmethod
116-
def get_weight_value(node_with_weight: NNCFNode, model: TModel, port_id: int, nncf_graph: NNCFGraph) -> Tensor:
116+
def get_weight_value(node_with_weight: NNCFNode, model: TModel, nncf_graph: NNCFGraph) -> Tensor:
117117
"""
118118
Returns the weight value for the node with weight.
119119
120120
:param node_with_weight: The node with weight.
121121
:param model: The model that contains this operation.
122-
:param port_id: The input port ID to get weight input.
123122
:param nncf_graph: NNCFGraph instance.
124123
:return: The weight value.
125124
"""
@@ -141,7 +140,11 @@ def weight_update_command(
141140
@staticmethod
142141
@abstractmethod
143142
def scale_insertion_command(
144-
source_node: NNCFNode, scale_value: TTensor, source_output_port_id: int, nodes: list[NNCFNode]
143+
source_node: NNCFNode,
144+
scale_value: TTensor,
145+
source_output_port_id: int,
146+
nodes: list[NNCFNode],
147+
scale_node_name: str,
145148
) -> TransformationCommand:
146149
"""
147150
Returns command to insert Smooth Quant node.
@@ -150,6 +153,7 @@ def scale_insertion_command(
150153
:param scale_value: Smooth Quant value.
151154
:param source_output_port_id: Output port for source node.
152155
:param nodes: List of consumers for Smooth node.
156+
:param scale_node_name: Scale node name.
153157
:return: TransformationCommand instance.
154158
"""
155159

0 commit comments

Comments
 (0)