Skip to content

Commit 80d2806

Browse files
authored
Fix CI test errors introduced by TF 2.4.0 (onnx#842)
Fix CI test errors introduced by TF 2.4.0 including lstm: test_model.py gru: test_model.py rnn: test_model.py mod: mod.py Signed-off-by: Chin Huang <[email protected]>
1 parent 1af4aed commit 80d2806

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

onnx_tf/handlers/backend/mod.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ class Mod(ArithmeticMixin, BackendHandler):
1515
tf.uint16: tf.int32,
1616
tf.uint32: tf.int64,
1717
tf.int8: tf.int32,
18-
tf.int16: tf.int32
18+
tf.int16: tf.int32,
19+
tf.float16: tf.float32
1920
}
2021
supported_types = [
21-
tf.int32, tf.int64, tf.float16, tf.float32, tf.float64, tf.bfloat16
22+
tf.int32, tf.int64, tf.float32, tf.float64, tf.bfloat16
2223
]
2324

2425
@classmethod

test/backend/test_model.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def test_lstm_savedmodel(self):
5555
helper.make_tensor_value_info("Y_c", TensorProto.FLOAT, [1, 2, 3])
5656
])
5757

58-
# prepare the ONNX model and save it as a Tensorflow SavedModel
59-
tf_rep = prepare(helper.make_model(graph_def))
60-
model_path = "lstm_savedmodel"
61-
tf_rep.export_graph(model_path)
62-
6358
# Initializing Inputs
6459
X = np.array([[[1., 2., 3., 4.], [5., 6., 7., 8.]]]).astype(np.float32)
6560
W = weight_scale * np.ones((1, number_of_gates * hidden_size, input_size)).astype(np.float32)
@@ -68,6 +63,12 @@ def test_lstm_savedmodel(self):
6863
seq_lens = np.repeat(X.shape[0], X.shape[1]).astype(np.int32)
6964
init_h = weight_scale * np.ones((1, X.shape[1], hidden_size)).astype(np.float32)
7065

66+
# prepare the ONNX model and save it as a Tensorflow SavedModel
67+
tf_rep = prepare(helper.make_model(graph_def))
68+
tf_rep.run({"X": X, "W": W, "R": R, "B": B, "sequence_lens": seq_lens, "initial_h": init_h })
69+
model_path = "lstm_savedmodel"
70+
tf_rep.export_graph(model_path)
71+
7172
# use the ONNX reference implementation to get expected output
7273
lstm = LSTM_Helper(X=X, W=W, R=R, B=B, initial_h=init_h)
7374
_, Y_ref = lstm.step()
@@ -108,10 +109,6 @@ def test_gru_savedmodel(self):
108109
helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, [1, 3, 3])
109110
])
110111

111-
tf_rep = prepare(helper.make_model(graph_def))
112-
model_path = "gru_savedmodel"
113-
tf_rep.export_graph(model_path)
114-
115112
# initializing Inputs
116113
X = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8.,
117114
9.]]]).astype(np.float32)
@@ -124,6 +121,12 @@ def test_gru_savedmodel(self):
124121
R_B = np.zeros((1, number_of_gates * hidden_size)).astype(np.float32)
125122
B = np.concatenate((W_B, R_B), axis=1)
126123

124+
# prepare the ONNX model and save it as a Tensorflow SavedModel
125+
tf_rep = prepare(helper.make_model(graph_def))
126+
tf_rep.run({"X": X, "W": W, "R": R, "B": B})
127+
model_path = "gru_savedmodel"
128+
tf_rep.export_graph(model_path)
129+
127130
# use the ONNX reference implementation to get the expected output
128131
gru = GRU_Helper(X=X, W=W, R=R, B=B)
129132
_, Y_ref = gru.step()
@@ -161,15 +164,17 @@ def test_rnn_savedmodel(self):
161164
helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, [1, 2, 4])
162165
])
163166

164-
tf_rep = prepare(helper.make_model(graph_def))
165-
model_path = "rnn_savedmodel"
166-
tf_rep.export_graph(model_path)
167-
168167
# initializing Inputs
169168
X = np.array([[[1., 2.], [3., 4.], [5., 6.]]]).astype(np.float32)
170169
W = weight_scale * np.ones((1, hidden_size, input_size)).astype(np.float32)
171170
R = weight_scale * np.ones((1, hidden_size, hidden_size)).astype(np.float32)
172171

172+
# prepare the ONNX model and save it as a Tensorflow SavedModel
173+
tf_rep = prepare(helper.make_model(graph_def))
174+
tf_rep.run({"X": X, "W": W, "R": R})
175+
model_path = "rnn_savedmodel"
176+
tf_rep.export_graph(model_path)
177+
173178
# use the ONNX reference implementation to get the expected output
174179
rnn = RNN_Helper(X=X, W=W, R=R)
175180
_, Y_ref = rnn.step()

0 commit comments

Comments
 (0)