Skip to content

Commit 2352296

Browse files
tianleiwuguschmue
authored andcommitted
LayerNormalization broadcast (limited support for axis=2) (#23297)
### Description Spec of LayerNormalization supports broadcasting (tensors Scale and B should be unidirectional broadcastable to tensor X). https://onnx.ai/onnx/operators/onnx__LayerNormalization.html However, current implementation only allow scale and bias size to be X.shape()[axis:]. Example of input tensors that normalized with axis=2: | X shape | Scale shape | B shape | Before | After | | - | - | - | - | - | | (B, S, D) | (D) | (D) | Supported | Supported | | (B, S, D) | (1, 1, D) | (1, 1, D) | Supported | Supported | | (B, S, D) | (B, 1, D) | (B, 1, D) | Not Supported | Supported | | (B, S, D) | (1, S, D) | (1, S, D) | Not Supported | Supported | | (B, S, D) | (B, S, D) | (B, S, D) | Not Supported | Supported | Here we add limited support: axis=2; scale/bias has same shape; scale/bias/X have same number of dimensions. It could support common use case in LLM and vision models. ### Motivation and Context Support Stable Diffusion 3.x and Flux model.
1 parent d0db863 commit 2352296

File tree

9 files changed

+333
-66
lines changed

9 files changed

+333
-66
lines changed

onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
101101
(double)epsilon_, // epsilon
102102
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
103103
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
104+
0, // no broadcast for gamma/beta
104105
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
105106
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
106107
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/framework/tensor_shape.h"
7+
#include "core/common/status.h"
8+
#include "core/common/narrow.h"
9+
10+
namespace onnxruntime {
11+
12+
constexpr const char* kLayerNormInputShapeMismatchError =
13+
"Size of scale and bias (if provided) must match X.shape[axis:], "
14+
"or scale and bias (with same shape) can be broadcasted to X when axis is 2.";
15+
16+
constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";
17+
18+
constexpr int64_t kLayerNormInvalidInput = -1;
19+
20+
struct LayerNormParams {
21+
int64_t num_rows;
22+
int64_t norm_size; // size per row
23+
int64_t scale_size;
24+
int64_t bias_size;
25+
int64_t broadcast_param;
26+
};
27+
28+
// We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns.
29+
// When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S).
30+
// We support scale and bias shape like below:
31+
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
32+
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
33+
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
34+
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
35+
36+
// Below is a macro to compute the offset for scale and bias data for a row of X.
37+
#ifndef LAYER_NORM_SCALE_BIAS_OFFSET
38+
#define LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, x_row, norm_size) \
39+
((broadcast_param == 0) ? 0 \
40+
: norm_size * (broadcast_param > 0 ? x_row / broadcast_param : x_row % (-broadcast_param)))
41+
#endif
42+
43+
class LayerNormHelper {
44+
public:
45+
static Status CheckInputs(const TensorShape& x_shape,
46+
const TensorShape& scale_shape,
47+
const TensorShape& bias_shape,
48+
bool has_bias,
49+
int64_t axis,
50+
LayerNormParams& params) {
51+
params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
52+
params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
53+
params.scale_size = scale_shape.Size();
54+
params.bias_size = bias_shape.Size();
55+
params.broadcast_param = 0;
56+
57+
if (params.norm_size <= 1) {
58+
params.broadcast_param = kLayerNormInvalidInput;
59+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size);
60+
} else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) {
61+
params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
62+
if (params.broadcast_param == kLayerNormInvalidInput) {
63+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
64+
kLayerNormInputShapeMismatchError,
65+
" X.shape=", x_shape,
66+
" scale.shape=", scale_shape,
67+
" bias.shape=", bias_shape,
68+
" and axis=", axis);
69+
}
70+
}
71+
return Status::OK();
72+
}
73+
74+
private:
75+
static int64_t GetBroadcastParam(const TensorShape& x_shape,
76+
const TensorShape& scale_shape,
77+
const TensorShape* bias_shape,
78+
int64_t axis) {
79+
// Note that when size of scale and bias is norm_size, it won't enter this function (see CheckInputs).
80+
81+
// X shape is (B, S, ...)
82+
if (axis == 2 &&
83+
x_shape.NumDimensions() >= 3 &&
84+
x_shape.NumDimensions() == scale_shape.NumDimensions() &&
85+
(bias_shape == nullptr || *bias_shape == scale_shape)) {
86+
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
87+
if (x_shape.GetDims()[i] != scale_shape.GetDims()[i]) {
88+
// scale cannot be broadcasted to X. It is invalid input.
89+
return kLayerNormInvalidInput;
90+
}
91+
}
92+
93+
if (x_shape.GetDims()[0] == scale_shape.GetDims()[0]) {
94+
// scale and bias shape is (B, S, ...).
95+
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
96+
return 1;
97+
}
98+
99+
// scale and bias shape is (B, 1, ...), returns S
100+
if (scale_shape.GetDims()[1] == 1) {
101+
return x_shape.GetDims()[1];
102+
}
103+
} else if (scale_shape.GetDims()[0] == 1) {
104+
// scale and bias shape is (1, S, ...), returns -S
105+
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
106+
return -(x_shape.GetDims()[1]);
107+
}
108+
}
109+
}
110+
111+
// Other cases that are not supported.
112+
return kLayerNormInvalidInput;
113+
}
114+
};
115+
116+
} // namespace onnxruntime

onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc

+31-31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "layer_norm_impl.h"
5+
#include "layer_norm_helper.h"
56

67
#include "core/common/safeint.h"
78
#include "core/framework/tensor.h"
@@ -24,6 +25,7 @@ void ComputeJob(
2425
const T* bias_data,
2526
const ptrdiff_t task_idx,
2627
const int64_t norm_size,
28+
const int64_t broadcast_param,
2729
const float* scale_float_ptr,
2830
const float* bias_float_ptr,
2931
float epsilon,
@@ -55,13 +57,16 @@ void ComputeJob(
5557
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
5658
}
5759

58-
for (int64_t h = 0; h < norm_size; h++) {
60+
// Compute the offset of gamma and beta to support broadcasting.
61+
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);
62+
63+
for (int64_t h = 0; h < norm_size; h++, i++) {
5964
if (simplified) {
60-
p_output[h] = p_output[h] / mean_square * scale_data[h];
65+
p_output[h] = p_output[h] / mean_square * scale_data[i];
6166
} else if (nullptr == bias_data) {
62-
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h];
67+
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i];
6368
} else {
64-
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h];
69+
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i] + bias_data[i];
6570
}
6671
}
6772

@@ -82,6 +87,7 @@ void ComputeJob(
8287
const MLFloat16* bias_data,
8388
const ptrdiff_t task_idx,
8489
const int64_t norm_size,
90+
const int64_t broadcast_param,
8591
const float* scale_float_ptr,
8692
const float* bias_float_ptr,
8793
float epsilon,
@@ -120,13 +126,16 @@ void ComputeJob(
120126
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
121127
}
122128

123-
for (size_t h = 0; h < num_elems; h++) {
129+
// Compute the offset of gamma and beta to support broadcasting.
130+
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);
131+
132+
for (size_t h = 0; h < num_elems; h++, i++) {
124133
if (simplified) {
125-
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
134+
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[i];
126135
} else if (nullptr == bias_float_ptr) {
127-
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
136+
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i];
128137
} else {
129-
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
138+
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i] + bias_float_ptr[i];
130139
}
131140
}
132141

@@ -161,9 +170,7 @@ LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified
161170
simplified_{simplified},
162171
contrib_op_{contrib_op},
163172
prepacked_scale_fp32_data_(nullptr),
164-
prepacked_scale_fp32_size_(0),
165-
prepacked_bias_fp32_data_(nullptr),
166-
prepacked_bias_fp32_size_(0) {
173+
prepacked_bias_fp32_data_(nullptr) {
167174
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
168175
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
169176
}
@@ -179,8 +186,8 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
179186
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();
180187

181188
const TensorShape& x_shape = X->Shape();
182-
size_t scale_size = scale ? static_cast<size_t>(scale->Shape().Size()) : prepacked_scale_fp32_size_;
183-
size_t bias_size = bias ? static_cast<size_t>(bias->Shape().Size()) : prepacked_bias_fp32_size_;
189+
const TensorShape& scale_shape = scale ? scale->Shape() : prepacked_scale_fp32_shape_;
190+
const TensorShape& bias_shape = bias ? bias->Shape() : prepacked_bias_fp32_shape_;
184191
Tensor* Y = p_ctx->Output(0, x_shape);
185192
T* Y_data = Y->MutableData<T>();
186193

@@ -215,7 +222,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
215222

216223
AllocatorPtr alloc;
217224
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
218-
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data,
225+
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data,
219226
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
220227
}
221228

@@ -234,10 +241,10 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
234241

235242
is_packed = false;
236243
if (input_idx == 1) { // scale
237-
prepacked_scale_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
244+
prepacked_scale_fp32_shape_ = tensor.Shape();
238245
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed);
239246
} else if (input_idx == 2) { // bias
240-
prepacked_bias_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
247+
prepacked_bias_fp32_shape_ = tensor.Shape();
241248
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
242249
}
243250

