Skip to content

Commit 3472949

Browse files
committed
1. update pool_decorator 2. convert lambda func to def func to avoid conflicts with pickle
1 parent a4d5c6d commit 3472949

File tree

1 file changed

+1
-113
lines changed

1 file changed

+1
-113
lines changed

tests/test_expected.py

+1-113
Original file line numberDiff line numberDiff line change
@@ -732,116 +732,4 @@ def test_diagsum_from_array():
732732
exp = _diagsum_symm_dense(ar, bad_bins=list(range(3, 5)))
733733
exp1 = diagsum_from_array(ar, ignore_diags=0)
734734
exp1["balanced.avg"] = exp1["balanced.sum"] / exp1["n_valid"]
735-
assert np.allclose(exp, exp1["balanced.avg"].values, equal_nan=True)
736-
737-
from concurrent.futures import ProcessPoolExecutor
738-
def test_multiprocessing_expected_cis(request):
739-
# perform test:
740-
clr = cooler.Cooler(op.join(request.fspath.dirname, "data/CN.mm9.1000kb.cool"))
741-
# symm result - engaging diagsum_symm
742-
with ProcessPoolExecutor(8) as p:
743-
res_symm = cooltools.api.expected.expected_cis(
744-
clr,
745-
view_df=view_df,
746-
clr_weight_name=clr_weight_name,
747-
chunksize=chunksize,
748-
ignore_diags=ignore_diags,
749-
map_functor=p.map
750-
)
751-
752-
# check column names
753-
assert list(res_symm.columns) == [
754-
"region1",
755-
"region2",
756-
"dist",
757-
"dist_bp",
758-
"contact_frequency",
759-
"n_total",
760-
"n_valid",
761-
"count.sum",
762-
"balanced.sum",
763-
"count.avg",
764-
"balanced.avg",
765-
"balanced.avg.smoothed",
766-
"balanced.avg.smoothed.agg",
767-
]
768-
769-
# check results for every block
770-
grouped = res_symm.groupby(["region1", "region2"])
771-
for (name1, name2), group in grouped:
772-
assert name1 == name2
773-
matrix = clr.matrix(balance=clr_weight_name).fetch(name1)
774-
desired_expected = _diagsum_symm_dense(matrix)
775-
# fill nan for ignored diags
776-
desired_expected = np.where(
777-
group["dist"] < ignore_diags, np.nan, desired_expected
778-
)
779-
testing.assert_allclose(
780-
actual=group["balanced.avg"].values,
781-
desired=desired_expected,
782-
equal_nan=True,
783-
)
784-
785-
# check column names, when clr_weight_name = None, which is the unbalanced case
786-
with ProcessPoolExecutor(8) as p:
787-
res_symm = cooltools.api.expected.expected_cis(
788-
clr,
789-
view_df=view_df,
790-
clr_weight_name=None,
791-
chunksize=chunksize,
792-
ignore_diags=ignore_diags,
793-
map_functor=p.map
794-
)
795-
assert list(res_symm.columns) == [
796-
"region1",
797-
"region2",
798-
"dist",
799-
"dist_bp",
800-
"contact_frequency",
801-
"n_total",
802-
"n_valid",
803-
"count.sum",
804-
"count.avg",
805-
"count.avg.smoothed",
806-
"count.avg.smoothed.agg",
807-
]
808-
809-
# asymm and symm result together - engaging diagsum_pairwise
810-
res_all = cooltools.api.expected.expected_cis(
811-
clr,
812-
view_df=view_df,
813-
intra_only=False,
814-
clr_weight_name=clr_weight_name,
815-
chunksize=chunksize,
816-
ignore_diags=ignore_diags,
817-
)
818-
# check results for every block
819-
grouped = res_all.groupby(["region1", "region2"])
820-
for (name1, name2), group in grouped:
821-
matrix = clr.matrix(balance=clr_weight_name).fetch(name1, name2)
822-
desired_expected = (
823-
_diagsum_asymm_dense(matrix)
824-
if (name1 != name2)
825-
else _diagsum_symm_dense(matrix)
826-
)
827-
# fill nan for ignored diags
828-
desired_expected = np.where(
829-
group["dist"] < ignore_diags, np.nan, desired_expected
830-
)
831-
testing.assert_allclose(
832-
actual=group["balanced.avg"].values,
833-
desired=desired_expected,
834-
equal_nan=True,
835-
)
836-
837-
# check multiprocessed result
838-
res_all_pooled = cooltools.api.expected.expected_cis(
839-
clr,
840-
view_df=view_df,
841-
intra_only=False,
842-
clr_weight_name=clr_weight_name,
843-
chunksize=chunksize,
844-
ignore_diags=ignore_diags,
845-
nproc=3,
846-
)
847-
assert res_all.equals(res_all_pooled)
735+
assert np.allclose(exp, exp1["balanced.avg"].values, equal_nan=True)

0 commit comments

Comments
 (0)