Skip to content

Commit 35bb64d

Browse files
fix: adapt fixture usage to newer pytest versions (#879)
* Fix fixture usage * we don’t even need the fixture * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fmt * mypy --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 77c2b39 commit 35bb64d

File tree

3 files changed

+65
-56
lines changed

3 files changed

+65
-56
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33

44

55
@pytest.fixture
6-
def rng():
6+
def rng() -> np.random.Generator:
77
return np.random.default_rng()

tests/tools/_distances/test_distance_tests.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import pytest
22
import scanpy as sc
3+
from anndata import AnnData
34
from pandas import DataFrame
45

56
import pertpy as pt
7+
from pertpy.tools._distances._distances import Metric
68

7-
distances = [
9+
distances: tuple[Metric, ...] = (
810
"edistance",
911
"euclidean",
1012
"mse",
@@ -25,21 +27,21 @@
2527
# "nbll",
2628
"mahalanobis",
2729
"mean_var_distribution",
28-
]
30+
)
2931

3032
count_distances = ["nb_ll"]
3133

3234

3335
@pytest.fixture
34-
def adata():
36+
def adata() -> AnnData:
3537
adata = pt.dt.distance_example()
3638
adata = sc.pp.subsample(adata, 0.1, copy=True)
3739

3840
return adata
3941

4042

4143
@pytest.mark.parametrize("distance", distances)
42-
def test_distancetest(adata, distance):
44+
def test_distancetest(adata: AnnData, distance: Metric) -> None:
4345
etest = pt.tl.DistanceTest(distance, n_perms=10, obsm_key="X_pca", alpha=0.05, correction="holm-sidak")
4446
tab = etest(adata, groupby="perturbation", contrast="control")
4547

Lines changed: 58 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import numpy as np
22
import pytest
33
import scanpy as sc
4+
from anndata import AnnData
45
from pandas import DataFrame, Series
56
from pytest import fixture, mark
67

78
import pertpy as pt
9+
from pertpy.tools._distances._distances import Distance, Metric
810

9-
actual_distances = [
11+
actual_distances: tuple[Metric, ...] = (
1012
# Euclidean distances and related
1113
"euclidean",
1214
"mean_absolute_error",
@@ -22,31 +24,34 @@
2224
"t_test",
2325
"wasserstein",
2426
"mahalanobis",
25-
]
26-
semi_distances = ["r2_distance", "sym_kldiv", "ks_test"]
27-
non_distances = ["classifier_proba"]
28-
onesided_only = ["classifier_cp"]
29-
pseudo_counts_distances = ["nb_ll"]
30-
lognorm_counts_distances = ["mean_var_distribution"]
31-
all_distances = (
32-
actual_distances + semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances
33-
) # + onesided_only
27+
)
28+
semi_distances: tuple[Metric, ...] = ("r2_distance", "sym_kldiv", "ks_test")
29+
non_distances: tuple[Metric, ...] = ("classifier_proba",)
30+
onesided_only: tuple[Metric, ...] = ("classifier_cp",)
31+
pseudo_counts_distances: tuple[Metric, ...] = ("nb_ll",)
32+
lognorm_counts_distances: tuple[Metric, ...] = ("mean_var_distribution",)
33+
all_distances: tuple[Metric, ...] = (
34+
*actual_distances,
35+
*semi_distances,
36+
*non_distances,
37+
*lognorm_counts_distances,
38+
*pseudo_counts_distances,
39+
# *onesided_only,
40+
)
3441

3542

3643
@fixture
37-
def adata(request):
38-
low_subsample_distances = [
44+
def adata(distance: Metric, rng: np.random.Generator) -> AnnData:
45+
low_subsample_distances = {
3946
"sym_kldiv",
4047
"t_test",
4148
"ks_test",
4249
"classifier_proba",
4350
"classifier_cp",
4451
"mahalanobis",
4552
"mean_var_distribution",
46-
]
47-
no_subsample_distances = ["mahalanobis"] # mahalanobis only works on the full data without subsampling
48-
49-
distance = request.node.callspec.params["distance"]
53+
}
54+
no_subsample_distances = {"mahalanobis"} # mahalanobis only works on the full data without subsampling
5055

5156
adata = pt.dt.distance_example()
5257
if distance not in no_subsample_distances:
@@ -55,7 +60,7 @@ def adata(request):
5560
else:
5661
adata = sc.pp.subsample(adata, 0.001, copy=True)
5762

58-
adata = adata[:, np.random.default_rng().choice(adata.n_vars, 100, replace=False)].copy()
63+
adata = adata[:, rng.choice(adata.n_vars, 100, replace=False)].copy()
5964

6065
adata.layers["lognorm"] = adata.X.copy()
6166
adata.layers["counts"] = np.round(adata.X.toarray()).astype(int)
@@ -70,25 +75,23 @@ def adata(request):
7075

7176

7277
@fixture
73-
def distance_obj(request):
74-
distance = request.node.callspec.params["distance"]
78+
def distance_obj(distance: Metric) -> pt.tl.Distance:
7579
if distance in lognorm_counts_distances:
76-
Distance = pt.tl.Distance(distance, layer_key="lognorm")
77-
elif distance in pseudo_counts_distances:
78-
Distance = pt.tl.Distance(distance, layer_key="counts")
79-
else:
80-
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
81-
return Distance
80+
return pt.tl.Distance(distance, layer_key="lognorm")
81+
if distance in pseudo_counts_distances:
82+
return pt.tl.Distance(distance, layer_key="counts")
83+
return pt.tl.Distance(distance, obsm_key="X_pca")
8284

8385

8486
@fixture
85-
@mark.parametrize("distance", all_distances)
86-
def pairwise_distance(adata, distance_obj, distance):
87+
def pairwise_distance(adata: AnnData, distance_obj: pt.tl.Distance) -> DataFrame:
8788
return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=True)
8889

