Skip to content

[WIP] Add AWQ quantization with QDQLayout support for ExecuTorch #2399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions examples/awq_qdq_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Simple usage example for AWQ with QDQLayout and ExecuTorch support.

This example demonstrates the complete workflow for using AWQ quantization
with QDQLayout support and 8-bit dynamic activation quantization.
"""

import torch
import torch.nn as nn

from torchao.prototype.awq import (
insert_awq_observer_qdq_,
AWQQDQConfig,
)
from torchao.prototype.awq.executorch_awq import _is_awq_observed_linear_qdq
from torchao.quantization import quantize_


def main():
print("AWQ + QDQLayout + ExecuTorch Example")
print("=" * 40)

# 1. Create a simple model
model = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 256),
)

print(f"Original model parameters: {sum(p.numel() for p in model.parameters()):,}")

# 2. Insert AWQ observers with QDQLayout support
print("\n1. Inserting AWQ observers...")
insert_awq_observer_qdq_(
model,
n_validation_examples=5,
validation_sequence_len=64,
quant_dtype=torch.uint4,
group_size=128,
use_dynamic_activation_quant=True, # Enable 8-bit dynamic activation quantization
)

print(" Observers inserted successfully!")

# 3. Calibrate the model
print("\n2. Calibrating model...")
model.eval()
with torch.no_grad():
for i in range(5):
# Generate random calibration data
example_input = torch.randn(2, 64, 512)
model(example_input)
print(f" Calibration step {i + 1}/5 completed")

print(" Calibration completed!")

# 4. Apply AWQ quantization with QDQLayout
print("\n3. Applying AWQ quantization with QDQLayout...")
config = AWQQDQConfig(
quant_dtype=torch.uint4,
group_size=128,
use_dynamic_activation_quant=True,
)

# Use the custom filter to target AWQObservedLinearQDQ modules
quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq)

print(" Quantization applied successfully!")

# 5. Test the quantized model
print("\n4. Testing quantized model...")
test_input = torch.randn(1, 64, 512)

with torch.no_grad():
output = model(test_input)
print(f" Input shape: {test_input.shape}")
print(f" Output shape: {output.shape}")

# 6. Verify QDQLayout usage
print("\n5. Verifying QDQLayout tensors...")
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
weight = module.weight
if hasattr(weight, "__tensor_flatten__"):
print(f" ✓ {name}: Uses quantized tensor (QDQLayout)")
# Check for QDQLayout specific attributes
if hasattr(weight, "int_data"):
print(f" - int_data shape: {weight.int_data.shape}")
print(f" - scale shape: {weight.scale.shape}")
else:
print(f" ✗ {name}: Uses regular tensor")

print("\n" + "=" * 40)
print("AWQ + QDQLayout quantization completed successfully!")
print("The model is now ready for ExecuTorch export.")

return model


if __name__ == "__main__":
# Set random seed for reproducibility
torch.manual_seed(42)

# Run the example
quantized_model = main()

print(f"\nFinal model type: {type(quantized_model)}")
print("Example completed successfully!")
245 changes: 245 additions & 0 deletions test/prototype/test_awq_executorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import torch
import torch.nn.functional as F
from torchao.prototype.awq import (
insert_awq_observer_qdq_,
AWQQDQConfig,
)
from torchao.prototype.awq.executorch_awq import _is_awq_observed_linear_qdq
from torchao.quantization import quantize_
from torchao.dtypes.uintx.q_dq_layout import QDQLayout


class TestAWQExecutorchIntegration(unittest.TestCase):
"""Test suite for AWQ + QDQLayout + ExecuTorch integration."""

def setUp(self):
"""Set up test fixtures."""
torch.manual_seed(42)

# Create a simple test model
self.model = torch.nn.Sequential(
torch.nn.Linear(64, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 32),
)

# Example input for testing
self.example_input = torch.randn(2, 16, 64)
self.batch_size, self.seq_len, self.hidden_size = self.example_input.shape

def test_awq_observer_insertion(self):
"""Test insertion of AWQ observers with QDQLayout support."""
model = torch.nn.Sequential(
torch.nn.Linear(64, 128),
torch.nn.Linear(128, 32),
)

# Insert AWQ observers
insert_awq_observer_qdq_(
model,
n_validation_examples=2,
validation_sequence_len=16,
quant_dtype=torch.int4,
group_size=64,
)

# Check that Linear layers were replaced with AWQObservedLinearQDQ
from torchao.prototype.awq.executorch_awq import AWQObservedLinearQDQ

for module in model.modules():
if isinstance(module, torch.nn.Linear):
# Should be replaced with AWQObservedLinearQDQ
self.assertIsInstance(module, AWQObservedLinearQDQ)
# Check observer configuration
self.assertEqual(module.act_obs.n_validation_examples, 2)
self.assertEqual(module.act_obs.validation_sequence_len, 16)

def test_awq_calibration_and_quantization(self):
"""Test AWQ calibration and quantization with QDQLayout."""
model = torch.nn.Sequential(torch.nn.Linear(64, 128))

# Insert AWQ observer
insert_awq_observer_qdq_(
model,
n_validation_examples=3,
validation_sequence_len=16,
quant_dtype=torch.int4,
group_size=32,
)

# Calibrate the model
model.eval()
with torch.no_grad():
for _ in range(3):
example_input = torch.randn(2, 16, 64)
model(example_input)

# Apply quantization
config = AWQQDQConfig(
quant_dtype=torch.int4,
group_size=32,
)
quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq)

# Verify the model is quantized (model is modified in-place)
self.assertIsInstance(model, torch.nn.Sequential)
self.assertIsInstance(model[0], torch.nn.Linear)

# Check that weight uses QDQLayout
weight_tensor = model[0].weight
self.assertTrue(hasattr(weight_tensor, "__tensor_flatten__")) # AQT tensor

# Test forward pass
with torch.no_grad():
output = model(self.example_input)
self.assertEqual(output.shape, (2, 16, 128))

def test_multiple_quantization_dtypes(self):
"""Test AWQ with different quantization dtypes."""
for quant_dtype in [torch.uint1, torch.uint2, torch.int4]:
with self.subTest(quant_dtype=quant_dtype):
model = torch.nn.Sequential(torch.nn.Linear(32, 64))

# Insert observer
insert_awq_observer_qdq_(
model,
n_validation_examples=2,
validation_sequence_len=4,
quant_dtype=quant_dtype,
group_size=16,
)

# Calibrate
model.eval()
with torch.no_grad():
for _ in range(2):
model(torch.randn(1, 4, 32))

# Quantize
config = AWQQDQConfig(quant_dtype=quant_dtype, group_size=16)
quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq)

# Test forward pass
with torch.no_grad():
output = model(torch.randn(1, 4, 32))
self.assertEqual(output.shape, (1, 4, 64))

def test_different_group_sizes(self):
"""Test AWQ with different group sizes."""
for group_size in [16, 32, 64, 128]:
with self.subTest(group_size=group_size):
model = torch.nn.Sequential(torch.nn.Linear(128, 64))

# Insert observer
insert_awq_observer_qdq_(
model,
n_validation_examples=2,
validation_sequence_len=4,
quant_dtype=torch.int4,
group_size=group_size,
)

# Calibrate
model.eval()
with torch.no_grad():
for _ in range(2):
model(torch.randn(1, 4, 128))

# Quantize
config = AWQQDQConfig(quant_dtype=torch.int4, group_size=group_size)
quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq)

# Test forward pass
with torch.no_grad():
output = model(torch.randn(1, 4, 128))
self.assertEqual(output.shape, (1, 4, 64))

def test_graph_pattern_for_executorch(self):
"""Test that the graph pattern matches ExecuTorch expectations for XNNPACK lowering."""
model = torch.nn.Sequential(torch.nn.Linear(128, 64))

# Insert AWQ observers with dynamic activation quantization
insert_awq_observer_qdq_(
model,
n_validation_examples=2,
validation_sequence_len=8,
quant_dtype=torch.int4,
group_size=32,
)

# Calibrate
model.eval()
with torch.no_grad():
for _ in range(2):
model(torch.randn(1, 8, 128))

# Quantize
config = AWQQDQConfig(
quant_dtype=torch.int4,
group_size=32,
)
quantize_(model, config, filter_fn=_is_awq_observed_linear_qdq)

# Test the forward method applies the expected AWQ + dynamic activation quantization pattern
example_input = torch.randn(1, 8, 128)

# Test that forward pass runs without error
with torch.no_grad():
actual_output = model(example_input)

# Verify output shape is correct
self.assertEqual(actual_output.shape, (1, 8, 64))

# Test graph pattern using torch.export (the proper way for ExecuTorch)
# Export with strict=True for ExecuTorch compatibility
exported_program = torch.export.export(model, (example_input,), strict=True)

# Test that exported model produces same results
exported_results = exported_program.module()(example_input)
self.assertTrue(
torch.allclose(actual_output, exported_results, atol=1e-3),
"Exported model should produce same results as original",
)

# Use FileCheck to verify the graph contains required operations for AWQ + dynamic activation quantization
from torch.testing import FileCheck

# Expected operations in the exported graph for AWQ + dynamic activation quantization
# This pattern is what ExecuTorch can recognize and lower to XNNPACK:
# 1. AWQ scaling (division operation)
# 2. Dynamic activation quantization (choose_qparams, quantize, dequantize)
# 3. Weight quantization/dequantization (from QDQLayout)
# 4. Linear operation on dequantized tensors
expected_operations = [
# AWQ scaling - division operation to scale input by AWQ scale
"torch.ops.aten.div.Tensor",
# Dynamic activation quantization - choose quantization parameters
"torch.ops.torchao.choose_qparams_affine.default",
# Dynamic activation quantization - quantize activation
"torch.ops.torchao.quantize_affine.default",
# Dynamic activation dequantization - dequantize activation for linear op
"torch.ops.torchao.dequantize_affine.default",
# Linear operation on dequantized tensors
"torch.ops.aten.linear.default",
]

# Verify each required operation appears in the exported graph
for operation in expected_operations:
count = 1
# We expect 2 dequantize operations: one for activation, one for weight
if operation == "torch.ops.torchao.dequantize_affine.default":
count = 2
FileCheck().check_count(operation, count, exactly=True).run(
exported_program.graph_module.code
)


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions torchao/prototype/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from .api import awq_uintx, insert_awq_observer_
from .core import AWQObservedLinear
from .executorch_awq import (
insert_awq_observer_qdq_,
AWQQDQConfig,
AWQObserverQDQ,
AWQObservedLinearQDQ,
_is_awq_observed_linear_qdq,
)

__all__ = [
"awq_uintx",
"insert_awq_observer_",
"AWQObservedLinear",
# ExecuTorch AWQ support
"insert_awq_observer_qdq_",
"AWQQDQConfig",
"AWQObserverQDQ",
"AWQObservedLinearQDQ",
"_is_awq_observed_linear_qdq",
]
Loading
Loading