Skip to content

Commit 1e8c121

Browse files
authored
Add custom kernels AddSharedInput, MulSharedInput (#734)
* Add custom kernel AddSharedInput, MulSharedInput * fix compilation * compilation issue * fix unit test
1 parent c9bba37 commit 1e8c121

File tree

5 files changed

+314
-8
lines changed

5 files changed

+314
-8
lines changed

operators/cuda/add_mul.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "add_mul_impl.cuh"
7+
#include "ortx_common.h"
8+
9+
namespace contrib {
10+
11+
template <typename T, bool addition>
12+
struct AddOrMulSharedInput {
13+
template <typename TDict>
14+
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
15+
return {};
16+
}
17+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
18+
const ortc::Tensor<T>& tensor_a,
19+
const ortc::Tensor<T>& tensor_b,
20+
const ortc::Tensor<T>& tensor_c,
21+
ortc::Tensor<T>& output_ab,
22+
ortc::Tensor<T>& output_ac) const {
23+
const T* input_data_a = tensor_a.Data();
24+
const T* input_data_b = tensor_b.Data();
25+
const T* input_data_c = tensor_c.Data();
26+
27+
auto length_a = tensor_a.NumberOfElement();
28+
auto length_b = tensor_b.NumberOfElement();
29+
auto length_c = tensor_c.NumberOfElement();
30+
31+
T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape());
32+
T* output_data_ac = output_ab.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());
33+
34+
if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
35+
return {};
36+
}
37+
LaunchAddOrMulSharedInputKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
38+
input_data_a, input_data_b, input_data_c,
39+
output_data_ab, output_data_ac,
40+
length_a, length_b, length_c,
41+
addition);
42+
return {};
43+
}
44+
};
45+
46+
} // namespace contrib

operators/cuda/add_mul_impl.cu

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "device_prop.cuh"
5+
#include "utils.cuh"
6+
#include "add_mul_impl.cuh"
7+
#include "cuda_type.h"
8+
9+
#ifndef CUDA_LONG
10+
#define CUDA_LONG int32_t
11+
#endif
12+
13+
using namespace Ort::Custom;
14+
15+
__device__ __forceinline__ void _add3_op(float* ab, float* ac, const float a, const float b, const float c) {
16+
*ab = a + b;
17+
*ac = a + c;
18+
}
19+
20+
__device__ __forceinline__ void _add3_op(half* ab, half* ac, const half a, const half b, const half c) {
21+
#if __CUDA_ARCH__ < 700
22+
*ab = __float2half(__half2float(a) + __half2float(b));
23+
*ac = __float2half(__half2float(a) + __half2float(c));
24+
#else
25+
*ab = a + b;
26+
*ac = a + c;
27+
#endif
28+
}
29+
30+
__device__ __forceinline__ void _mul3_op(float* ab, float* ac, const float a, const float b, const float c) {
31+
*ab = a * b;
32+
*ac = a * c;
33+
}
34+
35+
__device__ __forceinline__ void _mul3_op(half* ab, half* ac, const half a, const half b, const half c) {
36+
#if __CUDA_ARCH__ < 700
37+
*ab = __float2half(__half2float(a) * __half2float(b));
38+
*ac = __float2half(__half2float(a) * __half2float(c));
39+
#else
40+
*ab = a * b;
41+
*ac = a * c;
42+
#endif
43+
}
44+
45+
template <typename T>
46+
struct Mul3SharedOp {
47+
__device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const {
48+
_mul3_op(ab, ac, a, b, c);
49+
}
50+
};
51+
52+
template <typename T>
53+
struct Add3SharedOp {
54+
__device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const {
55+
_add3_op(ab, ac, a, b, c);
56+
}
57+
};
58+
59+
template <typename T, typename TFunc, int NumThreadsPerBlock, int NumElementsPerThread>
60+
__global__ void AddMulKernel(T* output_ab, T* output_ac, const T* pA, const T* pB,
61+
const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC,
62+
CUDA_LONG N, const TFunc func) {
63+
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
64+
CUDA_LONG id = start;
65+
#pragma unroll
66+
for (int i = 0; i < NumElementsPerThread; i++) {
67+
if (id < N) {
68+
func(output_ab + id, output_ac + id, pA[id % nA], pB[id % nB], pC[id % nC]);
69+
id += NumThreadsPerBlock;
70+
}
71+
}
72+
}
73+
74+
template <typename T>
75+
cudaError_t _LaunchAddOrMulSharedInputKernel(cudaStream_t stream,
76+
const T* pA, const T* pB, const T* pC,
77+
T* output_ab, T* output_ac,
78+
int64_t countA, int64_t countB, int64_t countC, bool addition) {
79+
int64_t max_count = std::max(std::max(countA, countB), countC);
80+
if (max_count == 0) // special case where there's a dim value of 0 in the output shape
81+
return cudaGetLastError();
82+
83+
const int num_elements_per_thread = 4;
84+
const int num_threads_per_block = 256;
85+
const int num_el_th = num_threads_per_block * num_elements_per_thread;
86+
87+
int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th;
88+
89+
using TT = typename contrib::CudaT<T>::MappedType;
90+
91+
if (addition) {
92+
AddMulKernel<TT, Add3SharedOp<TT>, num_threads_per_block, num_elements_per_thread>
93+
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
94+
reinterpret_cast<TT*>(output_ab), reinterpret_cast<TT*>(output_ac),
95+
reinterpret_cast<const TT*>(pA), reinterpret_cast<const TT*>(pB), reinterpret_cast<const TT*>(pC), static_cast<CUDA_LONG>(countA),
96+
static_cast<CUDA_LONG>(countB), static_cast<CUDA_LONG>(countC),
97+
static_cast<CUDA_LONG>(max_count), Add3SharedOp<TT>());
98+
} else {
99+
AddMulKernel<TT, Mul3SharedOp<TT>, num_threads_per_block, num_elements_per_thread>
100+
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
101+
reinterpret_cast<TT*>(output_ab), reinterpret_cast<TT*>(output_ac),
102+
reinterpret_cast<const TT*>(pA), reinterpret_cast<const TT*>(pB), reinterpret_cast<const TT*>(pC), static_cast<CUDA_LONG>(countA),
103+
static_cast<CUDA_LONG>(countB), static_cast<CUDA_LONG>(countC),
104+
static_cast<CUDA_LONG>(max_count), Mul3SharedOp<TT>());
105+
}
106+
return cudaGetLastError();
107+
}
108+
109+
template <>
110+
cudaError_t LaunchAddOrMulSharedInputKernel<float>(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c,
111+
float* output_ab, float* output_ac,
112+
int64_t length_a, int64_t length_b, int64_t length_c, bool addition) {
113+
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition);
114+
}
115+
116+
template <>
117+
cudaError_t LaunchAddOrMulSharedInputKernel<ortc::MFloat16>(cudaStream_t stream, const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c,
118+
ortc::MFloat16* output_ab, ortc::MFloat16* output_ac,
119+
int64_t length_a, int64_t length_b, int64_t length_c, bool addition) {
120+
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition);
121+
}

