Skip to content

Commit a9f50fe

Browse files
authored
Collapse repeated operations to remove unnecessary delayed layers. (#66)
This reduces the number of operations in each DelayedArray object without changing the result, which makes the DA easier to interpret and parse.
1 parent d3efd7f commit a9f50fe

File tree

8 files changed

+188
-22
lines changed

8 files changed

+188
-22
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ python_requires = >=3.9
5050
install_requires =
5151
importlib-metadata; python_version<"3.8"
5252
numpy
53-
biocutils
53+
biocutils>=0.1.8
5454

5555

5656
[options.packages.find]

src/delayedarray/Combine.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Callable, Tuple, Sequence
1+
from typing import Callable, Tuple, Sequence, Any
22
import numpy
3+
import copy
34

45
from .DelayedOp import DelayedOp
56
from ._mask import _concatenate_unmasked_ndarrays, _concatenate_maybe_masked_ndarrays
@@ -101,6 +102,26 @@ def along(self) -> int:
101102
return self._along
102103

103104

105+
def _simplify_combine(x: Combine) -> Any:
106+
if len(x.seeds) == 1:
107+
return x.seeds[0]
108+
all_seeds = []
109+
simplified = False
110+
for ss in x.seeds:
111+
if type(ss) is Combine and x.along == ss.along:
112+
# Don't use isinstance, we don't want to collapse for Combine
113+
# subclasses that might be doing god knows what.
114+
all_seeds += ss.seeds
115+
simplified = True
116+
else:
117+
all_seeds.append(ss)
118+
if not simplified:
119+
return x
120+
new_x = copy.copy(x)
121+
new_x._seeds = all_seeds
122+
return new_x
123+
124+
104125
def _extract_subarrays(x: Combine, subset: Tuple[Sequence[int], ...], f: Callable):
105126
# Figuring out which slices belong to who.
106127
chosen = subset[x._along]

src/delayedarray/DelayedArray.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from .SparseNdarray import SparseNdarray
77
from .BinaryIsometricOp import BinaryIsometricOp
88
from .Cast import Cast
9-
from .Combine import Combine
9+
from .Combine import Combine, _simplify_combine
1010
from .Round import Round
11-
from .Subset import Subset
12-
from .Transpose import Transpose
11+
from .Subset import Subset, _simplify_subset
12+
from .Transpose import Transpose, _simplify_transpose
1313
from .UnaryIsometricOpSimple import UnaryIsometricOpSimple
1414
from .UnaryIsometricOpWithArgs import UnaryIsometricOpWithArgs
1515

@@ -136,7 +136,9 @@ def T(self) -> "DelayedArray":
136136
Returns:
137137
A ``DelayedArray`` containing the delayed transpose.
138138
"""
139-
return DelayedArray(Transpose(self._seed, perm=None))
139+
tout = Transpose(self._seed, perm=None)
140+
tout = _simplify_transpose(tout)
141+
return DelayedArray(tout)
140142

141143
def __repr__(self) -> str:
142144
"""Pretty-print this ``DelayedArray``. This uses
@@ -253,20 +255,23 @@ def __array_function__(self, func, types, args, kwargs) -> "DelayedArray":
253255
seeds = []
254256
for x in args[0]:
255257
seeds.append(_extract_seed(x))
256-
257258
if "axis" in kwargs:
258259
axis = kwargs["axis"]
259260
else:
260261
axis = 0
261-
return DelayedArray(Combine(seeds, along=axis))
262+
cout = Combine(seeds, along=axis)
263+
cout = _simplify_combine(cout)
264+
return DelayedArray(cout)
262265

263266
if func == numpy.transpose:
264267
seed = _extract_seed(args[0])
265268
if "axes" in kwargs:
266269
axes = kwargs["axes"]
267270
else:
268271
axes = None
269-
return DelayedArray(Transpose(seed, perm=axes))
272+
tout = Transpose(seed, perm=axes)
273+
tout = _simplify_transpose(tout)
274+
return DelayedArray(tout)
270275

271276
if func == numpy.round:
272277
seed = _extract_seed(args[0])
@@ -808,7 +813,9 @@ def __getitem__(self, subset: Tuple[Union[slice, Sequence], ...]) -> Union["Dela
808813
"""
809814
cleaned = _getitem_subset_preserves_dimensions(self.shape, subset)
810815
if cleaned is not None:
811-
return DelayedArray(Subset(self._seed, cleaned))
816+
sout = Subset(self._seed, cleaned)
817+
sout = _simplify_subset(sout)
818+
return DelayedArray(sout)
812819
return _getitem_subset_discards_dimensions(self._seed, subset, extract_dense_array)
813820

814821

src/delayedarray/Subset.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Callable, Sequence, Tuple
1+
from typing import Callable, Sequence, Tuple, Any
22
from numpy import dtype, ndarray, ix_
33
import numpy
4+
import biocutils
5+
import copy
46

57
from .DelayedOp import DelayedOp
68
from .SparseNdarray import SparseNdarray
7-
from ._subset import _sanitize_subset
9+
from ._subset import _sanitize_subset, _is_single_subset_noop
810
from .extract_dense_array import extract_dense_array
911
from .extract_sparse_array import extract_sparse_array
1012
from .create_dask_array import create_dask_array
@@ -87,18 +89,36 @@ def subset(self) -> Tuple[Sequence[int], ...]:
8789
return self._subset
8890

8991

92+
def _simplify_subset(x: Subset) -> Any:
93+
seed = x.seed
94+
if not type(seed) is Subset:
95+
# Don't use isinstance, we don't want to collapse for Subset
96+
# subclasses that might be doing god knows what.
97+
return x
98+
all_subsets = []
99+
noop = True
100+
for i, sub in enumerate(x.subset):
101+
seed_sub = seed.subset[i]
102+
new_sub = biocutils.subset_sequence(seed_sub, sub)
103+
if noop and not _is_single_subset_noop(seed.seed.shape[i], new_sub):
104+
noop = False
105+
all_subsets.append(new_sub)
106+
if noop:
107+
return seed.seed
108+
new_x = copy.copy(x)
109+
new_x._seed = seed.seed
110+
new_x._subset = (*all_subsets,)
111+
return new_x
112+
113+
90114
def _extract_array(x: Subset, subset: Tuple[Sequence[int], ...], f: Callable):
91115
newsub = list(subset)
92116
expanded = []
93117
is_safe = 0
94118

95119
for i, s in enumerate(newsub):
96120
cursub = x._subset[i]
97-
if isinstance(cursub, ndarray):
98-
replacement = cursub[s]
99-
else:
100-
replacement = [cursub[j] for j in s]
101-
121+
replacement = biocutils.subset_sequence(cursub, s)
102122
san_sub, san_remap = _sanitize_subset(replacement)
103123
newsub[i] = san_sub
104124

src/delayedarray/Transpose.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Callable, Optional, Tuple, Sequence
1+
from typing import Callable, Optional, Tuple, Sequence, Any
22
from numpy import dtype, transpose
33
import numpy
4+
import copy
45

56
from .DelayedOp import DelayedOp
67
from .SparseNdarray import SparseNdarray
@@ -40,8 +41,6 @@ def __init__(self, seed, perm: Optional[Tuple[int, ...]]):
4041
dimension ordering is assumed to be reversed.
4142
"""
4243

43-
self._seed = seed
44-
4544
curshape = seed.shape
4645
ndim = len(curshape)
4746
if perm is not None:
@@ -52,12 +51,12 @@ def __init__(self, seed, perm: Optional[Tuple[int, ...]]):
5251
else:
5352
perm = (*range(ndim - 1, -1, -1),)
5453

55-
self._perm = perm
56-
5754
final_shape = []
5855
for x in perm:
5956
final_shape.append(curshape[x])
6057

58+
self._seed = seed
59+
self._perm = perm
6160
self._shape = (*final_shape,)
6261

6362
@property
@@ -94,6 +93,29 @@ def perm(self) -> Tuple[int, ...]:
9493
return self._perm
9594

9695

96+
def _simplify_transpose(x: Transpose) -> Any:
97+
seed = x.seed
98+
if not type(seed) is Transpose:
99+
# Don't use isinstance, we don't want to collapse for Transpose
100+
# subclasses that might be doing god knows what.
101+
return x
102+
103+
new_perm = []
104+
noop = True
105+
for i, p in enumerate(x.perm):
106+
new_p = seed.perm[p]
107+
if new_p != i:
108+
noop = False
109+
new_perm.append(new_p)
110+
if noop:
111+
return seed.seed
112+
113+
new_x = copy.copy(x)
114+
new_x._seed = seed.seed
115+
new_x._perm = (*new_perm,)
116+
return new_x
117+
118+
97119
def _extract_array(x: Transpose, subset: Tuple[Sequence[int], ...], f: Callable):
98120
permsub = [None] * len(subset)
99121
for i, j in enumerate(x._perm):

tests/test_Combine.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,48 @@ def test_Combine_otherdim(left_mask_rate, right_mask_rate):
4343
assert_identical_ndarrays(delayedarray.to_dense_array(x), safe_concatenate((y1, y2), axis=1))
4444

4545

46+
def test_Combine_simplified():
47+
y1 = simulate_ndarray((30, 23), mask_rate=0)
48+
y2 = simulate_ndarray((50, 23), mask_rate=0)
49+
y3 = simulate_ndarray((30, 41), mask_rate=0)
50+
51+
x1 = delayedarray.DelayedArray(y1)
52+
x2 = delayedarray.DelayedArray(y2)
53+
x3 = delayedarray.DelayedArray(y3)
54+
55+
com = numpy.concatenate((x1, x2))
56+
com2 = numpy.concatenate((com, x2))
57+
assert isinstance(com2, delayedarray.DelayedArray)
58+
assert isinstance(com2.seed, delayedarray.Combine)
59+
assert len(com2.seed.seeds) == 3
60+
assert [isinstance(s, delayedarray.Combine) for s in com2.seed.seeds] == [False] * 3
61+
assert_identical_ndarrays(delayedarray.to_dense_array(com2), safe_concatenate((y1, y2, y2)))
62+
63+
com = numpy.concatenate((x1, x3), axis=1)
64+
com2 = numpy.concatenate((com, x1), axis=1)
65+
assert isinstance(com2, delayedarray.DelayedArray)
66+
assert isinstance(com2.seed, delayedarray.Combine)
67+
assert len(com2.seed.seeds) == 3
68+
assert [isinstance(s, delayedarray.Combine) for s in com2.seed.seeds] == [False] * 3
69+
assert_identical_ndarrays(delayedarray.to_dense_array(com2), safe_concatenate((y1, y3, y1), axis=1))
70+
71+
# No-ops properly.
72+
com = numpy.concatenate((x1,))
73+
assert isinstance(com, delayedarray.DelayedArray)
74+
assert isinstance(com.seed, numpy.ndarray)
75+
assert_identical_ndarrays(delayedarray.to_dense_array(com), y1)
76+
77+
# Doesn't attempt to collapse if the axes are different.
78+
com = numpy.concatenate((x1, x2))
79+
com2 = numpy.concatenate((com, com), axis=1)
80+
assert isinstance(com2, delayedarray.DelayedArray)
81+
assert isinstance(com2.seed, delayedarray.Combine)
82+
assert len(com2.seed.seeds) == 2
83+
assert [isinstance(s, delayedarray.Combine) for s in com2.seed.seeds] == [True] * 2
84+
ref = numpy.concatenate((y1, y2))
85+
assert_identical_ndarrays(delayedarray.to_dense_array(com2), safe_concatenate((ref, ref), axis=1))
86+
87+
4688
@pytest.mark.parametrize("left_mask_rate", [0, 0.2])
4789
@pytest.mark.parametrize("right_mask_rate", [0, 0.2])
4890
def test_Combine_subset(left_mask_rate, right_mask_rate):

tests/test_Subset.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import delayedarray
22
import numpy
33
import pytest
4+
import biocutils
45

56
from utils import simulate_ndarray, assert_identical_ndarrays, simulate_SparseNdarray
67

@@ -13,6 +14,9 @@ def test_Subset_ix(mask_rate):
1314

1415
subix = numpy.ix_(range(1, 10), [20, 30, 40], [10, 11, 12, 13])
1516
sub = x[subix]
17+
assert isinstance(sub, delayedarray.DelayedArray)
18+
assert isinstance(sub.seed, delayedarray.Subset)
19+
1620
assert sub.shape == (9, 3, 4)
1721
assert isinstance(sub.seed.seed, numpy.ndarray)
1822
assert len(sub.seed.subset) == 3
@@ -88,6 +92,33 @@ def test_Subset_unsorted_duplicates(mask_rate):
8892
assert_identical_ndarrays(delayedarray.to_dense_array(sub), y[:, [5, 4, 3, 2, 1, 0], :])
8993

9094

95+
def test_Subset_simplified():
96+
test_shape = (30, 55)
97+
y = simulate_ndarray(test_shape, mask_rate=0)
98+
x = delayedarray.DelayedArray(y)
99+
100+
sub = x[:, list(range(0, 55, 2))]
101+
sub2 = sub[:, list(range(5, 20))]
102+
assert isinstance(sub2, delayedarray.DelayedArray)
103+
assert isinstance(sub2.seed, delayedarray.Subset)
104+
assert isinstance(sub2.seed.seed, numpy.ndarray)
105+
assert_identical_ndarrays(delayedarray.to_dense_array(sub2), y[:, biocutils.subset_sequence(range(0, 55, 2), range(5, 20))])
106+
107+
sub = x[list(range(10, 20)), :]
108+
sub2 = sub[:, list(range(0, 55, 5))]
109+
assert isinstance(sub2, delayedarray.DelayedArray)
110+
assert isinstance(sub2.seed, delayedarray.Subset)
111+
assert isinstance(sub2.seed.seed, numpy.ndarray)
112+
assert_identical_ndarrays(delayedarray.to_dense_array(sub2), y[10:20,0:55:5])
113+
114+
# Identifies no-ops and returns the seed directly.
115+
sub = x[::-1,::-1]
116+
sub2 = sub[::-1,::-1]
117+
assert isinstance(sub2, delayedarray.DelayedArray)
118+
assert isinstance(sub2.seed, numpy.ndarray)
119+
assert_identical_ndarrays(delayedarray.to_dense_array(sub2), y)
120+
121+
91122
@pytest.mark.parametrize("mask_rate", [0, 0.2])
92123
def test_Subset_subset(mask_rate):
93124
y = simulate_ndarray((99, 63), mask_rate=mask_rate)

tests/test_Transpose.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,29 @@ def test_Transpose_simple(mask_rate):
2525
assert_identical_ndarrays(delayedarray.to_dense_array(t), numpy.transpose(y))
2626

2727

28+
def test_Transpose_simplified():
29+
y = simulate_ndarray((30, 23, 5), mask_rate=0)
30+
x = delayedarray.DelayedArray(y)
31+
32+
t = x.T
33+
t2 = t.T
34+
assert isinstance(t2, delayedarray.DelayedArray)
35+
assert isinstance(t2.seed, numpy.ndarray)
36+
assert_identical_ndarrays(delayedarray.to_dense_array(t2), y.T.T)
37+
38+
t2 = numpy.transpose(t, axes=(2, 1, 0))
39+
assert isinstance(t2, delayedarray.DelayedArray)
40+
assert isinstance(t2.seed, numpy.ndarray)
41+
assert_identical_ndarrays(delayedarray.to_dense_array(t2), numpy.transpose(y.T, (2, 1, 0)))
42+
43+
t2 = numpy.transpose(t, axes=(1, 2, 0))
44+
assert isinstance(t2, delayedarray.DelayedArray)
45+
assert isinstance(t2.seed, delayedarray.Transpose)
46+
assert t2.seed.perm == (1, 0, 2)
47+
assert isinstance(t2.seed.seed, numpy.ndarray)
48+
assert_identical_ndarrays(delayedarray.to_dense_array(t2), numpy.transpose(y.T, axes=(1, 2, 0)))
49+
50+
2851
@pytest.mark.parametrize("mask_rate", [0, 0.2])
2952
def test_Transpose_more_dimensions(mask_rate):
3053
y = simulate_ndarray((30, 23, 10), mask_rate=mask_rate)

0 commit comments

Comments
 (0)