Skip to content

Multilingual Overhaul #833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

import numpy as np

from lighteval.metrics.metrics_corpus import CorpusLevelTranslationMetric
from lighteval.metrics.metrics_sample import (
BLEU,
ExactMatches,
F1_score,
LoglikelihoodAcc,
Expand All @@ -38,6 +40,7 @@
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.sample_preparator import GenerativePreparator
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
ExprExtractionConfig,
ExtractionTarget,
Expand All @@ -47,7 +50,7 @@
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparison import compare_gold_target
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.metrics.utils.metric_utils import CorpusLevelMetric, MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout
Expand Down Expand Up @@ -122,26 +125,34 @@ def probability_metric(


def multilingual_quasi_f1_score_metric(
language: Language, aggregation_function: Callable[[list[float]], float] = max
language: Language,
aggregation_function: Callable[[list[float]], float] = max,
normalize_gold: Callable[[str], str] | None = None,
normalize_pred: Callable[[str], str] | None = None,
) -> SampleLevelMetric:
"""
Creates a language-aware F1 score metric, which returns the F1 score.

Args:
language: The language of the samples.
aggregation_function: Aggregation samples to use when multiple golds are present.
normalize_gold: Normalization function for gold answers.
normalize_pred: Normalization function for predictions.

Returns:
F1 score metric.
"""
metric_name = f"f1_{language.value}"

multilang_normalizer = get_multilingual_normalizer(language)
base_normalizer = get_multilingual_normalizer(language)
gold_normalizer = (lambda x: base_normalizer(normalize_gold(x))) if normalize_gold is not None else base_normalizer
pred_normalizer = (lambda x: base_normalizer(normalize_pred(x))) if normalize_pred is not None else base_normalizer

return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=F1_score(
normalize_gold=multilang_normalizer,
normalize_pred=multilang_normalizer,
normalize_gold=gold_normalizer,
normalize_pred=pred_normalizer,
aggregation_function=aggregation_function,
).compute,
category=MetricCategory.GENERATIVE,
Expand All @@ -155,6 +166,8 @@ def multilingual_quasi_exact_match_metric(
language: Language,
match_type: Literal["prefix", "suffix", "full"] = "full",
aggregation_function: Callable[[list[float]], float] = max,
normalize_gold: Callable[[str], str] | None = None,
normalize_pred: Callable[[str], str] | None = None,
) -> SampleLevelMetric:
"""
Creates a language-aware exact match metric, which returns the exact match score
Expand All @@ -165,16 +178,21 @@ def multilingual_quasi_exact_match_metric(
- "suffix": Suffixes must match
- "full": Full strings must match
aggregation_function: Aggregation samples to use when multiple golds are present.
normalize_gold: Normalization function for gold answers.
normalize_pred: Normalization function for predictions.
Returns:
Exact match metric.
"""
metric_name = f"exact_match_{language.value}_{match_type}"
multilang_normalizer = get_multilingual_normalizer(language)
base_normalizer = get_multilingual_normalizer(language)
gold_normalizer = (lambda x: base_normalizer(normalize_gold(x))) if normalize_gold is not None else base_normalizer
pred_normalizer = (lambda x: base_normalizer(normalize_pred(x))) if normalize_pred is not None else base_normalizer

return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=ExactMatches(
normalize_gold=multilang_normalizer,
normalize_pred=multilang_normalizer,
normalize_gold=gold_normalizer,
normalize_pred=pred_normalizer,
aggregation_function=aggregation_function,
type_exact_match=match_type,
).compute,
Expand All @@ -185,6 +203,38 @@ def multilingual_quasi_exact_match_metric(
)


def translation_metric(
metric_name: Literal["bleu", "bleu_1", "bleu_4", "chrf", "chrf++"],
normalize_pred: Callable[[str], str] | None = None,
normalize_gold: Callable[[str], str] | None = None,
) -> CorpusLevelMetric | SampleLevelMetric:
"""
Creates a translation metric, which returns the translation score.
"""
if metric_name.startswith("bleu_"):
return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=BLEU(
n_gram=int(metric_name.split("_")[1]), normalize_pred=normalize_pred, normalize_gold=normalize_gold
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.TRANSLATION,
corpus_level_fn=np.mean,
higher_is_better=True,
)
else:
return CorpusLevelMetric(
metric_name=metric_name,
sample_level_fn=GenerativePreparator().prepare,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.TRANSLATION,
corpus_level_fn=CorpusLevelTranslationMetric(
metric_name, normalize_pred=normalize_pred, normalize_gold=normalize_gold
).compute, # type: ignore
higher_is_better=True,
)


def multilingual_extractive_match_metric(
language: Language = Language.ENGLISH,
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
Expand Down
16 changes: 12 additions & 4 deletions src/lighteval/metrics/metrics_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import logging
import math
from typing import Literal
from typing import Callable, Literal

import numpy as np
import sacrebleu
Expand Down Expand Up @@ -91,14 +91,22 @@ def compute(self, items: list[LogprobCorpusMetricInput]):


class CorpusLevelTranslationMetric:
def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
def __init__(
self,
metric_type: Literal["bleu", "chrf", "chrf++", "ter"],
lang: Literal["zh", "ja", "ko", ""] = "",
normalize_pred: Callable[[str], str] | None = None,
normalize_gold: Callable[[str], str] | None = None,
):
"""Stores the relevant parameters for a corpus level translation metric.

Args:
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
"""
self.metric_type = metric_type
self.lang = lang
self.normalize_pred = normalize_pred if normalize_pred is not None else lambda x: x
self.normalize_gold = normalize_gold if normalize_gold is not None else lambda x: x

def get_metric(self):
if self.metric_type == "bleu":
Expand All @@ -115,15 +123,15 @@ def get_metric(self):
def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
metric = self.get_metric()
golds = [i.golds for i in items]
golds = [[self.normalize_gold(gold) for gold in i.golds] for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
if len(pred) > 1:
logger.info(
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
)
preds.append(pred[0])
preds.append(self.normalize_pred(pred[0]))
return float(metric.corpus_score(hypotheses=preds, references=golds).score)


Expand Down
13 changes: 12 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,14 +744,21 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float:


class BLEU:
def __init__(self, n_gram: int):
def __init__(
self,
n_gram: int,
normalize_pred: Callable[[str], str] | None = None,
normalize_gold: Callable[[str], str] | None = None,
):
"""BLEU scorer class. Relies on `nltk`'s sentencebleu for scoring.
TODO: Will have to move this to sacrebleu.

Args:
n_gram (int): Number of n_grams to use for scoring.
"""
self.n_gram = n_gram
self.normalize_pred = normalize_pred
self.normalize_gold = normalize_gold

def compute(self, golds: list[str], predictions: list[str], **kwargs):
"""Computes the sentence level BLEU between the golds and each prediction, then takes the average.
Expand All @@ -763,6 +770,10 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs):
Returns:
float: Score over the current sample's items.
"""
if self.normalize_pred:
predictions = [self.normalize_pred(p) for p in predictions]
if self.normalize_gold:
golds = [self.normalize_gold(g) for g in golds]
return np.mean([self._bleu_score(golds, p) for p in predictions])

def _bleu_score(self, gold: list[str], pred: str) -> float:
Expand Down
4 changes: 4 additions & 0 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class IndicesExtractionConfig:
"""

prefix_for_extraction: ChoicePrefix
bb_match_priority: int = -1
try_extract_without_anchor: bool = True


Expand Down Expand Up @@ -340,6 +341,9 @@ def lazy_indices_regex(
]
)

if indices_config.bb_match_priority >= 0:
regexes.append((rf"<b>\s*{indice_str_re}\s*</b>", indices_config.bb_match_priority))

return [(re.compile(pattern), priority) for pattern, priority in regexes]


Expand Down
55 changes: 54 additions & 1 deletion src/lighteval/tasks/multilingual/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.multilingual.utils.adapters_utils import (
extract_answers_from_string,
float_to_choice_string,
multichoice_join,
multichoice_to_single_choice,
)
Expand Down Expand Up @@ -79,7 +80,7 @@ def thai_exams_adapter(line: dict) -> MCQInput | None:

def alghafa_adapter(line: dict) -> MCQInput | None:
answer_index = int(line["label"])
choices_keys = [key for key in line.keys() if key not in ["query", "label", "__few_shots"]]
choices_keys = [key for key in line.keys() if key not in ["query", "label", "__index", "__few_shots"]]
choices = [line[key] for key in choices_keys]
return {
"question": line["query"],
Expand Down Expand Up @@ -298,3 +299,55 @@ def enem_adapter(lang: Language, line: dict) -> MCQInput | None:
"choices": line["alternatives"],
"gold_idx": LETTER_INDICES.index(line["label"]),
}


CMM_MATH_ANSWER_RE = re.compile(r"([A-D])\.(.*?)(?=[A-D]\.|$)")


def cmm_math_adapter(line: dict) -> MCQInput | None:
"""Adapter for CMM-Math dataset.

Processes questions and options, handling cases where:
- Question ends with parentheses that need to be stripped
- Options are space-separated strings starting with A./B./C./D.
"""
# Strip ending parentheses from question
question = line["question"].strip().rstrip("( )")

# Split options and store as dict with letter keys
choices = {}
for match in CMM_MATH_ANSWER_RE.finditer(line["options"]):
letter, choice = match.groups()
choices[letter] = choice.strip()

try:
gold_idx = list(choices.keys()).index(line["answer"])
except ValueError:
gold_idx = None

# Validate we have enough options and answer
if len(choices) <= 1 or not line.get("answer") or gold_idx is None:
return None

return {"question": question, "choices": list(choices.values()), "gold_idx": gold_idx}


def qazuntv2_adapter(line: dict) -> MCQInput | None:
gold_idx = LETTER_INDICES.index(line["answer"])
choices = line["options"]
if gold_idx >= len(choices):
return None
return {"question": line["question"], "choices": choices, "gold_idx": gold_idx}


MGSM_COT_PREFIX_RE = re.compile(
r"\s*(ধাপে ধাপে উত্তর|Schritt-für-Schritt-Antwort|Step-by-Step Answer|Respuesta paso a paso|Réponse étape par étape|ステップごとの答え|Пошаговое решение|Jibu la Hatua kwa Hatua|దశలవారీగా సమాధానంi|คำตอบทีละขั้นตอน|逐步解答)\s*:\s*"
)
MGSM_QUESTION_RE = re.compile(r"\s*(প্রশ্ন|Frage|Question|Pregunta|Question|問題|Задача|Swali|ప్రశ్న|โจทย์|问题)\s*:\s*")


def mgsm_adapter(line: dict) -> QAInput | None:
question = MGSM_QUESTION_RE.sub("", line["question"])
answer_cot = MGSM_COT_PREFIX_RE.sub("", line["answer"]) if line["answer"] else ""
answer_number = line["answer_number"]
return {"question": question, "few_shot_cot": answer_cot, "choices": [float_to_choice_string(answer_number)]}
Loading