Skip to content

Commit

Permalink
CUDA/HIP: add support for selectable warp size to mmv (#11519)
Browse files Browse the repository at this point in the history
CUDA/HIP: add support for selectable warp size to mmv
  • Loading branch information
IMbackK authored Feb 2, 2025
1 parent 4d0598e commit 396856b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ static constexpr bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
}

static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
#else
return 32;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}

[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
Expand Down
38 changes: 24 additions & 14 deletions ggml/src/ggml-cuda/mmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ template <typename T, typename type_acc, int block_size>
static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z;
const int tid = threadIdx.x;
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
y += channel *stride_channel_y;
Expand All @@ -18,8 +19,8 @@ static __global__ void mul_mat_vec(
extern __shared__ char data_mmv[];
float * buf_iw = (float *) data_mmv;

if (block_size > WARP_SIZE) {
if (tid < WARP_SIZE) {
if (block_size > warp_size) {
if (tid < warp_size) {
buf_iw[tid] = 0.0f;
}
__syncthreads();
Expand Down Expand Up @@ -67,16 +68,16 @@ static __global__ void mul_mat_vec(
static_assert(std::is_same<T, void>::value, "unsupported type");
}

sumf = warp_reduce_sum(sumf);
sumf = warp_reduce_sum<warp_size>(sumf);

if (block_size > WARP_SIZE) {
buf_iw[tid/WARP_SIZE] = sumf;
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf;
__syncthreads();
if (tid >= WARP_SIZE) {
if (tid >= warp_size) {
return;
}
sumf = buf_iw[tid];
sumf = warp_reduce_sum(sumf);
sumf = warp_reduce_sum<warp_size>(sumf);
}

if (tid != 0) {
Expand All @@ -96,18 +97,27 @@ static void launch_mul_mat_vec_cuda(
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(nchannels_y % nchannels_x == 0);
const int64_t channel_ratio = nchannels_y / nchannels_x;
int device;
int warp_size;

int64_t block_size_best = WARP_SIZE;
int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;

int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
int64_t max_block_size = 256;
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
max_block_size = 128;
}
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
if (niter < niter_best) {
niter_best = niter;
block_size_best = block_size;
}
}

const int smem = WARP_SIZE*sizeof(float);
const int smem = warp_size*sizeof(float);
const dim3 block_nums(nrows, 1, nchannels_y);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
Expand All @@ -8,6 +9,7 @@
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__

#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
Expand Down

0 comments on commit 396856b

Please sign in to comment.