Skip to content

Save and output number of samples of each task #851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/saving-and-reading-results.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ The detail file contains the following columns:
"padded": 0,
"non_padded": 2,
"num_truncated_few_shots": 0
},
"num_samples": {
"lighteval|gsm8k|0": 1,
"all": 1
}
}
```
29 changes: 29 additions & 0 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import collections
import json
import logging
import os
Expand Down Expand Up @@ -228,6 +229,7 @@ def save(self) -> None:
"config_tasks": self.task_config_logger.tasks_configs,
"summary_tasks": self.details_logger.compiled_details,
"summary_general": asdict(self.details_logger.compiled_details_over_all_tasks),
"num_samples": self.calculate_num_samples(),
}

# Create the details datasets for later upload
Expand Down Expand Up @@ -351,6 +353,7 @@ def generate_final_dict(self) -> dict:
"config_tasks": self.task_config_logger.tasks_configs,
"summary_tasks": self.details_logger.compiled_details,
"summary_general": asdict(self.details_logger.compiled_details_over_all_tasks),
"num_samples": self.calculate_num_samples(),
}

final_dict = {
Expand Down Expand Up @@ -724,3 +727,29 @@ def push_to_tensorboard( # noqa: C901
f"Pushed to tensorboard at https://huggingface.co/{self.tensorboard_repo}/{output_dir_tb}/tensorboard"
f" at global_step {global_step}"
)

def calculate_num_samples(self) -> dict[str, int]:
"""
Counts the number of samples per task, includes grouped tasks.
This implementation is oriented on MetricsLogger.aggregate(), to make sure the subgroups of tasks match up.
"""

# Count samples of individual tasks
num_samples = {task: len(samples) for task, samples in self.details_logger.details.items()}

# Count samples for sub groups
grouped_tasks = collections.defaultdict(list)

for task in num_samples:
if "|" in task:
suite, task, fewshot = task.split("|")
grouped_tasks[f"{suite}|{task.split(':')[0]}:_average|{fewshot}"].append(task)

for average_task, list_of_subtasks in grouped_tasks.items():
if len(list_of_subtasks) > 1:
num_samples[average_task] = sum(num_samples[k] for k in list_of_subtasks)

# Add sample count for all
num_samples["all"] = sum(count for task, count in num_samples.items() if task != "all")

return num_samples
11 changes: 8 additions & 3 deletions src/lighteval/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,29 @@ def flatten(item: list[Union[list, str]]) -> list[str]:
def make_results_table(result_dict):
"""Generate table of results."""
md_writer = MarkdownTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
md_writer.headers = ["Task", "Version", "Number of Samples", "Metric", "Value", "", "Stderr"]

values = []

# For backwards compatibility, create empty dict if result_dict doesn't contain num_samples
num_samples_dict = result_dict["num_samples"] if "num_samples" in result_dict else {}

for k in sorted(result_dict["results"].keys()):
dic = result_dict["results"][k]
version = result_dict["versions"][k] if k in result_dict["versions"] else ""
num_samples = num_samples_dict[k] if k in num_samples_dict else ""
for m, v in dic.items():
if m.endswith("_stderr"):
continue

if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
values.append([k, version, num_samples, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, m, "%.4f" % v, "", ""])
values.append([k, version, num_samples, m, "%.4f" % v, "", ""])
k = ""
version = ""
num_samples = ""
md_writer.value_matrix = values

return md_writer.dumps()
Expand Down