Skip to content

Commit c88c74f

Browse files
authored
vulkan: only use M-sized matmul on Apple GPUs (#5412)
* vulkan: refactor guess_matmul_pipeline for vendor Refactor ggml_vk_guess_matmul_pipeline to simplify adding per-vendor conditionals. Signed-off-by: Sergio Lopez <[email protected]> * vulkan: only use M-sized matmul on Apple GPUs L-sized and S-sized matmuls are broken on Apple GPUs, force using M-size with this vendor. Signed-off-by: Sergio Lopez <[email protected]> --------- Signed-off-by: Sergio Lopez <[email protected]>
1 parent a803333 commit c88c74f

File tree

1 file changed

+89
-6
lines changed

1 file changed

+89
-6
lines changed

ggml-vulkan.cpp

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
2828

2929
#define VK_VENDOR_ID_AMD 0x1002
30+
#define VK_VENDOR_ID_APPLE 0x106b
3031
#define VK_VENDOR_ID_INTEL 0x8086
3132
#define VK_VENDOR_ID_NVIDIA 0x10de
3233

@@ -2034,18 +2035,100 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ct
20342035
return ctx->pipeline_matmul_f32_aligned_l.align;
20352036
}
20362037

2038+
static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2039+
if (bit16_x && bit16_y) {
2040+
if (m <= 32 || n <= 32) {
2041+
#ifdef GGML_VULKAN_DEBUG
2042+
std::cerr << " S" << std::endl;
2043+
#endif
2044+
return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2045+
}
2046+
#ifdef GGML_VULKAN_DEBUG
2047+
std::cerr << " M" << std::endl;
2048+
#endif
2049+
return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2050+
}
2051+
if (bit16_x && !bit16_y) {
2052+
if (m <= 32 || n <= 32) {
2053+
#ifdef GGML_VULKAN_DEBUG
2054+
std::cerr << " S" << std::endl;
2055+
#endif
2056+
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2057+
}
2058+
#ifdef GGML_VULKAN_DEBUG
2059+
std::cerr << " M" << std::endl;
2060+
#endif
2061+
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2062+
}
2063+
if (!bit16_x && bit16_y) {
2064+
GGML_ASSERT(false);
2065+
}
2066+
2067+
if (m <= 32 || n <= 32) {
2068+
#ifdef GGML_VULKAN_DEBUG
2069+
std::cerr << " S" << std::endl;
2070+
#endif
2071+
return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2072+
}
2073+
#ifdef GGML_VULKAN_DEBUG
2074+
std::cerr << " M" << std::endl;
2075+
#endif
2076+
return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2077+
}
2078+
2079+
static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2080+
#ifdef GGML_VULKAN_DEBUG
2081+
std::cerr << " M" << std::endl;
2082+
#endif
2083+
if (bit16_x && bit16_y) {
2084+
return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
2085+
}
2086+
if (bit16_x && !bit16_y) {
2087+
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
2088+
}
2089+
if (!bit16_x && bit16_y) {
2090+
GGML_ASSERT(false);
2091+
}
2092+
return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
2093+
}
2094+
2095+
static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2096+
#ifdef GGML_VULKAN_DEBUG
2097+
std::cerr << " S" << std::endl;
2098+
#endif
2099+
if (bit16_x && bit16_y) {
2100+
return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
2101+
}
2102+
if (bit16_x && !bit16_y) {
2103+
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
2104+
}
2105+
if (!bit16_x && bit16_y) {
2106+
GGML_ASSERT(false);
2107+
}
2108+
return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
2109+
}
2110+
20372111
static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
20382112
#ifdef GGML_VULKAN_DEBUG
20392113
std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
20402114
#endif
2115+
switch (ctx->device.lock()->vendor_id) {
2116+
case VK_VENDOR_ID_AMD:
2117+
return ggml_vk_guess_matmul_pipeline_amd(ctx, bit16_x, bit16_y, m, n, aligned);
2118+
case VK_VENDOR_ID_APPLE:
2119+
return ggml_vk_guess_matmul_pipeline_apple(ctx, bit16_x, bit16_y, aligned);
2120+
case VK_VENDOR_ID_INTEL:
2121+
return ggml_vk_guess_matmul_pipeline_intel(ctx, bit16_x, bit16_y, aligned);
2122+
}
2123+
20412124
if (bit16_x && bit16_y) {
2042-
if (ctx->device.lock()->vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32) {
2125+
if (m <= 32 || n <= 32) {
20432126
#ifdef GGML_VULKAN_DEBUG
20442127
std::cerr << " S" << std::endl;
20452128
#endif
20462129
return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
20472130
}
2048-
if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) {
2131+
if (m <= 64 || n <= 64) {
20492132
#ifdef GGML_VULKAN_DEBUG
20502133
std::cerr << " M" << std::endl;
20512134
#endif
@@ -2057,13 +2140,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
20572140
return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l;
20582141
}
20592142
if (bit16_x && !bit16_y) {
2060-
if (ctx->device.lock()->vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32) {
2143+
if (m <= 32 || n <= 32) {
20612144
#ifdef GGML_VULKAN_DEBUG
20622145
std::cerr << " S" << std::endl;
20632146
#endif
20642147
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
20652148
}
2066-
if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) {
2149+
if (m <= 64 || n <= 64) {
20672150
#ifdef GGML_VULKAN_DEBUG
20682151
std::cerr << " M" << std::endl;
20692152
#endif
@@ -2078,13 +2161,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
20782161
GGML_ASSERT(false);
20792162
}
20802163

2081-
if (ctx->device.lock()->vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32) {
2164+
if (m <= 32 || n <= 32) {
20822165
#ifdef GGML_VULKAN_DEBUG
20832166
std::cerr << " S" << std::endl;
20842167
#endif
20852168
return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
20862169
}
2087-
if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) {
2170+
if (m <= 64 || n <= 64) {
20882171
#ifdef GGML_VULKAN_DEBUG
20892172
std::cerr << " M" << std::endl;
20902173
#endif

0 commit comments

Comments
 (0)