Skip to content

Commit 9d169a8

Browse files
authored
Add 2, 3, 4, 5 bit custom ops
Differential Revision: D62248716 Pull Request resolved: #828
1 parent 684f7cd commit 9d169a8

File tree

3 files changed

+360
-127
lines changed

3 files changed

+360
-127
lines changed

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torch_custom_op import quantize, replace_linear_with_quantized_linear
8-
import torch
97
import copy
108

11-
group_size = 16
9+
import torch
10+
from torch_custom_op import (
11+
linear_a8sz_w_lowbit_reference_impl,
12+
replace_linear_with_quantized_linear,
13+
)
14+
15+
group_size = 256
1216
m = 1
1317
n = 4096
1418
k = 4096
15-
nbit = 4
16-
n_layers = 10
19+
nbit = 5
20+
has_weight_zeros = True
21+
n_layers = 5
1722

1823
print("Creating random model")
1924
layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)]
@@ -22,8 +27,15 @@
2227

2328
print("Quantizing random model")
2429
quantized_model = copy.deepcopy(model)
25-
quantized_model = quantized_model.eval()
26-
replace_linear_with_quantized_linear(quantized_model, kwargs={"group_size": group_size, "nbit": nbit})
30+
quantized_model = quantized_model.eval()
31+
replace_linear_with_quantized_linear(
32+
quantized_model,
33+
kwargs={
34+
"group_size": group_size,
35+
"nbit": nbit,
36+
"has_weight_zeros": has_weight_zeros,
37+
},
38+
)
2739

2840
print("Creating random activations")
2941
activations = torch.randn(m, k, dtype=torch.float32)
@@ -48,36 +60,42 @@
4860
fn(activations)
4961

5062

51-
print("Checking correctness on layer 0")
52-
53-
rtol=1e-05
54-
55-
# default is 1e-8, but PyTorch and C++ (and ARM neon) have different rounding
56-
# conventions for ties (PyTorch rounds half to even and C++ rounds half to odd)
57-
# TODO(T200109708): address this
58-
atol=1e-05
59-
63+
print("\nChecking correctness on layer 0")
6064
linear = model[0]
6165
quantized_linear = quantized_model[0]
62-
weight_qvals, weight_scales = quantize(linear.weight, group_size, quantized_linear.nbit, scale_only=True)
63-
64-
activation_qvals, activations_scales, activations_zeros = quantize(activations, k, 8, False)
65-
activations_dequantized = activations_scales * (activation_qvals - activations_zeros)
66-
weights_dequantized = (weight_qvals.reshape(-1, group_size) * weight_scales.reshape(-1, 1)).reshape(n, k)
6766

6867
with torch.no_grad():
6968
result = quantized_linear(activations)
70-
expected_result = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0))
69+
expected_result = linear_a8sz_w_lowbit_reference_impl(
70+
linear.weight, activations, group_size, nbit, has_weight_zeros
71+
)
7172
non_quantized_result = linear(activations)
7273

73-
if not (torch.allclose(result, expected_result, rtol=rtol, atol=atol)):
74-
rand_idxs = torch.randint(0, result.shape[1], (5,))
75-
print("rand_idxs: ", rand_idxs)
76-
print("kernel_result[rand_idxs]: ", result[0][rand_idxs])
77-
print("expected_result[rand_idxs]: ", expected_result[0][rand_idxs])
78-
assert False
79-
else:
80-
print("Correctness check passed")
81-
82-
print("kernel_result[0:5]: ", result[0][0:5])
83-
print("non_quantized_result[0:5]: ", non_quantized_result[0][0:5])
74+
75+
# Check that entries in result match entries in expected_result
76+
num_mismatch_at_low_tol = 0
77+
num_total = result.reshape(-1).shape[0]
78+
for i in range(num_total):
79+
actual_val = result.reshape(-1)[i]
80+
expected_val = expected_result.reshape(-1)[i]
81+
if not torch.allclose(actual_val, expected_val):
82+
num_mismatch_at_low_tol += 1
83+
84+
# If results are not close at a relaxed tolerance, exit with failure
85+
if not torch.allclose(actual_val, expected_val, atol=1e-6):
86+
assert False, "Correctness check failed"
87+
88+
# Assert at most 5% of entries are not close at a low tolerance
89+
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
90+
print(
91+
"Correctness check passed. All results are close, and ",
92+
(num_total - num_mismatch_at_low_tol),
93+
"/",
94+
num_total,
95+
" entries are close at a low tolerance.",
96+
)
97+
print("Quantization errors:")
98+
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
99+
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
100+
print("\tquantized_result[0:5]: ", result[0][0:5])
101+
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.cpp

Lines changed: 164 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <torchao/experimental/kernels/cpu/parallel.h>
1212

