Skip to content

cuda : support Falcon-H1 state size for SSM_SCAN #14602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3335,8 +3335,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2
// (kernel only supports d_state == 128 && d_head % 16 == 0)
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
} else {
// Mamba
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
Expand Down
15 changes: 13 additions & 2 deletions ggml/src/ggml-cuda/ssm-scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
const int threads = 128;
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
Expand All @@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else if (d_state == 256) { // Falcon-H1
const int threads = 256;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For learning purpose, the difference between the two calls is the number of threads used to call the cuda kernel - is there any implementational difference in the kernel itself for the different values of d_state ?

Copy link
Collaborator Author

@compilade compilade Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada Basically, the kernel currently assumes the number of threads in a block is the same as d_state.

It could also have been handled by restructuring the kernel to make each thread handle more than one intermediate state in the reduction (in the dot product with C), which might or might not be faster.

Each thread technically already handles multiple intermediate states by reducing over multiple head elements at once (i.e. splitH). This also allows calling expf less often per head.

I didn't particularly optimize the kernel, so there's most likely room for improvement.

(It could potentially be faster to use the semi-structured matrices implementation of Mamba-2 for better prompt processing speed, but from my (maybe wrong) understanding, that only allows starting from a blank state.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining this is very clear @compilade !

// NOTE: can be any power of two between 8 and 64
const int splitH = 16;
GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else {
GGML_ABORT("doesn't support d_state!=128.");
GGML_ABORT("doesn't support d_state!=(128 or 256).");
}
} else {
const int threads = 128;
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
Expand Down
1 change: 1 addition & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5066,6 +5066,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {

test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1

test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
Expand Down
Loading