Skip to content

Commit ee01f92

Browse files
committed
refactoring
1 parent 64381ae commit ee01f92

File tree

4 files changed

+73
-78
lines changed

4 files changed

+73
-78
lines changed

onnxruntime/core/providers/cpu/nn/layer_norm_helper.h

+53-17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "core/framework/tensor_shape.h"
77
#include "core/common/status.h"
8+
#include "core/common/narrow.h"
89

910
namespace onnxruntime {
1011

@@ -14,24 +15,57 @@ constexpr const char* kLayerNormInputShapeMismatchError =
1415

1516
constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";
1617

18+
constexpr int64_t kLayerNormInvalidInput = -1;
19+
20+
struct LayerNormParams {
21+
int64_t num_rows;
22+
int64_t norm_size; // size per row
23+
int64_t scale_size;
24+
int64_t bias_size;
25+
int64_t broadcast_param;
26+
};
27+
28+
// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
29+
// We support scale and bias shape like below:
30+
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
31+
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
32+
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
33+
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
34+
// Below is a macro to compute the initial index for scale and bias data.
35+
#ifndef LAYER_NORM_SCALE_BIAS_OFFSET
36+
#define LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, row_idx, norm_size) \
37+
((broadcast_param == 0) ? 0 \
38+
: norm_size * (broadcast_param > 0 ? row_idx / broadcast_param : row_idx % (-broadcast_param)))
39+
#endif
40+
1741
class LayerNormHelper {
1842
public:
19-
static Status CheckBroadcast(const TensorShape& x_shape,
20-
const TensorShape& scale_shape,
21-
const TensorShape& bias_shape,
22-
bool has_bias,
23-
int64_t axis,
24-
int64_t& broadcast_param) {
25-
broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
26-
if (broadcast_param == 0) {
27-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
28-
kLayerNormInputShapeMismatchError,
29-
" X.shape=", x_shape,
30-
" scale.shape=", scale_shape,
31-
" bias.shape=", bias_shape,
32-
" and axis=", axis);
33-
}
43+
static Status CheckInputs(const TensorShape& x_shape,
44+
const TensorShape& scale_shape,
45+
const TensorShape& bias_shape,
46+
bool has_bias,
47+
int64_t axis,
48+
LayerNormParams& params) {
49+
params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
50+
params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
51+
params.scale_size = scale_shape.Size();
52+
params.bias_size = bias_shape.Size();
53+
params.broadcast_param = 0;
3454

55+
if (params.norm_size <= 1) {
56+
params.broadcast_param = kLayerNormInvalidInput;
57+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size);
58+
} else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) {
59+
params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
60+
if (params.broadcast_param == kLayerNormInvalidInput) {
61+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
62+
kLayerNormInputShapeMismatchError,
63+
" X.shape=", x_shape,
64+
" scale.shape=", scale_shape,
65+
" bias.shape=", bias_shape,
66+
" and axis=", axis);
67+
}
68+
}
3569
return Status::OK();
3670
}
3771

@@ -47,7 +81,8 @@ class LayerNormHelper {
4781
(bias_shape == nullptr || *bias_shape == scale_shape)) {
4882
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
4983
if (x_shape.GetDims()[i] != scale_shape.GetDims()[i]) {
50-
return 0;
84+
// scale cannot be broadcasted to X. It is invalid input.
85+
return kLayerNormInvalidInput;
5186
}
5287
}
5388

@@ -69,7 +104,8 @@ class LayerNormHelper {
69104
}
70105
}
71106

72-
return 0;
107+
// Other cases that are not supported.
108+
return kLayerNormInvalidInput;
73109
}
74110
};
75111

onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc

+10-36
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,8 @@ void ComputeJob(
5757
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
5858
}
5959

60-
// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
61-
// We support scale and bias shape like below:
62-
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
63-
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
64-
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
65-
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
66-
// Here we compute the initial index for scale and bias data.
67-
int64_t i = (broadcast_param == 0)
68-
? 0
69-
: norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
60+
// Compute the offset of gamma and beta to support broadcasting.
61+
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);
7062

