Skip to content

Commit d96ed83

Browse files
authored
🚚 add functionality from acore (#7)
1 parent 0dab945 commit d96ed83

File tree

3 files changed

+308
-0
lines changed

3 files changed

+308
-0
lines changed

src/vuecore/__init__.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,211 @@
1+
"""
2+
This module contains functions to plot data. It will be moved to a separate
3+
visualization package.
4+
"""
5+
16
from importlib import metadata
27

38
__version__ = metadata.version("vuecore")
9+
10+
11+
import logging
12+
import pathlib
13+
from typing import Iterable
14+
15+
import matplotlib
16+
import matplotlib.pyplot as plt
17+
import numpy as np
18+
import pandas as pd
19+
20+
plt.rcParams["figure.figsize"] = [4.0, 3.0]
21+
plt.rcParams["pdf.fonttype"] = 42
22+
plt.rcParams["ps.fonttype"] = 42
23+
24+
plt.rcParams["figure.dpi"] = 147
25+
26+
figsize_a4 = (8.3, 11.7)
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
def savefig(
32+
fig: matplotlib.figure.Figure,
33+
name: str,
34+
folder: pathlib.Path = ".",
35+
pdf=True,
36+
tight_layout=True,
37+
dpi=300,
38+
):
39+
"""Save matplotlib Figure (having method `savefig`) as pdf and png."""
40+
folder = pathlib.Path(folder)
41+
fname = folder / name
42+
folder = fname.parent # in case name specifies folders
43+
folder.mkdir(exist_ok=True, parents=True)
44+
if not fig.get_constrained_layout() and tight_layout:
45+
fig.tight_layout()
46+
fig.savefig(fname.with_suffix(".png"), bbox_inches="tight", dpi=dpi)
47+
if pdf:
48+
fig.savefig(fname.with_suffix(".pdf"), bbox_inches="tight", dpi=dpi)
49+
logger.info(f"Saved Figures to {fname}")
50+
51+
52+
def select_xticks(ax: matplotlib.axes.Axes, max_ticks: int = 50) -> list:
53+
"""Limit the number of xticks displayed.
54+
55+
Parameters
56+
----------
57+
ax : matplotlib.axes.Axes
58+
Axes object to manipulate
59+
max_ticks : int, optional
60+
maximum number of set ticks on x-axis, by default 50
61+
62+
Returns
63+
-------
64+
list
65+
list of current ticks for x-axis. Either new
66+
or old (depending if something was changed).
67+
"""
68+
x_ticks = ax.get_xticks()
69+
offset = len(x_ticks) // max_ticks
70+
if offset > 1: # if larger than 1
71+
return ax.set_xticks(x_ticks[::offset])
72+
return x_ticks
73+
74+
75+
def select_dates(date_series: pd.Series, max_ticks=30) -> np.array:
76+
"""Get unique dates (single days) for selection in pd.plot.line
77+
with xticks argument.
78+
79+
Parameters
80+
----------
81+
date_series : pd.Series
82+
datetime series to use (values, not index)
83+
max_ticks : int, optional
84+
maximum number of unique ticks to select, by default 30
85+
86+
Returns
87+
-------
88+
np.array
89+
array of selected dates
90+
"""
91+
xticks = date_series.dt.date.unique()
92+
offset = len(xticks) // max_ticks
93+
if offset > 1:
94+
return xticks[::offset]
95+
else:
96+
xticks
97+
98+
99+
def make_large_descriptors(size="xx-large"):
100+
"""Helper function to have very large titles, labes and tick texts for
101+
matplotlib plots per default.
102+
103+
size: str
104+
fontsize or allowed category. Change default if necessary, default 'xx-large'
105+
"""
106+
plt.rcParams.update(
107+
{
108+
k: size
109+
for k in [
110+
"xtick.labelsize",
111+
"ytick.labelsize",
112+
"axes.titlesize",
113+
"axes.labelsize",
114+
"legend.fontsize",
115+
"legend.title_fontsize",
116+
]
117+
}
118+
)
119+
120+
121+
set_font_sizes = make_large_descriptors
122+
123+
124+
def add_prop_as_second_yaxis(
125+
ax: matplotlib.axes.Axes, n_samples: int, format_str: str = "{x:,.3f}"
126+
) -> matplotlib.axes.Axes:
127+
"""Add proportion as second axis. Try to align cleverly
128+
129+
Parameters
130+
----------
131+
ax : matplotlib.axes.Axes
132+
Axes for which you want to add a second y-axis
133+
n_samples : int
134+
Number of total samples (to normalize against)
135+
136+
Returns
137+
-------
138+
matplotlib.axes.Axes
139+
Second layover twin Axes with right-hand side y-axis
140+
"""
141+
ax2 = ax.twinx()
142+
n_min, n_max = np.round(ax.get_ybound())
143+
logger.info(f"{n_min = }, {n_max = }")
144+
lower_prop = n_min / n_samples + (ax.get_ybound()[0] - n_min) / n_samples
145+
upper_prop = n_max / n_samples + (ax.get_ybound()[1] - n_max) / n_samples
146+
logger.info(f"{lower_prop = }, {upper_prop = }")
147+
ax2.set_ybound(lower_prop, upper_prop)
148+
# _ = ax2.set_yticks(np.linspace(n_min/n_samples,
149+
# n_max /n_samples, len(ax.get_yticks())-2))
150+
_ = ax2.set_yticks(ax.get_yticks()[1:-1] / n_samples)
151+
ax2.yaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter(format_str))
152+
return ax2
153+
154+
155+
def add_height_to_barplot(
156+
ax: matplotlib.axes.Axes, size: int = 15
157+
) -> matplotlib.axes.Axes:
158+
"""Add height of bar to each bar in a barplot."""
159+
for bar in ax.patches:
160+
ax.annotate(
161+
text=format(bar.get_height(), ".2f"),
162+
xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
163+
xytext=(0, 7),
164+
ha="center",
165+
va="center",
166+
size=size,
167+
textcoords="offset points",
168+
)
169+
return ax
170+
171+
172+
def add_text_to_barplot(
173+
ax: matplotlib.axes.Axes, text: Iterable[str], size=15
174+
) -> matplotlib.axes.Axes:
175+
"""Add custom text from Iterable to each bar in a barplot."""
176+
for bar, text_bar in zip(ax.patches, text):
177+
msg = f"{bar = }, {text = }, {bar.get_height() = }"
178+
logger.debug(msg)
179+
ax.annotate(
180+
text=text_bar,
181+
xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
182+
xytext=(0, -5),
183+
rotation=90,
184+
ha="center",
185+
va="top",
186+
size=size,
187+
textcoords="offset points",
188+
)
189+
return ax
190+
191+
192+
def format_large_numbers(
193+
ax: matplotlib.axes.Axes, format_str: str = "{x:,.0f}"
194+
) -> matplotlib.axes.Axes:
195+
"""Format large integer numbers to be read more easily.
196+
197+
Parameters
198+
----------
199+
ax : matplotlib.axes.Axes
200+
Axes which labels should be manipulated.
201+
format_str : str, optional
202+
Default float format string, by default '{x:,.0f}'
203+
204+
Returns
205+
-------
206+
matplotlib.axes.Axes
207+
Return reference to modified input Axes object.
208+
"""
209+
ax.xaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter(format_str))
210+
ax.yaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter(format_str))
211+
return ax

