71
71
72
72
from benchmarks .eval_accuracy import eval_accuracy
73
73
from benchmarks .eval_accuracy_mmlu import eval_accuracy_mmlu
74
+ from benchmarks .eval_accuracy_longcontext import eval_accuracy_longcontext
74
75
from benchmarks .metrics import CounterMetric , EventMetric
75
76
import grpc
76
77
from jetstream .core .proto import jetstream_pb2
@@ -166,6 +167,7 @@ class InputRequest:
166
167
output : str = ""
167
168
output_len : int = 0
168
169
sample_idx : int = - 1
170
+ metric : str = ""
169
171
170
172
171
173
@dataclass
@@ -187,10 +189,12 @@ def to_dict(self):
187
189
prompt = self .input_request .prompt
188
190
original_output = self .input_request .output
189
191
sample_idx = self .input_request .sample_idx
192
+ metric = self .input_request .metric
190
193
else :
191
194
prompt = None
192
195
original_output = None
193
196
sample_idx = None
197
+ metric = None
194
198
return {
195
199
"prompt" : prompt ,
196
200
"original_output" : original_output ,
@@ -201,6 +205,7 @@ def to_dict(self):
201
205
"ttst_sec" : self .ttst_sec ,
202
206
"prompt_len" : self .prompt_len ,
203
207
"sample_idx" : sample_idx ,
208
+ "metric" : metric ,
204
209
}
205
210
206
211
@@ -281,17 +286,19 @@ def load_openorca_dataset_pkl(
281
286
282
287
def load_longcontext_dataset_pkl (
283
288
dataset_path : str ,
284
- ) -> list [tuple [Any , Any ]]:
289
+ ) -> tuple [ list [tuple [Any , Any ]], list ]:
285
290
assert os .path .isfile (dataset_path )
286
291
287
292
# read pickle file
288
293
data = pandas .read_pickle (dataset_path )
289
294
290
295
samples = []
296
+ metrics = []
291
297
for _ , row in data .iterrows ():
292
- samples .append ((row ["input" ], row ["ref_output" ]))
298
+ samples .append ((row ["input" ], row ["gt_output" ]))
299
+ metrics .append (row ["metric" ])
293
300
294
- return samples
301
+ return samples , metrics
295
302
296
303
297
304
def load_mmlu_dataset_csv (dataset_path : str ) -> tuple [Any , dict [str , str ]]:
@@ -417,7 +424,6 @@ def filter_dataset(
417
424
tokenized_dataset : list [tuple [str , Any , str , int , int , int ]],
418
425
dataset_type : str ,
419
426
max_output_length : int = 0 ,
420
- run_mmlu_dataset : bool = False ,
421
427
min_input_length : int = 4 ,
422
428
max_input_length : int = 0 ,
423
429
max_target_length : int = 0 ,
@@ -439,7 +445,8 @@ def filter_dataset(
439
445
sample_idx ,
440
446
) in tokenized_dataset :
441
447
if prompt_len < min_input_length or (
442
- not (run_mmlu_dataset or dataset_type == "math500" ) and output_len < 4
448
+ not (dataset_type == "mmlu" or dataset_type == "math500" )
449
+ and output_len < 4
443
450
):
444
451
# Prune too short sequences.
445
452
# This is because TGI causes errors when the input or output length
@@ -474,11 +481,11 @@ def sample_requests(
474
481
dataset_type : str ,
475
482
max_output_length : int = 0 ,
476
483
oversample_multiplier : float = 1.2 ,
477
- run_mmlu_dataset : bool = False ,
478
484
min_input_length : int = 4 ,
479
485
max_input_length : int = 0 ,
480
486
max_target_length : int = 0 ,
481
487
max_output_multiplier : int = 0 ,
488
+ metrics : Optional [list [str ]] = None ,
482
489
) -> list [InputRequest ]:
483
490
484
491
# Original dataset size
@@ -514,13 +521,16 @@ def sample_requests(
514
521
tokenized_dataset ,
515
522
dataset_type ,
516
523
max_output_length ,
517
- run_mmlu_dataset ,
518
524
min_input_length ,
519
525
max_input_length ,
520
526
max_target_length ,
521
527
max_output_multiplier ,
522
528
)
523
529
530
+ if metrics is not None :
531
+ for request in input_requests :
532
+ request .metric = metrics [request .sample_idx ]
533
+
524
534
# Sample the requests.
525
535
if len (input_requests ) > num_requests :
526
536
input_requests = random .sample (input_requests , num_requests )
@@ -1035,11 +1045,6 @@ def parse_args() -> argparse.Namespace:
1035
1045
choices = ["HELM" , "Harness" , "" ],
1036
1046
help = "mmlu method/format to generate shots" ,
1037
1047
)
1038
- parser .add_argument (
1039
- "--run-mmlu-dataset" ,
1040
- action = "store_true" ,
1041
- help = "specify if it's for mmlu dataset" ,
1042
- )
1043
1048
return parser .parse_args ()
1044
1049
1045
1050
@@ -1058,6 +1063,7 @@ def main(args: argparse.Namespace):
1058
1063
api_url = f"{ args .server } :{ args .port } "
1059
1064
1060
1065
tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
1066
+ metrics = None
1061
1067
if tokenizer == "test" or args .dataset == "test" :
1062
1068
input_requests = mock_requests (
1063
1069
args .total_mock_requests
@@ -1080,7 +1086,7 @@ def main(args: argparse.Namespace):
1080
1086
args .dataset_path ,
1081
1087
)
1082
1088
elif args .dataset == "longcontext" :
1083
- dataset = load_longcontext_dataset_pkl (
1089
+ dataset , metrics = load_longcontext_dataset_pkl (
1084
1090
args .dataset_path ,
1085
1091
)
1086
1092
else :
@@ -1097,11 +1103,11 @@ def main(args: argparse.Namespace):
1097
1103
num_requests = args .num_prompts ,
1098
1104
dataset_type = args .dataset ,
1099
1105
max_output_length = args .max_output_length ,
1100
- run_mmlu_dataset = args .run_mmlu_dataset ,
1101
1106
min_input_length = args .min_input_length ,
1102
1107
max_input_length = args .max_input_length ,
1103
1108
max_target_length = args .max_target_length ,
1104
1109
max_output_multiplier = args .max_output_multiplier ,
1110
+ metrics = metrics ,
1105
1111
)
1106
1112
1107
1113
warmup_requests = None
@@ -1145,8 +1151,10 @@ def main(args: argparse.Namespace):
1145
1151
# Process output
1146
1152
output = [output .to_dict () for output in request_outputs ]
1147
1153
if args .run_eval :
1148
- if args .run_mmlu_dataset :
1154
+ if args .dataset == "mmlu" :
1149
1155
eval_json = eval_accuracy_mmlu (output )
1156
+ elif args .dataset == "longcontext" :
1157
+ eval_json = eval_accuracy_longcontext (output )
1150
1158
else :
1151
1159
eval_json = eval_accuracy (output , args .dataset [:4 ])
1152
1160
0 commit comments