Skip to content

Commit 9a56a1d

Browse files
authored
Fix aqt implementation for aten.mm/aten.addmm fallback path (#2072)
Summary: Previously in try except there was a side effect of transposing weight_tensor, even if the dispatch failed, this could cause errors in the fallback path. This PR fixes it. Test Plan: python test/dtypes/test_affine_quantized.py -k test_matmul Reviewers: Subscribers: Tasks: Tags:
1 parent a96eeb1 commit 9a56a1d

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Int4XPULayout,
2222
PlainLayout,
2323
SemiSparseLayout,
24+
to_affine_quantized_intx,
2425
to_affine_quantized_intx_static,
2526
)
2627
from torchao.quantization import (
@@ -352,6 +353,23 @@ def test_slice(self, device, dtype):
352353
_ = dummy.weight.narrow(0, 0, 64)
353354
_ = dummy.weight.narrow(1, 0, 128)
354355

356+
@common_utils.parametrize("device", ["cuda"])
357+
@common_utils.parametrize("dtype", [torch.bfloat16])
358+
def test_matmul(self, device, dtype):
359+
x = torch.randn(53, 2048)
360+
w = torch.randn(53, 2048)
361+
w = to_affine_quantized_intx(
362+
w,
363+
mapping_type=MappingType.SYMMETRIC,
364+
block_size=(1, 32),
365+
target_dtype=torch.int8,
366+
quant_min=-8,
367+
quant_max=7,
368+
eps=torch.finfo(torch.float32).eps,
369+
)
370+
# make sure it runs
371+
torch.matmul(x, w.t())
372+
355373

356374
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
357375
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class QuantizedLinearNotImplementedError(NotImplementedError):
157157
pass
158158

159159

160+
# input_tensor: dimension is (M1, M2, ..., in_features)
161+
# weight_tensor: dimension is (out_features, in_features)
162+
# bias: dimension is (out_features,)
160163
@staticmethod
161164
def _quantized_linear_op(input_tensor, weight_tensor, bias):
162165
for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items():
@@ -335,12 +338,19 @@ def _(func, types, args, kwargs):
335338
f"{func} is not implemented for non floating point input"
336339
)
337340

341+
assert input_tensor.shape[-1] == weight_tensor.shape[0], (
342+
f"need mat1 shape: {input_tensor.shape} final dim"
343+
f"to match mat2 shape: {weight_tensor.shape} first dim"
344+
)
345+
338346
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
339347
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
340348
# make the branches easier to understand in `_quantized_linear_op`
341349
try:
342-
weight_tensor = weight_tensor.t()
343-
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
350+
transposed_weight_tensor = weight_tensor.t()
351+
return weight_tensor._quantized_linear_op(
352+
input_tensor, transposed_weight_tensor, bias
353+
)
344354
except QuantizedLinearNotImplementedError as e:
345355
# fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl`
346356
if (
@@ -365,9 +375,16 @@ def _(func, types, args, kwargs):
365375
f"{func} is not implemented for non floating point input"
366376
)
367377

378+
assert input_tensor.shape[-1] == weight_tensor.shape[0], (
379+
f"need mat1 shape: {input_tensor.shape} final dim"
380+
f"to match mat2 shape: {weight_tensor.shape} first dim"
381+
)
382+
368383
try:
369-
weight_tensor = weight_tensor.t()
370-
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
384+
transposed_weight_tensor = weight_tensor.t()
385+
return weight_tensor._quantized_linear_op(
386+
input_tensor, transposed_weight_tensor, bias
387+
)
371388
except QuantizedLinearNotImplementedError as e:
372389
# fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl`
373390
if (

0 commit comments

Comments
 (0)