5
5
import math
6
6
import os
7
7
import os .path as osp
8
+ import shutil
8
9
from datetime import datetime
9
10
from typing import Any , Dict , List , Optional
10
11
26
27
model_abbr_from_cfg )
27
28
from opencompass .utils .prompt import get_prompt_hash
28
29
30
+ model_name_mapping = {
31
+ 'llama-2-7b-chat-hf' : 'LLaMA-2-7B' ,
32
+ 'llama-2-13b-chat-hf' : 'LLaMA-2-13B' ,
33
+ 'llama-2-70b-chat-hf' : 'LLaMA-2-70B' ,
34
+ 'baichuan2-7b-chat-hf' : 'Baichuan2-7B' ,
35
+ 'baichuan2-13b-chat-hf' : 'Baichuan2-13B' ,
36
+ 'yi-6b-chat-hf' : 'Yi-6B' ,
37
+ 'yi-34b-chat-hf' : 'Yi-34B' ,
38
+ 'deepseek-67b-chat-hf' : 'DeepSeek-67B' ,
39
+ 'wizardlm-70b-v1.0-vllm' : 'WizardLM-70B' ,
40
+ 'qwen-14b-chat-hf' : 'Qwen-14B' ,
41
+ 'qwen-72b-chat-hf' : 'Qwen-72B' ,
42
+ 'qwen-72b-chat-vllm' : 'Qwen-72B-vLLM' ,
43
+ 'internlm2-chat-7b-turbomind' : 'InternLM2-7B-200K' ,
44
+ 'internlm2-chat-20b-turbomind' : 'InternLM2-20B-200K' ,
45
+ 'internlm2-chat-7b-hf' : 'InternLM2-7B' ,
46
+ 'internlm2-chat-20b-hf' : 'InternLM2-20B' ,
47
+ 'qwen-7b-chat-hf' : 'Qwen-7B' ,
48
+ 'chatglm3-6b-hf' : 'ChatGLM3-6B' ,
49
+ 'chatglm3-6b-32k-hf' : 'ChatGLM3-6B-32K' ,
50
+ 'zephyr-7b-beta-vllm' : 'Zephyr-7B Beta' ,
51
+ 'mistral-7b-instruct-v0.2-vllm' : 'Mistral-7B Inst. v0.2' ,
52
+ 'mistral-7b-instruct-v0.1-vllm' : 'Mistral-7B Inst. v0.1' ,
53
+ 'mixtral-8x7b-instruct-v0.1-vllm' : 'Mixtral-8x7B Inst. v0.1' ,
54
+ 'orionstar-yi-34b-chat-hf' : 'OrionStar-Yi-34B' ,
55
+ 'orionstar-14b-long-chat-vllm' : 'Orion-14B-LongChat' ,
56
+ 'internlm-chat-7b-hf' : 'InternLM-7B' ,
57
+ 'gemma-2b-it-hf' : 'Gemma-2B' ,
58
+ 'gemma-7b-it-hf' : 'Gemma-7B' ,
59
+ 'qwen1.5-0.5b-chat-hf' : 'Qwen-1.5-0.5B' ,
60
+ 'qwen1.5-1.8b-chat-hf' : 'Qwen-1.5-1.8B' ,
61
+ 'qwen1.5-4b-chat-hf' : 'Qwen-1.5-4B' ,
62
+ 'qwen1.5-14b-chat-hf' : 'Qwen-1.5-14B' ,
63
+ 'qwen1.5-72b-chat-hf' : 'Qwen-1.5-72B' ,
64
+ 'qwen1.5-14b-chat-vllm' : 'Qwen-1.5-14B-vLLM' ,
65
+ 'qwen1.5-72b-chat-vllm' : 'Qwen-1.5-72B-vLLM' ,
66
+ 'glm4_notools' : 'GLM-4' ,
67
+ 'claude-3-opus' : 'Claude-3-Opus' ,
68
+ # Add more mappings as necessary
69
+ }
70
+
71
+ dataset_mapping_dict = {}
72
+
73
+ needle_counts = ['2' , '3' , '4' , '5' ]
74
+ languages = ['en' , 'zh' ]
75
+ sizes = ['4k' , '8k' , '32k' , '200k' , '1000k' ]
76
+ types = ['origin' , 'parallel' ]
77
+
78
+ for needle_count in needle_counts :
79
+ for language in languages :
80
+ for size in sizes :
81
+ key = f'{ needle_count } needle_{ language } _{ size } '
82
+ value = f'{ needle_count } -Needle-Reasoning-{ language .upper ()} -{ size .upper ()} '
83
+ dataset_mapping_dict [key ] = value
84
+ for t in types :
85
+ for language in languages :
86
+ for size in sizes :
87
+ if t == 'origin' :
88
+ key = f'{ t } _{ language } _{ size } '
89
+ value = f'Single-Needle-Retrieval-{ language .upper ()} -{ size .upper ()} '
90
+ elif t == 'parallel' :
91
+ key = f'{ t } _{ language } _{ size } '
92
+ value = f'Multi-Needle-Retrieval-{ language .upper ()} -{ size .upper ()} '
93
+ dataset_mapping_dict [key ] = value
94
+
95
+
96
+ def calculate_elementwise_average (model_name , merged_df ):
97
+ score_columns = [col for col in merged_df .columns if col != 'dataset' ]
98
+
99
+ origin_columns = [col for col in score_columns if 'origin' in col ]
100
+ parallel_columns = [col for col in score_columns if 'parallel' in col ]
101
+ multi_columns = [col for col in score_columns if 'needle' in col ]
102
+
103
+ if origin_columns and parallel_columns and multi_columns :
104
+ origin_avg = merged_df [origin_columns ].mean (axis = 1 ) * 0.4
105
+ parallel_avg = merged_df [parallel_columns ].mean (axis = 1 ) * 0.3
106
+ multi_avg = merged_df [multi_columns ].mean (axis = 1 ) * 0.3
107
+ merged_df [model_name ] = origin_avg + parallel_avg + multi_avg
108
+ else :
109
+ relevant_columns = origin_columns or parallel_columns or multi_columns
110
+ if relevant_columns :
111
+ merged_df [model_name ] = merged_df [relevant_columns ].mean (axis = 1 )
112
+ else :
113
+ merged_df [model_name ] = pd .Series ([0 ] * len (merged_df ))
114
+
115
+ return merged_df .iloc [:, [0 , - 1 ]]
29
116
30
117
def read_after_specific_line_except_last (file_name , keyword , offset ):
31
118
with open (file_name , 'r' , encoding = 'utf-8' ) as file :
@@ -65,6 +152,12 @@ def create_model_dataframe(nested_dict, model_name, dataset_abbr, parallel=False
65
152
df = pd .DataFrame (data , columns = ['dataset' , model_name ])
66
153
return df
67
154
155
+ def convert_to_k (value ):
156
+ try :
157
+ return f'{ int (value ) // 1000 } k'
158
+ except ValueError :
159
+ return value
160
+
68
161
def parse_model_scores (text ):
69
162
lines = text .split ('\n ' )
70
163
@@ -82,8 +175,86 @@ def parse_model_scores(text):
82
175
83
176
return result_dict
84
177
178
+ def remove_empty_subfolders (plot_path ):
179
+ for folder_name in tqdm (os .listdir (plot_path ),
180
+ desc = 'Deleting Empty folders' ):
181
+ folder_path = os .path .join (plot_path , folder_name )
182
+ if os .path .isdir (folder_path ):
183
+ if not os .listdir (folder_path ):
184
+ shutil .rmtree (folder_path )
185
+
186
+ def save_results_to_plots (txt_results_save_path ):
187
+ content = read_after_specific_line_except_last (txt_results_save_path , 'raw format' , 2 )
188
+ parsed_data = parse_model_scores (content )
189
+ model_names = get_dict_model_names (parsed_data )
190
+ numbers = [2 , 3 , 4 , 5 ]
191
+ languages = ['en' , 'zh' ]
192
+ size_exists = []
193
+ sizes_origin = ['_4k' , '_8k' , '_32k' , '_128k' , '_200k' , '_1000k' ]
194
+
195
+ for size in sizes_origin :
196
+ if size in content :
197
+ size_exists .append (size )
198
+
199
+ multi_dataset_abbrs = [f'{ num } needle_{ lang } { size } ' for num in numbers for lang in languages for size in size_exists ]
200
+ origin_dataset_abbrs = [f'origin_{ lang } { size } ' for lang in languages for size in size_exists ]
201
+ parallel_dataset_abbrs = [f'parallel_{ lang } { size } ' for lang in languages for size in size_exists ]
202
+
203
+ dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + \
204
+ parallel_dataset_abbrs
205
+ base_path = os .path .dirname (txt_results_save_path )
206
+ plot_path = os .path .join (base_path , 'plots' )
207
+
208
+ model_scores = {}
209
+
210
+ for model_name in tqdm (model_names ):
211
+ model_datasets_scores = {} # Dictionary to store scores for each dataset for the current model
212
+ for dataset_abbr in dataset_abbrs :
213
+ parallel_flag = 'parallel' in dataset_abbr
214
+
215
+ folder_path = os .path .join (plot_path , dataset_mapping_dict [dataset_abbr ])
216
+ ensure_directory (folder_path )
217
+
218
+ save_path = os .path .join (folder_path , f'{ model_name } .png' )
219
+
220
+ df = create_model_dataframe (parsed_data , model_name , dataset_abbr , parallel = parallel_flag )
221
+
222
+ score = visualize (df , save_path , model_name , dataset_abbr )
223
+
224
+ model_datasets_scores [dataset_abbr ] = '{:.02f}' .format (score )
225
+
226
+ overall_dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + parallel_dataset_abbrs
227
+ overall_score_pic_path = os .path .join (plot_path , f'{ model_name } _overall.png' )
228
+ merged_df = merge_dataframes (model_name , overall_dataset_abbrs , parsed_data )
229
+ averaged_df = calculate_elementwise_average (model_name , merged_df )
230
+ overall_score = visualize (averaged_df , overall_score_pic_path , model_name , 'Overall Score' )
231
+
232
+ # Single-Retrieval
233
+ single_retrieval_score_pic_path = os .path .join (plot_path , f'{ model_name } _single_retrieval_overall.png' )
234
+ single_retrieval_merged_df = merge_dataframes (model_name , origin_dataset_abbrs , parsed_data )
235
+ single_retrieval_averaged_df = calculate_elementwise_average (model_name , single_retrieval_merged_df )
236
+ single_retrieval_overall_score = visualize (single_retrieval_averaged_df , single_retrieval_score_pic_path , model_name , 'Single-Retrieval Overall Score' )
237
+
238
+ # Multi-Retrieval
239
+ multi_retrieval_score_pic_path = os .path .join (plot_path , f'{ model_name } _multi_retrieval_overall.png' )
240
+ multi_retrieval_merged_df = merge_dataframes (model_name , parallel_dataset_abbrs , parsed_data )
241
+ multi_retrieval_averaged_df = calculate_elementwise_average (model_name , multi_retrieval_merged_df )
242
+ multi_retrieval_overall_score = visualize (multi_retrieval_averaged_df , multi_retrieval_score_pic_path , model_name , 'Multi-Retrieval Overall Score' )
243
+
244
+ # Multi-Reasoning
245
+ multi_reasoning_score_pic_path = os .path .join (plot_path , f'{ model_name } _multi_reasoning_overall.png' )
246
+ multi_reasoning_merged_df = merge_dataframes (model_name , multi_dataset_abbrs , parsed_data )
247
+ multi_reasoning_averaged_df = calculate_elementwise_average (model_name , multi_reasoning_merged_df )
248
+ multi_reasoning_overall_score = visualize (multi_reasoning_averaged_df , multi_reasoning_score_pic_path , model_name , 'Multi-Reasoning Overall Score' )
249
+
250
+ model_scores [model_name ] = averaged_df
251
+ remove_empty_subfolders (plot_path )
252
+ return model_scores
253
+
85
254
def visualize (df_raw , save_path : str ,model_name : str ,dataset_type :str ):
86
255
df = df_raw .copy ()
256
+ if df .empty :
257
+ return - 1
87
258
df ['Context Length' ] = df ['dataset' ].apply (
88
259
lambda x : int (x .split ('Length' )[1 ].split ('Depth' )[0 ]))
89
260
df ['Document Depth' ] = df ['dataset' ].apply (
@@ -98,144 +269,96 @@ def visualize(df_raw, save_path: str,model_name: str ,dataset_type:str):
98
269
model_df = df [['Document Depth' , 'Context Length' ,
99
270
model_name ]].copy ()
100
271
model_df .rename (columns = {model_name : 'Score' }, inplace = True )
101
-
102
- # Create pivot table
103
272
pivot_table = pd .pivot_table (model_df ,
104
- values = 'Score' ,
105
- index = ['Document Depth' ],
106
- columns = ['Context Length' ],
107
- aggfunc = 'mean' )
273
+ values = 'Score' ,
274
+ index = ['Document Depth' ],
275
+ columns = ['Context Length' ],
276
+ aggfunc = 'mean' )
108
277
109
- # Calculate mean scores
110
278
mean_scores = pivot_table .mean ().values
111
-
112
- # Calculate overall score
113
279
overall_score = mean_scores .mean ()
114
-
115
- # Create heatmap and line plot
116
- plt .figure (figsize = (15.5 , 8 ))
280
+ plt .figure (figsize = (10 , 6 ))
117
281
ax = plt .gca ()
118
282
cmap = LinearSegmentedColormap .from_list (
119
283
'custom_cmap' , ['#F0496E' , '#EBB839' , '#0CD79F' ])
120
284
121
- # Draw heatmap
122
285
sns .heatmap (pivot_table ,
123
286
cmap = cmap ,
124
287
ax = ax ,
125
- cbar_kws = {'label' : 'Score' },
126
288
vmin = 0 ,
127
289
vmax = 100 )
128
-
129
- # Set line plot data
290
+ cbar = ax .collections [0 ].colorbar
130
291
x_data = [i + 0.5 for i in range (len (mean_scores ))]
131
292
y_data = mean_scores
132
293
133
- # Create twin axis for line plot
134
294
ax2 = ax .twinx ()
135
- # Draw line plot
136
295
ax2 .plot (x_data ,
137
- y_data ,
138
- color = 'white' ,
139
- marker = 'o' ,
140
- linestyle = '-' ,
141
- linewidth = 2 ,
142
- markersize = 8 ,
143
- label = 'Average Depth Score' )
144
- # Set y-axis range
296
+ y_data ,
297
+ color = 'white' ,
298
+ marker = 'o' ,
299
+ linestyle = '-' ,
300
+ linewidth = 2 ,
301
+ markersize = 8 ,
302
+ label = 'Average Depth Score'
303
+ )
145
304
ax2 .set_ylim (0 , 100 )
146
305
147
- # Hide original y-axis ticks and labels
148
306
ax2 .set_yticklabels ([])
149
307
ax2 .set_yticks ([])
150
308
151
- # Add legend
152
- ax2 .legend (loc = 'upper left' )
153
-
154
- # Set chart title and labels
155
- ax .set_title (f'{ model_name } { dataset_type } Context '
156
- 'Performance\n Fact Retrieval Across '
157
- 'Context Lengths ("Needle In A Haystack")' )
158
- ax .set_xlabel ('Token Limit' )
159
- ax .set_ylabel ('Depth Percent' )
160
- ax .set_xticklabels (pivot_table .columns .values , rotation = 45 )
161
- ax .set_yticklabels (pivot_table .index .values , rotation = 0 )
162
- # Add overall score as a subtitle
163
- plt .text (0.5 ,
164
- - 0.13 , f'Overall Score for { model_name } : '
165
- f'{ overall_score :.2f} ' ,
166
- ha = 'center' ,
167
- va = 'center' ,
168
- transform = ax .transAxes ,
169
- fontsize = 13 )
170
-
171
- plt .tight_layout ()
172
- plt .subplots_adjust (right = 1 )
173
- plt .draw ()
174
- plt .savefig (save_path )
175
- print (f'Saved :{ save_path } ' )
176
- plt .close () # Close figure to prevent memory leaks
177
- return overall_score
178
-
179
- def save_results_to_plots (txt_results_save_path ):
180
-
181
- content = read_after_specific_line_except_last (txt_results_save_path , 'raw format' , 2 )
182
-
183
- parsed_data = parse_model_scores (content )
184
- model_names = get_dict_model_names (parsed_data )
185
- numbers = [2 , 3 , 4 , 5 ]
186
- languages = ['en' , 'zh' ]
187
- size_exists = []
188
- sizes_origin = ['_4k' , '_8k' , '_32k' , '_128k' , '_200k' ]
189
-
190
- for size in sizes_origin :
191
- if size in content :
192
- size_exists .append (size )
193
-
194
- multi_dataset_abbrs = [f'{ num } needle_{ lang } { size } ' for num in numbers for lang in languages for size in size_exists ]
195
- origin_dataset_abbrs = [f'origin_{ lang } { size } ' for lang in languages for size in size_exists ]
196
- parallel_dataset_abbrs = [f'parallel_{ lang } { size } ' for lang in languages for size in size_exists ]
197
-
198
- dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + \
199
- parallel_dataset_abbrs
200
- base_path = os .path .dirname (txt_results_save_path )
201
- plot_path = os .path .join (base_path , 'plots' )
202
- model_scores = {}
203
- for model_name in tqdm (model_names ):
204
- model_datasets_scores = {} # Dictionary to store scores for each dataset for the current model
205
- for dataset_abbr in dataset_abbrs :
206
- parallel_flag = 'parallel' in dataset_abbr
309
+ ax2 .legend (loc = 'lower left' )
207
310
208
- # Create a directory for each dataset_abbr
209
- folder_path = os .path .join (plot_path , dataset_abbr )
210
- ensure_directory (folder_path )
311
+ if model_name in model_name_mapping :
312
+ title_name = model_name_mapping [model_name ]
313
+ else :
314
+ title_name = model_name
211
315
212
- # Construct the full path to save the image
213
- save_path = os .path .join (folder_path , f'{ model_name } .png' )
316
+ ax .set_title (title_name , fontsize = 12 , fontweight = 'bold' , pad = 15 )
214
317
215
- # Create DataFrame for the model and dataset
216
- df = create_model_dataframe (parsed_data , model_name , dataset_abbr , parallel = parallel_flag )
318
+ if dataset_type in dataset_mapping_dict :
319
+ dataset_name = dataset_mapping_dict [dataset_type ]
320
+ else :
321
+ dataset_name = dataset_type
322
+
323
+ ax .text (0.5 , 1.005 , f'{ dataset_name } :{ overall_score :.2f} ' ,
324
+ transform = ax .transAxes ,
325
+ ha = 'center' ,
326
+ fontsize = 12 ,
327
+ fontweight = 'normal' )
328
+ ax .set_xlabel ('Token Length' , fontsize = 13 , fontweight = 'normal' , labelpad = 1 )
329
+ ax .set_ylabel ('Depth Percent(%)' , fontsize = 13 , fontweight = 'normal' , labelpad = 1 )
330
+ converted_labels = [convert_to_k (value ) for value in pivot_table .columns .values ]
331
+
332
+ ax .tick_params (axis = 'both' , which = 'major' , length = 1 , pad = 1 )
333
+ ax .tick_params (axis = 'both' , which = 'minor' , length = 1 , pad = 1 )
334
+ ax .set_xticklabels (converted_labels , rotation = 45 )
335
+ index_length = len (pivot_table .index )
336
+
337
+ selected_indices = pivot_table .index .values [::2 ]
338
+ labels = [str (int (index )) for index in selected_indices ]
339
+ ax .set_yticks (np .arange (0 , len (pivot_table .index ), 2 ))
340
+ ax .set_yticklabels (labels , rotation = 0 )
341
+ for spine in ax .spines .values ():
342
+ spine .set_visible (False )
343
+ for spine in ax2 .spines .values ():
344
+ spine .set_visible (False )
217
345
218
- # Generate visualization and get the score
219
- score = visualize (df , save_path , model_name , dataset_abbr )
346
+ plt .tight_layout ()
347
+ plt .draw ()
348
+ directory_path , original_filename = os .path .split (save_path )
220
349
221
- # Store the score in the dictionary
222
- model_datasets_scores [ dataset_abbr ] = '{:.02f}' . format ( score )
350
+ filename_suffix = ( title_name + '_' + dataset_name ). replace ( ' ' , '_' )
351
+ new_filename = f' { filename_suffix } .png'
223
352
224
- # Process and visualize the overall score
225
- overall_score_pic_path = os .path .join (plot_path , f'{ model_name } _overall.png' )
226
- merged_df = merge_dataframes (model_name , dataset_abbrs , parsed_data )
353
+ new_save_path = os .path .join (directory_path , new_filename )
227
354
228
- print ( merge_dataframes )
229
- averaged_df = calculate_elementwise_average ( merged_df )
355
+ plt . savefig ( new_save_path , format = 'png' , bbox_inches = 'tight' , pad_inches = 0 )
356
+ print ( f'Saved : { new_save_path } ' )
230
357
231
- # Assume visualize returns the average score for the overall visualization
232
- overall_score = visualize (averaged_df , overall_score_pic_path , 'weighted_average_score' , 'Overall Score' )
358
+ plt .close ()
233
359
234
- # Add the overall score to the dictionary
235
- model_datasets_scores ['Overall' ] = '{:.02f}' .format (overall_score )
360
+ return overall_score
236
361
237
- # Add the model's scores to the main dictionary
238
- model_scores [model_name ] = model_datasets_scores
239
362
240
363
def ensure_directory (path ):
241
364
if not os .path .exists (path ):
@@ -263,29 +386,11 @@ def merge_dataframes(model_name, dataset_abbrs, parsed_data):
263
386
merged_df = reduce (lambda left , right : pd .merge (left , right , on = 'dataset' , how = 'outer' ), dfs )
264
387
265
388
if merged_df .isnull ().any ().any ():
266
- print ('Warning: Some rows were filtered out due to NaN values. This is often due to mismatched row counts among DataFrames.' )
389
+ print ('Warning: Some rows were filtered out due to NaN values. '
390
+ 'This is often due to mismatched row counts among DataFrames.' )
267
391
merged_df = merged_df .dropna ()
268
392
return merged_df
269
393
270
- def calculate_elementwise_average (merged_df ):
271
- score_columns = [col for col in merged_df .columns if col != 'dataset' ]
272
-
273
- origin_columns = [col for col in score_columns if 'origin' in col ]
274
- parallel_columns = [col for col in score_columns if 'parallel' in col ]
275
- multi_columns = [col for col in score_columns if 'needle' in col ]
276
-
277
- if origin_columns and parallel_columns and multi_columns :
278
- origin_avg = merged_df [origin_columns ].mean (axis = 1 ) * 0.4
279
- parallel_avg = merged_df [parallel_columns ].mean (axis = 1 ) * 0.3
280
- multi_avg = merged_df [multi_columns ].mean (axis = 1 ) * 0.3
281
-
282
- merged_df ['weighted_average_score' ] = origin_avg + parallel_avg + multi_avg
283
- else :
284
- merged_df ['weighted_average_score' ] = pd .Series ([0 ] * len (merged_df ))
285
-
286
- return merged_df .iloc [:, [0 , - 1 ]]
287
-
288
-
289
394
class NeedleBenchSummarizer (DefaultSummarizer ):
290
395
"""NeedleBench summarizer in OpenCompass.
291
396
@@ -303,20 +408,17 @@ def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode):
303
408
304
409
summarizer_dataset_abbrs = []
305
410
if self .dataset_abbrs is None :
306
- # display all dataset metrics included in the config
307
411
for dataset_abbr in dataset_abbrs :
308
412
if dataset_abbr in dataset_metrics :
309
413
for metric in dataset_metrics [dataset_abbr ]:
310
414
summarizer_dataset_abbrs .append ((dataset_abbr , metric ))
311
415
else :
312
416
summarizer_dataset_abbrs .append ((dataset_abbr , None ))
313
- # along with all possible group metrics
314
417
for dataset_abbr in dataset_metrics :
315
418
for metric in dataset_metrics [dataset_abbr ]:
316
419
if (dataset_abbr , metric ) not in summarizer_dataset_abbrs :
317
420
summarizer_dataset_abbrs .append ((dataset_abbr , metric ))
318
421
else :
319
- # follow the required order
320
422
for item in self .dataset_abbrs :
321
423
if isinstance (item , str ):
322
424
summarizer_dataset_abbrs .append ((item , None ))
@@ -332,6 +434,7 @@ def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode):
332
434
333
435
for dataset_abbr , metric in summarizer_dataset_abbrs :
334
436
if dataset_abbr not in dataset_metrics :
437
+
335
438
table .append ([dataset_abbr , '-' , '-' , '-' ] + ['-' ] * len (self .model_abbrs ))
336
439
table .append (header )
337
440
continue
@@ -378,33 +481,7 @@ def _format_raw_txt(self, raw_results):
378
481
raw_txts = '\n ' .join (raw_txts )
379
482
return raw_txts
380
483
381
- def _read_and_sort_dataframe (self , file_path ):
382
- # Read the file without treating the first row as a header
383
- df = pd .read_csv (file_path , header = None )
384
-
385
- # Function to sort columns based on the value of a specific row, excluding the first column
386
- def sort_columns_based_on_row_corrected (df , base_row_idx , start_row_idx , end_row_idx ):
387
- # Extract the rows for sorting
388
- sort_values_row = df .iloc [base_row_idx , 1 :].replace ('-' , np .nan ).apply (pd .to_numeric , errors = 'coerce' )
389
- # Handle NaNs by setting them to a value less than the minimum or using a method to keep them at the end
390
- min_possible_value = sort_values_row .min (skipna = True ) - 1 # Use min value in the row minus 1 or another method
391
- sort_values_row_filled = sort_values_row .fillna (min_possible_value )
392
- # Get the sorted order of indices, excluding the first column
393
- sorted_col_indices = sort_values_row_filled .sort_values (ascending = False ).index
394
- # Apply the sorted column indices to the whole DataFrame, adjusting for Python's 0-based index
395
- df .iloc [start_row_idx :end_row_idx + 1 ] = df .iloc [start_row_idx :end_row_idx + 1 , [0 ] + sorted_col_indices .tolist ()]
396
-
397
- # Apply the corrected sorting function based on the description
398
- sort_columns_based_on_row_corrected (df , 1 , 0 , 2 ) # For rows 1-2 based on row 2's values
399
- sort_columns_based_on_row_corrected (df , 4 , 3 , 7 ) # For rows 4-7 based on row 5's values
400
- sort_columns_based_on_row_corrected (df , 9 , 8 , 12 ) # For rows 9-12 based on row 10's values
401
- sort_columns_based_on_row_corrected (df , 14 , 13 , 25 ) # For rows 14-25 based on row 15's values
402
-
403
- # Return the sorted DataFrame
404
- return df
405
-
406
484
def _output_to_file (self , output_path , time_str , table , raw_txts ):
407
- # output to file
408
485
if output_path is None :
409
486
output_path = osp .join (self .work_dir , 'summary' , f'summary_{ time_str } .txt' )
410
487
output_csv_path = osp .join (self .work_dir , 'summary' , f'summary_{ time_str } .csv' )
@@ -436,38 +513,19 @@ def _output_to_file(self, output_path, time_str, table, raw_txts):
436
513
f .write ('\n ' .join ([',' .join (row ) for row in table ]) + '\n ' )
437
514
self .logger .info (f'write csv to { osp .abspath (output_csv_path )} ' )
438
515
439
- df_sorted = self ._read_and_sort_dataframe (output_csv_path )
440
-
441
- sorted_file_path = osp .abspath (output_csv_path ).split ('.' )[0 ] + '_sorted.csv'
442
- df_sorted .to_csv (sorted_file_path , index = False , header = False )
443
-
444
- self .logger .info (f'write sorted csv to { sorted_file_path } ' )
445
-
446
516
447
517
def summarize (
448
518
self ,
449
519
output_path : str = None ,
450
520
time_str : str = datetime .now ().strftime ('%Y%m%d_%H%M%S' )): # noqa
451
521
452
- # pick up results
453
522
raw_results , parsed_results , dataset_metrics , dataset_eval_mode = self ._pick_up_results ()
454
-
455
- # calculate group metrics
456
523
raw_results , parsed_results , dataset_metrics , dataset_eval_mode = \
457
524
self ._calculate_group_metrics (raw_results , parsed_results , dataset_metrics , dataset_eval_mode )
458
-
459
- # format table
460
525
table = self ._format_table (parsed_results , dataset_metrics , dataset_eval_mode )
461
-
462
- # format raw txt
463
526
raw_txts = self ._format_raw_txt (raw_results )
464
-
465
- # output to screen
466
527
print (tabulate .tabulate (table , headers = 'firstrow' ))
467
-
468
- # output to .text / .csv files
469
528
self ._output_to_file (output_path , time_str , table , raw_txts )
470
-
471
529
if self .lark_reporter :
472
530
content = f'{ getpass .getuser ()} 的'
473
531
content += f'详细评测汇总已输出至 { osp .abspath (output_path )} '
0 commit comments