Skip to content

Commit 1b4da49

Browse files
authored
Merge pull request #1078 from AdeelH/ss-pred-error
Fix bug in SemanticSegmentationLabelStore
2 parents df1896a + 395be73 commit 1b4da49

File tree

1 file changed

+51
-29
lines changed

1 file changed

+51
-29
lines changed

rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py

+51-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence
1+
from typing import Optional, Sequence, Tuple
22
from os.path import join
33
import logging
44
import click
@@ -173,7 +173,7 @@ def save(self, labels: SemanticSegmentationLabels) -> None:
173173
local_root = get_local_path(self.root_uri, self.tmp_dir)
174174
make_dir(local_root)
175175

176-
out_smooth_profile = {
176+
out_profile = {
177177
'driver': 'GTiff',
178178
'height': self.extent.ymax,
179179
'width': self.extent.xmax,
@@ -190,12 +190,11 @@ def save(self, labels: SemanticSegmentationLabels) -> None:
190190
labels += old_labels
191191

192192
self.write_discrete_raster_output(
193-
out_smooth_profile, get_local_path(self.label_uri, self.tmp_dir),
194-
labels)
193+
out_profile, get_local_path(self.label_uri, self.tmp_dir), labels)
195194

196195
if self.smooth_output:
197196
self.write_smooth_raster_output(
198-
out_smooth_profile,
197+
out_profile,
199198
get_local_path(self.score_uri, self.tmp_dir),
200199
get_local_path(self.hits_uri, self.tmp_dir),
201200
labels,
@@ -207,14 +206,14 @@ def save(self, labels: SemanticSegmentationLabels) -> None:
207206
sync_to_dir(local_root, self.root_uri)
208207

209208
def write_smooth_raster_output(self,
210-
out_smooth_profile: dict,
209+
out_profile: dict,
211210
scores_path: str,
212211
hits_path: str,
213212
labels: SemanticSegmentationLabels,
214213
chip_sz: Optional[int] = None) -> None:
215214
dtype = np.uint8 if self.smooth_as_uint8 else np.float32
216215

217-
out_smooth_profile.update({
216+
out_profile.update({
218217
'count': labels.num_classes,
219218
'dtype': dtype,
220219
})
@@ -224,45 +223,38 @@ def write_smooth_raster_output(self,
224223
windows = labels.get_windows(chip_sz=chip_sz)
225224

226225
log.info('Writing smooth labels to disk.')
227-
with rio.open(scores_path, 'w', **out_smooth_profile) as dataset:
226+
with rio.open(scores_path, 'w', **out_profile) as dataset:
228227
with click.progressbar(windows) as bar:
229228
for window in bar:
230-
window = window.intersection(self.extent)
229+
window, _ = self._clip_to_extent(self.extent, window)
231230
score_arr = labels.get_score_arr(window)
232231
if self.smooth_as_uint8:
233-
score_arr *= 255
234-
score_arr = np.around(score_arr, out=score_arr)
235-
score_arr = score_arr.astype(dtype)
236-
window = window.rasterio_format()
237-
for i, class_scores in enumerate(score_arr, start=1):
238-
dataset.write_band(i, class_scores, window=window)
232+
score_arr = self._scores_to_uint8(score_arr)
233+
self._write_array(dataset, window, score_arr)
239234
# save pixel hits too
240235
np.save(hits_path, labels.pixel_hits)
241236

242237
def write_discrete_raster_output(
243-
self, out_smooth_profile: dict, path: str,
238+
self, out_profile: dict, path: str,
244239
labels: SemanticSegmentationLabels) -> None:
245240

246241
num_bands = 1 if self.class_transformer is None else 3
247-
out_smooth_profile.update({'count': num_bands, 'dtype': np.uint8})
242+
out_profile.update({'count': num_bands, 'dtype': np.uint8})
248243

249244
windows = labels.get_windows()
250245

251246
log.info('Writing labels to disk.')
252-
with rio.open(path, 'w', **out_smooth_profile) as dataset:
247+
with rio.open(path, 'w', **out_profile) as dataset:
253248
with click.progressbar(windows) as bar:
254249
for window in bar:
255-
window = window.intersection(self.extent)
256250
label_arr = labels.get_label_arr(window)
257-
window = window.rasterio_format()
258-
if self.class_transformer is None:
259-
dataset.write_band(1, label_arr, window=window)
260-
else:
261-
rgb_labels = self.class_transformer.class_to_rgb(
251+
window, label_arr = self._clip_to_extent(
252+
self.extent, window, label_arr)
253+
if self.class_transformer is not None:
254+
label_arr = self.class_transformer.class_to_rgb(
262255
label_arr)
263-
rgb_labels = rgb_labels.transpose(2, 0, 1)
264-
for i, band in enumerate(rgb_labels, start=1):
265-
dataset.write_band(i, band, window=window)
256+
label_arr = label_arr.transpose(2, 0, 1)
257+
self._write_array(dataset, window, label_arr)
266258

267259
def _labels_to_full_label_arr(
268260
self, labels: SemanticSegmentationLabels) -> np.ndarray:
@@ -273,7 +265,7 @@ def _labels_to_full_label_arr(
273265
except KeyError:
274266
pass
275267

276-
# construct the array from individual windows
268+
# we will construct the array from individual windows
277269
windows = labels.get_windows()
278270

279271
# value for pixels not convered by any windows
@@ -288,9 +280,9 @@ def _labels_to_full_label_arr(
288280
self.extent.size, fill_value=default_class_id, dtype=np.uint8)
289281

290282
for w in windows:
291-
w = w.intersection(self.extent)
292283
ymin, xmin, ymax, xmax = w
293284
arr = labels.get_label_arr(w)
285+
w, arr = self._clip_to_extent(self.extent, w, arr)
294286
label_arr[ymin:ymax, xmin:xmax] = arr
295287
return label_arr
296288

@@ -342,3 +334,33 @@ def empty_labels(self) -> SemanticSegmentationLabels:
342334
extent=self.extent,
343335
num_classes=len(self.class_config))
344336
return labels
337+
338+
def _write_array(self, dataset: rio.DatasetReader, window: Box,
339+
arr: np.ndarray) -> None:
340+
"""Write array out to a rasterio dataset. Array must be of shape
341+
(C, H, W).
342+
"""
343+
window = window.rasterio_format()
344+
if len(arr.shape) == 2:
345+
dataset.write_band(1, arr, window=window)
346+
else:
347+
for i, band in enumerate(arr, start=1):
348+
dataset.write_band(i, band, window=window)
349+
350+
def _clip_to_extent(self,
351+
extent: Box,
352+
window: Box,
353+
arr: Optional[np.ndarray] = None
354+
) -> Tuple[Box, Optional[np.ndarray]]:
355+
clipped_window = window.intersection(extent)
356+
if arr is not None:
357+
h, w = clipped_window.size
358+
arr = arr[:h, :w]
359+
return clipped_window, arr
360+
361+
def _scores_to_uint8(self, score_arr: np.ndarray) -> np.ndarray:
362+
"""Quantize scores to uint8 (0-255)."""
363+
score_arr *= 255
364+
score_arr = np.around(score_arr, out=score_arr)
365+
score_arr = score_arr.astype(np.uint8)
366+
return score_arr

0 commit comments

Comments
 (0)