Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,15 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu
if self.cfg.strategy in ['greedy', 'greedy_batch']:
self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False)

elif self.cfg.strategy in ['beam', 'tsd', 'alsd', 'maes']:
elif self.cfg.strategy in ['beam', 'tsd', 'alsd', 'maes', 'malsd_batch', 'maes_batch']:
self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False)

# Update compute timestamps
if self.compute_timestamps is None:
if self.cfg.strategy in ['greedy', 'greedy_batch']:
self.compute_timestamps = self.cfg.greedy.get('compute_timestamps', False)

elif self.cfg.strategy in ['beam', 'tsd', 'alsd', 'maes']:
elif self.cfg.strategy in ['beam', 'tsd', 'alsd', 'maes', 'malsd_batch', 'maes_batch']:
self.compute_timestamps = self.cfg.beam.get('compute_timestamps', False)

# Check if the model supports punctuation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def __init__(
self.maes_num_expansions = self.beam_size + self.maes_expansion_beta

if self.preserve_alignments:
raise NotImplementedError("Preserve alignments is not supported")
logging.warning(
"Full alignment data (per-step logprobs) is not available in batched beam search. "
"Hypothesis.alignments will be None. Timestamps are still available via compute_timestamps."
)

if allow_cuda_graphs:
logging.info("CUDA Graphs are unsupported for `maes_batch`; preceeding pure pytorch decoding")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ def __init__(
self.allow_cuda_graphs = allow_cuda_graphs

if self.preserve_alignments:
raise NotImplementedError("Preserve alignments is not supported")
logging.warning(
"Full alignment data (per-step logprobs) is not available in batched beam search. "
"Hypothesis.alignments will be None. Timestamps are still available via compute_timestamps."
)

self.state = None
self.full_graph = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def __init__(
self.allow_cuda_graphs = allow_cuda_graphs

if self.preserve_alignments:
raise NotImplementedError("Preserve alignments is not supported")
logging.warning(
"Full alignment data (per-step logprobs) is not available in batched beam search. "
"Hypothesis.alignments will be None. Timestamps are still available via compute_timestamps."
)

self.state = None
self.full_graph = None
Expand Down
34 changes: 33 additions & 1 deletion nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ def __init__(
self.next_timestamp = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long)
self.last_timestamp_lasts = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long)

if self.model_type == ASRModelTypeEnum.TDT:
self.token_durations = torch.zeros(
(batch_size, self.beam_size, self._max_length), device=device, dtype=torch.long
)

def clear_(self):
"""
Clears and resets the internal state of the object.
Expand Down Expand Up @@ -198,6 +203,8 @@ def clear_(self):
self.timestamps.fill_(0)
self.next_timestamp.fill_(0)
self.last_timestamp_lasts.fill_(0)
if self.model_type == ASRModelTypeEnum.TDT:
self.token_durations.fill_(0)

def _allocate_more(self):
"""
Expand All @@ -215,6 +222,10 @@ def _allocate_more(self):
self.timestamps = self._create_timestamps_tensor(2 * self._max_length)
else:
self.timestamps = torch.cat((self.timestamps, torch.zeros_like(self.timestamps)), dim=-1)
if self.model_type == ASRModelTypeEnum.TDT:
self.token_durations = torch.cat(
(self.token_durations, torch.zeros_like(self.token_durations)), dim=-1
)

self._max_length *= 2

Expand Down Expand Up @@ -299,7 +310,12 @@ def add_results_no_checks_(
self.timestamps.scatter_(
dim=-1,
index=self.current_lengths_wb.unsqueeze(-1),
src=(timesteps + next_label_durations).unsqueeze(-1),
src=timesteps.unsqueeze(-1),
)
self.token_durations.scatter_(
dim=-1,
index=self.current_lengths_wb.unsqueeze(-1),
src=next_label_durations.unsqueeze(-1),
)
torch.where(is_extended, timesteps + next_label_durations, timesteps, out=self.next_timestamp)
torch.where(
Expand Down Expand Up @@ -474,6 +490,8 @@ def to_hyps_list(self, score_norm: bool = True) -> list[Hypothesis]:
max_idx = self.current_lengths_wb.max() - 1
timestamps = self.timestamps[..., 0, : max_idx + 1]
transcripts = self.transcript_wb[..., 0, : max_idx + 1]
if self.model_type == ASRModelTypeEnum.TDT:
token_durations = self.token_durations[..., 0, : max_idx + 1]
hypotheses = [
Hypothesis(
score=scores[batch_idx],
Expand All @@ -482,6 +500,11 @@ def to_hyps_list(self, score_norm: bool = True) -> list[Hypothesis]:
.detach()
.numpy(),
timestamp=timestamps[batch_idx][mask].cpu().detach().numpy(),
token_duration=(
token_durations[batch_idx][mask].cpu().detach().numpy()
if self.model_type == ASRModelTypeEnum.TDT
else None
),
alignments=None,
dec_state=None,
)
Expand All @@ -506,6 +529,8 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]:
max_idx = self.current_lengths_wb.max() - 1
transcripts = self.transcript_wb[..., : max_idx + 1]
timestamps = self.timestamps[..., : max_idx + 1]
if self.model_type == ASRModelTypeEnum.TDT:
token_durations = self.token_durations[..., : max_idx + 1]
hypotheses = [
NBestHypotheses(
[
Expand All @@ -518,6 +543,11 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]:
.detach()
.numpy(),
timestamp=timestamps[batch_idx][beam_idx][mask].cpu().detach().numpy(),
token_duration=(
token_durations[batch_idx][beam_idx][mask].cpu().detach().numpy()
if self.model_type == ASRModelTypeEnum.TDT
else None
),
alignments=None,
dec_state=None,
)
Expand Down Expand Up @@ -556,6 +586,8 @@ def flatten_sort_(self, score_norm: bool = True):
self.transcript_wb[..., idx].copy_(self.transcript_wb[self.batch_indices.unsqueeze(-1), ptrs, idx])
if self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT:
self.timestamps[..., idx].copy_(self.timestamps[self.batch_indices.unsqueeze(-1), ptrs, idx])
if self.model_type == ASRModelTypeEnum.TDT:
self.token_durations[..., idx].copy_(self.token_durations[self.batch_indices.unsqueeze(-1), ptrs, idx])
ptrs = self.transcript_wb_prev_ptr[self.batch_indices.unsqueeze(-1), ptrs, idx]
self.transcript_wb_prev_ptr[..., : max_idx + 1].copy_(self.beam_indices.unsqueeze(0).unsqueeze(-1))

Expand Down
142 changes: 142 additions & 0 deletions tests/collections/asr/decoding/test_batched_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,145 @@ def test_ctc_beam_decoding_kenlm(
check_res_best_hyps(num_samples, hyps)
hyps = decode_text_from_hypotheses(hyps, model)
print_res_best_hyps(hyps)


class TestBeamDecodingTimestamps:
"""
Tests that MALSD batch beam decoding produces valid timestamps for both RNN-T and TDT models.
"""

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.with_downloads
@pytest.mark.unit
@pytest.mark.parametrize("device", DEVICES)
def test_rnnt_beam_decoding_with_preserve_alignments(
self, test_audio_filenames, rnnt_model, get_rnnt_encoder_output, device
):
"""Test that RNN-T MALSD batch decoding works with preserve_alignments=True and produces timestamps."""
batch_size = 4
beam_size = 4
num_samples = min(batch_size, len(test_audio_filenames))
model = rnnt_model.to(device)
encoder_output, encoded_lengths = get_rnnt_encoder_output
encoder_output = encoder_output[:num_samples].to(device)
encoded_lengths = encoded_lengths[:num_samples].to(device)

vocab_size = model.tokenizer.vocab_size
decoding = BeamBatchedRNNTInfer(
model.decoder,
model.joint,
blank_index=vocab_size,
beam_size=beam_size,
score_norm=True,
return_best_hypothesis=True,
search_type="malsd_batch",
allow_cuda_graphs=False,
preserve_alignments=True,
)

with torch.no_grad():
hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0]

assert len(hyps) == num_samples
for hyp in hyps:
assert hyp.timestamp is not None
assert len(hyp.timestamp) > 0, "Timestamp should not be empty for non-empty transcription"
assert len(hyp.timestamp) == len(
hyp.y_sequence
), f"Timestamp length {len(hyp.timestamp)} should match y_sequence length {len(hyp.y_sequence)}"

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.with_downloads
@pytest.mark.unit
@pytest.mark.parametrize("device", DEVICES)
def test_tdt_beam_decoding_with_preserve_alignments(
self, test_audio_filenames, tdt_model, get_tdt_encoder_output, device
):
"""Test that TDT MALSD batch decoding works with preserve_alignments=True and produces timestamps + durations."""
batch_size = 4
beam_size = 4
num_samples = min(batch_size, len(test_audio_filenames))
model = tdt_model.to(device)
encoder_output, encoded_lengths = get_tdt_encoder_output
encoder_output = encoder_output[:num_samples].to(device)
encoded_lengths = encoded_lengths[:num_samples].to(device)

model_config = model.to_config_dict()
durations = list(model_config["model_defaults"]["tdt_durations"])

vocab_size = model.tokenizer.vocab_size
decoding = BeamBatchedTDTInfer(
model.decoder,
model.joint,
blank_index=vocab_size,
durations=durations,
beam_size=beam_size,
score_norm=True,
return_best_hypothesis=True,
search_type="malsd_batch",
allow_cuda_graphs=False,
preserve_alignments=True,
)

with torch.no_grad():
hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0]

assert len(hyps) == num_samples
for hyp in hyps:
assert hyp.timestamp is not None
assert len(hyp.timestamp) > 0, "Timestamp should not be empty for non-empty transcription"
assert len(hyp.timestamp) == len(
hyp.y_sequence
), f"Timestamp length {len(hyp.timestamp)} should match y_sequence length {len(hyp.y_sequence)}"
# TDT-specific: token_duration should be populated
assert hyp.token_duration is not None, "TDT hypothesis should have token_duration populated"
assert len(hyp.token_duration) == len(
hyp.y_sequence
), f"token_duration length {len(hyp.token_duration)} should match y_sequence length {len(hyp.y_sequence)}"

@pytest.mark.with_downloads
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA")
@pytest.mark.parametrize("model_type", ["rnnt", "tdt"])
def test_beam_decoding_compute_timestamps(self, test_audio_filenames, rnnt_model, tdt_model, model_type):
"""Test the full compute_timestamps pipeline via model.transcribe() with malsd_batch strategy."""
batch_size = 4
device = torch.device("cuda")
model = rnnt_model.to(device) if model_type == "rnnt" else tdt_model.to(device)
decoding_config = copy.deepcopy(model.cfg.decoding)

with open_dict(decoding_config):
decoding_config["strategy"] = "malsd_batch"
decoding_config["beam"]["beam_size"] = 4
decoding_config["beam"]["return_best_hypothesis"] = True
decoding_config["beam"]["allow_cuda_graphs"] = False
decoding_config["compute_timestamps"] = True

model.change_decoding_strategy(decoding_config)

hypotheses = model.transcribe(
test_audio_filenames[:batch_size], batch_size=batch_size, num_workers=None, return_hypotheses=True
)

assert len(hypotheses) > 0
for hyp in hypotheses:
assert hyp.timestamp is not None
assert isinstance(
hyp.timestamp, dict
), f"After compute_timestamps, timestamp should be a dict, got {type(hyp.timestamp)}"
assert 'timestep' in hyp.timestamp, "timestamp dict should contain 'timestep' key"
assert 'char' in hyp.timestamp, "timestamp dict should contain 'char' key"
assert 'word' in hyp.timestamp, "timestamp dict should contain 'word' key"

# Verify char offsets have correct structure
if len(hyp.timestamp['char']) > 0:
char_offset = hyp.timestamp['char'][0]
assert 'start_offset' in char_offset, "char offset should have start_offset"
assert 'end_offset' in char_offset, "char offset should have end_offset"
assert char_offset['start_offset'] >= 0, "start_offset should be non-negative"
assert char_offset['end_offset'] >= char_offset['start_offset'], "end_offset should be >= start_offset"
Loading