Skip to content

Commit 4c4e6c7

Browse files
authored
zq/fix ascend index_select (DeepLink-org#832)
remove device_config of ascend index_select and add diopiIndexSelectBackward
1 parent ec602a5 commit 4c4e6c7

File tree

2 files changed

+14
-24
lines changed

2 files changed

+14
-24
lines changed

impl/ascend/device_configs.py

-24
Original file line numberDiff line numberDiff line change
@@ -657,30 +657,6 @@
657657
),
658658
),
659659

660-
'index_select': dict(
661-
name=['index_select'],
662-
tensor_para=dict(
663-
args=[
664-
{
665-
"ins": ['input'],
666-
"dtype": [Skip(np.float32),Skip(np.float64),Skip(np.float16),],
667-
},
668-
]
669-
),
670-
),
671-
672-
'index_select_not_float': dict(
673-
name=['index_select'],
674-
tensor_para=dict(
675-
args=[
676-
{
677-
"ins": ['input'],
678-
"dtype": [Skip(np.int32),Skip(np.int16),Skip(np.int64),Skip(np.uint8),Skip(np.int8),Skip(np.bool_),],
679-
},
680-
]
681-
),
682-
),
683-
684660
'masked_scatter': dict(
685661
name=['masked_scatter'],
686662
tensor_para=dict(

impl/ascend/functions/index_select.cpp

100644100755
+14
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,19 @@ diopiError_t diopiIndexSelect(diopiContextHandle_t ctx, diopiTensorHandle_t out,
1515
return diopiSuccess;
1616
}
1717

18+
diopiError_t diopiIndexSelectBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t grad, diopiSize_t inputSizes,
19+
int64_t dim, diopiConstTensorHandle_t index) {
20+
AscendTensor gradInputAt(gradInput);
21+
if (dim < 0) {
22+
dim = dim + inputSizes.len;
23+
}
24+
std::vector<int64_t> dimVec({dim});
25+
diopiSize_t dimInput = vectorToDiopiSize(dimVec);
26+
diopiScalar_t scalarZero = constructDiopiScalarT(gradInputAt.dtype(), 0);
27+
diopiFill(ctx, gradInput, &scalarZero);
28+
AclOpRunner<3, 1>("InplaceIndexAdd", ctx).addInput(gradInput).addInput(index).addInput(grad).setAttr<int64_t>("axis", dim).addOutput(gradInput).run();
29+
return diopiSuccess;
30+
}
31+
1832
} // namespace ascend
1933
} // namespace impl

0 commit comments

Comments
 (0)