71
71
72
72
from benchmarks .eval_accuracy import eval_accuracy
73
73
from benchmarks .eval_accuracy_mmlu import eval_accuracy_mmlu
74
+ from benchmarks .eval_accuracy_longcontext import eval_accuracy_longcontext
74
75
from benchmarks .metrics import CounterMetric , EventMetric
75
76
import grpc
76
77
from jetstream .core .proto import jetstream_pb2
@@ -166,6 +167,7 @@ class InputRequest:
166
167
output : str = ""
167
168
output_len : int = 0
168
169
sample_idx : int = - 1
170
+ metric : str = ""
169
171
170
172
171
173
@dataclass
@@ -187,6 +189,7 @@ def to_dict(self):
187
189
prompt = self .input_request .prompt
188
190
original_output = self .input_request .output
189
191
sample_idx = self .input_request .sample_idx
192
+ metric = self .input_request .metric
190
193
else :
191
194
prompt = None
192
195
original_output = None
@@ -201,6 +204,7 @@ def to_dict(self):
201
204
"ttst_sec" : self .ttst_sec ,
202
205
"prompt_len" : self .prompt_len ,
203
206
"sample_idx" : sample_idx ,
207
+ "metric" : metric ,
204
208
}
205
209
206
210
@@ -281,17 +285,19 @@ def load_openorca_dataset_pkl(
281
285
282
286
def load_longcontext_dataset_pkl (
283
287
dataset_path : str ,
284
- ) -> list [tuple [Any , Any ]]:
288
+ ) -> tuple [ list [tuple [Any , Any ]], list ]:
285
289
assert os .path .isfile (dataset_path )
286
290
287
291
# read pickle file
288
292
data = pandas .read_pickle (dataset_path )
289
293
290
294
samples = []
295
+ metrics = []
291
296
for _ , row in data .iterrows ():
292
- samples .append ((row ["input" ], row ["ref_output" ]))
297
+ samples .append ((row ["input" ], row ["gt_output" ]))
298
+ metrics .append (row ["metric" ])
293
299
294
- return samples
300
+ return samples , metrics
295
301
296
302
297
303
def load_mmlu_dataset_csv (dataset_path : str ) -> tuple [Any , dict [str , str ]]:
@@ -417,7 +423,6 @@ def filter_dataset(
417
423
tokenized_dataset : list [tuple [str , Any , str , int , int , int ]],
418
424
dataset_type : str ,
419
425
max_output_length : int = 0 ,
420
- run_mmlu_dataset : bool = False ,
421
426
min_input_length : int = 4 ,
422
427
max_input_length : int = 0 ,
423
428
max_target_length : int = 0 ,
@@ -439,7 +444,8 @@ def filter_dataset(
439
444
sample_idx ,
440
445
) in tokenized_dataset :
441
446
if prompt_len < min_input_length or (
442
- not (run_mmlu_dataset or dataset_type == "math500" ) and output_len < 4
447
+ not (dataset_type == "mmlu" or dataset_type == "math500" )
448
+ and output_len < 4
443
449
):
444
450
# Prune too short sequences.
445
451
# This is because TGI causes errors when the input or output length
@@ -474,11 +480,11 @@ def sample_requests(
474
480
dataset_type : str ,
475
481
max_output_length : int = 0 ,
476
482
oversample_multiplier : float = 1.2 ,
477
- run_mmlu_dataset : bool = False ,
478
483
min_input_length : int = 4 ,
479
484
max_input_length : int = 0 ,
480
485
max_target_length : int = 0 ,
481
486
max_output_multiplier : int = 0 ,
487
+ metrics : Optional [list [str ]] = None ,
482
488
) -> list [InputRequest ]:
483
489
484
490
# Original dataset size
@@ -514,13 +520,16 @@ def sample_requests(
514
520
tokenized_dataset ,
515
521
dataset_type ,
516
522
max_output_length ,
517
- run_mmlu_dataset ,
518
523
min_input_length ,
519
524
max_input_length ,
520
525
max_target_length ,
521
526
max_output_multiplier ,
522
527
)
523
528
529
+ if metrics is not None :
530
+ for request in input_requests :
531
+ request .metric = metrics [request .sample_idx ]
532
+
524
533
# Sample the requests.
525
534
if len (input_requests ) > num_requests :
526
535
input_requests = random .sample (input_requests , num_requests )
@@ -1035,11 +1044,6 @@ def parse_args() -> argparse.Namespace:
1035
1044
choices = ["HELM" , "Harness" , "" ],
1036
1045
help = "mmlu method/format to generate shots" ,
1037
1046
)
1038
- parser .add_argument (
1039
- "--run-mmlu-dataset" ,
1040
- action = "store_true" ,
1041
- help = "specify if it's for mmlu dataset" ,
1042
- )
1043
1047
return parser .parse_args ()
1044
1048
1045
1049
@@ -1058,6 +1062,7 @@ def main(args: argparse.Namespace):
1058
1062
api_url = f"{ args .server } :{ args .port } "
1059
1063
1060
1064
tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
1065
+ metrics = None
1061
1066
if tokenizer == "test" or args .dataset == "test" :
1062
1067
input_requests = mock_requests (
1063
1068
args .total_mock_requests
@@ -1080,7 +1085,7 @@ def main(args: argparse.Namespace):
1080
1085
args .dataset_path ,
1081
1086
)
1082
1087
elif args .dataset == "longcontext" :
1083
- dataset = load_longcontext_dataset_pkl (
1088
+ dataset , metrics = load_longcontext_dataset_pkl (
1084
1089
args .dataset_path ,
1085
1090
)
1086
1091
else :
@@ -1097,11 +1102,11 @@ def main(args: argparse.Namespace):
1097
1102
num_requests = args .num_prompts ,
1098
1103
dataset_type = args .dataset ,
1099
1104
max_output_length = args .max_output_length ,
1100
- run_mmlu_dataset = args .run_mmlu_dataset ,
1101
1105
min_input_length = args .min_input_length ,
1102
1106
max_input_length = args .max_input_length ,
1103
1107
max_target_length = args .max_target_length ,
1104
1108
max_output_multiplier = args .max_output_multiplier ,
1109
+ metrics = metrics ,
1105
1110
)
1106
1111
1107
1112
warmup_requests = None
@@ -1145,8 +1150,10 @@ def main(args: argparse.Namespace):
1145
1150
# Process output
1146
1151
output = [output .to_dict () for output in request_outputs ]
1147
1152
if args .run_eval :
1148
- if args .run_mmlu_dataset :
1153
+ if args .dataset == "mmlu" :
1149
1154
eval_json = eval_accuracy_mmlu (output )
1155
+ elif args .dataset == "longcontext" :
1156
+ eval_json = eval_accuracy_longcontext (output )
1150
1157
else :
1151
1158
eval_json = eval_accuracy (output , args .dataset [:4 ])
1152
1159
0 commit comments