Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/models/voxtral_realtime/export_voxtral_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
- token_embedding: token_ids (1, seq_len) -> embeds (1, seq_len, 3072)

With --streaming, produces a streaming .pte instead:
- encode_audio_chunk: mel_chunk (1,128,8) + conv states + enc_pos -> audio_embeds + new states
- encode_audio_chunk: mel_chunk (1,128,8) + enc_pos (4,) -> audio_embeds (1,1,3072)
- text_decoder: same as above
- token_embedding: same as above

Backend support:
- XNNPACK (default): Uses custom SDPA op (torch.ops.llama.custom_sdpa) for optimal performance
- Metal/AOTI: Automatically switches to standard PyTorch SDPA (F.scaled_dot_product_attention)
for text_decoder to avoid AOTI compilation issues. Uses Dim.AUTO for audio encoder
dynamic shapes (explicit bounds cause issues with AOTI). All components run on Metal GPU.
- Metal/AOTI: Uses MetalSDPA (_scaled_dot_product_attention_math_for_mps) for text_decoder
and StandardEncoderSDPA (F.scaled_dot_product_attention) for streaming encoder,
avoiding custom_sdpa which is incompatible with AOTI. Uses Dim.AUTO for audio
encoder dynamic shapes (explicit bounds cause issues with AOTI).
- Portable: Uses custom SDPA like XNNPACK

Usage:
Expand Down Expand Up @@ -475,12 +476,11 @@ def main():

# Load model
print("Loading model...")
use_standard_attention = args.backend == "metal"
model = load_model(
args.model_path,
max_seq_len=args.max_seq_len,
n_delay_tokens=args.delay_tokens,
use_standard_attention=use_standard_attention,
backend=args.backend,
)

# Untie output/embedding weights before quantization so each layer gets
Expand Down
74 changes: 38 additions & 36 deletions examples/models/voxtral_realtime/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ The model exports three methods (offline mode):
| `token_embedding` | token IDs `(1, seq_len)` | embeddings `(1, seq_len, 3072)` |

With `--streaming`, `audio_encoder` is replaced by `encode_audio_chunk`
which takes a mel chunk `(1, 128, 8)` + conv states + encoder positions
and returns audio embeddings `(1, 1, 3072)` + updated conv states.
which takes a mel chunk `(1, 128, 8)` + encoder positions `(4,)` and
returns audio embeddings `(1, 1, 3072)`. Conv states are maintained as
internal buffers.

Audio and text embeddings are **summed** at each position (not concatenated
or masked-scatter like the original non-realtime Voxtral).
Expand Down Expand Up @@ -101,21 +102,21 @@ VoxtralRealtimeModel
attention: LMAttention
wq/wk/wv/wo: Linear (no bias)
kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal)
sdpa: SDPA (XNNPACK) or StandardSDPA (Metal)
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal)
ffn_norm: RMSNorm
ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear)
feed_forward: LMMLP (w1/w2/w3)
norm: RMSNorm
output: Linear (tied to tok_embeddings)

StreamingAudioEncoderExport (XNNPACK/Portable only)
StreamingAudioEncoderExport
conv1: nn.Conv1d (shared from encoder.conv_layers[0].conv)
conv2: nn.Conv1d (shared from encoder.conv_layers[1].conv)
layers: 32x CausalEncoderLayer (shared from encoder.layers)
enc_norm: RMSNorm (shared from encoder.norm)
adapter: AudioLanguageAdapter (shared from model.adapter)
kv_caches: 32x EncoderRingKVCache (ring buffer for sliding window attention)
sdpa: SDPA (for streaming attention with custom_sdpa op)
kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal)
sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal)
inv_freq: RoPE inverse frequencies (owned, on-the-fly computation)
```

Expand All @@ -137,11 +138,8 @@ than 750 encoder frames (~15s), full causal is equivalent.

The text decoder (`MistralDecoder`) is a 26-layer Mistral decoder with
GQA (32 query heads, 8 KV heads). Backend selection is controlled by the
`use_standard_attention` config flag, set by the export script:

```python
use_standard_attention = (args.backend == "metal")
```
`backend` config field, passed through from the export script's `--backend`
flag (e.g., `"xnnpack"`, `"metal"`, `"portable"`).

### KV cache

Expand All @@ -164,9 +162,10 @@ backend-specific implementations.
fused kernel with causal masking via `start_pos` + `is_causal=True`.
Handles GQA expansion internally and upcasts to float32.

**Metal:** `StandardSDPA` uses `F.scaled_dot_product_attention` with
explicit attention masks. AOTInductor has compatibility issues with the
`custom_sdpa` custom op.
**Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps`
which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth
overhead of `repeat_interleave`. Uses explicit additive attention masks.
AOTInductor has compatibility issues with the `custom_sdpa` custom op.

