Skip to content

Commit cc98896

Browse files
authored
vulkan: optimize and reenable split_k (ggml-org#10637)
Use vector loads when possible in mul_mat_split_k_reduce. Use split_k when there aren't enough workgroups to fill the shaders.
1 parent 91c36c2 commit cc98896

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ struct vk_device_struct {
165165
vk_queue transfer_queue;
166166
bool single_queue;
167167
uint32_t subgroup_size;
168+
uint32_t shader_core_count;
168169
bool uma;
169170

170171
size_t idx;
@@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
14981499
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
14991500
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
15001501

1501-
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
1502+
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
15021503

15031504
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
15041505
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
@@ -1610,23 +1611,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
16101611
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
16111612

16121613
bool maintenance4_support = false;
1614+
bool sm_builtins = false;
16131615

16141616
// Check if maintenance4 is supported
16151617
for (const auto& properties : ext_props) {
16161618
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
16171619
maintenance4_support = true;
1620+
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
1621+
sm_builtins = true;
16181622
}
16191623
}
16201624

16211625
vk::PhysicalDeviceProperties2 props2;
16221626
vk::PhysicalDeviceMaintenance3Properties props3;
16231627
vk::PhysicalDeviceMaintenance4Properties props4;
16241628
vk::PhysicalDeviceSubgroupProperties subgroup_props;
1629+
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
16251630
props2.pNext = &props3;
16261631
props3.pNext = &subgroup_props;
1632+
1633+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
1634+
16271635
if (maintenance4_support) {
1628-
subgroup_props.pNext = &props4;
1636+
last_struct->pNext = (VkBaseOutStructure *)&props4;
1637+
last_struct = (VkBaseOutStructure *)&props4;
1638+
}
1639+
if (sm_builtins) {
1640+
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
1641+
last_struct = (VkBaseOutStructure *)&sm_props;
16291642
}
1643+
16301644
device->physical_device.getProperties2(&props2);
16311645
device->properties = props2.properties;
16321646

@@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
16431657
device->vendor_id = device->properties.vendorID;
16441658
device->subgroup_size = subgroup_props.subgroupSize;
16451659
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1660+
if (sm_builtins) {
1661+
device->shader_core_count = sm_props.shaderSMCount;
1662+
} else {
1663+
device->shader_core_count = 0;
1664+
}
16461665

16471666
bool fp16_storage = false;
16481667
bool fp16_compute = false;
@@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
27322751
dst->device->device.resetFences({ dst->device->fence });
27332752
}
27342753

2735-
static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
2754+
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
27362755
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
2737-
// if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2738-
// return 4;
2739-
// }
27402756

2741-
return 1;
2757+
uint32_t split_k = 1;
2758+
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
2759+
// If k is 'large' and the SMs will fill less than halfway, use split_k.
2760+
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
2761+
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
2762+
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
2763+
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
2764+
// Clamp to 2 or 4
2765+
split_k = std::min(split_k, 4u);
2766+
if (split_k == 3) {
2767+
split_k = 2;
2768+
}
2769+
}
2770+
}
27422771

2743-
GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
2772+
return split_k;
27442773
}
27452774

27462775
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
@@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
29642993
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
29652994
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
29662995

2967-
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
2968-
29692996
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
29702997

2998+
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
2999+
29713000
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
29723001
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
29733002
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
29933022
if (dryrun) {
29943023
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
29953024
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
2996-
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
3025+
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
29973026
if (
29983027
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
29993028
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,44 @@
55
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
66

77
layout (binding = 0) readonly buffer A {float data_a[];};
8+
layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
89
layout (binding = 1) writeonly buffer D {float data_d[];};
10+
layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
911

1012
layout (push_constant) uniform parameter {
1113
uint ne;
1214
uint k_num;
1315
} p;
1416

1517
void main() {
16-
const uint idx = gl_GlobalInvocationID.x;
18+
// Each invocation handles four consecutive components
19+
const uint idx = gl_GlobalInvocationID.x * 4;
1720

1821
if (idx >= p.ne) {
1922
return;
2023
}
2124

22-
float result = 0.0f;
25+
// Check if all four components are in bounds and aligned,
26+
// then use vector loads
27+
if (idx + 3 < p.ne && (p.ne % 4) == 0) {
28+
vec4 result = vec4(0.0f);
2329

24-
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
25-
result += data_a[i * p.ne + idx];
26-
}
30+
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
31+
result += data_a4[(i * p.ne + idx) / 4];
32+
}
33+
34+
data_d4[idx / 4] = result;
35+
} else {
36+
[[unroll]] for (uint j = 0; j < 4; ++j) {
37+
if (idx + j < p.ne) {
38+
float result = 0.0f;
2739

28-
data_d[idx] = result;
40+
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
41+
result += data_a[i * p.ne + idx + j];
42+
}
43+
44+
data_d[idx + j] = result;
45+
}
46+
}
47+
}
2948
}

0 commit comments

Comments
 (0)