operators/cuda/add_mul_impl.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
template <typename T>
9+
cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
10+
T* output_ab, T* output_ac,
11+
int64_t length_a, int64_t length_b, int64_t length_c, bool addition);

operators/cuda/cuda_ops.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,36 @@
44
#include "ocos.h"
55

66
#ifdef USE_CUDA
7+
#include "cuda/add_mul.h"
78
#include "cuda/fast_gelu.h"
89
#include "cuda/negxplus1.h"
910
#endif
1011

1112
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
13+
14+
using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, true>;
15+
using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, false>;
16+
17+
#if ORT_API_VERSION >= 16
18+
using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, true>;
19+
using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, false>;
20+
#endif
21+
22+
1223
static OrtOpLoader op_loader(
1324
[]() { return nullptr; }
1425
#ifdef USE_CUDA
1526
,
27+
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
1628
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
29+
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
1730
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
1831
#if ORT_API_VERSION >= 16
1932

33+
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
2034
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
2135
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
36+
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
2237
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
2338
#endif
2439
#endif

test/cuda/test_cudaops.py

Lines changed: 121 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import onnxruntime as _ort
1111

1212

13+
def has_cuda():
14+
return "CUDAExecutionProvider" in _ort.get_available_providers()
15+
16+
1317
class NegXPlus1(OpRun):
1418
op_domain = "ai.onnx.contrib"
1519

@@ -101,8 +105,6 @@ def test_cuda_fastgelu_f16(self):
101105
print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.")
102106

103107
def _negxplus1_cuda(self, itype):
104-
import onnxruntime
105-
106108
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
107109
model1 = helper.make_model(
108110
helper.make_graph(
@@ -137,17 +139,128 @@ def _negxplus1_cuda(self, itype):
137139
ref = ReferenceEvaluator(model1, new_ops=[NegXPlus1])
138140
expected = ref.run(None, feeds1)[0]
139141

140-
opts = onnxruntime.SessionOptions()
142+
opts = _ort.SessionOptions()
141143
opts.register_custom_ops_library(_get_library_path())
142-
sess = onnxruntime.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
144+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
143145
got = sess.run(None, feeds1)[0]
144146
assert_almost_equal(expected, got, decimal=5)
145147

148+
@unittest.skipIf(not has_cuda(), reason="CUDA is missing")
146149
def test_cuda_negxplus1(self):
147-
eps = _ort.get_available_providers()
148-
if "CUDAExecutionProvider" in eps:
149-
self._negxplus1_cuda(TensorProto.FLOAT)
150-
self._negxplus1_cuda(TensorProto.FLOAT16)
150+
self._negxplus1_cuda(TensorProto.FLOAT)
151+
self._negxplus1_cuda(TensorProto.FLOAT16)
152+
153+
def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)):
154+
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs
155+
156+
model1 = helper.make_model(
157+
helper.make_graph(
158+
[
159+
helper.make_node(op_type, ["X", "Y"], ["XY"]),
160+
helper.make_node(op_type, ["X", "Z"], ["XZ"]),
161+
],
162+
"nd",
163+
[
164+
helper.make_tensor_value_info("X", itype, [None, None, None]),
165+
helper.make_tensor_value_info("Y", itype, [None, None, None]),
166+
helper.make_tensor_value_info("Z", itype, [None, None, None]),
167+
],
168+
[
169+
helper.make_tensor_value_info("XY", itype, [None, None, None]),
170+
helper.make_tensor_value_info("XZ", itype, [None, None, None]),
171+
],
172+
),
173+
opset_imports=[helper.make_opsetid("", 18)],
174+
ir_version=9,
175+
)
176+
177+
model2 = helper.make_model(
178+
helper.make_graph(
179+
[
180+
helper.make_node(
181+
f"{op_type}SharedInput",
182+
["X", "Y", "Z"],
183+
["XY", "XZ"],
184+
domain="onnx_extended.ortops.optim.cuda",
185+
)
186+
],
187+
"nd",
188+
[
189+
helper.make_tensor_value_info("X", itype, [None, None, None]),
190+
helper.make_tensor_value_info("Y", itype, [None, None, None]),
191+
helper.make_tensor_value_info("Z", itype, [None, None, None]),
192+
],
193+
[
194+
helper.make_tensor_value_info("XY", itype, [None, None, None]),
195+
helper.make_tensor_value_info("XZ", itype, [None, None, None]),
196+
],
197+
),
198+
opset_imports=[
199+
helper.make_opsetid("", 18),
200+
helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
201+
],
202+
ir_version=9,
203+
)
204+
205+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
206+
x = (np.arange(np.prod(shapea)) + 1).reshape((shapea)).astype(dtype)
207+
y = (np.arange(np.prod(shapeb)) + 2).reshape((shapeb)).astype(dtype)
208+
z = (np.arange(np.prod(shapec)) + 3).reshape((shapec)).astype(dtype)
209+
210+
feeds1 = dict(X=x, Y=y, Z=z)
211+
ref = ReferenceEvaluator(model1)
212+
expected = ref.run(None, feeds1)
213+
214+
opts = _ort.SessionOptions()
215+
opts.register_custom_ops_library(get_ort_ext_libs()[0])
216+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
217+
got = sess.run(None, feeds1)
218+
for i in range(2):
219+
assert_almost_equal(expected[i], got[i])
220+
221+
@unittest.skipIf(not has_cuda(), reason="CUDA is missing")
222+
def test_add_shared_input_cuda(self):
223+
self._addmul_shared_input_cuda(TensorProto.FLOAT, "Add")
224+
self._addmul_shared_input_cuda(TensorProto.FLOAT16, "Add")
225+
226+
@unittest.skipIf(not has_cuda(), reason="CUDA is missing")
227+
def test_mul_shared_input_cuda(self):
228+
self._addmul_shared_input_cuda(TensorProto.FLOAT, "Mul")
229+
self._addmul_shared_input_cuda(TensorProto.FLOAT16, "Mul")
230+
231+
@unittest.skipIf(not has_cuda(), reason="CUDA is missing")
232+
def test_add_shared_input_cuda_broadcast1(self):
233+
self._addmul_shared_input_cuda(
234+
TensorProto.FLOAT,
235+
"Add",
236+
shapea=(3, 2, 3),
237+
shapeb=(1, 2, 3),
238+
shapec=(1, 2, 3),
239+
)
240+
self._addmul_shared_input_cuda(
241+
TensorProto.FLOAT16,
242+
"Add",
243+
shapea=(3, 2, 3),
244+
shapeb=(1, 2, 3),
245+
shapec=(1, 2, 3),
246+
)
247+
248+
@unittest.skipIf(not has_cuda(), reason="CUDA is missing")
249+
def test_add_shared_input_cuda_broadcast2(self):
250+
self._addmul_shared_input_cuda(
251+
TensorProto.FLOAT,
252+
"Add",
253+
shapea=(1, 2, 3),
254+
shapeb=(3, 2, 3),
255+
shapec=(3, 2, 3),
256+
)
257+
self._addmul_shared_input_cuda(
258+
TensorProto.FLOAT16,
259+
"Add",
260+
shapea=(1, 2, 3),
261+
shapeb=(3, 2, 3),
262+
shapec=(3, 2, 3),
263+
)
151264

152265

153266
if __name__ == "__main__":

0 commit comments

Comments
 (0)