6
6
7
7
import json
8
8
import os
9
-
9
+ from typing import Dict , List , Optional , Union
10
+ import logging
11
+ import matplotlib .pyplot as plt
10
12
import numpy as np
11
13
import pandas as pd
12
14
from scipy .stats import wilcoxon
13
15
14
16
from aeon .utils .validation ._dependencies import _check_soft_dependencies
15
17
18
+ # Set up logging
19
+ logging .basicConfig (level = logging .INFO )
20
+ logger = logging .getLogger (__name__ )
16
21
17
22
def create_multi_comparison_matrix (
18
- df_results ,
19
- output_dir = "./" ,
20
- pdf_savename = None ,
21
- png_savename = None ,
22
- csv_savename = None ,
23
- tex_savename = None ,
24
- used_statistic = "Accuracy" ,
25
- save_as_json = False ,
26
- plot_1v1_comparisons = False ,
27
- order_win_tie_loss = "higher" ,
28
- include_pvalue = True ,
29
- pvalue_test = "wilcoxon" ,
30
- pvalue_test_params = None ,
31
- pvalue_correction = None ,
32
- pvalue_threshold = 0.05 ,
33
- use_mean = "mean-difference" ,
34
- order_stats = "average-statistic" ,
35
- order_better = "decreasing" ,
36
- dataset_column = None ,
37
- precision = 4 ,
38
- load_analysis = False ,
39
- row_comparates = None ,
40
- col_comparates = None ,
41
- excluded_row_comparates = None ,
42
- excluded_col_comparates = None ,
43
- colormap = "coolwarm" ,
44
- fig_size = "auto" ,
45
- font_size = "auto" ,
46
- colorbar_orientation = "vertical" ,
47
- colorbar_value = None ,
48
- win_label = "r>c" ,
49
- tie_label = "r=c" ,
50
- loss_label = "r<c" ,
51
- include_legend = True ,
52
- show_symetry = True ,
53
- ):
23
+ df_results : Union [str , pd .DataFrame ],
24
+ output_dir : str = "./" ,
25
+ save_files : Optional [Dict [str , str ]] = None ,
26
+ used_statistic : str = "Accuracy" ,
27
+ save_as_json : bool = True ,
28
+ plot_1v1_comparisons : bool = False ,
29
+ order_settings : Dict [str , str ] = {"win_tie_loss" : "higher" , "better" : "decreasing" },
30
+ include_pvalue : bool = True ,
31
+ pvalue_test : str = "wilcoxon" ,
32
+ pvalue_test_params : Optional [Dict [str , str ]] = None ,
33
+ pvalue_correction : Optional [str ] = None ,
34
+ pvalue_threshold : float = 0.05 ,
35
+ use_mean : str = "mean-difference" ,
36
+ order_stats : str = "average-statistic" ,
37
+ dataset_column : Optional [str ] = None ,
38
+ precision : int = 4 ,
39
+ load_analysis : bool = False ,
40
+ row_comparates : Optional [List [str ]] = None ,
41
+ col_comparates : Optional [List [str ]] = None ,
42
+ excluded_row_comparates : Optional [List [str ]] = None ,
43
+ excluded_col_comparates : Optional [List [str ]] = None ,
44
+ colormap : str = "coolwarm" ,
45
+ fig_size : Union [str , tuple ] = "auto" ,
46
+ font_size : Union [str , int ] = "auto" ,
47
+ colorbar_orientation : str = "vertical" ,
48
+ colorbar_value : Optional [str ] = None ,
49
+ win_label : str = "r>c" ,
50
+ tie_label : str = "r=c" ,
51
+ loss_label : str = "r<c" ,
52
+ include_legend : bool = True ,
53
+ show_symmetry : bool = True ,
54
+ ) -> plt .Figure :
54
55
"""Generate the Multi-Comparison Matrix (MCM) [1]_.
55
56
56
57
MCM summarises a set of results for multiple estimators evaluated on multiple
@@ -64,27 +65,15 @@ def create_multi_comparison_matrix(
64
65
row should contain the names of the estimators and the first column can
65
66
contain the names of the problems if `dataset_column` is true.
66
67
output_dir: str, default = './'
67
- The output directory for the results.
68
- pdf_savename: str, default = None
69
- The name of the saved file into pdf format. if None, it will not be saved into
70
- this format.
71
- png_savename: str, default = None
72
- The name of the saved file into png format, if None, it will not be saved
73
- into this format.
74
- csv_savename: str, default = None
75
- The name of the saved file into csv format, if None, will not be saved into
76
- this format.
77
- tex_savename: str, default = None
78
- The name of the saved file into tex format, if None, will not be saved into
79
- this format.
68
+ The output directory for the results.
69
+ save_files: dict, default = None
70
+ Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
80
71
used_statistic: str, default = 'Score'
81
72
Name of the metric being assesses (e.g. accuracy, error, mse).
82
73
save_as_json: bool, default = True
83
74
Whether or not to save the python analysis dict into a json file format.
84
75
plot_1v1_comparisons: bool, default = True
85
76
Whether or not to plot the 1v1 scatter results.
86
- order_win_tie_loss: str, default = 'higher'
87
- The order on considering a win or a loss for a given statistics.
88
77
include_pvalue bool, default = True
89
78
Condition whether or not include a pvalue stats.
90
79
pvalue_test: str, default = 'wilcoxon'
@@ -117,8 +106,8 @@ def create_multi_comparison_matrix(
117
106
amean-amean average over difference of use_mean
118
107
pvalue average pvalue over all comparates
119
108
================================================================
120
- order_better: str , default = ' decreasing'
121
- By which order to sort stats, from best to worse .
109
+ order_settings: dict , default = {"win_tie_loss": "higher", "better": " decreasing"}
110
+ Settings for ordering the results .
122
111
dataset_column: str, default = 'dataset_name'
123
112
The name of the datasets column in the csv file.
124
113
precision: int, default = 4
@@ -178,45 +167,46 @@ def create_multi_comparison_matrix(
178
167
Evaluations That Is Stable Under Manipulation Of The Comparate Set
179
168
arXiv preprint arXiv:2305.11921, 2023.
180
169
"""
170
+
171
+ logger .info ("Starting MCM creation..." )
172
+
181
173
if isinstance (df_results , str ):
182
174
try :
183
175
df_results = pd .read_csv (df_results )
176
+ except FileNotFoundError :
177
+ raise FileNotFoundError (f"The file { df_results } was not found." )
178
+ except pd .errors .EmptyDataError :
179
+ raise ValueError (f"The file { df_results } is empty." )
184
180
except Exception as e :
185
- raise ValueError (f"No dataframe or valid path is given: Exception { e } " )
186
-
181
+ raise ValueError (f"An error occurred while reading the file: { e } " )
182
+
187
183
analysis = _get_analysis (
188
184
df_results ,
189
185
output_dir = output_dir ,
190
186
used_statistic = used_statistic ,
191
187
save_as_json = save_as_json ,
192
188
plot_1v1_comparisons = plot_1v1_comparisons ,
193
- order_win_tie_loss = order_win_tie_loss ,
189
+ order_win_tie_loss = order_settings [ "win_tie_loss" ] ,
194
190
include_pvalue = include_pvalue ,
195
191
pvalue_test = pvalue_test ,
196
192
pvalue_test_params = pvalue_test_params ,
197
193
pvalue_correction = pvalue_correction ,
198
- pvalue_threshhold = pvalue_threshold ,
194
+ pvalue_threshold = pvalue_threshold ,
199
195
use_mean = use_mean ,
200
196
order_stats = order_stats ,
201
- order_better = order_better ,
197
+ order_better = order_settings [ "better" ] ,
202
198
dataset_column = dataset_column ,
203
199
precision = precision ,
204
200
load_analysis = load_analysis ,
205
- )
206
-
207
- # start drawing heatmap
208
- temp = _draw (
209
- analysis ,
210
- pdf_savename = pdf_savename ,
211
- png_savename = png_savename ,
212
- tex_savename = tex_savename ,
213
- csv_savename = csv_savename ,
214
- output_dir = output_dir ,
215
201
row_comparates = row_comparates ,
216
202
col_comparates = col_comparates ,
217
203
excluded_row_comparates = excluded_row_comparates ,
218
204
excluded_col_comparates = excluded_col_comparates ,
219
- precision = precision ,
205
+ )
206
+
207
+ # Generate figure
208
+ fig = _draw (
209
+ analysis ,
220
210
colormap = colormap ,
221
211
fig_size = fig_size ,
222
212
font_size = font_size ,
@@ -226,9 +216,28 @@ def create_multi_comparison_matrix(
226
216
tie_label = tie_label ,
227
217
loss_label = loss_label ,
228
218
include_legend = include_legend ,
229
- show_symetry = show_symetry ,
219
+ show_symmetry = show_symmetry ,
230
220
)
231
- return temp
221
+
222
+
223
+ # start drawing heatmap
224
+ if save_files :
225
+ for file_type , file_name in save_files .items ():
226
+ if file_name :
227
+ save_path = os .path .join (output_dir , file_name )
228
+ if file_type == "pdf" :
229
+ fig .savefig (save_path , format = "pdf" , bbox_inches = "tight" )
230
+ elif file_type == "png" :
231
+ fig .savefig (save_path , format = "png" , dpi = 300 )
232
+ elif file_type == "csv" :
233
+ df_results .to_csv (save_path , index = False )
234
+ elif file_type == "tex" :
235
+ with open (save_path , "w" ) as f :
236
+ f .write (analysis .to_latex ())
237
+
238
+ logger .info ("MCM creation completed." )
239
+ return fig
240
+
232
241
233
242
234
243
def _get_analysis (
0 commit comments