Skip to content

Commit 5218ea2

Browse files
slarenggerganov
authored andcommitted
cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X (llama/8800)
* cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X * update asserts * only use dmmv for supported types * add test
1 parent e60be82 commit 5218ea2

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,10 +1885,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18851885
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
18861886
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
18871887

1888-
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
1888+
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
18891889
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1890-
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
1891-
&& src1->ne[1] == 1;
1890+
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
18921891
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
18931892
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
18941893
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;

ggml/src/ggml-cuda/dmmv.cu

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
500500
}
501501

502502
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
503-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
503+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
504504
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
505505
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
506506
const dim3 block_nums(block_num_y, 1, 1);
@@ -510,7 +510,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
510510
}
511511

512512
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
513-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
513+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
514514
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
515515
const dim3 block_nums(block_num_y, 1, 1);
516516
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -519,7 +519,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
519519
}
520520

521521
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
522-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
522+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
523523
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
524524
const dim3 block_nums(block_num_y, 1, 1);
525525
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -528,7 +528,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
528528
}
529529

530530
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
531-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
531+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
532532
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
533533
const dim3 block_nums(block_num_y, 1, 1);
534534
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -537,7 +537,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
537537
}
538538

539539
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
540-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
540+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
541541
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
542542
const dim3 block_nums(block_num_y, 1, 1);
543543
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -588,7 +588,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
588588
}
589589

590590
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
591-
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
591+
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
592592
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
593593
const dim3 block_nums(block_num_y, 1, 1);
594594
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -672,3 +672,12 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
672672
GGML_UNUSED(src1_ncols);
673673
GGML_UNUSED(src1_padded_row_size);
674674
}
675+
676+
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
677+
return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
678+
src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
679+
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
680+
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
681+
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
682+
src0_type == GGML_TYPE_F16;
683+
}

ggml/src/ggml-cuda/dmmv.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
1616
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1717
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
1818
const int64_t src1_padded_row_size, cudaStream_t stream);
19+
20+
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);

0 commit comments

Comments
 (0)