Skip to content

Commit ec602a5

Browse files
authored
wq/impl atan operator (DeepLink-org#812)
* wq/impl atan operator
1 parent d636cd7 commit ec602a5

File tree

5 files changed

+44
-6
lines changed

5 files changed

+44
-6
lines changed

impl/ascend/convert_config.yaml

+7-1
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,17 @@
276276
- diopiRemainder:
277277
dtype: (bool, uint8, int8, int16, uint16)->int32, (float64)->float32
278278

279+
- diopiAtan:
280+
dtype: (uint8, int8, int32, int16, int64, bool)->float32
281+
282+
- diopiAtanInp:
283+
dtype: (uint8, int8, int32, int16, int64, bool)->float32
284+
279285
- diopiNormalTensor:
280286
dtype: (float64)->float32
281287

282288
- diopiNormalScalarTensor:
283289
dtype: (float64)->float32
284290

285291
- diopiNormalTensorScalar:
286-
dtype: (float64)->float32
292+
dtype: (float64)->float32

impl/ascend/device_configs.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@
338338
),
339339

340340
'pointwise_op': dict(
341-
name=['erf', 'erfinv', 'asin', 'ceil', 'atan'],
341+
name=['erf', 'erfinv', 'asin', 'ceil'],
342342
tensor_para=dict(
343343
args=[
344344
{
@@ -350,7 +350,7 @@
350350
),
351351

352352
'pointwise_op_int_without_inplace': dict(
353-
name=['erf', 'asin', 'atan'],
353+
name=['erf', 'asin'],
354354
tensor_para=dict(
355355
args=[
356356
{
@@ -362,7 +362,7 @@
362362
),
363363

364364
'pointwise_op_uint8': dict(
365-
name=['erf', 'asin', 'atan'],
365+
name=['erf', 'asin'],
366366
tensor_para=dict(
367367
args=[
368368
{
@@ -374,7 +374,7 @@
374374
),
375375

376376
'pointwise_op_bool': dict(
377-
name=['erf', 'asin', 'atan'],
377+
name=['erf', 'asin'],
378378
tensor_para=dict(
379379
args=[
380380
{

impl/ascend_npu/ascend_config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,5 @@ ascend_npu:
223223
- diopiTokenSoftmaxReduceVInference
224224
- diopiApplyPenalty
225225
- diopiContextAttentionInference
226+
- diopiAtan
227+
- diopiAtanInp

impl/ascend_npu/diopi_impl/atan.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
* @file
3+
* @author DeepLink
4+
* @copyright (c) 2023, DeepLink.
5+
*/
6+
7+
#include "helper.hpp"
8+
#include "op_plugin/AclOpsInterface.h"
9+
10+
namespace OP_IMPL_NS {
11+
12+
diopiError_t diopiAtan(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
13+
BEGIN_CALL_ACL_OP(input, out);
14+
if (!outAt.defined() || outAt.numel() <= 0) {
15+
return diopiSuccess;
16+
}
17+
acl_op::atan_out(inputAt, outAt);
18+
END_CALL_ACL_OP();
19+
}
20+
21+
diopiError_t diopiAtanInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {
22+
BEGIN_CALL_ACL_OP(input);
23+
if (!inputAt.defined() || inputAt.numel() <= 0) {
24+
return diopiSuccess;
25+
}
26+
acl_op::atan_(inputAt);
27+
END_CALL_ACL_OP();
28+
}
29+
30+
} // namespace OP_IMPL_NS

impl/ascend_npu/diopi_impl/helper.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ template <>
326326
inline std::string dumpArgs(const at::IntArrayRef& t) {
327327
std::stringstream stream;
328328
stream << "[";
329-
for (size_t i = 0; i < t.size(); i++) {
329+
for (long i : t) {
330330
stream << i << ",";
331331
}
332332
stream << "]";

0 commit comments

Comments
 (0)