Skip to content

Commit 282d04f

Browse files
authored
[BE] Convert quant_primitives methods private (#2350)
1 parent 346baf6 commit 282d04f

23 files changed

+399
-222
lines changed

docs/source/api_ref_quantization.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,8 @@ Quantization Primitives
6363

6464
choose_qparams_affine
6565
choose_qparams_affine_with_min_max
66-
choose_qparams_affine_floatx
6766
quantize_affine
68-
quantize_affine_floatx
6967
dequantize_affine
70-
dequantize_affine_floatx
71-
choose_qparams_and_quantize_affine_hqq
72-
fake_quantize_affine
73-
fake_quantize_affine_cachemask
7468
safe_int_mm
7569
int_scaled_matmul
7670
MappingType

test/dtypes/test_affine_quantized_float.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@
4242
)
4343
from torchao.quantization.quant_primitives import (
4444
MappingType,
45+
_choose_qparams_affine_float8,
46+
_dequantize_affine_float8,
47+
_quantize_affine_float8,
4548
choose_qparams_affine,
46-
choose_qparams_affine_float8,
47-
dequantize_affine_float8,
48-
quantize_affine_float8,
4949
)
5050
from torchao.utils import (
5151
is_sm_at_least_89,
@@ -358,21 +358,21 @@ def test_mm_float8dq_per_row(
358358
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
359359
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
360360
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
361-
"""Test dequantize_affine_float8 with various configurations"""
361+
"""Test _dequantize_affine_float8 with various configurations"""
362362

363363
device = "cuda"
364364
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)
365365

366366
# Choose quantization parameters
367-
scale = choose_qparams_affine_float8(
367+
scale = _choose_qparams_affine_float8(
368368
input_tensor, float8_dtype=float8_dtype, block_size=block_size
369369
)
370370

371371
# Quantize
372-
quantized = quantize_affine_float8(input_tensor, scale, float8_dtype)
372+
quantized = _quantize_affine_float8(input_tensor, scale, float8_dtype)
373373

374374
# Dequantize
375-
dequantized = dequantize_affine_float8(quantized, scale, output_dtype)
375+
dequantized = _dequantize_affine_float8(quantized, scale, output_dtype)
376376

377377
# Verify output properties
378378
self.assertEqual(dequantized.dtype, output_dtype)
@@ -395,7 +395,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
395395
block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim
396396

397397
# Choose quantization parameters
398-
scale = choose_qparams_affine_float8(
398+
scale = _choose_qparams_affine_float8(
399399
input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size
400400
)
401401

@@ -407,10 +407,10 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
407407
self.assertEqual(scale.shape, expected_scale_shape)
408408

409409
# Quantize
410-
quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn)
410+
quantized = _quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn)
411411

412412
# Dequantize
413-
dequantized = dequantize_affine_float8(quantized, scale, torch.float32)
413+
dequantized = _dequantize_affine_float8(quantized, scale, torch.float32)
414414

415415
# Verify shapes match
416416
self.assertEqual(dequantized.shape, input_tensor.shape)

test/dtypes/test_floatx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
9191
@parametrize("ebits,mbits", _Floatx_DTYPES)
9292
def test_to_copy_device(self, ebits, mbits):
9393
from torchao.quantization.quant_primitives import (
94-
choose_qparams_affine_floatx,
95-
quantize_affine_floatx,
94+
_choose_qparams_affine_floatx,
95+
_quantize_affine_floatx,
9696
)
9797

9898
x = torch.randn(256, 64)
99-
scale = choose_qparams_affine_floatx(x, ebits, mbits)
100-
x = quantize_affine_floatx(x, scale, ebits, mbits)
99+
scale = _choose_qparams_affine_floatx(x, ebits, mbits)
100+
x = _quantize_affine_floatx(x, scale, ebits, mbits)
101101
_layout = FloatxTensorCoreLayout(ebits, mbits)
102102
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(
103103
x, scale, None, _layout

test/prototype/test_gguf_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
GGUFWeightOnlyConfig,
1414
)
1515
from torchao.quantization import quantize_
16-
from torchao.quantization.quant_primitives import choose_qparams_gguf
16+
from torchao.quantization.quant_primitives import _choose_qparams_gguf
1717
from torchao.quantization.utils import compute_error
1818

1919

@@ -31,7 +31,7 @@ def test_choose_qparams_gguf(self):
3131
super_block_min_scale,
3232
quantized_block_scale,
3333
quantized_block_min,
34-
) = choose_qparams_gguf(self.input, self.block_size, self.dtype)
34+
) = _choose_qparams_gguf(self.input, self.block_size, self.dtype)
3535

3636
assert super_block_scale_scale.shape, (2, 8)
3737
assert super_block_min_scale.shape, (2, 8)

test/quantization/test_marlin_qqq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from torchao.quantization.quant_primitives import (
2323
MappingType,
24-
choose_qparams_and_quantize_affine_qqq,
24+
_choose_qparams_and_quantize_affine_qqq,
2525
)
2626
from torchao.testing.utils import skip_if_rocm
2727
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
@@ -102,7 +102,7 @@ def test_pack_unpack_equivalence(self):
102102