src/vuecore/decomposition.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Decompositon plots like pca, umap, tsne, etc."""
2+
3+
from typing import Optional
4+
5+
import matplotlib
6+
import pandas as pd
7+
import sklearn.decomposition
8+
9+
10+
def plot_explained_variance(
11+
pca: sklearn.decomposition.PCA, ax: Optional[matplotlib.axes.Axes] = None
12+
) -> matplotlib.axes.Axes:
13+
"""Plot explained variance of PCA from scikit-learn."""
14+
exp_var = pd.Series(pca.explained_variance_ratio_).to_frame("explained variance")
15+
exp_var.index += 1 # start at 1
16+
exp_var["explained variance (cummulated)"] = exp_var["explained variance"].cumsum()
17+
exp_var.index.name = "PC"
18+
ax = exp_var.plot(ax=ax)
19+
return ax

src/vuecore/metrics.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Plot metrics for binary classification."""
2+
3+
import matplotlib
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
from njab.sklearn.types import Results, ResultsSplit
7+
8+
LIMITS = (-0.05, 1.05)
9+
10+
11+
def plot_split_auc(
12+
result: ResultsSplit, name: str, ax: matplotlib.axes.Axes
13+
) -> matplotlib.axes.Axes:
14+
"""Add receiver operation curve to ax of a split of the data."""
15+
col_name = f"{name} (auc: {result.auc:.3f})"
16+
roc = pd.DataFrame(result.roc, index="fpr tpr cutoffs".split()).rename(
17+
{"tpr": col_name}
18+
)
19+
ax = roc.T.plot(
20+
"fpr",
21+
col_name,
22+
xlabel="false positive rate",
23+
ylabel="true positive rate",
24+
style=".-",
25+
ylim=LIMITS,
26+
xlim=LIMITS,
27+
ax=ax,
28+
)
29+
return ax
30+
31+
32+
# ! should be roc
33+
def plot_auc(
34+
results: Results,
35+
ax: matplotlib.axes.Axes = None,
36+
label_train="train",
37+
label_test="test",
38+
**kwargs,
39+
) -> matplotlib.axes.Axes:
40+
"""Plot ROC curve for train and test data."""
41+
if ax is None:
42+
fig, ax = plt.subplots(1, 1, **kwargs)
43+
ax = plot_split_auc(results.train, f"{label_train}", ax)
44+
ax = plot_split_auc(results.test, f"{label_test}", ax)
45+
return ax
46+
47+
48+
def plot_split_prc(
49+
result: ResultsSplit, name: str, ax: matplotlib.axes.Axes
50+
) -> matplotlib.axes.Axes:
51+
"""Add precision recall curve to ax of a split of the data."""
52+
col_name = f"{name} (aps: {result.aps:.3f})"
53+
roc = pd.DataFrame(result.prc, index="precision recall cutoffs".split()).rename(
54+
{"precision": col_name}
55+
)
56+
ax = roc.T.plot(
57+
"recall",
58+
col_name,
59+
xlabel="true positive rate",
60+
ylabel="precision",
61+
style=".-",
62+
ylim=LIMITS,
63+
xlim=LIMITS,
64+
ax=ax,
65+
)
66+
return ax
67+
68+
69+
def plot_prc(
70+
results: ResultsSplit,
71+
ax: matplotlib.axes.Axes = None,
72+
label_train="train",
73+
label_test="test",
74+
**kwargs,
75+
):
76+
"""Plot precision recall curve for train and test data."""
77+
if ax is None:
78+
fig, ax = plt.subplots(1, 1, **kwargs)
79+
ax = plot_split_prc(results.train, f"{label_train}", ax)
80+
ax = plot_split_prc(results.test, f"{label_test}", ax)
81+
return ax

0 commit comments

Comments
 (0)