|
15 | 15 | UNSUPPORTED_PORTABLE_OPS = {
|
16 | 16 | "aten::_embedding_bag",
|
17 | 17 | "aten::_adaptive_avg_pool2d",
|
| 18 | + "aten::adaptive_max_pool2d", |
18 | 19 | "aten::median",
|
19 | 20 | "aten::median.dim",
|
20 | 21 | "aten::round.decimals",
|
|
34 | 35 | TestResult,
|
35 | 36 | )
|
36 | 37 | from executorch.exir import EdgeProgramManager
|
| 38 | +from executorch.exir.dialects._ops import ops as exir_ops |
37 | 39 |
|
38 | 40 |
|
39 | 41 | # A list of all runnable test suites and the corresponding python package.
|
|
43 | 45 | }
|
44 | 46 |
|
45 | 47 |
|
| 48 | +def _graph_has_unsupported_patterns(program: torch.export.ExportedProgram) -> bool: |
| 49 | + # Returns true if the model contains patterns that will fail when running on the ET |
| 50 | + # portable kernel library. |
| 51 | + |
| 52 | + # Check for 3d convolutions. All convs (1d, 2d, 3d) use the same op, so we need to look at |
| 53 | + # the input meta to determine the rank. |
| 54 | + for node in program.graph.nodes: |
| 55 | + if ( |
| 56 | + node.op == "call_function" |
| 57 | + and node.target == exir_ops.edge.aten.convolution.default |
| 58 | + ): |
| 59 | + in_rank = node.args[0].meta["val"].dim() |
| 60 | + if in_rank != 4: |
| 61 | + return True |
| 62 | + |
| 63 | + return False |
| 64 | + |
| 65 | + |
46 | 66 | def _get_test_seed(test_base_name: str) -> int:
|
47 | 67 | # Set the seed based on the test base name to give consistent inputs between backends. Add the
|
48 | 68 | # run seed to allow for reproducible results, but still allow for run-to-run variation.
|
@@ -162,7 +182,7 @@ def build_result(
|
162 | 182 | # Check if any undelegated ops are in the unsupported ops set.
|
163 | 183 | has_unsupported_ops = any(
|
164 | 184 | op in UNSUPPORTED_PORTABLE_OPS for op in undelegated_op_counts.keys()
|
165 |
| - ) |
| 185 | + ) or _graph_has_unsupported_patterns(edge_manager._etrecord.edge_dialect_program) |
166 | 186 |
|
167 | 187 | # Skip the test if there are unsupported portable ops remaining.
|
168 | 188 | if has_unsupported_ops:
|
|
0 commit comments