@@ -57,16 +57,8 @@ void ComputeJob(
57
57
mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
58
58
}
59
59
60
- // When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
61
- // We support scale and bias shape like below:
62
- // When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
63
- // When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
64
- // When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
65
- // When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
66
- // Here we compute the initial index for scale and bias data.
67
- int64_t i = (broadcast_param == 0 )
68
- ? 0
69
- : norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
60
+ // Compute the offset of gamma and beta to support broadcasting.
61
+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
70
62
71
63
for (int64_t h = 0 ; h < norm_size; h++, i++) {
72
64
if (simplified) {
@@ -134,16 +126,8 @@ void ComputeJob(
134
126
mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
135
127
}
136
128
137
- // When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
138
- // We support scale and bias shape like below:
139
- // When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
140
- // When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
141
- // When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
142
- // When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
143
- // Here we compute the initial index for scale and bias data.
144
- int64_t i = (broadcast_param == 0 )
145
- ? 0
146
- : norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
129
+ // Compute the offset of gamma and beta to support broadcasting.
130
+ int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET (broadcast_param, task_idx, norm_size);
147
131
148
132
for (size_t h = 0 ; h < num_elems; h++, i++) {
149
133
if (simplified) {
@@ -283,38 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
283
267
float epsilon,
284
268
bool simplified,
285
269
AllocatorPtr alloc) const {
286
- int64_t norm_count = x_shape.SizeToDimension (onnxruntime::narrow<size_t >(axis));
287
- int64_t norm_size = x_shape.SizeFromDimension (onnxruntime::narrow<size_t >(axis));
288
-
289
- int64_t scale_size = scale_shape.Size ();
290
- int64_t bias_size = bias_shape.Size ();
291
- int64_t broadcast_param = 0 ;
292
-
293
- if (norm_size <= 1 ) {
294
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize , norm_size);
295
- } else if (static_cast <int64_t >(scale_size) != norm_size || (bias_data && static_cast <int64_t >(bias_size) != norm_size)) {
296
- ORT_RETURN_IF_ERROR (LayerNormHelper::CheckBroadcast (x_shape, scale_shape, bias_shape, bias_data != nullptr , axis, broadcast_param));
297
- }
270
+ LayerNormParams params;
271
+ ORT_RETURN_IF_ERROR (LayerNormHelper::CheckInputs (x_shape, scale_shape, bias_shape, bias_data != nullptr , axis, params));
298
272
299
273
IAllocatorUniquePtr<float > scale_fp32;
300
274
IAllocatorUniquePtr<float > bias_fp32;
301
275
if constexpr (std::is_same_v<T, MLFloat16>) {
302
276
if (prepacked_scale_fp32_data_ == nullptr ) {
303
- const size_t num_elems = static_cast <size_t >(scale_size);
277
+ const size_t num_elems = static_cast <size_t >(params. scale_size );
304
278
scale_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
305
279
MlasConvertHalfToFloatBuffer (scale_data, scale_fp32.get (), num_elems);
306
280
}
307
281
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
308
- const size_t num_elems = static_cast <size_t >(bias_size);
282
+ const size_t num_elems = static_cast <size_t >(params. bias_size );
309
283
bias_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
310
284
MlasConvertHalfToFloatBuffer (bias_data, bias_fp32.get (), num_elems);
311
285
}
312
286
}
313
287
314
288
concurrency::ThreadPool::TryBatchParallelFor (
315
- thread_pool, static_cast <int32_t >(norm_count ),
289
+ thread_pool, static_cast <int32_t >(params. num_rows ),
316
290
[&](ptrdiff_t task_idx) {
317
- ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size, broadcast_param,
291
+ ComputeJob (X_data, scale_data, bias_data, task_idx, params. norm_size , params. broadcast_param ,
318
292
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get () : scale_fp32.get (),
319
293
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
320
294
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
0 commit comments