Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add fused_adamw operator on dipu #935

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dipu/SupportedDiopiFunctions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ diopiForeachmulInpTensor
diopiForeachmulScalar
diopiForeachmulTensor
diopiForeachnormScalar
diopiFusedAdamW
diopiGather
diopiGe
diopiGeInp
Expand Down
36 changes: 36 additions & 0 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,42 @@
::diopiConstTensorHandle_t self_dtype_diopi = dipu::diopi_helper::toDiopiTensorHandle(self_dtype);
interface: diopiProd(ctx, out, self_dtype_diopi, nullptr)

- schema: "_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()"
custom_code_at_the_beginning: |
std::vector<diopiTensorHandle_t> diopiTensorHandles_self(self.size());
for(size_t i=0; i < self.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(self.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_self[i] = handle;
}
std::vector<diopiConstTensorHandle_t> diopiTensorHandles_grads(grads.size());
for(size_t i=0; i < grads.size(); ++i){
diopiTensorHandles_grads[i] = dipu::diopi_helper::toDiopiTensorHandle(grads.at(i));
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_exp_avgs(exp_avgs.size());
for(size_t i=0; i < exp_avgs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avgs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_exp_avgs[i] = handle;
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_exp_avg_sqs(exp_avg_sqs.size());
for(size_t i=0; i < exp_avg_sqs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(exp_avg_sqs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_exp_avg_sqs[i] = handle;
}
std::vector<diopiTensorHandle_t> diopiTensorHandles_max_exp_avg_sqs(max_exp_avg_sqs.size());
for(size_t i=0; i < max_exp_avg_sqs.size(); ++i){
diopiConstTensorHandle_t const_handle = dipu::diopi_helper::toDiopiTensorHandle(max_exp_avg_sqs.at(i));
diopiTensorHandle_t handle = const_cast<diopiTensorHandle_t>(const_handle);
diopiTensorHandles_max_exp_avg_sqs[i] = handle;
}
std::vector<diopiConstTensorHandle_t> diopiTensorHandles_state_steps(state_steps.size(), nullptr);
for(size_t i=0; i < state_steps.size(); ++i){
diopiTensorHandles_state_steps[i] = dipu::diopi_helper::toDiopiTensorHandle(state_steps.at(i));
}
interface: diopiFusedAdamW(ctx, diopiTensorHandles_self.data(), diopiTensorHandles_grads.data(), diopiTensorHandles_exp_avgs.data(), diopiTensorHandles_exp_avg_sqs.data(), diopiTensorHandles_max_exp_avg_sqs.data(), diopiTensorHandles_state_steps.data(), static_cast<int64_t>(self.size()), lr, beta1, beta2, eps, weight_decay, amsgrad, maximize)

- schema: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
custom_code_at_the_beginning: |
const auto self_dtype = at::native::to(self, dtype);
Expand Down
2 changes: 1 addition & 1 deletion dipu/third_party/DIOPI
Submodule DIOPI updated 36 files
+1 −1 .github/workflows/main.yml
+3 −3 diopi_test/diopi_stub/csrc/litert.cpp
+94 −10 diopi_test/python/configs/diopi_configs.py
+33 −0 diopi_test/python/conformance/customized_test.py
+69 −0 diopi_test/python/conformance/diopi_functions.py
+6 −4 diopi_test/python/conformance/gen_input.py
+19 −5 diopi_test/python/conformance/gen_output.py
+1 −0 diopi_test/python/conformance/global_op_list.py
+17 −2 impl/ascend/aclnn/adaptor.hpp
+116 −0 impl/ascend/ascend_tensor.cpp
+5 −0 impl/ascend/ascend_tensor.hpp
+4 −0 impl/ascend/common/acloprunner.hpp
+43 −0 impl/ascend/common/utils.cpp
+9 −0 impl/ascend/convert_config.yaml
+7 −15 impl/ascend/device_configs.py
+335 −0 impl/ascend/functions/index.cpp
+34 −0 impl/ascend/functions/syn_batch_norm.cpp
+107 −0 impl/ascend/functions_ext/token_attention_inference.cpp
+103 −0 impl/ascend/functions_ext/token_softmax_reducev_inference.cpp
+4 −0 impl/ascend_npu/CMakeLists.txt
+7 −4 impl/ascend_npu/ascend_config.yaml
+5 −0 impl/camb/device_configs.py
+78 −0 impl/cuda/device_configs.py
+1 −0 impl/cuda/error.cpp
+49 −1 impl/cuda/functions.cu
+13 −17 impl/cuda/test/CMakeLists.txt
+12 −0 impl/cuda/test/conform_test.cpp
+4 −1 impl/droplet/CMakeLists.txt
+1 −1 impl/droplet/test/CMakeLists.txt
+49 −0 impl/torch/functions/functions.cpp
+62 −12 impl/torch/functions/functions_ext.cpp
+14 −1 impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt
+43 −38 impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h
+4 −3 impl/torch/functions/functions_sparse.cpp
+4 −3 proto/include/diopi/diopirt.h
+42 −0 proto/include/diopi/functions.h
Loading