Skip to content

Commit c4c69d5

Browse files
committed
Add support for dynamic zeros_like and ones_like
1 parent bf46d54 commit c4c69d5

File tree

2 files changed

+318
-0
lines changed

2 files changed

+318
-0
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,77 @@ DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
125125
int64_t,
126126
{"aten::__round_to_zero_floordiv(int a, int b) -> (int)"});
127127

128+
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args) {
129+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
130+
131+
// Input 2 is the dtype
132+
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
133+
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
134+
} else {
135+
auto tensor_var = args.at(n->input(0));
136+
if (tensor_var.isITensor()) {
137+
auto tensor = tensor_var.ITensor();
138+
options = options.dtype(scalarTypeToTypeMeta(util::TRTDataTypeToScalarType(tensor->getType())));
139+
} else {
140+
auto tensor = tensor_var.unwrapToTensor();
141+
options = options.dtype(tensor.dtype());
142+
}
143+
}
144+
return std::make_pair(args.at(n->input(1)).unwrapToIntList().vec(), options);
145+
}
146+
147+
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
148+
ConversionCtx* ctx,
149+
const torch::jit::Node* n,
150+
kwargs& args,
151+
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder) {
152+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
153+
auto tensor_var = args.at(n->input(0));
154+
155+
if (tensor_var.isITensor()) {
156+
auto tensor = tensor_var.ITensor();
157+
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
158+
options = options.dtype(dtype);
159+
} else {
160+
auto tensor = tensor_var.unwrapToTensor();
161+
options = options.dtype(tensor.dtype());
162+
}
163+
164+
// Input 1 is the dtype
165+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
166+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
167+
}
168+
std::vector<int64_t> tensor_dims;
169+
if (tensor_var.isITensor()) {
170+
auto tensor = tensor_var.ITensor();
171+
tensor_dims = util::toVec(tensor->getDimensions());
172+
} else {
173+
auto tensor = tensor_var.unwrapToTensor();
174+
tensor_dims = tensor.sizes().vec();
175+
}
176+
if (ctx->input_is_dynamic) {
177+
auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
178+
std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
179+
auto zeros = tensor_builder(dims_vec, options);
180+
auto zeros_itensor = converters::tensor_to_const(ctx, zeros);
181+
// broadcast constant to output shape
182+
std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
183+
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
184+
auto shape_layer = ctx->net->addShape(*self);
185+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
186+
shape_layer->setName((util::node_info(n) + "_shape").c_str());
187+
// slice implements expand
188+
auto slice_layer = ctx->net->addSlice(*zeros_itensor, start_offset, self->getDimensions(), start_offset);
189+
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
190+
slice_layer->setInput(2, *shape_layer->getOutput(0));
191+
slice_layer->setName((util::node_info(n) + "_slice").c_str());
192+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
193+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
194+
return {};
195+
}
196+
return tensor_builder(tensor_dims, options);
197+
}
198+
128199
auto aten_registrations TORCHTRT_UNUSED =
129200
RegisterNodeEvaluators()
130201
.evaluator(
@@ -157,6 +228,63 @@ auto aten_registrations TORCHTRT_UNUSED =
157228
auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
158229
return out_tensor;
159230
}})
231+
.evaluator(
232+
{c10::Symbol::fromQualString("aten::new_zeros"),
233+
// aten::new_zeros(Tensor self, int[] size, *, int? dtype=None, int? layout=None,
234+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
235+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
236+
auto tensor_info = newTensorImplementation(n, args);
237+
return torch::zeros(tensor_info.first, tensor_info.second);
238+
}})
239+
.evaluator(
240+
{c10::Symbol::fromQualString("aten::new_ones"),
241+
// aten::new_ones(Tensor self, int[] size, *, int? dtype=None, int? layout=None,
242+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
243+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
244+
auto tensor_info = newTensorImplementation(n, args);
245+
return torch::ones(tensor_info.first, tensor_info.second);
246+
}})
247+
.evaluator(
248+
{c10::Symbol::fromQualString("aten::zeros_like"),
249+
// aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None,
250+
// Device? device=None, bool? pin_memory=None, int? memory_format=None) -> (Tensor)
251+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
252+
return newTensorLikeImplementation(
253+
ctx, n, args, [](const std::vector<int64_t>& dims, const torch::TensorOptions& options) {
254+
return torch::zeros(dims, options);
255+
});
256+
}})
257+
.evaluator(
258+
{c10::Symbol::fromQualString("aten::ones_like"),
259+
// aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None,
260+
// Device? device=None, bool? pin_memory=None, int? memory_format=None) -> (Tensor)
261+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
262+
return newTensorLikeImplementation(
263+
ctx, n, args, [](const std::vector<int64_t>& dims, const torch::TensorOptions& options) {
264+
return torch::ones(dims, options);
265+
});
266+
}})
267+
.evaluator(
268+
{c10::Symbol::fromQualString("aten::fill_"),
269+
// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
270+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
271+
auto tensor_var = args.at(n->input(0));
272+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
273+
std::vector<int64_t> dims;
274+
if (tensor_var.isITensor()) {
275+
auto tensor = tensor_var.ITensor();
276+
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
277+
options = options.dtype(dtype);
278+
dims = util::toVec(tensor->getDimensions());
279+
} else {
280+
auto tensor = tensor_var.unwrapToTensor();
281+
options = options.dtype(tensor.dtype());
282+
dims = tensor.sizes().vec();
283+
}
284+
auto scalar_value = args.at(n->input(1)).unwrapToScalar();
285+
auto out_tensor = torch::full(dims, scalar_value, options);
286+
return out_tensor;
287+
}})
160288
.evaluator(
161289
{c10::Symbol::fromQualString("aten::full"),
162290
// aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None,

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,196 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
207207
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
208208
}
209209

210+
TEST(Evaluators, NewZerosEvaluatesCorrectly) {
211+
const auto graph = R"IR(
212+
graph(%x.1 : Tensor):
213+
%2 : None = prim::Constant() # :0:0
214+
%3 : int[] = aten::size(%x.1) # <string>:7:9
215+
%z.1 : Tensor = aten::new_zeros(%x.1, %3, %2, %2, %2, %2)
216+
return (%z.1))IR";
217+
218+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
219+
220+
auto g = std::make_shared<torch::jit::Graph>();
221+
torch::jit::parseIR(graph, g.get());
222+
223+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
224+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
225+
226+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
227+
}
228+
229+
TEST(Evaluators, NewZerosDataTypeEvaluatesCorrectly) {
230+
const auto graph = R"IR(
231+
graph(%x.1 : Tensor):
232+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
233+
%3 : None = prim::Constant() # :0:0
234+
%4 : int[] = aten::size(%x.1) # <string>:7:9
235+
%z.1 : Tensor = aten::new_zeros(%x.1, %4, %2, %3, %3, %3)
236+
return (%z.1))IR";
237+
238+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
239+
240+
auto g = std::make_shared<torch::jit::Graph>();
241+
torch::jit::parseIR(graph, g.get());
242+
243+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
244+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
245+
246+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
247+
}
248+
249+
TEST(Evaluators, NewOnesEvaluatesCorrectly) {
250+
const auto graph = R"IR(
251+
graph(%x.1 : Tensor):
252+
%2 : None = prim::Constant() # :0:0
253+
%3 : int[] = aten::size(%x.1) # <string>:7:9
254+
%z.1 : Tensor = aten::new_ones(%x.1, %3, %2, %2, %2, %2)
255+
return (%z.1))IR";
256+
257+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
258+
259+
auto g = std::make_shared<torch::jit::Graph>();
260+
torch::jit::parseIR(graph, g.get());
261+
262+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
263+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
264+
265+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
266+
}
267+
268+
TEST(Evaluators, NewOnesDataTypeEvaluatesCorrectly) {
269+
const auto graph = R"IR(
270+
graph(%x.1 : Tensor):
271+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
272+
%3 : None = prim::Constant() # :0:0
273+
%4 : int[] = aten::size(%x.1) # <string>:7:9
274+
%z.1 : Tensor = aten::new_ones(%x.1, %4, %2, %3, %3, %3)
275+
return (%z.1))IR";
276+
277+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
278+
279+
auto g = std::make_shared<torch::jit::Graph>();
280+
torch::jit::parseIR(graph, g.get());
281+
282+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
283+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
284+
285+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
286+
}
287+
288+
TEST(Evaluators, ZerosLikeEvaluatesCorrectly) {
289+
const auto graph = R"IR(
290+
graph(%x.1 : Tensor):
291+
%2 : None = prim::Constant() # :0:0
292+
%z.1 : Tensor = aten::zeros_like(%x.1, %2, %2, %2, %2, %2)
293+
return (%z.1))IR";
294+
295+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
296+
297+
auto g = std::make_shared<torch::jit::Graph>();
298+
torch::jit::parseIR(graph, g.get());
299+
300+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
301+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
302+
303+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
304+
}
305+
306+
TEST(Evaluators, ZerosLikeDataTypeEvaluatesCorrectly) {
307+
const auto graph = R"IR(
308+
graph(%x.1 : Tensor):
309+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
310+
%3 : None = prim::Constant()
311+
%z.1 : Tensor = aten::zeros_like(%x.1, %2, %3, %3, %3, %3)
312+
return (%z.1))IR";
313+
314+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
315+
316+
auto g = std::make_shared<torch::jit::Graph>();
317+
torch::jit::parseIR(graph, g.get());
318+
319+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
320+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
321+
322+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
323+
}
324+
325+
TEST(Evaluators, ZerosLikeDynamic) {
326+
const auto graph = R"IR(
327+
graph(%x.1 : Tensor):
328+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
329+
%3 : None = prim::Constant()
330+
%z.1 : Tensor = aten::zeros_like(%x.1, %2, %3, %3, %3, %3)
331+
return (%z.1))IR";
332+
auto in = at::randint(1, 10, {23, 17, 5, 29}, {at::kCUDA});
333+
334+
auto g = std::make_shared<torch::jit::Graph>();
335+
torch::jit::parseIR(graph, g.get());
336+
337+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
338+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
339+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in});
340+
341+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0]));
342+
}
343+
344+
TEST(Evaluators, OnesLikeEvaluatesCorrectly) {
345+
const auto graph = R"IR(
346+
graph(%x.1 : Tensor):
347+
%2 : None = prim::Constant() # :0:0
348+
%z.1 : Tensor = aten::ones_like(%x.1, %2, %2, %2, %2, %2)
349+
return (%z.1))IR";
350+
351+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
352+
353+
auto g = std::make_shared<torch::jit::Graph>();
354+
torch::jit::parseIR(graph, g.get());
355+
356+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
357+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
358+
359+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
360+
}
361+
362+
TEST(Evaluators, OnesLikeDataTypeEvaluatesCorrectly) {
363+
const auto graph = R"IR(
364+
graph(%x.1 : Tensor):
365+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
366+
%3 : None = prim::Constant()
367+
%z.1 : Tensor = aten::ones_like(%x.1, %2, %3, %3, %3, %3)
368+
return (%z.1))IR";
369+
370+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
371+
372+
auto g = std::make_shared<torch::jit::Graph>();
373+
torch::jit::parseIR(graph, g.get());
374+
375+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
376+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
377+
378+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
379+
}
380+
381+
TEST(Evaluators, OnesLikeDynamic) {
382+
const auto graph = R"IR(
383+
graph(%x.1 : Tensor):
384+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
385+
%3 : None = prim::Constant()
386+
%z.1 : Tensor = aten::ones_like(%x.1, %2, %3, %3, %3, %3)
387+
return (%z.1))IR";
388+
auto in = at::randint(1, 10, {3, 6}, {at::kCUDA});
389+
390+
auto g = std::make_shared<torch::jit::Graph>();
391+
torch::jit::parseIR(graph, g.get());
392+
393+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
394+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
395+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in});
396+
397+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0]));
398+
}
399+
210400
TEST(Evaluators, ATenArangeIntEvaluatesCorrectly) {
211401
const auto graph = R"IR(
212402
graph():

0 commit comments

Comments
 (0)