Skip to content

Commit ff91514

Browse files
JoaoDaniel Moreno
authored andcommitted
refactor(generator): draft filter usage
1 parent a71b220 commit ff91514

File tree

3 files changed

+53
-7
lines changed

3 files changed

+53
-7
lines changed

pixels/generator/filters.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ def _make_mask_on_value(img, mask_value):
2525
return mask_img
2626

2727

28-
def order_tensor_on_masks(images, mask_value, max_images=12):
28+
def order_tensor_on_masks(images: np.array, mask_value: float, max_images: int = 12):
2929
"""
3030
Order a set of images based on a mask count.
31-
3231
Parameters
3332
----------
3433
images : array
@@ -37,7 +36,6 @@ def order_tensor_on_masks(images, mask_value, max_images=12):
3736
Value to create mask.
3837
max_images : int
3938
The maximum number of images to return
40-
4139
Returns
4240
-------
4341
image : numpy array

pixels/generator/generator.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ def get_data(self, index, metadata=False):
411411
if temp_dir:
412412
temp_dir.cleanup()
413413
x_imgs = np.array(x_imgs)
414-
415414
if (
416415
SENTINEL_2 in self.platforms or LANDSAT_8 in self.platforms
417416
) and self.cloud_sort:
@@ -420,7 +419,6 @@ def get_data(self, index, metadata=False):
420419
x_imgs = filters.order_tensor_on_cloud_mask(
421420
x_imgs, max_images=self.timesteps, sat_platform=sat_platform
422421
)
423-
424422
if not self.train:
425423
if metadata:
426424
return x_imgs, x_meta
@@ -927,6 +925,27 @@ def process_data(self):
927925
raise NotImplementedError
928926

929927

928+
"""
929+
930+
read data
931+
(several option)
932+
but return alwasys this shape
933+
934+
935+
process data
936+
a lot of options but shuold always be able to do:
937+
938+
upsampling (optional)
939+
augmentaiton (Optional)
940+
padding (Optional)
941+
nan_sorter (Optional)
942+
cloud sorter (Optional)
943+
944+
return data
945+
946+
"""
947+
948+
930949
class Data(Protocol):
931950
x_tensor: np.array
932951
y_tensor: np.array
@@ -971,8 +990,31 @@ def padd(self, padding: int = 0, padding_mode="same"):
971990
mode=padding_mode,
972991
)
973992

974-
def cloud_sort(self):
975-
...
993+
def nan_value_sorter(self):
994+
# TODO: Working only when timesteps is the first dimension. Solve to general case
995+
# to integrate batch.
996+
# That is why the loop on batch.
997+
ordered_tensor = []
998+
for batch_images in self.x_tensor:
999+
# Choose and order timesteps by level of nan_value density.
1000+
# Expects X -> (Timesteps, num_bands, width, height)
1001+
x_imgs = filters.order_tensor_on_masks(
1002+
batch_images, self.x_nan_value, max_images=self.x_tensor.timesteps
1003+
)
1004+
ordered_tensor.append(x_imgs)
1005+
self.x_tensor = ordered_tensor
1006+
1007+
def cloud_sorter(self):
1008+
# TODO: Working only after data is read, so it is expecting shape:
1009+
# NOT working probably.
1010+
if (
1011+
SENTINEL_2 in self.platforms or LANDSAT_8 in self.platforms
1012+
) and self.cloud_sort:
1013+
# Now we only use one platform.
1014+
sat_platform = [f for f in self.platforms][0]
1015+
self.x_tensor = filters.order_tensor_on_cloud_mask(
1016+
self.x_tensor, max_images=self.timesteps, sat_platform=sat_platform
1017+
)
9761018

9771019
def normalize(self, normalization: float = None):
9781020
if normalization is not None:
@@ -1011,6 +1053,9 @@ def __init__(self, x_shape: XShape1D, y_shape: YShape1D):
10111053
self.x_shape = x_shape
10121054
self.y_shape = y_shape
10131055

1056+
def padding(self):
1057+
pass
1058+
10141059

10151060
class LearningMode:
10161061
def __init__():

pixels/generator/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TypeVar
33

44
XShape3D = namedtuple("XShape3D", ["batch", "timesteps", "height", "width", "bands"])
5+
# batch = batch * timesteps
56
XShape2D = namedtuple("XShape2D", ["batch", "height", "width", "bands"])
67
XShape1D = namedtuple("XShape1D", ["batch", "timesteps", "bands"])
78

@@ -10,3 +11,5 @@
1011

1112
XShape = TypeVar("XShape", XShape1D, XShape2D, XShape3D)
1213
YShape = TypeVar("YShape", YShape1D, YShapeND)
14+
15+
# RESNET -> Batch, batch * (number of moving windoe on each batch)

0 commit comments

Comments
 (0)