Skip to content

Add LayerNorm support #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnx2pytorch/convert/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def extract_attributes(node):
)
elif attr.name == "axis" and node.op_type == "Flatten":
kwargs["start_dim"] = extract_attr_values(attr)
elif attr.name == "axis" and node.op_type == "LayerNormalization":
continue
elif attr.name == "axis" or attr.name == "axes":
v = extract_attr_values(attr)
if isinstance(v, (tuple, list)) and len(v) == 1:
Expand Down
2 changes: 2 additions & 0 deletions onnx2pytorch/convert/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
op = nn.Identity()
elif node.op_type == "InstanceNormalization":
op = convert_instance_norm_layer(node, params=params)
elif node.op_type == "LayerNormalization":
op = LayerNorm(list(params[0].dims), **extract_attributes(node))
elif node.op_type == "LeakyRelu":
op = nn.LeakyReLU(**extract_attributes(node), inplace=True)
elif node.op_type == "Less":
Expand Down
2 changes: 2 additions & 0 deletions onnx2pytorch/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .globalaveragepool import GlobalAveragePool
from .hardsigmoid import Hardsigmoid
from .instancenorm import InstanceNormWrapper
from .layernorm import LayerNorm
from .loop import Loop
from .lstm import LSTMWrapper
from .matmul import MatMul
Expand Down Expand Up @@ -55,6 +56,7 @@
"GatherND",
"GlobalAveragePool",
"InstanceNormWrapper",
"LayerNorm",
"Loop",
"LSTMWrapper",
"MatMul",
Expand Down
24 changes: 24 additions & 0 deletions onnx2pytorch/operations/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from torch import nn
from typing import Optional


class LayerNorm(nn.Module): # pylint: disable=missing-docstring
def __init__(self, normalized_shape: list, eps: float):
super().__init__()
self.normalized_shape = normalized_shape
self.eps = eps

def forward( # pylint: disable=missing-function-docstring
self,
inputs: torch.Tensor,
scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return nn.functional.layer_norm(
input=inputs,
normalized_shape=self.normalized_shape,
weight=scale,
bias=bias,
eps=self.eps,
)