Skip to content

Commit a581609

Browse files
authored
[BE] Rename qparams for tinygemm (#2344)
1 parent 16e2d0a commit a581609

File tree

8 files changed

+37
-31
lines changed

8 files changed

+37
-31
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
choose_qparams_and_quantize_affine_hqq,
2727
dequantize_affine,
2828
dequantize_affine_float8,
29-
dequantize_affine_float_zero_point,
3029
dequantize_affine_floatx,
3130
dequantize_affine_no_zero_point,
31+
dequantize_affine_tinygemm,
3232
quantize_affine,
3333
quantize_affine_float8,
34-
quantize_affine_float_zero_point,
3534
quantize_affine_floatx,
3635
quantize_affine_no_zero_point,
36+
quantize_affine_tinygemm,
3737
)
3838
from torchao.utils import (
3939
TORCH_VERSION_AT_LEAST_2_5,
@@ -155,7 +155,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
155155
else:
156156
data, scale, zero_point = self.tensor_impl.get_plain()
157157
if self.zero_point_domain == ZeroPointDomain.FLOAT:
158-
dq = dequantize_affine_float_zero_point(
158+
dq = dequantize_affine_tinygemm(
159159
data,
160160
self.block_size,
161161
scale,
@@ -339,7 +339,7 @@ def from_hp_to_intx(
339339
quant_max,
340340
)
341341
elif zero_point_domain == ZeroPointDomain.FLOAT:
342-
data = quantize_affine_float_zero_point(
342+
data = quantize_affine_tinygemm(
343343
input_float,
344344
block_size,
345345
scale,
@@ -410,7 +410,7 @@ def from_hp_to_intx_static(
410410
quant_max,
411411
)
412412
elif zero_point_domain == ZeroPointDomain.FLOAT:
413-
int_data = quantize_affine_float_zero_point(
413+
int_data = quantize_affine_tinygemm(
414414
input_float,
415415
block_size,
416416
scale,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@
9393
from torchao.quantization.quant_primitives import (
9494
ZeroPointDomain,
9595
dequantize_affine,
96-
dequantize_affine_float_zero_point,
9796
dequantize_affine_no_zero_point,
97+
dequantize_affine_tinygemm,
9898
)
9999
from torchao.utils import (
100100
fill_defaults,
@@ -318,7 +318,7 @@ def _(func, types, args, kwargs):
318318
# we need to increase block size to correct dim
319319
new_blocks = idx.dim() - 1
320320
if args[1].zero_point_domain == ZeroPointDomain.FLOAT:
321-
_dequantize_affine = dequantize_affine_float_zero_point
321+
_dequantize_affine = dequantize_affine_tinygemm
322322
elif args[1].zero_point_domain == ZeroPointDomain.NONE:
323323
_dequantize_affine = dequantize_affine_no_zero_point
324324
else:

torchao/dtypes/uintx/int4_cpu_layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device
2020
from torchao.quantization.quant_primitives import (
2121
ZeroPointDomain,
22-
quantize_affine_float_zero_point,
22+
quantize_affine_tinygemm,
2323
)
2424
from torchao.utils import (
2525
TORCH_VERSION_AT_LEAST_2_5,
@@ -266,7 +266,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
266266
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
267267
scale = scale.reshape(scale.shape[:-1]).contiguous()
268268
zero = zero.reshape(zero.shape[:-1]).contiguous()
269-
int_data = quantize_affine_float_zero_point(
269+
int_data = quantize_affine_tinygemm(
270270
dequantized,
271271
block_size,
272272
scale,

torchao/dtypes/uintx/int4_xpu_layout.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from dataclasses import dataclass
28
from typing import Optional, Tuple
39

@@ -372,7 +378,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
372378
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
373379
from torchao.quantization.quant_primitives import (
374380
quantize_affine,
375-
quantize_affine_float_zero_point,
381+
quantize_affine_tinygemm,
376382
)
377383
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
378384

@@ -423,7 +429,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
423429
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
424430
scale = scale.reshape(scale.shape[:-1]).contiguous()
425431
zero = zero.reshape(zero.shape[:-1]).contiguous()
426-
int_data = quantize_affine_float_zero_point(
432+
int_data = quantize_affine_tinygemm(
427433
dequantized,
428434
block_size,
429435
scale,

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchao.quantization.quant_primitives import (
2222
ZeroPointDomain,
2323
_get_reduction_params,
24-
quantize_affine_float_zero_point,
24+
quantize_affine_tinygemm,
2525
)
2626
from torchao.utils import (
2727
TORCH_VERSION_AT_LEAST_2_5,
@@ -511,7 +511,7 @@ def dequant_4d(self):
511511
target_dtype = torch.int32
512512
quant_min = 0
513513
quant_max = 15
514-
int_data = quantize_affine_float_zero_point(
514+
int_data = quantize_affine_tinygemm(
515515
dequantized,
516516
self.block_size,
517517
scale,

torchao/prototype/parq/quant/uniform_torchao.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
choose_qparams_affine_dont_preserve_zero,
1919
choose_qparams_affine_tinygemm,
2020
dequantize_affine,
21-
dequantize_affine_float_zero_point,
2221
dequantize_affine_no_zero_point,
22+
dequantize_affine_tinygemm,
2323
quantize_affine,
24-
quantize_affine_float_zero_point,
2524
quantize_affine_no_zero_point,
25+
quantize_affine_tinygemm,
2626
)
2727

2828
from .quantizer import Quantizer
@@ -76,8 +76,8 @@ def quantize(
7676

7777
if self.zero_point_domain == ZeroPointDomain.FLOAT and not self.preserve_zero:
7878
_choose_qparams_affine = choose_qparams_affine_tinygemm
79-
_quantize_affine = quantize_affine_float_zero_point
80-
_dequantize_affine = dequantize_affine_float_zero_point
79+
_quantize_affine = quantize_affine_tinygemm
80+
_dequantize_affine = dequantize_affine_tinygemm
8181
elif self.zero_point_domain == ZeroPointDomain.INT and not self.preserve_zero:
8282
_choose_qparams_affine = choose_qparams_affine_dont_preserve_zero
8383
_quantize_affine = quantize_affine

torchao/quantization/quant_primitives.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
"choose_qparams_affine_floatx",
3131
"quantize_affine",
3232
"quantize_affine_no_zero_point",
33-
"quantize_affine_float_zero_point",
33+
"quantize_affine_tinygemm",
3434
"dequantize_affine",
3535
"dequantize_affine_no_zero_point",
36-
"dequantize_affine_float_zero_point",
36+
"dequantize_affine_tinygemm",
3737
"quantize_affine_floatx",
3838
"dequantize_affine_floatx",
3939
"fake_quantize_affine",
@@ -428,7 +428,7 @@ def _quantize_affine_no_dtype_cast(
428428
return quant
429429

430430

431-
def quantize_affine_float_zero_point(
431+
def quantize_affine_tinygemm(
432432
input: torch.Tensor,
433433
block_size: List[int],
434434
scale: torch.Tensor,
@@ -453,7 +453,7 @@ def quantize_affine_float_zero_point(
453453
# torch.uintx dtypes yet
454454
if output_dtype in _SUB_BYTE_UINT_BOUNDS:
455455
output_dtype = torch.uint8
456-
return _quantize_affine_float_zero_point_no_dtype_cast(
456+
return _quantize_affine_tinygemm_no_dtype_cast(
457457
input,
458458
block_size,
459459
scale,
@@ -463,7 +463,7 @@ def quantize_affine_float_zero_point(
463463
).to(output_dtype)
464464

465465

466-
def _quantize_affine_float_zero_point_no_dtype_cast(
466+
def _quantize_affine_tinygemm_no_dtype_cast(
467467
input: torch.Tensor,
468468
block_size: Tuple[int, ...],
469469
scale: torch.Tensor,
@@ -803,7 +803,7 @@ def dequantize_affine_no_zero_point(
803803
)
804804

805805

806-
def _dequantize_affine_float_zero_point_no_dtype_check(
806+
def _dequantize_affine_tinygemm_no_dtype_check(
807807
input: torch.Tensor,
808808
block_size: List[int],
809809
scale: torch.Tensor,
@@ -848,7 +848,7 @@ def _dequantize_affine_float_zero_point_no_dtype_check(
848848
return dequant.view(original_shape).to(output_dtype)
849849

850850

851-
def dequantize_affine_float_zero_point(
851+
def dequantize_affine_tinygemm(
852852
input: torch.Tensor,
853853
block_size: Tuple[int, ...],
854854
scale: torch.Tensor,
@@ -887,7 +887,7 @@ def dequantize_affine_float_zero_point(
887887
torch.bfloat16,
888888
], f"Unsupported output dtype: {output_dtype}"
889889
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
890-
return _dequantize_affine_float_zero_point_no_dtype_check(
890+
return _dequantize_affine_tinygemm_no_dtype_check(
891891
input,
892892
block_size,
893893
scale,
@@ -1013,8 +1013,8 @@ def _do_fake_quantize_affine(
10131013
_quantize_affine = _quantize_affine_no_dtype_cast
10141014
_dequantize_affine = _dequantize_affine_no_dtype_check
10151015
elif zero_point_domain == ZeroPointDomain.FLOAT:
1016-
_quantize_affine = _quantize_affine_float_zero_point_no_dtype_cast
1017-
_dequantize_affine = _dequantize_affine_float_zero_point_no_dtype_check
1016+
_quantize_affine = _quantize_affine_tinygemm_no_dtype_cast
1017+
_dequantize_affine = _dequantize_affine_tinygemm_no_dtype_check
10181018
elif ZeroPointDomain == ZeroPointDomain.NONE:
10191019
_quantize_affine = _quantize_affine_no_zero_point_no_dtype_cast
10201020
_dequantize_affine = _dequantize_affine_no_zero_point_no_dtype_check

torchao/quantization/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
choose_qparams_affine_dont_preserve_zero,
2020
choose_qparams_affine_tinygemm,
2121
dequantize_affine,
22-
dequantize_affine_float_zero_point,
2322
dequantize_affine_no_zero_point,
23+
dequantize_affine_tinygemm,
2424
quantize_affine,
25-
quantize_affine_float_zero_point,
2625
quantize_affine_no_zero_point,
26+
quantize_affine_tinygemm,
2727
)
2828
from torchao.utils import (
2929
TORCH_VERSION_AT_LEAST_2_5,
@@ -439,7 +439,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
439439
if zero_point_domain == ZeroPointDomain.INT:
440440
_quantize_affine = quantize_affine
441441
elif zero_point_domain == ZeroPointDomain.FLOAT:
442-
_quantize_affine = quantize_affine_float_zero_point
442+
_quantize_affine = quantize_affine_tinygemm
443443
elif ZeroPointDomain == ZeroPointDomain.NONE:
444444
_quantize_affine = quantize_affine_no_zero_point
445445
else:
@@ -508,7 +508,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
508508
if zero_point_domain == ZeroPointDomain.INT:
509509
_dequantize_affine = dequantize_affine
510510
elif zero_point_domain == ZeroPointDomain.FLOAT:
511-
_dequantize_affine = dequantize_affine_float_zero_point
511+
_dequantize_affine = dequantize_affine_tinygemm
512512
else:
513513
_dequantize_affine = dequantize_affine_no_zero_point
514514
return _dequantize_affine(

0 commit comments

Comments
 (0)