Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion gpt_oss/evals/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OPENAI_SYSTEM_MESSAGE_API,
ChatCompletionsSampler,
)
from .mmlu_eval import MMLUEval
from .responses_sampler import ResponsesSampler


Expand Down Expand Up @@ -47,7 +48,7 @@ def main():
parser.add_argument(
"--eval",
type=str,
default="gpqa,healthbench,healthbench_hard,healthbench_consensus,aime25",
default="gpqa,healthbench,healthbench_hard,healthbench_consensus,aime25,mmlu",
help="Select an eval by name. Accepts a comma-separated list.",
)
parser.add_argument(
Expand Down Expand Up @@ -139,6 +140,11 @@ def get_evals(eval_name, debug_mode):
num_examples=num_examples,
n_threads=args.n_threads or 1,
)
case "mmlu":
return MMLUEval(
num_examples=num_examples,
n_threads=args.n_threads or 1,
)
case _:
raise Exception(f"Unrecognized eval type: {eval_name}")

Expand Down
118 changes: 118 additions & 0 deletions gpt_oss/evals/mmlu_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Measuring Massive Multitask Language Understanding
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
https://arxiv.org/abs/2009.03300
"""

import random

import pandas

from . import report
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .abcd_grader import extract_abcd
from .gpqa_eval import format_multichoice_question

subject2category = {
"abstract_algebra": "stem",
"anatomy": "other",
"astronomy": "stem",
"business_ethics": "other",
"clinical_knowledge": "other",
"college_biology": "stem",
"college_chemistry": "stem",
"college_computer_science": "stem",
"college_mathematics": "stem",
"college_medicine": "other",
"college_physics": "stem",
"computer_security": "stem",
"conceptual_physics": "stem",
"econometrics": "social_sciences",
"electrical_engineering": "stem",
"elementary_mathematics": "stem",
"formal_logic": "humanities",
"global_facts": "other",
"high_school_biology": "stem",
"high_school_chemistry": "stem",
"high_school_computer_science": "stem",
"high_school_european_history": "humanities",
"high_school_geography": "social_sciences",
"high_school_government_and_politics": "social_sciences",
"high_school_macroeconomics": "social_sciences",
"high_school_mathematics": "stem",
"high_school_microeconomics": "social_sciences",
"high_school_physics": "stem",
"high_school_psychology": "social_sciences",
"high_school_statistics": "stem",
"high_school_us_history": "humanities",
"high_school_world_history": "humanities",
"human_aging": "other",
"human_sexuality": "social_sciences",
"international_law": "humanities",
"jurisprudence": "humanities",
"logical_fallacies": "humanities",
"machine_learning": "stem",
"management": "other",
"marketing": "other",
"medical_genetics": "other",
"miscellaneous": "other",
"moral_disputes": "humanities",
"moral_scenarios": "humanities",
"nutrition": "other",
"philosophy": "humanities",
"prehistory": "humanities",
"professional_accounting": "other",
"professional_law": "humanities",
"professional_medicine": "other",
"professional_psychology": "social_sciences",
"public_relations": "social_sciences",
"security_studies": "social_sciences",
"sociology": "social_sciences",
"us_foreign_policy": "social_sciences",
"virology": "other",
"world_religions": "humanities",
}


class MMLUEval(Eval):
def __init__(self, num_examples: int | None = None, language: str = "EN-US", n_threads: int = 1):
if language != "EN-US":
url = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv"
else:
url = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
df = pandas.read_csv(url)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
examples = random.Random(0).sample(examples, num_examples)
self.examples = examples
self.n_threads = n_threads

def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
sampler._pack_message(
content=format_multichoice_question(row), role="user"
)
]
# print(f"prompt_messages: {prompt_messages}")
sampler_response = sampler(prompt_messages)
response_text = sampler_response.response_text
# print(f"response_text: {response_text}")
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
extracted_answer = extract_abcd(response_text)
score = 1.0 if extracted_answer == row["Answer"] else 0.0
html = report.jinja_env.from_string(report.HTML_JINJA).render(
prompt_messages=actual_queried_prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["Answer"],
extracted_answer=extracted_answer,
)
convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
category = subject2category.get(row["Subject"], "other")
return SingleEvalResult(
html=html, score=score, metrics={category: score}, convo=convo
)

results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)
return report.aggregate_results(results)