8990

9091
@mark.parametrize("distance", actual_distances + semi_distances)
91-
def test_distance_axioms(pairwise_distance, distance):
92+
def test_distance_axioms(pairwise_distance: DataFrame, distance: Metric) -> None:
93+
del distance
94+
9295
# This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality.
9396
# (M1) Definiteness
9497
assert all(np.diag(pairwise_distance.values) == 0) # distance to self is 0
@@ -102,12 +105,12 @@ def test_distance_axioms(pairwise_distance, distance):
102105

103106

104107
@mark.parametrize("distance", actual_distances)
105-
def test_triangle_inequality(pairwise_distance, distance, rng):
106-
# Test if distances are well-defined in accordance with metric axioms
107-
# (M4) Triangle inequality (we just probe this for a few random triplets)
108-
# Some tests are not well defined for the triangle inequality. We skip those.
108+
def test_triangle_inequality(pairwise_distance: DataFrame, distance: Metric, rng: np.random.Generator) -> None:
109+
"""Test if distances are well-defined in accordance with metric axioms
110+
(M4) Triangle inequality (we just probe this for a few random triplets)
111+
"""
109112
if distance in {"mahalanobis", "wasserstein"}:
110-
return
113+
pytest.skip("Some tests not well defined for triangle inequality")
111114

112115
for _ in range(5):
113116
triplet = rng.choice(pairwise_distance.index, size=3, replace=False)
@@ -118,30 +121,33 @@ def test_triangle_inequality(pairwise_distance, distance, rng):
118121

119122

120123
@mark.parametrize("distance", all_distances)
121-
def test_distance_layers(pairwise_distance, distance):
124+
def test_distance_layers(pairwise_distance: DataFrame, distance: Metric) -> None:
125+
del distance
126+
122127
assert isinstance(pairwise_distance, DataFrame)
123128
assert pairwise_distance.columns.equals(pairwise_distance.index)
124129
assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 # symmetry
125130

126131

127132
@mark.parametrize("distance", actual_distances + pseudo_counts_distances)
128-
def test_distance_counts(adata, distance):
129-
if distance != "mahalanobis": # skip, doesn't work because covariance matrix is a singular matrix, not invertible
130-
distance = pt.tl.Distance(distance, layer_key="counts")
131-
df = distance.pairwise(adata, groupby="perturbation")
132-
assert isinstance(df, DataFrame)
133-
assert df.columns.equals(df.index)
134-
assert np.sum(df.values - df.values.T) == 0
133+
def test_distance_counts(adata: AnnData, distance: Metric) -> None:
134+
if distance == "mahalanobis":
135+
pytest.skip("covariance matrix is a singular matrix, not invertible")
136+
distance_obj = pt.tl.Distance(distance, layer_key="counts")
137+
df = distance_obj.pairwise(adata, groupby="perturbation")
138+
assert isinstance(df, DataFrame)
139+
assert df.columns.equals(df.index)
140+
assert np.sum(df.values - df.values.T) == 0
135141

