Skip to content

Commit 95a49fa

Browse files
xaduprewenbingl
andauthored
Add kernel NegXPlus1 = 1 - X (#709)
* first draft for NegXPlus1 * complete * fix unit test * rename one test * remove test if not cuda --------- Co-authored-by: Wenbing Li <[email protected]>
1 parent 1eaf5ca commit 95a49fa

File tree

5 files changed

+189
-61
lines changed

5 files changed

+189
-61
lines changed

operators/cuda/cuda_ops.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#ifdef USE_CUDA
77
#include "cuda/fast_gelu.h"
8+
#include "cuda/negxplus1.h"
89
#endif
910

1011
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
@@ -13,10 +14,12 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
1314
#ifdef USE_CUDA
1415
,
1516
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
17+
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
1618
#if ORT_API_VERSION >= 16
1719

1820
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
19-
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>)
21+
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
22+
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
2023
#endif
2124
#endif
2225
);

operators/cuda/negxplus1.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "negxplus1_impl.cuh"
7+
8+
namespace contrib {
9+
10+
template <typename T>
11+
struct NegXPlus1 {
12+
template <typename TDict>
13+
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
14+
return nullptr;
15+
}
16+
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
17+
const ortc::Tensor<T>& input,
18+
ortc::Tensor<T>& output) const {
19+
const T* input_data = input.Data();
20+
T* output_data = output.Allocate(input.Shape());
21+
auto input_length = input.NumberOfElement();
22+
if (0 == input_length) {
23+
return nullptr;
24+
}
25+
LaunchNegXPlus1Kernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
26+
input_length,
27+
input_data,
28+
output_data);
29+
return nullptr;
30+
}
31+
};
32+
33+
} // namespace contrib

operators/cuda/negxplus1_impl.cu

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "device_prop.cuh"
5+
#include "utils.cuh"
6+
#include "negxplus1_impl.cuh"
7+
#include "cuda_type.h"
8+
9+
using namespace Ort::Custom;
10+
11+
template <typename T>
12+
__device__ __inline__ T _negxplus1(const T x) {
13+
return (T)1 - x;
14+
}
15+
16+
template <>
17+
__device__ __inline__ half _negxplus1(const half x) {
18+
#if __CUDA_ARCH__ < 700
19+
return __float2half(1 - __half2float(x));
20+
#else
21+
return (half)1 - x;
22+
#endif
23+
}
24+
25+
template <typename T>
26+
__global__ void NegXPlus1Kernel(T* output_data, const T* input_data, int N) {
27+
int id = blockDim.x * blockIdx.x + threadIdx.x;
28+
if (id >= N)
29+
return;
30+
output_data[id] = _negxplus1(input_data[id]);
31+
}
32+
33+
template <typename T>
34+
cudaError_t _LaunchNegXPlus1Kernel(cudaStream_t stream, int input_length, const T* input, T* output) {
35+
constexpr int blockSize = 256;
36+
const int gridSize = (input_length + blockSize - 1) / blockSize;
37+
using TT = typename contrib::CudaT<T>::MappedType;
38+
NegXPlus1Kernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
39+
return cudaGetLastError();
40+
}
41+
42+
template <>
43+
cudaError_t LaunchNegXPlus1Kernel<float>(cudaStream_t stream, int input_length, const float* input, float* output) {
44+
return _LaunchNegXPlus1Kernel(stream, input_length, input, output);
45+
}
46+
47+
template <>
48+
cudaError_t LaunchNegXPlus1Kernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) {
49+
return _LaunchNegXPlus1Kernel(stream, input_length, input, output);
50+
}

operators/cuda/negxplus1_impl.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
template <typename T>
9+
cudaError_t LaunchNegXPlus1Kernel(cudaStream_t stream, int input_length, const T* input, T* output);

test/cuda/test_cudaops.py

Lines changed: 93 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
import unittest
22
import numpy as np
33
from numpy.testing import assert_almost_equal
4-
from onnx import helper, onnx_pb as onnx_proto
4+
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
5+
from onnx.reference import ReferenceEvaluator
6+
from onnx.reference.op_run import OpRun
57
from onnxruntime_extensions import make_onnx_model
68
from onnxruntime_extensions import get_library_path as _get_library_path
79

810
import onnxruntime as _ort
911

1012

