Skip to content

Commit d6d70fe

Browse files
Add converter support for aten::frobenius_norm (#1422)
1 parent 5d9ddaa commit d6d70fe

File tree

2 files changed

+124
-17
lines changed

2 files changed

+124
-17
lines changed

core/conversion/converters/impl/normalize.cpp

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,66 @@ void create_plugin(
5353
LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions());
5454
}
5555

56-
auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
57-
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
58-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
59-
auto in = args[0].ITensor();
60-
auto in_shape = util::toVec(in->getDimensions());
61-
auto order = args[1].unwrapToScalar().to<int32_t>();
62-
auto axes_values = args[2].unwrapToIntList().vec();
63-
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
64-
auto keep_dims = (int32_t)args[3].unwrapToBool();
65-
LOG_DEBUG("Order of normalize_plugin: " << order);
66-
LOG_DEBUG("Axis: " << axes);
67-
LOG_DEBUG("keep_dims: " << keep_dims);
68-
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT");
69-
return true;
70-
}
71-
72-
});
56+
auto normalize_registrations TORCHTRT_UNUSED =
57+
RegisterNodeConversionPatterns()
58+
.pattern(
59+
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
60+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
61+
auto in = args[0].ITensorOrFreeze(ctx);
62+
auto in_shape = util::toVec(in->getDimensions());
63+
auto order = args[1].unwrapToScalar().to<int32_t>();
64+
auto axes_values = args[2].unwrapToIntList().vec();
65+
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
66+
auto keep_dims = (int32_t)args[3].unwrapToBool();
67+
LOG_DEBUG("Order of normalize_plugin: " << order);
68+
LOG_DEBUG("Axis: " << axes);
69+
LOG_DEBUG("keep_dims: " << keep_dims);
70+
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePluginTorchTRT");
71+
return true;
72+
}
73+
74+
})
75+
.pattern(
76+
{"aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> (Tensor)",
77+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
78+
auto self = args[0].ITensorOrFreeze(ctx);
79+
auto axes_values = args[1].unwrapToIntList().vec();
80+
auto keep_dims = args[2].unwrapToBool();
81+
82+
int32_t axes_mask = 0;
83+
auto self_nb_dims = self->getDimensions().nbDims;
84+
for (size_t i = 0UL; i < axes_values.size(); ++i) {
85+
auto axis = axes_values[i];
86+
if (axis < 0) {
87+
axis += self_nb_dims;
88+
}
89+
TORCHTRT_CHECK(
90+
axis < self_nb_dims,
91+
"aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank");
92+
axes_mask += 1 << axis;
93+
}
94+
95+
auto squared_layer = add_elementwise(
96+
ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared");
97+
TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n);
98+
auto squared_output = squared_layer->getOutput(0);
99+
100+
auto sum_layer =
101+
ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims);
102+
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
103+
sum_layer->setName((util::node_info(n) + "_sum").c_str());
104+
auto sum_output = sum_layer->getOutput(0);
105+
106+
auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT);
107+
TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n);
108+
sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str());
109+
auto sqrt_output = sqrt_layer->getOutput(0);
110+
111+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sqrt_layer->getOutput(0));
112+
113+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
114+
return true;
115+
}});
73116

74117
} // namespace
75118
} // namespace impl

tests/core/conversion/converters/test_normalize.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,67 @@ ATEN_INTERPOLATE_TESTS(
7575
%5 : Tensor = aten::norm(%x.1, %3, %2, %4)
7676
return (%5))IR",
7777
std::vector<int64_t>({3, 4, 3}));
78+
79+
TEST(Converters, ATenFrobeniusNorm) {
80+
const auto graph = R"IR(
81+
graph(%x : Tensor):
82+
%0 : int = prim::Constant[value=0]()
83+
%keep : bool = prim::Constant[value=0]()
84+
%dims : int[] = prim::ListConstruct(%0)
85+
%out : Tensor = aten::frobenius_norm(%x, %dims, %keep)
86+
return (%out))IR";
87+
auto g = std::make_shared<torch::jit::Graph>();
88+
torch::jit::parseIR(graph, g.get());
89+
90+
auto x = at::randn({5, 5, 5}, {at::kCUDA});
91+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
92+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});
93+
94+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
95+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});
96+
97+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
98+
}
99+
100+
TEST(Converters, ATenFrobeniusNormKeep) {
101+
const auto graph = R"IR(
102+
graph(%x : Tensor):
103+
%1 : int = prim::Constant[value=-1]()
104+
%keep : bool = prim::Constant[value=1]()
105+
%dims : int[] = prim::ListConstruct(%1)
106+
%out : Tensor = aten::frobenius_norm(%x, %dims, %keep)
107+
return (%out))IR";
108+
auto g = std::make_shared<torch::jit::Graph>();
109+
torch::jit::parseIR(graph, g.get());
110+
111+
auto x = at::randn({5, 5, 5}, {at::kCUDA});
112+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
113+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});
114+
115+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
116+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});
117+
118+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
119+
}
120+
121+
TEST(Converters, ATenFrobeniusNormMatrix) {
122+
const auto graph = R"IR(
123+
graph(%x : Tensor):
124+
%0 : int = prim::Constant[value=0]()
125+
%1 : int = prim::Constant[value=-1]()
126+
%keep : bool = prim::Constant[value=0]()
127+
%dims : int[] = prim::ListConstruct(%0, %1)
128+
%out : Tensor = aten::frobenius_norm(%x, %dims, %keep)
129+
return (%out))IR";
130+
auto g = std::make_shared<torch::jit::Graph>();
131+
torch::jit::parseIR(graph, g.get());
132+
133+
auto x = at::randn({3, 5, 7, 11, 13}, {at::kCUDA});
134+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
135+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});
136+
137+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
138+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});
139+
140+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
141+
}

0 commit comments

Comments
 (0)