7163
for (int64_t h = 0; h < norm_size; h++, i++) {
7264
if (simplified) {
@@ -134,16 +126,8 @@ void ComputeJob(
134126
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
135127
}
136128

137-
// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
138-
// We support scale and bias shape like below:
139-
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
140-
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
141-
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
142-
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
143-
// Here we compute the initial index for scale and bias data.
144-
int64_t i = (broadcast_param == 0)
145-
? 0
146-
: norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
129+
// Compute the offset of gamma and beta to support broadcasting.
130+
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);
147131

148132
for (size_t h = 0; h < num_elems; h++, i++) {
149133
if (simplified) {
@@ -283,38 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
283267
float epsilon,
284268
bool simplified,
285269
AllocatorPtr alloc) const {
286-
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
287-
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
288-
289-
int64_t scale_size = scale_shape.Size();
290-
int64_t bias_size = bias_shape.Size();
291-
int64_t broadcast_param = 0;
292-
293-
if (norm_size <= 1) {
294-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, norm_size);
295-
} else if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
296-
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckBroadcast(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, broadcast_param));
297-
}
270+
LayerNormParams params;
271+
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));
298272

299273
IAllocatorUniquePtr<float> scale_fp32;
300274
IAllocatorUniquePtr<float> bias_fp32;
301275
if constexpr (std::is_same_v<T, MLFloat16>) {
302276
if (prepacked_scale_fp32_data_ == nullptr) {
303-
const size_t num_elems = static_cast<size_t>(scale_size);
277+
const size_t num_elems = static_cast<size_t>(params.scale_size);
304278
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
305279
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
306280
}
307281
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
308-
const size_t num_elems = static_cast<size_t>(bias_size);
282+
const size_t num_elems = static_cast<size_t>(params.bias_size);
309283
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
310284
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
311285
}
312286
}
313287

314288
concurrency::ThreadPool::TryBatchParallelFor(
315-
thread_pool, static_cast<int32_t>(norm_count),
289+
thread_pool, static_cast<int32_t>(params.num_rows),
316290
[&](ptrdiff_t task_idx) {
317-
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, broadcast_param,
291+
ComputeJob(X_data, scale_data, bias_data, task_idx, params.norm_size, params.broadcast_param,
318292
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
319293
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
320294
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);

onnxruntime/core/providers/cuda/nn/layer_norm.cc

+7-14
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,11 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
4848
auto x_num_dims = x_shape.NumDimensions();
4949
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);
5050

51-
int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
52-
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));
53-
5451
const TensorShape& scale_shape = scale->Shape();
55-
5652
const TensorShape& bias_shape = bias_data ? bias->Shape() : TensorShape();
5753

58-
int64_t broadcast_param = 0;
59-
if (n2 <= 1) {
60-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, n2);
61-
} else if (scale_shape.Size() != n2 || (bias_data && bias_shape.Size() != n2)) {
62-
// Check if scale and bias can be broadcasted to X (only limited cases are supported).
63-
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckBroadcast(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, broadcast_param));
64-
}
54+
LayerNormParams params;
55+
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));
6556

6657
// Outputs
6758
Tensor* Y = ctx->Output(0, x_shape);
@@ -97,9 +88,11 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
9788
return Status::OK();
9889
}
9990

100-
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
101-
X_data, n1, n2, epsilon_, scale_data, bias_data,
102-
gsl::narrow_cast<int>(broadcast_param));
91+
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(
92+
GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, X_data,
93+
onnxruntime::narrow<int>(params.num_rows), onnxruntime::narrow<int>(params.norm_size), epsilon_,
94+
scale_data, bias_data,
95+
onnxruntime::narrow<int>(params.broadcast_param));
10396
CUDA_RETURN_IF_ERROR(cudaGetLastError());
10497
return Status::OK();
10598
}

onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu

+3-11
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
/* Modifications Copyright (c) Microsoft. */
2424

2525
#include "core/providers/cuda/cu_inc/common.cuh"
26-
2726
#include "layer_norm_impl.h"
27+
#include "core/providers/cpu/nn/layer_norm_helper.h"
2828

2929
namespace onnxruntime {
3030
namespace cuda {
@@ -355,16 +355,8 @@ __global__ void cuApplyLayerNorm(
355355
T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr;
356356
U c_inv_std_dev = rsqrt(sigma2 + epsilon);
357357

358-
// When X shape is (B, S, ...), and i1 is in the range of [0, B * S).
359-
// We support scale and bias shape like below:
360-
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
361-
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
362-
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
363-
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
364-
// Here we compute the offset of gamma and beta (assuming they have same shape) to support broadcasting.
365-
int gamma_beta_offset = (broadcast_param == 0)
366-
? 0
367-
: n2 * (broadcast_param > 0 ? (i1 / broadcast_param) : (i1 % (-broadcast_param)));
358+
// Compute the offset of gamma and beta to support broadcasting.
359+
int gamma_beta_offset = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, i1, n2);
368360

369361
const int numx = blockDim.x * blockDim.y;
370362
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;

0 commit comments

Comments
 (0)