Skip to content

Commit f3035be

Browse files
authored
Fixed layernorm when weight and bias is None in Stable Diffusion 3 (#2936)
1 parent f909b07 commit f3035be

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
def args_bounds_check(
2828
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
2929
) -> Any:
30-
return args[i] if len(args) > i else replacement
30+
return args[i] if len(args) > i and args[i] is not None else replacement
3131

3232

3333
def get_ir(target: Target) -> SourceIR:
@@ -156,8 +156,8 @@ def aten_ops_layer_norm(
156156
name,
157157
input=args[0],
158158
normalized_shape=args[1],
159-
weight=args_bounds_check(args, 2),
160-
bias=args_bounds_check(args, 3),
159+
weight=args_bounds_check(args, 2, 1.0),
160+
bias=args_bounds_check(args, 3, 0.0),
161161
eps=args_bounds_check(args, 4, 1e-05),
162162
cudnn_enable=args_bounds_check(args, 5, True),
163163
return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default),

tests/py/dynamo/conversion/test_layer_norm_aten.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ def forward(self, x):
119119
input_specs,
120120
)
121121

122+
@parameterized.expand([((5, 3, 2, 4), [2, 4])])
123+
def test_layer_norm_without_Scaling(self, input_shape, normalized_shape, eps=1e-05):
124+
class LayerNorm(torch.nn.Module):
125+
def forward(self, x):
126+
return torch.ops.aten.native_layer_norm.default(
127+
x,
128+
normalized_shape,
129+
None,
130+
None,
131+
eps,
132+
)[0]
133+
134+
inputs = [torch.randn(input_shape)]
135+
self.run_test(
136+
LayerNorm(),
137+
inputs,
138+
)
139+
122140

123141
if __name__ == "__main__":
124142
run_tests()

0 commit comments

Comments
 (0)