Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def mlstm_chunkwise_bw(
vecN_out=vecN_out, # (B, NH, S)
matDeltaC_last=matDeltaC_last, # (B, NH, DHQK, DHHV)
qk_scale=qk_scale,
chunk_size=kernel_chunk_params.chunk_size_inter,
eps=eps,
save_states_every_nth_chunk=kernel_chunk_params.save_states_every_nth_chunk,
num_stages=num_stages_inter,
num_warps=num_warps_inter,
)
Expand Down
22 changes: 3 additions & 19 deletions mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def mlstm_chunkwise__recurrent_bw_dC(
vecN_out: torch.Tensor, # (B, NH, S)
matDeltaC_last: torch.Tensor = None, # (B, NH, DHQK, DHHV)
qk_scale: float = None,
chunk_size: int = 64,
save_states_every_nth_chunk: int = 1,
num_warps: int | None = None,
num_stages: int | None = None,
eps: float = 0.0,
Expand All @@ -31,31 +29,18 @@ def mlstm_chunkwise__recurrent_bw_dC(
"""
B, NH, S, DHQK, DHHV = *matQ.shape, matDeltaH.shape[-1]
_dtype, _device = matQ.dtype, matQ.device
L = chunk_size
NC = scaM_inter.shape[-1] - 1
L = S // NC
assert is_power_of_2(L), "Chunk size must be a power of 2."
assert S % L == 0, "S must be divisible by chunk_size."
NC = S // L

assert (
save_states_every_nth_chunk > 0
), "save_states_every_nth_chunk must be positive."
assert (
save_states_every_nth_chunk <= NC
), "save_states_every_nth_chunk must be <= NC."

assert is_power_of_2(
save_states_every_nth_chunk
), f"save_states_every_nth_chunk must be a power of 2. Got {save_states_every_nth_chunk}."

if qk_scale is None:
qk_scale = DHQK**-0.5

USE_LAST_STATE = matDeltaC_last is not None

num_chunks_saved = NC // save_states_every_nth_chunk

matDeltaC_states = torch.empty(
(B, NH, (num_chunks_saved + 1) * DHQK, DHHV),
(B, NH, (NC + 1) * DHQK, DHHV),
dtype=torch.float32,
device=_device,
)
Expand Down Expand Up @@ -109,7 +94,6 @@ def mlstm_chunkwise__recurrent_bw_dC(
L=L,
siz_b_DHQK=siz_b_DHQK,
siz_b_DHHV=siz_b_DHHV,
save_states_every_nth_chunk=save_states_every_nth_chunk,
USE_LAST_STATE=USE_LAST_STATE,
DTYPE=torch2triton_dtype(_dtype),
EPS=eps,
Expand Down
36 changes: 17 additions & 19 deletions mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def mlstm_chunkwise__recurrent_bw_dC_kernel(
L: tl.constexpr,
siz_b_DHQK: tl.constexpr,
siz_b_DHHV: tl.constexpr,
save_states_every_nth_chunk: tl.constexpr,
USE_LAST_STATE: tl.constexpr,
DTYPE: tl.constexpr = tl.float32,
EPS: tl.constexpr = 1e-6,
Expand Down Expand Up @@ -100,24 +99,23 @@ def mlstm_chunkwise__recurrent_bw_dC_kernel(
order=(1, 0),
)
# ? end pointers
if k % save_states_every_nth_chunk == 0:
idx_k_save = k // save_states_every_nth_chunk
# * store matDeltaC_k_val from previous iteration in HBM
matDeltaCstates_k_ptr = tl.make_block_ptr(
base=matDeltaC_states
+ idx_b_NH * str_matDeltaC_states_B_NH
+ idx_k_save * DHQK * DHHV,
shape=(DHQK, DHHV),
strides=(str_matDeltaC_states_NCDHQK, str_matDeltaC_states_DHHV),
offsets=(idx_b_DHQK * siz_b_DHQK, idx_b_DHHV * siz_b_DHHV),
block_shape=(siz_b_DHQK, siz_b_DHHV),
order=(1, 0),
)
tl.store(
matDeltaCstates_k_ptr,
matDeltaC_k_val.to(tl.float32),
boundary_check=(0, 1),
)

# * store matDeltaC_k_val from previous iteration in HBM
matDeltaCstates_k_ptr = tl.make_block_ptr(
base=matDeltaC_states
+ idx_b_NH * str_matDeltaC_states_B_NH
+ k * DHQK * DHHV,
shape=(DHQK, DHHV),
strides=(str_matDeltaC_states_NCDHQK, str_matDeltaC_states_DHHV),
offsets=(idx_b_DHQK * siz_b_DHQK, idx_b_DHHV * siz_b_DHHV),
block_shape=(siz_b_DHQK, siz_b_DHHV),
order=(1, 0),
)
tl.store(
matDeltaCstates_k_ptr,
matDeltaC_k_val.to(tl.float32),
boundary_check=(0, 1),
)

# * compute matDeltaC_km1_val
# load scaG_k, vecB_k, scaM_inter_km1, scaM_inter_k, vecM_combine_k
Expand Down
40 changes: 40 additions & 0 deletions tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,46 @@ def test_triton_chunkwise_xl_chunk_vs_native_parallel_stablef_fp32(
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.")
@pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations)
def test_inter_vs_intra_chunks(S, B, NH, DHQK, DHHV):
torch.manual_seed(2025)
q = torch.randn(B, NH, S, DHQK, device="cuda", requires_grad=True)
k = torch.randn(B, NH, S, DHQK, device="cuda", requires_grad=True)
v = torch.randn(B, NH, S, DHHV, device="cuda", requires_grad=True)
i = torch.randn(B, NH, S, device="cuda", requires_grad=True)
f = torch.randn(B, NH, S, device="cuda", requires_grad=True)

h_ref = mlstm_chunkwise__xl_chunk(
q, k, v, i, f,
chunk_size=128, chunk_size_inter=128, chunk_size_intra=128,
siz_b_L_parallel=64, siz_b_L_loop=64,
siz_b_DH_parallel=DHHV, siz_b_DH_loop=DHHV,
)

dh = torch.randn_like(h_ref)
dq_ref, dk_ref, dv_ref, di_ref, df_ref = torch.autograd.grad(
[h_ref], [q, k, v, i, f], [dh]
)

h = mlstm_chunkwise__xl_chunk(
q, k, v, i, f,
chunk_size=128, chunk_size_inter=64, chunk_size_intra=128,
siz_b_L_parallel=64, siz_b_L_loop=64,
siz_b_DH_parallel=DHHV, siz_b_DH_loop=DHHV
)
dq, dk, dv, di, df = torch.autograd.grad(
[h], [q, k, v, i, f], [dh]
)

torch.testing.assert_close(h, h_ref)
torch.testing.assert_close(dq, dq_ref)
torch.testing.assert_close(dk, dk_ref)
torch.testing.assert_close(dv, dv_ref)
torch.testing.assert_close(di, di_ref)
torch.testing.assert_close(df, df_ref)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.")
def test_state_passing(mlstm_state_passing_test, state_passing_qkvif):
num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64
Expand Down