Skip to content

Commit 0ba142c

Browse files
committed
Resolved Type hints for primitives and string arguments Improve create_multi_comparison_matrix parameters and saving #2650
1 parent adfcc59 commit 0ba142c

File tree

2 files changed

+82
-72
lines changed

2 files changed

+82
-72
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"name":"Local: _mcm","url":"c:\\Users\\Pravas\\Documents\\Proj\\aeon\\aeon\\visualisation\\results\\_mcm.py","tests":[{"id":1742476901320,"input":"","output":""}],"interactive":false,"memoryLimit":1024,"timeLimit":3000,"srcPath":"c:\\Users\\Pravas\\Documents\\Proj\\aeon\\aeon\\visualisation\\results\\_mcm.py","group":"local","local":true}

aeon/visualisation/results/_mcm.py

+81-72
Original file line numberDiff line numberDiff line change
@@ -6,51 +6,52 @@
66

77
import json
88
import os
9-
9+
from typing import Dict, List, Optional, Union
10+
import logging
11+
import matplotlib.pyplot as plt
1012
import numpy as np
1113
import pandas as pd
1214
from scipy.stats import wilcoxon
1315

1416
from aeon.utils.validation._dependencies import _check_soft_dependencies
1517

18+
# Set up logging
19+
logging.basicConfig(level=logging.INFO)
20+
logger = logging.getLogger(__name__)
1621

