Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 48 additions & 46 deletions examples/models/voxtral_realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,13 +773,24 @@ class EncoderRingKVCache(nn.Module):
Negative results indicate unwritten slots.
"""

def __init__(self, window_size: int, n_heads: int, head_dim: int):
def __init__(
self,
window_size: int,
n_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
):
super().__init__()
self.window_size = window_size
self.buf_size = window_size * 2
cache_shape = (1, self.buf_size, n_heads, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape))
self.register_buffer("v_cache", torch.zeros(cache_shape))
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device=device)
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype, device=device)
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
Expand All @@ -796,37 +807,24 @@ def update(
def create_causal_mask(
self, start_pos: torch.Tensor | int, seq_len: int
) -> torch.Tensor:
device = (
start_pos.device
if isinstance(start_pos, torch.Tensor)
else self.k_cache.device
)
device = self.k_cache.device
total_written = start_pos + seq_len
j = torch.arange(self.buf_size, dtype=torch.long, device=device)
cache_pos = j + ((total_written - 1 - j) // self.buf_size) * self.buf_size
pos_q = (
start_pos + torch.arange(seq_len, dtype=torch.long, device=device)
).view(-1, 1)
q_offsets = torch.arange(seq_len, dtype=torch.long, device=device)
pos_q = (start_pos + q_offsets).view(-1, 1)
delta = pos_q - cache_pos.unsqueeze(0)
valid = (cache_pos >= 0) & (delta >= 0) & (delta < self.window_size)
return torch.where(valid, 0.0, float("-inf"))


class StandardEncoderRingKVCache(nn.Module):
"""Export-friendly ring buffer KV cache using index_copy_ for updates.
class StandardEncoderRingKVCache(EncoderRingKVCache):
"""EncoderRingKVCache variant that uses index_copy_ instead of
torch.ops.llama.update_cache_with_indices for the ring buffer update.

Compatible with torch.export and AOTI. Uses [B, S, H, D] layout
matching the encoder's convention. Ring buffer enables unlimited streaming.
Compatible with torch.export and AOTI where the custom llama op is unavailable.
"""

def __init__(self, window_size: int, n_heads: int, head_dim: int):
super().__init__()
self.window_size = window_size
self.buf_size = window_size * 2
cache_shape = (1, self.buf_size, n_heads, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape))
self.register_buffer("v_cache", torch.zeros(cache_shape))

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -847,25 +845,6 @@ def update(

return self.k_cache, self.v_cache

def create_causal_mask(self, start_pos: torch.Tensor, seq_len: int) -> torch.Tensor:
"""Create sliding window attention mask for ring buffer.

Args:
start_pos: Tensor containing the starting position (scalar tensor)
seq_len: Number of query positions
"""
total_written = start_pos + seq_len
j = torch.arange(self.buf_size, dtype=torch.long, device=start_pos.device)
cache_pos = j + ((total_written - 1 - j) // self.buf_size) * self.buf_size

# Query positions using tensor operations
q_offsets = torch.arange(seq_len, dtype=torch.long, device=start_pos.device)
pos_q = (start_pos + q_offsets).view(-1, 1)

delta = pos_q - cache_pos.unsqueeze(0)
valid = (cache_pos >= 0) & (delta >= 0) & (delta < self.window_size)
return torch.where(valid, 0.0, float("-inf"))


class StreamingAudioEncoderExport(nn.Module):
"""Streaming encoder: processes one 8-mel-frame chunk at a time.
Expand Down Expand Up @@ -893,22 +872,45 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
self.n_heads = config.enc_n_heads
self.head_dim = config.enc_head_dim

# Infer dtype/device from encoder weights so all buffers match at runtime.
enc_param = next(model.encoder.parameters())

# Register conv states as buffers (mutable state for streaming)
self.register_buffer("conv1_state", torch.zeros(1, config.num_mel_bins, 2))
self.register_buffer("conv2_state", torch.zeros(1, config.enc_dim, 2))
self.register_buffer(
"conv1_state",
torch.zeros(
1,
config.num_mel_bins,
2,
dtype=enc_param.dtype,
device=enc_param.device,
),
)
self.register_buffer(
"conv2_state",
torch.zeros(
1, config.enc_dim, 2, dtype=enc_param.dtype, device=enc_param.device
),
)

# Ring buffer KV caches for unlimited streaming.
# Window size = max_enc_len (encoder sliding window from params.json).
# Buffer is 2x internally for safe wraparound.
# Choose cache implementation based on backend
# Choose cache implementation based on backend.
cache_class = (
StandardEncoderRingKVCache
if config.use_standard_attention
else EncoderRingKVCache
)
self.kv_caches = nn.ModuleList(
[
cache_class(max_enc_len, config.enc_n_heads, config.enc_head_dim)
cache_class(
max_enc_len,
config.enc_n_heads,
config.enc_head_dim,
dtype=enc_param.dtype,
device=enc_param.device,
)
for _ in range(config.enc_n_layers)
]
)
Expand Down
Loading