-
Notifications
You must be signed in to change notification settings - Fork 651
Open
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/Issues related to Apple's Core ML delegation and code under backends/apple/coreml/
Description
🐛 Describe the bug
When running BatchNorm3d ops on Core ML, it appears to crash the process pretty regularly.
Repro:
import torch
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.BatchNorm3d(3)
def forward(self, x):
return self.norm(x)
model = Model()
inputs = (
torch.randn(1, 3, 4, 4, 4),
)
eager_outputs = model(*inputs)
print(f"Eager: {eager_outputs.shape} {eager_outputs}")
ep = torch.export.export(model.eval(), inputs)
lowered = to_edge_transform_and_lower(
ep,
partitioner=[CoreMLPartitioner()],
compile_config=EdgeCompileConfig(_check_ir_validity=False)
).to_executorch()
print(ep)
print(lowered.exported_program())
et_model = _load_for_executorch_from_buffer(lowered.buffer)
et_outputs = et_model([*inputs])[0]
et_outputs - eager_outputs
I captured the following output pre-crash (with slightly different shapes):
loc("tensor<fp16, [5, 10, 4, 4, 4]> aten__native_batch_norm_legit_no_training_default_cast_fp16 = batch_norm(beta = tensor<fp16, [10]>(BLOBFILE(path = string(\22/Users/gjcomer/Library/Caches/executorchcoreml/models/executorch_c8e2de39-f155-452a-af5b-13f794eb4932_all.mlmodelc/weights/weight.bin\22), offset = uint64(448))), epsilon = fp16(1.00135803e-05), gamma = tensor<fp16, [10]>(BLOBFILE(path = string(\22/Users/gjcomer/Library/Caches/executorchcoreml/models/executorch_c8e2de39-f155-452a-af5b-13f794eb4932_all.mlmodelc/weights/weight.bin\22), offset = uint64(320))), mean = tensor<fp16, [10]>(BLOBFILE(path = string(\22/Users/gjcomer/Library/Caches/executorchcoreml/models/executorch_c8e2de39-f155-452a-af5b-13f794eb4932_all.mlmodelc/weights/weight.bin\22), offset = uint64(64))), variance = tensor<fp16, [10]>(BLOBFILE(path = string(\22/Users/gjcomer/Library/Caches/executorchcoreml/models/executorch_c8e2de39-f155-452a-af5b-13f794eb4932_all.mlmodelc/weights/weight.bin\22), offset = uint64(192))), x = x_to_fp16)[milId = uint64(1), name = string(\22aten__native_batch_norm_legit_no_training_default_cast_fp16\22)]; - /Users/gjcomer/Library/Caches/executorchcoreml/models/executorch_c8e2de39-f155-452a-af5b-13f794eb4932_all.mlmodelc/model.mil":12:12): error: output type 'tensor<5x10x4x4x4xf16>' and mean type 'tensor<1x0x1x1x519870560xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
Versions
coremltools version 8.3
executorch commit 67b6009 (Jun 14)
siyryu
Metadata
Metadata
Assignees
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/Issues related to Apple's Core ML delegation and code under backends/apple/coreml/
Type
Projects
Status
Backlog