@@ -500,7 +500,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
500
500
}
501
501
502
502
static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
503
- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
503
+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
504
504
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
505
505
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
506
506
const dim3 block_nums (block_num_y, 1 , 1 );
@@ -510,7 +510,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
510
510
}
511
511
512
512
static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
513
- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
513
+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
514
514
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
515
515
const dim3 block_nums (block_num_y, 1 , 1 );
516
516
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -519,7 +519,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
519
519
}
520
520
521
521
static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
522
- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
522
+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
523
523
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
524
524
const dim3 block_nums (block_num_y, 1 , 1 );
525
525
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -528,7 +528,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
528
528
}
529
529
530
530
static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
531
- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
531
+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
532
532
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
533
533
const dim3 block_nums (block_num_y, 1 , 1 );
534
534
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -537,7 +537,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
537
537
}
538
538
539
539
static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
540
- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
540
+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
541
541
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
542
542
const dim3 block_nums (block_num_y, 1 , 1 );
543
543
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -588,7 +588,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
588
588
}
589
589
590
590
static void convert_mul_mat_vec_f16_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
591
- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
591
+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
592
592
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
593
593
const dim3 block_nums (block_num_y, 1 , 1 );
594
594
const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -672,3 +672,12 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
672
672
GGML_UNUSED (src1_ncols);
673
673
GGML_UNUSED (src1_padded_row_size);
674
674
}
675
+
676
+ bool ggml_cuda_dmmv_type_supported (ggml_type src0_type) {
677
+ return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
678
+ src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
679
+ src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
680
+ src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
681
+ src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
682
+ src0_type == GGML_TYPE_F16;
683
+ }
0 commit comments