Skip to content

Commit f67d3b1

Browse files
committed
fix merging conflicts
2 parents 52f351c + 79f3b04 commit f67d3b1

File tree

6 files changed

+182
-3
lines changed

6 files changed

+182
-3
lines changed

operators/cuda/add_mul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct AddOrMulSharedInput {
2929
auto length_c = tensor_c.NumberOfElement();
3030

3131
T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape());
32-
T* output_data_ac = output_ab.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());
32+
T* output_data_ac = output_ac.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());
3333

3434
if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
3535
return {};

operators/cuda/cuda_ops.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "cuda/fast_gelu.h"
99
#include "cuda/negxplus1.h"
1010
#include "cuda/rotary.h"
11+
#include "cuda/transpose_cast.h"
1112
#endif
1213

1314
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
@@ -18,6 +19,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
1819
#if ORT_API_VERSION >= 16
1920
using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, true>;
2021
using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, false>;
22+
using Transpose2DCastFloat32ToFloat16Type = typename contrib::Transpose2DCast<float, ortc::MFloat16>;
23+
using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast<ortc::MFloat16, float>;
2124
#endif
2225

2326

@@ -37,7 +40,9 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
3740
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
3841
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
3942
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
40-
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>)
43+
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>),
44+
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
45+
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
4146
#endif
4247
#endif
4348
);

operators/cuda/transpose_cast.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "transpose_cast_impl.cuh"
7+
#include "ortx_common.h"
8+
9+
namespace contrib {
10+
11+
template <typename TIN, typename TOUT>
12+
struct Transpose2DCast {
13+
template <typename TDict>
14+
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
15+
return {};
16+
}
17+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
18+
const ortc::Tensor<TIN>& input,
19+
ortc::Tensor<TOUT>& output) const {
20+
const TIN* input_data = input.Data();
21+
auto shape = input.Shape();
22+
if (shape.size() != 2) {
23+
ORTX_CXX_API_THROW("Input must be a 2D tensor", ORT_RUNTIME_EXCEPTION);
24+
}
25+
int n_rows = static_cast<int>(shape[0]);
26+
int n_cols = static_cast<int>(shape[1]);
27+
28+
std::vector<int64_t> new_shape{static_cast<int64_t>(n_cols), static_cast<int64_t>(n_rows)};
29+
TOUT* output_data = output.Allocate(new_shape);
30+
if (0 == n_rows || 0 == n_cols) {
31+
return {};
32+
}
33+
LaunchTranspose2DCastKernel<TIN, TOUT>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
34+
n_rows, n_cols, input_data, output_data);
35+
return {};
36+
}
37+
};
38+
39+
} // namespace contrib
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "device_prop.cuh"
5+
#include "utils.cuh"
6+
#include "transpose_cast_impl.cuh"
7+
#include "cuda_type.h"
8+
9+
using namespace Ort::Custom;
10+
11+
#define TILE_DIM 32
12+
#define BLOCK_ROWS 8
13+
14+
template <typename TOUT, typename TIN>
15+
__global__ void Transpose2DCastKernel(TOUT *output_data, const TIN *input_data, int n_rows, int n_cols) {
16+
__shared__ TIN tile[TILE_DIM][TILE_DIM + 1];
17+
18+
int x = blockIdx.x * TILE_DIM + threadIdx.x;
19+
int y = blockIdx.y * TILE_DIM + threadIdx.y;
20+
// int width = gridDim.x * TILE_DIM;
21+
22+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
23+
tile[threadIdx.y + j][threadIdx.x] = input_data[(y + j) * n_cols + x];
24+
25+
__syncthreads();
26+
27+
x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
28+
y = blockIdx.x * TILE_DIM + threadIdx.y;
29+
30+
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
31+
output_data[(y + j) * n_rows + x] = (TOUT)(tile[threadIdx.x][threadIdx.y + j]);
32+
}
33+
34+
template <typename TIN, typename TOUT>
35+
cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols,
36+
const TIN* input, TOUT* output) {
37+
dim3 dimGrid((n_cols + TILE_DIM - 1) / TILE_DIM, (n_rows + TILE_DIM - 1) / TILE_DIM, 1);
38+
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
39+
using TTIN = typename contrib::CudaT<TIN>::MappedType;
40+
using TTOUT = typename contrib::CudaT<TOUT>::MappedType;
41+
Transpose2DCastKernel<TTOUT, TTIN><<<dimGrid, dimBlock, TILE_DIM * TILE_DIM + TILE_DIM, stream>>>(
42+
reinterpret_cast<TTOUT*>(output), reinterpret_cast<const TTIN*>(input), n_rows, n_cols);
43+
return cudaGetLastError();
44+
}
45+
46+
template <>
47+
cudaError_t LaunchTranspose2DCastKernel<float, ortc::MFloat16>(cudaStream_t stream, int n_rows, int n_cols,
48+
const float* input, ortc::MFloat16* output) {
49+
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
50+
}
51+
52+
template <>
53+
cudaError_t LaunchTranspose2DCastKernel<ortc::MFloat16, float>(cudaStream_t stream, int n_rows, int n_cols,
54+
const ortc::MFloat16* input, float* output) {
55+
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
56+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
template <typename TIN, typename TOUT>
9+
cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const TIN* input, TOUT* output);