1313
template <int weight_nbit>
14-
at::Tensor pack_weights_cpu(
14+
at::Tensor pack_weights_without_zeros_cpu(
1515
const at::Tensor& weight_qvals,
1616
const at::Tensor& weight_scales,
1717
// TODO(T200095131): convert to int64_t when supported by AOTI
@@ -54,9 +54,8 @@ at::Tensor pack_weights_cpu(
5454

5555
auto packed_weight_data_size =
5656
get_packed_weight_data_size(ukernel_config, n, k, group_size);
57-
auto options = torch::TensorOptions().dtype(torch::kInt8);
58-
59-
at::Tensor packed_weights = torch::empty({packed_weight_data_size}, options);
57+
at::Tensor packed_weights =
58+
torch::empty({packed_weight_data_size}, torch::kInt8);
6059
pack_weight_data_operator(
6160
ukernel_config,
6261
pack_weight_tiling_params,
@@ -72,7 +71,74 @@ at::Tensor pack_weights_cpu(
7271
}
7372

7473
template <int weight_nbit>
75-
at::Tensor pack_weights_meta(
74+
at::Tensor pack_weights_with_zeros_cpu(
75+
const at::Tensor& weight_qvals,
76+
const at::Tensor& weight_scales,
77+
const at::Tensor& weight_zeros,
78+
// TODO(T200095131): convert to int64_t when supported by AOTI
79+
// group_size is a meta tensor with size (group_size)
80+
const at::Tensor& group_size_tensor) {
81+
int64_t group_size = group_size_tensor.size(0);
82+
83+
TORCH_CHECK(
84+
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");
85+
TORCH_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");
86+
87+
// In PyTorch, weights are nxk in row-major format (with activations being
88+
// right-multiplied).
89+
// In kernel, activations are left-multiplied by kxn transposed
90+
// weights in column-major format.
91+
// Note the underlying data is the same in both cases
92+
int n = weight_qvals.size(0);
93+
int k = weight_qvals.size(1);
94+
95+
TORCH_CHECK(
96+
weight_scales.dtype() == torch::kFloat32,
97+
"weight_scales must be float32");
98+
TORCH_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D");
99+
TORCH_CHECK(
100+
weight_scales.size(0) == ((n * k) / group_size),
101+
"expected 1 scale per group");
102+
TORCH_CHECK(
103+
weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8");
104+
TORCH_CHECK(weight_zeros.dim() == 1, "weight_zeros must be 1D");
105+
TORCH_CHECK(
106+
weight_zeros.size(0) == ((n * k) / group_size),
107+
"expected 1 zero per group");
108+
109+
using namespace torchao::operators::cpu::linear::
110+
channelwise_8bit_activation_groupwise_lowbit_weight;
111+
112+
auto ukernel_config = get_ukernel_config<
113+
weight_nbit,
114+
true /*has_weight_zeros*/,
115+
false /*has_bias*/,
116+
false /*has_clamp*/>();
117+
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
118+
ukernel_config, n, /*target_panels_per_thread=*/1);
119+
120+
torchao::set_num_threads(torch::get_num_threads());
121+
122+
auto packed_weight_data_size =
123+
get_packed_weight_data_size(ukernel_config, n, k, group_size);
124+
at::Tensor packed_weights =
125+
torch::empty({packed_weight_data_size}, torch::kInt8);
126+
pack_weight_data_operator(
127+
ukernel_config,
128+
pack_weight_tiling_params,
129+
packed_weights.data_ptr<int8_t>(),
130+
n,
131+
k,
132+
group_size,
133+
weight_qvals.const_data_ptr<int8_t>(),
134+
weight_scales.const_data_ptr<float>(),
135+
weight_zeros.const_data_ptr<int8_t>());
136+
137+
return packed_weights;
138+
}
139+
140+
template <int weight_nbit>
141+
at::Tensor pack_weights_without_zeros_meta(
76142
const at::Tensor& weight_qvals,
77143
const at::Tensor& weight_scales,
78144
// TODO(T200095131): convert to int64_t when supported by AOTI
@@ -98,6 +164,33 @@ at::Tensor pack_weights_meta(
98164
}
99165

100166
template <int weight_nbit>
167+
at::Tensor pack_weights_with_zeros_meta(
168+
const at::Tensor& weight_qvals,
169+
const at::Tensor& weight_scales,
170+
const at::Tensor& weight_zeros,
171+
// TODO(T200095131): convert to int64_t when supported by AOTI
172+
// group_size is a meta tensor with size (group_size)
173+
const at::Tensor& group_size_tensor) {
174+
int64_t group_size = group_size_tensor.size(0);
175+
176+
int n = weight_qvals.size(0);
177+
int k = weight_qvals.size(1);
178+
179+
using namespace torchao::operators::cpu::linear::
180+
channelwise_8bit_activation_groupwise_lowbit_weight;
181+
182+
auto ukernel_config = get_ukernel_config<
183+
weight_nbit,
184+
true /*has_weight_zeros*/,
185+
false /*has_bias*/,
186+
false /*has_clamp*/>();
187+
188+
auto packed_weight_data_size =
189+
get_packed_weight_data_size(ukernel_config, n, k, group_size);
190+
return torch::empty({packed_weight_data_size}).to("meta");
191+
}
192+
193+
template <int weight_nbit, bool has_weight_zeros>
101194
at::Tensor linear_cpu(
102195
const at::Tensor& packed_weights,
103196
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
@@ -123,7 +216,7 @@ at::Tensor linear_cpu(
123216

124217
auto ukernel_config = get_ukernel_config<
125218
weight_nbit,
126-
false /*has_weight_zeros*/,
219+
has_weight_zeros /*has_weight_zeros*/,
127220
false /*has_bias*/,
128221
false /*has_clamp*/>();
129222
auto linear_tiling_params = get_default_linear_tiling_params(
@@ -167,7 +260,7 @@ at::Tensor linear_cpu(
167260
return output_tensor;
168261
}
169262

170-
template <int weight_nbit>
263+
template <int weight_nbit, bool has_weight_zeros>
171264
at::Tensor linear_meta(
172265
const at::Tensor& packed_weights,
173266
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
@@ -187,26 +280,78 @@ at::Tensor linear_meta(
187280
}
188281

189282
TORCH_LIBRARY(torchao, m) {
283+
// Pack weights without zeros
284+
m.def(
285+
"_pack_weights_a8sz_w2s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
286+
m.def(
287+
"_pack_weights_a8sz_w3s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
288+
m.def(
289+
"_pack_weights_a8sz_w4s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
290+
m.def(
291+
"_pack_weights_a8sz_w5s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
292+
// Pack weights with zeros
293+
m.def(
294+
"_pack_weights_a8sz_w2sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
295+
m.def(
296+
"_pack_weights_a8sz_w3sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
297+
m.def(
298+
"_pack_weights_a8sz_w4sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
299+
m.def(
300+
"_pack_weights_a8sz_w5sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
301+
// Linear weights without zeros
302+
m.def(
303+
"_linear_a8sz_w2s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
304+
m.def(
305+
"_linear_a8sz_w3s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
306+
m.def(
307+
"_linear_a8sz_w4s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
308+
m.def(
309+
"_linear_a8sz_w5s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
310+
// Linear weights with zeros
190311
m.def(
191-
"_pack_weights_3bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
312+
"_linear_a8sz_w2sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
192313
m.def(
193-
"_linear_3bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
314+
"_linear_a8sz_w3sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
194315
m.def(
195-
"_pack_weights_4bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
316+
"_linear_a8sz_w4sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
196317
m.def(
197-
"_linear_4bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
318+
"_linear_a8sz_w5sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
198319
}
199320

200321
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
201-
m.impl("_pack_weights_3bit", &pack_weights_cpu<3>);
202-
m.impl("_linear_3bit", &linear_cpu<3>);
203-
m.impl("_pack_weights_4bit", &pack_weights_cpu<4>);
204-
m.impl("_linear_4bit", &linear_cpu<4>);
322+
m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_cpu<2>);
323+
m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_cpu<3>);
324+
m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_cpu<4>);
325+
m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_cpu<5>);
326+
m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_cpu<2>);
327+
m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_cpu<3>);
328+
m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_cpu<4>);
329+
m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_cpu<5>);
330+
m.impl("_linear_a8sz_w2s", &linear_cpu<2, false>);
331+
m.impl("_linear_a8sz_w3s", &linear_cpu<3, false>);
332+
m.impl("_linear_a8sz_w4s", &linear_cpu<4, false>);
333+
m.impl("_linear_a8sz_w5s", &linear_cpu<5, false>);
334+
m.impl("_linear_a8sz_w2sz", &linear_cpu<2, true>);
335+
m.impl("_linear_a8sz_w3sz", &linear_cpu<3, true>);
336+
m.impl("_linear_a8sz_w4sz", &linear_cpu<4, true>);
337+
m.impl("_linear_a8sz_w5sz", &linear_cpu<5, true>);
205338
}
206339

207340
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
208-
m.impl("_pack_weights_3bit", &pack_weights_meta<3>);
209-
m.impl("_linear_3bit", &linear_meta<3>);
210-
m.impl("_pack_weights_4bit", &pack_weights_meta<4>);
211-
m.impl("_linear_4bit", &linear_meta<4>);
341+
m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_meta<2>);
342+
m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_meta<3>);
343+
m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_meta<4>);
344+
m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_meta<5>);
345+
m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_meta<2>);
346+
m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_meta<3>);
347+
m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_meta<4>);
348+
m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_meta<5>);
349+
m.impl("_linear_a8sz_w2s", &linear_meta<2, false>);
350+
m.impl("_linear_a8sz_w3s", &linear_meta<3, false>);
351+
m.impl("_linear_a8sz_w4s", &linear_meta<4, false>);
352+
m.impl("_linear_a8sz_w5s", &linear_meta<5, false>);
353+
m.impl("_linear_a8sz_w2sz", &linear_meta<2, true>);
354+
m.impl("_linear_a8sz_w3sz", &linear_meta<3, true>);
355+
m.impl("_linear_a8sz_w4sz", &linear_meta<4, true>);
356+
m.impl("_linear_a8sz_w5sz", &linear_meta<5, true>);
212357
}

0 commit comments

Comments
 (0)