### Attention layout

Expand All @@ -178,8 +177,9 @@ explicit attention masks. AOTInductor has compatibility issues with the
require when using `[B, H, S, D]` attention with `[B, S, H, D]` cache.

**Metal:** Q/K/V projections still produce `[B, T, H, D]`, but
`StaticKVCache` stores `[B, H, S, D]` and `StandardSDPA` transposes q to
`[B, H, T, D]` for `F.scaled_dot_product_attention`, then transposes back.
`StaticKVCache` stores `[B, H, S, D]` and `MetalSDPA` transposes q to
`[B, H, T, D]` for `_scaled_dot_product_attention_math_for_mps`, then
transposes back.

### Adaptive RMSNorm

Expand All @@ -205,28 +205,25 @@ mel at once. It shares all weights with the offline encoder but uses a
different forward path:

```
mel_chunk (1, 128, 8)
+ conv1_state (1, 128, 2) + conv2_state (1, 1280, 2)
mel_chunk (1, 128, 8) + enc_input_pos (4,)
conv1_state (1, 128, 2) and conv2_state (1, 1280, 2) are internal buffers
-> cat(state, chunk) -> raw Conv1d (no CausalConv1d padding) -> GELU
-> cat(state, conv1_out) -> raw Conv1d -> GELU
(1, 1280, 4) -> transpose -> (1, 4, 1280)
-> 32x streaming encoder layer (EncoderRingKVCache + custom_sdpa)
-> 32x streaming encoder layer (ring KV cache + SDPA)
-> RMSNorm
(1, 4, 1280)
-> Reshape downsample (1, 1, 5120) -> Adapter (1, 1, 3072)
-> audio_embeds, new_conv1_state, new_conv2_state
-> audio_embeds (1, 1, 3072)
```

**XNNPACK/Portable only.** Metal does not yet support streaming mode.
The custom ops used by `StreamingAudioEncoderExport`
(`update_cache_with_indices`, `custom_sdpa`) are incompatible with AOTI.
Adding Metal streaming support would require:
**XNNPACK/Portable:** Uses `EncoderRingKVCache` (`update_cache_with_indices`
custom op) and `SDPA` (`custom_sdpa`).

- Replace `EncoderRingKVCache` with an `index_copy_`-based ring buffer
(similar to `StaticKVCache` but with modular index arithmetic)
- Replace `SDPA` (`custom_sdpa`) with `StandardSDPA` using explicit
sliding window masks
- These are the same patterns already used in the Metal text decoder
**Metal:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring
buffer) and `StandardEncoderSDPA` (`F.scaled_dot_product_attention` with
explicit sliding window masks) — the same patterns used in the Metal
text decoder.

### Streaming decode loop

Expand Down Expand Up @@ -258,9 +255,10 @@ encoder — verified to within fp32 precision (max diff < 2e-5).

### Encoder KV cache

Each of the 32 encoder transformer layers gets its own `EncoderRingKVCache`
instance — a ring buffer that overwrites old entries when the window is
exceeded, enabling streaming of arbitrary length audio.
Each of the 32 encoder transformer layers gets its own ring buffer KV
cache (`EncoderRingKVCache` for XNNPACK/Portable, `StandardEncoderRingKVCache`
for Metal) that overwrites old entries when the window is exceeded,
enabling streaming of arbitrary length audio.

- Cache shape: `(1, 2*max_enc_len, 32, 64)` per layer. The buffer is 2x the
window size because writes happen *before* attention. With a 1x buffer
Expand All @@ -281,10 +279,14 @@ ring buffer. This is unrelated to `max_enc_len=16384` in
`CausalWhisperEncoder.__init__`, which is the RoPE frequency table size
for the offline encoder.

Cache writes use `torch.ops.llama.update_cache_with_indices` (a custom op
that scatter-writes via an indices tensor). Write indices are computed
analytically: `(arange(seq_len) + start_pos) % buf_size`. No mutable
position state is needed.
**XNNPACK/Portable:** Cache writes use `torch.ops.llama.update_cache_with_indices`
(a custom op that scatter-writes via an indices tensor). Write indices are
computed analytically: `(arange(seq_len) + start_pos) % buf_size`.

