Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.0.7 #46

Merged
merged 15 commits into from
Feb 28, 2024
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.0.7] - 2024-02-28

#### Added

- `ed` accessor.

## [0.0.6] - 2024-02-23

### Fixed
Expand Down
3 changes: 2 additions & 1 deletion earthdaily/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import earthdatastore, datasets
from .accessor import EarthDailyAccessorDataArray, EarthDailyAccessorDataset

__version__ = "0.0.6"
__version__ = "0.0.7"
284 changes: 284 additions & 0 deletions earthdaily/accessor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import warnings
import xarray as xr
import rioxarray as rxr
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
from dask import array as da
import spyndex
from dask_image import ndfilters as ndimage

from xarray.core.extensions import AccessorRegistrationWarning

warnings.filterwarnings("ignore", category=AccessorRegistrationWarning)


class MisType(Warning):
pass


_SUPPORTED_DTYPE = [int, float, list, bool, str]


def _typer(raise_mistype=False):
def decorator(func):
def force(*args, **kwargs):
for key, val in func.__annotations__.items():
if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None:
continue
if raise_mistype and val != type(kwargs.get(key)):
raise MisType(
f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})"
)
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
return func(*args, **kwargs)

return force

return decorator


@_typer()
def xr_loop_func(
dataset: xr.Dataset,
func,
to_numpy: bool = False,
loop_dimension: str = "time",
**kwargs,
):
def _xr_loop_func(dataset, metafunc, loop_dimension, **kwargs):
if to_numpy is True:
dataset_func = dataset.copy()
looped = [
metafunc(dataset.isel({loop_dimension: i}).load().data, **kwargs)
for i in range(dataset[loop_dimension].size)
]
dataset_func.data = np.asarray(looped)
return dataset_func
else:
return xr.concat(
[
metafunc(dataset.isel({loop_dimension: i}), **kwargs)
for i in range(dataset[loop_dimension].size)
],
dim=loop_dimension,
)

return dataset.map(
func=_xr_loop_func, metafunc=func, loop_dimension=loop_dimension, **kwargs
)


@_typer()
def _lee_filter(img, window_size: int):
try:
from dask_image import ndfilters
except ImportError:
raise ImportError("Please install dask-image to run lee_filter")

img_ = img.copy()
ndimage_type = ndfilters
if hasattr(img, "data"):
if isinstance(img.data, (memoryview, np.ndarray)):
ndimage_type = ndimage
img = img.data
# print(ndimage_type)
binary_nan = ndimage_type.minimum_filter(
xr.where(np.isnan(img), 0, 1), size=window_size
)
binary_nan = np.where(binary_nan == 0, np.nan, 1)
img = xr.where(np.isnan(img), 0, img)
window_size = da.from_array([window_size, window_size, 1])

img_mean = ndimage_type.uniform_filter(img, window_size)
img_sqr_mean = ndimage_type.uniform_filter(img**2, window_size)
img_variance = img_sqr_mean - img_mean**2

overall_variance = np.var(img, axis=(0, 1))

img_weights = img_variance / (np.add(img_variance, overall_variance))

img_output = img_mean + img_weights * (np.subtract(img, img_mean))
img_output = xr.where(np.isnan(binary_nan), img_, img_output)
return img_output


@xr.register_dataarray_accessor("ed")
class EarthDailyAccessorDataArray:
def __init__(self, xarray_obj):
self._obj = xarray_obj

@_typer()
def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)

