Skip to content

Commit 16d9758

Browse files
author
morelos
committed
[ET-VK][Ops] linear_qta8a_qga4w_qta8o impl and shaders
Pull Request resolved: #12006 # Operator Description The linear_qta8a_qga4w_qta8o operator implements a quantized linear transformation that enables efficient neural network inference through dynamic quantization. This operator performs matrix multiplication between quantized 8-bit activations and 4-bit grouped quantized weights, producing quantized 8-bit outputs. The quantization scheme follows the standard affine mapping where `real_value = scale * (quantized_value - zero_point)`. Input activations use 8-bit signed integers with per-token scale and zero-point parameters, while weights employ 4-bit quantization with group-wise parameters. # Implementation Architecture The operator provides two distinct computational approaches optimized for different matrix multiplication scenarios: the TILED algorithm for general matrix-matrix multiplication (GEMM) and the COOPERATIVE algorithm for matrix-vector multiplication (GEMV). ## TILED Algorithm (GEMM Cases) The tiled implementation processes the output matrix in rectangular blocks. Each thread is responsible for calculating a tile of output values, typically processing 3 rows and 2 columns worth of results in each iteration. The algorithm operates by having each thread load blocks of quantized weights and activations, perform integer arithmetic accumulation, and then apply the necessary scaling operations. Weight data is pre-packed in a specialized format where two 4-bit values are stored in each byte. Each thread loads multiple weight elements simultaneously and unpacks them during computation. The quantization parameters for weights are organized by groups, where each group of consecutive weight elements shares the same scale and zero-point values. ## COOPERATIVE Algorithm (GEMV Cases) The cooperative implementation uses shared memory and thread cooperation where this approach uses workgroups of 64 threads arranged as 8 groups of 8 workers each. The key insight is that GEMV operations have limited parallelism in the output dimension but substantial parallelism in the reduction dimension, making cooperative reduction strategies more effective than independent thread computation. Each group of 8 worker threads collaboratively computes a portion of the output vector. The workers divide the reduction work along the input feature dimension, with each worker processing every 8th element in a strided pattern. # Future Performance Improvements - Making use of dotPacked4x8EXT (this requires upgrading glslc and vulkan) - Fixed point math for pure integer operations - Might be more performant to avoid preloading tensors - Might also be more performant to avoid excessive register overhead by defining the ivec4 within each block operation (allowing more threads to be more register intensive) ghstack-source-id: 292886840 Differential Revision: [D77173441](https://our.internmc.facebook.com/intern/diff/D77173441/)
1 parent ef5e756 commit 16d9758

File tree

6 files changed

+1125
-0
lines changed

6 files changed

