5
5
__all__ = ["create_multi_comparison_matrix" ]
6
6
7
7
import json
8
+ import logging
8
9
import os
9
10
from typing import Dict , List , Optional , Union
10
- import logging
11
+
11
12
import matplotlib .pyplot as plt
12
13
import numpy as np
13
14
import pandas as pd
19
20
logging .basicConfig (level = logging .INFO )
20
21
logger = logging .getLogger (__name__ )
21
22
23
+
22
24
def create_multi_comparison_matrix (
23
- df_results : Union [str , pd .DataFrame ],
25
+ df_results : Union [str , pd .DataFrame ],
24
26
output_dir : str = "./" ,
25
- save_files : Optional [Dict [str , str ]] = None ,
27
+ save_files : Optional [dict [str , str ]] = None ,
26
28
used_statistic : str = "Accuracy" ,
27
29
save_as_json : bool = True ,
28
30
plot_1v1_comparisons : bool = False ,
29
- order_settings : Dict [str , str ] = {"win_tie_loss" : "higher" , "better" : "decreasing" },
31
+ order_settings : dict [str , str ] = {"win_tie_loss" : "higher" , "better" : "decreasing" },
30
32
include_pvalue : bool = True ,
31
33
pvalue_test : str = "wilcoxon" ,
32
- pvalue_test_params : Optional [Dict [str , str ]] = None ,
34
+ pvalue_test_params : Optional [dict [str , str ]] = None ,
33
35
pvalue_correction : Optional [str ] = None ,
34
36
pvalue_threshold : float = 0.05 ,
35
37
use_mean : str = "mean-difference" ,
36
38
order_stats : str = "average-statistic" ,
37
39
dataset_column : Optional [str ] = None ,
38
40
precision : int = 4 ,
39
41
load_analysis : bool = False ,
40
- row_comparates : Optional [List [str ]] = None ,
41
- col_comparates : Optional [List [str ]] = None ,
42
- excluded_row_comparates : Optional [List [str ]] = None ,
43
- excluded_col_comparates : Optional [List [str ]] = None ,
42
+ row_comparates : Optional [list [str ]] = None ,
43
+ col_comparates : Optional [list [str ]] = None ,
44
+ excluded_row_comparates : Optional [list [str ]] = None ,
45
+ excluded_col_comparates : Optional [list [str ]] = None ,
44
46
colormap : str = "coolwarm" ,
45
47
fig_size : Union [str , tuple ] = "auto" ,
46
48
font_size : Union [str , int ] = "auto" ,
@@ -65,9 +67,9 @@ def create_multi_comparison_matrix(
65
67
row should contain the names of the estimators and the first column can
66
68
contain the names of the problems if `dataset_column` is true.
67
69
output_dir: str, default = './'
68
- The output directory for the results.
70
+ The output directory for the results.
69
71
save_files: dict, default = None
70
- Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
72
+ Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
71
73
used_statistic: str, default = 'Score'
72
74
Name of the metric being assesses (e.g. accuracy, error, mse).
73
75
save_as_json: bool, default = True
@@ -167,9 +169,8 @@ def create_multi_comparison_matrix(
167
169
Evaluations That Is Stable Under Manipulation Of The Comparate Set
168
170
arXiv preprint arXiv:2305.11921, 2023.
169
171
"""
170
-
171
172
logger .info ("Starting MCM creation..." )
172
-
173
+
173
174
if isinstance (df_results , str ):
174
175
try :
175
176
df_results = pd .read_csv (df_results )
@@ -179,7 +180,7 @@ def create_multi_comparison_matrix(
179
180
raise ValueError (f"The file { df_results } is empty." )
180
181
except Exception as e :
181
182
raise ValueError (f"An error occurred while reading the file: { e } " )
182
-
183
+
183
184
analysis = _get_analysis (
184
185
df_results ,
185
186
output_dir = output_dir ,
@@ -203,7 +204,7 @@ def create_multi_comparison_matrix(
203
204
excluded_row_comparates = excluded_row_comparates ,
204
205
excluded_col_comparates = excluded_col_comparates ,
205
206
)
206
-
207
+
207
208
# Generate figure
208
209
fig = _draw (
209
210
analysis ,
@@ -218,7 +219,6 @@ def create_multi_comparison_matrix(
218
219
include_legend = include_legend ,
219
220
show_symmetry = show_symmetry ,
220
221
)
221
-
222
222
223
223
# start drawing heatmap
224
224
if save_files :
@@ -239,7 +239,6 @@ def create_multi_comparison_matrix(
239
239
return fig
240
240
241
241
242
-
243
242
def _get_analysis (
244
243
df_results ,
245
244
output_dir = "./" ,
0 commit comments