Skip to content

Commit d187f78

Browse files
committed
Add support for Float8ActivationInt4GroupwisePreshuffleTensor for fbgemm
Summary: Note: slice is not working yet, others are working Test Plan: python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2437, branch: jerryzh168/stack/4
1 parent 9c2d239 commit d187f78

File tree

4 files changed

+631
-1
lines changed

4 files changed

+631
-1
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing._internal.common_utils import (
11+
TestCase,
12+
run_tests,
13+
)
14+
15+
from torchao.quantization import (
16+
FbgemmConfig,
17+
quantize_,
18+
)
19+
from torchao.quantization.utils import compute_error
20+
from torchao.utils import (
21+
TORCH_VERSION_AT_LEAST_2_8,
22+
is_sm_at_least_90,
23+
)
24+
25+
26+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
28+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29+
class TestInt4GroupwisePreshuffleTensor(TestCase):
30+
def setUp(self):
31+
self.config = FbgemmConfig(
32+
input_dtype=torch.bfloat16,
33+
weight_dtype=torch.int4,
34+
output_dtype=torch.bfloat16,
35+
block_size=[1, 128],
36+
preshuffle=True,
37+
float8_activation=True,
38+
)
39+
self.bmm_config = FbgemmConfig(
40+
input_dtype=torch.bfloat16,
41+
weight_dtype=torch.int4,
42+
output_dtype=torch.bfloat16,
43+
block_size=[1, 1, 128],
44+
preshuffle=True,
45+
float8_activation=True,
46+
)
47+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
48+
49+
def test_linear(self):
50+
dtype = torch.bfloat16
51+
device = "cuda"
52+
input = torch.randn(1, 128, dtype=dtype, device=device)
53+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
54+
original = linear(input)
55+
quantize_(linear, self.config)
56+
quantized = linear(input)
57+
self.assertTrue(compute_error(original, quantized) > 20)
58+
59+
# @unittest.skip("WIP: this doesn't work yet")
60+
def test_slice(self):
61+
dtype = torch.bfloat16
62+
device = "cuda"
63+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
64+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
65+
dummy1.weight = torch.nn.Parameter(
66+
dummy.weight.narrow(0, 0, 64), requires_grad=False
67+
)
68+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
69+
dummy2.weight = torch.nn.Parameter(
70+
dummy.weight.narrow(1, 0, 128), requires_grad=False
71+
)
72+
73+
quantize_(dummy, self.config)
74+
weight1 = dummy.weight.narrow(0, 0, 64)
75+
weight2 = dummy.weight.narrow(1, 0, 128)
76+
# check the slicing operation is correctly performend of the constituents Tensors
77+
self.assertEqual(
78+
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
79+
)
80+
self.assertEqual(weight1.group_scale, dummy.weight.group_scale.narrow(2, 0, 64))
81+
self.assertEqual(
82+
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
83+
)
84+
self.assertEqual(weight2.group_scale, dummy.weight.group_scale.narrow(0, 0, 1))
85+
86+
# check for 1. sliced bf16 weight 2. sliced quantized weight
87+
# can produce similar results doing matmul on the same input Tensor
88+
89+
input = torch.randn(2, 256, dtype=dtype, device=device)
90+
res_ref = dummy1(input)
91+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
92+
res = dummy(input)
93+
sqnr = compute_error(res, res_ref)
94+
assert sqnr > 20, f"Got: {sqnr}"
95+
96+
input = torch.randn(2, 128, dtype=dtype, device=device)
97+
res_ref = dummy2(input)
98+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
99+
res = dummy(input)
100+
sqnr = compute_error(res, res_ref)
101+
assert sqnr > 15, f"Got: {sqnr}"
102+
103+
def test_slice_and_copy_(self):
104+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
105+
l.weight = torch.nn.Parameter(
106+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
107+
)
108+
quantize_(l, self.config)
109+
param = l.weight
110+
param_data = param.data
111+
param_data = param_data.narrow(0, 0, 512)
112+
assert (
113+
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
114+
)
115+
assert param.data.group_scale.data_ptr() == param_data.group_scale.data_ptr()
116+
assert param.data.row_scale.data_ptr() == param_data.row_scale.data_ptr()
117+
orig_value = param.data.packed_weight[0][0].item()
118+
119+
# dummy_l has random input (shouldn't be 0)
120+
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
121+
quantize_(dummy_l, self.config)
122+
quantized = dummy_l.weight
123+
quantized = quantized.narrow(0, 0, 512)
124+
125+
param_data.copy_(quantized)
126+
127+
# making sure param.data is updated
128+
assert param.data.packed_weight[0][0] != orig_value
129+
130+
def test_bmm(self):
131+
class M(torch.nn.Module):
132+
def __init__(self, weight):
133+
super().__init__()
134+
self.weight = weight
135+
136+
def forward(self, x):
137+
return torch.bmm(x, self.weight)
138+
139+
dtype = torch.bfloat16
140+
device = "cuda"
141+
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
142+
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
143+
m = M(weight).eval()
144+
original = m(input)
145+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
146+
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
147+
quantized = m(input)
148+
self.assertTrue(compute_error(original, quantized) > 18)
149+
150+
def test_to_device(self):
151+
for device in self.GPU_DEVICES:
152+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
153+
quantize_(linear, self.config)
154+
linear.to(device)
155+
156+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
157+
quantize_(linear, self.config)
158+
linear.to(device=device)
159+
160+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
161+
quantize_(linear, self.config)
162+
linear.to(device)
163+
164+
165+
if __name__ == "__main__":
166+
run_tests()

torchao/dtypes/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
)
1111
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
1212
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
13+
from .float8_activation_int4_groupwise_preshuffle_tensor import (
14+
Float8ActivationInt4GroupwisePreshuffleTensor,
15+
to_float8_activation_int4_groupwise_preshuffle,
16+
)
1317
from .floatx import (
1418
CutlassSemiSparseLayout,
1519
Float8Layout,
@@ -73,4 +77,6 @@
7377
"FbgemmFp8Tensor",
7478
"Int4GroupwisePreshuffleTensor",
7579
"to_int4_groupwise_preshuffle",
80+
"Float8ActivationInt4GroupwisePreshuffleTensor",
81+
"to_float8_activation_int4_groupwise_preshuffle",
7682
]

0 commit comments

Comments
 (0)