Skip to content

Commit a20c24a

Browse files
committed
address review comments
1 parent f2b20be commit a20c24a

File tree

3 files changed

+78
-71
lines changed

3 files changed

+78
-71
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -125,77 +125,6 @@ DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
125125
int64_t,
126126
{"aten::__round_to_zero_floordiv(int a, int b) -> (int)"});
127127

128-
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args) {
129-
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
130-
131-
// Input 2 is the dtype
132-
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
133-
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
134-
} else {
135-
auto tensor_var = args.at(n->input(0));
136-
if (tensor_var.isITensor()) {
137-
auto tensor = tensor_var.ITensor();
138-
options = options.dtype(scalarTypeToTypeMeta(util::TRTDataTypeToScalarType(tensor->getType())));
139-
} else {
140-
auto tensor = tensor_var.unwrapToTensor();
141-
options = options.dtype(tensor.dtype());
142-
}
143-
}
144-
return std::make_pair(args.at(n->input(1)).unwrapToIntList().vec(), options);
145-
}
146-
147-
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
148-
ConversionCtx* ctx,
149-
const torch::jit::Node* n,
150-
kwargs& args,
151-
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder) {
152-
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
153-
auto tensor_var = args.at(n->input(0));
154-
155-
if (tensor_var.isITensor()) {
156-
auto tensor = tensor_var.ITensor();
157-
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
158-
options = options.dtype(dtype);
159-
} else {
160-
auto tensor = tensor_var.unwrapToTensor();
161-
options = options.dtype(tensor.dtype());
162-
}
163-
164-
// Input 1 is the dtype
165-
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
166-
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
167-
}
168-
std::vector<int64_t> tensor_dims;
169-
if (tensor_var.isITensor()) {
170-
auto tensor = tensor_var.ITensor();
171-
tensor_dims = util::toVec(tensor->getDimensions());
172-
} else {
173-
auto tensor = tensor_var.unwrapToTensor();
174-
tensor_dims = tensor.sizes().vec();
175-
}
176-
if (ctx->input_is_dynamic) {
177-
auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
178-
std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
179-
auto constant = tensor_builder(dims_vec, options);
180-
auto constant_itensor = converters::tensor_to_const(ctx, constant);
181-
// broadcast constant to output shape
182-
std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
183-
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
184-
auto shape_layer = ctx->net->addShape(*self);
185-
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
186-
shape_layer->setName((util::node_info(n) + "_shape").c_str());
187-
// slice implements expand
188-
auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset);
189-
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
190-
slice_layer->setInput(2, *shape_layer->getOutput(0));
191-
slice_layer->setName((util::node_info(n) + "_slice").c_str());
192-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
193-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
194-
return {};
195-
}
196-
return tensor_builder(tensor_dims, options);
197-
}
198-
199128
auto aten_registrations TORCHTRT_UNUSED =
200129
RegisterNodeEvaluators()
201130
.evaluator(

core/conversion/evaluators/eval_util.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,77 @@ at::Tensor createTensorFromList(
367367
return tensor;
368368
}
369369

370+
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args) {
371+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
372+
373+
// Input 2 is the dtype
374+
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
375+
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
376+
} else {
377+
auto tensor_var = args.at(n->input(0));
378+
if (tensor_var.isITensor()) {
379+
auto tensor = tensor_var.ITensor();
380+
options = options.dtype(scalarTypeToTypeMeta(util::TRTDataTypeToScalarType(tensor->getType())));
381+
} else {
382+
auto tensor = tensor_var.unwrapToTensor();
383+
options = options.dtype(tensor.dtype());
384+
}
385+
}
386+
return std::make_pair(args.at(n->input(1)).unwrapToIntList().vec(), options);
387+
}
388+
389+
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
390+
ConversionCtx* ctx,
391+
const torch::jit::Node* n,
392+
kwargs& args,
393+
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder) {
394+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
395+
auto tensor_var = args.at(n->input(0));
396+
397+
if (tensor_var.isITensor()) {
398+
auto tensor = tensor_var.ITensor();
399+
auto dtype = util::TRTDataTypeToScalarType(tensor->getType());
400+
options = options.dtype(dtype);
401+
} else {
402+
auto tensor = tensor_var.unwrapToTensor();
403+
options = options.dtype(tensor.dtype());
404+
}
405+
406+
// Input 1 is the dtype
407+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
408+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
409+
}
410+
std::vector<int64_t> tensor_dims;
411+
if (tensor_var.isITensor()) {
412+
auto tensor = tensor_var.ITensor();
413+
tensor_dims = util::toVec(tensor->getDimensions());
414+
} else {
415+
auto tensor = tensor_var.unwrapToTensor();
416+
tensor_dims = tensor.sizes().vec();
417+
}
418+
if (ctx->settings.allow_shape_tensors && ctx->input_is_dynamic) {
419+
auto self = args.at(n->input(0)).ITensorOrFreeze(ctx);
420+
std::vector<int64_t> dims_vec(self->getDimensions().nbDims, 1);
421+
auto constant = tensor_builder(dims_vec, options);
422+
auto constant_itensor = converters::tensor_to_const(ctx, constant);
423+
// broadcast constant to output shape
424+
std::vector<int64_t> start_vec(self->getDimensions().nbDims, 0);
425+
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
426+
auto shape_layer = ctx->net->addShape(*self);
427+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
428+
shape_layer->setName((util::node_info(n) + "_shape").c_str());
429+
// slice implements expand
430+
auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset);
431+
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
432+
slice_layer->setInput(2, *shape_layer->getOutput(0));
433+
slice_layer->setName((util::node_info(n) + "_slice").c_str());
434+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
435+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
436+
return {};
437+
}
438+
return tensor_builder(tensor_dims, options);
439+
}
440+
370441
} // namespace evaluators
371442
} // namespace conversion
372443
} // namespace core

core/conversion/evaluators/eval_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size);
2626

2727
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);
2828

29+
std::pair<std::vector<int64_t>, torch::TensorOptions> newTensorImplementation(const torch::jit::Node* n, kwargs& args);
30+
c10::optional<torch::jit::IValue> newTensorLikeImplementation(
31+
ConversionCtx* ctx,
32+
const torch::jit::Node* n,
33+
kwargs& args,
34+
const std::function<torch::Tensor(const std::vector<int64_t>&, const torch::TensorOptions&)>& tensor_builder);
35+
2936
} // namespace evaluators
3037
} // namespace conversion
3138
} // namespace core

0 commit comments

Comments
 (0)