Skip to content

Commit 18d70ff

Browse files
authored
Graph should only have one (input) kParam node (onnx#1088)
that has all graph level outputs as its outputs
1 parent f28e2f1 commit 18d70ff

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

onnx/optimizer/passes/extract_constant_to_initializer.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,18 @@ struct ExtractConstantToInitializer final : public OptimizePass {
2828
if (n->kind() == kConstant) {
2929
const auto name = n->output()->uniqueName();
3030
Tensor t = n->t(kvalue);
31+
32+
// add a new graph input
33+
Value* input = graph.addInput();
34+
input->setUniqueName(name);
35+
input->setSizes({t.sizes().begin(), t.sizes().end()});
36+
input->setElemType(t.elem_type());
37+
n->output()->replaceAllUsesWith(input);
38+
39+
// copy the tensor to initializer
3140
t.setName(name);
32-
std::vector<Dimension> tsizes;
33-
for (auto v : t.sizes()) {
34-
tsizes.push_back(v);
35-
}
36-
Node* param = graph.create(kParam, 1);
37-
param->output()->setUniqueName(name);
38-
param->output()->setSizes(tsizes);
39-
param->output()->setElemType(t.elem_type());
4041
graph.addInitializer(std::move(t), name);
41-
graph.addInput()->copyMetadata(param->output());
42-
n->replaceAllUsesWith(param);
42+
4343
it.destroyCurrent();
4444
}
4545
}

onnx/test/optimizer_test.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,19 @@ def test_extract_constant_to_initializer(self): # type: () -> None
276276
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 5, 3, 3)),
277277
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (16, 5, 3, 3))],
278278
[helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16, 3, 3))],
279-
value_info=[
280-
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1)),
281-
]
282279
)
283280
optimized_model = self._optimized(graph, ["extract_constant_to_initializer"])
284-
assert len(list(optimized_model.graph.initializer)) == 1
285-
assert len(list(optimized_model.graph.node)) == 2
286-
assert "A" in [i.name for i in optimized_model.graph.initializer]
281+
self.assertEqual(
282+
set(vi.name for vi in optimized_model.graph.input),
283+
{'X', 'Y', 'A'})
284+
285+
self.assertEqual(len(optimized_model.graph.initializer), 1)
286+
init = optimized_model.graph.initializer[0]
287+
self.assertEqual(init.name, 'A')
288+
self.assertEqual(init.dims, [16])
289+
self.assertEqual(init.data_type, TensorProto.FLOAT)
290+
291+
self.assertEqual([n.op_type for n in optimized_model.graph.node], ['Conv', 'Add'])
287292

288293
def test_fuse_transpose(self): # type: () -> None
289294
nodes = [helper.make_node("Transpose", ["X"], ["Y"], perm=[1, 0, 2]),

0 commit comments

Comments
 (0)