Skip to content

Commit

Permalink
Merge pull request #49 from earthdaily/dev
Browse files Browse the repository at this point in the history
v0.0.10
  • Loading branch information
nkarasiak authored Mar 5, 2024
2 parents 5ca8fe5 + a035d47 commit 467b0e4
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 56 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [0.0.10] - Unreleased


## [0.0.9] - 2024-02-29

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion earthdaily/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# to hide warnings from rioxarray or nano seconds conversion
warnings.filterwarnings("ignore")

__version__ = "0.0.9"
__version__ = "0.0.10"
97 changes: 55 additions & 42 deletions earthdaily/accessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from shapely.geometry import Point
from dask import array as da
import spyndex
from dask_image import ndfilters as ndimage

from dask_image import ndfilters as dask_ndimage
from scipy import ndimage
from xarray.core.extensions import AccessorRegistrationWarning
from ..earthdatastore.cube_utils import GeometryManager

warnings.filterwarnings("ignore", category=AccessorRegistrationWarning)

Expand Down Expand Up @@ -49,6 +50,8 @@ def force(*args, **kwargs):
if is_kwargs:
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
elif len(args) >= idx:
if isinstance(val, (list, tuple)) and len(val) > 1:
val = val[0]
_args[idx] = val(args[idx]) if val != list else [args[idx]]
idx += 1
args = tuple(_args)
Expand Down Expand Up @@ -92,18 +95,11 @@ def _xr_loop_func(dataset, metafunc, loop_dimension, **kwargs):

@_typer()
def _lee_filter(img, window_size: int):
try:
from dask_image import ndfilters
except ImportError:
raise ImportError("Please install dask-image to run lee_filter")

img_ = img.copy()
ndimage_type = ndfilters
if hasattr(img, "data"):
if isinstance(img.data, (memoryview, np.ndarray)):
ndimage_type = ndimage
img = img.data
# print(ndimage_type)
if isinstance(img, np.ndarray):
ndimage_type = ndimage
else:
ndimage_type = dask_ndimage
binary_nan = ndimage_type.minimum_filter(
xr.where(np.isnan(img), 0, 1), size=window_size
)
Expand All @@ -124,30 +120,29 @@ def _lee_filter(img, window_size: int):
return img_output


def _xr_rio_clip(datacube, geom):
geom = GeometryManager(geom).to_geopandas()
geom = geom.to_crs(datacube.rio.crs)
return datacube.rio.clip(geom.geometry)


@xr.register_dataarray_accessor("ed")
class EarthDailyAccessorDataArray:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def _max_time_wrap(self, wish=5):
return np.min((wish, self._obj["time"].size))
def clip(self, geom):
return _xr_rio_clip(self._obj, geom)

@_typer()
def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(
cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
)
def _max_time_wrap(self, wish=5, col="time"):
return np.min((wish, self._obj[col].size))

