Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,27 @@ def get_symmetric_quantization_config(
# Determine the right observer/fake-quant constructor
if is_qat:
if is_per_channel:
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
weight_observer_or_fake_quant_ctr = FakeQuantize.with_args(
observer=PerChannelMinMaxObserver,
quant_min=weight_qmin,
quant_max=weight_qmax,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0,
**extra_args,
)
else:
# Set plain fake-quant with true min/max
weight_observer_or_fake_quant_ctr = FakeQuantize
weight_observer_or_fake_quant_ctr = FakeQuantize.with_args(**extra_args)
else:
# PTQ: set min/max observer
weight_observer_or_fake_quant_ctr = (
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
)
weight_observer_or_fake_quant_ctr = weight_observer_or_fake_quant_ctr.with_args(
**extra_args,
)

weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
Expand All @@ -122,9 +134,7 @@ def get_symmetric_quantization_config(
qscheme=weight_qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
)

bias_quantization_spec = None
Expand Down
100 changes: 100 additions & 0 deletions backends/arm/test/misc/test_qat_training_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from executorch.backends.arm.quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)

from executorch.backends.arm.tosa.specification import TosaSpecification
from torch.export import export
from torchao.quantization.pt2e import (
move_exported_model_to_eval,
move_exported_model_to_train,
)
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e

logger = logging.getLogger(__name__)


class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)

def forward(self, x):
return self.sequential(x)


def evaluate_model(model, inputs, expected_outputs):
with torch.no_grad():
test_outputs = model(inputs)
loss = torch.nn.functional.mse_loss(test_outputs, expected_outputs)
logger.info(f"Mean squared error: {loss.item()}")


def test_qat_training_loop():
"""Test the QAT training loop with a simple MLP model.
This function creates a simple MLP model, prepares it for QAT, runs a training loop,
and evaluates the quantized model to make sure everything works as expected."""

model = MLP()
logger.info("Starting training loop test")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(100):
model.train()
optimizer.zero_grad()
inputs = torch.randn(100, 1).clamp(-1, 1)
outputs = model(inputs)
loss = torch.nn.functional.mse_loss(outputs, torch.sin(inputs))
loss.backward()
optimizer.step()
if epoch % 5 == 0:
logger.info(f"Epoch {epoch}, Loss: {loss.item()}")
logger.info("Training loop test completed successfully")

logger.info("Evaluating model before QAT")
test_inputs = torch.randn(20, 1).clamp(-1, 1)
test_outputs = torch.sin(test_inputs)
evaluate_model(model, test_inputs, test_outputs)

exported_model = export(model, (torch.randn(1, 1),), strict=True)
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))

prepared_model = prepare_qat_pt2e(exported_model.module(), quantizer)
prepared_model = move_exported_model_to_train(prepared_model)
logger.info("QAT model prepared successfully")

logger.info("Starting QAT training loop")

for epoch in range(25):
inputs = torch.randn(100, 1).clamp(-1, 1)
optimizer.zero_grad()
outputs = prepared_model(inputs)
loss = torch.nn.functional.mse_loss(outputs, torch.sin(inputs))
loss.backward()
optimizer.step()
if epoch % 5 == 0:
logger.info(f"QAT Epoch {epoch}, Loss: {loss.item()}")
logger.info("QAT training loop completed successfully")
prepared_model = move_exported_model_to_eval(prepared_model)

quantized_model = convert_pt2e(prepared_model)
logger.info("QAT model quantized successfully")

logger.info("Evaluating quantized model")
test_inputs = torch.randn(100, 1).clamp(-1, 1)
test_outputs = torch.sin(test_inputs)
evaluate_model(quantized_model, test_inputs, test_outputs)
Loading