Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def forward(self, x):
x = self.linear2(x)
return x

class M2(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 512)

def example_inputs(self):
return (torch.randint(1, 10, (1, 512)),)

def forward(self, x):
return self.embedding(x)


class TestQAT(unittest.TestCase):
SEED = 123
Expand Down Expand Up @@ -669,5 +680,17 @@ def test_composable_qat_quantizer(self):
values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME)
self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"])

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_embedding(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer
model = M2()
x = model.example_inputs()
out = model(*x)
quantizer = Int4WeightOnlyEmbeddingQATQuantizer()
prepared = quantizer.prepare(model)
prepared_out = prepared(*x)
converted = quantizer.convert(model)
converted_out = converted(*x)

if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from ._module_swap_api import (
Int8DynActInt4WeightQATLinear,
)
from .embedding import (
Int4WeightOnlyEmbeddingQATQuantizer,
)

__all__ = [
"disable_4w_fake_quant",
Expand All @@ -23,6 +26,7 @@
"int8_dynamic_activation_int4_weight_fake_quantize",
"ComposableQATQuantizer",
"Int4WeightOnlyQATQuantizer",
"Int4WeightOnlyEmbeddingQATQuantizer"
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
]
36 changes: 13 additions & 23 deletions torchao/quantization/prototype/qat/_module_swap_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,23 @@
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_get_qmin_qmax,
)


# TODO: deprecate this flow in favor of the tensor subclass flow under qat/api.py
# This is currently needed for DDP and FSDP1, which are not compatible with the
# subclass flow.
# TODO: make module swap the main flow again, and remove the quantize_ flow
# TODO: rename this file to linear.py

# =========================================================
# | Linear int8 dynamic activations + int4 weight QAT |
# =========================================================


class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have int8
dynamic per token fake quantized activations and int4 fake quantized
grouped per channel weights.

Note: This quantizer is implemented using module swaps and may be
deprecated in the future. Please use `Int8DynActInt4WeightQATQuantizer`
instead if possible.
"""

def prepare(
Expand Down Expand Up @@ -92,7 +92,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):

# Load weights and qparams into quantized linear
n_bit = 4
(qmin, qmax) = child._get_qmin_qmax(n_bit)
(qmin, qmax) = _get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
x, self.scales_precision, self.zero_points_precision,
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
(act_qmin, act_qmax) = _get_qmin_qmax(8)
x_fq = _fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
Expand All @@ -170,7 +170,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
(weight_qmin, weight_qmax) = _get_qmin_qmax(4)
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
Expand All @@ -183,12 +183,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
w_fq = self.weight
return F.linear(x_fq, w_fq)

# TODO: move this to common util
def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)


def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):
"""
Expand All @@ -206,19 +200,15 @@ def disable_8da4w_fake_quant_module_swap(mod: torch.nn.Module):
mod.disable_fake_quant()


# ==================
# | int4wo QAT |
# ==================
# ===================================
# | Linear int4 weight-only QAT |
# ===================================


class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have
int4 fake quantized grouped per channel weights.

Note: This quantizer is implemented using module swaps and may be
deprecated in the future. Please use `Int4WeightOnlyQATQuantizer`
instead if possible.
"""

def prepare(
Expand Down
Loading
Loading