Skip to content

Comments

[MagpieTTS][bugfix] reset kv cache for longform inference and add missing utmosv2 score #15385

Closed
XuesongYang wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-bugfix-stale-kvcache
Closed

[MagpieTTS][bugfix] reset kv cache for longform inference and add missing utmosv2 score #15385
XuesongYang wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-bugfix-stale-kvcache

Conversation

@XuesongYang
Copy link
Collaborator

Summary

Two inference bugfixes for MagpieTTS.

1. Reset KV cache at start of longform inference batch

generate_long_form_speech never reset the decoder KV cache. When the inference script
processes multiple datasets sequentially (e.g., a non-longform dataset followed by a longform
dataset), the prior generate_speech call leaves use_cache=True with populated tensors.
The longform path then inherits this stale cache, causing a RuntimeError: Sizes of tensors must match in torch.cat during self-attention KV concatenation.

Fix: call reset_cache(use_cache=self.model.use_kv_cache_for_inference) at the start of each
longform batch in _run_longform_inference, matching the pattern used by infer_batch.

Error Details:

[NeMo I 2026-02-11 03:24:13 inference:317] Using longform inference path
[NeMo I 2026-02-11 03:24:13 inference:459] Cleaning up old generated files in: /results/moe16_sinkhorn_top1_valLoss5.0469_step2625132_epoch524_decoder-MoE_16x1_d3072_sinkhorn_Temp0.7_Topk80_Cfg_True_2.5_Prior_True_0.1_5_0_None_None_LT_False_MaskGit_3_None_None_EOS_argmax_or_multinomial_any_IgnoreFST_False_SV_titanet_libritts_seen/audio/repeat_0
[NeMo I 2026-02-11 03:24:14 inference:602] Processing batch 1/6 (longform)
[NeMo I 2026-02-11 03:24:15 magpietts:4621] Longform decoding timestep 0
Traceback (most recent call last):
  File "/code/examples/tts/magpietts_inference.py", line 668, in <module>
    main()
  File "/code/examples/tts/magpietts_inference.py", line 638, in main
    cer, ssim = run_inference_and_evaluation(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/examples/tts/magpietts_inference.py", line 257, in run_inference_and_evaluation
    rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/modules/magpietts_inference/inference.py", line 318, in run_inference_on_dataset
    return self._run_longform_inference(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/modules/magpietts_inference/inference.py", line 646, in _run_longform_inference
    output = self.model.generate_long_form_speech(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/models/magpietts.py", line 4650, in generate_long_form_speech
    all_code_logits, attn_probs, dec_out = self._run_longform_forward_with_cfg(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/models/magpietts.py", line 4321, in _run_longform_forward_with_cfg
    combined_logits, attn_probs, dec_out, _ = self.forward(
                                              ^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/models/magpietts.py", line 1262, in forward
    decoder_out = self.decoder(
                  ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)                                                                                                                                                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/modules/transformer_2501.py", line 826, in forward
    out_dict = layer(x, x_mask, _cond, _cond_mask, attn_prior=_attn_prior)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/modules/transformer_2501.py", line 577, in forward
    x_, s_attn_prob = self.self_attention(query=self.norm_self(x), query_mask=x_mask)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/modules/transformer_2501.py", line 300, in forward
    y, attn_prob = self.attn_naive(query, query_mask, memory, memory_mask, attn_prior)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/modules/transformer_2501.py", line 222, in attn_naive
    q, k, v, mask = self.compute_qkv_and_mask(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/nemo/collections/tts/modules/transformer_2501.py", line 358, in compute_qkv_and_mask
    k = torch.cat([self.cache['self_k'], k], dim=1)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 6 but got size 64 for tensor number 1 in the list.

2. Save filewise utmosv2 score in evaluation output

The utmosv2 metric was computed per file but not included in the saved filewise metrics
JSON, so downstream visualization (box plots) could not display MOS scores.

Fix: add 'utmosv2' to filewise_metrics_keys_to_save in evaluate_generated_audio.py.

Error Details:
Screenshot 2026-02-11 at 10 22 07 AM

…nt stale cache from prior batch or datasets

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
… display MOS.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

'gt_audio_filepath',
'pred_audio_filepath',
'context_audio_filepath',
'utmosv2',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be added in #15381, please remove from yours

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reviewed the other PR and don't anticipate any conflicts during a rebase. I suggest we avoid reverting the commit here. Instead, let's simply merge whichever PR is ready first, and then rebase the remaining one.

@blisc
Copy link
Collaborator

blisc commented Feb 12, 2026

@subhankar-ghosh please review

@blisc blisc marked this pull request as draft February 18, 2026 20:13
@blisc
Copy link
Collaborator

blisc commented Feb 18, 2026

Drafting since we plan to add this to #15375

@XuesongYang
Copy link
Collaborator Author

Drafting since we plan to add this to #15375

let's close this PR and move our discussion to that PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants