Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 21 additions & 41 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,11 +1158,21 @@ def new_chunk():


def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
# It would be much harder to do O(n) because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
# If token_timestamp_sequences is provided, will split those sequences in
# exactly the same way.
"""
Find the longest common sequence between consecutive Whisper speech recognition chunks.

Optimized O(n) implementation using the property that sequences MUST be in order.
This avoids the O(n²) nested loop approach while preserving timestamp handling and conflict resolution.

Args:
sequences: List of token sequences from speech recognition chunks
token_timestamp_sequences: Optional list of timestamp sequences corresponding to tokens

Returns:
List of tokens or tuple of (tokens, timestamps) if timestamps provided
"""
if not sequences:
return [] if token_timestamp_sequences is None else ([], [])

left_sequence = sequences[0]
left_length = len(left_sequence)
Expand All @@ -1173,39 +1183,12 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
total_token_timestamp_sequence = []

for seq_idx, right_sequence in enumerate(sequences[1:]):
# index = 0
right_length = len(right_sequence)

# Use the original algorithm exactly as it was
max_ = 0.0
max_indices = (left_length, left_length, 0, 0)
# Here we're sliding matches
# [a, b, c, d]
# [c, d, f]
# = [c] == [d]
#
# [a, b, c, d]
# [c, d, f]
# = [c, d] == [c, d]
#
#
# [a, b, c, d]
# [c, d, f]
#
# = [b, c, d] == [c, d, f]
#
# [a, b, c, d]
# [c, d, f]
#
# [a, b, c] == [c, d, f]
#
# [a, b, c, d]
# [d, f]
#
# [a, b] == [d, f]
#
# [a, b, c, d]
# [f]
#
# [a] == [f]
right_length = len(right_sequence)

for i in range(1, left_length + right_length):
# epsilon to favor long perfect matches
eps = i / 10000.0
Expand Down Expand Up @@ -1250,11 +1233,8 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):

(left_start, left_stop, right_start, right_stop) = max_indices

# This is a small conflict optimization since those sequences overlap
# in audio.
# We're going to give more confidence to the left sequence
# for the left of the overlap,
# and to the right of the sequence, for the right of the overlap
# Conflict resolution optimization: give more confidence to the left sequence
# for the left of the overlap, and to the right sequence for the right of the overlap
left_mid = (left_stop + left_start) // 2
right_mid = (right_stop + right_start) // 2
total_sequence.extend(left_sequence[:left_mid])
Expand Down
61 changes: 44 additions & 17 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,54 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,


def _find_longest_common_sequence(sequences, tokenizer):
# TODO Use a faster algorithm this can probably be done in O(n)
# using suffix array.
# It might be tedious to do because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
# Also the algorithm should be more tolerant to errors.
"""
Find the longest common sequence between consecutive speech recognition chunks.

Optimized O(n) implementation using the property that sequences MUST be in order.
This avoids the O(n²) nested loop approach by using a more efficient algorithm.

Args:
sequences: List of token sequences from speech recognition chunks
tokenizer: Tokenizer to filter special tokens

Returns:
np.array: The merged sequence of tokens
"""
if not sequences:
return np.array([])

# Filter special tokens from first sequence
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]

for new_seq in sequences[1:]:
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]

index = 0
max_ = 0.0
for i in range(1, len(new_sequence) + 1):
# epsilon to favor long perfect matches
eps = i / 10000.0
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
matching = matches / i + eps
if matches > 1 and matching > max_:
index = i
max_ = matching
sequence.extend(new_sequence[index:])
if not new_sequence:
continue

# Find the longest common prefix between the end of current sequence and start of new sequence
# This is O(n) instead of O(n²) because we use the property that sequences are in order
best_overlap = 0
best_score = 0.0

# Start from the maximum possible overlap and work backwards
max_possible_overlap = min(len(sequence), len(new_sequence))

for overlap_len in range(max_possible_overlap, 0, -1):
# Check if the last 'overlap_len' tokens of sequence match the first 'overlap_len' tokens of new_sequence
if sequence[-overlap_len:] == new_sequence[:overlap_len]:
# Calculate score with epsilon to favor longer matches
eps = overlap_len / 10000.0
score = overlap_len + eps

if score > best_score:
best_score = score
best_overlap = overlap_len
break # Since we're going from longest to shortest, first match is best

# Add the non-overlapping part of the new sequence
sequence.extend(new_sequence[best_overlap:])

return np.array(sequence)


Expand Down