Skip to content

Commit 0cce7af

Browse files
Sequential tensor reducer
TensorReducerSequence Reducer adapter inside reducer TensorCollectorAdapter
1 parent bc09b0d commit 0cce7af

File tree

11 files changed

+465
-53
lines changed

11 files changed

+465
-53
lines changed

examples/post_training_quantization/onnx/mobilenet_v2/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def run_benchmark(path_to_model: str, shape: Optional[List[int]] = None, verbose
104104
# >> output_names = [output.name for output in sess.get_outputs()]
105105
# >> for data_item in val_loader:
106106
# >> sess.run(output_names, input_feed=transform_fn(data_item))
107+
107108
input_name = model.graph.input[0].name
108109

109110

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
{
2+
"compression": {
3+
"algorithms": [
4+
{
5+
"name": "DefaultQuantization",
6+
"params": {
7+
"preset": "performance",
8+
"stat_subset_size": 3
9+
}
10+
}
11+
],
12+
"dump_intermediate_model": true
13+
},
14+
"engine": {
15+
"datasets": [
16+
{
17+
"metrics": [
18+
{
19+
"type": "wer"
20+
}
21+
],
22+
"name": "LibriSpeech_test_clean_wav",
23+
"data_source": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav",
24+
25+
"annotation_conversion": {
26+
"converter": "librispeech",
27+
"data_dir": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/librispeech/test/LibriSpeech/test-clean.wav"
28+
},
29+
"preprocessing": [
30+
{
31+
"int16mode": true,
32+
"type": "audio_normalization"
33+
},
34+
{
35+
"duration": "512 samples",
36+
"overlap": "192 samples",
37+
"type": "clip_audio"
38+
},
39+
{
40+
"base": 512,
41+
"type": "hanning_window"
42+
},
43+
{
44+
"fftbase": 512,
45+
"magnitude_squared": true,
46+
"skip_channels": true,
47+
"type": "audio_spectrogram"
48+
},
49+
{
50+
"base": 257,
51+
"filterbank_channel_count": 40,
52+
"lower_frequency_limit": 20,
53+
"sample_rate": 16000,
54+
"type": "audio_triangle_filtering",
55+
"upper_frequency_limit": 4000
56+
},
57+
{
58+
"filterbank_channel_count": 40,
59+
"numceps": 26,
60+
"type": "audio_dct"
61+
},
62+
{
63+
"context": 9,
64+
"numceps": 26,
65+
"type": "clip_cepstrum"
66+
},
67+
{
68+
"step": 16,
69+
"type": "pack_cepstrum"
70+
}
71+
],
72+
"reader": "wav_reader"
73+
}
74+
],
75+
"launchers": [
76+
{
77+
"adapter": {
78+
"beam_size": 32,
79+
"lm_alpha": 0.75,
80+
"lm_beta": 1.05,
81+
"lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary",
82+
"lm_oov_score": -1000,
83+
"lm_vocabulary_length": 4463723,
84+
"lm_vocabulary_offset": 941235601,
85+
"logarithmic_prob": false,
86+
"probability_out": "logits",
87+
"type": "ctc_beam_search_decoder_with_lm"
88+
},
89+
"framework": "dlsdk",
90+
"inputs": [
91+
{
92+
"layout": "NHWC",
93+
"name": "input_node",
94+
"type": "INPUT"
95+
},
96+
{
97+
"name": "previous_state_c",
98+
"type": "LSTM_INPUT",
99+
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.2"
100+
},
101+
{
102+
"name": "previous_state_h",
103+
"type": "LSTM_INPUT",
104+
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/BlockLSTM/TensorIterator.1"
105+
}
106+
]
107+
},
108+
{
109+
"adapter": {
110+
"beam_size": 32,
111+
"lm_alpha": 0.75,
112+
"lm_beta": 1.05,
113+
"lm_file": "/mnt/omz_new/nn_icv_cv_externalN/omz-validation-datasets/model_attributes/mozilla-deepspeech-0.6.1/lm.binary",
114+
"lm_oov_score": -1000,
115+
"lm_vocabulary_length": 4463723,
116+
"lm_vocabulary_offset": 941235601,
117+
"logarithmic_prob": false,
118+
"probability_out": "logits",
119+
"type": "ctc_beam_search_decoder_with_lm"
120+
},
121+
"framework": "openvino",
122+
"inputs": [
123+
{
124+
"layout": "NHWC",
125+
"name": "input_node",
126+
"type": "INPUT"
127+
},
128+
{
129+
"name": "previous_state_c",
130+
"type": "LSTM_INPUT",
131+
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd:0"
132+
},
133+
{
134+
"name": "previous_state_h",
135+
"type": "LSTM_INPUT",
136+
"value": "cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/GatherNd_1:0"
137+
}
138+
]
139+
}
140+
]
141+
},
142+
"model": {
143+
"model": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.xml",
144+
"model_name": "mozilla-deepspeech-0.6.1",
145+
"weights": "/mnt/omz/cv_bench_cache/ww18_weekly_23.0.0-10862-40bf400b189-API2.0/mozilla-deepspeech-0.6.1/tf/tf_frozen/FP16/1/dldt/mozilla-deepspeech-0.6.1.bin"
146+
}
147+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import json
2+
import os
3+
import subprocess
4+
5+
import numpy as np
6+
import openvino.runtime as ov
7+
from openvino.tools.accuracy_checker.evaluators.quantization_model_evaluator import create_model_evaluator
8+
from openvino.tools.pot.configs.config import Config
9+
10+
import nncf
11+
12+
model_name = "mozilla-deepspeech-0.6.1"
13+
cache_dir = os.path.dirname(__file__)
14+
dataset_config = os.path.join(cache_dir, "accuracy_checker.json")
15+
16+
command = f"omz_downloader --name {model_name} --cache_dir {cache_dir}"
17+
cmd_output = subprocess.call(command, shell=True) # nosec
18+
19+
model_dir = os.path.join(cache_dir, model_name)
20+
if not os.path.exists(model_dir):
21+
command = f"omz_converter --name {model_name} -o {os.path.join(cache_dir, model_name)}"
22+
cmd_output = subprocess.call(command, shell=True) # nosec
23+
24+
xml_path = os.path.join(model_dir, f"public/{model_name}/FP16/{model_name}.xml")
25+
ov_model = ov.Core().read_model(xml_path)
26+
27+
config = Config.read_config(dataset_config)
28+
config.configure_params()
29+
accuracy_checker_config = config.engine
30+
31+
model_evaluator = create_model_evaluator(accuracy_checker_config)
32+
model_evaluator.load_network([{"model": ov_model}])
33+
model_evaluator.select_dataset("")
34+
35+
36+
def get_tokens_from_sequence_func(data_item):
37+
_, batch_annotation, batch_input, _ = data_item
38+
filled_inputs, _, _ = model_evaluator._get_batch_input(batch_input, batch_annotation)
39+
for filled_input in filled_inputs:
40+
input_data = {}
41+
for name, value in filled_input.items():
42+
input_data[model_evaluator.launcher.input_to_tensor_name[name]] = value
43+
yield input_data
44+
45+
46+
def fill_sequential_inputs_fn(model_inputs, model_outputs):
47+
# Combine model inputs with state model outputs
48+
# or fill state model outputs if model_outputs is None
49+
state_inputs = model_evaluator.launcher._fill_lstm_inputs(model_outputs)
50+
model_inputs.update(state_inputs)
51+
return model_inputs
52+
53+
54+
dataset = nncf.RecurentDataset(model_evaluator.dataset, get_tokens_from_sequence_func, fill_sequential_inputs_fn)
55+
quantized_model = nncf.quantize(ov_model, dataset, subset_size=3)

nncf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from nncf.common.logging.logger import set_log_level
1616
from nncf.config import NNCFConfig
1717
from nncf.data import Dataset
18+
from nncf.data import RecurentDataset
1819
from nncf.parameters import DropType
1920
from nncf.parameters import ModelType
2021
from nncf.parameters import TargetDevice

nncf/common/tensor_statistics/aggregator.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111
from abc import ABC
1212
from abc import abstractmethod
13+
from collections import defaultdict
1314
from itertools import islice
1415
from typing import Any, Dict, TypeVar
1516

@@ -21,6 +22,7 @@
2122
from nncf.common.tensor import NNCFTensor
2223
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
2324
from nncf.data.dataset import Dataset
25+
from nncf.data.dataset import RecurentDataset
2426

2527
TensorType = TypeVar("TensorType")
2628
TModel = TypeVar("TModel")
@@ -31,10 +33,13 @@ class StatisticsAggregator(ABC):
3133
Base class for statistics collection.
3234
"""
3335

36+
STACK_AXIS = 0
37+
3438
def __init__(self, dataset: Dataset):
3539
self.dataset = dataset
3640
self.stat_subset_size = None
3741
self.statistic_points = StatisticPointsContainer()
42+
self._is_sequential = isinstance(dataset, RecurentDataset)
3843

3944
def collect_statistics(self, model: TModel) -> None:
4045
"""
@@ -46,19 +51,44 @@ def collect_statistics(self, model: TModel) -> None:
4651
model_transformer = ModelTransformerFactory.create(model)
4752

4853
merged_statistics = self._get_merged_statistic_points(self.statistic_points, model)
54+
if self._is_sequential:
55+
merged_statistics = self._adapt_collectors(merged_statistics, self.STACK_AXIS)
56+
4957
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
5058
model_with_outputs = model_transformer.transform(transformation_layout)
5159
engine = EngineFactory.create(model_with_outputs)
60+
infer_fn = self._infer_sequential if self._is_sequential else self._infer
5261

5362
for input_data in tqdm(
5463
islice(self.dataset.get_inference_data(), self.stat_subset_size),
5564
total=self.stat_subset_size,
5665
desc="Statistics collection",
5766
):
58-
outputs = engine.infer(input_data)
59-
processed_outputs = self._process_outputs(outputs)
67+
processed_outputs = infer_fn(engine, input_data)
6068
self._register_statistics(processed_outputs, merged_statistics)
6169

70+
def _infer(self, engine, input_data):
71+
outputs = engine.infer(input_data)
72+
return self._process_outputs(outputs)
73+
74+
def _infer_sequential(self, engine, sequence):
75+
model_output = None
76+
model_outputs = defaultdict(list)
77+
for token in sequence.get_tokens_iter():
78+
filled_inputs = sequence.fill_inputs(token, model_output)
79+
model_output = engine.infer(filled_inputs)
80+
processed_output = self._process_outputs(model_output)
81+
for output_name, output_value in processed_output.items():
82+
model_outputs[output_name].append(output_value)
83+
84+
# Stack model outputs and return them
85+
stacked_outputs = {}
86+
tensor_processor = self._get_tensor_processor()
87+
for output_name, output_values in model_outputs.items():
88+
stacked_outputs[output_name] = tensor_processor.stack(output_values, axis=self.STACK_AXIS)
89+
90+
return stacked_outputs
91+
6292
def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
6393
"""
6494
Register statistic points for statistics collection and recalculates the maximum number samples
@@ -115,6 +145,10 @@ def _get_merged_statistic_points(
115145
:return: Merged statistic points container bounded with given statistic point container.
116146
"""
117147

148+
@staticmethod
149+
def _adapt_collectors(statistic_points: StatisticPointsContainer, stack_axis: int):
150+
return statistic_points
151+
118152
@staticmethod
119153
@abstractmethod
120154
def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
@@ -124,3 +158,8 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
124158
:param outputs: raw model outputs
125159
:return: processed model outputs in Dict[str, NNCFTensor] format
126160
"""
161+
162+
@staticmethod
163+
@abstractmethod
164+
def _get_tensor_processor():
165+
pass

nncf/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@
1010
# limitations under the License.
1111

1212
from nncf.data.dataset import Dataset
13+
from nncf.data.dataset import RecurentDataset
14+
from nncf.data.dataset import Sequence

nncf/data/dataset.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Generic, Iterable, List, Optional, TypeVar
12+
from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar
1313

1414
from nncf.common.utils.api_marker import api
1515

@@ -115,3 +115,30 @@ def _get_iterator_for_iter(
115115
if idx == indices[pos]:
116116
pos = pos + 1
117117
yield transform_func(data_item)
118+
119+
120+
@api(canonical_alias="nncf.RecurentDataset")
121+
class RecurentDataset(Dataset):
122+
def __init__(self, data_source: Iterable, get_token_from_sequence_func, fill_sequential_inputs_fn):
123+
def transform_fn_wrapper(data_item):
124+
return Sequence(data_item, get_token_from_sequence_func, fill_sequential_inputs_fn)
125+
126+
super().__init__(data_source, transform_fn_wrapper)
127+
128+
129+
class Sequence:
130+
def __init__(
131+
self,
132+
raw_sequence,
133+
get_tokens_from_sequence_func: Callable[[DataItem], ModelInput],
134+
fill_sequential_inputs_fn: Callable[[DataItem], ModelInput],
135+
):
136+
self._raw_sequence = raw_sequence
137+
self._get_tokens_from_sequence_func = get_tokens_from_sequence_func
138+
self._fill_sequential_inputs_fn = fill_sequential_inputs_fn
139+
140+
def get_tokens_iter(self):
141+
return self._get_tokens_from_sequence_func(self._raw_sequence)
142+
143+
def fill_inputs(self, token, model_outputs):
144+
return self._fill_sequential_inputs_fn(token, model_outputs)

0 commit comments

Comments
 (0)