13+
class NegXPlus1(OpRun):
14+
op_domain = "ai.onnx.contrib"
15+
16+
def _run(self, X):
17+
return (1 - X,)
18+
19+
1120
class TestCudaOps(unittest.TestCase):
1221
@staticmethod
13-
def _create_negpos_test_model(domain='ai.onnx.contrib'):
22+
def _create_negpos_test_model(domain="ai.onnx.contrib"):
1423
nodes = [
15-
helper.make_node('Identity', ['x'], ['identity1']),
16-
helper.make_node(
17-
'NegPos', ['identity1'], ['neg', 'pos'],
18-
domain=domain)
24+
helper.make_node("Identity", ["x"], ["identity1"]),
25+
helper.make_node("NegPos", ["identity1"], ["neg", "pos"], domain=domain),
1926
]
2027

21-
input0 = helper.make_tensor_value_info(
22-
'x', onnx_proto.TensorProto.FLOAT, [None, None])
23-
output1 = helper.make_tensor_value_info(
24-
'neg', onnx_proto.TensorProto.FLOAT, [None, None])
25-
output2 = helper.make_tensor_value_info(
26-
'pos', onnx_proto.TensorProto.FLOAT, [None, None])
28+
input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, [None, None])
29+
output1 = helper.make_tensor_value_info("neg", onnx_proto.TensorProto.FLOAT, [None, None])
30+
output2 = helper.make_tensor_value_info("pos", onnx_proto.TensorProto.FLOAT, [None, None])
2731

28-
graph = helper.make_graph(nodes, 'test0', [input0], [output1, output2])
32+
graph = helper.make_graph(nodes, "test0", [input0], [output1, output2])
2933
model = make_onnx_model(graph)
3034
return model
3135

@@ -34,87 +38,116 @@ def test_cuda_negpos(self):
3438
so.register_custom_ops_library(_get_library_path())
3539
onnx_model = self._create_negpos_test_model()
3640
self.assertIn('op_type: "NegPos"', str(onnx_model))
37-
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
38-
so,
39-
providers=['CUDAExecutionProvider'])
40-
x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32)
41-
neg, pos = sess.run(None, {'x': x})
41+
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"])
42+
x = np.array([[0.0, 1.0, 1.5], [7.0, 8.0, -5.5]]).astype(np.float32)
43+
neg, pos = sess.run(None, {"x": x})
4244
diff = x - (neg + pos)
4345
assert_almost_equal(diff, np.zeros(diff.shape))
4446

4547
@staticmethod
46-
def _create_fastgelu_test_model(domain='ai.onnx.contrib'):
47-
nodes = [
48-
helper.make_node(
49-
'FastGelu', ['x', 'bias'], ['y'],
50-
domain=domain)
51-
]
48+
def _create_fastgelu_test_model(domain="ai.onnx.contrib"):
49+
nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)]
5250

53-
input0 = helper.make_tensor_value_info(
54-
'x', onnx_proto.TensorProto.FLOAT, [])
55-
input1 = helper.make_tensor_value_info(
56-
'bias', onnx_proto.TensorProto.FLOAT, [])
57-
output0 = helper.make_tensor_value_info(
58-
'y', onnx_proto.TensorProto.FLOAT, [])
51+
input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, [])
52+
input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT, [])
53+
output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT, [])
5954

60-
graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0])
55+
graph = helper.make_graph(nodes, "test1", [input0, input1], [output0])
6156
model = make_onnx_model(graph)
6257
return model
6358

6459
@staticmethod
65-
def _create_fastgelu_test_model_f16(domain='ai.onnx.contrib'):
66-
nodes = [
67-
helper.make_node(
68-
'FastGelu', ['x', 'bias'], ['y'],
69-
domain=domain)
70-
]
60+
def _create_fastgelu_test_model_f16(domain="ai.onnx.contrib"):
61+
nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)]
7162

72-
input0 = helper.make_tensor_value_info(
73-
'x', onnx_proto.TensorProto.FLOAT16, [])
74-
input1 = helper.make_tensor_value_info(
75-
'bias', onnx_proto.TensorProto.FLOAT16, [])
76-
output0 = helper.make_tensor_value_info(
77-
'y', onnx_proto.TensorProto.FLOAT16, [])
63+
input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT16, [])
64+
input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT16, [])
65+
output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT16, [])
7866

