1
- from typing import Optional , Sequence
1
+ from typing import Optional , Sequence , Tuple
2
2
from os .path import join
3
3
import logging
4
4
import click
@@ -173,7 +173,7 @@ def save(self, labels: SemanticSegmentationLabels) -> None:
173
173
local_root = get_local_path (self .root_uri , self .tmp_dir )
174
174
make_dir (local_root )
175
175
176
- out_smooth_profile = {
176
+ out_profile = {
177
177
'driver' : 'GTiff' ,
178
178
'height' : self .extent .ymax ,
179
179
'width' : self .extent .xmax ,
@@ -190,12 +190,11 @@ def save(self, labels: SemanticSegmentationLabels) -> None:
190
190
labels += old_labels
191
191
192
192
self .write_discrete_raster_output (
193
- out_smooth_profile , get_local_path (self .label_uri , self .tmp_dir ),
194
- labels )
193
+ out_profile , get_local_path (self .label_uri , self .tmp_dir ), labels )
195
194
196
195
if self .smooth_output :
197
196
self .write_smooth_raster_output (
198
- out_smooth_profile ,
197
+ out_profile ,
199
198
get_local_path (self .score_uri , self .tmp_dir ),
200
199
get_local_path (self .hits_uri , self .tmp_dir ),
201
200
labels ,
@@ -207,14 +206,14 @@ def save(self, labels: SemanticSegmentationLabels) -> None:
207
206
sync_to_dir (local_root , self .root_uri )
208
207
209
208
def write_smooth_raster_output (self ,
210
- out_smooth_profile : dict ,
209
+ out_profile : dict ,
211
210
scores_path : str ,
212
211
hits_path : str ,
213
212
labels : SemanticSegmentationLabels ,
214
213
chip_sz : Optional [int ] = None ) -> None :
215
214
dtype = np .uint8 if self .smooth_as_uint8 else np .float32
216
215
217
- out_smooth_profile .update ({
216
+ out_profile .update ({
218
217
'count' : labels .num_classes ,
219
218
'dtype' : dtype ,
220
219
})
@@ -224,45 +223,38 @@ def write_smooth_raster_output(self,
224
223
windows = labels .get_windows (chip_sz = chip_sz )
225
224
226
225
log .info ('Writing smooth labels to disk.' )
227
- with rio .open (scores_path , 'w' , ** out_smooth_profile ) as dataset :
226
+ with rio .open (scores_path , 'w' , ** out_profile ) as dataset :
228
227
with click .progressbar (windows ) as bar :
229
228
for window in bar :
230
- window = window . intersection (self .extent )
229
+ window , _ = self . _clip_to_extent (self .extent , window )
231
230
score_arr = labels .get_score_arr (window )
232
231
if self .smooth_as_uint8 :
233
- score_arr *= 255
234
- score_arr = np .around (score_arr , out = score_arr )
235
- score_arr = score_arr .astype (dtype )
236
- window = window .rasterio_format ()
237
- for i , class_scores in enumerate (score_arr , start = 1 ):
238
- dataset .write_band (i , class_scores , window = window )
232
+ score_arr = self ._scores_to_uint8 (score_arr )
233
+ self ._write_array (dataset , window , score_arr )
239
234
# save pixel hits too
240
235
np .save (hits_path , labels .pixel_hits )
241
236
242
237
def write_discrete_raster_output (
243
- self , out_smooth_profile : dict , path : str ,
238
+ self , out_profile : dict , path : str ,
244
239
labels : SemanticSegmentationLabels ) -> None :
245
240
246
241
num_bands = 1 if self .class_transformer is None else 3
247
- out_smooth_profile .update ({'count' : num_bands , 'dtype' : np .uint8 })
242
+ out_profile .update ({'count' : num_bands , 'dtype' : np .uint8 })
248
243
249
244
windows = labels .get_windows ()
250
245
251
246
log .info ('Writing labels to disk.' )
252
- with rio .open (path , 'w' , ** out_smooth_profile ) as dataset :
247
+ with rio .open (path , 'w' , ** out_profile ) as dataset :
253
248
with click .progressbar (windows ) as bar :
254
249
for window in bar :
255
- window = window .intersection (self .extent )
256
250
label_arr = labels .get_label_arr (window )
257
- window = window .rasterio_format ()
258
- if self .class_transformer is None :
259
- dataset .write_band (1 , label_arr , window = window )
260
- else :
261
- rgb_labels = self .class_transformer .class_to_rgb (
251
+ window , label_arr = self ._clip_to_extent (
252
+ self .extent , window , label_arr )
253
+ if self .class_transformer is not None :
254
+ label_arr = self .class_transformer .class_to_rgb (
262
255
label_arr )
263
- rgb_labels = rgb_labels .transpose (2 , 0 , 1 )
264
- for i , band in enumerate (rgb_labels , start = 1 ):
265
- dataset .write_band (i , band , window = window )
256
+ label_arr = label_arr .transpose (2 , 0 , 1 )
257
+ self ._write_array (dataset , window , label_arr )
266
258
267
259
def _labels_to_full_label_arr (
268
260
self , labels : SemanticSegmentationLabels ) -> np .ndarray :
@@ -273,7 +265,7 @@ def _labels_to_full_label_arr(
273
265
except KeyError :
274
266
pass
275
267
276
- # construct the array from individual windows
268
+ # we will construct the array from individual windows
277
269
windows = labels .get_windows ()
278
270
279
271
# value for pixels not convered by any windows
@@ -288,9 +280,9 @@ def _labels_to_full_label_arr(
288
280
self .extent .size , fill_value = default_class_id , dtype = np .uint8 )
289
281
290
282
for w in windows :
291
- w = w .intersection (self .extent )
292
283
ymin , xmin , ymax , xmax = w
293
284
arr = labels .get_label_arr (w )
285
+ w , arr = self ._clip_to_extent (self .extent , w , arr )
294
286
label_arr [ymin :ymax , xmin :xmax ] = arr
295
287
return label_arr
296
288
@@ -342,3 +334,33 @@ def empty_labels(self) -> SemanticSegmentationLabels:
342
334
extent = self .extent ,
343
335
num_classes = len (self .class_config ))
344
336
return labels
337
+
338
+ def _write_array (self , dataset : rio .DatasetReader , window : Box ,
339
+ arr : np .ndarray ) -> None :
340
+ """Write array out to a rasterio dataset. Array must be of shape
341
+ (C, H, W).
342
+ """
343
+ window = window .rasterio_format ()
344
+ if len (arr .shape ) == 2 :
345
+ dataset .write_band (1 , arr , window = window )
346
+ else :
347
+ for i , band in enumerate (arr , start = 1 ):
348
+ dataset .write_band (i , band , window = window )
349
+
350
+ def _clip_to_extent (self ,
351
+ extent : Box ,
352
+ window : Box ,
353
+ arr : Optional [np .ndarray ] = None
354
+ ) -> Tuple [Box , Optional [np .ndarray ]]:
355
+ clipped_window = window .intersection (extent )
356
+ if arr is not None :
357
+ h , w = clipped_window .size
358
+ arr = arr [:h , :w ]
359
+ return clipped_window , arr
360
+
361
+ def _scores_to_uint8 (self , score_arr : np .ndarray ) -> np .ndarray :
362
+ """Quantize scores to uint8 (0-255)."""
363
+ score_arr *= 255
364
+ score_arr = np .around (score_arr , out = score_arr )
365
+ score_arr = score_arr .astype (np .uint8 )
366
+ return score_arr
0 commit comments