@@ -1138,3 +1138,33 @@ TEST(Converters, ScatterSrcConvertsCorrectly) {
1138
1138
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
1139
1139
}
1140
1140
}
1141
+
1142
+ TEST (Converters, WhereConvertsCorrectly) {
1143
+ const auto graph = R"IR(
1144
+ graph(%condition : Tensor,
1145
+ %x : Tensor,
1146
+ %y : Tensor):
1147
+ %out : Tensor = aten::where(%condition, %x, %y)
1148
+ return (%out))IR" ;
1149
+
1150
+ auto g = std::make_shared<torch::jit::Graph>();
1151
+
1152
+ torch::jit::parseIR (graph, g.get ());
1153
+
1154
+ auto condition = at::randint (0 , 2 , {5 , 5 }, {at::kCUDA }).to (torch::kBool );
1155
+ auto x = at::randn ({5 , 5 }, {at::kCUDA });
1156
+ auto y = at::randn ({5 , 5 }, {at::kCUDA });
1157
+
1158
+ auto jit_condition = at::clone (condition);
1159
+ auto jit_x = at::clone (x);
1160
+ auto jit_y = at::clone (y);
1161
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1162
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_condition, jit_x, jit_y});
1163
+
1164
+ auto trt_condition = at::clone (condition);
1165
+ auto trt_x = at::clone (x);
1166
+ auto trt_y = at::clone (y);
1167
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_condition, trt_x, trt_y});
1168
+
1169
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
1170
+ }
0 commit comments