@@ -165,6 +165,7 @@ struct vk_device_struct {
165
165
vk_queue transfer_queue;
166
166
bool single_queue;
167
167
uint32_t subgroup_size;
168
+ uint32_t shader_core_count;
168
169
bool uma;
169
170
170
171
size_t idx;
@@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1498
1499
ggml_vk_create_pipeline (device, device->pipeline_get_rows_f32 [GGML_TYPE_Q8_0], " get_rows_q8_0_f32" , get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {1024 , 1 , 1 }, {}, 1 );
1499
1500
ggml_vk_create_pipeline (device, device->pipeline_get_rows_f32 [GGML_TYPE_IQ4_NL], " get_rows_iq4_nl_f32" , get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {1024 , 1 , 1 }, {}, 1 );
1500
1501
1501
- ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 , 1 , 1 }, {}, 1 );
1502
+ ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 * 4 , 1 , 1 }, {}, 1 );
1502
1503
1503
1504
ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32" , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
1504
1505
ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32" , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main" , 3 , 7 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
@@ -1610,23 +1611,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
1610
1611
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device .enumerateDeviceExtensionProperties ();
1611
1612
1612
1613
bool maintenance4_support = false ;
1614
+ bool sm_builtins = false ;
1613
1615
1614
1616
// Check if maintenance4 is supported
1615
1617
for (const auto & properties : ext_props) {
1616
1618
if (strcmp (" VK_KHR_maintenance4" , properties.extensionName ) == 0 ) {
1617
1619
maintenance4_support = true ;
1620
+ } else if (strcmp (" VK_NV_shader_sm_builtins" , properties.extensionName ) == 0 ) {
1621
+ sm_builtins = true ;
1618
1622
}
1619
1623
}
1620
1624
1621
1625
vk::PhysicalDeviceProperties2 props2;
1622
1626
vk::PhysicalDeviceMaintenance3Properties props3;
1623
1627
vk::PhysicalDeviceMaintenance4Properties props4;
1624
1628
vk::PhysicalDeviceSubgroupProperties subgroup_props;
1629
+ vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
1625
1630
props2.pNext = &props3;
1626
1631
props3.pNext = &subgroup_props;
1632
+
1633
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
1634
+
1627
1635
if (maintenance4_support) {
1628
- subgroup_props.pNext = &props4;
1636
+ last_struct->pNext = (VkBaseOutStructure *)&props4;
1637
+ last_struct = (VkBaseOutStructure *)&props4;
1638
+ }
1639
+ if (sm_builtins) {
1640
+ last_struct->pNext = (VkBaseOutStructure *)&sm_props;
1641
+ last_struct = (VkBaseOutStructure *)&sm_props;
1629
1642
}
1643
+
1630
1644
device->physical_device .getProperties2 (&props2);
1631
1645
device->properties = props2.properties ;
1632
1646
@@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
1643
1657
device->vendor_id = device->properties .vendorID ;
1644
1658
device->subgroup_size = subgroup_props.subgroupSize ;
1645
1659
device->uma = device->properties .deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1660
+ if (sm_builtins) {
1661
+ device->shader_core_count = sm_props.shaderSMCount ;
1662
+ } else {
1663
+ device->shader_core_count = 0 ;
1664
+ }
1646
1665
1647
1666
bool fp16_storage = false ;
1648
1667
bool fp16_compute = false ;
@@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
2732
2751
dst->device ->device .resetFences ({ dst->device ->fence });
2733
2752
}
2734
2753
2735
- static uint32_t ggml_vk_guess_split_k (int m, int n, int k) {
2754
+ static uint32_t ggml_vk_guess_split_k (ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline ) {
2736
2755
VK_LOG_DEBUG (" ggml_vk_guess_split_k(" << m << " , " << n << " , " << k << " )" );
2737
- // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2738
- // return 4;
2739
- // }
2740
2756
2741
- return 1 ;
2757
+ uint32_t split_k = 1 ;
2758
+ if (ctx->device ->shader_core_count != 0 && m >= (int )pipeline->wg_denoms [0 ] && n >= (int )pipeline->wg_denoms [1 ]) {
2759
+ // If k is 'large' and the SMs will fill less than halfway, use split_k.
2760
+ uint32_t m_tiles = CEIL_DIV (m, pipeline->wg_denoms [0 ]);
2761
+ uint32_t n_tiles = CEIL_DIV (n, pipeline->wg_denoms [1 ]);
2762
+ if (k >= 2048 && m_tiles * n_tiles < ctx->device ->shader_core_count / 2 ) {
2763
+ split_k = ctx->device ->shader_core_count / (m_tiles * n_tiles);
2764
+ // Clamp to 2 or 4
2765
+ split_k = std::min (split_k, 4u );
2766
+ if (split_k == 3 ) {
2767
+ split_k = 2 ;
2768
+ }
2769
+ }
2770
+ }
2742
2771
2743
- GGML_UNUSED (m); GGML_UNUSED (n); GGML_UNUSED (k) ;
2772
+ return split_k ;
2744
2773
}
2745
2774
2746
2775
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd (ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
@@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
2964
2993
const uint32_t kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_pipeline_align (ctx, mmp, ne01, ne11));
2965
2994
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8 ;
2966
2995
2967
- const uint32_t split_k = ggml_vk_guess_split_k (ne01, ne11, ne10);
2968
-
2969
2996
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline (ctx, mmp, ne01, ne11, aligned);
2970
2997
2998
+ const uint32_t split_k = ggml_vk_guess_split_k (ctx, ne01, ne11, ne10, pipeline);
2999
+
2971
3000
const uint64_t qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
2972
3001
const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
2973
3002
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof (ggml_fp16_t ) * x_ne;
@@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
2993
3022
if (dryrun) {
2994
3023
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
2995
3024
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
2996
- const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0 ;
3025
+ const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0 ;
2997
3026
if (
2998
3027
(qx_needs_dequant && x_sz_upd > ctx->device ->max_memory_allocation_size ) ||
2999
3028
(qy_needs_dequant && y_sz_upd > ctx->device ->max_memory_allocation_size ) ||
0 commit comments