|
20 | 20 | from torchvision.ops.misc import FrozenBatchNorm2d
|
21 | 21 |
|
22 | 22 | from .create_act import create_act_layer
|
23 |
| -from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm |
| 23 | +from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d |
| 24 | +from .norm import RmsNorm, RmsNorm2d |
24 | 25 | from .trace_utils import _assert
|
25 | 26 |
|
| 27 | +try: |
| 28 | + from torch.nn.functional import rms_norm |
| 29 | +except ImportError: |
| 30 | + from .fast_norm import rms_norm |
| 31 | + |
26 | 32 |
|
27 | 33 | def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
|
28 | 34 | act_kwargs = act_kwargs or {}
|
@@ -460,3 +466,69 @@ def forward(self, x):
|
460 | 466 | x = self.drop(x)
|
461 | 467 | x = self.act(x)
|
462 | 468 | return x
|
| 469 | + |
| 470 | + |
| 471 | +class RmsNormAct(RmsNorm): |
| 472 | + """ RMSNorm + Activation for '2D' NCHW tensors |
| 473 | +
|
| 474 | + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction |
| 475 | + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something |
| 476 | + like https://github.com/pytorch/pytorch/pull/150576 lands. |
| 477 | + """ |
| 478 | + def __init__( |
| 479 | + self, |
| 480 | + num_channels, |
| 481 | + eps=1e-6, |
| 482 | + affine=True, |
| 483 | + apply_act=True, |
| 484 | + act_layer=nn.ReLU, |
| 485 | + act_kwargs=None, |
| 486 | + inplace=True, |
| 487 | + drop_layer=None, |
| 488 | + ): |
| 489 | + super().__init__(channels=num_channels, eps=eps, affine=affine) |
| 490 | + self.drop = drop_layer() if drop_layer is not None else nn.Identity() |
| 491 | + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) |
| 492 | + self._fast_norm = is_fast_norm() |
| 493 | + |
| 494 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 495 | + if self._fast_norm: |
| 496 | + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) |
| 497 | + else: |
| 498 | + x = rms_norm(x, self.normalized_shape, self.weight, self.eps) |
| 499 | + x = self.drop(x) |
| 500 | + x = self.act(x) |
| 501 | + return x |
| 502 | + |
| 503 | + |
| 504 | +class RmsNormAct2d(RmsNorm2d): |
| 505 | + """ RMSNorm + Activation for '2D' NCHW tensors |
| 506 | +
|
| 507 | + NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction |
| 508 | + on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something |
| 509 | + like https://github.com/pytorch/pytorch/pull/150576 lands. |
| 510 | + """ |
| 511 | + def __init__( |
| 512 | + self, |
| 513 | + num_channels, |
| 514 | + eps=1e-6, |
| 515 | + affine=True, |
| 516 | + apply_act=True, |
| 517 | + act_layer=nn.ReLU, |
| 518 | + act_kwargs=None, |
| 519 | + inplace=True, |
| 520 | + drop_layer=None, |
| 521 | + ): |
| 522 | + super().__init__(channels=num_channels, eps=eps, affine=affine) |
| 523 | + self.drop = drop_layer() if drop_layer is not None else nn.Identity() |
| 524 | + self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) |
| 525 | + self._fast_norm = is_fast_norm() |
| 526 | + |
| 527 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 528 | + if self._fast_norm: |
| 529 | + x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps) |
| 530 | + else: |
| 531 | + x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps) |
| 532 | + x = self.drop(x) |
| 533 | + x = self.act(x) |
| 534 | + return x |
0 commit comments