Skip to content

Commit d50f920

Browse files
Zzf/rms norm (DeepLink-org#751)
* add rms norm op * take functions_ext into the adaptor
1 parent d0ab3da commit d50f920

File tree

5 files changed

+36
-35
lines changed

5 files changed

+36
-35
lines changed

adaptor/codegen/gen.py

+3
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,10 @@ def gen_autogen_operators(
724724
# get the implemented functions
725725
impl_base_dir = os.path.dirname(config_file_path)
726726
impl_func_dir = os.path.join(impl_base_dir, "functions")
727+
impl_func_ext_dir = os.path.join(impl_base_dir, "functions_ext")
727728
impl_functions = obtain_impl_func(impl_func_dir)
729+
impl_functions_ext = obtain_impl_func(impl_func_ext_dir)
730+
impl_functions.update(impl_functions_ext)
728731

729732
if impl_plugin:
730733
impl_plugin_dir = os.path.join(impl_base_dir, "../ascend_npu/diopi_impl")

impl/ascend/device_configs.py

-35
Original file line numberDiff line numberDiff line change
@@ -1860,39 +1860,4 @@
18601860
]
18611861
),
18621862
),
1863-
1864-
'rms_norm': dict(
1865-
name=["rms_norm"],
1866-
tensor_para=dict(
1867-
args=[
1868-
{
1869-
"ins": ['input'],
1870-
"dtype": [Skip(np.float32)],
1871-
},
1872-
],
1873-
),
1874-
),
1875-
1876-
'topk_nonzero': dict(
1877-
name=['topk'],
1878-
para=dict(
1879-
k=[Skip(1)],
1880-
),
1881-
),
1882-
1883-
'topk_zero': dict(
1884-
name=['topk'],
1885-
interface=['torch'],
1886-
para=dict(
1887-
k=[Skip(1)],
1888-
),
1889-
),
1890-
1891-
# FIXME 特定参数组合报错
1892-
'embedding': dict(
1893-
name=["embedding"],
1894-
para=dict(
1895-
padding_idx=[Skip(92)],
1896-
),
1897-
),
18981863
}
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
* @file
3+
* @author DeepLink
4+
* @copyright (c) 2023, DeepLink.
5+
*/
6+
7+
#include "../common/acloprunner.hpp"
8+
9+
namespace impl {
10+
namespace ascend {
11+
12+
diopiError_t diopiRMSNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t invRms, diopiConstTensorHandle_t input,
13+
diopiSize_t normalizedShape, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, double eps) {
14+
AscendTensor inputTensor(input);
15+
ASCEND_CHECK_ABORT(1 == normalizedShape.len && normalizedShape.data[0] == inputTensor.shape()[inputTensor.dim() - 1], "normalized shape error!");
16+
AclOpRunner<2, 2>("RmsNorm", ctx).addInput(input).addInput(weight).setAttr("epsilon", static_cast<float>(eps)).addOutput(out).addOutput(invRms).run();
17+
return diopiSuccess;
18+
}
19+
20+
diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias,
21+
diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
22+
diopiConstTensorHandle_t bias, diopiConstTensorHandle_t invRms, diopiSize_t normalizedShape, double eps) {
23+
AscendTensor inputTensor(input);
24+
ASCEND_CHECK_ABORT(1 == normalizedShape.len && normalizedShape.data[0] == inputTensor.shape()[inputTensor.dim() - 1], "normalized shape error!");
25+
AclOpRunner<4, 2>("RmsNorm", ctx).addInput(gradOutput).addInput(input).addInput(invRms).addInput(weight).addOutput(gradInput).addOutput(gradWeight).run();
26+
return diopiSuccess;
27+
}
28+
29+
} // namespace ascend
30+
} // namespace impl

impl/ascend_npu/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ set(OLD_IMPL_SRC
640640
${OLD_IMPL_DIR}/functions/linspace.cpp
641641
${OLD_IMPL_DIR}/functions/apply_penalty.cpp
642642
${OLD_IMPL_DIR}/functions/split.cpp
643+
${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp
643644
#${OLD_IMPL_DIR}/test/export_functions.cpp
644645
#${OLD_IMPL_DIR}/test/conform_test.cpp
645646
${OLD_IMPL_DIR}/common/utils.cpp

impl/ascend_npu/ascend_config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ ascend:
192192
- diopiConvolution2dBackward
193193
- diopiLogicalAnd
194194
- diopiLogicalOr
195+
- diopiRMSNorm
196+
- diopiRMSNormBackward
195197
- diopiExpand
196198
- diopiLinspace
197199
- diopiProd

0 commit comments

Comments
 (0)