-
Notifications
You must be signed in to change notification settings - Fork 12.4k
HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 #14624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
68da4e5
79f348a
89ba8a6
e57e563
9784a51
dad79b3
ff60fa9
e8eeb34
a161900
75d386a
0215a80
aa35feb
ba17f62
5ab1491
fb2fd31
b55d44a
ab7c007
279b51e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -66,7 +66,40 @@ namespace ggml_cuda_mma { | |||||
struct tile { | ||||||
static constexpr int I = I_; | ||||||
static constexpr int J = J_; | ||||||
static constexpr int ne = I * J / WARP_SIZE; | ||||||
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) | ||||||
static constexpr int ne = I * J / 64; | ||||||
T x[ne] = {0}; | ||||||
|
||||||
static __device__ __forceinline__ int get_i(const int l) { | ||||||
if constexpr (I == 16 && J == 8) { | ||||||
return threadIdx.x % 16; | ||||||
} else if constexpr (I == 32 && J == 4) { | ||||||
return threadIdx.x % 32; | ||||||
} else if constexpr (I == 16 && J == 16) { | ||||||
return 4 * (threadIdx.x / 16) + l; | ||||||
} else if constexpr (I == 32 && J == 32) { | ||||||
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); | ||||||
} else { | ||||||
static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||||||
} | ||||||
} | ||||||
|
||||||
static __device__ __forceinline__ int get_j(const int l) { | ||||||
if constexpr (I == 16 && J == 8) { | ||||||
return 2 * (threadIdx.x / 16) + l; | ||||||
} else if constexpr (I == 32 && J == 4) { | ||||||
return 2 * (threadIdx.x / 32) + l; | ||||||
} else if constexpr (I == 16 && J == 16) { | ||||||
return threadIdx.x % 16; | ||||||
} else if constexpr (I == 32 && J == 32) { | ||||||
return threadIdx.x % 32; | ||||||
} else { | ||||||
static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||||||
} | ||||||
} | ||||||
#else | ||||||
static constexpr int ne = I * J / 32; | ||||||
T x[ne] = {0}; | ||||||
|
||||||
static __device__ __forceinline__ int get_i(const int l) { | ||||||
|
@@ -94,6 +127,7 @@ namespace ggml_cuda_mma { | |||||
static_assert(I == -1 && J == -1, "template specialization not implemented"); | ||||||
} | ||||||
} | ||||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) | ||||||
}; | ||||||
|
||||||
template <int I_, int J_> | ||||||
|
@@ -148,10 +182,16 @@ namespace ggml_cuda_mma { | |||||
|
||||||
template <int I, int J, typename T> | ||||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) { | ||||||
#if defined(AMD_MFMA_AVAILABLE) | ||||||
int64_t * xi = (int64_t *) t.x; | ||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | ||||||
xi[0] = xs[0]; | ||||||
#else | ||||||
#pragma unroll | ||||||
for (int l = 0; l < t.ne; ++l) { | ||||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; | ||||||
} | ||||||
#endif // defined(AMD_MFMA_AVAILABLE) | ||||||
} | ||||||
|
||||||
template <typename T> | ||||||
|
@@ -186,7 +226,7 @@ namespace ggml_cuda_mma { | |||||
template <typename T> | ||||||
static __device__ __forceinline__ void load_ldmatrix( | ||||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { | ||||||
#ifdef NEW_MMA_AVAILABLE | ||||||
#if defined(NEW_MMA_AVAILABLE) | ||||||
int * xi = (int * ) t.x; | ||||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); | ||||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" | ||||||
|
@@ -386,6 +426,62 @@ namespace ggml_cuda_mma { | |||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) | ||||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); | ||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
#else | ||||||
GGML_UNUSED(D); | ||||||
GGML_UNUSED(A); | ||||||
GGML_UNUSED(B); | ||||||
NO_DEVICE_CODE; | ||||||
#endif // NEW_MMA_AVAILABLE | ||||||
} | ||||||
|
||||||
static __device__ __forceinline__ void mma( | ||||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { | ||||||
#if defined(AMD_MFMA_AVAILABLE) | ||||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; | ||||||
int32x4_t * acc = (int32x4_t *) D.x; | ||||||
#if defined(CDNA3) | ||||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], | ||||||
((int64_t *) B.x)[0], | ||||||
acc[0], | ||||||
0, 0, 0); | ||||||
#elif defined(CDNA2) || defined(CDNA) | ||||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], | ||||||
B.x[0], | ||||||
acc[0], | ||||||
0, 0, 0); | ||||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], | ||||||
B.x[1], | ||||||
acc[0], | ||||||
0, 0, 0); | ||||||
#endif | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
#else | ||||||
GGML_UNUSED(D); | ||||||
GGML_UNUSED(A); | ||||||
GGML_UNUSED(B); | ||||||
NO_DEVICE_CODE; | ||||||
#endif // NEW_MMA_AVAILABLE | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
} | ||||||
|
||||||
static __device__ __forceinline__ void mma( | ||||||
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) { | ||||||
#if defined(AMD_MFMA_AVAILABLE) | ||||||
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; | ||||||
int32x16_t * acc = (int32x16_t *) D.x; | ||||||
#if defined(CDNA3) | ||||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], | ||||||
((int64_t *) B.x)[0], | ||||||
acc[0], | ||||||
0, 0, 0); | ||||||
#elif defined(CDNA2) || defined(CDNA) | ||||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], | ||||||
B.x[0], | ||||||
acc[0], | ||||||
0, 0, 0); | ||||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], | ||||||
B.x[1], | ||||||
acc[0], | ||||||
0, 0, 0); | ||||||
#endif | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Github doesn't let me comment on the line in question but please also fix the comment for the |
||||||
#else | ||||||
GGML_UNUSED(D); | ||||||
GGML_UNUSED(A); | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Github doesn't let me comment on the line in question, but please amend the documentation at the top of the file: