Skip to content

Commit a368332

Browse files
committed
Simplify DelayedArrays by collapsing repeated operations.
This reduces the number of operations in each DelayedArray object without changing the result, which makes the DA easier to interpret and parse.
1 parent 3999fd0 commit a368332

File tree

7 files changed

+185
-21
lines changed

7 files changed

+185
-21
lines changed

src/delayedarray/Combine.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Callable, Tuple, Sequence
1+
from typing import Callable, Tuple, Sequence, Any
22
import numpy
3+
import copy
34

45
from .DelayedOp import DelayedOp
56
from ._mask import _concatenate_unmasked_ndarrays, _concatenate_maybe_masked_ndarrays
@@ -101,6 +102,24 @@ def along(self) -> int:
101102
return self._along
102103

103104

105+
def _simplify_combine(x: Combine) -> Any:
106+
all_seeds = []
107+
for ss in x.seeds:
108+
if type(ss) is Combine and x.along == ss.along:
109+
# Don't use isinstance, we don't want to collapse for Combine
110+
# subclasses that might be doing god knows what.
111+
all_seeds += ss.seeds
112+
else:
113+
all_seeds.append(ss)
114+
if len(all_seeds) == 1:
115+
return all_seeds[0]
116+
if len(all_seeds) == len(x.seeds):
117+
return x
118+
new_x = copy.copy(x)
119+
new_x._seeds = all_seeds
120+
return new_x
121+
122+
104123
def _extract_subarrays(x: Combine, subset: Tuple[Sequence[int], ...], f: Callable):
105124
# Figuring out which slices belong to who.
106125
chosen = subset[x._along]

src/delayedarray/DelayedArray.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from .SparseNdarray import SparseNdarray
77
from .BinaryIsometricOp import BinaryIsometricOp
88
from .Cast import Cast
9-
from .Combine import Combine
9+
from .Combine import Combine, _simplify_combine
1010
from .Round import Round
11-
from .Subset import Subset
12-
from .Transpose import Transpose
11+
from .Subset import Subset, _simplify_subset
12+
from .Transpose import Transpose, _simplify_transpose
1313
from .UnaryIsometricOpSimple import UnaryIsometricOpSimple
1414
from .UnaryIsometricOpWithArgs import UnaryIsometricOpWithArgs
1515

@@ -136,7 +136,9 @@ def T(self) -> "DelayedArray":
136136
Returns:
137137
A ``DelayedArray`` containing the delayed transpose.
138138
"""
139-
return DelayedArray(Transpose(self._seed, perm=None))
139+
tout = Transpose(self._seed, perm=None)
140+
tout = _simplify_transpose(tout)
141+
return DelayedArray(tout)
140142

141143
def __repr__(self) -> str:
142144
"""Pretty-print this ``DelayedArray``. This uses
@@ -253,20 +255,23 @@ def __array_function__(self, func, types, args, kwargs) -> "DelayedArray":
253255
seeds = []
254256
for x in args[0]:
255257
seeds.append(_extract_seed(x))
256-
257258
if "axis" in kwargs:
258259
axis = kwargs["axis"]
259260
else:
260261
axis = 0
261-
return DelayedArray(Combine(seeds, along=axis))
262+
cout = Combine(seeds, along=axis)
263+
cout = _simplify_combine(cout)
264+
return DelayedArray(cout)
262265

263266
if func == numpy.transpose:
264267
seed = _extract_seed(args[0])
265268
if "axes" in kwargs:
266269
axes = kwargs["axes"]
267270
else:
268271
axes = None
269-
return DelayedArray(Transpose(seed, perm=axes))
272+
tout = Transpose(seed, perm=axes)
273+
tout = _simplify_transpose(tout)
274+
return DelayedArray(tout)
270275

271276
if func == numpy.round:
272277
seed = _extract_seed(args[0])
@@ -808,7 +813,9 @@ def __getitem__(self, subset: Tuple[Union[slice, Sequence], ...]) -> Union["Dela
808813
"""
809814
cleaned = _getitem_subset_preserves_dimensions(self.shape, subset)
810815
if cleaned is not None:
811-
return DelayedArray(Subset(self._seed, cleaned))
816+
sout = Subset(self._seed, cleaned)
817+
sout = _simplify_subset(sout)
818+
return DelayedArray(sout)
812819
return _getitem_subset_discards_dimensions(self._seed, subset, extract_dense_array)
813820

814821

src/delayedarray/Subset.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Callable, Sequence, Tuple
1+
from typing import Callable, Sequence, Tuple, Any
22
from numpy import dtype, ndarray, ix_
33
import numpy
4+
import biocutils
5+
import copy
46

57
from .DelayedOp import DelayedOp
68
from .SparseNdarray import SparseNdarray
7-
from ._subset import _sanitize_subset
9+
from ._subset import _sanitize_subset, _is_single_subset_noop
810
from .extract_dense_array import extract_dense_array
911
from .extract_sparse_array import extract_sparse_array
1012
from .create_dask_array import create_dask_array
@@ -87,18 +89,36 @@ def subset(self) -> Tuple[Sequence[int], ...]:
8789
return self._subset
8890

8991

92+
def _simplify_subset(x: Subset) -> Any:
93+
seed = x.seed
94+
if not type(seed) is Subset:
95+
# Don't use isinstance, we don't want to collapse for Subset
96+
# subclasses that might be doing god knows what.
97+
return x
98+
all_subsets = []
99+
noop = True
100+
for i, sub in enumerate(x.subset):
101+
seed_sub = seed.subset[i]
102+
new_sub = biocutils.subset_sequence(seed_sub, sub)
103+
if noop and not _is_single_subset_noop(seed.seed.shape[i], new_sub):
104+
noop = False
105+
all_subsets.append(new_sub)
106+
if noop:
107+
return seed.seed
108+
new_x = copy.copy(x)
109+
new_x._seed = seed.seed
110+
new_x._subset = (*all_subsets,)
111+
return new_x
112+
113+
90114
def _extract_array(x: Subset, subset: Tuple[Sequence[int], ...], f: Callable):
91115
newsub = list(subset)
92116
expanded = []
93117
is_safe = 0
94118

95119
for i, s in enumerate(newsub):
96120
cursub = x._subset[i]
97-
if isinstance(cursub, ndarray):
98-
replacement = cursub[s]
99-
else:
100-
replacement = [cursub[j] for j in s]
101-
121+
replacement = biocutils.subset_sequence(cursub, s)
102122
san_sub, san_remap = _sanitize_subset(replacement)
103123
newsub[i] = san_sub
104124

src/delayedarray/Transpose.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Callable, Optional, Tuple, Sequence
1+
from typing import Callable, Optional, Tuple, Sequence, Any
22
from numpy import dtype, transpose
33
import numpy
4+
import copy
45

56
from .DelayedOp import DelayedOp
67
from .SparseNdarray import SparseNdarray
@@ -40,8 +41,6 @@ def __init__(self, seed, perm: Optional[Tuple[int, ...]]):
4041
dimension ordering is assumed to be reversed.
4142
"""
4243

43-
self._seed = seed
44-
4544
curshape = seed.shape
4645
ndim = len(curshape)
4746
if perm is not None:
@@ -52,12 +51,12 @@ def __init__(self, seed, perm: Optional[Tuple[int, ...]]):
5251
else:
5352
perm = (*range(ndim - 1, -1, -1),)
5453

55-
self._perm = perm
56-
5754
final_shape = []
5855
for x in perm:
5956
final_shape.append(curshape[x])
6057

58+
self._seed = seed
59+
self._perm = perm
6160
self._shape = (*final_shape,)
6261

6362
@property
@@ -94,6 +93,29 @@ def perm(self) -> Tuple[int, ...]:
9493
return self._perm
9594

9695

96+
def _simplify_transpose(x: Transpose) -> Any:
97+
seed = x.seed
98+
if not type(seed) is Transpose:
99+
# Don't use isinstance, we don't want to collapse for Transpose
100+
# subclasses that might be doing god knows what.
101+
return x
102+
103+
new_perm = []
104+
noop = True
105+
for i, p in enumerate(x.perm):
106+
new_p = seed.perm[p]
107+
if new_p != i:
108+
noop = False
109+
new_perm.append(new_p)
110+
if noop:
111+
return seed.seed
112+
113+
new_x = copy.copy(x)
114+
new_x._seed = seed.seed
115+
new_x._perm = (*new_perm,)
116+
return new_x
117+
118+
97119
def _extract_array(x: Transpose, subset: Tuple[Sequence[int], ...], f: Callable):
98120
permsub = [None] * len(subset)
99121
for i, j in enumerate(x._perm):

tests/test_Combine.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,48 @@ def test_Combine_otherdim(left_mask_rate, right_mask_rate):
4343
assert_identical_ndarrays(delayedarray.to_dense_array(x), safe_concatenate((y1, y2), axis=1))
4444

4545

46+
def test_Combine_simplified():
47+
y1 = simulate_ndarray((30, 23), mask_rate=0)
48+
y2 = simulate_ndarray((50, 23), mask_rate=0)
49+
y3 = simulate_ndarray((30, 41), mask_rate=0)
50+
51+
x1 = delayedarray.DelayedArray(y1)
52+
x2 = delayedarray.DelayedArray(y2)
53+
x3 = delayedarray.DelayedArray(y3)
54+
55+
com = numpy.concatenate((x1, x2))
56+
com2 = numpy.concatenate((com, x2))
57+
assert isinstance(com2, delayedarray.DelayedArray)
58+
assert isinstance(com2.seed, delayedarray.Combine)
59+
assert len(com2.seed.seeds) == 3
60+
assert [isinstance(s, delayedarray.Combine) for s in com2.seed.seeds] == [False] * 3
61+
assert_identical_ndarrays(delayedarray.to_dense_array(com2), safe_concatenate((y1, y2, y2)))
62+
63+
com = numpy.concatenate((x1, x3), axis=1)
64+
com2 = numpy.concatenate((com, x1), axis=1)
65+
assert isinstance(com2, delayedarray.DelayedArray)
66+
assert isinstance(com2.seed, delayedarray.Combine)
67+
assert len(com2.seed.seeds) == 3
68+
assert [isinstance(s, delayedarray.Combine) for s in com2.seed.seeds] == [False] * 3
69+
assert_identical_ndarrays(delayedarray.to_dense_array(com2), safe_concatenate((y1, y3, y1), axis=1))
70+
71+
# No-ops properly.
72+
com = numpy.concatenate((x1,))
73+
assert isinstance(com, delayedarray.DelayedArray)
74+
assert isinstance(com.seed, numpy.ndarray)
75+
assert_identical_ndarrays(delayedarray.to_dense_array(com), y1)
76+
77+
# Doesn't attempt to collapse if the axes are different.
78+
com = numpy.concatenate((x1, x2))
79+
com2 = numpy.concatenate((com, com), axis=1)
80+
assert isinstance(com2, delayedarray.DelayedArray)
81+
assert isinstance(com2.seed, delayedarray.Combine)
82+
assert len(com2.seed.seeds) == 2
83+
assert [isinstance(s, delayedarray.Combine) for s in com2.seed.seeds] == [True] * 2
84+
ref = numpy.concatenate((y1, y2))
85+
assert_identical_ndarrays(delayedarray.to_dense_array(com2), safe_concatenate((ref, ref), axis=1))
86+
87+
4688
@pytest.mark.parametrize("left_mask_rate", [0, 0.2])
4789
@pytest.mark.parametrize("right_mask_rate", [0, 0.2])
4890
def test_Combine_subset(left_mask_rate, right_mask_rate):

tests/test_Subset.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import delayedarray
22
import numpy
33
import pytest
4+
import biocutils
45

56
from utils import simulate_ndarray, assert_identical_ndarrays, simulate_SparseNdarray
67

@@ -13,6 +14,9 @@ def test_Subset_ix(mask_rate):
1314

1415
subix = numpy.ix_(range(1, 10), [20, 30, 40], [10, 11, 12, 13])
1516
sub = x[subix]
17+
assert isinstance(sub, delayedarray.DelayedArray)
18+
assert isinstance(sub.seed, delayedarray.Subset)
19+
1620
assert sub.shape == (9, 3, 4)
1721
assert isinstance(sub.seed.seed, numpy.ndarray)
1822
assert len(sub.seed.subset) == 3
@@ -88,6 +92,33 @@ def test_Subset_unsorted_duplicates(mask_rate):
8892
assert_identical_ndarrays(delayedarray.to_dense_array(sub), y[:, [5, 4, 3, 2, 1, 0], :])
8993

9094

95+
def test_Subset_simplified():
96+
test_shape = (30, 55)
97+
y = simulate_ndarray(test_shape, mask_rate=0)
98+
x = delayedarray.DelayedArray(y)
99+
100+
sub = x[:, list(range(0, 55, 2))]
101+
sub2 = sub[:, list(range(5, 20))]
102+
assert isinstance(sub2, delayedarray.DelayedArray)
103+
assert isinstance(sub2.seed, delayedarray.Subset)
104+
assert isinstance(sub2.seed.seed, numpy.ndarray)
105+
assert_identical_ndarrays(delayedarray.to_dense_array(sub2), y[:, biocutils.subset_sequence(range(0, 55, 2), range(5, 20))])
106+
107+
sub = x[list(range(10, 20)), :]
108+
sub2 = sub[:, list(range(0, 55, 5))]
109+
assert isinstance(sub2, delayedarray.DelayedArray)
110+
assert isinstance(sub2.seed, delayedarray.Subset)
111+
assert isinstance(sub2.seed.seed, numpy.ndarray)
112+
assert_identical_ndarrays(delayedarray.to_dense_array(sub2), y[10:20,0:55:5])
113+
114+
# Identifies no-ops and returns the seed directly.
115+
sub = x[::-1,::-1]
116+
sub2 = sub[::-1,::-1]
117+
assert isinstance(sub2, delayedarray.DelayedArray)
118+
assert isinstance(sub2.seed, numpy.ndarray)
119+
assert_identical_ndarrays(delayedarray.to_dense_array(sub2), y)
120+
121+
91122
@pytest.mark.parametrize("mask_rate", [0, 0.2])
92123
def test_Subset_subset(mask_rate):
93124
y = simulate_ndarray((99, 63), mask_rate=mask_rate)

tests/test_Transpose.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,29 @@ def test_Transpose_simple(mask_rate):
2525
assert_identical_ndarrays(delayedarray.to_dense_array(t), numpy.transpose(y))
2626

2727

28+
def test_Transpose_simplified():
29+
y = simulate_ndarray((30, 23, 5), mask_rate=0)
30+
x = delayedarray.DelayedArray(y)
31+
32+
t = x.T
33+
t2 = t.T
34+
assert isinstance(t2, delayedarray.DelayedArray)
35+
assert isinstance(t2.seed, numpy.ndarray)
36+
assert_identical_ndarrays(delayedarray.to_dense_array(t2), y.T.T)
37+
38+
t2 = numpy.transpose(t, axes=(2, 1, 0))
39+
assert isinstance(t2, delayedarray.DelayedArray)
40+
assert isinstance(t2.seed, numpy.ndarray)
41+
assert_identical_ndarrays(delayedarray.to_dense_array(t2), numpy.transpose(y.T, (2, 1, 0)))
42+
43+
t2 = numpy.transpose(t, axes=(1, 2, 0))
44+
assert isinstance(t2, delayedarray.DelayedArray)
45+
assert isinstance(t2.seed, delayedarray.Transpose)
46+
assert t2.seed.perm == (1, 0, 2)
47+
assert isinstance(t2.seed.seed, numpy.ndarray)
48+
assert_identical_ndarrays(delayedarray.to_dense_array(t2), numpy.transpose(y.T, axes=(1, 2, 0)))
49+
50+
2851
@pytest.mark.parametrize("mask_rate", [0, 0.2])
2952
def test_Transpose_more_dimensions(mask_rate):
3053
y = simulate_ndarray((30, 23, 10), mask_rate=mask_rate)

0 commit comments

Comments
 (0)