@_typer()
def plot_index(
self, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
return self._obj.plot.imshow(
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
)


@xr.register_dataset_accessor("ed")
class EarthDailyAccessorDataset:
def __init__(self, xarray_obj):
self._obj = xarray_obj

@_typer()
def plot_rgb(
self,
red: str = "red",
green: str = "green",
blue: str = "blue",
col="time",
col_wrap=5,
**kwargs,
):
return (
self._obj[[red, green, blue]]
.to_array(dim="bands")
.plot.imshow(col=col, col_wrap=col_wrap, **kwargs)
)

@_typer()
def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj[band].plot.imshow(
cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
)

@_typer()
def plot_index(
self, index, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
return self._obj[index].plot.imshow(
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
)

@_typer()
def lee_filter(self, window_size: int = 7):
return xr.apply_ufunc(
_lee_filter,
self._obj,
input_core_dims=[["time"]],
dask="allowed",
output_core_dims=[["time"]],
kwargs=dict(window_size=window_size),
)

@_typer()
def centroid(self, to_wkt: str = False, to_4326: bool = True):
"""Return the geographic center point in 4326/WKT of this dataset."""
# we can use a cache on our accessor objects, because accessors
# themselves are cached on instances that access them.
lon = float(self._obj.x[int(self._obj.x.size / 2)])
lat = float(self._obj.y[int(self._obj.y.size / 2)])
point = gpd.GeoSeries([Point(lon, lat)], crs=self._obj.rio.crs)
if to_4326:
point = point.to_crs(epsg="4326")
if to_wkt:
point = point.map(lambda x: x.wkt).iloc[0]
return point

def _auto_mapper(self):
_BAND_MAPPING = {
"coastal": "A",
"blue": "B",
"green": "G",
"yellow": "Y",
"red": "R",
"rededge1": "RE1",
"rededge2": "RE2",
"rededge3": "RE3",
"nir": "N",
"nir08": "N2",
"watervapor": "WV",
"swir16": "S1",
"swir22": "S2",
"lwir": "T1",
"lwir11": "T2",
"vv": "VV",
"vh": "VH",
"hh": "HH",
"hv": "HV",
}

params = {}
data_vars = list(
self._obj.rename(
{var: var.lower() for var in self._obj.data_vars}
).data_vars
)
for v in data_vars:
if v in _BAND_MAPPING.keys():
params[_BAND_MAPPING[v]] = self._obj[v]
return params

def list_available_index(self, details=False):
mapper = list(self._auto_mapper().keys())
indices = spyndex.indices
available_indices = []
for k, v in indices.items():
needed_bands = v.bands
for needed_band in needed_bands:
if needed_band not in mapper:
break
available_indices.append(spyndex.indices[k] if details else k)
return available_indices

@_typer()
def add_index(self, index: list, **kwargs):
"""
Uses spyndex to compute and add index.

For list of indices, see https://github.com/awesome-spectral-indices/awesome-spectral-indices.


Parameters
----------
index : list
['NDVI'].
Returns
-------
xr.Dataset
The input xr.Dataset with new data_vars of indices.

"""

params = {}
bands_mapping = self._auto_mapper()
for k, v in bands_mapping.items():
params[k] = self._obj[v]
params.update(**kwargs)
idx = spyndex.computeIndex(index=index, params=params, **kwargs)

if len(index) == 1:
idx = idx.expand_dims(index=index)
idx = idx.to_dataset(dim="index")

return xr.merge((self._obj, idx))

@_typer()
def sel_nearest_dates(
self,
target,
max_delta: int = 0,
method: str = "nearest",
return_target: bool = False,
):
src_time = self._obj.sel(time=target.time.dt.date, method=method).time.dt.date
target_time = target.time.dt.date
pos = np.abs(src_time.data - target_time.data)
pos = [
src_time.isel(time=i).time.values
for i, j in enumerate(pos)
if j.days <= max_delta
]
if return_target:
method_convert = {"bfill": "ffill", "ffill": "bfill", "nearest": "nearest"}
return self._obj.sel(time=pos), target.sel(
time=pos, method=method_convert[method]
)
return self._obj.sel(time=pos)
2 changes: 1 addition & 1 deletion earthdaily/earthdatastore/cube_utils/_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
"""

from rasterio import features
from scipy.sparse import csr_matrix
import numpy as np
import xarray as xr
import tqdm
from . import custom_operations
from .preprocessing import rasterize
from scipy.sparse import csr_matrix


def _compute_M(data):
Expand Down
10 changes: 5 additions & 5 deletions examples/compare_scale_s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,26 @@ def get_cube(rescale=True):

pivot_cube = get_cube(rescale=False) * 0.0001

#####################################################################da#########
##############################################################################
# Plots cube with SCL with at least 50% of clear data
# ----------------------------------------------------

pivot_cube.to_array(dim="band").plot.imshow(vmin=0, vmax=0.33, col="time", col_wrap=3)
pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)
plt.show()

#####################################################################da#########
##############################################################################
# Get cube with automatic rescale (default option)
# ----------------------------------------------------

pivot_cube = get_cube()
pivot_cube.clear_percent.plot.scatter(x="time")
plt.show()

#####################################################################da#########
##############################################################################
# Plots cube with SCL with at least 50% of clear data
# ----------------------------------------------------


pivot_cube.to_array(dim="band").plot.imshow(vmin=0, vmax=0.33, col="time", col_wrap=3)
pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)

plt.show()
2 changes: 1 addition & 1 deletion examples/earthdaily_simulated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
# Plot RGB image time series
# -------------------------------------------

datacube[["red", "green", "blue"]].to_array(dim="band").plot.imshow(
datacube[["red", "green", "blue"]].ed.plot_rgb(
col="time", col_wrap=4, vmax=0.2
)

Expand Down
3 changes: 2 additions & 1 deletion examples/field_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
# ----------------------------------------------------

zonal_stats = earthdatastore.cube_utils.zonal_stats(
pivot_cube, pivot, operations=["mean", "max", "min"]
pivot_cube, pivot, operations=["mean", "max", "min"],
method="standard"
)
zonal_stats = zonal_stats.load()

Expand Down
6 changes: 2 additions & 4 deletions examples/first_steps_create_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@
plt.title("Percentage of clear pixels on the study site")
plt.show()

s2_datacube[["red", "green", "blue"]].to_array(dim="band").plot.imshow(
vmin=0, vmax=0.2, col="time", col_wrap=4
)
s2_datacube.ed.plot_rgb(vmin=0, vmax=0.2, col="time", col_wrap=4)

###########################################################
# Create datacube in three steps
Expand Down Expand Up @@ -82,6 +80,6 @@
s2_datacube, 50
) # at least 50% of clear pixels
#
s2_datacube[["red", "green", "blue"]].to_array(dim="band").plot.imshow(
s2_datacube.ed.plot_rgb(
vmin=0, vmax=0.2, col="time", col_wrap=4
)
4 changes: 1 addition & 3 deletions examples/venus_cube_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,4 @@
)
print(venus_datacube)

venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500))[
["red", "green", "blue"]
].to_array(dim="band").plot.imshow(col="time", vmin=0, vmax=0.30)
venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).plot_rgb()
Loading
Loading