-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcustom.py
81 lines (63 loc) · 2.24 KB
/
custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
For defining a custom set of benchmarks.
"""
import csv
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
from typing import List, TypedDict, cast
import jsonschema
import jsonschema.validators
from benchmark import LMC, TasksStore, LoadedTask, ResultStatus, ZeroShotTask, judge_result
CUSTOM_TASK_SCHEMA = {
"type": "object",
"properties": {
"id": { "type": "string" },
"prompt": { "type": "string" },
"answer": { "type": "string" },
}
}
class CustomTask(TypedDict):
id: str
prompt: str
answer: str
@dataclass
class LoadedCustomTask(LoadedTask[CustomTask]):
task: CustomTask
def to_zero_shot(self) -> ZeroShotTask:
return {"id": self.task["id"], "prompt": self.task["prompt"]}
def to_result_status(self, messages: List[LMC]) -> ResultStatus:
if len(messages) == 0:
return "unknown"
final_message = messages[-1]
if "role" not in final_message:
return "unknown"
if final_message["role"] == "error":
return "error"
if "content" not in final_message:
return "unknown"
expected = self.task["answer"]
prompt = self.to_zero_shot()["prompt"]
return judge_result(prompt, final_message["content"], expected)
@dataclass
class CustomTasks(TasksStore[CustomTask]):
tasks: List[CustomTask] = field(default_factory=list)
@staticmethod
def from_list(l: List[CustomTask]) -> "CustomTasks":
return CustomTasks(l)
@staticmethod
def from_csv(path: PathLike) -> "CustomTasks":
rows: List[CustomTask] = []
if Path(path).exists():
with open(path, "r") as file:
reader = csv.DictReader(file)
for row in reader:
jsonschema.validate(row, CUSTOM_TASK_SCHEMA)
rows.append(cast(CustomTask, row))
else:
raise FileNotFoundError(f"'{path}' does not exist, so it can't be used to load a CustomBenchmark.")
return CustomTasks(rows)
def get_tasks(self) -> List[CustomTask]:
return self.tasks
def load_task(self, task: CustomTask) -> LoadedTask[CustomTask]:
return LoadedCustomTask(task)