@@ -153,3 +153,136 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
153
153
154
154
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
155
155
}
156
+
157
+ TEST (LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntCorrectly) {
158
+ std::string source_graph = R"IR(
159
+ graph(%0: int):
160
+ %1: Tensor = prim::Constant[value=[7]]()
161
+ %3: Tensor = prim::NumToTensor(%0)
162
+ %4: Tensor = aten::floor_divide(%1, %3)
163
+ %5: int = aten::Int(%4)
164
+ return (%5))IR" ;
165
+ std::string target_graph = R"IR(
166
+ graph(%0: int):
167
+ %1: int = prim::Constant[value=7]()
168
+ %4: int = aten::floordiv(%1, %0)
169
+ return (%4))IR" ;
170
+
171
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
172
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
173
+ auto sg = std::make_shared<torch::jit::Graph>();
174
+ torch::jit::parseIR (source_graph, sg.get ());
175
+
176
+ auto first_op = *(sg->block ()->nodes ().begin ());
177
+ torch::jit::WithInsertPoint guard (first_op);
178
+ torch::jit::Value* r = sg->insertConstant (c10::scalar_to_tensor (7 ), c10::nullopt, first_op->scope ());
179
+ r->copyMetadata (first_op->output ());
180
+ r->setType (c10::TensorType::get ());
181
+ first_op->output ()->replaceAllUsesWith (r);
182
+ first_op->destroy ();
183
+
184
+ torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors (sg);
185
+
186
+ auto tg = std::make_shared<torch::jit::Graph>();
187
+ torch::jit::parseIR (target_graph, tg.get ());
188
+
189
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
190
+ }
191
+
192
+ TEST (LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatCorrectly) {
193
+ std::string source_graph = R"IR(
194
+ graph(%0: float):
195
+ %1: Tensor = prim::Constant[value=[8.]]()
196
+ %3: Tensor = prim::NumToTensor(%0)
197
+ %4: Tensor = aten::floor_divide(%1, %3)
198
+ %5: float = aten::Float(%4)
199
+ return (%5))IR" ;
200
+ std::string target_graph = R"IR(
201
+ graph(%0: float):
202
+ %1: float = prim::Constant[value=8.]()
203
+ %4: float = aten::floordiv(%1, %0)
204
+ return (%4))IR" ;
205
+
206
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
207
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
208
+ auto sg = std::make_shared<torch::jit::Graph>();
209
+ torch::jit::parseIR (source_graph, sg.get ());
210
+
211
+ auto first_op = *(sg->block ()->nodes ().begin ());
212
+ torch::jit::WithInsertPoint guard (first_op);
213
+ torch::jit::Value* r = sg->insertConstant (c10::scalar_to_tensor (8.0 ), c10::nullopt, first_op->scope ());
214
+ r->copyMetadata (first_op->output ());
215
+ r->setType (c10::TensorType::get ());
216
+ first_op->output ()->replaceAllUsesWith (r);
217
+ first_op->destroy ();
218
+
219
+ torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors (sg);
220
+
221
+ auto tg = std::make_shared<torch::jit::Graph>();
222
+ torch::jit::parseIR (target_graph, tg.get ());
223
+
224
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
225
+ }
226
+
227
+ TEST (LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) {
228
+ std::string source_graph_no_inputs = R"IR(
229
+ graph():
230
+ %0: int = prim::Constant[value=2]()
231
+ %11: int = prim::Constant[value=7]()
232
+ %3: Tensor = prim::NumToTensor(%0)
233
+ %1: Tensor = prim::NumToTensor(%11)
234
+ %4: Tensor = aten::floor_divide(%1, %3)
235
+ %50: int = aten::Int(%4)
236
+ %5: Tensor = prim::NumToTensor(%50)
237
+ return (%5))IR" ;
238
+ std::string target_graph_no_inputs = R"IR(
239
+ graph():
240
+ %0: int = prim::Constant[value=2]()
241
+ %1: int = prim::Constant[value=7]()
242
+ %40: int = aten::floordiv(%1, %0)
243
+ %4: Tensor = prim::NumToTensor(%40)
244
+ return (%4))IR" ;
245
+
246
+ auto g_in = std::make_shared<torch::jit::Graph>();
247
+ auto g_out = std::make_shared<torch::jit::Graph>();
248
+
249
+ torch::jit::parseIR (source_graph_no_inputs, g_in.get ());
250
+ torch::jit::parseIR (target_graph_no_inputs, g_out.get ());
251
+
252
+ auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_in, {});
253
+ auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_out, {});
254
+
255
+ ASSERT_TRUE (torch_tensorrt::tests::util::exactlyEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor ()));
256
+ }
257
+
258
+ TEST (LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
259
+ std::string source_graph_no_inputs = R"IR(
260
+ graph():
261
+ %0: float = prim::Constant[value=2.]()
262
+ %11: float = prim::Constant[value=7.]()
263
+ %3: Tensor = prim::NumToTensor(%0)
264
+ %1: Tensor = prim::NumToTensor(%11)
265
+ %4: Tensor = aten::floor_divide(%1, %3)
266
+ %50: float = aten::Float(%4)
267
+ %5: Tensor = prim::NumToTensor(%50)
268
+ return (%5))IR" ;
269
+ std::string target_graph_no_inputs = R"IR(
270
+ graph():
271
+ %0: float = prim::Constant[value=2.]()
272
+ %1: float = prim::Constant[value=7.]()
273
+ %40: float = aten::floordiv(%1, %0)
274
+ %4: Tensor = prim::NumToTensor(%40)
275
+ return (%4))IR" ;
276
+
277
+ auto g_in = std::make_shared<torch::jit::Graph>();
278
+ auto g_out = std::make_shared<torch::jit::Graph>();
279
+
280
+ torch::jit::parseIR (source_graph_no_inputs, g_in.get ());
281
+ torch::jit::parseIR (target_graph_no_inputs, g_out.get ());
282
+
283
+ auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_in, {});
284
+ auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_out, {});
285
+
286
+ ASSERT_TRUE (
287
+ torch_tensorrt::tests::util::almostEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor (), 2e-6 ));
288
+ }
0 commit comments