11
11
#include < torchao/experimental/kernels/cpu/parallel.h>
12
12
13
13
template <int weight_nbit>
14
- at::Tensor pack_weights_cpu (
14
+ at::Tensor pack_weights_without_zeros_cpu (
15
15
const at::Tensor& weight_qvals,
16
16
const at::Tensor& weight_scales,
17
17
// TODO(T200095131): convert to int64_t when supported by AOTI
@@ -54,9 +54,8 @@ at::Tensor pack_weights_cpu(
54
54
55
55
auto packed_weight_data_size =
56
56
get_packed_weight_data_size (ukernel_config, n, k, group_size);
57
- auto options = torch::TensorOptions ().dtype (torch::kInt8 );
58
-
59
- at::Tensor packed_weights = torch::empty ({packed_weight_data_size}, options);
57
+ at::Tensor packed_weights =
58
+ torch::empty ({packed_weight_data_size}, torch::kInt8 );
60
59
pack_weight_data_operator (
61
60
ukernel_config,
62
61
pack_weight_tiling_params,
@@ -72,7 +71,74 @@ at::Tensor pack_weights_cpu(
72
71
}
73
72
74
73
template <int weight_nbit>
75
- at::Tensor pack_weights_meta (
74
+ at::Tensor pack_weights_with_zeros_cpu (
75
+ const at::Tensor& weight_qvals,
76
+ const at::Tensor& weight_scales,
77
+ const at::Tensor& weight_zeros,
78
+ // TODO(T200095131): convert to int64_t when supported by AOTI
79
+ // group_size is a meta tensor with size (group_size)
80
+ const at::Tensor& group_size_tensor) {
81
+ int64_t group_size = group_size_tensor.size (0 );
82
+
83
+ TORCH_CHECK (
84
+ weight_qvals.dtype () == torch::kInt8 , " weight_qvals must be int8" );
85
+ TORCH_CHECK (weight_qvals.dim () == 2 , " weight_qvals must be 2D" );
86
+
87
+ // In PyTorch, weights are nxk in row-major format (with activations being
88
+ // right-multiplied).
89
+ // In kernel, activations are left-multiplied by kxn transposed
90
+ // weights in column-major format.
91
+ // Note the underlying data is the same in both cases
92
+ int n = weight_qvals.size (0 );
93
+ int k = weight_qvals.size (1 );
94
+
95
+ TORCH_CHECK (
96
+ weight_scales.dtype () == torch::kFloat32 ,
97
+ " weight_scales must be float32" );
98
+ TORCH_CHECK (weight_scales.dim () == 1 , " weight_scales must be 1D" );
99
+ TORCH_CHECK (
100
+ weight_scales.size (0 ) == ((n * k) / group_size),
101
+ " expected 1 scale per group" );
102
+ TORCH_CHECK (
103
+ weight_zeros.dtype () == torch::kInt8 , " weight_zeros must be int8" );
104
+ TORCH_CHECK (weight_zeros.dim () == 1 , " weight_zeros must be 1D" );
105
+ TORCH_CHECK (
106
+ weight_zeros.size (0 ) == ((n * k) / group_size),
107
+ " expected 1 zero per group" );
108
+
109
+ using namespace torchao ::operators::cpu::linear::
110
+ channelwise_8bit_activation_groupwise_lowbit_weight;
111
+
112
+ auto ukernel_config = get_ukernel_config<
113
+ weight_nbit,
114
+ true /* has_weight_zeros*/ ,
115
+ false /* has_bias*/ ,
116
+ false /* has_clamp*/ >();
117
+ auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params (
118
+ ukernel_config, n, /* target_panels_per_thread=*/ 1 );
119
+
120
+ torchao::set_num_threads (torch::get_num_threads ());
121
+
122
+ auto packed_weight_data_size =
123
+ get_packed_weight_data_size (ukernel_config, n, k, group_size);
124
+ at::Tensor packed_weights =
125
+ torch::empty ({packed_weight_data_size}, torch::kInt8 );
126
+ pack_weight_data_operator (
127
+ ukernel_config,
128
+ pack_weight_tiling_params,
129
+ packed_weights.data_ptr <int8_t >(),
130
+ n,
131
+ k,
132
+ group_size,
133
+ weight_qvals.const_data_ptr <int8_t >(),
134
+ weight_scales.const_data_ptr <float >(),
135
+ weight_zeros.const_data_ptr <int8_t >());
136
+
137
+ return packed_weights;
138
+ }
139
+
140
+ template <int weight_nbit>
141
+ at::Tensor pack_weights_without_zeros_meta (
76
142
const at::Tensor& weight_qvals,
77
143
const at::Tensor& weight_scales,
78
144
// TODO(T200095131): convert to int64_t when supported by AOTI
@@ -98,6 +164,33 @@ at::Tensor pack_weights_meta(
98
164
}
99
165
100
166
template <int weight_nbit>
167
+ at::Tensor pack_weights_with_zeros_meta (
168
+ const at::Tensor& weight_qvals,
169
+ const at::Tensor& weight_scales,
170
+ const at::Tensor& weight_zeros,
171
+ // TODO(T200095131): convert to int64_t when supported by AOTI
172
+ // group_size is a meta tensor with size (group_size)
173
+ const at::Tensor& group_size_tensor) {
174
+ int64_t group_size = group_size_tensor.size (0 );
175
+
176
+ int n = weight_qvals.size (0 );
177
+ int k = weight_qvals.size (1 );
178
+
179
+ using namespace torchao ::operators::cpu::linear::
180
+ channelwise_8bit_activation_groupwise_lowbit_weight;
181
+
182
+ auto ukernel_config = get_ukernel_config<
183
+ weight_nbit,
184
+ true /* has_weight_zeros*/ ,
185
+ false /* has_bias*/ ,
186
+ false /* has_clamp*/ >();
187
+
188
+ auto packed_weight_data_size =
189
+ get_packed_weight_data_size (ukernel_config, n, k, group_size);
190
+ return torch::empty ({packed_weight_data_size}).to (" meta" );
191
+ }
192
+
193
+ template <int weight_nbit, bool has_weight_zeros>
101
194
at::Tensor linear_cpu (
102
195
const at::Tensor& packed_weights,
103
196
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
@@ -123,7 +216,7 @@ at::Tensor linear_cpu(
123
216
124
217
auto ukernel_config = get_ukernel_config<
125
218
weight_nbit,
126
- false /* has_weight_zeros*/ ,
219
+ has_weight_zeros /* has_weight_zeros*/ ,
127
220
false /* has_bias*/ ,
128
221
false /* has_clamp*/ >();
129
222
auto linear_tiling_params = get_default_linear_tiling_params (
@@ -167,7 +260,7 @@ at::Tensor linear_cpu(
167
260
return output_tensor;
168
261
}
169
262
170
- template <int weight_nbit>
263
+ template <int weight_nbit, bool has_weight_zeros >
171
264
at::Tensor linear_meta (
172
265
const at::Tensor& packed_weights,
173
266
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
@@ -187,26 +280,78 @@ at::Tensor linear_meta(
187
280
}
188
281
189
282
TORCH_LIBRARY (torchao, m) {
283
+ // Pack weights without zeros
284
+ m.def (
285
+ " _pack_weights_a8sz_w2s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
286
+ m.def (
287
+ " _pack_weights_a8sz_w3s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
288
+ m.def (
289
+ " _pack_weights_a8sz_w4s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
290
+ m.def (
291
+ " _pack_weights_a8sz_w5s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor" );
292
+ // Pack weights with zeros
293
+ m.def (
294
+ " _pack_weights_a8sz_w2sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
295
+ m.def (
296
+ " _pack_weights_a8sz_w3sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
297
+ m.def (
298
+ " _pack_weights_a8sz_w4sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
299
+ m.def (
300
+ " _pack_weights_a8sz_w5sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor" );
301
+ // Linear weights without zeros
302
+ m.def (
303
+ " _linear_a8sz_w2s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
304
+ m.def (
305
+ " _linear_a8sz_w3s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
306
+ m.def (
307
+ " _linear_a8sz_w4s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
308
+ m.def (
309
+ " _linear_a8sz_w5s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
310
+ // Linear weights with zeros
190
311
m.def (
191
- " _pack_weights_3bit (Tensor weight_qvals , Tensor weight_scales , Tensor group_size) -> Tensor" );
312
+ " _linear_a8sz_w2sz (Tensor packed_weights , Tensor n , Tensor k, Tensor group_size, Tensor activations ) -> Tensor" );
192
313
m.def (
193
- " _linear_3bit (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
314
+ " _linear_a8sz_w3sz (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
194
315
m.def (
195
- " _pack_weights_4bit (Tensor weight_qvals , Tensor weight_scales , Tensor group_size) -> Tensor" );
316
+ " _linear_a8sz_w4sz (Tensor packed_weights , Tensor n , Tensor k, Tensor group_size, Tensor activations ) -> Tensor" );
196
317
m.def (
197
- " _linear_4bit (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
318
+ " _linear_a8sz_w5sz (Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor" );
198
319
}
199
320
200
321
TORCH_LIBRARY_IMPL (torchao, CPU, m) {
201
- m.impl (" _pack_weights_3bit" , &pack_weights_cpu<3 >);
202
- m.impl (" _linear_3bit" , &linear_cpu<3 >);
203
- m.impl (" _pack_weights_4bit" , &pack_weights_cpu<4 >);
204
- m.impl (" _linear_4bit" , &linear_cpu<4 >);
322
+ m.impl (" _pack_weights_a8sz_w2s" , &pack_weights_without_zeros_cpu<2 >);
323
+ m.impl (" _pack_weights_a8sz_w3s" , &pack_weights_without_zeros_cpu<3 >);
324
+ m.impl (" _pack_weights_a8sz_w4s" , &pack_weights_without_zeros_cpu<4 >);
325
+ m.impl (" _pack_weights_a8sz_w5s" , &pack_weights_without_zeros_cpu<5 >);
326
+ m.impl (" _pack_weights_a8sz_w2sz" , &pack_weights_with_zeros_cpu<2 >);
327
+ m.impl (" _pack_weights_a8sz_w3sz" , &pack_weights_with_zeros_cpu<3 >);
328
+ m.impl (" _pack_weights_a8sz_w4sz" , &pack_weights_with_zeros_cpu<4 >);
329
+ m.impl (" _pack_weights_a8sz_w5sz" , &pack_weights_with_zeros_cpu<5 >);
330
+ m.impl (" _linear_a8sz_w2s" , &linear_cpu<2 , false >);
331
+ m.impl (" _linear_a8sz_w3s" , &linear_cpu<3 , false >);
332
+ m.impl (" _linear_a8sz_w4s" , &linear_cpu<4 , false >);
333
+ m.impl (" _linear_a8sz_w5s" , &linear_cpu<5 , false >);
334
+ m.impl (" _linear_a8sz_w2sz" , &linear_cpu<2 , true >);
335
+ m.impl (" _linear_a8sz_w3sz" , &linear_cpu<3 , true >);
336
+ m.impl (" _linear_a8sz_w4sz" , &linear_cpu<4 , true >);
337
+ m.impl (" _linear_a8sz_w5sz" , &linear_cpu<5 , true >);
205
338
}
206
339
207
340
TORCH_LIBRARY_IMPL (torchao, Meta, m) {
208
- m.impl (" _pack_weights_3bit" , &pack_weights_meta<3 >);
209
- m.impl (" _linear_3bit" , &linear_meta<3 >);
210
- m.impl (" _pack_weights_4bit" , &pack_weights_meta<4 >);
211
- m.impl (" _linear_4bit" , &linear_meta<4 >);
341
+ m.impl (" _pack_weights_a8sz_w2s" , &pack_weights_without_zeros_meta<2 >);
342
+ m.impl (" _pack_weights_a8sz_w3s" , &pack_weights_without_zeros_meta<3 >);
343
+ m.impl (" _pack_weights_a8sz_w4s" , &pack_weights_without_zeros_meta<4 >);
344
+ m.impl (" _pack_weights_a8sz_w5s" , &pack_weights_without_zeros_meta<5 >);
345
+ m.impl (" _pack_weights_a8sz_w2sz" , &pack_weights_with_zeros_meta<2 >);
346
+ m.impl (" _pack_weights_a8sz_w3sz" , &pack_weights_with_zeros_meta<3 >);
347
+ m.impl (" _pack_weights_a8sz_w4sz" , &pack_weights_with_zeros_meta<4 >);
348
+ m.impl (" _pack_weights_a8sz_w5sz" , &pack_weights_with_zeros_meta<5 >);
349
+ m.impl (" _linear_a8sz_w2s" , &linear_meta<2 , false >);
350
+ m.impl (" _linear_a8sz_w3s" , &linear_meta<3 , false >);
351
+ m.impl (" _linear_a8sz_w4s" , &linear_meta<4 , false >);
352
+ m.impl (" _linear_a8sz_w5s" , &linear_meta<5 , false >);
353
+ m.impl (" _linear_a8sz_w2sz" , &linear_meta<2 , true >);
354
+ m.impl (" _linear_a8sz_w3sz" , &linear_meta<3 , true >);
355
+ m.impl (" _linear_a8sz_w4sz" , &linear_meta<4 , true >);
356
+ m.impl (" _linear_a8sz_w5sz" , &linear_meta<5 , true >);
212
357
}
0 commit comments