Skip to content

Commit 52f351c

Browse files
committed
Fix implementation of Rotary
1 parent 1c9c4a4 commit 52f351c

File tree

5 files changed

+95
-35
lines changed

5 files changed

+95
-35
lines changed

operators/cuda/cuda_ops.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "cuda/add_mul.h"
88
#include "cuda/fast_gelu.h"
99
#include "cuda/negxplus1.h"
10+
#include "cuda/rotary.h"
1011
#endif
1112

1213
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
@@ -28,13 +29,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
2829
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
2930
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
3031
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
32+
CustomCudaStructV2("Rotary", contrib::Rotary<float>),
3133
#if ORT_API_VERSION >= 16
3234

3335
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
3436
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
3537
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
3638
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
37-
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
39+
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
40+
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>)
3841
#endif
3942
#endif
4043
);

operators/cuda/rotary.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@ namespace contrib {
1111
template <typename T>
1212
struct Rotary {
1313
template <typename TDict>
14-
OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
15-
std::string side;
16-
auto status = OrtW::GetOpAttribute(info, "side", side);
17-
if (!status) {
18-
return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."};
19-
}
14+
OrtxStatus OnModelAttach(const TDict& dict) {
15+
std::string empty;
16+
std::string side = dict.TryToGetAttributeWithDefault("side", empty);
2017
if (side == "left") {
2118
side_ = RotarySide::LEFT;
2219
}
@@ -45,15 +42,14 @@ struct Rotary {
4542
if (shape_split.size() != 1 || shape_split[0] != 2) {
4643
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."};
4744
}
48-
if (shape_split[0] != shape_split[1]) {
45+
const int64_t* split_data = split.Data();
46+
if (split_data[0] != split_data[1]) {
4947
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."};
5048
}
51-
if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) {
49+
if (split_data[0] * 2 != input_shape[input_shape.size()-1]) {
5250
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
5351
}
5452

55-
const int64_t* split_data = split.Data();
56-
5753
LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
5854
input_length,
5955
static_cast<int>(input_shape[input_shape.size()-1]),

operators/cuda/rotary_impl.cu

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33

44
#include "device_prop.cuh"
55
#include "utils.cuh"
6-
#include "Rotary_impl.cuh"
6+
#include "rotary_impl.cuh"
77
#include "cuda_type.h"
88

9+
#ifndef CUDA_LONG
10+
#define CUDA_LONG int32_t
11+
#endif
12+
913
using namespace Ort::Custom;
1014

1115
template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }
@@ -34,46 +38,44 @@ __global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half
3438

3539
template <typename T>
3640
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
37-
const T* input, const int64_t* split_data, T* output, RotarySide side) {
38-
constexpr int blockSize = 256;
39-
const int gridSize = (input_length + blockSize - 1) / blockSize;
41+
const T* input_data, const int64_t* /* split_data */, T* output_data, RotarySide side) {
4042
if (input_length == 0)
41-
return;
43+
return cudaGetLastError();
4244
using TT = typename contrib::CudaT<T>::MappedType;
4345

44-
CUDA_LONG N = static_cast<CUDA_LONG>(count);
46+
CUDA_LONG N = static_cast<CUDA_LONG>(input_length);
4547
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim);
4648

47-
const int num_threads_per_block = GridDim::maxThreadsPerBlock;
49+
const int num_threads_per_block = 256;
4850
const int num_elements_per_thread =
4951
(N / 2 + num_threads_per_block - 1) / num_threads_per_block;
5052

5153
switch (side) {
5254
case RotarySide::LEFT:
53-
RotaryKernel<T, RotarySide::LEFT>
54-
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
55+
RotaryKernel<TT, RotarySide::LEFT>
56+
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
57+
reinterpret_cast<const TT*>(input_data),
5558
N / 2, stride / 2);
5659
break;
5760
case RotarySide::RIGHT:
58-
RotaryKernel<T, RotarySide::RIGHT>
59-
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
61+
RotaryKernel<TT, RotarySide::RIGHT>
62+
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
63+
reinterpret_cast<const TT*>(input_data),
6064
N / 2, stride / 2);
6165
break;
6266
}
63-
64-
RotaryKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
6567
return cudaGetLastError();
6668
}
6769

6870
template <>
6971
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim,
70-
const float* input, const int64_t* split_data, float* output, RotarySide side) {
71-
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
72+
const float* input_data, const int64_t* split_data, float* output_data, RotarySide side) {
73+
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
7274
}
7375