136142

137143
@mark.parametrize("distance", all_distances)
138-
def test_mutually_exclusive_keys(distance):
144+
def test_mutually_exclusive_keys(distance: Metric) -> None:
139145
with pytest.raises(ValueError):
140146
_ = pt.tl.Distance(distance, layer_key="counts", obsm_key="X_pca")
141147

142148

143149
@mark.parametrize("distance", actual_distances + semi_distances + non_distances)
144-
def test_distance_output_type(distance, rng):
150+
def test_distance_output_type(distance: Metric, rng: np.random.Generator) -> None:
145151
# Test if distances are outputting floats
146152
Distance = pt.tl.Distance(distance)
147153
X = rng.normal(size=(50, 10))
@@ -151,15 +157,16 @@ def test_distance_output_type(distance, rng):
151157

152158

153159
@mark.parametrize("distance", all_distances + onesided_only)
154-
def test_distance_onesided(adata, distance_obj, distance):
160+
def test_distance_onesided(adata: AnnData, distance_obj: Distance, distance: Metric) -> None:
161+
del distance
155162
# Test consistency of one-sided distance results
156-
selected_group = adata.obs.perturbation.unique()[0]
163+
selected_group = adata.obs["perturbation"].unique()[0]
157164
df = distance_obj.onesided_distances(adata, groupby="perturbation", selected_group=selected_group)
158165
assert isinstance(df, Series)
159166
assert df.loc[selected_group] == 0 # distance to self is 0
160167

161168

162-
def test_bootstrap_distance_output_type(rng):
169+
def test_bootstrap_distance_output_type(rng: np.random.Generator) -> None:
163170
# Test if distances are outputting floats
164171
Distance = pt.tl.Distance(metric="edistance")
165172
X = rng.normal(size=(50, 10))
@@ -170,7 +177,7 @@ def test_bootstrap_distance_output_type(rng):
170177

171178

172179
@mark.parametrize("distance", ["edistance"])
173-
def test_bootstrap_distance_pairwise(adata, distance):
180+
def test_bootstrap_distance_pairwise(adata: AnnData, distance: Metric) -> None:
174181
# Test consistency of pairwise distance results
175182
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
176183
bootstrap_output = Distance.pairwise(adata, groupby="perturbation", bootstrap=True, n_bootstrap=3)
@@ -186,9 +193,9 @@ def test_bootstrap_distance_pairwise(adata, distance):
186193

187194

188195
@mark.parametrize("distance", ["edistance"])
189-
def test_bootstrap_distance_onesided(adata, distance):
196+
def test_bootstrap_distance_onesided(adata: AnnData, distance: Metric) -> None:
190197
# Test consistency of one-sided distance results
191-
selected_group = adata.obs.perturbation.unique()[0]
198+
selected_group = adata.obs["perturbation"].unique()[0]
192199
Distance = pt.tl.Distance(distance, obsm_key="X_pca")
193200
bootstrap_output = Distance.onesided_distances(
194201
adata,
@@ -201,7 +208,7 @@ def test_bootstrap_distance_onesided(adata, distance):
201208
assert isinstance(bootstrap_output, tuple)
202209

203210

204-
def test_compare_distance(rng):
211+
def test_compare_distance(rng: np.random.Generator) -> None:
205212
X = rng.normal(size=(50, 10))
206213
Y = rng.normal(size=(50, 10))
207214
C = rng.normal(size=(50, 10))
@@ -211,4 +218,4 @@ def test_compare_distance(rng):
211218
res_scaled = Distance.compare_distance(X, Y, C, mode="scaled")
212219
assert isinstance(res_scaled, float)
213220
with pytest.raises(ValueError):
214-
Distance.compare_distance(X, Y, C, mode="new_mode")
221+
Distance.compare_distance(X, Y, C, mode="new_mode") # type: ignore[arg-type]

0 commit comments

Comments
 (0)