Skip to content

Commit 79f3b04

Browse files
authored
Add custom op Transpose2DCast (#737)
* Add custom op Transpose2DCast * fix compilation issues * fix compilation issues
1 parent 1e8c121 commit 79f3b04

File tree

6 files changed

+185
-8
lines changed

6 files changed

+185
-8
lines changed

operators/cuda/add_mul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct AddOrMulSharedInput {
2929
auto length_c = tensor_c.NumberOfElement();
3030

3131
T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape());
32-
T* output_data_ac = output_ab.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());
32+
T* output_data_ac = output_ac.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());
3333

3434
if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
3535
return {};

operators/cuda/cuda_ops.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "cuda/add_mul.h"
88
#include "cuda/fast_gelu.h"
99
#include "cuda/negxplus1.h"
10+
#include "cuda/transpose_cast.h"
1011
#endif
1112

1213
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
@@ -17,6 +18,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
1718
#if ORT_API_VERSION >= 16
1819
using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, true>;
1920
using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, false>;
21+
using Transpose2DCastFloat32ToFloat16Type = typename contrib::Transpose2DCast<float, ortc::MFloat16>;
22+
using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast<ortc::MFloat16, float>;
2023
#endif
2124

2225

@@ -34,7 +37,9 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
3437
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
3538
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
3639
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
37-
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
40+
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
41+
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
42+
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
3843
#endif
3944
#endif
4045
);

operators/cuda/transpose_cast.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "transpose_cast_impl.cuh"
7+
#include "ortx_common.h"
8+
9+
namespace contrib {
10+
11+
template <typename TIN, typename TOUT>
12+
struct Transpose2DCast {
13+
template <typename TDict>
14+
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
15+
return {};
16+
}
17+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
18+
const ortc::Tensor<TIN>& input,
19+
ortc::Tensor<TOUT>& output) const {
20+
const TIN* input_data = input.Data();
21+
auto shape = input.Shape();
22+
if (shape.size() != 2) {
23+
ORTX_CXX_API_THROW("Input must be a 2D tensor", ORT_RUNTIME_EXCEPTION);
24+
}
25+
int n_rows = static_cast<int>(shape[0]);
26+
int n_cols = static_cast<int>(shape[1]);
27+
28+
std::vector<int64_t> new_shape{static_cast<int64_t>(n_cols), static_cast<int64_t>(n_rows)};
29+
TOUT* output_data = output.Allocate(new_shape);
30+
if (0 == n_rows || 0 == n_cols) {
31+
return {};
32+
}
33+
LaunchTranspose2DCastKernel<TIN, TOUT>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
34+
n_rows, n_cols, input_data, output_data);
35+
return {};
36+
}
37+
};
38+
39+
} // namespace contrib
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "device_prop.cuh"
5+
#include "utils.cuh"
6+
#include "transpose_cast_impl.cuh"
7+
#include "cuda_type.h"
8+
9+
using namespace Ort::Custom;
10+
11+
#define TILE_DIM 32
12+
#define BLOCK_ROWS 8
13+
14+
template <typename TOUT, typename TIN>
15+
__global__ void Transpose2DCastKernel(TOUT *output_data, const TIN *input_data, int n_rows, int n_cols) {
16+
__shared__ TIN tile[TILE_DIM][TILE_DIM + 1];
17+
18+
int x = blockIdx.x * TILE_DIM + threadIdx.x;
19+
int y = blockIdx.y * TILE_DIM + threadIdx.y;
20+
// int width = gridDim.x * TILE_DIM;
21+
22+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
23+
tile[threadIdx.y + j][threadIdx.x] = input_data[(y + j) * n_cols + x];
24+
25+
__syncthreads();
26+
27+
x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
28+
y = blockIdx.x * TILE_DIM + threadIdx.y;
29+
30+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
31+
output_data[(y + j) * n_rows + x] = (TOUT)(tile[threadIdx.x][threadIdx.y + j]);
32+
}
33+
34+
template <typename TIN, typename TOUT>
35+
cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols,
36+
const TIN* input, TOUT* output) {
37+
dim3 dimGrid((n_cols + TILE_DIM - 1) / TILE_DIM, (n_rows + TILE_DIM - 1) / TILE_DIM, 1);
38+
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
39+
using TTIN = typename contrib::CudaT<TIN>::MappedType;
40+
using TTOUT = typename contrib::CudaT<TOUT>::MappedType;
41+
Transpose2DCastKernel<TTOUT, TTIN><<<dimGrid, dimBlock, TILE_DIM * TILE_DIM + TILE_DIM, stream>>>(
42+
reinterpret_cast<TTOUT*>(output), reinterpret_cast<const TTIN*>(input), n_rows, n_cols);
43+
return cudaGetLastError();
44+
}
45+
46+
template <>
47+
cudaError_t LaunchTranspose2DCastKernel<float, ortc::MFloat16>(cudaStream_t stream, int n_rows, int n_cols,
48+
const float* input, ortc::MFloat16* output) {
49+
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
50+
}
51+
52+
template <>
53+
cudaError_t LaunchTranspose2DCastKernel<ortc::MFloat16, float>(cudaStream_t stream, int n_rows, int n_cols,
54+
const ortc::MFloat16* input, float* output) {
55+
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
56+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
template <typename TIN, typename TOUT>
9+
cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const TIN* input, TOUT* output);