**Metal:** Cache writes use `index_copy_` with wrapped indices
(`input_pos % buf_size`).

No mutable position state is needed in either variant.

Position tracking is analytic — no mutable state buffer. For buffer
slot `j` after `total_written` frames have been stored:
Expand Down
31 changes: 17 additions & 14 deletions examples/models/voxtral_realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ class VoxtralRealtimeConfig:
downsample_factor: int = 4
# Runtime
max_seq_len: int = 4096
use_standard_attention: bool = (
False # Use standard PyTorch attention instead of custom ops
)
backend: str = "xnnpack" # "xnnpack", "metal", or "portable"

@staticmethod
def from_params_json(path: str) -> "VoxtralRealtimeConfig":
Expand Down Expand Up @@ -563,15 +561,15 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int):
self.n_kv_heads = config.n_kv_heads
self.head_dim = config.head_dim
self.dim = config.dim
self.use_standard_attention = config.use_standard_attention
self.backend = config.backend

self.wq = nn.Linear(config.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False)

# Choose KV cache and SDPA based on backend
if self.use_standard_attention:
if self.backend == "metal":
self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim)
self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim)
else:
Expand All @@ -595,7 +593,7 @@ def forward(

k, v = self.kv_cache.update(input_pos, k, v)

if self.use_standard_attention:
if self.backend == "metal":
y = self.sdpa(input_pos, q, k, v, B, T, attn_mask)
else:
y = self.sdpa(input_pos, q, k, v, B, T)
Expand Down Expand Up @@ -685,7 +683,7 @@ def forward(

# Compute attention mask once for all 26 layers (P3 optimization).
attn_mask: torch.Tensor | None = None
if self.config.use_standard_attention:
if self.config.backend == "metal":
max_seq_len = self.freqs_cos.shape[0]
attn_mask = _build_attn_mask(input_pos, max_seq_len, input_embeds.device)

Expand Down Expand Up @@ -909,7 +907,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
# Choose cache implementation based on backend
cache_class = (
StandardEncoderRingKVCache
if config.use_standard_attention
if config.backend == "metal"
else EncoderRingKVCache
)
self.kv_caches = nn.ModuleList(
Expand All @@ -920,7 +918,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
)

# Choose SDPA based on backend
if config.use_standard_attention:
if config.backend == "metal":
self.sdpa = StandardEncoderSDPA(config.enc_n_heads, config.enc_head_dim)
else:
self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim)
Expand Down Expand Up @@ -1067,7 +1065,7 @@ def load_model(
max_seq_len: int = 4096,
n_delay_tokens: int = 6,
dtype: torch.dtype = torch.float32,
use_standard_attention: bool = False,
backend: str = "xnnpack",
) -> VoxtralRealtimeModel:
"""Load VoxtralRealtimeModel from a Mistral consolidated checkpoint.

Expand All @@ -1080,20 +1078,25 @@ def load_model(
max_seq_len: Maximum sequence length for KV cache.
n_delay_tokens: Transcription delay in tokens (default 6 = 480ms).
dtype: Weight dtype (default: float32).
use_standard_attention: Use standard PyTorch attention instead of custom ops
(required for Metal/AOTI backends).
backend: Backend for acceleration ("xnnpack", "metal", or "portable").
"""
_VALID_BACKENDS = ("xnnpack", "metal", "portable")
if backend not in _VALID_BACKENDS:
raise ValueError(
f"Unknown backend '{backend}'. Must be one of {_VALID_BACKENDS}."
)

from safetensors import safe_open

model_dir = Path(model_path)
config = VoxtralRealtimeConfig.from_params_json(str(model_dir / "params.json"))
config.max_seq_len = max_seq_len
config.use_standard_attention = use_standard_attention
config.backend = backend

print(
f"Building model on meta device (dim={config.dim}, enc_dim={config.enc_dim}, "
f"layers={config.n_layers}, enc_layers={config.enc_n_layers}, "
f"attention={'standard' if use_standard_attention else 'custom'})..."
f"backend={backend})..."
)
with torch.device("meta"):
model = VoxtralRealtimeModel(config, max_seq_len)
Expand Down
Loading