Skip to content

Commit b1348d3

Browse files
JohannesGaesslerggerganov
authored andcommitted
CUDA/HIP: fix tests/test-backend-ops (llama/8896)
1 parent 90641b5 commit b1348d3

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ggml/src/ggml-cuda.cu

+7-6
Original file line numberDiff line numberDiff line change
@@ -2742,11 +2742,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27422742
case GGML_OP_MUL_MAT_ID:
27432743
{
27442744
struct ggml_tensor * a = op->src[0];
2745-
if (op->op == GGML_OP_MUL_MAT) {
2746-
struct ggml_tensor * b = op->src[1];
2747-
if (a->ne[3] != b->ne[3]) {
2748-
return false;
2749-
}
2745+
struct ggml_tensor * b = op->src[1];
2746+
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
2747+
return false;
2748+
}
2749+
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
2750+
return false;
27502751
}
27512752
switch (a->type) {
27522753
case GGML_TYPE_F32:
@@ -2877,7 +2878,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28772878
return true;
28782879
case GGML_OP_FLASH_ATTN_EXT:
28792880
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2880-
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2881+
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
28812882
#else
28822883
if (op->src[0]->ne[0] == 128) {
28832884
return true;

0 commit comments

Comments
 (0)