Skip to content

Commit 65a1373

Browse files
committed
Summary:
Note: slice is not working yet, others are working Test Plan: python test/dtypes/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2421, branch: jerryzh168/stack/1
1 parent 2898903 commit 65a1373

File tree

4 files changed

+595
-5
lines changed

4 files changed

+595
-5
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
FbgemmConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.utils import compute_error
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_8,
22+
is_sm_at_least_90,
23+
)
24+
25+
26+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
28+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29+
class TestInt4GroupwisePreshuffleTensor(TestCase):
30+
def setUp(self):
31+
self.config = FbgemmConfig(
32+
input_dtype=torch.bfloat16,
33+
weight_dtype=torch.int4,
34+
output_dtype=torch.bfloat16,
35+
block_size=[1, 128],
36+
preshuffle=True,
37+
)
38+
self.bmm_config = FbgemmConfig(
39+
input_dtype=torch.bfloat16,
40+
weight_dtype=torch.int4,
41+
output_dtype=torch.bfloat16,
42+
block_size=[1, 1, 128],
43+
preshuffle=True,
44+
)
45+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
46+
47+
def test_linear(self):
48+
dtype = torch.bfloat16
49+
device = "cuda"
50+
input = torch.randn(1, 128, dtype=dtype, device=device)
51+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
52+
original = linear(input)
53+
quantize_(linear, self.config)
54+
quantized = linear(input)
55+
self.assertTrue(compute_error(original, quantized) > 20)
56+
57+
@unittest.skip("WIP: this doesn't work yet")
58+
def test_slice(self):
59+
dtype = torch.bfloat16
60+
device = "cuda"
61+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
62+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
63+
dummy1.weight = torch.nn.Parameter(
64+
dummy.weight.narrow(0, 0, 64), requires_grad=False
65+
)
66+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
67+
dummy2.weight = torch.nn.Parameter(
68+
dummy.weight.narrow(1, 0, 128), requires_grad=False
69+
)
70+
71+
quantize_(dummy, self.config)
72+
weight1 = dummy.weight.narrow(0, 0, 64)
73+
weight2 = dummy.weight.narrow(1, 0, 128)
74+
self.assertEqual(
75+
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
76+
)
77+
self.assertEqual(weight1.group_scale, dummy.weight.group_scale.narrow(1, 0, 64))
78+
self.assertEqual(
79+
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
80+
)
81+
self.assertEqual(weight2.group_scale, dummy.weight.group_scale.narrow(0, 0, 1))
82+
83+
# check for sliced weight, before and after float8 quantization
84+
# does not differ too much
85+
input = torch.randn(2, 256, dtype=dtype, device=device)
86+
res_ref = dummy1(input)
87+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
88+
res = dummy(input)
89+
sqnr = compute_error(res, res_ref)
90+
assert sqnr > 20, f"Got: {sqnr}"
91+
92+
input = torch.randn(2, 128, dtype=dtype, device=device)
93+
res_ref = dummy2(input)
94+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
95+
res = dummy(input)
96+
sqnr = compute_error(res, res_ref)
97+
assert sqnr > 15, f"Got: {sqnr}"
98+
99+
def test_slice_and_copy_(self):
100+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
101+
l.weight = torch.nn.Parameter(
102+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
103+
)
104+
quantize_(l, self.config)
105+
param = l.weight
106+
param_data = param.data
107+
param_data = param_data.narrow(0, 0, 512)
108+
assert (
109+
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
110+
)
111+
assert param.data.group_scale.data_ptr() == param_data.group_scale.data_ptr()
112+
assert param.data.row_scale.data_ptr() == param_data.row_scale.data_ptr()
113+
orig_value = param.data.packed_weight[0][0].item()
114+
115+
# dummy_l has random input (shouldn't be 0)
116+
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
117+
quantize_(dummy_l, self.config)
118+
quantized = dummy_l.weight
119+
quantized = quantized.narrow(0, 0, 512)
120+
121+
param_data.copy_(quantized)
122+
123+
# making sure param.data is updated
124+
assert param.data.packed_weight[0][0] != orig_value
125+
126+
def test_bmm(self):
127+
class M(torch.nn.Module):
128+
def __init__(self, weight):
129+
super().__init__()
130+
self.weight = weight
131+
132+
def forward(self, x):
133+
return torch.bmm(x, self.weight)
134+
135+
dtype = torch.bfloat16
136+
device = "cuda"
137+
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
138+
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
139+
m = M(weight).eval()
140+
original = m(input)
141+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
142+
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
143+
quantized = m(input)
144+
self.assertTrue(compute_error(original, quantized) > 18)
145+
146+
def test_to_device(self):
147+
for device in self.GPU_DEVICES:
148+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
149+
quantize_(linear, self.config)
150+
linear.to(device)
151+
152+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
153+
quantize_(linear, self.config)
154+
linear.to(device=device)
155+
156+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
157+
quantize_(linear, self.config)
158+
linear.to(device)
159+
160+
161+
if __name__ == "__main__":
162+
run_tests()

torchao/dtypes/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
CutlassSemiSparseLayout,
1515
Float8Layout,
1616
)
17+
from .int4_groupwise_preshuffle_tensor import (
18+
Int4GroupwisePreshuffleTensor,
19+
to_int4_groupwise_preshuffle,
20+
)
1721
from .nf4tensor import NF4Tensor, to_nf4
1822
from .uintx import (
1923
BlockSparseLayout,
@@ -67,4 +71,6 @@
6771
"FbgemmInt4Tensor",
6872
"to_fbgemm_fp8",
6973
"FbgemmFp8Tensor",
74+
"Int4GroupwisePreshuffleTensor",
75+
"to_int4_groupwise_preshuffle",
7076
]

0 commit comments

Comments
 (0)