Skip to content

Commit a553949

Browse files
PravasMohantygithub-actions[bot]
authored andcommitted
Automatic pre-commit fixes
1 parent dbdc5dd commit a553949

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

aeon/visualisation/results/_mcm.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
__all__ = ["create_multi_comparison_matrix"]
66

77
import json
8+
import logging
89
import os
910
from typing import Dict, List, Optional, Union
10-
import logging
11+
1112
import matplotlib.pyplot as plt
1213
import numpy as np
1314
import pandas as pd
@@ -19,28 +20,29 @@
1920
logging.basicConfig(level=logging.INFO)
2021
logger = logging.getLogger(__name__)
2122

23+
2224
def create_multi_comparison_matrix(
23-
df_results: Union[str, pd.DataFrame],
25+
df_results: Union[str, pd.DataFrame],
2426
output_dir: str = "./",
25-
save_files: Optional[Dict[str, str]] = None,
27+
save_files: Optional[dict[str, str]] = None,
2628
used_statistic: str = "Accuracy",
2729
save_as_json: bool = True,
2830
plot_1v1_comparisons: bool = False,
29-
order_settings: Dict[str, str] = {"win_tie_loss": "higher", "better": "decreasing"},
31+
order_settings: dict[str, str] = {"win_tie_loss": "higher", "better": "decreasing"},
3032
include_pvalue: bool = True,
3133
pvalue_test: str = "wilcoxon",
32-
pvalue_test_params: Optional[Dict[str, str]] = None,
34+
pvalue_test_params: Optional[dict[str, str]] = None,
3335
pvalue_correction: Optional[str] = None,
3436
pvalue_threshold: float = 0.05,
3537
use_mean: str = "mean-difference",
3638
order_stats: str = "average-statistic",
3739
dataset_column: Optional[str] = None,
3840
precision: int = 4,
3941
load_analysis: bool = False,
40-
row_comparates: Optional[List[str]] = None,
41-
col_comparates: Optional[List[str]] = None,
42-
excluded_row_comparates: Optional[List[str]] = None,
43-
excluded_col_comparates: Optional[List[str]] = None,
42+
row_comparates: Optional[list[str]] = None,
43+
col_comparates: Optional[list[str]] = None,
44+
excluded_row_comparates: Optional[list[str]] = None,
45+
excluded_col_comparates: Optional[list[str]] = None,
4446
colormap: str = "coolwarm",
4547
fig_size: Union[str, tuple] = "auto",
4648
font_size: Union[str, int] = "auto",
@@ -65,9 +67,9 @@ def create_multi_comparison_matrix(
6567
row should contain the names of the estimators and the first column can
6668
contain the names of the problems if `dataset_column` is true.
6769
output_dir: str, default = './'
68-
The output directory for the results.
70+
The output directory for the results.
6971
save_files: dict, default = None
70-
Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
72+
Dictionary to handle all save options (e.g., {"pdf": "output.pdf", "png": "output.png"}).
7173
used_statistic: str, default = 'Score'
7274
Name of the metric being assesses (e.g. accuracy, error, mse).
7375
save_as_json: bool, default = True
@@ -167,9 +169,8 @@ def create_multi_comparison_matrix(
167169
Evaluations That Is Stable Under Manipulation Of The Comparate Set
168170
arXiv preprint arXiv:2305.11921, 2023.
169171
"""
170-
171172
logger.info("Starting MCM creation...")
172-
173+
173174
if isinstance(df_results, str):
174175
try:
175176
df_results = pd.read_csv(df_results)
@@ -179,7 +180,7 @@ def create_multi_comparison_matrix(
179180
raise ValueError(f"The file {df_results} is empty.")
180181
except Exception as e:
181182
raise ValueError(f"An error occurred while reading the file: {e}")
182-
183+
183184
analysis = _get_analysis(
184185
df_results,
185186
output_dir=output_dir,
@@ -203,7 +204,7 @@ def create_multi_comparison_matrix(
203204
excluded_row_comparates=excluded_row_comparates,
204205
excluded_col_comparates=excluded_col_comparates,
205206
)
206-
207+
207208
# Generate figure
208209
fig = _draw(
209210
analysis,
@@ -218,7 +219,6 @@ def create_multi_comparison_matrix(
218219
include_legend=include_legend,
219220
show_symmetry=show_symmetry,
220221
)
221-
222222

223223
# start drawing heatmap
224224
if save_files:
@@ -239,7 +239,6 @@ def create_multi_comparison_matrix(
239239
return fig
240240

241241

242-
243242
def _get_analysis(
244243
df_results,
245244
output_dir="./",

0 commit comments

Comments
 (0)