Open
Description
🐛 Describe the bug
Now that export switched to non-strict by default, unused parameters are left in the graph by default. This means that unquantized weights get serialized along with quantized weights, causing PTE size to bloat by 5x or more. We should strip out unused parameters somewhere in to_edge or to_executorch.
As a repro (requiring the latest PyTorch):
import torch
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export, export_for_training, Dim
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(16, 1024)
self.relu1 = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(1024, 16)
self.relu2 = torch.nn.ReLU()
def forward(self, x):
x = self.linear1(x)
x = self.relu1(x)
x = self.linear2(x)
x = self.relu2(x)
return x
model = SimpleModel()
inputs = (torch.randn(1, 16),)
pre_autograd_aten_dialect = torch.export.export_for_training(
model,
inputs,
).module()
quantizer = XNNPACKQuantizer()
#qparams = get_symmetric_quantization_config(is_dynamic=True, is_per_channel=True)
qparams = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(qparams)
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
prepared_graph.to("cpu")
converted_graph = convert_pt2e(prepared_graph)
DuplicateDynamicQuantChainPass()(converted_graph)
ep = export(converted_graph, inputs, strict=False)
lowered = to_edge_transform_and_lower(
ep,
partitioner=[XnnpackPartitioner()]
)
When printing the lowered program, note the extra unused f32 weights. You can also observe the PTE size is much larger than expected. Specifically, p_linear1_weight and p_linear2_weight are the original (unquantized) f32 weights and are unused. There is a u8 copy of the weights which is consumed by the delegate as expected.
print(lowered.exported_program())
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_linear1_weight: "f32[1024, 16]", p_linear2_weight: "f32[16, 1024]", x: "f32[1, 16]"):
# No stacktrace found for following nodes
lowered_module_0 = self.lowered_module_0
executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None
getitem: "f32[1, 16]" = executorch_call_delegate[0]; executorch_call_delegate = None
return (getitem,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear1_weight'), target='linear1.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear2_weight'), target='linear2.weight', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
Versions
All
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
To triage