27
27
#define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
28
28
29
29
#define VK_VENDOR_ID_AMD 0x1002
30
+ #define VK_VENDOR_ID_APPLE 0x106b
30
31
#define VK_VENDOR_ID_INTEL 0x8086
31
32
#define VK_VENDOR_ID_NVIDIA 0x10de
32
33
@@ -2034,18 +2035,100 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ct
2034
2035
return ctx->pipeline_matmul_f32_aligned_l .align ;
2035
2036
}
2036
2037
2038
+ static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2039
+ if (bit16_x && bit16_y) {
2040
+ if (m <= 32 || n <= 32 ) {
2041
+ #ifdef GGML_VULKAN_DEBUG
2042
+ std::cerr << " S" << std::endl;
2043
+ #endif
2044
+ return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s ;
2045
+ }
2046
+ #ifdef GGML_VULKAN_DEBUG
2047
+ std::cerr << " M" << std::endl;
2048
+ #endif
2049
+ return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m ;
2050
+ }
2051
+ if (bit16_x && !bit16_y) {
2052
+ if (m <= 32 || n <= 32 ) {
2053
+ #ifdef GGML_VULKAN_DEBUG
2054
+ std::cerr << " S" << std::endl;
2055
+ #endif
2056
+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s ;
2057
+ }
2058
+ #ifdef GGML_VULKAN_DEBUG
2059
+ std::cerr << " M" << std::endl;
2060
+ #endif
2061
+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m ;
2062
+ }
2063
+ if (!bit16_x && bit16_y) {
2064
+ GGML_ASSERT (false );
2065
+ }
2066
+
2067
+ if (m <= 32 || n <= 32 ) {
2068
+ #ifdef GGML_VULKAN_DEBUG
2069
+ std::cerr << " S" << std::endl;
2070
+ #endif
2071
+ return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s ;
2072
+ }
2073
+ #ifdef GGML_VULKAN_DEBUG
2074
+ std::cerr << " M" << std::endl;
2075
+ #endif
2076
+ return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m ;
2077
+ }
2078
+
2079
+ static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2080
+ #ifdef GGML_VULKAN_DEBUG
2081
+ std::cerr << " M" << std::endl;
2082
+ #endif
2083
+ if (bit16_x && bit16_y) {
2084
+ return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m ;
2085
+ }
2086
+ if (bit16_x && !bit16_y) {
2087
+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m ;
2088
+ }
2089
+ if (!bit16_x && bit16_y) {
2090
+ GGML_ASSERT (false );
2091
+ }
2092
+ return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m ;
2093
+ }
2094
+
2095
+ static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
2096
+ #ifdef GGML_VULKAN_DEBUG
2097
+ std::cerr << " S" << std::endl;
2098
+ #endif
2099
+ if (bit16_x && bit16_y) {
2100
+ return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s ;
2101
+ }
2102
+ if (bit16_x && !bit16_y) {
2103
+ return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s ;
2104
+ }
2105
+ if (!bit16_x && bit16_y) {
2106
+ GGML_ASSERT (false );
2107
+ }
2108
+ return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s ;
2109
+ }
2110
+
2037
2111
static vk_pipeline* ggml_vk_guess_matmul_pipeline (ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
2038
2112
#ifdef GGML_VULKAN_DEBUG
2039
2113
std::cerr << " ggml_vk_guess_matmul_pipeline(" << bit16_x << " , " << bit16_y << " , " << m << " , " << n << " , " << aligned << " )" ;
2040
2114
#endif
2115
+ switch (ctx->device .lock ()->vendor_id ) {
2116
+ case VK_VENDOR_ID_AMD:
2117
+ return ggml_vk_guess_matmul_pipeline_amd (ctx, bit16_x, bit16_y, m, n, aligned);
2118
+ case VK_VENDOR_ID_APPLE:
2119
+ return ggml_vk_guess_matmul_pipeline_apple (ctx, bit16_x, bit16_y, aligned);
2120
+ case VK_VENDOR_ID_INTEL:
2121
+ return ggml_vk_guess_matmul_pipeline_intel (ctx, bit16_x, bit16_y, aligned);
2122
+ }
2123
+
2041
2124
if (bit16_x && bit16_y) {
2042
- if (ctx-> device . lock ()-> vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32 ) {
2125
+ if (m <= 32 || n <= 32 ) {
2043
2126
#ifdef GGML_VULKAN_DEBUG
2044
2127
std::cerr << " S" << std::endl;
2045
2128
#endif
2046
2129
return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s ;
2047
2130
}
2048
- if (ctx-> device . lock ()-> subgroup_size == 64 || m <= 64 || n <= 64 ) {
2131
+ if (m <= 64 || n <= 64 ) {
2049
2132
#ifdef GGML_VULKAN_DEBUG
2050
2133
std::cerr << " M" << std::endl;
2051
2134
#endif
@@ -2057,13 +2140,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
2057
2140
return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l ;
2058
2141
}
2059
2142
if (bit16_x && !bit16_y) {
2060
- if (ctx-> device . lock ()-> vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32 ) {
2143
+ if (m <= 32 || n <= 32 ) {
2061
2144
#ifdef GGML_VULKAN_DEBUG
2062
2145
std::cerr << " S" << std::endl;
2063
2146
#endif
2064
2147
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s ;
2065
2148
}
2066
- if (ctx-> device . lock ()-> subgroup_size == 64 || m <= 64 || n <= 64 ) {
2149
+ if (m <= 64 || n <= 64 ) {
2067
2150
#ifdef GGML_VULKAN_DEBUG
2068
2151
std::cerr << " M" << std::endl;
2069
2152
#endif
@@ -2078,13 +2161,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
2078
2161
GGML_ASSERT (false );
2079
2162
}
2080
2163
2081
- if (ctx-> device . lock ()-> vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32 ) {
2164
+ if (m <= 32 || n <= 32 ) {
2082
2165
#ifdef GGML_VULKAN_DEBUG
2083
2166
std::cerr << " S" << std::endl;
2084
2167
#endif
2085
2168
return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s ;
2086
2169
}
2087
- if (ctx-> device . lock ()-> subgroup_size == 64 || m <= 64 || n <= 64 ) {
2170
+ if (m <= 64 || n <= 64 ) {
2088
2171
#ifdef GGML_VULKAN_DEBUG
2089
2172
std::cerr << " M" << std::endl;
2090
2173
#endif
0 commit comments