Skip to content

Add tensor support in edit distance #3036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 70 additions & 22 deletions src/torchmetrics/functional/text/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,61 @@


def _edit_distance_update(
preds: Union[str, Sequence[str]],
target: Union[str, Sequence[str]],
preds: Union[str, Sequence[str], Tensor, Sequence[Tensor]],
target: Union[str, Sequence[str], Tensor, Sequence[Tensor]],
substitution_cost: int = 1,
) -> Tensor:
if isinstance(preds, str):
preds = [preds]
if isinstance(target, str):
target = [target]
if not all(isinstance(x, str) for x in preds):
raise ValueError(f"Expected all values in argument `preds` to be string type, but got {preds}")
if not all(isinstance(x, str) for x in target):
raise ValueError(f"Expected all values in argument `target` to be string type, but got {target}")
if len(preds) != len(target):
raise ValueError(
f"Expected argument `preds` and `target` to have same length, but got {len(preds)} and {len(target)}"
)
"""Update the edit distance score with the current set of predictions and targets.

Args:
preds: An iterable of predicted texts (strings) or a tensor of categorical values or a list of tensors
target: An iterable of reference texts (strings) or a tensor of categorical values or a list of tensors
substitution_cost: The cost of substituting one character for another.

Returns:
A tensor containing the edit distance scores for each prediction-target pair.

"""
# Handle tensor inputs
if isinstance(preds, Tensor) and isinstance(target, Tensor):
if preds.dim() == 1:
preds = preds.unsqueeze(0)
if target.dim() == 1:
target = target.unsqueeze(0)
if preds.size(0) != target.size(0):
raise ValueError(
f"Expected argument `preds` and `target` to have same batch size, but got {preds.size(0)} and {target.size(0)}"
)
# Convert tensors to lists of lists for the edit distance algorithm
preds = [p.tolist() for p in preds]
target = [t.tolist() for t in target]
# Handle lists of tensors
elif isinstance(preds, (list, tuple)) and isinstance(target, (list, tuple)):
if not all(isinstance(x, Tensor) for x in preds):
raise ValueError(f"Expected all values in argument `preds` to be tensor type, but got {preds}")
if not all(isinstance(x, Tensor) for x in target):
raise ValueError(f"Expected all values in argument `target` to be tensor type, but got {target}")
if len(preds) != len(target):
raise ValueError(
f"Expected argument `preds` and `target` to have same length, but got {len(preds)} and {len(target)}"
)
# Convert tensors to lists for the edit distance algorithm
preds = [p.tolist() for p in preds]
target = [t.tolist() for t in target]
else:
# Handle string inputs (existing behavior)
if isinstance(preds, str):
preds = [preds]
if isinstance(target, str):
target = [target]
if not all(isinstance(x, str) for x in preds):
raise ValueError(f"Expected all values in argument `preds` to be string type, but got {preds}")
if not all(isinstance(x, str) for x in target):
raise ValueError(f"Expected all values in argument `target` to be string type, but got {target}")
if len(preds) != len(target):
raise ValueError(
f"Expected argument `preds` and `target` to have same length, but got {len(preds)} and {len(target)}"
)

