Skip to content

Commit

Permalink
Fix MMLU in Eval Harness (#859)
Browse files Browse the repository at this point in the history
Fixes averaging and also initialization of aggregates
  • Loading branch information
dlwh authored Jan 22, 2025
1 parent b60fc96 commit 237851b
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 20 deletions.
6 changes: 5 additions & 1 deletion config/gpt2_nano_harness.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
eval_harness:
task_spec: ["piqa", "hellaswag"]
task_spec:
- mmlu
# - task: mmlu
# task_alias: mmlu_0shot
# num_fewshot: 0
max_examples: 32
eval_harness_steps: 50
data:
Expand Down
4 changes: 3 additions & 1 deletion config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
eval_harness:
# task_spec: ["hellaswag"]
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
# - mmlu
- task: mmlu
num_fewshot: 1
task_alias: mmlu_1shot
tokenizer: "gpt2"
model:
type: gpt2
Expand Down
106 changes: 90 additions & 16 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
result_greedy[out_ids[valid_indices]] = out_correct[valid_indices]
covered_points[out_ids[valid_indices]] = True

total_padding += padding_count
total_tokens += batch_tokens

pbar.set_postfix(
padding=(
f"{total_padding + padding_count}/{total_tokens + batch_tokens} ="
f" {(total_padding + padding_count) / (total_tokens + batch_tokens):.2f}"
),
padding=f"{total_padding}/{total_tokens} = {(total_padding) / (total_tokens):.2f}",
this_padding=f"{padding_count}/{batch_tokens}= {padding_count / batch_tokens:.2f}",
)
pbar.update(len(segments_this_batch))
Expand Down Expand Up @@ -420,11 +420,14 @@ def to_task_dict(self) -> dict:
else:
our_name = task.get("task_alias", task["task"]) if isinstance(task, dict) else task
our_name = our_name.replace(" ", "_")
this_task = self._get_task_and_rename(manager, our_name, task)
this_tasks[our_name] = this_task
except Exception:
tasks_for_this_task_spec = self._get_task_and_rename(manager, our_name, task)
for k, v in tasks_for_this_task_spec.items():
if k in this_tasks:
raise ValueError(f"Task {k} already exists")
this_tasks[k] = v
except Exception as e:
logger.exception(f"Failed to load task {task}")
raise ValueError(f"Failed to load task {task}")
raise ValueError(f"Failed to load task {task}") from e

logger.info(f"Loaded {len(this_tasks)} tasks")
return this_tasks
Expand All @@ -437,12 +440,84 @@ def _get_task_and_rename(self, manager, our_name, task: dict | str):
"""
import lm_eval.tasks as tasks

task_name = task if isinstance(task, str) else task["task"]

task_dict = tasks.get_task_dict([task], manager)
this_task = task_dict.popitem()[1]
# hacky, but this allows us to run multiple instances of the same task with different fewshot settings
this_task.config.task = our_name
assert len(task_dict) == 1, f"Expected 1 task, got {len(task_dict)}"
try:
this_task = self._rename_tasks_for_eval_harness(task_dict, task_name, our_name)
except AttributeError:
logger.exception(f"Failed to rename task {task}: {task_dict}")
raise ValueError(f"Failed to rename task {task}: {task_dict}")
return this_task

def _rename_tasks_for_eval_harness(self, this_task, lm_eval_task_name, our_name):
import lm_eval.tasks as tasks

# hacky, but this allows us to run multiple instances of the same task with different fewshot settings
if isinstance(this_task, dict):
out = {}
for k, v in this_task.items():
v = self._rename_tasks_for_eval_harness(v, lm_eval_task_name, our_name)

if isinstance(k, tasks.ConfigurableGroup):
k._config.group = self._replace_name_with_our_name(k.group, lm_eval_task_name, our_name)
out[k] = v
elif isinstance(k, str):
k = self._replace_name_with_our_name(k, lm_eval_task_name, our_name)
if isinstance(v, dict):
subtask_list = self._get_child_tasks(v)
# ok so inexplicably, lm_eval_harness doesn't wrap the key in a ConfigurableGroup when you pass
# in a task dict (it seems like a mistake), so we need to do that here
# subtask is the name of all of the child tasks in v
group = tasks.ConfigurableGroup(config={"group": k, "task": subtask_list})
out[group] = v
else:
out[k] = v
else:
raise ValueError(f"Unknown key type: {k}")

return out

elif isinstance(this_task, tasks.ConfigurableTask):
this_task.config.task = self._replace_name_with_our_name(
this_task.config.task, lm_eval_task_name, our_name
)
return this_task
else:
raise ValueError(f"Unknown task type: {this_task}")

def _replace_name_with_our_name(self, lm_eval_name, lm_eval_prefix, our_name_prefix):
if our_name_prefix.startswith(lm_eval_prefix):
suffix = our_name_prefix[len(lm_eval_prefix) :]
prefix = lm_eval_prefix
else:
suffix = ""
prefix = our_name_prefix
if lm_eval_prefix in lm_eval_name:
lm_eval_name = lm_eval_name.replace(lm_eval_prefix, prefix) + suffix
else:
lm_eval_name = prefix + "_" + lm_eval_name + suffix
return lm_eval_name

def _get_child_tasks(self, task_group):
import lm_eval.tasks as tasks

out = []
for k, v in task_group.items():
if isinstance(k, tasks.ConfigurableGroup):
subtask_or_tasks = k.config.task
if isinstance(subtask_or_tasks, str):
out.append(subtask_or_tasks)
else:
out.extend(subtask_or_tasks)
elif isinstance(k, str):
out.append(k)
else:
raise ValueError(f"Unknown key type: {k}")

return out


@dataclass(frozen=True)
class EvalHarnessMainConfig:
Expand Down Expand Up @@ -567,15 +642,14 @@ def _compute_averages(outputs):
for task_results in outputs["results"].values():
metric_keys.update(k for k in task_results.keys() if "stderr" not in k and k != "alias")

examples_per_task = [task_samples["effective"] for task_samples in outputs["n-samples"].values()]

# Compute macro and micro averages
for metric in metric_keys:
# Collect valid tasks for this metric
# We iterate over the n-samples because real tasks (as opposed to aggregates like "mmlu") have counts
valid_tasks = [
(task_results.get(metric), examples_per_task[i])
for i, (task_name, task_results) in enumerate(outputs["results"].items())
if metric in task_results
(outputs["results"][task_name].get(metric), outputs["n-samples"][task_name]["effective"])
for task_name in outputs["n-samples"]
if outputs["results"][task_name].get(metric, None) is not None
]

if not valid_tasks:
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]]
source_offsets = _virtual_offset(source_offsets, data_offset)

delay = 1
delay = 4
while True:
try:
async with ts.Transaction() as txn:
Expand All @@ -1253,7 +1253,7 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
logger.info("Rate limit exceeded. Retrying.")
await asyncio.sleep(delay)
delay *= 2
if delay > 60:
if delay > 120:
raise

futures.append(offset_future)
Expand Down

0 comments on commit 237851b

Please sign in to comment.