Skip to content

Commit 31b76a3

Browse files
feat: Add converter for aten::where (#1421)
1 parent ce29cc7 commit 31b76a3

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,23 @@ auto select_registrations TORCHTRT_UNUSED =
721721

722722
layer->setName(util::node_info(n).c_str());
723723

724+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
725+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
726+
return true;
727+
}})
728+
.pattern(
729+
{"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
730+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
731+
auto condition = args[0].ITensorOrFreeze(ctx);
732+
auto x = args[1].ITensorOrFreeze(ctx);
733+
auto y = args[2].ITensorOrFreeze(ctx);
734+
735+
auto layer = ctx->net->addSelect(*condition, *x, *y);
736+
737+
TORCHTRT_CHECK(layer, "Unable to create select layer for aten::where.self");
738+
739+
layer->setName(util::node_info(n).c_str());
740+
724741
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
725742
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
726743
return true;

tests/core/conversion/converters/test_select.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,3 +1138,33 @@ TEST(Converters, ScatterSrcConvertsCorrectly) {
11381138
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
11391139
}
11401140
}
1141+
1142+
TEST(Converters, WhereConvertsCorrectly) {
1143+
const auto graph = R"IR(
1144+
graph(%condition : Tensor,
1145+
%x : Tensor,
1146+
%y : Tensor):
1147+
%out : Tensor = aten::where(%condition, %x, %y)
1148+
return (%out))IR";
1149+
1150+
auto g = std::make_shared<torch::jit::Graph>();
1151+
1152+
torch::jit::parseIR(graph, g.get());
1153+
1154+
auto condition = at::randint(0, 2, {5, 5}, {at::kCUDA}).to(torch::kBool);
1155+
auto x = at::randn({5, 5}, {at::kCUDA});
1156+
auto y = at::randn({5, 5}, {at::kCUDA});
1157+
1158+
auto jit_condition = at::clone(condition);
1159+
auto jit_x = at::clone(x);
1160+
auto jit_y = at::clone(y);
1161+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1162+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y});
1163+
1164+
auto trt_condition = at::clone(condition);
1165+
auto trt_x = at::clone(x);
1166+
auto trt_y = at::clone(y);
1167+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y});
1168+
1169+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
1170+
}

0 commit comments

Comments
 (0)