@@ -53,23 +53,66 @@ void create_plugin(
53
53
LOG_DEBUG (" Normalize layer output tensor shape: " << layer_output->getDimensions ());
54
54
}
55
55
56
- auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
57
- {" aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)" ,
58
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
59
- auto in = args[0 ].ITensor ();
60
- auto in_shape = util::toVec (in->getDimensions ());
61
- auto order = args[1 ].unwrapToScalar ().to <int32_t >();
62
- auto axes_values = args[2 ].unwrapToIntList ().vec ();
63
- std::vector<int32_t > axes (axes_values.begin (), axes_values.end ());
64
- auto keep_dims = (int32_t )args[3 ].unwrapToBool ();
65
- LOG_DEBUG (" Order of normalize_plugin: " << order);
66
- LOG_DEBUG (" Axis: " << axes);
67
- LOG_DEBUG (" keep_dims: " << keep_dims);
68
- create_plugin (ctx, n, in, order, axes, keep_dims, " NormalizePluginTorchTRT" );
69
- return true ;
70
- }
71
-
72
- });
56
+ auto normalize_registrations TORCHTRT_UNUSED =
57
+ RegisterNodeConversionPatterns ()
58
+ .pattern(
59
+ {" aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)" ,
60
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
61
+ auto in = args[0 ].ITensorOrFreeze (ctx);
62
+ auto in_shape = util::toVec (in->getDimensions ());
63
+ auto order = args[1 ].unwrapToScalar ().to <int32_t >();
64
+ auto axes_values = args[2 ].unwrapToIntList ().vec ();
65
+ std::vector<int32_t > axes (axes_values.begin (), axes_values.end ());
66
+ auto keep_dims = (int32_t )args[3 ].unwrapToBool ();
67
+ LOG_DEBUG (" Order of normalize_plugin: " << order);
68
+ LOG_DEBUG (" Axis: " << axes);
69
+ LOG_DEBUG (" keep_dims: " << keep_dims);
70
+ create_plugin (ctx, n, in, order, axes, keep_dims, " NormalizePluginTorchTRT" );
71
+ return true ;
72
+ }
73
+
74
+ })
75
+ .pattern(
76
+ {" aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> (Tensor)" ,
77
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
78
+ auto self = args[0 ].ITensorOrFreeze (ctx);
79
+ auto axes_values = args[1 ].unwrapToIntList ().vec ();
80
+ auto keep_dims = args[2 ].unwrapToBool ();
81
+
82
+ int32_t axes_mask = 0 ;
83
+ auto self_nb_dims = self->getDimensions ().nbDims ;
84
+ for (size_t i = 0UL ; i < axes_values.size (); ++i) {
85
+ auto axis = axes_values[i];
86
+ if (axis < 0 ) {
87
+ axis += self_nb_dims;
88
+ }
89
+ TORCHTRT_CHECK (
90
+ axis < self_nb_dims,
91
+ " aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank" );
92
+ axes_mask += 1 << axis;
93
+ }
94
+
95
+ auto squared_layer = add_elementwise (
96
+ ctx, nvinfer1::ElementWiseOperation::kPROD , self, self, util::node_info (n) + " _squared" );
97
+ TORCHTRT_CHECK (squared_layer, " Unabled to create square layer from node: " << *n);
98
+ auto squared_output = squared_layer->getOutput (0 );
99
+
100
+ auto sum_layer =
101
+ ctx->net ->addReduce (*squared_output, nvinfer1::ReduceOperation::kSUM , axes_mask, keep_dims);
102
+ TORCHTRT_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
103
+ sum_layer->setName ((util::node_info (n) + " _sum" ).c_str ());
104
+ auto sum_output = sum_layer->getOutput (0 );
105
+
106
+ auto sqrt_layer = ctx->net ->addUnary (*sum_output, nvinfer1::UnaryOperation::kSQRT );
107
+ TORCHTRT_CHECK (sqrt_layer, " Unable to create sqrt layer from node: " << *n);
108
+ sqrt_layer->setName ((util::node_info (n) + " _sqrt" ).c_str ());
109
+ auto sqrt_output = sqrt_layer->getOutput (0 );
110
+
111
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sqrt_layer->getOutput (0 ));
112
+
113
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
114
+ return true ;
115
+ }});
73
116
74
117
} // namespace
75
118
} // namespace impl
0 commit comments