Skip to content

Commit 2c6c8df

Browse files
authored
vulkan: optimize coopmat2 iq2/iq3 callbacks (ggml-org#11521)
* vulkan: optimize coopmat2 iq2/iq3 callbacks * build: trigger CI on GLSL compute shader changes
1 parent 8a7e3bf commit 2c6c8df

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ on:
1010
push:
1111
branches:
1212
- master
13-
paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal']
13+
paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp']
1414
pull_request:
1515
types: [opened, synchronize, reopened]
16-
paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal']
16+
paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp']
1717

1818
concurrency:
1919
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -323,15 +323,16 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
323323
const uint8_t qs = bl.block.qs[iqs];
324324
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
325325

326-
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28));
326+
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
327327
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
328328
sign |= bitCount(sign) << 7;
329329

330-
const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3];
330+
uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
331+
g2 >>= (idx & 2) * 8;
332+
const vec2 g = vec2(unpack8(g2));
331333

332-
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
333-
334-
return ret;
334+
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
335+
return float16_t(ret[idx & 1]);
335336
}
336337
#endif
337338

@@ -350,14 +351,16 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor
350351
const uint iqs = (idx & 0xF8) >> 3; // 0..63
351352

352353
const uint16_t qs = bl.block.qs[iqs];
353-
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF));
354+
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
354355

355356
uint sign = uint(qs >> 9);
356357
sign |= bitCount(sign) << 7;
357-
const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3];
358+
uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
359+
g2 >>= (idx & 2) * 8;
360+
const vec2 g = vec2(unpack8(g2));
358361

359-
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
360-
return ret;
362+
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
363+
return float16_t(ret[idx & 1]);
361364
}
362365
#endif
363366

@@ -369,24 +372,23 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2
369372
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
370373
{
371374
uint idx = coordInBlock[1];
372-
uint lsb = idx & 1;
373-
idx /= 2;
374375

375-
const uint ib8 = (idx % 128) / 4; // 0..31
376-
const uint ib32 = ib8 / 4; // 0..7
376+
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
377+
const uint ib8 = (idx & 0xF8) >> 3; // 0..31
378+
const uint qhshift = 2 * (ib8 % 4);
377379

378-
const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
380+
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
379381
const uint qs = bl.block.qs[ib8];
380382
const uint qh = bl.block.qh[ib32];
381-
const uint qhshift = 2 * (ib8 % 4);
382-
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
383+
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
383384

384385
const float d = float(bl.block.d);
385386
const float db = d * 0.25 * (0.5 + scale);
386-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
387-
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
388-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
389-
return float16_t(v[lsb]);
387+
const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
388+
uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
389+
g2 >>= (idx & 2) * 8;
390+
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
391+
return float16_t(v[idx & 1]);
390392
}
391393
#endif
392394

@@ -401,28 +403,25 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3
401403

402404
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
403405
{
406+
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
404407
uint idx = coordInBlock[1];
405-
uint lsb = idx & 1;
406-
idx /= 2;
407408

408-
const uint iqs = (idx % 128) / 2; // 0..63
409-
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
409+
const uint iqs = (idx & 0xFC) >> 2; // 0..63
410+
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
410411

411412
const float d = float(bl.block.d);
412413
const uint qs = bl.block.qs[iqs];
413-
const uint signs = pack32(u8vec4(
414-
bl.block.qs[is+0],
415-
bl.block.qs[is+1],
416-
bl.block.qs[is+2],
417-
bl.block.qs[is+3]
414+
const uint signs = pack32(u16vec2(
415+
bl16.block.qs[is/2+0],
416+
bl16.block.qs[is/2+1]
418417
));
419418
const float db = d * 0.5 * (0.5 + (signs >> 28));
420419
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
421-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
422-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
423-
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
420+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
421+
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
422+
const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
424423
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
425-
return float16_t(v[lsb]);
424+
return float16_t(v[idx & 1]);
426425
}
427426
#endif
428427

@@ -434,23 +433,21 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3
434433
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
435434
{
436435
uint idx = coordInBlock[1];
437-
uint lsb = idx & 1;
438-
idx /= 2;
439436

440-
const uint iqs = (idx % 128) / 2; // 0..63
441-
const uint iqh = iqs / 8;
437+
const uint iqs = (idx & 0xFC) >> 2; // 0..63
438+
const uint iqh = (idx & 0xE0) >> 5;
442439

443440
const float d = float(bl.block.d);
444441
const uint qs = bl.block.qs[iqs];
445442
const uint qh = bl.block.qh[iqh];
446-
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4)));
443+
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
447444
const uint scale = bl.block.scales[iqs / 16];
448-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
445+
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
449446
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
450-
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
447+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
451448
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
452449

453-
return float16_t(v[lsb]);
450+
return float16_t(v[idx & 1]);
454451
}
455452
#endif
456453

0 commit comments

Comments
 (0)