11import numpy as np
22import pytest
33import scanpy as sc
4+ from anndata import AnnData
45from pandas import DataFrame , Series
56from pytest import fixture , mark
67
78import pertpy as pt
9+ from pertpy .tools ._distances ._distances import Distance , Metric
810
9- actual_distances = [
11+ actual_distances : tuple [ Metric , ...] = (
1012 # Euclidean distances and related
1113 "euclidean" ,
1214 "mean_absolute_error" ,
2224 "t_test" ,
2325 "wasserstein" ,
2426 "mahalanobis" ,
25- ]
26- semi_distances = ["r2_distance" , "sym_kldiv" , "ks_test" ]
27- non_distances = ["classifier_proba" ]
28- onesided_only = ["classifier_cp" ]
29- pseudo_counts_distances = ["nb_ll" ]
30- lognorm_counts_distances = ["mean_var_distribution" ]
31- all_distances = (
32- actual_distances + semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances
33- ) # + onesided_only
27+ )
28+ semi_distances : tuple [Metric , ...] = ("r2_distance" , "sym_kldiv" , "ks_test" )
29+ non_distances : tuple [Metric , ...] = ("classifier_proba" ,)
30+ onesided_only : tuple [Metric , ...] = ("classifier_cp" ,)
31+ pseudo_counts_distances : tuple [Metric , ...] = ("nb_ll" ,)
32+ lognorm_counts_distances : tuple [Metric , ...] = ("mean_var_distribution" ,)
33+ all_distances : tuple [Metric , ...] = (
34+ * actual_distances ,
35+ * semi_distances ,
36+ * non_distances ,
37+ * lognorm_counts_distances ,
38+ * pseudo_counts_distances ,
39+ # *onesided_only,
40+ )
3441
3542
3643@fixture
37- def adata (request ) :
38- low_subsample_distances = [
44+ def adata (distance : Metric , rng : np . random . Generator ) -> AnnData :
45+ low_subsample_distances = {
3946 "sym_kldiv" ,
4047 "t_test" ,
4148 "ks_test" ,
4249 "classifier_proba" ,
4350 "classifier_cp" ,
4451 "mahalanobis" ,
4552 "mean_var_distribution" ,
46- ]
47- no_subsample_distances = ["mahalanobis" ] # mahalanobis only works on the full data without subsampling
48-
49- distance = request .node .callspec .params ["distance" ]
53+ }
54+ no_subsample_distances = {"mahalanobis" } # mahalanobis only works on the full data without subsampling
5055
5156 adata = pt .dt .distance_example ()
5257 if distance not in no_subsample_distances :
@@ -55,7 +60,7 @@ def adata(request):
5560 else :
5661 adata = sc .pp .subsample (adata , 0.001 , copy = True )
5762
58- adata = adata [:, np . random . default_rng () .choice (adata .n_vars , 100 , replace = False )].copy ()
63+ adata = adata [:, rng .choice (adata .n_vars , 100 , replace = False )].copy ()
5964
6065 adata .layers ["lognorm" ] = adata .X .copy ()
6166 adata .layers ["counts" ] = np .round (adata .X .toarray ()).astype (int )
@@ -70,25 +75,23 @@ def adata(request):
7075
7176
7277@fixture
73- def distance_obj (request ):
74- distance = request .node .callspec .params ["distance" ]
78+ def distance_obj (distance : Metric ) -> pt .tl .Distance :
7579 if distance in lognorm_counts_distances :
76- Distance = pt .tl .Distance (distance , layer_key = "lognorm" )
77- elif distance in pseudo_counts_distances :
78- Distance = pt .tl .Distance (distance , layer_key = "counts" )
79- else :
80- Distance = pt .tl .Distance (distance , obsm_key = "X_pca" )
81- return Distance
80+ return pt .tl .Distance (distance , layer_key = "lognorm" )
81+ if distance in pseudo_counts_distances :
82+ return pt .tl .Distance (distance , layer_key = "counts" )
83+ return pt .tl .Distance (distance , obsm_key = "X_pca" )
8284
8385
8486@fixture
85- @mark .parametrize ("distance" , all_distances )
86- def pairwise_distance (adata , distance_obj , distance ):
87+ def pairwise_distance (adata : AnnData , distance_obj : pt .tl .Distance ) -> DataFrame :
8788 return distance_obj .pairwise (adata , groupby = "perturbation" , show_progressbar = True )
8889
8990
9091@mark .parametrize ("distance" , actual_distances + semi_distances )
91- def test_distance_axioms (pairwise_distance , distance ):
92+ def test_distance_axioms (pairwise_distance : DataFrame , distance : Metric ) -> None :
93+ del distance
94+
9295 # This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality.
9396 # (M1) Definiteness
9497 assert all (np .diag (pairwise_distance .values ) == 0 ) # distance to self is 0
@@ -102,12 +105,12 @@ def test_distance_axioms(pairwise_distance, distance):
102105
103106
104107@mark .parametrize ("distance" , actual_distances )
105- def test_triangle_inequality (pairwise_distance , distance , rng ) :
106- # Test if distances are well-defined in accordance with metric axioms
107- # (M4) Triangle inequality (we just probe this for a few random triplets)
108- # Some tests are not well defined for the triangle inequality. We skip those.
108+ def test_triangle_inequality (pairwise_distance : DataFrame , distance : Metric , rng : np . random . Generator ) -> None :
109+ """ Test if distances are well-defined in accordance with metric axioms
110+ (M4) Triangle inequality (we just probe this for a few random triplets)
111+ """
109112 if distance in {"mahalanobis" , "wasserstein" }:
110- return
113+ pytest . skip ( "Some tests not well defined for triangle inequality" )
111114
112115 for _ in range (5 ):
113116 triplet = rng .choice (pairwise_distance .index , size = 3 , replace = False )
@@ -118,30 +121,33 @@ def test_triangle_inequality(pairwise_distance, distance, rng):
118121
119122
120123@mark .parametrize ("distance" , all_distances )
121- def test_distance_layers (pairwise_distance , distance ):
124+ def test_distance_layers (pairwise_distance : DataFrame , distance : Metric ) -> None :
125+ del distance
126+
122127 assert isinstance (pairwise_distance , DataFrame )
123128 assert pairwise_distance .columns .equals (pairwise_distance .index )
124129 assert np .sum (pairwise_distance .values - pairwise_distance .values .T ) == 0 # symmetry
125130
126131
127132@mark .parametrize ("distance" , actual_distances + pseudo_counts_distances )
128- def test_distance_counts (adata , distance ):
129- if distance != "mahalanobis" : # skip, doesn't work because covariance matrix is a singular matrix, not invertible
130- distance = pt .tl .Distance (distance , layer_key = "counts" )
131- df = distance .pairwise (adata , groupby = "perturbation" )
132- assert isinstance (df , DataFrame )
133- assert df .columns .equals (df .index )
134- assert np .sum (df .values - df .values .T ) == 0
133+ def test_distance_counts (adata : AnnData , distance : Metric ) -> None :
134+ if distance == "mahalanobis" :
135+ pytest .skip ("covariance matrix is a singular matrix, not invertible" )
136+ distance_obj = pt .tl .Distance (distance , layer_key = "counts" )
137+ df = distance_obj .pairwise (adata , groupby = "perturbation" )
138+ assert isinstance (df , DataFrame )
139+ assert df .columns .equals (df .index )
140+ assert np .sum (df .values - df .values .T ) == 0
135141
136142
137143@mark .parametrize ("distance" , all_distances )
138- def test_mutually_exclusive_keys (distance ) :
144+ def test_mutually_exclusive_keys (distance : Metric ) -> None :
139145 with pytest .raises (ValueError ):
140146 _ = pt .tl .Distance (distance , layer_key = "counts" , obsm_key = "X_pca" )
141147
142148
143149@mark .parametrize ("distance" , actual_distances + semi_distances + non_distances )
144- def test_distance_output_type (distance , rng ) :
150+ def test_distance_output_type (distance : Metric , rng : np . random . Generator ) -> None :
145151 # Test if distances are outputting floats
146152 Distance = pt .tl .Distance (distance )
147153 X = rng .normal (size = (50 , 10 ))
@@ -151,15 +157,16 @@ def test_distance_output_type(distance, rng):
151157
152158
153159@mark .parametrize ("distance" , all_distances + onesided_only )
154- def test_distance_onesided (adata , distance_obj , distance ):
160+ def test_distance_onesided (adata : AnnData , distance_obj : Distance , distance : Metric ) -> None :
161+ del distance
155162 # Test consistency of one-sided distance results
156- selected_group = adata .obs . perturbation .unique ()[0 ]
163+ selected_group = adata .obs [ " perturbation" ] .unique ()[0 ]
157164 df = distance_obj .onesided_distances (adata , groupby = "perturbation" , selected_group = selected_group )
158165 assert isinstance (df , Series )
159166 assert df .loc [selected_group ] == 0 # distance to self is 0
160167
161168
162- def test_bootstrap_distance_output_type (rng ) :
169+ def test_bootstrap_distance_output_type (rng : np . random . Generator ) -> None :
163170 # Test if distances are outputting floats
164171 Distance = pt .tl .Distance (metric = "edistance" )
165172 X = rng .normal (size = (50 , 10 ))
@@ -170,7 +177,7 @@ def test_bootstrap_distance_output_type(rng):
170177
171178
172179@mark .parametrize ("distance" , ["edistance" ])
173- def test_bootstrap_distance_pairwise (adata , distance ) :
180+ def test_bootstrap_distance_pairwise (adata : AnnData , distance : Metric ) -> None :
174181 # Test consistency of pairwise distance results
175182 Distance = pt .tl .Distance (distance , obsm_key = "X_pca" )
176183 bootstrap_output = Distance .pairwise (adata , groupby = "perturbation" , bootstrap = True , n_bootstrap = 3 )
@@ -186,9 +193,9 @@ def test_bootstrap_distance_pairwise(adata, distance):
186193
187194
188195@mark .parametrize ("distance" , ["edistance" ])
189- def test_bootstrap_distance_onesided (adata , distance ) :
196+ def test_bootstrap_distance_onesided (adata : AnnData , distance : Metric ) -> None :
190197 # Test consistency of one-sided distance results
191- selected_group = adata .obs . perturbation .unique ()[0 ]
198+ selected_group = adata .obs [ " perturbation" ] .unique ()[0 ]
192199 Distance = pt .tl .Distance (distance , obsm_key = "X_pca" )
193200 bootstrap_output = Distance .onesided_distances (
194201 adata ,
@@ -201,7 +208,7 @@ def test_bootstrap_distance_onesided(adata, distance):
201208 assert isinstance (bootstrap_output , tuple )
202209
203210
204- def test_compare_distance (rng ) :
211+ def test_compare_distance (rng : np . random . Generator ) -> None :
205212 X = rng .normal (size = (50 , 10 ))
206213 Y = rng .normal (size = (50 , 10 ))
207214 C = rng .normal (size = (50 , 10 ))
@@ -211,4 +218,4 @@ def test_compare_distance(rng):
211218 res_scaled = Distance .compare_distance (X , Y , C , mode = "scaled" )
212219 assert isinstance (res_scaled , float )
213220 with pytest .raises (ValueError ):
214- Distance .compare_distance (X , Y , C , mode = "new_mode" )
221+ Distance .compare_distance (X , Y , C , mode = "new_mode" ) # type: ignore[arg-type]
0 commit comments