103103
for group_size in [-1, 128]:
104104
# Quantize weights
105-
q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
105+
q_w, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq(
106106
w, num_bits, group_size
107107
)
108108

test/quantization/test_qat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
MappingType,
6565
TorchAODType,
6666
ZeroPointDomain,
67+
_fake_quantize_affine,
6768
choose_qparams_affine,
6869
dequantize_affine,
69-
fake_quantize_affine,
7070
quantize_affine,
7171
)
7272
from torchao.quantization.unified import (
@@ -637,7 +637,7 @@ def test_qat_4w_primitives(self):
637637
group_size,
638638
scales_precision,
639639
)
640-
w_fq = fake_quantize_affine(
640+
w_fq = _fake_quantize_affine(
641641
weight,
642642
block_size,
643643
scales,

test/quantization/test_quant_primitives.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from torchao.quantization.quant_primitives import (
1414
MappingType,
1515
ZeroPointDomain,
16+
_choose_qparams_affine_tinygemm,
17+
_fake_quantize_affine,
18+
_fake_quantize_affine_cachemask,
1619
choose_qparams_affine,
17-
choose_qparams_affine_tinygemm,
1820
dequantize_affine,
19-
fake_quantize_affine,
20-
fake_quantize_affine_cachemask,
2121
quantize_affine,
2222
)
2323

@@ -672,7 +672,7 @@ def test_get_groupwise_affine_qparams(self):
672672
zero_point_domain=zero_point_domain,
673673
)
674674
if zero_point_domain == ZeroPointDomain.FLOAT:
675-
scale, zero_point = choose_qparams_affine_tinygemm(
675+
scale, zero_point = _choose_qparams_affine_tinygemm(
676676
input,
677677
mapping_type,
678678
block_size,
@@ -780,7 +780,7 @@ def test_fake_quantize_affine(self):
780780
dequantized = dequantize_affine(
781781
quantized, block_size, scale, zero_point, dtype, quant_min, quant_max
782782
)
783-
fake_quantized = fake_quantize_affine(
783+
fake_quantized = _fake_quantize_affine(
784784
input, block_size, scale, zero_point, dtype, quant_min, quant_max
785785
)
786786
torch.testing.assert_close(dequantized, fake_quantized)
@@ -816,7 +816,7 @@ def test_fake_quantize_affine_cachemask(self):
816816
dequantized = dequantize_affine(
817817
quantized, block_size, scale, zero_point, dtype, quant_min, quant_max
818818
)
819-
(fake_quantized, mask) = fake_quantize_affine_cachemask(
819+
(fake_quantized, mask) = _fake_quantize_affine_cachemask(
820820
input,
821821
block_size,
822822
scale,

test/test_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
marlin_qqq_workspace,
2424
pack_to_marlin_qqq,
2525
)
26-
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
26+
from torchao.quantization.quant_primitives import (
27+
_choose_qparams_and_quantize_affine_qqq,
28+
)
2729
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
2830
from torchao.utils import (
2931
TORCH_VERSION_AT_LEAST_2_5,
@@ -713,7 +715,7 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
713715
)
714716

715717
# Quantize weights
716-
q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq(
718+
q_w, s_group, s_channel, w_ref = _choose_qparams_and_quantize_affine_qqq(
717719
b_weight, num_bits, group_size
718720
)
719721
q_w = q_w.t()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@
1818
FP8_TYPES,
1919
MappingType,
2020
ZeroPointDomain,
21+
_choose_qparams_affine_dont_preserve_zero,
22+
_choose_qparams_affine_float8,
23+
_choose_qparams_affine_floatx,
24+
_choose_qparams_affine_tinygemm,
25+
_choose_qparams_and_quantize_affine_hqq,
26+
_dequantize_affine_float8,
27+
_dequantize_affine_floatx,
28+
_dequantize_affine_no_zero_point,
29+
_dequantize_affine_tinygemm,
30+
_quantize_affine_float8,
31+
_quantize_affine_floatx,
32+
_quantize_affine_no_zero_point,
33+
_quantize_affine_tinygemm,
2134
choose_qparams_affine,
22-
choose_qparams_affine_dont_preserve_zero,
23-
choose_qparams_affine_float8,
24-
choose_qparams_affine_floatx,
25-
choose_qparams_affine_tinygemm,
26-
choose_qparams_and_quantize_affine_hqq,
2735
dequantize_affine,
28-
dequantize_affine_float8,
29-
dequantize_affine_floatx,
30-
dequantize_affine_no_zero_point,
31-
dequantize_affine_tinygemm,
3236
quantize_affine,
33-
quantize_affine_float8,
34-
quantize_affine_floatx,
35-
quantize_affine_no_zero_point,
36-
quantize_affine_tinygemm,
3737
)
3838
from torchao.utils import (
3939
TORCH_VERSION_AT_LEAST_2_5,
@@ -142,7 +142,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
142142

143143
if isinstance(self._layout, FloatxTensorCoreLayout):
144144
int_data, scale = self.tensor_impl.get_plain()
145-
return dequantize_affine_floatx(
145+
return _dequantize_affine_floatx(
146146
int_data,
147147
scale,
148148
self._layout.ebits,
@@ -151,11 +151,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
151151
)
152152
elif isinstance(self._layout, Float8Layout):
153153
data, scale, _ = self.tensor_impl.get_plain()
154-
return dequantize_affine_float8(data, scale, output_dtype)
154+
return _dequantize_affine_float8(data, scale, output_dtype)
155155
else:
156156
data, scale, zero_point = self.tensor_impl.get_plain()
157157
if self.zero_point_domain == ZeroPointDomain.FLOAT:
158-
dq = dequantize_affine_tinygemm(
158+
dq = _dequantize_affine_tinygemm(
159159
data,
160160
self.block_size,
161161
scale,
@@ -166,7 +166,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
166166
output_dtype=output_dtype,
167167
)
168168
elif self.zero_point_domain == ZeroPointDomain.NONE:
169-
dq = dequantize_affine_no_zero_point(
169+
dq = _dequantize_affine_no_zero_point(
170170
data,
171171
self.block_size,
172172
scale,
@@ -270,7 +270,7 @@ def from_hp_to_intx(
270270
from torchao.dtypes import Int4CPULayout
271271
from torchao.dtypes.uintx import TensorCoreTiledLayout
272272

273-
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
273+
data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq(
274274
input_float,
275275
nbits=nbits,
276276
group_size=group_size,
@@ -291,7 +291,7 @@ def from_hp_to_intx(
291291
data = data.to(target_dtype)
292292
else:
293293
if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
294-
scale, zero_point = choose_qparams_affine_tinygemm(
294+
scale, zero_point = _choose_qparams_affine_tinygemm(
295295
input_float,
296296
mapping_type,
297297
block_size,
@@ -303,7 +303,7 @@ def from_hp_to_intx(
303303
zero_point_dtype,
304304
)
305305
elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero:
306-
scale, zero_point = choose_qparams_affine_dont_preserve_zero(
306+
scale, zero_point = _choose_qparams_affine_dont_preserve_zero(
307307
input_float,
308308
mapping_type,
309309
block_size,
@@ -329,7 +329,7 @@ def from_hp_to_intx(
329329
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
330330
if zero_point_domain == ZeroPointDomain.NONE:
331331
zero_point = None
332-
data = quantize_affine_no_zero_point(
332+
data = _quantize_affine_no_zero_point(
333333
input_float,
334334
block_size,
335335
scale,
@@ -339,7 +339,7 @@ def from_hp_to_intx(
339339
quant_max,
340340
)
341341
elif zero_point_domain == ZeroPointDomain.FLOAT:
342-
data = quantize_affine_tinygemm(
342+
data = _quantize_affine_tinygemm(
343343
input_float,
344344
block_size,
345345
scale,
@@ -400,7 +400,7 @@ def from_hp_to_intx_static(
400400

401401
if zero_point_domain == ZeroPointDomain.NONE:
402402
zero_point = None
403-
int_data = quantize_affine_no_zero_point(
403+
int_data = _quantize_affine_no_zero_point(
404404
input_float,
405405
block_size,
406406
scale,
@@ -410,7 +410,7 @@ def from_hp_to_intx_static(
410410
quant_max,
411411
)
412412
elif zero_point_domain == ZeroPointDomain.FLOAT:
413-
int_data = quantize_affine_tinygemm(
413+
int_data = _quantize_affine_tinygemm(
414414
input_float,
415415
block_size,
416416
scale,
@@ -462,10 +462,10 @@ def from_hp_to_floatx(
462462
if target_dtype in FP8_TYPES:
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465-
scale = choose_qparams_affine_float8(
465+
scale = _choose_qparams_affine_float8(
466466
input_float, float8_dtype=target_dtype, block_size=block_size
467467
)
468-
data = quantize_affine_float8(input_float, scale, target_dtype)
468+
data = _quantize_affine_float8(input_float, scale, target_dtype)
469469
data, scale, zero_point = _layout.post_process(
470470
data, scale, None, block_size
471471
)
@@ -499,7 +499,7 @@ def from_hp_to_floatx_static(
499499
input_float, scale, ZeroPointDomain.NONE, block_size
500500
)
501501

502-
data = quantize_affine_float8(
502+
data = _quantize_affine_float8(
503503
input_float,
504504
scale,
505505
target_dtype,
@@ -545,8 +545,8 @@ def from_hp_to_fpx(
545545

546546
ebits, mbits = _layout.ebits, _layout.mbits
547547
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
548-
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
549-
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
548+
scale = _choose_qparams_affine_floatx(input_float, ebits, mbits)
549+
floatx_unpacked = _quantize_affine_floatx(input_float, scale, ebits, mbits)
550550
floatx_packed, scale, _ = _layout.post_process(
551551
floatx_unpacked, scale, None, block_size
552552
)

0 commit comments

Comments
 (0)