test/cuda/test_cudaops.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ def _run(self, X):
2121
return (1 - X,)
2222

2323

24+
class Transpose2DCastFP16(OpRun):
25+
op_domain = "ai.onnx.contrib"
26+
27+
def _run(self, X):
28+
return (X.T.to(np.float16),)
29+
30+
31+
class Transpose2DCastFP32(OpRun):
32+
op_domain = "ai.onnx.contrib"
33+
34+
def _run(self, X):
35+
return (X.T.to(np.float32),)
36+
37+
2438
class TestCudaOps(unittest.TestCase):
2539
@staticmethod
2640
def _create_negpos_test_model(domain="ai.onnx.contrib"):
@@ -151,8 +165,6 @@ def test_cuda_negxplus1(self):
151165
self._negxplus1_cuda(TensorProto.FLOAT16)
152166

153167
def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)):
154-
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs
155-
156168
model1 = helper.make_model(
157169
helper.make_graph(
158170
[
@@ -181,7 +193,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
181193
f"{op_type}SharedInput",
182194
["X", "Y", "Z"],
183195
["XY", "XZ"],
184-
domain="onnx_extended.ortops.optim.cuda",
196+
domain="ai.onnx.contrib",
185197
)
186198
],
187199
"nd",
@@ -197,7 +209,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
197209
),
198210
opset_imports=[
199211
helper.make_opsetid("", 18),
200-
helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
212+
helper.make_opsetid("ai.onnx.contrib", 1),
201213
],
202214
ir_version=9,
203215
)
@@ -212,7 +224,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
212224
expected = ref.run(None, feeds1)
213225

214226
opts = _ort.SessionOptions()
215-
opts.register_custom_ops_library(get_ort_ext_libs()[0])
227+
opts.register_custom_ops_library(_get_library_path())
216228
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
217229
got = sess.run(None, feeds1)
218230
for i in range(2):
@@ -262,6 +274,62 @@ def test_add_shared_input_cuda_broadcast2(self):
262274
shapec=(3, 2, 3),
263275
)
264276

277+
def _transpose_cast_cuda(self, itype):
278+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
279+
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
280+
model1 = helper.make_model(
281+
helper.make_graph(
282+
[
283+
helper.make_node("Transpose", ["X"], ["t"], perm=[1, 0]),
284+
helper.make_node("Cast", ["t"], ["Y"], to=itype2),
285+
],
286+
"nd",
287+
[helper.make_tensor_value_info("X", itype, [None, None])],
288+
[helper.make_tensor_value_info("Y", itype2, [None, None])],
289+
),
290+
opset_imports=[helper.make_opsetid("", 18)],
291+
ir_version=9,
292+
)
293+
294+
model2 = helper.make_model(
295+
helper.make_graph(
296+
[
297+
helper.make_node(
298+
("Transpose2DCastFP16" if itype2 == TensorProto.FLOAT16 else "Transpose2DCastFP32"),
299+
["X"],
300+
["Y"],
301+
domain="ai.onnx.contrib",
302+
)
303+
],
304+
"nd",
305+
[helper.make_tensor_value_info("X", itype, [None, None])],
306+
[helper.make_tensor_value_info("Y", itype2, [None, None])],
307+
),
308+
opset_imports=[
309+
helper.make_opsetid("", 18),
310+
helper.make_opsetid("ai.onnx.contrib", 1),
311+
],
312+
ir_version=9,
313+
)
314+
315+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
316+
x = (np.arange(32 * 32 * 3) + 1).reshape((32, 32 * 3)).astype(dtype)
317+
318+
feeds1 = dict(X=x)
319+
ref = ReferenceEvaluator(model1, new_ops=[Transpose2DCastFP16, Transpose2DCastFP32])
320+
expected = ref.run(None, feeds1)[0]
321+
322+
opts = _ort.SessionOptions()
323+
opts.register_custom_ops_library(_get_library_path())
324+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
325+
got = sess.run(None, feeds1)[0]
326+
assert_almost_equal(expected, got, decimal=5)
327+
328+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
329+
def test_transpose_cast_cuda(self):
330+
self._transpose_cast_cuda(TensorProto.FLOAT)
331+
self._transpose_cast_cuda(TensorProto.FLOAT16)
332+
265333

266334
if __name__ == "__main__":
267-
unittest.main()
335+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)