Skip to content

Commit a4c1aac

Browse files
authored
store activation cls instead of function (#10832)
* store cls instead of an obj * style
1 parent b2ca39c commit a4c1aac

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/diffusers/models/activations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
if is_torch_npu_available():
2525
import torch_npu
2626

27-
ACTIVATION_FUNCTIONS = {
28-
"swish": nn.SiLU(),
29-
"silu": nn.SiLU(),
30-
"mish": nn.Mish(),
31-
"gelu": nn.GELU(),
32-
"relu": nn.ReLU(),
27+
ACT2CLS = {
28+
"swish": nn.SiLU,
29+
"silu": nn.SiLU,
30+
"mish": nn.Mish,
31+
"gelu": nn.GELU,
32+
"relu": nn.ReLU,
3333
}
3434

3535

@@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
4444
"""
4545

4646
act_fn = act_fn.lower()
47-
if act_fn in ACTIVATION_FUNCTIONS:
48-
return ACTIVATION_FUNCTIONS[act_fn]
47+
if act_fn in ACT2CLS:
48+
return ACT2CLS[act_fn]()
4949
else:
50-
raise ValueError(f"Unsupported activation function: {act_fn}")
50+
raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
5151

5252

5353
class FP32SiLU(nn.Module):

0 commit comments

Comments
 (0)