Skip to content

Commit cd1bda3

Browse files
authored
fix: Update floor division schema in lowering (#1464)
- Lowering can occasionally replace the `aten::floor_divide` inputs with integers or floats, for which the correct function name is `aten::floordiv`, which ultimately throws an unknown schema error - Fix `RemoveSingleUse0DTensors` lowering pass to repair schema when replacing intermediate operation including an `aten::floor_divide` call - Add test cases to ensure that the resulting graph is schematically correct, but also numerically accurate
1 parent 158be87 commit cd1bda3

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

core/lowering/passes/remove_unnecessary_casts.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
131131
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
132132
user->destroy();
133133
break;
134+
case c10::aten::floor_divide:
135+
new_node = g->create(c10::aten::floordiv, user->inputs(), 1);
136+
new_node->insertAfter(user);
137+
new_node->outputs()[0]->setType(c10::IntType::get());
138+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
139+
user->destroy();
140+
break;
134141
default:
135142
new_node = g->create(user->kind(), user->inputs(), 1);
136143
new_node->insertAfter(user);

tests/core/lowering/test_remove_unnecessary_casts.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,136 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
153153

154154
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
155155
}
156+
157+
TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntCorrectly) {
158+
std::string source_graph = R"IR(
159+
graph(%0: int):
160+
%1: Tensor = prim::Constant[value=[7]]()
161+
%3: Tensor = prim::NumToTensor(%0)
162+
%4: Tensor = aten::floor_divide(%1, %3)
163+
%5: int = aten::Int(%4)
164+
return (%5))IR";
165+
std::string target_graph = R"IR(
166+
graph(%0: int):
167+
%1: int = prim::Constant[value=7]()
168+
%4: int = aten::floordiv(%1, %0)
169+
return (%4))IR";
170+
171+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
172+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
173+
auto sg = std::make_shared<torch::jit::Graph>();
174+
torch::jit::parseIR(source_graph, sg.get());
175+
176+
auto first_op = *(sg->block()->nodes().begin());
177+
torch::jit::WithInsertPoint guard(first_op);
178+
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(7), c10::nullopt, first_op->scope());
179+
r->copyMetadata(first_op->output());
180+
r->setType(c10::TensorType::get());
181+
first_op->output()->replaceAllUsesWith(r);
182+
first_op->destroy();
183+
184+
torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
185+
186+
auto tg = std::make_shared<torch::jit::Graph>();
187+
torch::jit::parseIR(target_graph, tg.get());
188+
189+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
190+
}
191+
192+
TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatCorrectly) {
193+
std::string source_graph = R"IR(
194+
graph(%0: float):
195+
%1: Tensor = prim::Constant[value=[8.]]()
196+
%3: Tensor = prim::NumToTensor(%0)
197+
%4: Tensor = aten::floor_divide(%1, %3)
198+
%5: float = aten::Float(%4)
199+
return (%5))IR";
200+
std::string target_graph = R"IR(
201+
graph(%0: float):
202+
%1: float = prim::Constant[value=8.]()
203+
%4: float = aten::floordiv(%1, %0)
204+
return (%4))IR";
205+
206+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
207+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
208+
auto sg = std::make_shared<torch::jit::Graph>();
209+
torch::jit::parseIR(source_graph, sg.get());
210+
211+
auto first_op = *(sg->block()->nodes().begin());
212+
torch::jit::WithInsertPoint guard(first_op);
213+
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
214+
r->copyMetadata(first_op->output());
215+
r->setType(c10::TensorType::get());
216+
first_op->output()->replaceAllUsesWith(r);
217+
first_op->destroy();
218+
219+
torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
220+
221+
auto tg = std::make_shared<torch::jit::Graph>();
222+
torch::jit::parseIR(target_graph, tg.get());
223+
224+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
225+
}
226+
227+
TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) {
228+
std::string source_graph_no_inputs = R"IR(
229+
graph():
230+
%0: int = prim::Constant[value=2]()
231+
%11: int = prim::Constant[value=7]()
232+
%3: Tensor = prim::NumToTensor(%0)
233+
%1: Tensor = prim::NumToTensor(%11)
234+
%4: Tensor = aten::floor_divide(%1, %3)
235+
%50: int = aten::Int(%4)
236+
%5: Tensor = prim::NumToTensor(%50)
237+
return (%5))IR";
238+
std::string target_graph_no_inputs = R"IR(
239+
graph():
240+
%0: int = prim::Constant[value=2]()
241+
%1: int = prim::Constant[value=7]()
242+
%40: int = aten::floordiv(%1, %0)
243+
%4: Tensor = prim::NumToTensor(%40)
244+
return (%4))IR";
245+
246+
auto g_in = std::make_shared<torch::jit::Graph>();
247+
auto g_out = std::make_shared<torch::jit::Graph>();
248+
249+
torch::jit::parseIR(source_graph_no_inputs, g_in.get());
250+
torch::jit::parseIR(target_graph_no_inputs, g_out.get());
251+
252+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
253+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});
254+
255+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));
256+
}
257+
258+
TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
259+
std::string source_graph_no_inputs = R"IR(
260+
graph():
261+
%0: float = prim::Constant[value=2.]()
262+
%11: float = prim::Constant[value=7.]()
263+
%3: Tensor = prim::NumToTensor(%0)
264+
%1: Tensor = prim::NumToTensor(%11)
265+
%4: Tensor = aten::floor_divide(%1, %3)
266+
%50: float = aten::Float(%4)
267+
%5: Tensor = prim::NumToTensor(%50)
268+
return (%5))IR";
269+
std::string target_graph_no_inputs = R"IR(
270+
graph():
271+
%0: float = prim::Constant[value=2.]()
272+
%1: float = prim::Constant[value=7.]()
273+
%40: float = aten::floordiv(%1, %0)
274+
%4: Tensor = prim::NumToTensor(%40)
275+
return (%4))IR";
276+
277+
auto g_in = std::make_shared<torch::jit::Graph>();
278+
auto g_out = std::make_shared<torch::jit::Graph>();
279+
280+
torch::jit::parseIR(source_graph_no_inputs, g_in.get());
281+
torch::jit::parseIR(target_graph_no_inputs, g_out.get());
282+
283+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
284+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});
285+
286+
ASSERT_TRUE(
287+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
288+
}

0 commit comments

Comments
 (0)