Open
Description
🐛 Describe the bug
When using an in-place activations (tested on relu and elu), it appears to alter the graph outputs. I'm not entirely sure what the expected behavior is, but this seems incorrect.
In the example below, note how the method returns two values instead of one, and the first element is no longer the primary output of the method.
Example Graph (after export):
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 4, 5]"):
# File: /var/folders/90/5w9gk0fn4n3g7fw1bvq8r1_m0000gn/T/ipykernel_98835/3298680854.py:12 in forward, code: y = torch.nn.functional.relu(x, inplace=True)
relu_: "f32[3, 4, 5]" = torch.ops.aten.relu_.default(x); x = None
# File: /var/folders/90/5w9gk0fn4n3g7fw1bvq8r1_m0000gn/T/ipykernel_98835/3298680854.py:13 in forward, code: return x + y
add: "f32[3, 4, 5]" = torch.ops.aten.add.Tensor(relu_, relu_); relu_ = None
return (add,)
After lowering:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 4, 5]"):
# No stacktrace found for following nodes
alloc: "f32[3, 4, 5]" = executorch_exir_memory_alloc(((3, 4, 5), torch.float32))
# File: /var/folders/90/5w9gk0fn4n3g7fw1bvq8r1_m0000gn/T/ipykernel_98835/3298680854.py:12 in forward, code: y = torch.nn.functional.relu(x, inplace=True)
aten_relu_default: "f32[3, 4, 5]" = torch.ops.aten.relu.out(x, out = alloc); alloc = None
# No stacktrace found for following nodes
alloc_1: "f32[3, 4, 5]" = executorch_exir_memory_alloc(((3, 4, 5), torch.float32))
# File: /var/folders/90/5w9gk0fn4n3g7fw1bvq8r1_m0000gn/T/ipykernel_98835/3298680854.py:13 in forward, code: return x + y
aten_add_tensor: "f32[3, 4, 5]" = torch.ops.aten.add.out(aten_relu_default, aten_relu_default, out = alloc_1); alloc_1 = None
# No stacktrace found for following nodes
alloc_2: "f32[3, 4, 5]" = executorch_exir_memory_alloc(((3, 4, 5), torch.float32))
aten_copy_default: "f32[3, 4, 5]" = torch.ops.aten.copy.out(x, aten_relu_default, out = alloc_2); aten_relu_default = alloc_2 = None
copy_: "f32[3, 4, 5]" = torch.ops.aten.copy_.default(x, aten_copy_default); x = aten_copy_default = None
return (copy_, aten_add_tensor)
Repro:
import torch
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
y = torch.nn.functional.relu(x, inplace=True)
return x + y
model = Model()
inputs = (
torch.randn(3, 4, 5),
)
eager_outputs = model(*inputs)
print(f"Eager: {eager_outputs.shape} {eager_outputs}")
ep = torch.export.export(model, inputs)
lowered = to_edge_transform_and_lower(
ep,
#partitioner=[CoreMLPartitioner()],
).to_executorch()
print(ep)
print(lowered.exported_program())
Versions
Nightly
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
To triage