Skip to content

Commit 8a7e3bf

Browse files
vulkan: initial support for IQ4_XS quantization (ggml-org#11501)
1 parent 1b598b3 commit 8a7e3bf

13 files changed

+169
-13
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 25 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1212
#endif
1313

1414
void main() {
15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem(gl_WorkGroupSize);
1717
if (gl_LocalInvocationIndex.x != 0) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
217217
#endif
218218

219219
void main() {
220-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
220+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
221221
init_iq_shmem(gl_WorkGroupSize);
222222
if (gl_LocalInvocationIndex.x != 0) {
223223
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,42 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
304304
}
305305
#endif
306306

307+
#if defined(DATA_A_IQ4_XS)
308+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
309+
const uint ib32 = iqs / 32;
310+
const uint iq = 16 * ib32 + (iqs % 16);
311+
312+
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
313+
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
314+
const uint qshift = (iqs & 16) >> 2;
315+
u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
316+
qs = (qs >> qshift) & uint8_t(0xF);
317+
318+
const float dl = float(int(sl | (sh << 4)) - 32);
319+
return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
320+
}
321+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
322+
const uint ib32 = iqs / 32;
323+
const uint iq = 16 * ib32 + (iqs % 16);
324+
325+
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
326+
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
327+
const uint qshift = (iqs & 16) >> 2;
328+
u8vec4 qs = u8vec4(
329+
data_a[a_offset + ib].qs[iq + 0],
330+
data_a[a_offset + ib].qs[iq + 1],
331+
data_a[a_offset + ib].qs[iq + 2],
332+
data_a[a_offset + ib].qs[iq + 3]
333+
);
334+
qs = (qs >> qshift) & uint8_t(0xF);
335+
336+
const float dl = float(int(sl | (sh << 4)) - 32);
337+
return dl * vec4(
338+
kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
339+
kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
340+
}
341+
#endif
342+
307343
#if defined(DATA_A_IQ4_NL)
308344
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
309345
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -321,7 +357,7 @@ vec2 get_dm(uint ib, uint a_offset) {
321357
}
322358
#endif
323359

324-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
360+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
325361
vec2 get_dm(uint ib, uint a_offset) {
326362
return vec2(float(data_a[a_offset + ib].d), 0);
327363
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,27 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords
454454
}
455455
#endif
456456

457+
#if defined(DATA_A_IQ4_XS)
458+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
459+
block_iq4_xs block;
460+
};
461+
462+
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
463+
{
464+
const float16_t d = bl.block.d;
465+
const uint idx = coordInBlock[1];
466+
467+
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
468+
469+
const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
470+
const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
471+
const uint qshift = (idx & 16) >> 2;
472+
const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
473+
474+
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
475+
return ret;
476+
}
477+
#endif
457478

458479
#if defined(DATA_A_IQ4_NL)
459480
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
@@ -504,6 +525,8 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
504525
#define dequantFuncA dequantFuncIQ3_XXS
505526
#elif defined(DATA_A_IQ3_S)
506527
#define dequantFuncA dequantFuncIQ3_S
528+
#elif defined(DATA_A_IQ4_XS)
529+
#define dequantFuncA dequantFuncIQ4_XS
507530
#elif defined(DATA_A_IQ4_NL)
508531
#define dequantFuncA dequantFuncIQ4_NL
509532
#endif
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#version 450
2+
3+
#include "dequant_head.comp"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles 1 subblock (1 scale and 32 quantized values)
12+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
13+
14+
init_iq_shmem(gl_WorkGroupSize);
15+
16+
if (ib >= p.nel / 256) {
17+
return;
18+
}
19+
20+
const uint ib32 = gl_LocalInvocationID.x % 8;
21+
22+
const float d = float(data_a[ib].d);
23+
// Scales are 6 bits
24+
const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
25+
| (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
26+
const float dl = d * (int(scale) - 32);
27+
28+
const uint b_idx = 256 * ib + 32 * ib32;
29+
const uint q_idx = 16 * ib32;
30+
[[unroll]] for (uint l = 0; l < 16; ++l) {
31+
data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
32+
data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
33+
}
34+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
104104
#endif
105105

106106
void main() {
107-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
107+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
108108
init_iq_shmem(gl_WorkGroupSize);
109109
#endif
110110

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem(gl_WorkGroupSize);
1717
#endif
1818

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
133133
void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
136+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
137137
init_iq_shmem(gl_WorkGroupSize);
138138
#endif
139139

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9595
#endif
9696

9797
void main() {
98-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
98+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
9999
init_iq_shmem(gl_WorkGroupSize);
100100
#endif
101101

@@ -547,6 +547,25 @@ void main() {
547547
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
548548
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
549549

550+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
551+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
552+
#elif defined(DATA_A_IQ4_XS)
553+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
554+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
555+
556+
const uint ib = idx / 128; // 2 values per idx
557+
const uint ib32 = (idx % 128) / 16; // 0..7
558+
const uint iq = 16 * ib32 + 2 * (idx % 8);
559+
560+
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
561+
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
562+
const uint qshift = (idx & 8) >> 1;
563+
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
564+
qs = (qs >> qshift) & uint8_t(0xF);
565+
566+
const float d = float(data_a[ib].d);
567+
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
568+
550569
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
551570
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
552571
#elif defined(DATA_A_IQ4_NL)

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
106106
#endif
107107

108108
void main() {
109-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
109+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
110110
init_iq_shmem(gl_WorkGroupSize);
111111
#endif
112112

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,23 @@ void init_iq_shmem(uvec3 wgsize)
10261026
#define A_TYPE_PACKED16 block_iq3_s_packed16
10271027
#endif
10281028

1029+
#define QUANT_K_IQ4_XS 256
1030+
#define QUANT_R_IQ4_XS 1
1031+
1032+
struct block_iq4_xs
1033+
{
1034+
float16_t d;
1035+
uint16_t scales_h;
1036+
uint8_t scales_l[QUANT_K_IQ4_XS/64];
1037+
uint8_t qs[QUANT_K_IQ4_XS/2];
1038+
};
1039+
1040+
#if defined(DATA_A_IQ4_XS)
1041+
#define QUANT_K QUANT_K_IQ4_XS
1042+
#define QUANT_R QUANT_R_IQ4_XS
1043+
#define A_TYPE block_iq4_xs
1044+
#endif
1045+
10291046
#define QUANT_K_IQ4_NL 32
10301047
#define QUANT_R_IQ4_NL 2
10311048

@@ -1042,7 +1059,13 @@ struct block_iq4_nl_packed16
10421059
};
10431060

10441061
#if defined(DATA_A_IQ4_NL)
1062+
#define QUANT_K QUANT_K_IQ4_NL
1063+
#define QUANT_R QUANT_R_IQ4_NL
1064+
#define A_TYPE block_iq4_nl
1065+
#define A_TYPE_PACKED16 block_iq4_nl_packed16
1066+
#endif
10451067

1068+
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
10461069
const int8_t kvalues_iq4nl_const[16] = {
10471070
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
10481071
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
@@ -1058,11 +1081,6 @@ void init_iq_shmem(uvec3 wgsize)
10581081
}
10591082
barrier();
10601083
}
1061-
1062-
#define QUANT_K QUANT_K_IQ4_NL
1063-
#define QUANT_R QUANT_R_IQ4_NL
1064-
#define A_TYPE block_iq4_nl
1065-
#define A_TYPE_PACKED16 block_iq4_nl_packed16
10661084
#endif
10671085

10681086
#endif // !defined(GGML_TYPES_COMP)

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
6060
"iq2_s",
6161
"iq3_xxs",
6262
"iq3_s",
63+
"iq4_xs",
6364
"iq4_nl"
6465
};
6566

0 commit comments

Comments
 (0)