79-
graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0])
67+
graph = helper.make_graph(nodes, "test1", [input0, input1], [output0])
8068
model = make_onnx_model(graph)
8169
return model
8270

8371
def test_cuda_fastgelu(self):
8472
eps = _ort.get_available_providers()
85-
if 'CUDAExecutionProvider' in eps:
73+
if "CUDAExecutionProvider" in eps:
8674
so = _ort.SessionOptions()
8775
so.register_custom_ops_library(_get_library_path())
8876
onnx_model = self._create_fastgelu_test_model()
8977
self.assertIn('op_type: "FastGelu"', str(onnx_model))
90-
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
91-
so,
92-
providers=['CUDAExecutionProvider'])
93-
x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float32)
78+
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"])
79+
x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float32)
9480
bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32)
95-
expected_y = np.array([0., 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32)
96-
y = sess.run(None, {'x': x, 'bias':bias})[0]
81+
expected_y = np.array([0.0, 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32)
82+
y = sess.run(None, {"x": x, "bias": bias})[0]
9783
assert_almost_equal(y, expected_y)
9884
else:
99-
print ('CUDAExecutionProvider not available, test_cuda_fastgelu skipped.')
85+
print("CUDAExecutionProvider not available, test_cuda_fastgelu skipped.")
10086

10187
def test_cuda_fastgelu_f16(self):
10288
eps = _ort.get_available_providers()
103-
if 'CUDAExecutionProvider' in eps:
89+
if "CUDAExecutionProvider" in eps:
10490
so = _ort.SessionOptions()
10591
so.register_custom_ops_library(_get_library_path())
10692
onnx_model = self._create_fastgelu_test_model_f16()
10793
self.assertIn('op_type: "FastGelu"', str(onnx_model))
108-
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
109-
so,
110-
providers=['CUDAExecutionProvider'])
111-
x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float16)
94+
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"])
95+
x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float16)
11296
bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float16)
113-
expected_y = np.array([0., 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16)
114-
y = sess.run(None, {'x': x, 'bias':bias})[0]
97+
expected_y = np.array([0.0, 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16)
98+
y = sess.run(None, {"x": x, "bias": bias})[0]
11599
assert_almost_equal(y, expected_y)
116100
else:
117-
print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.')
101+
print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.")
102+
103+
def _negxplus1_cuda(self, itype):
104+
import onnxruntime
105+
106+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
107+
model1 = helper.make_model(
108+
helper.make_graph(
109+
[helper.make_node("Sub", ["one", "X"], ["Y"])],
110+
"nd",
111+
[helper.make_tensor_value_info("X", itype, [None, None, None])],
112+
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
113+
[numpy_helper.from_array(np.array([1], dtype=dtype), name="one")],
114+
),
115+
opset_imports=[helper.make_opsetid("", 18)],
116+
ir_version=9,
117+
)
118+
119+
model2 = helper.make_model(
120+
helper.make_graph(
121+
[helper.make_node("NegXPlus1", ["X"], ["Y"], domain="ai.onnx.contrib")],
122+
"nd",
123+
[helper.make_tensor_value_info("X", itype, [None, None, None])],
124+
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
125+
),
126+
opset_imports=[
127+
helper.make_opsetid("", 18),
128+
helper.make_opsetid("ai.onnx.contrib", 1),
129+
],
130+
ir_version=9,
131+
)
132+
133+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
134+
x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype)
135+
136+
feeds1 = dict(X=x)
137+
ref = ReferenceEvaluator(model1, new_ops=[NegXPlus1])
138+
expected = ref.run(None, feeds1)[0]
139+
140+
opts = onnxruntime.SessionOptions()
141+
opts.register_custom_ops_library(_get_library_path())
142+
sess = onnxruntime.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
143+
got = sess.run(None, feeds1)[0]
144+
assert_almost_equal(expected, got, decimal=5)
145+
146+
def test_cuda_negxplus1(self):
147+
eps = _ort.get_available_providers()
148+
if "CUDAExecutionProvider" in eps:
149+
self._negxplus1_cuda(TensorProto.FLOAT)
150+
self._negxplus1_cuda(TensorProto.FLOAT16)
118151

119152

120153
if __name__ == "__main__":

0 commit comments

Comments
 (0)