+1125
-0
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
#define NGROUPS 8
19+
#define NWORKERS 8
20+
21+
${define_required_extensions(DTYPE)}
22+
$if WEIGHT_STORAGE == "buffer":
23+
${define_required_extensions("uint8")}
24+
25+
#extension GL_EXT_control_flow_attributes : require
26+
27+
layout(std430) buffer;
28+
29+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_qparams", "float", PARAMS_STORAGE, is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)}
34+
${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)}
35+
${layout_declare_tensor(B, "r", "t_output_scale", "float", "buffer", is_scalar_array=True)}
36+
${layout_declare_tensor(B, "r", "t_output_zero_point", "int", "buffer", is_scalar_array=True)}
37+
38+
layout(push_constant) uniform restrict Block {
39+
ivec4 out_sizes;
40+
ivec4 mat1_sizes;
41+
ivec4 qmat2_sizes;
42+
};
43+
44+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
45+
46+
layout(constant_id = 3) const int group_size = 64;
47+
48+
shared vec4 partial_results[NGROUPS][NWORKERS][TILE_ROWS][2];
49+
50+
/*
51+
* This shader computes a linear operator between a quantized int8 input matrix
52+
* x and a weights matrix that is quantized to 4 bits, producing a quantized int8 output.
53+
*
54+
* This shader implements a co-operative algorithm to compute the output. The
55+
* work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads
56+
* cooperative to compute TILE_ROWS * 2 output texels. Therefore,
57+
* NGROUP * TILE_ROWS * 2 output texels are computed across one work group.
58+
*
59+
* The threads co-operate by each thread computing a partial reduction along the
60+
* K dimension. To illustrate the computation, consider a scalar variant of the
61+
* algorithm that computes the dot product of 2 vectors. Also assume that
62+
* NWORKERS is 8.
63+
*
64+
* Thread 1 in each group will compute:
65+
* (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ...
66+
*
67+
* Thread 2 in each group will compute:
68+
* (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ...
69+
*
70+
* Thread 3 in each group will compute:
71+
* (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ...
72+
*
73+
* The partial accumulations is structured such that memory accesses in each
74+
* loop iteration can be coalesced.
75+
*
76+
* Then, at the end first thread in each group will accumulate the partial
77+
* accumulations computed by each thread to obtain the final result.
78+
*
79+
* Note that this shader assumes that all tensors are width packed.
80+
*/
81+
82+
void main() {
83+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
84+
const uint out_col = gl_GlobalInvocationID.x << 3;
85+
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;
86+
87+
const uint gid = gl_LocalInvocationID.x; // group id
88+
const uint wid = gl_LocalInvocationID.z; // worker id
89+
90+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
91+
return;
92+
}
93+
94+
const int num_blocks = mat1_sizes.x / group_size;
95+
96+
VEC4_T mat1_quantized[TILE_ROWS];
97+
ivec4 qmat2_quantized[4][2];
98+
vec4 final_result[TILE_ROWS][2];
99+
100+
// Initialize accumulators
101+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
102+
final_result[r][0] = vec4(0.0);
103+
final_result[r][1] = vec4(0.0);
104+
}
105+
106+
vec4 scales[2];
107+
vec4 zeros[2];
108+
109+
$if WEIGHT_STORAGE == "buffer":
110+
const int qmat2_stride = qmat2_sizes.x >> 2;
111+
$if PARAMS_STORAGE == "buffer":
112+
const int qparams_y_stride = out_sizes.x >> 2;
113+
const int qparams_z_stride = qparams_y_stride * 2;
114+
115+
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
116+
$if PARAMS_STORAGE == "buffer":
117+
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
118+
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];
119+
120+
scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
121+
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
122+
$else:
123+
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
124+
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
125+
126+
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
127+
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
128+
129+
ivec4 int32_sums[TILE_ROWS][2];
130+
int input_sums[TILE_ROWS];
131+
132+
// Initialize accumulators for this block
133+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
134+
int32_sums[r][0] = ivec4(0);
135+
int32_sums[r][1] = ivec4(0);
136+
input_sums[r] = 0;
137+
}
138+
139+
for (int g_idx = 4 * int(wid); g_idx < group_size; g_idx += (4 * NWORKERS)) {
140+
const int k = block_idx * group_size + g_idx;
141+
142+
// Preload B (weights) - keep as quantized integers
143+
[[unroll]] for (int r = 0; r < 4; ++r) {
144+
$if WEIGHT_STORAGE == "buffer":
145+
const uvec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
146+
$else:
147+
const uvec4 packed_weight_tex = texelFetch(
148+
t_qmat2,
149+
ivec2(gl_GlobalInvocationID.x, k + r),
150+
0);
151+
152+
// Unpack 4-bit weights to integers and subtract zero point (8 for 4-bit)
153+
qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - ivec4(8);
154+
qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - ivec4(8);
155+
}
156+
157+
// Preload A (quantized input) - keep as quantized integers
158+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
159+
$if IN_STORAGE == "buffer":
160+
mat1_quantized[r] = VEC4_T(t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]);
161+
$else:
162+
mat1_quantized[r] = VEC4_T(texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]);
163+
164+
input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w;
165+
}
166+
167+
// Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point)
168+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
169+
int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0]
170+
+ mat1_quantized[r].y * qmat2_quantized[1][0]
171+
+ mat1_quantized[r].z * qmat2_quantized[2][0]
172+
+ mat1_quantized[r].w * qmat2_quantized[3][0];
173+
174+
int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1]
175+
+ mat1_quantized[r].y * qmat2_quantized[1][1]
176+
+ mat1_quantized[r].z * qmat2_quantized[2][1]
177+
+ mat1_quantized[r].w * qmat2_quantized[3][1];
178+
}
179+
}
180+
181+
// Incorporates this block's results into the final accumulation
182+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
183+
if (out_row + r >= out_sizes.y) {
184+
continue;
185+
}
186+
187+
float input_scale = t_input_scale[int(out_row) + r];
188+
float input_sum_scalar = float(input_sums[r]);
189+
190+
final_result[r][0] += input_scale * (vec4(int32_sums[r][0]) * scales[0] + input_sum_scalar * zeros[0]);
191+
final_result[r][1] += input_scale * (vec4(int32_sums[r][1]) * scales[1] + input_sum_scalar * zeros[1]);
192+
}
193+
}
194+
195+
// Store worker results in shared memory
196+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
197+
partial_results[gid][wid][r][0] = final_result[r][0];
198+
partial_results[gid][wid][r][1] = final_result[r][1];
199+
}
200+
201+
memoryBarrierShared();
202+
barrier();
203+
204+
// Only the first worker in each group accumulates and writes output
205+
if (wid != 0) {
206+
return;
207+
}
208+
209+
vec4 cooperative_result[TILE_ROWS][2];
210+
211+
for (int r = 0; r < TILE_ROWS; ++r) {
212+
cooperative_result[r][0] = vec4(0.0);
213+
cooperative_result[r][1] = vec4(0.0);
214+
[[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) {
215+
cooperative_result[r][0] += partial_results[gid][worker][r][0];
216+
cooperative_result[r][1] += partial_results[gid][worker][r][1];
217+
}
218+
}
219+
220+
// Apply final output quantization
221+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
222+
int token_idx = int(out_row) + r;
223+
224+
float output_scale = t_output_scale[token_idx];
225+
int output_zero_point = t_output_zero_point[token_idx];
226+
227+
VEC4_T quantized_out_0 = VEC4_T(clamp(
228+
ivec4(round(cooperative_result[r][0] / output_scale)) + float(output_zero_point),
229+
-128, 127));
230+
VEC4_T quantized_out_1 = VEC4_T(clamp(
231+
ivec4(round(cooperative_result[r][1] / output_scale)) + float(output_zero_point),
232+
-128, 127));
233+
234+
$if OUT_STORAGE == "buffer":
235+
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = quantized_out_0;
236+
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = quantized_out_1;
237+
$else:
238+
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), quantized_out_0);
239+
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), quantized_out_1);
240+
}
241+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
linear_qta8a_qga4w_qta8o_coop:
8+
parameter_names_with_default_values:
9+
DTYPE: int8
10+
OUT_STORAGE: texture3d
11+
IN_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture2d
13+
PARAMS_STORAGE: buffer
14+
TILE_ROWS: 1
15+
shader_variants:
16+
- NAME: linear_qta8a_qga4w_qta8o_coop_texture3d_texture3d_texture2d_int8
17+
- NAME: linear_qta8a_qga4w_qta8o_coop_buffer_buffer_texture2d_int8
18+
OUT_STORAGE: buffer
19+
IN_STORAGE: buffer
20+
- NAME: linear_qta8a_qga4w_qta8o_coop_buffer_buffer_buffer_int8
21+
OUT_STORAGE: buffer
22+
IN_STORAGE: buffer
23+
WEIGHT_STORAGE: buffer
24+
- NAME: linear_qta8a_qga4w_qta8o_coop_buffer_texture2d_buffer_int8
25+
OUT_STORAGE: buffer
26+
WEIGHT_STORAGE: buffer

0 commit comments

Comments
 (0)