Skip to content

Commit fbf07ca

Browse files
authored
[onnx] Fix create ONNX DFT op for scalar signal size (#32637)
### Details: - The ONNX specification for DFT-17, DFT-20 operator defines `dft_length` as scalar input which is not compatible with OV FFT operators (Must be 1-D). In case of scalar input unsqueeze this input to 1-D. ### Tickets: - N/A --------- Signed-off-by: Raasz, Pawel <[email protected]>
1 parent 4c44301 commit fbf07ca

File tree

3 files changed

+95
-6
lines changed

3 files changed

+95
-6
lines changed

src/frontends/onnx/frontend/src/utils/dft.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "openvino/op/idft.hpp"
1313
#include "openvino/op/irdft.hpp"
1414
#include "openvino/op/rdft.hpp"
15+
#include "openvino/op/reshape.hpp"
1516
#include "openvino/op/shape_of.hpp"
1617
#include "openvino/op/unsqueeze.hpp"
1718

@@ -55,27 +56,31 @@ ov::Output<ov::Node> make_dft(const ov::Output<ov::Node>& signal,
5556
conversion_to_complex_applied = try_convert_real_to_complex(processed_signal);
5657
}
5758

58-
bool dft_length_provided = !ov::op::util::is_null(length);
59+
const bool dft_length_provided = !ov::op::util::is_null(length);
60+
const auto& signal_size =
61+
dft_length_provided
62+
? std::make_shared<v1::Reshape>(length, v0::Constant::create(ov::element::i32, {1}, {1}), false)->output(0)
63+
: length;
5964

6065
ov::Output<ov::Node> result;
6166
if (is_inversed) {
6267
if (is_onesided) {
63-
result = dft_length_provided ? std::make_shared<v9::IRDFT>(processed_signal, axis_const, length)
68+
result = dft_length_provided ? std::make_shared<v9::IRDFT>(processed_signal, axis_const, signal_size)
6469
: std::make_shared<v9::IRDFT>(processed_signal, axis_const);
6570
if (conversion_to_complex_applied) { // align the output shape with a real numbers representation
66-
const auto unsqueeze_axis = v0::Constant::create(ov::element::i64, {}, {-1});
71+
const auto unsqueeze_axis = v0::Constant::create(ov::element::i32, {}, {-1});
6772
result = std::make_shared<v0::Unsqueeze>(result, unsqueeze_axis);
6873
}
6974
} else {
70-
result = dft_length_provided ? std::make_shared<v7::IDFT>(processed_signal, axis_const, length)
75+
result = dft_length_provided ? std::make_shared<v7::IDFT>(processed_signal, axis_const, signal_size)
7176
: std::make_shared<v7::IDFT>(processed_signal, axis_const);
7277
}
7378
} else {
7479
if (is_onesided) {
75-
result = dft_length_provided ? std::make_shared<v9::RDFT>(processed_signal, axis_const, length)
80+
result = dft_length_provided ? std::make_shared<v9::RDFT>(processed_signal, axis_const, signal_size)
7681
: std::make_shared<v9::RDFT>(processed_signal, axis_const);
7782
} else {
78-
result = dft_length_provided ? std::make_shared<v7::DFT>(processed_signal, axis_const, length)
83+
result = dft_length_provided ? std::make_shared<v7::DFT>(processed_signal, axis_const, signal_size)
7984
: std::make_shared<v7::DFT>(processed_signal, axis_const);
8085
}
8186
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
ir_version: 7
2+
graph {
3+
node {
4+
output: "dft_length"
5+
op_type: "Constant"
6+
attribute {
7+
name: "value"
8+
t {
9+
dims: 0
10+
data_type: 7
11+
int64_data: 1
12+
name: "const_tensor"
13+
}
14+
type: TENSOR
15+
}
16+
}
17+
node {
18+
input: "data"
19+
input: "dft_length"
20+
output: "out"
21+
op_type: "DFT"
22+
attribute {
23+
name: "inverse"
24+
i: 0
25+
type: INT
26+
}
27+
attribute {
28+
name: "onesided"
29+
i: 0
30+
type: INT
31+
}
32+
attribute {
33+
name: "axis"
34+
i: 0
35+
type: INT
36+
}
37+
}
38+
input {
39+
name: "data"
40+
type {
41+
tensor_type {
42+
elem_type: 1
43+
shape {
44+
dim {
45+
dim_value: 3
46+
}
47+
dim {
48+
dim_value: 5
49+
}
50+
dim {
51+
dim_value: 2
52+
}
53+
}
54+
}
55+
}
56+
}
57+
output {
58+
name: "out"
59+
type {
60+
tensor_type {
61+
elem_type: 1
62+
shape {
63+
}
64+
}
65+
}
66+
}
67+
}
68+
opset_import {
69+
domain: ""
70+
version: 13
71+
}

src/frontends/onnx/tests/onnx_import_signal.in.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,19 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_dft_length_provided) {
124124
{0.000000f, 0.000000f, 1.000000f, 0.000000f, 2.000000f, 0.000000f, 3.000000f, 0.000000f, 4.000000f, 0.000000f});
125125
}
126126

127+
OPENVINO_TEST(${BACKEND_NAME}, onnx_model_dft_scalar_length_provided) {
128+
auto model = convert_model("dft_scalar_length_provided.onnx");
129+
auto test_case = ov::test::TestCase(model, s_device);
130+
test_case.add_input<float>(Shape{3, 5, 2}, {0.000000f, 0.000000f, 1.000000f, 0.000000f, 2.000000f, 0.000000f,
131+
3.000000f, 0.000000f, 4.000000f, 0.000000f, 5.000000f, 0.000000f,
132+
6.000000f, 0.000000f, 7.000000f, 0.000000f, 8.000000f, 0.000000f,
133+
9.000000f, 0.000000f, 10.000000f, 0.000000f, 11.000000f, 0.000000f,
134+
12.000000f, 0.000000f, 13.000000f, 0.000000f, 14.000000f, 0.000000f});
135+
test_case.add_expected_output<float>(
136+
Shape{1, 5, 2},
137+
{0.000000f, 0.000000f, 1.000000f, 0.000000f, 2.000000f, 0.000000f, 3.000000f, 0.000000f, 4.000000f, 0.000000f});
138+
}
139+
127140
OPENVINO_TEST(${BACKEND_NAME}, onnx_model_dft_length_provided_onesided) {
128141
auto model = convert_model("dft_lenght_provided_onesided.onnx");
129142
auto test_case = ov::test::TestCase(model, s_device);

0 commit comments

Comments
 (0)