Skip to content

Commit 16f29b2

Browse files
authoredApr 7, 2024··
[Fix] Simplify needlebench summarizer (#1024)
* Conflicts: configs/summarizers/needlebench.py * fix lint problems
1 parent f2af493 commit 16f29b2

File tree

4 files changed

+505
-865
lines changed

4 files changed

+505
-865
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from mmengine.config import read_base
2+
with read_base():
3+
from .atc_choice_20 import *
4+
5+
needle_num_list = list(range(2, 50, 1))
6+
needlebench_datasets = []
7+
8+
for _name in list(single_choice_prompts.keys()):
9+
10+
needlebench_atc_infer_cfg = dict(
11+
prompt_template=dict(
12+
type=PromptTemplate,
13+
template=dict(
14+
round=(single_choice_prompts[_name])),
15+
),
16+
retriever=dict(type=ZeroRetriever),
17+
inferencer=dict(type=GenInferencer,),
18+
)
19+
20+
needlebench_atc_eval_cfg = dict(
21+
evaluator=dict(type=CircularEvaluator),
22+
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'))
23+
24+
for num_needles in needle_num_list:
25+
abbr = (f'NeedleBenchATCDataset-'
26+
f'{num_needles}Needle-{"EN" if "en" in _name else "ZH"}')
27+
language = "English" if "en" in _name else "Chinese"
28+
if 'reasoning' in _name:
29+
abbr += '-Reasoning'
30+
dataset_dict = {
31+
'abbr': abbr,
32+
'type': NeedleBenchATCDataset,
33+
'path': names_path,
34+
'num_needles': num_needles,
35+
'language': language,
36+
'repeats': repeats,
37+
'with_circular': with_circular_eval,
38+
'reader_cfg': needlebench_atc_reader_cfg,
39+
'infer_cfg': needlebench_atc_infer_cfg,
40+
'eval_cfg': needlebench_atc_eval_cfg
41+
}
42+
needlebench_datasets.append(dataset_dict)
43+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from mmengine.config import read_base
2+
with read_base():
3+
from .atc_choice_20 import *
4+
5+
needle_num_list = list(range(2, 80, 1))
6+
needlebench_datasets = []
7+
8+
for _name in list(single_choice_prompts.keys()):
9+
10+
needlebench_atc_infer_cfg = dict(
11+
prompt_template=dict(
12+
type=PromptTemplate,
13+
template=dict(
14+
round=(single_choice_prompts[_name])),
15+
),
16+
retriever=dict(type=ZeroRetriever),
17+
inferencer=dict(type=GenInferencer,),
18+
)
19+
20+
needlebench_atc_eval_cfg = dict(
21+
evaluator=dict(type=CircularEvaluator),
22+
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'))
23+
24+
for num_needles in needle_num_list:
25+
abbr = (f'NeedleBenchATCDataset-'
26+
f'{num_needles}Needle-{"EN" if "en" in _name else "ZH"}')
27+
language = "English" if "en" in _name else "Chinese"
28+
if 'reasoning' in _name:
29+
abbr += '-Reasoning'
30+
dataset_dict = {
31+
'abbr': abbr,
32+
'type': NeedleBenchATCDataset,
33+
'path': names_path,
34+
'num_needles': num_needles,
35+
'language': language,
36+
'repeats': repeats,
37+
'with_circular': with_circular_eval,
38+
'reader_cfg': needlebench_atc_reader_cfg,
39+
'infer_cfg': needlebench_atc_infer_cfg,
40+
'eval_cfg': needlebench_atc_eval_cfg
41+
}
42+
needlebench_datasets.append(dataset_dict)
43+

‎configs/summarizers/needlebench.py

+188-692
Large diffs are not rendered by default.

‎opencompass/summarizers/needlebench.py

+231-173
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import os
77
import os.path as osp
8+
import shutil
89
from datetime import datetime
910
from typing import Any, Dict, List, Optional
1011

@@ -26,6 +27,92 @@
2627
model_abbr_from_cfg)
2728
from opencompass.utils.prompt import get_prompt_hash
2829

30+
model_name_mapping = {
31+
'llama-2-7b-chat-hf': 'LLaMA-2-7B',
32+
'llama-2-13b-chat-hf': 'LLaMA-2-13B',
33+
'llama-2-70b-chat-hf': 'LLaMA-2-70B',
34+
'baichuan2-7b-chat-hf': 'Baichuan2-7B',
35+
'baichuan2-13b-chat-hf': 'Baichuan2-13B',
36+
'yi-6b-chat-hf': 'Yi-6B',
37+
'yi-34b-chat-hf': 'Yi-34B',
38+
'deepseek-67b-chat-hf': 'DeepSeek-67B',
39+
'wizardlm-70b-v1.0-vllm': 'WizardLM-70B',
40+
'qwen-14b-chat-hf': 'Qwen-14B',
41+
'qwen-72b-chat-hf': 'Qwen-72B',
42+
'qwen-72b-chat-vllm': 'Qwen-72B-vLLM',
43+
'internlm2-chat-7b-turbomind': 'InternLM2-7B-200K',
44+
'internlm2-chat-20b-turbomind': 'InternLM2-20B-200K',
45+
'internlm2-chat-7b-hf': 'InternLM2-7B',
46+
'internlm2-chat-20b-hf': 'InternLM2-20B',
47+
'qwen-7b-chat-hf': 'Qwen-7B',
48+
'chatglm3-6b-hf': 'ChatGLM3-6B',
49+
'chatglm3-6b-32k-hf': 'ChatGLM3-6B-32K',
50+
'zephyr-7b-beta-vllm': 'Zephyr-7B Beta',
51+
'mistral-7b-instruct-v0.2-vllm': 'Mistral-7B Inst. v0.2',
52+
'mistral-7b-instruct-v0.1-vllm': 'Mistral-7B Inst. v0.1',
53+
'mixtral-8x7b-instruct-v0.1-vllm': 'Mixtral-8x7B Inst. v0.1',
54+
'orionstar-yi-34b-chat-hf': 'OrionStar-Yi-34B',
55+
'orionstar-14b-long-chat-vllm': 'Orion-14B-LongChat',
56+
'internlm-chat-7b-hf': 'InternLM-7B',
57+
'gemma-2b-it-hf': 'Gemma-2B',
58+
'gemma-7b-it-hf': 'Gemma-7B',
59+
'qwen1.5-0.5b-chat-hf': 'Qwen-1.5-0.5B',
60+
'qwen1.5-1.8b-chat-hf': 'Qwen-1.5-1.8B',
61+
'qwen1.5-4b-chat-hf': 'Qwen-1.5-4B',
62+
'qwen1.5-14b-chat-hf': 'Qwen-1.5-14B',
63+
'qwen1.5-72b-chat-hf': 'Qwen-1.5-72B',
64+
'qwen1.5-14b-chat-vllm': 'Qwen-1.5-14B-vLLM',
65+
'qwen1.5-72b-chat-vllm': 'Qwen-1.5-72B-vLLM',
66+
'glm4_notools': 'GLM-4',
67+
'claude-3-opus': 'Claude-3-Opus',
68+
# Add more mappings as necessary
69+
}
70+
71+
dataset_mapping_dict = {}
72+
73+
needle_counts = ['2', '3', '4', '5']
74+
languages = ['en', 'zh']
75+
sizes = ['4k', '8k', '32k', '200k', '1000k']
76+
types = ['origin', 'parallel']
77+
78+
for needle_count in needle_counts:
79+
for language in languages:
80+
for size in sizes:
81+
key = f'{needle_count}needle_{language}_{size}'
82+
value = f'{needle_count}-Needle-Reasoning-{language.upper()}-{size.upper()}'
83+
dataset_mapping_dict[key] = value
84+
for t in types:
85+
for language in languages:
86+
for size in sizes:
87+
if t == 'origin':
88+
key = f'{t}_{language}_{size}'
89+
value = f'Single-Needle-Retrieval-{language.upper()}-{size.upper()}'
90+
elif t == 'parallel':
91+
key = f'{t}_{language}_{size}'
92+
value = f'Multi-Needle-Retrieval-{language.upper()}-{size.upper()}'
93+
dataset_mapping_dict[key] = value
94+
95+
96+
def calculate_elementwise_average(model_name, merged_df):
97+
score_columns = [col for col in merged_df.columns if col != 'dataset']
98+
99+
origin_columns = [col for col in score_columns if 'origin' in col]
100+
parallel_columns = [col for col in score_columns if 'parallel' in col]
101+
multi_columns = [col for col in score_columns if 'needle' in col]
102+
103+
if origin_columns and parallel_columns and multi_columns:
104+
origin_avg = merged_df[origin_columns].mean(axis=1) * 0.4
105+
parallel_avg = merged_df[parallel_columns].mean(axis=1) * 0.3
106+
multi_avg = merged_df[multi_columns].mean(axis=1) * 0.3
107+
merged_df[model_name] = origin_avg + parallel_avg + multi_avg
108+
else:
109+
relevant_columns = origin_columns or parallel_columns or multi_columns
110+
if relevant_columns:
111+
merged_df[model_name] = merged_df[relevant_columns].mean(axis=1)
112+
else:
113+
merged_df[model_name] = pd.Series([0] * len(merged_df))
114+
115+
return merged_df.iloc[:, [0, -1]]
29116

30117
def read_after_specific_line_except_last(file_name, keyword, offset):
31118
with open(file_name, 'r', encoding='utf-8') as file:
@@ -65,6 +152,12 @@ def create_model_dataframe(nested_dict, model_name, dataset_abbr, parallel=False
65152
df = pd.DataFrame(data, columns=['dataset', model_name])
66153
return df
67154

155+
def convert_to_k(value):
156+
try:
157+
return f'{int(value) // 1000}k'
158+
except ValueError:
159+
return value
160+
68161
def parse_model_scores(text):
69162
lines = text.split('\n')
70163

@@ -82,8 +175,86 @@ def parse_model_scores(text):
82175

83176
return result_dict
84177

178+
def remove_empty_subfolders(plot_path):
179+
for folder_name in tqdm(os.listdir(plot_path),
180+
desc='Deleting Empty folders'):
181+
folder_path = os.path.join(plot_path, folder_name)
182+
if os.path.isdir(folder_path):
183+
if not os.listdir(folder_path):
184+
shutil.rmtree(folder_path)
185+
186+
def save_results_to_plots(txt_results_save_path):
187+
content = read_after_specific_line_except_last(txt_results_save_path, 'raw format', 2)
188+
parsed_data = parse_model_scores(content)
189+
model_names = get_dict_model_names(parsed_data)
190+
numbers = [2, 3, 4, 5]
191+
languages = ['en', 'zh']
192+
size_exists = []
193+
sizes_origin = ['_4k', '_8k', '_32k', '_128k', '_200k', '_1000k']
194+
195+
for size in sizes_origin:
196+
if size in content:
197+
size_exists.append(size)
198+
199+
multi_dataset_abbrs = [f'{num}needle_{lang}{size}' for num in numbers for lang in languages for size in size_exists]
200+
origin_dataset_abbrs = [f'origin_{lang}{size}' for lang in languages for size in size_exists]
201+
parallel_dataset_abbrs = [f'parallel_{lang}{size}' for lang in languages for size in size_exists]
202+
203+
dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + \
204+
parallel_dataset_abbrs
205+
base_path = os.path.dirname(txt_results_save_path)
206+
plot_path = os.path.join(base_path, 'plots')
207+
208+
model_scores = {}
209+
210+
for model_name in tqdm(model_names):
211+
model_datasets_scores = {} # Dictionary to store scores for each dataset for the current model
212+
for dataset_abbr in dataset_abbrs:
213+
parallel_flag = 'parallel' in dataset_abbr
214+
215+
folder_path = os.path.join(plot_path, dataset_mapping_dict[dataset_abbr])
216+
ensure_directory(folder_path)
217+
218+
save_path = os.path.join(folder_path, f'{model_name}.png')
219+
220+
df = create_model_dataframe(parsed_data, model_name, dataset_abbr, parallel=parallel_flag)
221+
222+
score = visualize(df, save_path, model_name, dataset_abbr)
223+
224+
model_datasets_scores[dataset_abbr] = '{:.02f}'.format(score)
225+
226+
overall_dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + parallel_dataset_abbrs
227+
overall_score_pic_path = os.path.join(plot_path, f'{model_name}_overall.png')
228+
merged_df = merge_dataframes(model_name, overall_dataset_abbrs, parsed_data)
229+
averaged_df = calculate_elementwise_average(model_name, merged_df)
230+
overall_score = visualize(averaged_df, overall_score_pic_path, model_name, 'Overall Score')
231+
232+
# Single-Retrieval
233+
single_retrieval_score_pic_path = os.path.join(plot_path, f'{model_name}_single_retrieval_overall.png')
234+
single_retrieval_merged_df = merge_dataframes(model_name, origin_dataset_abbrs, parsed_data)
235+
single_retrieval_averaged_df = calculate_elementwise_average(model_name, single_retrieval_merged_df)
236+
single_retrieval_overall_score = visualize(single_retrieval_averaged_df, single_retrieval_score_pic_path, model_name, 'Single-Retrieval Overall Score')
237+
238+
# Multi-Retrieval
239+
multi_retrieval_score_pic_path = os.path.join(plot_path, f'{model_name}_multi_retrieval_overall.png')
240+
multi_retrieval_merged_df = merge_dataframes(model_name, parallel_dataset_abbrs, parsed_data)
241+
multi_retrieval_averaged_df = calculate_elementwise_average(model_name, multi_retrieval_merged_df)
242+
multi_retrieval_overall_score = visualize(multi_retrieval_averaged_df, multi_retrieval_score_pic_path, model_name, 'Multi-Retrieval Overall Score')
243+
244+
# Multi-Reasoning
245+
multi_reasoning_score_pic_path = os.path.join(plot_path, f'{model_name}_multi_reasoning_overall.png')
246+
multi_reasoning_merged_df = merge_dataframes(model_name, multi_dataset_abbrs, parsed_data)
247+
multi_reasoning_averaged_df = calculate_elementwise_average(model_name, multi_reasoning_merged_df)
248+
multi_reasoning_overall_score = visualize(multi_reasoning_averaged_df, multi_reasoning_score_pic_path, model_name, 'Multi-Reasoning Overall Score')
249+
250+
model_scores[model_name] = averaged_df
251+
remove_empty_subfolders(plot_path)
252+
return model_scores
253+
85254
def visualize(df_raw, save_path: str,model_name: str ,dataset_type:str):
86255
df = df_raw.copy()
256+
if df.empty:
257+
return -1
87258
df['Context Length'] = df['dataset'].apply(
88259
lambda x: int(x.split('Length')[1].split('Depth')[0]))
89260
df['Document Depth'] = df['dataset'].apply(
@@ -98,144 +269,96 @@ def visualize(df_raw, save_path: str,model_name: str ,dataset_type:str):
98269
model_df = df[['Document Depth', 'Context Length',
99270
model_name]].copy()
100271
model_df.rename(columns={model_name: 'Score'}, inplace=True)
101-
102-
# Create pivot table
103272
pivot_table = pd.pivot_table(model_df,
104-
values='Score',
105-
index=['Document Depth'],
106-
columns=['Context Length'],
107-
aggfunc='mean')
273+
values='Score',
274+
index=['Document Depth'],
275+
columns=['Context Length'],
276+
aggfunc='mean')
108277

109-
# Calculate mean scores
110278
mean_scores = pivot_table.mean().values
111-
112-
# Calculate overall score
113279
overall_score = mean_scores.mean()
114-
115-
# Create heatmap and line plot
116-
plt.figure(figsize=(15.5, 8))
280+
plt.figure(figsize=(10, 6))
117281
ax = plt.gca()
118282
cmap = LinearSegmentedColormap.from_list(
119283
'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F'])
120284

121-
# Draw heatmap
122285
sns.heatmap(pivot_table,
123286
cmap=cmap,
124287
ax=ax,
125-
cbar_kws={'label': 'Score'},
126288
vmin=0,
127289
vmax=100)
128-
129-
# Set line plot data
290+
cbar = ax.collections[0].colorbar
130291
x_data = [i + 0.5 for i in range(len(mean_scores))]
131292
y_data = mean_scores
132293

133-
# Create twin axis for line plot
134294
ax2 = ax.twinx()
135-
# Draw line plot
136295
ax2.plot(x_data,
137-
y_data,
138-
color='white',
139-
marker='o',
140-
linestyle='-',
141-
linewidth=2,
142-
markersize=8,
143-
label='Average Depth Score')
144-
# Set y-axis range
296+
y_data,
297+
color='white',
298+
marker='o',
299+
linestyle='-',
300+
linewidth=2,
301+
markersize=8,
302+
label='Average Depth Score'
303+
)
145304
ax2.set_ylim(0, 100)
146305

147-
# Hide original y-axis ticks and labels
148306
ax2.set_yticklabels([])
149307
ax2.set_yticks([])
150308

151-
# Add legend
152-
ax2.legend(loc='upper left')
153-
154-
# Set chart title and labels
155-
ax.set_title(f'{model_name} {dataset_type} Context '
156-
'Performance\nFact Retrieval Across '
157-
'Context Lengths ("Needle In A Haystack")')
158-
ax.set_xlabel('Token Limit')
159-
ax.set_ylabel('Depth Percent')
160-
ax.set_xticklabels(pivot_table.columns.values, rotation=45)
161-
ax.set_yticklabels(pivot_table.index.values, rotation=0)
162-
# Add overall score as a subtitle
163-
plt.text(0.5,
164-
-0.13, f'Overall Score for {model_name}: '
165-
f'{overall_score:.2f}',
166-
ha='center',
167-
va='center',
168-
transform=ax.transAxes,
169-
fontsize=13)
170-
171-
plt.tight_layout()
172-
plt.subplots_adjust(right=1)
173-
plt.draw()
174-
plt.savefig(save_path)
175-
print(f'Saved :{save_path}')
176-
plt.close() # Close figure to prevent memory leaks
177-
return overall_score
178-
179-
def save_results_to_plots(txt_results_save_path):
180-
181-
content = read_after_specific_line_except_last(txt_results_save_path, 'raw format', 2)
182-
183-
parsed_data = parse_model_scores(content)
184-
model_names = get_dict_model_names(parsed_data)
185-
numbers = [2, 3, 4, 5]
186-
languages = ['en', 'zh']
187-
size_exists = []
188-
sizes_origin = ['_4k', '_8k', '_32k', '_128k', '_200k']
189-
190-
for size in sizes_origin:
191-
if size in content:
192-
size_exists.append(size)
193-
194-
multi_dataset_abbrs = [f'{num}needle_{lang}{size}' for num in numbers for lang in languages for size in size_exists]
195-
origin_dataset_abbrs = [f'origin_{lang}{size}' for lang in languages for size in size_exists]
196-
parallel_dataset_abbrs = [f'parallel_{lang}{size}' for lang in languages for size in size_exists]
197-
198-
dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + \
199-
parallel_dataset_abbrs
200-
base_path = os.path.dirname(txt_results_save_path)
201-
plot_path = os.path.join(base_path, 'plots')
202-
model_scores = {}
203-
for model_name in tqdm(model_names):
204-
model_datasets_scores = {} # Dictionary to store scores for each dataset for the current model
205-
for dataset_abbr in dataset_abbrs:
206-
parallel_flag = 'parallel' in dataset_abbr
309+
ax2.legend(loc='lower left')
207310

208-
# Create a directory for each dataset_abbr
209-
folder_path = os.path.join(plot_path, dataset_abbr)
210-
ensure_directory(folder_path)
311+
if model_name in model_name_mapping:
312+
title_name = model_name_mapping[model_name]
313+
else:
314+
title_name = model_name
211315

212-
# Construct the full path to save the image
213-
save_path = os.path.join(folder_path, f'{model_name}.png')
316+
ax.set_title(title_name, fontsize=12, fontweight='bold', pad=15)
214317

215-
# Create DataFrame for the model and dataset
216-
df = create_model_dataframe(parsed_data, model_name, dataset_abbr, parallel=parallel_flag)
318+
if dataset_type in dataset_mapping_dict:
319+
dataset_name = dataset_mapping_dict[dataset_type]
320+
else:
321+
dataset_name = dataset_type
322+
323+
ax.text(0.5, 1.005, f'{dataset_name}:{overall_score:.2f}',
324+
transform=ax.transAxes,
325+
ha='center',
326+
fontsize=12,
327+
fontweight='normal')
328+
ax.set_xlabel('Token Length', fontsize=13, fontweight='normal', labelpad=1)
329+
ax.set_ylabel('Depth Percent(%)', fontsize=13, fontweight='normal', labelpad=1)
330+
converted_labels = [convert_to_k(value) for value in pivot_table.columns.values]
331+
332+
ax.tick_params(axis='both', which='major', length=1, pad=1)
333+
ax.tick_params(axis='both', which='minor', length=1, pad=1)
334+
ax.set_xticklabels(converted_labels, rotation=45)
335+
index_length = len(pivot_table.index)
336+
337+
selected_indices = pivot_table.index.values[::2]
338+
labels = [str(int(index)) for index in selected_indices]
339+
ax.set_yticks(np.arange(0, len(pivot_table.index), 2))
340+
ax.set_yticklabels(labels, rotation=0)
341+
for spine in ax.spines.values():
342+
spine.set_visible(False)
343+
for spine in ax2.spines.values():
344+
spine.set_visible(False)
217345

218-
# Generate visualization and get the score
219-
score = visualize(df, save_path, model_name, dataset_abbr)
346+
plt.tight_layout()
347+
plt.draw()
348+
directory_path, original_filename = os.path.split(save_path)
220349

221-
# Store the score in the dictionary
222-
model_datasets_scores[dataset_abbr] = '{:.02f}'.format(score)
350+
filename_suffix = (title_name+'_'+dataset_name).replace(' ', '_')
351+
new_filename = f'{filename_suffix}.png'
223352

224-
# Process and visualize the overall score
225-
overall_score_pic_path = os.path.join(plot_path, f'{model_name}_overall.png')
226-
merged_df = merge_dataframes(model_name, dataset_abbrs, parsed_data)
353+
new_save_path = os.path.join(directory_path, new_filename)
227354

228-
print(merge_dataframes)
229-
averaged_df = calculate_elementwise_average(merged_df)
355+
plt.savefig(new_save_path, format='png', bbox_inches='tight', pad_inches=0)
356+
print(f'Saved :{new_save_path}')
230357

231-
# Assume visualize returns the average score for the overall visualization
232-
overall_score = visualize(averaged_df, overall_score_pic_path, 'weighted_average_score', 'Overall Score')
358+
plt.close()
233359

234-
# Add the overall score to the dictionary
235-
model_datasets_scores['Overall'] = '{:.02f}'.format(overall_score)
360+
return overall_score
236361

237-
# Add the model's scores to the main dictionary
238-
model_scores[model_name] = model_datasets_scores
239362

240363
def ensure_directory(path):
241364
if not os.path.exists(path):
@@ -263,29 +386,11 @@ def merge_dataframes(model_name, dataset_abbrs, parsed_data):
263386
merged_df = reduce(lambda left, right: pd.merge(left, right, on='dataset', how='outer'), dfs)
264387

265388
if merged_df.isnull().any().any():
266-
print('Warning: Some rows were filtered out due to NaN values. This is often due to mismatched row counts among DataFrames.')
389+
print('Warning: Some rows were filtered out due to NaN values. '
390+
'This is often due to mismatched row counts among DataFrames.')
267391
merged_df = merged_df.dropna()
268392
return merged_df
269393

270-
def calculate_elementwise_average(merged_df):
271-
score_columns = [col for col in merged_df.columns if col != 'dataset']
272-
273-
origin_columns = [col for col in score_columns if 'origin' in col]
274-
parallel_columns = [col for col in score_columns if 'parallel' in col]
275-
multi_columns = [col for col in score_columns if 'needle' in col]
276-
277-
if origin_columns and parallel_columns and multi_columns:
278-
origin_avg = merged_df[origin_columns].mean(axis=1) * 0.4
279-
parallel_avg = merged_df[parallel_columns].mean(axis=1) * 0.3
280-
multi_avg = merged_df[multi_columns].mean(axis=1) * 0.3
281-
282-
merged_df['weighted_average_score'] = origin_avg + parallel_avg + multi_avg
283-
else:
284-
merged_df['weighted_average_score'] = pd.Series([0] * len(merged_df))
285-
286-
return merged_df.iloc[:, [0, -1]]
287-
288-
289394
class NeedleBenchSummarizer(DefaultSummarizer):
290395
"""NeedleBench summarizer in OpenCompass.
291396
@@ -303,20 +408,17 @@ def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode):
303408

304409
summarizer_dataset_abbrs = []
305410
if self.dataset_abbrs is None:
306-
# display all dataset metrics included in the config
307411
for dataset_abbr in dataset_abbrs:
308412
if dataset_abbr in dataset_metrics:
309413
for metric in dataset_metrics[dataset_abbr]:
310414
summarizer_dataset_abbrs.append((dataset_abbr, metric))
311415
else:
312416
summarizer_dataset_abbrs.append((dataset_abbr, None))
313-
# along with all possible group metrics
314417
for dataset_abbr in dataset_metrics:
315418
for metric in dataset_metrics[dataset_abbr]:
316419
if (dataset_abbr, metric) not in summarizer_dataset_abbrs:
317420
summarizer_dataset_abbrs.append((dataset_abbr, metric))
318421
else:
319-
# follow the required order
320422
for item in self.dataset_abbrs:
321423
if isinstance(item, str):
322424
summarizer_dataset_abbrs.append((item, None))
@@ -332,6 +434,7 @@ def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode):
332434

333435
for dataset_abbr, metric in summarizer_dataset_abbrs:
334436
if dataset_abbr not in dataset_metrics:
437+
335438
table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs))
336439
table.append(header)
337440
continue
@@ -378,33 +481,7 @@ def _format_raw_txt(self, raw_results):
378481
raw_txts = '\n'.join(raw_txts)
379482
return raw_txts
380483

381-
def _read_and_sort_dataframe(self, file_path):
382-
# Read the file without treating the first row as a header
383-
df = pd.read_csv(file_path, header=None)
384-
385-
# Function to sort columns based on the value of a specific row, excluding the first column
386-
def sort_columns_based_on_row_corrected(df, base_row_idx, start_row_idx, end_row_idx):
387-
# Extract the rows for sorting
388-
sort_values_row = df.iloc[base_row_idx, 1:].replace('-', np.nan).apply(pd.to_numeric, errors='coerce')
389-
# Handle NaNs by setting them to a value less than the minimum or using a method to keep them at the end
390-
min_possible_value = sort_values_row.min(skipna=True) - 1 # Use min value in the row minus 1 or another method
391-
sort_values_row_filled = sort_values_row.fillna(min_possible_value)
392-
# Get the sorted order of indices, excluding the first column
393-
sorted_col_indices = sort_values_row_filled.sort_values(ascending=False).index
394-
# Apply the sorted column indices to the whole DataFrame, adjusting for Python's 0-based index
395-
df.iloc[start_row_idx:end_row_idx+1] = df.iloc[start_row_idx:end_row_idx+1, [0] + sorted_col_indices.tolist()]
396-
397-
# Apply the corrected sorting function based on the description
398-
sort_columns_based_on_row_corrected(df, 1, 0, 2) # For rows 1-2 based on row 2's values
399-
sort_columns_based_on_row_corrected(df, 4, 3, 7) # For rows 4-7 based on row 5's values
400-
sort_columns_based_on_row_corrected(df, 9, 8, 12) # For rows 9-12 based on row 10's values
401-
sort_columns_based_on_row_corrected(df, 14, 13, 25) # For rows 14-25 based on row 15's values
402-
403-
# Return the sorted DataFrame
404-
return df
405-
406484
def _output_to_file(self, output_path, time_str, table, raw_txts):
407-
# output to file
408485
if output_path is None:
409486
output_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.txt')
410487
output_csv_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.csv')
@@ -436,38 +513,19 @@ def _output_to_file(self, output_path, time_str, table, raw_txts):
436513
f.write('\n'.join([','.join(row) for row in table]) + '\n')
437514
self.logger.info(f'write csv to {osp.abspath(output_csv_path)}')
438515

439-
df_sorted = self._read_and_sort_dataframe(output_csv_path)
440-
441-
sorted_file_path = osp.abspath(output_csv_path).split('.')[0] + '_sorted.csv'
442-
df_sorted.to_csv(sorted_file_path, index=False, header=False)
443-
444-
self.logger.info(f'write sorted csv to {sorted_file_path}')
445-
446516

447517
def summarize(
448518
self,
449519
output_path: str = None,
450520
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): # noqa
451521

452-
# pick up results
453522
raw_results, parsed_results, dataset_metrics, dataset_eval_mode = self._pick_up_results()
454-
455-
# calculate group metrics
456523
raw_results, parsed_results, dataset_metrics, dataset_eval_mode = \
457524
self._calculate_group_metrics(raw_results, parsed_results, dataset_metrics, dataset_eval_mode)
458-
459-
# format table
460525
table = self._format_table(parsed_results, dataset_metrics, dataset_eval_mode)
461-
462-
# format raw txt
463526
raw_txts = self._format_raw_txt(raw_results)
464-
465-
# output to screen
466527
print(tabulate.tabulate(table, headers='firstrow'))
467-
468-
# output to .text / .csv files
469528
self._output_to_file(output_path, time_str, table, raw_txts)
470-
471529
if self.lark_reporter:
472530
content = f'{getpass.getuser()} 的'
473531
content += f'详细评测汇总已输出至 {osp.abspath(output_path)}'

0 commit comments

Comments
 (0)
Please sign in to comment.