Skip to content

Commit b638e78

Browse files
[fix] Fix crash when calling unbind on evaluated tensor (#1554)
1 parent 2ef6c3a commit b638e78

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace impl {
1616
namespace {
1717

1818
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) {
19-
auto in = args[0].ITensor();
19+
auto in = args[0].ITensorOrFreeze(ctx);
2020
auto numOutputs = 1, numRemainder = 0;
2121
std::vector<int64_t> sizes;
2222

tests/core/conversion/converters/test_select.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,34 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
11221122
}
11231123
}
11241124

1125+
TEST(Converters, ATenUnbindEvaluatedTensor) {
1126+
const auto graph = R"IR(
1127+
graph(%x.1 : Tensor):
1128+
%2 : None = prim::Constant()
1129+
%3 : int[] = aten::size(%x.1)
1130+
%z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2)
1131+
%5 : int = prim::Constant[value=-1]()
1132+
%6 : Tensor[] = aten::unbind(%z.1, %5)
1133+
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%6)
1134+
return (%o1.1, %o2.1))IR";
1135+
1136+
auto in = at::randint(1, 10, {2}, {at::kCUDA});
1137+
1138+
auto g = std::make_shared<torch::jit::Graph>();
1139+
1140+
torch::jit::parseIR(graph, g.get());
1141+
1142+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1143+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
1144+
1145+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
1146+
1147+
for (size_t i = 0; i < jit_results.size(); i++) {
1148+
auto trt = trt_results[i];
1149+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i].cuda(), trt, 2e-6));
1150+
}
1151+
}
1152+
11251153
TEST(Converters, ScatterValueConvertsCorrectly) {
11261154
const auto graph = R"IR(
11271155
graph(%data : Tensor,

0 commit comments

Comments
 (0)