@_typer()
def plot_index(
self, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(
vmin=vmin,
vmax=vmax,
cmap=cmap,
col=col,
col_wrap=self._max_time_wrap(col_wrap),
col_wrap=self._max_time_wrap(col_wrap, col=col),
**kwargs,
)

Expand All @@ -157,8 +152,11 @@ class EarthDailyAccessorDataset:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def _max_time_wrap(self, wish=5):
return np.min((wish, self._obj["time"].size))
def clip(self, geom):
return _xr_rio_clip(self._obj, geom)

def _max_time_wrap(self, wish=5, col="time"):
return np.min((wish, self._obj[col].size))

@_typer()
def plot_rgb(
Expand All @@ -173,30 +171,22 @@ def plot_rgb(
return (
self._obj[[red, green, blue]]
.to_array(dim="bands")
.plot.imshow(col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs)
.plot.imshow(
col=col, col_wrap=self._max_time_wrap(col_wrap, col=col), **kwargs
)
)

@_typer()
def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj[band].plot.imshow(
cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
)

@_typer()
def plot_index(
self, index, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
return self._obj[index].plot.imshow(
vmin=vmin,
vmax=vmax,
cmap=cmap,
col=col,
col_wrap=self._max_time_wrap(col_wrap),
col_wrap=self._max_time_wrap(col_wrap, col=col),
**kwargs,
)

@_typer()
def lee_filter(self, window_size: int = 7):
def lee_filter(self, window_size: int):
return xr.apply_ufunc(
_lee_filter,
self._obj,
Expand Down Expand Up @@ -302,7 +292,7 @@ def add_indices(self, index: list, **kwargs):
@_typer()
def sel_nearest_dates(
self,
target,
target: (xr.Dataset, xr.DataArray),
max_delta: int = 0,
method: str = "nearest",
return_target: bool = False,
Expand All @@ -321,3 +311,26 @@ def sel_nearest_dates(
time=pos, method=method_convert[method]
)
return self._obj.sel(time=pos)

@_typer()
def whittaker(
self,
lmbd: float,
weights: np.ndarray = None,
a: float = 0.5,
min_value: float = -np.inf,
max_value: float = np.inf,
max_iter: int = 10,
):
from . import whittaker

return whittaker.xr_wt(
self._obj,
lmbd,
time="time",
weights=None,
a=0.5,
min_value=min_value,
max_value=max_value,
max_iter=max_iter,
)
113 changes: 113 additions & 0 deletions earthdaily/accessor/whittaker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import xarray as xr
import numpy as np
from ._pywapor_core import _wt1, _wt2, cve1, second_order_diff_matrix, dist_to_finite
import logging as log


def xr_dist_to_finite(y, dim="time"):
if dim not in y.dims:
raise ValueError

out = xr.apply_ufunc(
dist_to_finite,
y,
y[dim],
input_core_dims=[[dim], [dim]],
output_core_dims=[[dim]],
vectorize=False,
dask="parallelized",
)

return out


def xr_choose_func(y, lmbd, dim):
funcs = [_wt1, _wt2]
y_dims = getattr(y, "ndim", 0)
lmbd_dims = getattr(lmbd, "ndim", 0)
if y_dims in [2, 3] and lmbd_dims in [1]:
wt_func = funcs[1]
icd = [[dim], [], ["lmbda"], [], [], [], [], []]
ocd = [["lmbda", dim]]
elif y_dims in [2] and lmbd_dims in [2]:
raise ValueError
else:
wt_func = funcs[0]
icd = [[dim], [], [], [], [], [], [], []]
ocd = [[dim]]

return wt_func, icd, ocd


def assert_lmbd(lmbd):
# Check lmbdas.
if isinstance(lmbd, float) or isinstance(lmbd, int) or isinstance(lmbd, list):
lmbd = np.array(lmbd)
assert lmbd.ndim <= 2
if isinstance(lmbd, np.ndarray) or np.isscalar(lmbd):
if not np.isscalar(lmbd):
assert lmbd.ndim <= 1
if lmbd.ndim == 0:
lmbd = float(lmbd)
else:
lmbd = xr.DataArray(lmbd, dims=["lmbda"], coords={"lmbda": lmbd})
# else:
lmbd = xr.DataArray(lmbd)

return lmbd


def xr_wt(
datacube,
lmbd,
time="time",
weights=None,
a=0.5,
min_value=-np.inf,
max_value=np.inf,
max_iter=10,
):
datacube = datacube.chunk(time=-1)
datacube_ = datacube.copy()
lmbd = assert_lmbd(lmbd)

# Normalize x-coordinates
x = datacube[time]
x = (x - x.min()) / (x.max() - x.min()) * x.size

# Create x-aware delta matrix.
A = second_order_diff_matrix(x)

# Make default u weights if necessary.
if isinstance(weights, type(None)):
weights = np.ones(x.shape)

# Choose which vectorized function to use.
_wt, icd, ocd = xr_choose_func(datacube, lmbd, time)

# Make sure lmbd is chunked similar to y.
if not isinstance(datacube.chunk, type(None)):
lmbd = lmbd.chunk(
{
k: v
for k, v in datacube.unify_chunks().chunksizes.items()
if k in lmbd.dims
}
)

# Apply whittaker smoothing along axis.
datacube = xr.apply_ufunc(
_wt,
datacube,
A,
lmbd,
weights,
a,
min_value,
max_value,
max_iter,
input_core_dims=icd,
output_core_dims=ocd,
dask="allowed",
)
return xr.where(np.isnan(datacube_), datacube_, datacube)
Loading

0 comments on commit 467b0e4

Please sign in to comment.