Skip to content

Commit fb02a73

Browse files
authored
Merge pull request #1079 from AdeelH/overflow
Zero out overflowing regions when reading a window from a RasterioSource with a cropped extent
2 parents 1b4da49 + 533a89a commit fb02a73

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

rastervision_core/rastervision/core/data/raster_source/rasterio_source.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,26 @@ def load_window(image_dataset, window=None, is_masked=False):
7373
return im
7474

7575

76+
def fill_overflow(extent: Box,
77+
window: Box,
78+
arr: np.ndarray,
79+
fill_value: int = 0) -> np.ndarray:
80+
"""Given a window and corresponding array of values, if the window
81+
overflows the extent, fill the overflowing regions with fill_value.
82+
"""
83+
top_overflow = max(0, extent.ymin - window.ymin)
84+
bottom_overflow = max(0, window.ymax - extent.ymax)
85+
left_overflow = max(0, extent.xmin - window.xmin)
86+
right_overflow = max(0, window.xmax - extent.xmax)
87+
88+
h, w = arr.shape[:2]
89+
arr[:top_overflow] = fill_value
90+
arr[h - bottom_overflow:] = fill_value
91+
arr[:, :left_overflow] = fill_value
92+
arr[:, w - right_overflow:] = fill_value
93+
return arr
94+
95+
7696
class RasterioSource(ActivateMixin, RasterSource):
7797
def __init__(self,
7898
uris,
@@ -175,14 +195,17 @@ def get_dtype(self):
175195
"""Return the numpy.dtype of this scene"""
176196
return self.dtype
177197

178-
def _get_chip(self, window):
198+
def _get_chip(self, window: Box) -> np.ndarray:
179199
if self.image_dataset is None:
180200
raise ActivationError('RasterSource must be activated before use')
181201
shifted_window = self._get_shifted_window(window)
182-
return load_window(
202+
chip = load_window(
183203
self.image_dataset,
184204
window=shifted_window.rasterio_format(),
185205
is_masked=self.is_masked)
206+
if self.extent_crop is not None:
207+
chip = fill_overflow(self.get_extent(), window, chip)
208+
return chip
186209

187210
def _activate(self):
188211
# Download images to temporary directory and delete them when done.

tests/core/data/raster_source/test_rasterio_source.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import unittest
22
from os.path import join
33
from pydantic import ValidationError
4+
from tempfile import NamedTemporaryFile
45

56
import numpy as np
67
import rasterio
78
from rasterio.enums import ColorInterp
89

910
from rastervision.core import (RasterStats)
11+
from rastervision.core.box import Box
1012
from rastervision.core.utils.misc import save_img
1113
from rastervision.core.data import (ChannelOrderError, RasterioSourceConfig,
12-
StatsTransformerConfig, CropOffsets)
14+
StatsTransformerConfig, CropOffsets,
15+
fill_overflow)
1316
from rastervision.pipeline import rv_config
1417

1518
from tests import data_file_path
@@ -306,20 +309,62 @@ def test_extent_crop(self):
306309
self.assertRaises(
307310
ValidationError,
308311
lambda: RasterioSourceConfig(uris=[img_path],
309-
extent_crop=extent_crop))
312+
extent_crop=extent_crop))
310313

311314
extent_crop = CropOffsets(skip_left=.5, skip_right=.5)
312315
self.assertRaises(
313316
ValidationError,
314317
lambda: RasterioSourceConfig(uris=[img_path],
315-
extent_crop=extent_crop))
318+
extent_crop=extent_crop))
316319

317320
# test extent_crop=None
318321
try:
319322
_ = RasterioSourceConfig(uris=[img_path], extent_crop=None) # noqa
320323
except Exception:
321324
self.fail('extent_crop=None caused an error.')
322325

326+
def test_fill_overflow(self):
327+
extent = Box(10, 10, 90, 90)
328+
window = Box(0, 0, 100, 100)
329+
arr = np.ones((100, 100), dtype=np.uint8)
330+
out = fill_overflow(extent, window, arr)
331+
mask = np.zeros_like(arr).astype(np.bool)
332+
mask[10:90, 10:90] = 1
333+
self.assertTrue(np.all(out[mask] == 1))
334+
self.assertTrue(np.all(out[~mask] == 0))
335+
336+
window = Box(0, 0, 80, 100)
337+
arr = np.ones((80, 100), dtype=np.uint8)
338+
out = fill_overflow(extent, window, arr)
339+
mask = np.zeros((80, 100), dtype=np.bool)
340+
mask[10:90, 10:90] = 1
341+
self.assertTrue(np.all(out[mask] == 1))
342+
self.assertTrue(np.all(out[~mask] == 0))
343+
344+
def test_extent_crop_overflow(self):
345+
f = 1 / 10
346+
arr = np.ones((100, 100), dtype=np.uint8)
347+
mask = np.zeros_like(arr).astype(np.bool)
348+
mask[10:90, 10:90] = 1
349+
with NamedTemporaryFile('wb') as fp:
350+
uri = fp.name
351+
with rasterio.open(
352+
uri,
353+
'w',
354+
driver='GTiff',
355+
height=100,
356+
width=100,
357+
count=1,
358+
dtype=np.uint8) as ds:
359+
ds.write_band(1, arr)
360+
cfg = RasterioSourceConfig(uris=[uri], extent_crop=(f, f, f, f))
361+
rs = cfg.build(tmp_dir=self.tmp_dir)
362+
with rs.activate():
363+
out = rs.get_chip(Box(0, 0, 100, 100))[..., 0]
364+
365+
self.assertTrue(np.all(out[mask] == 1))
366+
self.assertTrue(np.all(out[~mask] == 0))
367+
323368

324369
if __name__ == '__main__':
325370
unittest.main()

0 commit comments

Comments
 (0)