Skip to content

Commit 6b7d8be

Browse files
author
Joao
committed
refactor(generator): draft filter usage
1 parent 8d4567b commit 6b7d8be

File tree

3 files changed

+53
-7
lines changed

3 files changed

+53
-7
lines changed

pixels/generator/filters.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ def _make_mask_on_value(img, mask_value):
2525
return mask_img
2626

2727

28-
def order_tensor_on_masks(images, mask_value, max_images=12):
28+
def order_tensor_on_masks(images: np.array, mask_value: float, max_images: int = 12):
2929
"""
3030
Order a set of images based on a mask count.
31-
3231
Parameters
3332
----------
3433
images : array
@@ -37,7 +36,6 @@ def order_tensor_on_masks(images, mask_value, max_images=12):
3736
Value to create mask.
3837
max_images : int
3938
The maximum number of images to return
40-
4139
Returns
4240
-------
4341
image : numpy array

pixels/generator/generator.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def get_data(self, index, metadata=False):
412412
if temp_dir:
413413
temp_dir.cleanup()
414414
x_imgs = np.array(x_imgs)
415-
416415
if (
417416
SENTINEL_2 in self.platforms or LANDSAT_8 in self.platforms
418417
) and self.cloud_sort:
@@ -421,7 +420,6 @@ def get_data(self, index, metadata=False):
421420
x_imgs = filters.order_tensor_on_cloud_mask(
422421
x_imgs, max_images=self.timesteps, sat_platform=sat_platform
423422
)
424-
425423
if not self.train:
426424
if metadata:
427425
return x_imgs, x_meta
@@ -928,6 +926,27 @@ def process_data(self):
928926
raise NotImplementedError
929927

930928

929+
"""
930+
931+
read data
932+
(several option)
933+
but return alwasys this shape
934+
935+
936+
process data
937+
a lot of options but shuold always be able to do:
938+
939+
upsampling (optional)
940+
augmentaiton (Optional)
941+
padding (Optional)
942+
nan_sorter (Optional)
943+
cloud sorter (Optional)
944+
945+
return data
946+
947+
"""
948+
949+
931950
class Data(Protocol):
932951
x_tensor: np.array
933952
y_tensor: np.array
@@ -972,8 +991,31 @@ def padd(self, padding: int = 0, padding_mode="same"):
972991
mode=padding_mode,
973992
)
974993

975-
def cloud_sort(self):
976-
...
994+
def nan_value_sorter(self):
995+
# TODO: Working only when timesteps is the first dimension. Solve to general case
996+
# to integrate batch.
997+
# That is why the loop on batch.
998+
ordered_tensor = []
999+
for batch_images in self.x_tensor:
1000+
# Choose and order timesteps by level of nan_value density.
1001+
# Expects X -> (Timesteps, num_bands, width, height)
1002+
x_imgs = filters.order_tensor_on_masks(
1003+
batch_images, self.x_nan_value, max_images=self.x_tensor.timesteps
1004+
)
1005+
ordered_tensor.append(x_imgs)
1006+
self.x_tensor = ordered_tensor
1007+
1008+
def cloud_sorter(self):
1009+
# TODO: Working only after data is read, so it is expecting shape:
1010+
# NOT working probably.
1011+
if (
1012+
SENTINEL_2 in self.platforms or LANDSAT_8 in self.platforms
1013+
) and self.cloud_sort:
1014+
# Now we only use one platform.
1015+
sat_platform = [f for f in self.platforms][0]
1016+
self.x_tensor = filters.order_tensor_on_cloud_mask(
1017+
self.x_tensor, max_images=self.timesteps, sat_platform=sat_platform
1018+
)
9771019

9781020
def normalize(self, normalization: float = None):
9791021
if normalization is not None:
@@ -1012,6 +1054,9 @@ def __init__(self, x_shape: XShape1D, y_shape: YShape1D):
10121054
self.x_shape = x_shape
10131055
self.y_shape = y_shape
10141056

1057+
def padding(self):
1058+
pass
1059+
10151060

10161061
class LearningMode:
10171062
def __init__():

pixels/generator/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TypeVar
33

44
XShape3D = namedtuple("XShape3D", ["batch", "timesteps", "height", "width", "bands"])
5+
# batch = batch * timesteps
56
XShape2D = namedtuple("XShape2D", ["batch", "height", "width", "bands"])
67
XShape1D = namedtuple("XShape1D", ["batch", "timesteps", "bands"])
78

@@ -10,3 +11,5 @@
1011

1112
XShape = TypeVar("XShape", XShape1D, XShape2D, XShape3D)
1213
YShape = TypeVar("YShape", YShape1D, YShapeND)
14+
15+
# RESNET -> Batch, batch * (number of moving windoe on each batch)

0 commit comments

Comments
 (0)