distance = [
_LE_distance(t, op_substitute=substitution_cost)(p)[0] # type: ignore[arg-type]
Expand All @@ -63,8 +102,8 @@ def _edit_distance_compute(


def edit_distance(
preds: Union[str, Sequence[str]],
target: Union[str, Sequence[str]],
preds: Union[str, Sequence[str], Tensor],
target: Union[str, Sequence[str], Tensor],
substitution_cost: int = 1,
reduction: Optional[Literal["mean", "sum", "none"]] = "mean",
) -> Tensor:
Expand All @@ -76,8 +115,8 @@ def edit_distance(
Implementation is similar to `nltk.edit_distance <https://www.nltk.org/_modules/nltk/metrics/distance.html>`_.

Args:
preds: An iterable of predicted texts (strings).
target: An iterable of reference texts (strings).
preds: An iterable of predicted texts (strings) or a tensor of categorical values
target: An iterable of reference texts (strings) or a tensor of categorical values
substitution_cost: The cost of substituting one character for another.
reduction: a method to reduce metric score over samples.

Expand All @@ -87,19 +126,19 @@ def edit_distance(

Raises:
ValueError:
If ``preds`` and ``target`` do not have the same length.
If ``preds`` and ``target`` do not have the same length/batch size.
ValueError:
If ``preds`` or ``target`` contain non-string values.
If ``preds`` or ``target`` contain non-string values when using string inputs.

Example::
Basic example with two strings. Going from rain -> sain -> shin -> shine takes 3 edits:
Basic example with two strings. Going from "rain" -> "sain" -> "shin" -> "shine" takes 3 edits:

>>> from torchmetrics.functional.text import edit_distance
>>> edit_distance(["rain"], ["shine"])
tensor(3.)

Example::
Basic example with two strings and substitution cost of 2. Going from rain -> sain -> shin -> shine
Basic example with two strings and substitution cost of 2. Going from "rain" -> "sain" -> "shin" -> "shine"
takes 3 edits, where two of them are substitutions:

>>> from torchmetrics.functional.text import edit_distance
Expand All @@ -115,6 +154,15 @@ def edit_distance(
>>> edit_distance(["rain", "lnaguaeg"], ["shine", "language"], reduction="mean")
tensor(3.5000)

Example::
Using tensors of categorical values:

>>> from torchmetrics.functional.text import edit_distance
>>> preds = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> target = torch.tensor([[1, 2, 4], [4, 5, 7]])
>>> edit_distance(preds, target)
tensor(2.0000)

"""
distance = _edit_distance_update(preds, target, substitution_cost)
return _edit_distance_compute(distance, num_elements=distance.numel(), reduction=reduction)
71 changes: 50 additions & 21 deletions src/torchmetrics/functional/text/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,19 @@ class _LevenshteinEditDistance:
where the most of this implementation is adapted and copied from.

Args:
reference_tokens: list of reference tokens
reference_tokens: list of reference tokens or tensor values
op_insert: cost of insertion operation
op_delete: cost of deletion operation
op_substitute: cost of substitution operation

"""

def __init__(
self, reference_tokens: list[str], op_insert: int = 1, op_delete: int = 1, op_substitute: int = 1
self,
reference_tokens: Union[list[str], list[int], list[float]],
op_insert: int = 1,
op_delete: int = 1,
op_substitute: int = 1,
) -> None:
self.reference_tokens = reference_tokens
self.reference_len = len(reference_tokens)
Expand All @@ -82,11 +86,13 @@ def __init__(
self.op_nothing = 0
self.op_undefined = _INT_INFINITY

def __call__(self, prediction_tokens: list[str]) -> tuple[int, tuple[_EditOperations, ...]]:
def __call__(
self, prediction_tokens: Union[list[str], list[int], list[float]]
) -> tuple[int, tuple[_EditOperations, ...]]:
"""Calculate edit distance between self._words_ref and the hypothesis. Uses cache to skip some computations.

Args:
prediction_tokens: A tokenized predicted sentence.
prediction_tokens: A tokenized predicted sentence or tensor values.

Return:
A tuple of a calculated edit distance and a trace of executed operations.
Expand All @@ -105,14 +111,14 @@ def __call__(self, prediction_tokens: list[str]) -> tuple[int, tuple[_EditOperat

def _levenshtein_edit_distance(
self,
prediction_tokens: list[str],
prediction_tokens: Union[list[str], list[int], list[float]],
prediction_start: int,
cache: list[list[tuple[int, _EditOperations]]],
) -> tuple[int, list[list[tuple[int, _EditOperations]]], tuple[_EditOperations, ...]]:
"""Dynamic programming algorithm to compute the Levenhstein edit distance.

Args:
prediction_tokens: A tokenized predicted sentence.
prediction_tokens: A tokenized predicted sentence or tensor values.
prediction_start: An index where a predicted sentence to be considered from.
cache: A cached Levenshtein edit distance.

Expand Down Expand Up @@ -146,7 +152,15 @@ def _levenshtein_edit_distance(
_EditOperations.OP_DELETE,
)
else:
if prediction_tokens[i - 1] == self.reference_tokens[j - 1]:
# Handle both string and numeric comparisons
if isinstance(prediction_tokens[i - 1], (int, float)) and isinstance(
self.reference_tokens[j - 1], (int, float)
):
is_equal = abs(prediction_tokens[i - 1] - self.reference_tokens[j - 1]) < 1e-10
else:
is_equal = prediction_tokens[i - 1] == self.reference_tokens[j - 1]

if is_equal:
cost_substitute = self.op_nothing
operation_substitute = _EditOperations.OP_NOTHING
else:
Expand Down Expand Up @@ -209,14 +223,18 @@ def _get_trace(

return trace

def _add_cache(self, prediction_tokens: list[str], edit_distance: list[list[tuple[int, _EditOperations]]]) -> None:
def _add_cache(
self,
prediction_tokens: Union[list[str], list[int], list[float]],
edit_distance: list[list[tuple[int, _EditOperations]]],
) -> None:
"""Add newly computed rows to cache.

Since edit distance is only calculated on the hypothesis suffix that was not in cache, the number of rows in
`edit_distance` matrx may be shorter than hypothesis length. In that case we skip over these initial words.

Args:
prediction_tokens: A tokenized predicted sentence.
prediction_tokens: A tokenized predicted sentence or tensor values.
edit_distance:
A matrix of the Levenshtedin edit distance. The element part of the matrix is a tuple of an edit
operation cost and an edit operation itself.
Expand All @@ -232,21 +250,23 @@ def _add_cache(self, prediction_tokens: list[str], edit_distance: list[list[tupl

# Jump through the cache to the current position
for i in range(skip_num):
node = node[prediction_tokens[i]][0] # type: ignore
node = node[str(prediction_tokens[i])][0] # type: ignore

# Update cache with newly computed rows
for word, row in zip(prediction_tokens[skip_num:], edit_distance):
if word not in node:
node[word] = ({}, tuple(row)) # type: ignore
if str(word) not in node:
node[str(word)] = ({}, tuple(row)) # type: ignore
self.cache_size += 1
value = node[word]
value = node[str(word)]
node = value[0] # type: ignore

def _find_cache(self, prediction_tokens: list[str]) -> tuple[int, list[list[tuple[int, _EditOperations]]]]:
def _find_cache(
self, prediction_tokens: Union[list[str], list[int], list[float]]
) -> tuple[int, list[list[tuple[int, _EditOperations]]]]:
"""Find the already calculated rows of the Levenshtein edit distance metric.

Args:
prediction_tokens: A tokenized predicted sentence.
prediction_tokens: A tokenized predicted sentence or tensor values.

Return:
A tuple of a start hypothesis position and `edit_distance` matrix.
Expand All @@ -261,9 +281,9 @@ def _find_cache(self, prediction_tokens: list[str]) -> tuple[int, list[list[tupl
start_position = 0
edit_distance: list[list[tuple[int, _EditOperations]]] = [self._get_initial_row(self.reference_len)]
for word in prediction_tokens:
if word in node:
if str(word) in node:
start_position += 1
node, row = node[word] # type: ignore
node, row = node[str(word)] # type: ignore
edit_distance.append(row) # type: ignore
else:
break
Expand Down Expand Up @@ -327,12 +347,15 @@ def _validate_inputs(
return ref_corpus, hypothesis_corpus


def _edit_distance(prediction_tokens: list[str], reference_tokens: list[str]) -> int:
def _edit_distance(
prediction_tokens: Union[list[str], list[int], list[float]],
reference_tokens: Union[list[str], list[int], list[float]],
) -> int:
"""Dynamic programming algorithm to compute the edit distance.

Args:
prediction_tokens: A tokenized predicted sentence
reference_tokens: A tokenized reference sentence
prediction_tokens: A tokenized predicted sentence or tensor values
reference_tokens: A tokenized reference sentence or tensor values
Returns:
Edit distance between the predicted sentence and the reference sentence

Expand All @@ -344,7 +367,13 @@ def _edit_distance(prediction_tokens: list[str], reference_tokens: list[str]) ->
dp[0][j] = j
for i in range(1, len(prediction_tokens) + 1):
for j in range(1, len(reference_tokens) + 1):
if prediction_tokens[i - 1] == reference_tokens[j - 1]:
# Handle both string and numeric comparisons
if isinstance(prediction_tokens[i - 1], (int, float)) and isinstance(reference_tokens[j - 1], (int, float)):
is_equal = abs(prediction_tokens[i - 1] - reference_tokens[j - 1]) < 1e-10
else:
is_equal = prediction_tokens[i - 1] == reference_tokens[j - 1]

if is_equal:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
Expand Down
20 changes: 15 additions & 5 deletions src/torchmetrics/text/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class EditDistance(Metric):

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~Sequence`): An iterable of hypothesis corpus
- ``target`` (:class:`~Sequence`): An iterable of iterables of reference corpus
- ``preds`` (:class:`~Sequence`): An iterable of hypothesis corpus or a tensor of categorical values
- ``target`` (:class:`~Sequence`): An iterable of iterables of reference corpus or a tensor of categorical values

As output of ``forward`` and ``compute`` the metric returns the following output:

Expand All @@ -56,15 +56,15 @@ class EditDistance(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example::
Basic example with two strings. Going from rain -> sain -> shin -> shine takes 3 edits:
Basic example with two strings. Going from "rain" -> "sain" -> "shin" -> "shine" takes 3 edits:

>>> from torchmetrics.text import EditDistance
>>> metric = EditDistance()
>>> metric(["rain"], ["shine"])
tensor(3.)

Example::
Basic example with two strings and substitution cost of 2. Going from rain -> sain -> shin -> shine
Basic example with two strings and substitution cost of 2. Going from "rain" -> "sain" -> "shin" -> "shine"
takes 3 edits, where two of them are substitutions:

>>> from torchmetrics.text import EditDistance
Expand All @@ -83,6 +83,16 @@ class EditDistance(Metric):
>>> metric(["rain", "lnaguaeg"], ["shine", "language"])
tensor(3.5000)

Example::
Using tensors of categorical values:

>>> from torchmetrics.text import EditDistance
>>> metric = EditDistance()
>>> preds = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> target = torch.tensor([[1, 2, 4], [4, 5, 7]])
>>> metric(preds, target)
tensor(2.0000)

"""

higher_is_better: bool = False
Expand Down Expand Up @@ -115,7 +125,7 @@ def __init__(
self.add_state("edit_scores", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("num_elements", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[str]]) -> None:
def update(self, preds: Union[str, Sequence[str], Tensor], target: Union[str, Sequence[str], Tensor]) -> None:
"""Update state with predictions and targets."""
distance = _edit_distance_update(preds, target, self.substitution_cost)
if self.reduction == "none" or self.reduction is None:
Expand Down
Loading
Loading