File tree Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Original file line number Diff line number Diff line change 24
24
if is_torch_npu_available ():
25
25
import torch_npu
26
26
27
- ACTIVATION_FUNCTIONS = {
28
- "swish" : nn .SiLU () ,
29
- "silu" : nn .SiLU () ,
30
- "mish" : nn .Mish () ,
31
- "gelu" : nn .GELU () ,
32
- "relu" : nn .ReLU () ,
27
+ ACT2CLS = {
28
+ "swish" : nn .SiLU ,
29
+ "silu" : nn .SiLU ,
30
+ "mish" : nn .Mish ,
31
+ "gelu" : nn .GELU ,
32
+ "relu" : nn .ReLU ,
33
33
}
34
34
35
35
@@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
44
44
"""
45
45
46
46
act_fn = act_fn .lower ()
47
- if act_fn in ACTIVATION_FUNCTIONS :
48
- return ACTIVATION_FUNCTIONS [act_fn ]
47
+ if act_fn in ACT2CLS :
48
+ return ACT2CLS [act_fn ]()
49
49
else :
50
- raise ValueError (f"Unsupported activation function: { act_fn } " )
50
+ raise ValueError (f"activation function { act_fn } not found in ACT2FN mapping { list ( ACT2CLS . keys ()) } " )
51
51
52
52
53
53
class FP32SiLU (nn .Module ):
You can’t perform that action at this time.
0 commit comments