Skip to content

Commit 8e7eebf

Browse files
[1.8] Fix onnx mixed precision export for layernorm & fuseLogSoftmaxNllLoss (pytorch#52510)
Co-authored-by: Shubham Bhokare <[email protected]>
1 parent f8afb8b commit 8e7eebf

File tree

4 files changed

+89
-3
lines changed

4 files changed

+89
-3
lines changed

test/onnx/test_pytorch_onnx_onnxruntime_cuda.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import onnxruntime # noqa
33
import torch
44

5+
from torch.cuda.amp import autocast
6+
57
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
68
from test_pytorch_common import skipIfNoCuda
79

@@ -24,6 +26,43 @@ def forward(self, x):
2426
x = torch.randn(2, 4, 5, 6, requires_grad=True, dtype=torch.float16, device=torch.device('cuda'))
2527
self.run_test(GeluModel(), x, rtol=1e-3, atol=1e-5)
2628

29+
@skipIfUnsupportedMinOpsetVersion(9)
30+
@skipIfNoCuda
31+
def test_layer_norm_fp16(self):
32+
class LayerNormModel(torch.nn.Module):
33+
def __init__(self):
34+
super(LayerNormModel, self).__init__()
35+
self.layer_norm = torch.nn.LayerNorm([10, 10])
36+
37+
def forward(self, x):
38+
return self.layer_norm(x)
39+
40+
x = torch.randn(20, 5, 10, 10, requires_grad=True, dtype=torch.float16, device=torch.device('cuda'))
41+
self.run_test(LayerNormModel(), x, rtol=1e-3, atol=1e-5)
42+
43+
44+
@skipIfUnsupportedMinOpsetVersion(12)
45+
@skipIfNoCuda
46+
def test_softmaxCrossEntropy_fusion_fp16(self):
47+
class FusionModel(torch.nn.Module):
48+
def __init__(self):
49+
super(FusionModel, self).__init__()
50+
self.loss = torch.nn.NLLLoss(reduction='none')
51+
self.m = torch.nn.LogSoftmax(dim=1)
52+
53+
@autocast()
54+
def forward(self, input, target):
55+
output = self.loss(self.m(2 * input), target)
56+
return output
57+
58+
N, C = 5, 4
59+
input = torch.randn(N, 16, dtype=torch.float16, device=torch.device('cuda'))
60+
target = torch.empty(N, dtype=torch.long, device=torch.device('cuda')).random_(0, C)
61+
62+
# using test data containing default ignore_index=-100
63+
target[target == 1] = -100
64+
self.run_test(FusionModel(), (input, target))
65+
2766
TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp
2867
TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test
2968

torch/csrc/jit/passes/onnx/peephole.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,32 @@ static void fuseLogSoftmaxNllLoss(Block* b) {
668668
auto prev = it->input(0)->node();
669669
Node* origNllLossNode = *it;
670670
Node* origLogSoftmaxNode;
671+
672+
// Check for patterns especially in cases with autocasting enabled
673+
// in which a cast node is inserted before the NegativeLogLikelihoodLoss
674+
// node and this causes the patterns below not to be recognizable by the
675+
// fuseLogSoftmaxNllLoss function
676+
// For example if the input is 2D
677+
// graph(%input : Half(3, 5),
678+
// %target : Long(3)):
679+
// %4 : Half(3, 5) = onnx::LogSoftmaxaxis=1
680+
// %8 : Float = onnx::Cast[to=1](%4)
681+
// %9 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"]
682+
// return (%8)
683+
Node* castNode = nullptr;
684+
if (prev->kind() == onnx::Cast) {
685+
castNode = prev;
686+
prev = prev->input(0)->node();
687+
}
688+
671689
if (prev->kind() == onnx::LogSoftmax) {
672690
// if the input is 2D
673691
// graph(%input : Float(3, 5),
674692
// %target : Long(3)):
675693
// %4 : Float(3, 5) = onnx::LogSoftmaxaxis=1
676694
// %8 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"]
677695
// return (%8)
678-
origLogSoftmaxNode = it->input(0)->node();
696+
origLogSoftmaxNode = prev;
679697
} else if (
680698
prev->kind() == onnx::Transpose &&
681699
prev->input(0)->node()->kind() == onnx::LogSoftmax) {
@@ -751,6 +769,19 @@ static void fuseLogSoftmaxNllLoss(Block* b) {
751769
continue;
752770
}
753771

772+
// If the pattern indeed consists of a cast node before the
773+
// NegativeLogLikelihoodLoss node, place a cast node in the beginning
774+
// of the pattern instead
775+
if (castNode != nullptr) {
776+
auto onnx_type = castNode->i(attr::to);
777+
Node* cast_node = b->owningGraph()->create(onnx::Cast, 1);
778+
cast_node->addInput(origLogSoftmaxNode->inputs().at(0));
779+
cast_node->i_(attr::to, onnx_type);
780+
cast_node->insertBefore(origLogSoftmaxNode);
781+
origLogSoftmaxNode->replaceInputWith(
782+
origLogSoftmaxNode->inputs().at(0), cast_node->output());
783+
}
784+
754785
Node* softmaxCrossEntropyNode = b->owningGraph()->create(
755786
onnx::SoftmaxCrossEntropyLoss, it->outputs().size());
756787
for (size_t i = 0; i < softmaxCrossEntropyNode->outputs().size(); ++i) {

torch/onnx/symbolic_helper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,22 @@ def _is_fp(value):
296296
return (type == 'Float') or (type == 'Double') or (type == 'Half')
297297
return False
298298

299+
def _generate_wrapped_number(g, scalar):
300+
"""
301+
Create a wrapped number based on https://github.com/pytorch/pytorch/issues/9515
302+
A Tensor is a considered a "wrapped number" if it is
303+
auto-wrapped from a C++ or Python number type. Integer types are
304+
wrapped as 0-dim int64 tensors and floating-point types are
305+
wrapped as 0-dim double tensors.
306+
307+
The input to this function is constant value. If the data type
308+
is a floating point type, it is converted to a 0-dim double
309+
tensor, else it is converted to a 0-dim tensor of its original type
310+
"""
311+
assert not isinstance(scalar, torch.Tensor)
312+
if isinstance(scalar, float):
313+
return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double))
314+
return g.op("Constant", value_t=torch.tensor(scalar))
299315

300316
def _sort_helper(g, input, dim, decending=True, out=None):
301317
if out is not None:

torch/onnx/symbolic_opset9.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,8 +1319,8 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
13191319

13201320
axes = [-i for i in range(len(normalized_shape), 0, -1)]
13211321

1322-
two_cst = g.op("Constant", value_t=torch.tensor(2.))
1323-
eps_cst = g.op("Constant", value_t=torch.tensor(eps))
1322+
two_cst = sym_help._generate_wrapped_number(g, 2.)
1323+
eps_cst = sym_help._generate_wrapped_number(g, eps)
13241324

13251325
mean = g.op("ReduceMean", input, axes_i=axes)
13261326
numerator = sub(g, input, mean)

0 commit comments

Comments
 (0)