@@ -11,7 +11,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
11
11
12
12
bool ggml_cuda_should_use_mmf (enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
13
13
14
- template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
14
+ template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids, size_t n_expert_used = 0 >
15
15
__launch_bounds__ (ggml_cuda_get_physical_warp_size()*nwarps, 1)
16
16
static __global__ void mul_mat_f(
17
17
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -57,27 +57,40 @@ static __global__ void mul_mat_f(
57
57
T * tile_xy = (T *) compute_base + threadIdx .y *(tile_A::I * tile_k_padded);
58
58
59
59
if constexpr (has_ids) {
60
+
60
61
int found = 0 ;
61
62
62
- for (int j = threadIdx .y ; j < cols_per_block; j += nwarps) {
63
+ #pragma unroll
64
+ for (int j0 = 0 ; j0 < cols_per_block; j0 += nwarps) {
65
+ const int j = j0 + threadIdx .y ;
63
66
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
64
67
65
68
if (threadIdx .x == 0 ) {
66
69
slot_map[j] = -1 ;
67
70
}
68
71
69
- for (int k_base = 0 ; k_base < nchannels_dst; k_base += warp_size) {
70
- int k = k_base + threadIdx .x ;
71
- int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
72
+ if constexpr (n_expert_used == 0 ) {
73
+ for (int k_base = 0 ; k_base < nchannels_dst; k_base += warp_size) {
74
+ int k = k_base + threadIdx .x ;
75
+ int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
72
76
73
- unsigned mask = __ballot_sync (0xffffffff , match);
74
- if (mask) {
75
- int leader = __ffs (mask) - 1 ;
76
- if (threadIdx .x == leader) {
77
- slot_map[j] = k_base + leader;
77
+ if (match) {
78
+ slot_map[j] = k;
79
+ found = 1 ;
80
+ break ;
81
+ }
82
+ }
83
+ } else {
84
+ #pragma unroll
85
+ for (int k_base = 0 ; k_base < n_expert_used; k_base += warp_size) {
86
+ int k = k_base + threadIdx .x ;
87
+ int match = (k < n_expert_used) && (id_row[k*stride_col_id] == expert_idx);
88
+
89
+ if (match) {
90
+ slot_map[j] = k;
91
+ found = 1 ;
92
+ break ;
78
93
}
79
- found = 1 ;
80
- break ;
81
94
}
82
95
}
83
96
}
@@ -202,6 +215,73 @@ static __global__ void mul_mat_f(
202
215
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
203
216
}
204
217
218
+ template <typename T, int rows_per_block, int cols_per_block, int nwarps>
219
+ static inline void launch_mul_mat_ids (
220
+ const T * x, const float * y, const int32_t * ids, float * dst,
221
+ const int64_t ncols_x, const int64_t nchannels_dst,
222
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
223
+ const int64_t stride_col_id, const int64_t stride_row_id,
224
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
225
+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
226
+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
227
+
228
+
229
+ const int n_expert_used = nchannels_dst;
230
+
231
+ switch (n_expert_used) {
232
+ case 1 : {
233
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 1 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
234
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
235
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
236
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
237
+ } break ;
238
+ case 2 : {
239
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 2 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
240
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
241
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
242
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
243
+ } break ;
244
+ case 4 : {
245
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 4 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
246
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
247
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
248
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
249
+ } break ;
250
+ case 6 : {
251
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 6 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
252
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
253
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
254
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
255
+ } break ;
256
+ case 8 : {
257
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 8 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
258
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
259
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
260
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
261
+ } break ;
262
+ case 16 : {
263
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 16 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
264
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
265
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267
+ } break ;
268
+ case 32 : {
269
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 32 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
270
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
271
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
272
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
273
+ } break ;
274
+ default : {
275
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true , 0 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
276
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
277
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
278
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
279
+ } break ;
280
+ }
281
+
282
+ }
283
+
284
+
205
285
template <typename T, int cols_per_block, int nwarps>
206
286
static inline void mul_mat_f_switch_ids (
207
287
const T * x, const float * y, const int32_t * ids, float * dst,
@@ -212,10 +292,11 @@ static inline void mul_mat_f_switch_ids(
212
292
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
213
293
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
214
294
if (ids) {
215
- mul_mat_f <T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true > <<<block_nums, block_dims, nbytes_shared_total, stream>>>
216
- ( x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
295
+ launch_mul_mat_ids <T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps>(
296
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
217
297
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
218
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
298
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
299
+ block_nums, block_dims, nbytes_shared_total, stream);
219
300
} else {
220
301
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
221
302
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
0 commit comments