test/cuda/test_cudaops.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ def _run(self, X):
2121
return (1 - X,)
2222

2323

24+
class Transpose2DCastFP16(OpRun):
25+
op_domain = "ai.onnx.contrib"
26+
27+
def _run(self, X):
28+
return (X.T.to(np.float16),)
29+
30+
31+
class Transpose2DCastFP32(OpRun):
32+
op_domain = "ai.onnx.contrib"
33+
34+
def _run(self, X):
35+
return (X.T.to(np.float32),)
36+
37+
2438
class TestCudaOps(unittest.TestCase):
2539
@staticmethod
2640
def _create_negpos_test_model(domain="ai.onnx.contrib"):
@@ -321,6 +335,62 @@ def test_bigger_rotary_cuda(self):
321335
self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh)
322336
self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh)
323337

338+
def _transpose_cast_cuda(self, itype):
339+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
340+
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
341+
model1 = helper.make_model(
342+
helper.make_graph(
343+
[
344+
helper.make_node("Transpose", ["X"], ["t"], perm=[1, 0]),
345+
helper.make_node("Cast", ["t"], ["Y"], to=itype2),
346+
],
347+
"nd",
348+
[helper.make_tensor_value_info("X", itype, [None, None])],
349+
[helper.make_tensor_value_info("Y", itype2, [None, None])],
350+
),
351+
opset_imports=[helper.make_opsetid("", 18)],
352+
ir_version=9,
353+
)
354+
355+
model2 = helper.make_model(
356+
helper.make_graph(
357+
[
358+
helper.make_node(
359+
("Transpose2DCastFP16" if itype2 == TensorProto.FLOAT16 else "Transpose2DCastFP32"),
360+
["X"],
361+
["Y"],
362+
domain="ai.onnx.contrib",
363+
)
364+
],
365+
"nd",
366+
[helper.make_tensor_value_info("X", itype, [None, None])],
367+
[helper.make_tensor_value_info("Y", itype2, [None, None])],
368+
),
369+
opset_imports=[
370+
helper.make_opsetid("", 18),
371+
helper.make_opsetid("ai.onnx.contrib", 1),
372+
],
373+
ir_version=9,
374+
)
375+
376+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
377+
x = (np.arange(32 * 32 * 3) + 1).reshape((32, 32 * 3)).astype(dtype)
378+
379+
feeds1 = dict(X=x)
380+
ref = ReferenceEvaluator(model1, new_ops=[Transpose2DCastFP16, Transpose2DCastFP32])
381+
expected = ref.run(None, feeds1)[0]
382+
383+
opts = _ort.SessionOptions()
384+
opts.register_custom_ops_library(_get_library_path())
385+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
386+
got = sess.run(None, feeds1)[0]
387+
assert_almost_equal(expected, got, decimal=5)
388+
389+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
390+
def test_transpose_cast_cuda(self):
391+
self._transpose_cast_cuda(TensorProto.FLOAT)
392+
self._transpose_cast_cuda(TensorProto.FLOAT16)
393+
324394

325395
if __name__ == "__main__":
326-
unittest.main()
396+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)