6
6
(chunked in one resp. two dimensions), with support for efficient storage and retrieval
7
7
using the Zarr library.
8
8
"""
9
+ from __future__ import annotations
9
10
10
11
import logging
11
12
from abc import ABC , abstractmethod
12
- from typing import Callable , Generator , Generic , List , Optional , Tuple , Union
13
+ from typing import (
14
+ Callable ,
15
+ Generator ,
16
+ Generic ,
17
+ Iterator ,
18
+ List ,
19
+ Optional ,
20
+ Tuple ,
21
+ Union ,
22
+ cast ,
23
+ )
13
24
14
25
import zarr
15
26
from numpy .typing import NDArray
27
+ from tqdm import tqdm
16
28
from zarr .storage import StoreLike
17
29
18
30
from ..utils import log_duration
@@ -35,9 +47,12 @@ def from_numpy(self, x: NDArray) -> TensorType:
35
47
36
48
class SequenceAggregator (Generic [TensorType ], ABC ):
37
49
@abstractmethod
38
- def __call__ (self , tensor_generator : Generator [TensorType , None , None ]):
50
+ def __call__ (
51
+ self ,
52
+ tensor_sequence : LazyChunkSequence ,
53
+ ):
39
54
"""
40
- Aggregates tensors from a generator .
55
+ Aggregates tensors from a sequence .
41
56
42
57
Implement this method to define how a sequence of tensors, provided by a
43
58
generator, should be combined.
@@ -46,31 +61,37 @@ def __call__(self, tensor_generator: Generator[TensorType, None, None]):
46
61
47
62
class ListAggregator (SequenceAggregator ):
48
63
def __call__ (
49
- self , tensor_generator : Generator [TensorType , None , None ]
64
+ self ,
65
+ tensor_sequence : LazyChunkSequence ,
50
66
) -> List [TensorType ]:
51
67
"""
52
68
Aggregates tensors from a single-level generator into a list. This method simply
53
69
collects each tensor emitted by the generator into a single list.
54
70
55
71
Args:
56
- tensor_generator: A generator that yields TensorType objects.
72
+ tensor_sequence: Object wrapping a generator that yields `TensorType`
73
+ objects.
57
74
58
75
Returns:
59
76
A list containing all the tensors provided by the tensor_generator.
60
77
"""
61
- return [t for t in tensor_generator ]
78
+
79
+ gen = cast (Iterator [TensorType ], tensor_sequence .generator_factory ())
80
+
81
+ if tensor_sequence .len_generator is not None :
82
+ gen = cast (
83
+ Iterator [TensorType ],
84
+ tqdm (gen , total = tensor_sequence .len_generator , desc = "Blocks" ),
85
+ )
86
+
87
+ return [t for t in gen ]
62
88
63
89
64
90
class NestedSequenceAggregator (Generic [TensorType ], ABC ):
65
91
@abstractmethod
66
- def __call__ (
67
- self ,
68
- nested_generators_of_tensors : Generator [
69
- Generator [TensorType , None , None ], None , None
70
- ],
71
- ):
92
+ def __call__ (self , nested_sequence_of_tensors : NestedLazyChunkSequence ):
72
93
"""
73
- Aggregates tensors from a generator of generators .
94
+ Aggregates tensors from a nested sequence of tensors .
74
95
75
96
Implement this method to specify how tensors, nested in two layers of
76
97
generators, should be combined. Useful for complex data structures where tensors
@@ -81,27 +102,36 @@ def __call__(
81
102
class NestedListAggregator (NestedSequenceAggregator ):
82
103
def __call__ (
83
104
self ,
84
- nested_generators_of_tensors : Generator [
85
- Generator [TensorType , None , None ], None , None
86
- ],
105
+ nested_sequence_of_tensors : NestedLazyChunkSequence ,
87
106
) -> List [List [TensorType ]]:
88
107
"""
89
108
Aggregates tensors from a nested generator structure into a list of lists.
90
109
Each inner generator is converted into a list of tensors, resulting in a nested
91
110
list structure.
92
111
93
112
Args:
94
- nested_generators_of_tensors: A generator of generators, where each inner
95
- generator yields TensorType objects.
113
+ nested_sequence_of_tensors: Object wrapping a generator of generators,
114
+ where each inner generator yields TensorType objects.
96
115
97
116
Returns:
98
117
A list of lists, where each inner list contains tensors returned from one
99
118
of the inner generators.
100
119
"""
101
- return [list (tensor_gen ) for tensor_gen in nested_generators_of_tensors ]
120
+ outer_gen = cast (
121
+ Iterator [Iterator [TensorType ]],
122
+ nested_sequence_of_tensors .generator_factory (),
123
+ )
124
+ len_outer_gen = nested_sequence_of_tensors .len_outer_generator
125
+ if len_outer_gen is not None :
126
+ outer_gen = cast (
127
+ Iterator [Iterator [TensorType ]],
128
+ tqdm (outer_gen , total = len_outer_gen , desc = "Row blocks" ),
129
+ )
102
130
131
+ return [list (tensor_gen ) for tensor_gen in outer_gen ]
103
132
104
- class LazyChunkSequence :
133
+
134
+ class LazyChunkSequence (Generic [TensorType ]):
105
135
"""
106
136
A class representing a chunked, and lazily evaluated array,
107
137
where the chunking is restricted to the first dimension
@@ -114,12 +144,18 @@ class LazyChunkSequence:
114
144
Attributes:
115
145
generator_factory: A factory function that returns
116
146
a generator. This generator yields chunks of the large array when called.
147
+ len_generator: if the number of elements from the generator is
148
+ known from the context, this optional parameter can be used to improve
149
+ logging by adding a progressbar.
117
150
"""
118
151
119
152
def __init__ (
120
- self , generator_factory : Callable [[], Generator [TensorType , None , None ]]
153
+ self ,
154
+ generator_factory : Callable [[], Generator [TensorType , None , None ]],
155
+ len_generator : Optional [int ] = None ,
121
156
):
122
157
self .generator_factory = generator_factory
158
+ self .len_generator = len_generator
123
159
124
160
@log_duration (log_level = logging .INFO )
125
161
def compute (self , aggregator : Optional [SequenceAggregator ] = None ):
@@ -140,7 +176,7 @@ def compute(self, aggregator: Optional[SequenceAggregator] = None):
140
176
"""
141
177
if aggregator is None :
142
178
aggregator = ListAggregator ()
143
- return aggregator (self . generator_factory () )
179
+ return aggregator (self )
144
180
145
181
@log_duration (log_level = logging .INFO )
146
182
def to_zarr (
@@ -171,7 +207,15 @@ def to_zarr(
171
207
"""
172
208
row_idx = 0
173
209
z = None
174
- for block in self .generator_factory ():
210
+
211
+ gen = cast (Iterator [TensorType ], self .generator_factory ())
212
+
213
+ if self .len_generator is not None :
214
+ gen = cast (
215
+ Iterator [TensorType ], tqdm (gen , total = self .len_generator , desc = "Blocks" )
216
+ )
217
+
218
+ for block in gen :
175
219
numpy_block = converter .to_numpy (block )
176
220
177
221
if z is None :
@@ -204,7 +248,7 @@ def _initialize_zarr_array(block: NDArray, path_or_url: str, overwrite: bool):
204
248
)
205
249
206
250
207
- class NestedLazyChunkSequence :
251
+ class NestedLazyChunkSequence ( Generic [ TensorType ]) :
208
252
"""
209
253
A class representing chunked, and lazily evaluated array, where the chunking is
210
254
restricted to the first two dimensions.
@@ -216,16 +260,21 @@ class NestedLazyChunkSequence:
216
260
217
261
Attributes:
218
262
generator_factory: A factory function that returns a generator of generators.
219
- Each inner generator yields chunks.
263
+ Each inner generator yields chunks
264
+ len_outer_generator: if the number of elements from the outer generator is
265
+ known from the context, this optional parameter can be used to improve
266
+ logging by adding a progressbar.
220
267
"""
221
268
222
269
def __init__ (
223
270
self ,
224
271
generator_factory : Callable [
225
272
[], Generator [Generator [TensorType , None , None ], None , None ]
226
273
],
274
+ len_outer_generator : Optional [int ] = None ,
227
275
):
228
276
self .generator_factory = generator_factory
277
+ self .len_outer_generator = len_outer_generator
229
278
230
279
@log_duration (log_level = logging .INFO )
231
280
def compute (self , aggregator : Optional [NestedSequenceAggregator ] = None ):
@@ -247,7 +296,7 @@ def compute(self, aggregator: Optional[NestedSequenceAggregator] = None):
247
296
"""
248
297
if aggregator is None :
249
298
aggregator = NestedListAggregator ()
250
- return aggregator (self . generator_factory () )
299
+ return aggregator (self )
251
300
252
301
@log_duration (log_level = logging .INFO )
253
302
def to_zarr (
@@ -280,7 +329,17 @@ def to_zarr(
280
329
row_idx = 0
281
330
z = None
282
331
numpy_block = None
283
- for row_blocks in self .generator_factory ():
332
+ block_generator = cast (Iterator [Iterator [TensorType ]], self .generator_factory ())
333
+
334
+ if self .len_outer_generator is not None :
335
+ block_generator = cast (
336
+ Iterator [Iterator [TensorType ]],
337
+ tqdm (
338
+ block_generator , total = self .len_outer_generator , desc = "Row blocks"
339
+ ),
340
+ )
341
+
342
+ for row_blocks in block_generator :
284
343
col_idx = 0
285
344
for block in row_blocks :
286
345
numpy_block = converter .to_numpy (block )
0 commit comments