1
1
"""Decompositon plots like pca, umap, tsne, etc."""
2
2
3
+ import itertools
3
4
from typing import Optional
4
5
5
6
import matplotlib
7
+ import matplotlib .pyplot as plt
6
8
import pandas as pd
7
9
import sklearn .decomposition
8
10
@@ -17,3 +19,53 @@ def plot_explained_variance(
17
19
exp_var .index .name = "PC"
18
20
ax = exp_var .plot (ax = ax )
19
21
return ax
22
+
23
+
24
+ def pca_grid (
25
+ PCs : pd .DataFrame ,
26
+ meta_column : pd .Series ,
27
+ n_components : int = 4 ,
28
+ meta_col_name : Optional [str ] = None ,
29
+ figsize = (6 , 8 ),
30
+ ) -> plt .Figure :
31
+ """Plot a grid of scatter plots for the first n_components of PCA, per default 4.
32
+
33
+ Parameters
34
+ ----------
35
+ PCs : pd.DataFrame
36
+ DataFrame with the principal components as columns.
37
+ meta_column : pd.Series
38
+ Series with categorical data to color the scatter plots.
39
+ n_components : int, optional
40
+ Number of first n components to plot, by default 4
41
+ meta_col_name : Optional[str], optional
42
+ If another name than the default series name shoudl be used, by default None
43
+
44
+ Returns
45
+ -------
46
+ plt.Figure
47
+ Matplotlib figure with the scatter plots.
48
+ """
49
+ if meta_col_name is None :
50
+ meta_col_name = meta_column .name
51
+ else :
52
+ meta_column = meta_column .rename (meta_col_name )
53
+ up_to = min (PCs .shape [- 1 ], n_components )
54
+ fig , axes = plt .subplots (up_to - 1 , 2 , figsize = figsize , layout = "constrained" )
55
+ PCs = PCs .join (
56
+ meta_column .astype ("category" )
57
+ ) # ! maybe add a check that it's not continous
58
+ for k , (pos , ax ) in enumerate (
59
+ zip (itertools .combinations (range (up_to ), 2 ), axes .flatten ())
60
+ ):
61
+ i , j = pos
62
+ plot_heatmap = bool (k % 2 )
63
+ PCs .plot .scatter (
64
+ i ,
65
+ j ,
66
+ c = meta_col_name ,
67
+ cmap = "Paired" ,
68
+ ax = ax ,
69
+ colorbar = plot_heatmap ,
70
+ )
71
+ return fig
0 commit comments