2
2
// Licensed under the MIT License.
3
3
4
4
#include " layer_norm_impl.h"
5
+ #include " layer_norm_helper.h"
5
6
6
7
#include " core/common/safeint.h"
7
8
#include " core/framework/tensor.h"
@@ -24,6 +25,7 @@ void ComputeJob(
24
25
const T* bias_data,
25
26
const ptrdiff_t task_idx,
26
27
const int64_t norm_size,
28
+ const int64_t broadcast_param,
27
29
const float * scale_float_ptr,
28
30
const float * bias_float_ptr,
29
31
float epsilon,
@@ -55,13 +57,16 @@ void ComputeJob(
55
57
mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
56
58
}
57
59
58
- for (int64_t h = 0 ; h < norm_size; h++) {
60
+ // Compute the offset of gamma and beta to support broadcasting.
61
+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
62
+
63
+ for (int64_t h = 0 ; h < norm_size; h++, i++) {
59
64
if (simplified) {
60
- p_output[h] = p_output[h] / mean_square * scale_data[h ];
65
+ p_output[h] = p_output[h] / mean_square * scale_data[i ];
61
66
} else if (nullptr == bias_data) {
62
- p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h ];
67
+ p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i ];
63
68
} else {
64
- p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h ] + bias_data[h ];
69
+ p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i ] + bias_data[i ];
65
70
}
66
71
}
67
72
@@ -82,6 +87,7 @@ void ComputeJob(
82
87
const MLFloat16* bias_data,
83
88
const ptrdiff_t task_idx,
84
89
const int64_t norm_size,
90
+ const int64_t broadcast_param,
85
91
const float * scale_float_ptr,
86
92
const float * bias_float_ptr,
87
93
float epsilon,
@@ -120,13 +126,16 @@ void ComputeJob(
120
126
mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
121
127
}
122
128
123
- for (size_t h = 0 ; h < num_elems; h++) {
129
+ // Compute the offset of gamma and beta to support broadcasting.
130
+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
131
+
132
+ for (size_t h = 0 ; h < num_elems; h++, i++) {
124
133
if (simplified) {
125
- output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h ];
134
+ output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[i ];
126
135
} else if (nullptr == bias_float_ptr) {
127
- output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h ];
136
+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i ];
128
137
} else {
129
- output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h ] + bias_float_ptr[h ];
138
+ output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i ] + bias_float_ptr[i ];
130
139
}
131
140
}
132
141
@@ -161,9 +170,7 @@ LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified
161
170
simplified_{simplified},
162
171
contrib_op_{contrib_op},
163
172
prepacked_scale_fp32_data_ (nullptr ),
164
- prepacked_scale_fp32_size_ (0 ),
165
- prepacked_bias_fp32_data_ (nullptr ),
166
- prepacked_bias_fp32_size_ (0 ) {
173
+ prepacked_bias_fp32_data_ (nullptr ) {
167
174
ORT_ENFORCE (op_kernel_info.GetAttr (" axis" , &axis_).IsOK ());
168
175
ORT_ENFORCE (op_kernel_info.GetAttr <float >(" epsilon" , &epsilon_).IsOK ());
169
176
}
@@ -179,8 +186,8 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
179
186
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data <T>();
180
187
181
188
const TensorShape& x_shape = X->Shape ();
182
- size_t scale_size = scale ? static_cast < size_t >( scale->Shape (). Size ()) : prepacked_scale_fp32_size_ ;
183
- size_t bias_size = bias ? static_cast < size_t >( bias->Shape (). Size ()) : prepacked_bias_fp32_size_ ;
189
+ const TensorShape& scale_shape = scale ? scale->Shape () : prepacked_scale_fp32_shape_ ;
190
+ const TensorShape& bias_shape = bias ? bias->Shape () : prepacked_bias_fp32_shape_ ;
184
191
Tensor* Y = p_ctx->Output (0 , x_shape);
185
192
T* Y_data = Y->MutableData <T>();
186
193
@@ -215,7 +222,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
215
222
216
223
AllocatorPtr alloc;
217
224
ORT_RETURN_IF_ERROR (p_ctx->GetTempSpaceAllocator (&alloc));
218
- return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size , bias_data, bias_size , Y_data, mean_data,
225
+ return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape , bias_data, bias_shape , Y_data, mean_data,
219
226
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
220
227
}
221
228
@@ -234,10 +241,10 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
234
241
235
242
is_packed = false ;
236
243
if (input_idx == 1 ) { // scale
237
- prepacked_scale_fp32_size_ = static_cast < size_t >( tensor.Shape (). Size () );
244
+ prepacked_scale_fp32_shape_ = tensor.Shape ();
238
245
ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_scale_fp32_data_, is_packed);
239
246
} else if (input_idx == 2 ) { // bias
240
- prepacked_bias_fp32_size_ = static_cast < size_t >( tensor.Shape (). Size () );
247
+ prepacked_bias_fp32_shape_ = tensor.Shape ();
241
248
ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_bias_fp32_data_, is_packed);
242
249
}
243
250
@@ -249,9 +256,9 @@ Status LayerNormImpl::ComputeWithoutContext(
249
256
const T* X_data,
250
257
const TensorShape& x_shape,
251
258
const T* scale_data,
252
- size_t scale_size ,
259
+ const TensorShape& scale_shape ,
253
260
const T* bias_data,
254
- size_t bias_size ,
261
+ const TensorShape& bias_shape ,
255
262
T* Y_data,
256
263
U* mean_data,
257
264
U* inv_std_dev_data,
@@ -260,35 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
260
267
float epsilon,
261
268
bool simplified,
262
269
AllocatorPtr alloc) const {
263
- int64_t norm_count = x_shape.SizeToDimension (onnxruntime::narrow<size_t >(axis));
264
- int64_t norm_size = x_shape.SizeFromDimension (onnxruntime::narrow<size_t >(axis));
265
-
266
- if (static_cast <int64_t >(scale_size) != norm_size || (bias_data && static_cast <int64_t >(bias_size) != norm_size)) {
267
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
268
- " Size of X.shape()[axis:] == " , norm_size,
269
- " . Size of scale and bias (if provided) must match this. Got scale size of " ,
270
- scale_size, " and bias size of " , bias_size);
271
- }
270
+ LayerNormParams params;
271
+ ORT_RETURN_IF_ERROR (LayerNormHelper::CheckInputs (x_shape, scale_shape, bias_shape, bias_data != nullptr , axis, params));
272
272
273
273
IAllocatorUniquePtr<float > scale_fp32;
274
274
IAllocatorUniquePtr<float > bias_fp32;
275
275
if constexpr (std::is_same_v<T, MLFloat16>) {
276
276
if (prepacked_scale_fp32_data_ == nullptr ) {
277
- const size_t num_elems = static_cast <size_t >(norm_size );
277
+ const size_t num_elems = static_cast <size_t >(params. scale_size );
278
278
scale_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
279
279
MlasConvertHalfToFloatBuffer (scale_data, scale_fp32.get (), num_elems);
280
280
}
281
281
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
282
- const size_t num_elems = static_cast <size_t >(norm_size );
282
+ const size_t num_elems = static_cast <size_t >(params. bias_size );
283
283
bias_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
284
284
MlasConvertHalfToFloatBuffer (bias_data, bias_fp32.get (), num_elems);
285
285
}
286
286
}
287
287
288
288
concurrency::ThreadPool::TryBatchParallelFor (
289
- thread_pool, static_cast <int32_t >(norm_count ),
289
+ thread_pool, static_cast <int32_t >(params. num_rows ),
290
290
[&](ptrdiff_t task_idx) {
291
- ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size,
291
+ ComputeJob (X_data, scale_data, bias_data, task_idx, params. norm_size , params. broadcast_param ,
292
292
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get () : scale_fp32.get (),
293
293
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
294
294
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
0 commit comments