Skip to content

Commit c66fb36

Browse files
committed
unroll n_expert_used loop + remove warp syncs
1 parent 8a6cfa4 commit c66fb36

File tree

1 file changed

+96
-15
lines changed

1 file changed

+96
-15
lines changed

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
1111

1212
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
1313

14-
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
14+
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids, size_t n_expert_used = 0>
1515
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
1616
static __global__ void mul_mat_f(
1717
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -57,27 +57,40 @@ static __global__ void mul_mat_f(
5757
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
5858

5959
if constexpr (has_ids) {
60+
6061
int found = 0;
6162

62-
for (int j = threadIdx.y; j < cols_per_block; j += nwarps) {
63+
#pragma unroll
64+
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
65+
const int j = j0 + threadIdx.y;
6366
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
6467

6568
if (threadIdx.x == 0) {
6669
slot_map[j] = -1;
6770
}
6871

69-
for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) {
70-
int k = k_base + threadIdx.x;
71-
int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
72+
if constexpr (n_expert_used == 0) {
73+
for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) {
74+
int k = k_base + threadIdx.x;
75+
int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
7276

73-
unsigned mask = __ballot_sync(0xffffffff, match);
74-
if (mask) {
75-
int leader = __ffs(mask) - 1;
76-
if (threadIdx.x == leader) {
77-
slot_map[j] = k_base + leader;
77+
if (match) {
78+
slot_map[j] = k;
79+
found = 1;
80+
break;
81+
}
82+
}
83+
} else {
84+
#pragma unroll
85+
for (int k_base = 0; k_base < n_expert_used; k_base += warp_size) {
86+
int k = k_base + threadIdx.x;
87+
int match = (k < n_expert_used) && (id_row[k*stride_col_id] == expert_idx);
88+
89+
if (match) {
90+
slot_map[j] = k;
91+
found = 1;
92+
break;
7893
}
79-
found = 1;
80-
break;
8194
}
8295
}
8396
}
@@ -202,6 +215,73 @@ static __global__ void mul_mat_f(
202215
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
203216
}
204217

218+
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
219+
static inline void launch_mul_mat_ids(
220+
const T * x, const float * y, const int32_t * ids, float * dst,
221+
const int64_t ncols_x, const int64_t nchannels_dst,
222+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
223+
const int64_t stride_col_id, const int64_t stride_row_id,
224+
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
225+
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
226+
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
227+
228+
229+
const int n_expert_used = nchannels_dst;
230+
231+
switch (n_expert_used) {
232+
case 1: {
233+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 1><<<block_nums, block_dims, nbytes_shared_total, stream>>>
234+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
235+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
236+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
237+
} break;
238+
case 2: {
239+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 2><<<block_nums, block_dims, nbytes_shared_total, stream>>>
240+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
241+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
242+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
243+
} break;
244+
case 4: {
245+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 4><<<block_nums, block_dims, nbytes_shared_total, stream>>>
246+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
247+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
248+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
249+
} break;
250+
case 6: {
251+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 6><<<block_nums, block_dims, nbytes_shared_total, stream>>>
252+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
253+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
254+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
255+
} break;
256+
case 8: {
257+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 8><<<block_nums, block_dims, nbytes_shared_total, stream>>>
258+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
259+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
260+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
261+
} break;
262+
case 16: {
263+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 16><<<block_nums, block_dims, nbytes_shared_total, stream>>>
264+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
265+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267+
} break;
268+
case 32: {
269+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 32><<<block_nums, block_dims, nbytes_shared_total, stream>>>
270+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
271+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
272+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
273+
} break;
274+
default: {
275+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true, 0><<<block_nums, block_dims, nbytes_shared_total, stream>>>
276+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
277+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
278+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
279+
} break;
280+
}
281+
282+
}
283+
284+
205285
template<typename T, int cols_per_block, int nwarps>
206286
static inline void mul_mat_f_switch_ids(
207287
const T * x, const float * y, const int32_t * ids, float * dst,
@@ -212,10 +292,11 @@ static inline void mul_mat_f_switch_ids(
212292
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
213293
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
214294
if (ids) {
215-
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
216-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
295+
launch_mul_mat_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps>(
296+
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
217297
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
218-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
298+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
299+
block_nums, block_dims, nbytes_shared_total, stream);
219300
} else {
220301
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
221302
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,

0 commit comments

Comments
 (0)