Skip to content

Commit 03056c2

Browse files
authored
[Kunlunxin] Add Flip/MaxPool2dWithIndices/MaxPool2dGrad (DeepLink-org#865)
* [kunlunxin] fix ci - add missing file * [CI][kunlunxin]fix sub dtype * [CI][kunlunxin]fix reduce_mean&reduce_sum;add more test cases * [KUNLUNXIN] add flip/max_pool2d_with_indices/max_pool2d_grad
1 parent 42ae84b commit 03056c2

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

impl/kunlunxin/convert_config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
- common_config:
22
dtype: (float64)->float32, (int64)->int32
3+
4+
- diopiMaxPool2dWithIndices:
5+
layout: NCHW
6+
7+
- diopiMaxPool2dBackward:
8+
layout: NCHW

impl/kunlunxin/functions/basic_op.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,5 +514,47 @@ DIOPI_API diopiError_t diopiCat(diopiContextHandle_t ctx, diopiTensorHandle_t ou
514514
return diopiSuccess;
515515
}
516516

517+
DIOPI_API diopiError_t diopiFlip(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t dims) {
518+
xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx);
519+
xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input);
520+
xdnn_pytorch::Tensor _out = impl::kunlunxin::build_xtorch_tensor(out);
521+
xtorch_vec _dims = impl::kunlunxin::build_xtorch_vec(dims);
522+
523+
DIOPI_CALL_XDNN(xdnn_pytorch::flip(ctx_xpu, _in, _dims, _out));
524+
return diopiSuccess;
525+
}
526+
527+
DIOPI_API diopiError_t diopiMaxPool2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output,
528+
diopiConstTensorHandle_t input, diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding,
529+
diopiSize_t dilation, bool ceil_mode, diopiConstTensorHandle_t indices) {
530+
xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx);
531+
xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input);
532+
xdnn_pytorch::Tensor _grad_in = impl::kunlunxin::build_xtorch_tensor(grad_input);
533+
xdnn_pytorch::Tensor _grad_out = impl::kunlunxin::build_xtorch_tensor(grad_output);
534+
xdnn_pytorch::Tensor _indices = impl::kunlunxin::build_xtorch_tensor(indices);
535+
xtorch_vec _kernel_size = impl::kunlunxin::build_xtorch_vec(kernel_size);
536+
xtorch_vec _stride = impl::kunlunxin::build_xtorch_vec(stride);
537+
xtorch_vec _padding = impl::kunlunxin::build_xtorch_vec(padding);
538+
xtorch_vec _dilation = impl::kunlunxin::build_xtorch_vec(dilation);
539+
540+
DIOPI_CALL_XDNN(
541+
xdnn_pytorch::max_pool2d_with_indices_backward(ctx_xpu, _grad_out, _in, _kernel_size, _stride, _padding, _dilation, ceil_mode, _indices, _grad_in));
542+
return diopiSuccess;
543+
}
544+
545+
DIOPI_API diopiError_t diopiMaxPool2dWithIndices(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t indices, diopiConstTensorHandle_t input,
546+
diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode) {
547+
xdnn::Context* ctx_xpu = impl::kunlunxin::set_cur_ctx(ctx);
548+
xdnn_pytorch::Tensor _in = impl::kunlunxin::build_xtorch_tensor(input);
549+
xdnn_pytorch::Tensor _out = impl::kunlunxin::build_xtorch_tensor(out);
550+
xdnn_pytorch::Tensor _indices = impl::kunlunxin::build_xtorch_tensor(indices);
551+
xtorch_vec _kernel_size = impl::kunlunxin::build_xtorch_vec(kernel_size);
552+
xtorch_vec _stride = impl::kunlunxin::build_xtorch_vec(stride);
553+
xtorch_vec _padding = impl::kunlunxin::build_xtorch_vec(padding);
554+
xtorch_vec _dilation = impl::kunlunxin::build_xtorch_vec(dilation);
555+
DIOPI_CALL_XDNN(xdnn_pytorch::max_pool2d_with_indices(ctx_xpu, _in, _kernel_size, _stride, _padding, _dilation, ceil_mode, _out, _indices));
556+
return diopiSuccess;
557+
}
558+
517559
} // namespace kunlunxin
518560
} // namespace impl

0 commit comments

Comments
 (0)