1722
def create_multi_comparison_matrix(
18-
df_results,
19-
output_dir="./",
20-
pdf_savename=None,
21-
png_savename=None,
22-
csv_savename=None,
23-
tex_savename=None,
24-
used_statistic="Accuracy",
25-
save_as_json=False,
26-
plot_1v1_comparisons=False,
27-
order_win_tie_loss="higher",
28-
include_pvalue=True,
29-
pvalue_test="wilcoxon",
30-
pvalue_test_params=None,
31-
pvalue_correction=None,
32-
pvalue_threshold=0.05,
33-
use_mean="mean-difference",
34-
order_stats="average-statistic",
35-
order_better="decreasing",
36-
dataset_column=None,
37-
precision=4,
38-
load_analysis=False,
39-
row_comparates=None,
40-
col_comparates=None,
41-
excluded_row_comparates=None,
42-
excluded_col_comparates=None,
43-
colormap="coolwarm",
44-
fig_size="auto",
45-
font_size="auto",
46-
colorbar_orientation="vertical",
47-
colorbar_value=None,
48-
win_label="r>c",
49-
tie_label="r=c",
50-
loss_label="r<c",
51-
include_legend=True,
52-
show_symetry=True,
53-
):
23+
df_results: Union[str, pd.DataFrame],
24+
output_dir: str = "./",
25+
save_files: Optional[Dict[str, str]] = None,
26+
used_statistic: str = "Accuracy",
27+
save_as_json: bool = True,
28+
plot_1v1_comparisons: bool = False,
29+
order_settings: Dict[str, str] = {"win_tie_loss": "higher", "better": "decreasing"},
30+
include_pvalue: bool = True,
31+
pvalue_test: str = "wilcoxon",
32+
pvalue_test_params: Optional[Dict[str, str]] = None,
33+
pvalue_correction: Optional[str] = None,
34+
pvalue_threshold: float = 0.05,
35+
use_mean: str = "mean-difference",
36+
order_stats: str = "average-statistic",
37+
dataset_column: Optional[str] = None,
38+
precision: int = 4,
39+
load_analysis: bool = False,
40+
row_comparates: Optional[List[str]] = None,
41+
col_comparates: Optional[List[str]] = None,
42+
excluded_row_comparates: Optional[List[str]] = None,
43+
excluded_col_comparates: Optional[List[str]] = None,
44+
colormap: str = "coolwarm",
45+
fig_size: Union[str, tuple] = "auto",
46+
font_size: Union[str, int] = "auto",
47+
colorbar_orientation: str = "vertical",
48+
colorbar_value: Optional[str] = None,
49+
win_label: str = "r>c",
50+
tie_label: str = "r=c",
51+
loss_label: str = "r<c",
52+
include_legend: bool = True,
53+
show_symmetry: bool = True,
54+
) -> plt.Figure:
5455
"""Generate the Multi-Comparison Matrix (MCM) [1]_.
5556
5657
MCM summarises a set of results for multiple estimators evaluated on multiple
@@ -64,27 +65,15 @@ def create_multi_comparison_matrix(
6465
row should contain the names of the estimators and the first column can
6566
contain the names of the problems if `dataset_column` is true.
6667
output_dir: str, default = './'
67-
The output directory for the results.
68-
pdf_savename: str, default = None
69-
The name of the saved file into pdf format. if None, it will not be saved into
70-
this format.
71-
png_savename: str, default = None
72-
The name of the saved file into png format, if None, it will not be saved
73-
into this format.
74-
csv_savename: str, default = None
75-
The name of the saved file into csv format, if None, will not be saved into
76-
this format.
77-
tex_savename: str, default = None
78-
The name of the saved file into tex format, if None, will not be saved into
79-
this format.
68+
The output directory for the results.
69+
save_files: dict, default = None
70+
Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
8071
used_statistic: str, default = 'Score'
8172
Name of the metric being assesses (e.g. accuracy, error, mse).
8273
save_as_json: bool, default = True
8374
Whether or not to save the python analysis dict into a json file format.
8475
plot_1v1_comparisons: bool, default = True
8576
Whether or not to plot the 1v1 scatter results.
86-
order_win_tie_loss: str, default = 'higher'
87-
The order on considering a win or a loss for a given statistics.
8877
include_pvalue bool, default = True
8978
Condition whether or not include a pvalue stats.
9079
pvalue_test: str, default = 'wilcoxon'
@@ -117,8 +106,8 @@ def create_multi_comparison_matrix(
117106
amean-amean average over difference of use_mean
118107
pvalue average pvalue over all comparates
119108
================================================================
120-
order_better: str, default = 'decreasing'
121-
By which order to sort stats, from best to worse.
109+
order_settings: dict, default = {"win_tie_loss": "higher", "better": "decreasing"}
110+
Settings for ordering the results.
122111
dataset_column: str, default = 'dataset_name'
123112
The name of the datasets column in the csv file.
124113
precision: int, default = 4
@@ -178,45 +167,46 @@ def create_multi_comparison_matrix(
178167
Evaluations That Is Stable Under Manipulation Of The Comparate Set
179168
arXiv preprint arXiv:2305.11921, 2023.
180169
"""
170+
171+
logger.info("Starting MCM creation...")
172+
181173
if isinstance(df_results, str):
182174
try:
183175
df_results = pd.read_csv(df_results)
176+
except FileNotFoundError:
177+
raise FileNotFoundError(f"The file {df_results} was not found.")
178+
except pd.errors.EmptyDataError:
179+
raise ValueError(f"The file {df_results} is empty.")
184180
except Exception as e:
185-
raise ValueError(f"No dataframe or valid path is given: Exception {e}")
186-
181+
raise ValueError(f"An error occurred while reading the file: {e}")
182+
187183
analysis = _get_analysis(
188184
df_results,
189185
output_dir=output_dir,
190186
used_statistic=used_statistic,
191187
save_as_json=save_as_json,
192188
plot_1v1_comparisons=plot_1v1_comparisons,
193-
order_win_tie_loss=order_win_tie_loss,
189+
order_win_tie_loss=order_settings["win_tie_loss"],
194190
include_pvalue=include_pvalue,
195191
pvalue_test=pvalue_test,
196192
pvalue_test_params=pvalue_test_params,
197193
pvalue_correction=pvalue_correction,
198-
pvalue_threshhold=pvalue_threshold,
194+
pvalue_threshold=pvalue_threshold,
199195
use_mean=use_mean,
200196
order_stats=order_stats,
201-
order_better=order_better,
197+
order_better=order_settings["better"],
202198
dataset_column=dataset_column,
203199
precision=precision,
204200
load_analysis=load_analysis,
205-
)
206-
207-
# start drawing heatmap
208-
temp = _draw(
209-
analysis,
210-
pdf_savename=pdf_savename,
211-
png_savename=png_savename,
212-
tex_savename=tex_savename,
213-
csv_savename=csv_savename,
214-
output_dir=output_dir,
215201
row_comparates=row_comparates,
216202
col_comparates=col_comparates,
217203
excluded_row_comparates=excluded_row_comparates,
218204
excluded_col_comparates=excluded_col_comparates,
219-
precision=precision,
205+
)
206+
207+
# Generate figure
208+
fig = _draw(
209+
analysis,
220210
colormap=colormap,
221211
fig_size=fig_size,
222212
font_size=font_size,
@@ -226,9 +216,28 @@ def create_multi_comparison_matrix(
226216
tie_label=tie_label,
227217
loss_label=loss_label,
228218
include_legend=include_legend,
229-
show_symetry=show_symetry,
219+
show_symmetry=show_symmetry,
230220
)
231-
return temp
221+
222+
223+
# start drawing heatmap
224+
if save_files:
225+
for file_type, file_name in save_files.items():
226+
if file_name:
227+
save_path = os.path.join(output_dir, file_name)
228+
if file_type == "pdf":
229+
fig.savefig(save_path, format="pdf", bbox_inches="tight")
230+
elif file_type == "png":
231+
fig.savefig(save_path, format="png", dpi=300)
232+
elif file_type == "csv":
233+
df_results.to_csv(save_path, index=False)
234+
elif file_type == "tex":
235+
with open(save_path, "w") as f:
236+
f.write(analysis.to_latex())
237+
238+
logger.info("MCM creation completed.")
239+
return fig
240+
232241

233242

234243
def _get_analysis(

0 commit comments

Comments
 (0)