Skip to content
60 changes: 26 additions & 34 deletions examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@
from executorch.backends.arm.vgf import VgfCompileSpec

# To use Cortex-M backend
from executorch.backends.cortex_m.passes.convert_to_cortex_m_pass import (
ConvertToCortexMPass,
)

from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import (
QuantizedOpFusionPass,
)

from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import (
ReplaceQuantNodesPass,
)
Expand Down Expand Up @@ -207,6 +199,14 @@ def _load_serialized_model(
return model, example_inputs


def _apply_replace_quant_nodes(edge, args):
"""Apply the replace_quant_nodes pass to the edge graph module."""

if args.target != "vgf" and not args.direct_drive:
edge = edge.transform([ReplaceQuantNodesPass()])
return edge


def get_model_and_inputs_from_name(
model_name: str, model_input: str | None
) -> Tuple[torch.nn.Module, Any]:
Expand Down Expand Up @@ -606,7 +606,7 @@ def get_args():
parser.add_argument(
"--enable_qdq_fusion_pass",
action="store_true",
help="Enable the Quantized qdq fusion Op passes",
help="[DEPRECATED] This flag is no longer used and will be removed in a future release.",
)
parser.add_argument(
"--enable_debug_mode",
Expand Down Expand Up @@ -787,6 +787,11 @@ def to_edge_TOSA_delegate(
),
)

# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
# delegated subgraph.
edge = _apply_replace_quant_nodes(edge, args)

return model_quant, edge


Expand Down Expand Up @@ -822,27 +827,12 @@ def to_edge_no_delegate(
),
)

return model_quant, edge


def transform_for_cortex_m_backend(edge_program_manager, args):
# Let's make sure we are using optimized Cortex M backend
# NB: If we can't find and replace ops those are expected to be replaced,
# bad things will happen at runtime, like "missing operator" errors!
# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
# delegated subgraph.
edge = _apply_replace_quant_nodes(edge, args)

# Instantiate the mandatory ReplaceQuantNodesPass
passes = [ReplaceQuantNodesPass]
if args.enable_qdq_fusion_pass:
passes += [ConvertToCortexMPass, QuantizedOpFusionPass]
current_edge = edge_program_manager
for pass_cls in passes:
transform_pass = (
pass_cls(current_edge.exported_program())
if pass_cls.__name__ == "QuantizedLinearFusionPass"
else pass_cls()
)
current_edge = current_edge.transform([transform_pass])
return current_edge
return model_quant, edge


if __name__ == "__main__": # noqa: C901
Expand All @@ -863,6 +853,13 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
model = exported_program.module()
model_fp32 = model

if args.enable_qdq_fusion_pass:
logging.warning(
"--enable_qdq_fusion_pass is deprecated and has no effect. "
"Quantized node replacement is now handled within the "
"respective compilation paths."
)

model_name = os.path.basename(os.path.splitext(args.model_name)[0])
if args.intermediates:
os.makedirs(args.intermediates, exist_ok=True)
Expand All @@ -885,11 +882,6 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
exported_program, args, model, example_inputs
)

# Cortex-m ops are never included in vgf or direct-drive
if args.target != "vgf" and not args.direct_drive:
# Transform so we can use ops from the Cortex M backend
edge = transform_for_cortex_m_backend(edge, args)

dump_delegation_info(edge, args.intermediates)

edge_program_manager_copy = copy.deepcopy(edge)
Expand Down
Loading