Skip to content

Commit 43e3059

Browse files
Add ROCm support (#7)
* Add ROCm support Co-authored-by: [ ] <[email protected]> * Disable half2 by default when using HIP --------- Co-authored-by: [ ] <[email protected]>
1 parent 45de2b5 commit 43e3059

File tree

11 files changed

+85
-4
lines changed

11 files changed

+85
-4
lines changed

cuda_ext.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def find_msvc():
5353
os.path.join(library_dir, "exllama_ext/cuda_func/q4_mlp.cu"),
5454
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp")
5555
],
56+
extra_include_paths = [os.path.join(library_dir, "exllama_ext")],
5657
verbose = verbose,
57-
extra_ldflags = ["cublas.lib"] if windows else []
58+
extra_ldflags = ["cublas.lib"] if windows else [],
59+
extra_cuda_cflags = ["-U__HIP_NO_HALF_CONVERSIONS__"] if torch.version.hip else []
5860
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
5961
)
6062

exllama_ext/cuda_compat.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
4141

4242
//
4343

44-
#ifdef __CUDA_ARCH__
45-
#if __CUDA_ARCH__ < 700
44+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
45+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
4646

4747
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
4848
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }

exllama_ext/cuda_func/column_remap.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
#include "column_remap.cuh"
22
#include "../util.cuh"
33

4+
// Using 1024 make me crash with "Memory access fault by GPU node-1 (Agent
5+
// handle: 0x012345678912) on address 0x012345678912. Reason: Page not present
6+
// or supervisor privilege."
7+
#if defined(USE_ROCM)
8+
const int SHUF_BLOCKSIZE_X = 256;
9+
#else
410
const int SHUF_BLOCKSIZE_X = 1024;
11+
#endif
512
const int SHUF_BLOCKSIZE_Y = 16;
613

714
__global__ void column_remap_kernel

exllama_ext/cuda_func/half_matmul.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include "../util.cuh"
33
#include "../matrix.cuh"
44
#include "../cuda_compat.cuh"
5+
#if defined(USE_ROCM)
6+
#include "../hip_compat.cuh"
7+
#endif
58

69
// Block size
710

exllama_ext/cuda_func/half_matmul.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
#include <cstdint>
77
#include <ATen/cuda/CUDAContext.h>
88

9+
// Workaround for hipify_python using rocblas instead of hipblas.
10+
#if defined(USE_ROCM)
11+
#include <hipblas/hipblas.h>
12+
#define rocblas_handle hipblasHandle_t
13+
#endif
14+
915
void half_matmul_cuda
1016
(
1117
const half* x,

exllama_ext/cuda_func/q4_matmul.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "../matrix.cuh"
55
#include "../cuda_compat.cuh"
66
#include "../cuda_buffers.cuh"
7+
#if defined(USE_ROCM)
8+
#include "../hip_compat.cuh"
9+
#endif
710

811
const int THREADS_X = 32; // Block size and thread count along columns in w and out
912
const int THREADS_Y = 1; // Block size and thread count along rows in x and out

exllama_ext/cuda_func/q4_matmul.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
#include "q4_matrix.cuh"
1111
#include "../tuning.h"
1212

13+
// Workaround for hipify_python using rocblas instead of hipblas.
14+
#if defined(USE_ROCM)
15+
#include <hipblas/hipblas.h>
16+
#define rocblas_handle hipblasHandle_t
17+
#endif
18+
1319
void q4_matmul_cuda
1420
(
1521
ExLlamaTuning* tuningParams,

exllama_ext/cuda_func/q4_mlp.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "../cuda_buffers.cuh"
55
#include "../util.cuh"
66
#include "../matrix.cuh"
7+
#if defined(USE_ROCM)
8+
#include "../hip_compat.cuh"
9+
#endif
710

811
const int THREADS_X = 32;
912
const int THREADS_Y = 4;

exllama_ext/hip_compat.cuh

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#ifndef _hip_compat_cuh
2+
#define _hip_compat_cuh
3+
4+
// Workaround for a bug in hipamd, backported from upstream.
5+
__device__ __forceinline__ __half __compat_hrcp(__half x) {
6+
return __half_raw{
7+
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
8+
}
9+
10+
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
11+
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
12+
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
13+
}
14+
15+
#define hrcp __compat_hrcp
16+
#define h2rcp __compat_h2rcp
17+
18+
// Workaround for hipify_python using rocblas instead of hipblas.
19+
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
20+
hipblasOperation_t transA,
21+
hipblasOperation_t transB,
22+
int m,
23+
int n,
24+
int k,
25+
const half* alpha,
26+
const half* AP,
27+
int lda,
28+
const half* BP,
29+
int ldb,
30+
const half* beta,
31+
half* CP,
32+
int ldc) {
33+
return hipblasHgemm(handle, transA, transB, m, n, k,
34+
reinterpret_cast<const hipblasHalf *>(alpha),
35+
reinterpret_cast<const hipblasHalf *>(AP), lda,
36+
reinterpret_cast<const hipblasHalf *>(BP), ldb,
37+
reinterpret_cast<const hipblasHalf *>(beta),
38+
reinterpret_cast<hipblasHalf *>(CP), ldc);
39+
}
40+
41+
#define rocblas_handle hipblasHandle_t
42+
#define rocblas_operation_none HIPBLAS_OP_N
43+
#define rocblas_hgemm __compat_hipblasHgemm
44+
45+
#endif

exllama_ext/util.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
#include <cstdint>
77
#include <cstdio>
88

9+
#if defined(USE_ROCM)
10+
#define cudaUnspecified hipErrorUnknown
11+
#else
912
#define cudaUnspecified cudaErrorApiFailureBase
13+
#endif
1014

1115
// React to failure on return code != cudaSuccess
1216

0 commit comments

Comments
 (0)