6
6
7
7
import json
8
8
import os
9
- from typing import Dict , List , Optional , Union
10
- import logging
11
- import matplotlib .pyplot as plt
9
+ from typing import Dict , Optional
12
10
import numpy as np
13
11
import pandas as pd
14
12
from scipy .stats import wilcoxon
15
13
16
14
from aeon .utils .validation ._dependencies import _check_soft_dependencies
17
15
18
- # Set up logging
19
- logging .basicConfig (level = logging .INFO )
20
- logger = logging .getLogger (__name__ )
21
-
22
16
def create_multi_comparison_matrix (
23
- df_results : Union [ str , pd . DataFrame ] ,
24
- output_dir : str = "./" ,
25
- save_files : Optional [Dict [str , str ]] = None ,
26
- used_statistic : str = "Accuracy" ,
27
- save_as_json : bool = True ,
28
- plot_1v1_comparisons : bool = False ,
29
- order_settings : Dict [str , str ] = {"win_tie_loss" : "higher" , "better" : "decreasing" },
30
- include_pvalue : bool = True ,
31
- pvalue_test : str = "wilcoxon" ,
32
- pvalue_test_params : Optional [ Dict [ str , str ]] = None ,
33
- pvalue_correction : Optional [ str ] = None ,
34
- pvalue_threshold : float = 0.05 ,
35
- use_mean : str = "mean-difference" ,
36
- order_stats : str = "average-statistic" ,
37
- dataset_column : Optional [ str ] = None ,
38
- precision : int = 4 ,
39
- load_analysis : bool = False ,
40
- row_comparates : Optional [ List [ str ]] = None ,
41
- col_comparates : Optional [ List [ str ]] = None ,
42
- excluded_row_comparates : Optional [ List [ str ]] = None ,
43
- excluded_col_comparates : Optional [ List [ str ]] = None ,
44
- colormap : str = "coolwarm" ,
45
- fig_size : Union [ str , tuple ] = "auto" ,
46
- font_size : Union [ str , int ] = "auto" ,
47
- colorbar_orientation : str = "vertical" ,
48
- colorbar_value : Optional [ str ] = None ,
49
- win_label : str = "r>c" ,
50
- tie_label : str = "r=c" ,
51
- loss_label : str = "r<c" ,
52
- include_legend : bool = True ,
53
- show_symmetry : bool = True ,
54
- ) -> plt . Figure :
17
+ df_results ,
18
+ output_dir = "./" ,
19
+ save_files : Optional [Dict [str , str ]] = None , # Changed to dictionary
20
+ used_statistic = "Accuracy" ,
21
+ save_as_json = False ,
22
+ plot_1v1_comparisons = False ,
23
+ order_settings : Dict [str , str ] = {"win_tie_loss" : "higher" , "better" : "decreasing" }, # Combined order settings
24
+ include_pvalue = True ,
25
+ pvalue_test = "wilcoxon" ,
26
+ pvalue_test_params = None ,
27
+ pvalue_correction = None ,
28
+ pvalue_threshold = 0.05 ,
29
+ use_mean = "mean-difference" ,
30
+ order_stats = "average-statistic" ,
31
+ dataset_column = None ,
32
+ precision = 4 ,
33
+ load_analysis = False ,
34
+ row_comparates = None ,
35
+ col_comparates = None ,
36
+ excluded_row_comparates = None ,
37
+ excluded_col_comparates = None ,
38
+ colormap = "coolwarm" ,
39
+ fig_size = "auto" ,
40
+ font_size = "auto" ,
41
+ colorbar_orientation = "vertical" ,
42
+ colorbar_value = None ,
43
+ win_label = "r>c" ,
44
+ tie_label = "r=c" ,
45
+ loss_label = "r<c" ,
46
+ include_legend = True ,
47
+ show_symetry = True ,
48
+ ):
55
49
"""Generate the Multi-Comparison Matrix (MCM) [1]_.
56
50
57
51
MCM summarises a set of results for multiple estimators evaluated on multiple
@@ -65,15 +59,17 @@ def create_multi_comparison_matrix(
65
59
row should contain the names of the estimators and the first column can
66
60
contain the names of the problems if `dataset_column` is true.
67
61
output_dir: str, default = './'
68
- The output directory for the results.
62
+ The output directory for the results.
69
63
save_files: dict, default = None
70
- Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
64
+ Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
71
65
used_statistic: str, default = 'Score'
72
66
Name of the metric being assesses (e.g. accuracy, error, mse).
73
67
save_as_json: bool, default = True
74
68
Whether or not to save the python analysis dict into a json file format.
75
69
plot_1v1_comparisons: bool, default = True
76
70
Whether or not to plot the 1v1 scatter results.
71
+ order_settings: dict, default = {"win_tie_loss": "higher", "better": "decreasing"}
72
+ Settings for ordering the results.
77
73
include_pvalue bool, default = True
78
74
Condition whether or not include a pvalue stats.
79
75
pvalue_test: str, default = 'wilcoxon'
@@ -106,8 +102,6 @@ def create_multi_comparison_matrix(
106
102
amean-amean average over difference of use_mean
107
103
pvalue average pvalue over all comparates
108
104
================================================================
109
- order_settings: dict, default = {"win_tie_loss": "higher", "better": "decreasing"}
110
- Settings for ordering the results.
111
105
dataset_column: str, default = 'dataset_name'
112
106
The name of the datasets column in the csv file.
113
107
precision: int, default = 4
@@ -155,7 +149,11 @@ def create_multi_comparison_matrix(
155
149
Example
156
150
-------
157
151
>>> from aeon.visualisation import create_multi_comparison_matrix # doctest: +SKIP
158
- >>> create_multi_comparison_matrix(df_results='results.csv') # doctest: +SKIP
152
+ >>> create_multi_comparison_matrix(
153
+ ... df_results='results.csv',
154
+ ... save_files={"pdf": "output.pdf", "png": "output.png"},
155
+ ... order_settings={"win_tie_loss": "higher", "better": "decreasing"}
156
+ ... ) # doctest: +SKIP
159
157
160
158
Notes
161
159
-----
@@ -167,19 +165,12 @@ def create_multi_comparison_matrix(
167
165
Evaluations That Is Stable Under Manipulation Of The Comparate Set
168
166
arXiv preprint arXiv:2305.11921, 2023.
169
167
"""
170
-
171
- logger .info ("Starting MCM creation..." )
172
-
173
168
if isinstance (df_results , str ):
174
169
try :
175
170
df_results = pd .read_csv (df_results )
176
- except FileNotFoundError :
177
- raise FileNotFoundError (f"The file { df_results } was not found." )
178
- except pd .errors .EmptyDataError :
179
- raise ValueError (f"The file { df_results } is empty." )
180
171
except Exception as e :
181
- raise ValueError (f"An error occurred while reading the file: { e } " )
182
-
172
+ raise ValueError (f"No dataframe or valid path is given: Exception { e } " )
173
+
183
174
analysis = _get_analysis (
184
175
df_results ,
185
176
output_dir = output_dir ,
@@ -191,22 +182,24 @@ def create_multi_comparison_matrix(
191
182
pvalue_test = pvalue_test ,
192
183
pvalue_test_params = pvalue_test_params ,
193
184
pvalue_correction = pvalue_correction ,
194
- pvalue_threshold = pvalue_threshold ,
185
+ pvalue_threshhold = pvalue_threshold ,
195
186
use_mean = use_mean ,
196
187
order_stats = order_stats ,
197
188
order_better = order_settings ["better" ],
198
189
dataset_column = dataset_column ,
199
190
precision = precision ,
200
191
load_analysis = load_analysis ,
192
+ )
193
+
194
+ # start drawing heatmap
195
+ fig = _draw (
196
+ analysis ,
197
+ output_dir = output_dir ,
201
198
row_comparates = row_comparates ,
202
199
col_comparates = col_comparates ,
203
200
excluded_row_comparates = excluded_row_comparates ,
204
201
excluded_col_comparates = excluded_col_comparates ,
205
- )
206
-
207
- # Generate figure
208
- fig = _draw (
209
- analysis ,
202
+ precision = precision ,
210
203
colormap = colormap ,
211
204
fig_size = fig_size ,
212
205
font_size = font_size ,
@@ -216,14 +209,13 @@ def create_multi_comparison_matrix(
216
209
tie_label = tie_label ,
217
210
loss_label = loss_label ,
218
211
include_legend = include_legend ,
219
- show_symmetry = show_symmetry ,
212
+ show_symetry = show_symetry ,
220
213
)
221
-
222
214
223
- # start drawing heatmap
215
+ # Handle saving files if save_files dictionary is provided
224
216
if save_files :
225
217
for file_type , file_name in save_files .items ():
226
- if file_name :
218
+ if file_name : # Only save if filename is not None or empty
227
219
save_path = os .path .join (output_dir , file_name )
228
220
if file_type == "pdf" :
229
221
fig .savefig (save_path , format = "pdf" , bbox_inches = "tight" )
@@ -235,11 +227,9 @@ def create_multi_comparison_matrix(
235
227
with open (save_path , "w" ) as f :
236
228
f .write (analysis .to_latex ())
237
229
238
- logger .info ("MCM creation completed." )
239
230
return fig
240
231
241
232
242
-
243
233
def _get_analysis (
244
234
df_results ,
245
235
output_dir = "./" ,
0 commit comments