@@ -55,8 +55,8 @@ def __init__(
55
55
postprocessing : List [Processing ],
56
56
model_adapter : ModelAdapter ,
57
57
default_ns : Union [
58
- v0_5 .ParameterizedSize . N ,
59
- Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize . N ],
58
+ v0_5 .ParameterizedSize_N ,
59
+ Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
60
60
] = 10 ,
61
61
default_batch_size : int = 1 ,
62
62
) -> None :
@@ -179,40 +179,17 @@ def get_output_sample_id(self, input_sample_id: SampleId):
179
179
self .model_description .id or self .model_description .name
180
180
)
181
181
182
- def predict_sample_with_blocking (
182
+ def predict_sample_with_fixed_blocking (
183
183
self ,
184
184
sample : Sample ,
185
+ input_block_shape : Mapping [MemberId , Mapping [AxisId , int ]],
186
+ * ,
185
187
skip_preprocessing : bool = False ,
186
188
skip_postprocessing : bool = False ,
187
- ns : Optional [
188
- Union [
189
- v0_5 .ParameterizedSize .N ,
190
- Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize .N ],
191
- ]
192
- ] = None ,
193
- batch_size : Optional [int ] = None ,
194
189
) -> Sample :
195
- """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
196
190
if not skip_preprocessing :
197
191
self .apply_preprocessing (sample )
198
192
199
- if isinstance (self .model_description , v0_4 .ModelDescr ):
200
- raise NotImplementedError (
201
- "predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}"
202
- )
203
-
204
- ns = ns or self ._default_ns
205
- if isinstance (ns , int ):
206
- ns = {
207
- (ipt .id , a .id ): ns
208
- for ipt in self .model_description .inputs
209
- for a in ipt .axes
210
- if isinstance (a .size , v0_5 .ParameterizedSize )
211
- }
212
- input_block_shape = self .model_description .get_tensor_sizes (
213
- ns , batch_size or self ._default_batch_size
214
- ).inputs
215
-
216
193
n_blocks , input_blocks = sample .split_into_blocks (
217
194
input_block_shape ,
218
195
halo = self ._default_input_halo ,
@@ -239,6 +216,47 @@ def predict_sample_with_blocking(
239
216
240
217
return predicted_sample
241
218
219
+ def predict_sample_with_blocking (
220
+ self ,
221
+ sample : Sample ,
222
+ skip_preprocessing : bool = False ,
223
+ skip_postprocessing : bool = False ,
224
+ ns : Optional [
225
+ Union [
226
+ v0_5 .ParameterizedSize_N ,
227
+ Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
228
+ ]
229
+ ] = None ,
230
+ batch_size : Optional [int ] = None ,
231
+ ) -> Sample :
232
+ """predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
233
+
234
+ if isinstance (self .model_description , v0_4 .ModelDescr ):
235
+ raise NotImplementedError (
236
+ "`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
237
+ + f" { self .model_description .name } ."
238
+ + " Consider using `predict_sample_with_fixed_blocking`"
239
+ )
240
+
241
+ ns = ns or self ._default_ns
242
+ if isinstance (ns , int ):
243
+ ns = {
244
+ (ipt .id , a .id ): ns
245
+ for ipt in self .model_description .inputs
246
+ for a in ipt .axes
247
+ if isinstance (a .size , v0_5 .ParameterizedSize )
248
+ }
249
+ input_block_shape = self .model_description .get_tensor_sizes (
250
+ ns , batch_size or self ._default_batch_size
251
+ ).inputs
252
+
253
+ return self .predict_sample_with_fixed_blocking (
254
+ sample ,
255
+ input_block_shape = input_block_shape ,
256
+ skip_preprocessing = skip_preprocessing ,
257
+ skip_postprocessing = skip_postprocessing ,
258
+ )
259
+
242
260
# def predict(
243
261
# self,
244
262
# inputs: Predict_IO,
@@ -310,8 +328,8 @@ def create_prediction_pipeline(
310
328
),
311
329
model_adapter : Optional [ModelAdapter ] = None ,
312
330
ns : Union [
313
- v0_5 .ParameterizedSize . N ,
314
- Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize . N ],
331
+ v0_5 .ParameterizedSize_N ,
332
+ Mapping [Tuple [MemberId , AxisId ], v0_5 .ParameterizedSize_N ],
315
333
] = 10 ,
316
334
** deprecated_kwargs : Any ,
317
335
) -> PredictionPipeline :
0 commit comments