Skip to content

metal: SSM_SCAN performance #14743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ typedef struct {
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
int64_t s_off;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is actually necessary. I did this when initially trying to emulate the CUDA kernel better which passes this as an arg. It does seem like it might be slightly faster to avoid computing this in the kernel, though that could be offset by the latency of an additional int64_t being passed to the device?

uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
Expand Down
43 changes: 40 additions & 3 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2909,7 +2909,26 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
const int64_t d_state = ne10;

// One shared memory bucket for each simd group in the threadgroup
if (d_state >= 32) {
const int64_t shmem_size = d_state / 32;
Comment on lines +2912 to +2916
Copy link
Collaborator

@compilade compilade Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably be called d_conv for consistency with how it's named in the graph and/or in the CPU op.

This is 4 for Mamba and Mamba-2 models, and 3 for LiquidAI's LFM2 models.

With such small sums, does it still make a measurable difference to use simd_sum for SSM_CONV?


// The final simd_sum won't work if the number of simd groups is
// larger than the size of a single simd group. If this case is
// hit at some point, the logic in the second simd_sum could be
// expanded to handle this with one more sequential simd_sum to
// collapse simd group sums another time.
GGML_ASSERT(shmem_size <= 32);

// One thread pre element in d_state
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);

[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
}

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} break;
case GGML_OP_SSM_SCAN:
{
Expand Down Expand Up @@ -2986,6 +3005,7 @@ static bool ggml_metal_encode_node(
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
Expand Down Expand Up @@ -3014,12 +3034,29 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&args length:sizeof(args) atIndex:8];

// One shared memory bucket for each simd group in the threadgroup
if (d_state >= 32) {
const int64_t shmem_size = d_state / 32;

// The final simd_sum won't work if the number of simd groups is
// larger than the size of a single simd group. If this case is
// hit at some point, the logic in the second simd_sum could be
// expanded to handle this with one more sequential simd_sum to
// collapse simd group sums another time.
GGML_ASSERT(shmem_size <= 32);

// One thread pre element in d_state
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);

[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
}

if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
}
} break;
case GGML_OP_RWKV_WKV6:
Expand Down
222 changes: 174 additions & 48 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32(
device const void * src0,
device const void * src1,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_conv & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {

const int64_t i0 = tpitg.x;
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
Expand All @@ -1681,13 +1687,31 @@ kernel void kernel_ssm_conv_f32(
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);

float sumf = 0.0f;
float sumf = s[i0] * c[i0];

for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
// Parallel sum: first sum over threads in simd group, then sum over simd
// group sums
sumf = simd_sum(sumf);

x[0] = sumf;
// If multiple simd groups per threadgroup, sum over simd group sums
if (sgptg > 1) {
if (tiisg == 0) {
shared[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
x[0] = sumf;
}
}
} else if (tiisg == 0) {
x[0] = sumf;
}
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
Expand All @@ -1700,10 +1724,16 @@ kernel void kernel_ssm_scan_f32(
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {

const int64_t i0 = tpitg.x;
const int64_t i1 = 0;
const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
Expand All @@ -1718,41 +1748,88 @@ kernel void kernel_ssm_scan_f32(
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;

const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
const int64_t s_off = args.s_off;

device const int32_t * ids = (device const int32_t *) src6;

device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
float s0 = s0_buff[i];
float s = s_buff[i];

device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);

for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}

const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f;

for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
s = state;

// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.

// Computed for each thread
float sumf = state * C[i0];

// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);

y[0] = sumf;
if (sgptg > 1) {

// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}

// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup);

// For simd group 0 at indices < num simd groups, extract the shared
// simd sum
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
y[0] = sumf;
}
}
} else if (tiisg == 0) {
y[0] = sumf;
}

// recurse
s0 = s;
}

// Assign the final state to the output buffer
s_buff[i] = s;
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
Expand All @@ -1762,10 +1839,16 @@ kernel void kernel_ssm_scan_f32_group(
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {

const int64_t i0 = tpitg.x;
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
Expand All @@ -1780,38 +1863,81 @@ kernel void kernel_ssm_scan_f32_group(
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;

const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
const int64_t s_off = args.s_off;

device const int32_t * ids = (device const int32_t *) src6;

device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
float s0 = s0_buff[i];
float s = s_buff[i];

device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);

for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}

const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
float sumf = 0.0f;

for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
const float state = (s0 * dA) + (B[i0] * x_dt);
s = state;

// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.

// Computed for each thread
float sumf = state * C[i0];

// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);

// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}

y[0] = sumf;
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup);

// For simd group 0 at indices < num simd groups, extract the shared
// simd sum
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
y[0] = sumf;
}
}

// recurse
s0 = s;
}

// Assign the final state to the output buffer
s_buff[i] = s;
}

kernel void kernel_rwkv_wkv6_f32(
Expand Down
Loading