Skip to content

Commit 0439fde

Browse files
authored
[Backend Tester] Skip tests with undelegated conv3d ops (#14185)
We are missing a portable kernel implementation for Conv3d, so any tests that contain one will fail. Since conv1d, 2d, and 3d all share the same op, we need some logic to handle this constraint and skip the test. This fixes a small number of (incorrect) test failures on the backend suite when decomposed conv3d patterns are partially delegated.
1 parent e7fecdf commit 0439fde

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

backends/test/suite/runner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
UNSUPPORTED_PORTABLE_OPS = {
1616
"aten::_embedding_bag",
1717
"aten::_adaptive_avg_pool2d",
18+
"aten::adaptive_max_pool2d",
1819
"aten::median",
1920
"aten::median.dim",
2021
"aten::round.decimals",
@@ -34,6 +35,7 @@
3435
TestResult,
3536
)
3637
from executorch.exir import EdgeProgramManager
38+
from executorch.exir.dialects._ops import ops as exir_ops
3739

3840

3941
# A list of all runnable test suites and the corresponding python package.
@@ -43,6 +45,24 @@
4345
}
4446

4547

48+
def _graph_has_unsupported_patterns(program: torch.export.ExportedProgram) -> bool:
49+
# Returns true if the model contains patterns that will fail when running on the ET
50+
# portable kernel library.
51+
52+
# Check for 3d convolutions. All convs (1d, 2d, 3d) use the same op, so we need to look at
53+
# the input meta to determine the rank.
54+
for node in program.graph.nodes:
55+
if (
56+
node.op == "call_function"
57+
and node.target == exir_ops.edge.aten.convolution.default
58+
):
59+
in_rank = node.args[0].meta["val"].dim()
60+
if in_rank != 4:
61+
return True
62+
63+
return False
64+
65+
4666
def _get_test_seed(test_base_name: str) -> int:
4767
# Set the seed based on the test base name to give consistent inputs between backends. Add the
4868
# run seed to allow for reproducible results, but still allow for run-to-run variation.
@@ -162,7 +182,7 @@ def build_result(
162182
# Check if any undelegated ops are in the unsupported ops set.
163183
has_unsupported_ops = any(
164184
op in UNSUPPORTED_PORTABLE_OPS for op in undelegated_op_counts.keys()
165-
)
185+
) or _graph_has_unsupported_patterns(edge_manager._etrecord.edge_dialect_program)
166186

167187
# Skip the test if there are unsupported portable ops remaining.
168188
if has_unsupported_ops:

0 commit comments

Comments
 (0)