Skip to content

Commit f2b20be

Browse files
committed
force ones_like to fallback in test that relies on it not converting
1 parent c4c69d5 commit f2b20be

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,16 @@ c10::optional<torch::jit::IValue> newTensorLikeImplementation(
176176
if (ctx->input_is_dynamic) {
177177
auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
178178
std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
179-
auto zeros = tensor_builder(dims_vec, options);
180-
auto zeros_itensor = converters::tensor_to_const(ctx, zeros);
179+
auto constant = tensor_builder(dims_vec, options);
180+
auto constant_itensor = converters::tensor_to_const(ctx, constant);
181181
// broadcast constant to output shape
182182
std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
183183
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
184184
auto shape_layer = ctx->net->addShape(*self);
185185
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
186186
shape_layer->setName((util::node_info(n) + "_shape").c_str());
187187
// slice implements expand
188-
auto slice_layer = ctx->net->addSlice(*zeros_itensor, start_offset, self->getDimensions(), start_offset);
188+
auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset);
189189
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
190190
slice_layer->setInput(2, *shape_layer->getOutput(0));
191191
slice_layer->setName((util::node_info(n) + "_slice").c_str());

tests/core/partitioning/test_loop_fallback.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
5353

5454
std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 10})};
5555
torch_tensorrt::core::CompileSpec cfg(input_ranges);
56+
cfg.partitioning_info.forced_fallback_operators.push_back("aten::ones_like");
5657
cfg.partitioning_info.enabled = true;
5758

5859
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();

0 commit comments

Comments
 (0)