Skip to content

Commit a4c3531

Browse files
Check all models with concat in tf / make changes in export_helpers.py to make test clearer
1 parent 03b14ba commit a4c3531

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

nncf/common/pruning/export_helpers.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,31 @@ def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool:
144144

145145
for input_node in graph.get_previous_nodes(node):
146146
# If input has mask -> it went from convolution (source of this node is a convolution)
147+
node_has_mask = False
147148
if input_node.data.get('output_mask', None) is not None:
148-
continue
149+
node_has_mask = True
149150

150151
source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() +
151152
cls.StopMaskForwardOp.get_all_op_aliases() +
152153
cls.InputOp.get_all_op_aliases())
153-
sources_types = [node.node_type for node in source_nodes] + [input_node.node_type]
154-
if any(t in sources_types for t in cls.StopMaskForwardOp.get_all_op_aliases()):
154+
155+
source_types_old = [node.node_type for node in source_nodes]
156+
sources_types_new = source_types_old + [input_node.node_type]
157+
158+
decision_old_on_sources = any(t in source_types_old for t in cls.StopMaskForwardOp.get_all_op_aliases())
159+
decision_old = decision_old_on_sources and node_has_mask
160+
161+
decision_new_on_sources = any(t in sources_types_new for t in cls.StopMaskForwardOp.get_all_op_aliases())
162+
decision_new = decision_new_on_sources and not node_has_mask
163+
164+
if decision_new != decision_old:
165+
is_on_sources_equal = decision_new_on_sources == decision_old_on_sources
166+
if not is_on_sources_equal:
167+
raise ValueError('ALERT')
168+
169+
print(f'is_on_sources_equal = {is_on_sources_equal}')
170+
print('behaviour changed!!!')
171+
if decision_new:
155172
return False
156173
return True
157174

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test
4+
from tests.tensorflow.pruning.helpers import get_basic_pruning_config
5+
from tests.tensorflow import test_models
6+
7+
8+
MODELS = [
9+
{'model': test_models.InceptionV3,
10+
'input_shape': (75, 75, 3)},
11+
{'model': test_models.InceptionResNetV2,
12+
'input_shape': (75, 75, 3)},
13+
{'model': test_models.NASNetMobile,
14+
'input_shape': (32, 32, 3)},
15+
{'model': test_models.DenseNet121,
16+
'input_shape': (32, 32, 3)},
17+
]
18+
19+
20+
@pytest.mark.parametrize('model,input_shape', [list(elem.values()) for elem in MODELS])
21+
def test_concat(model, input_shape):
22+
config = get_basic_pruning_config(input_shape[1])
23+
model = model(list(input_shape))
24+
25+
model, _ = create_compressed_model_and_algo_for_test(model, config)

0 commit comments

Comments
 (0)