@@ -249,9 +256,9 @@ Status LayerNormImpl::ComputeWithoutContext(
249256
const T* X_data,
250257
const TensorShape& x_shape,
251258
const T* scale_data,
252-
size_t scale_size,
259+
const TensorShape& scale_shape,
253260
const T* bias_data,
254-
size_t bias_size,
261+
const TensorShape& bias_shape,
255262
T* Y_data,
256263
U* mean_data,
257264
U* inv_std_dev_data,
@@ -260,35 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
260267
float epsilon,
261268
bool simplified,
262269
AllocatorPtr alloc) const {
263-
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
264-
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
265-
266-
if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
267-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
268-
"Size of X.shape()[axis:] == ", norm_size,
269-
". Size of scale and bias (if provided) must match this. Got scale size of ",
270-
scale_size, " and bias size of ", bias_size);
271-
}
270+
LayerNormParams params;
271+
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));
272272

273273
IAllocatorUniquePtr<float> scale_fp32;
274274
IAllocatorUniquePtr<float> bias_fp32;
275275
if constexpr (std::is_same_v<T, MLFloat16>) {
276276
if (prepacked_scale_fp32_data_ == nullptr) {
277-
const size_t num_elems = static_cast<size_t>(norm_size);
277+
const size_t num_elems = static_cast<size_t>(params.scale_size);
278278
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
279279
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
280280
}
281281
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
282-
const size_t num_elems = static_cast<size_t>(norm_size);
282+
const size_t num_elems = static_cast<size_t>(params.bias_size);
283283
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
284284
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
285285
}
286286
}
287287

288288
concurrency::ThreadPool::TryBatchParallelFor(
289-
thread_pool, static_cast<int32_t>(norm_count),
289+
thread_pool, static_cast<int32_t>(params.num_rows),
290290
[&](ptrdiff_t task_idx) {
291-
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size,
291+
ComputeJob(X_data, scale_data, bias_data, task_idx, params.norm_size, params.broadcast_param,
292292
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
293293
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
294294
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);

onnxruntime/core/providers/cpu/nn/layer_norm_impl.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel {
2424
const T* X_data,
2525
const TensorShape& x_shape,
2626
const T* scale_data,
27-
size_t scale_size,
27+
const TensorShape& scale_shape,
2828
const T* bias_data,
29-
size_t bias_size,
29+
const TensorShape& bias_shape,
3030
T* Y_data,
3131
U* mean_data,
3232
U* inv_std_dev,
@@ -64,9 +64,9 @@ class LayerNormImpl : public OpKernel {
6464
const bool simplified_;
6565
const bool contrib_op_;
6666
IAllocatorUniquePtr<float> prepacked_scale_fp32_data_;
67-
size_t prepacked_scale_fp32_size_;
67+
TensorShape prepacked_scale_fp32_shape_;
6868
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
69-
size_t prepacked_bias_fp32_size_;
69+
TensorShape prepacked_bias_fp32_shape_;
7070
};
7171

7272
} // namespace onnxruntime

onnxruntime/core/providers/cuda/nn/layer_norm.cc

+15-17
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/providers/shared_library/provider_api.h"
55
#include "core/providers/cuda/nn/layer_norm.h"
66
#include "core/providers/cuda/nn/layer_norm_impl.h"
7+
#include "core/providers/cpu/nn/layer_norm_helper.h"
78
#include "core/providers/cuda/cuda_common.h"
89

910
namespace onnxruntime {
@@ -44,28 +45,22 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
4445
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());
4546

4647
const TensorShape& x_shape = X->Shape();
47-
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
48-
49-
int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
50-
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));
51-
52-
const auto scale_size = scale->Shape().Size();
53-
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;
54-
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
55-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
56-
"Size of X.shape()[axis:] == ", n2,
57-
". Size of scale and bias (if provided) must match this "
58-
"and the size must not be 1. Got scale size of ",
59-
scale_size, " and bias size of ", bias_size);
60-
}
48+
auto x_num_dims = x_shape.NumDimensions();
49+
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);
50+
51+
const TensorShape& scale_shape = scale->Shape();
52+
const TensorShape& bias_shape = bias_data ? bias->Shape() : TensorShape();
53+
54+
LayerNormParams params;
55+
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));
6156

6257
// Outputs
6358
Tensor* Y = ctx->Output(0, x_shape);
6459
auto Y_data = reinterpret_cast<CudaV*>(Y->MutableData<V>());
6560

6661
// Mean and variance
6762
std::vector<int64_t> mean_inv_std_var_dim;
68-
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
63+
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
6964
if (i < axis) {
7065
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
7166
} else {
@@ -93,8 +88,11 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
9388
return Status::OK();
9489
}
9590

96-
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
97-
X_data, n1, n2, epsilon_, scale_data, bias_data);
91+
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(
92+
GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, X_data,
93+
onnxruntime::narrow<int>(params.num_rows), onnxruntime::narrow<int>(params.norm_size), epsilon_,
94+
scale_data, bias_data,
95+
onnxruntime::narrow<int>(params.broadcast_param));
9896
CUDA_RETURN_IF_ERROR(cudaGetLastError());
9997
return Status::OK();
10098
}

0 commit comments

Comments
 (0)