7476
template <>
7577
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim,
76-
const ortc::MFloat16* input, const int64_t* split_data,
77-
ortc::MFloat16* output, RotarySide side) {
78-
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
78+
const ortc::MFloat16* input_data, const int64_t* split_data,
79+
ortc::MFloat16* output_data, RotarySide side) {
80+
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
7981
}

operators/cuda/roatry_impl.cuh renamed to operators/cuda/rotary_impl.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ enum class RotarySide : int {
1212

1313
template <typename T>
1414
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
15-
const T* input, const int64_t* split_data, T* output, RotarySide side);
15+
const T* input_data, const int64_t* split_data, T* output_data, RotarySide side);

test/cuda/test_cudaops.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def test_cuda_negxplus1(self):
151151
self._negxplus1_cuda(TensorProto.FLOAT16)
152152

153153
def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)):
154-
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs
155-
156154
model1 = helper.make_model(
157155
helper.make_graph(
158156
[
@@ -181,7 +179,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
181179
f"{op_type}SharedInput",
182180
["X", "Y", "Z"],
183181
["XY", "XZ"],
184-
domain="onnx_extended.ortops.optim.cuda",
182+
domain="ai.onnx.contrib",
185183
)
186184
],
187185
"nd",
@@ -197,7 +195,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
197195
),
198196
opset_imports=[
199197
helper.make_opsetid("", 18),
200-
helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
198+
helper.make_opsetid("ai.onnx.contrib", 1),
201199
],
202200
ir_version=9,
203201
)
@@ -212,7 +210,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
212210
expected = ref.run(None, feeds1)
213211

214212
opts = _ort.SessionOptions()
215-
opts.register_custom_ops_library(get_ort_ext_libs()[0])
213+
opts.register_custom_ops_library(_get_library_path())
216214
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
217215
got = sess.run(None, feeds1)
218216
for i in range(2):
@@ -262,6 +260,67 @@ def test_add_shared_input_cuda_broadcast2(self):
262260
shapec=(3, 2, 3),
263261
)
264262

263+
def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)):
264+
model2 = helper.make_model(
265+
helper.make_graph(
266+
[
267+
helper.make_node(
268+
"Rotary",
269+
["X", "splits"],
270+
["Y"],
271+
domain="ai.onnx.contrib",
272+
side=side,
273+
)
274+
],
275+
"nd",
276+
[
277+
helper.make_tensor_value_info("X", itype, [None, None, None, None]),
278+
helper.make_tensor_value_info("splits", TensorProto.INT64, [2]),
279+
],
280+
[helper.make_tensor_value_info("Y", itype, [None, None, None, None])],
281+
),
282+
opset_imports=[
283+
helper.make_opsetid("", 18),
284+
helper.make_opsetid("ai.onnx.contrib", 1),
285+
],
286+
ir_version=9,
287+
)
288+
289+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
290+
x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype)
291+
splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64)
292+
293+
expected = x.copy()
294+
half = x.shape[-1] // 2
295+
if side == "left":
296+
expected[:, :, :, :half] = x[:, :, :, half:]
297+
expected[:, :, :, half:] = -x[:, :, :, :half]
298+
else:
299+
expected[:, :, :, :half] = -x[:, :, :, half:]
300+
expected[:, :, :, half:] = x[:, :, :, :half]
301+
302+
feeds = dict(X=x, splits=splits)
303+
opts = _ort.SessionOptions()
304+
opts.register_custom_ops_library(_get_library_path())
305+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
306+
got = sess.run(None, feeds)[0]
307+
assert_almost_equal(expected, got)
308+
309+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
310+
def test_rotary_cuda(self):
311+
self._rotary_cuda(TensorProto.FLOAT, "left")
312+
self._rotary_cuda(TensorProto.FLOAT, "right")
313+
self._rotary_cuda(TensorProto.FLOAT16, "left")
314+
self._rotary_cuda(TensorProto.FLOAT16, "right")
315+
316+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
317+
def test_bigger_rotary_cuda(self):
318+
sh = (2, 2, 1024, 8)
319+
self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh)
320+
self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh)
321+
self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh)
322+
self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh)
323+
265324

266325
if __name__ == "__main__":
267326
unittest.main()

0 commit comments

Comments
 (0)