Skip to content

Commit b60df02

Browse files
xaduprewenbingl
andauthored
Use of OrtxStatus in kernel NegXPlus1 (#732)
* first draft for NegXPlus1 * complete * fix unit test * rename one test * remove test if not cuda * switch to OrtxStatus --------- Co-authored-by: Wenbing Li <[email protected]>
1 parent 95a49fa commit b60df02

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

operators/cuda/negxplus1.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,30 @@
44
#pragma once
55
#include "ocos.h"
66
#include "negxplus1_impl.cuh"
7+
#include "ortx_common.h"
78

89
namespace contrib {
910

1011
template <typename T>
1112
struct NegXPlus1 {
1213
template <typename TDict>
13-
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
14-
return nullptr;
14+
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
15+
return {};
1516
}
16-
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
17+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
1718
const ortc::Tensor<T>& input,
1819
ortc::Tensor<T>& output) const {
1920
const T* input_data = input.Data();
2021
T* output_data = output.Allocate(input.Shape());
2122
auto input_length = input.NumberOfElement();
2223
if (0 == input_length) {
23-
return nullptr;
24+
return {};
2425
}
2526
LaunchNegXPlus1Kernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
2627
input_length,
2728
input_data,
2829
output_data);
29-
return nullptr;
30+
return {};
3031
}
3132
};
3233

0 commit comments

Comments
 (0)