Skip to content

Commit

Permalink
feat(accessor)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasK committed Feb 28, 2024
1 parent 8e70444 commit 58bb461
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 1 deletion.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.0.7] - Unreleased

#### Added

- `ed` accessor.

## [0.0.6] - 2024-02-23

### Fixed
Expand Down
3 changes: 2 additions & 1 deletion earthdaily/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import earthdatastore, datasets
from .accessor import EarthDailyAccessorDataArray, EarthDailyAccessorDataset

__version__ = "0.0.6"
__version__ = "0.0.7"
195 changes: 195 additions & 0 deletions earthdaily/accessor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import warnings
import xarray as xr
import rioxarray as rxr
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
from scipy import ndimage
from dask_image import ndfilters
from dask import array as da
import spyndex


class MisType(Warning):
pass

_SUPPORTED_DTYPE = [int, float, list, bool, str]

def _typer(raise_mistype=False):
def decorator(func):
def force(*args,**kwargs):
for key, val in func.__annotations__.items():
if val not in _SUPPORTED_DTYPE or kwargs.get(key,None) is None:
continue
if raise_mistype and val != type(kwargs.get(key)):
raise MisType(f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})")
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
return func(*args,**kwargs)
return force
return decorator

@_typer()
def xr_loop_func(dataset:xr.Dataset, func, to_numpy:bool=False, loop_dimension:str="time", **kwargs):
def _xr_loop_func(dataset, metafunc,loop_dimension,**kwargs):
if to_numpy is True:
dataset_func = dataset.copy()
looped = [metafunc(dataset.isel({loop_dimension:i}).load().data, **kwargs) for i in range(dataset[loop_dimension].size)]
dataset_func.data = np.asarray(looped)
return dataset_func
else:
return xr.concat([metafunc(dataset.isel({loop_dimension:i}), **kwargs) for i in range(dataset[loop_dimension].size)], dim=loop_dimension)
return dataset.map(func=_xr_loop_func, metafunc=func, loop_dimension=loop_dimension,**kwargs)

@_typer()
def _lee_filter(img, window_size:int):
img_ = img.copy()
ndimage_type = ndfilters
if hasattr(img,"data"):
if isinstance(img.data, (memoryview, np.ndarray)):
ndimage_type = ndimage
img = img.data
# print(ndimage_type)
binary_nan = ndimage_type.minimum_filter(xr.where(np.isnan(img),0,1),size=window_size)
binary_nan = np.where(binary_nan==0,np.nan,1)
img = xr.where(np.isnan(img),0,img)
window_size = da.from_array([window_size,window_size,1])

img_mean = ndimage_type.uniform_filter(img, window_size)
img_sqr_mean = ndimage_type.uniform_filter(img**2, window_size)
img_variance = img_sqr_mean - img_mean**2

overall_variance = np.var(img,axis=(0,1))

img_weights = img_variance / (np.add(img_variance,overall_variance))

img_output = img_mean + img_weights * (np.subtract(img, img_mean))
img_output = xr.where(np.isnan(binary_nan),img_,img_output)
return img_output


@xr.register_dataarray_accessor("ed")
class EarthDailyAccessorDataArray:
def __init__(self, xarray_obj):
self._obj = xarray_obj

@_typer()
def plot_band(self, cmap='Greys', col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)

@_typer()
def plot_index(self, cmap='RdYlGn', vmin=-1,vmax=1, col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)


@xr.register_dataset_accessor("ed")
class EarthDailyAccessorDataset:
def __init__(self, xarray_obj):
self._obj = xarray_obj

@_typer()
def plot_rgb(self, red:str='red', green:str='green', blue:str='blue', col="time", col_wrap=5, **kwargs):
return self._obj[[red,green,blue]].to_array(dim="bands").plot.imshow(col=col, col_wrap=col_wrap, **kwargs)

@_typer()
def plot_band(self, band, cmap='Greys', col="time", col_wrap=5, **kwargs):
return self._obj[band].plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)

@_typer()
def plot_index(self, index, cmap='RdYlGn', vmin=-1,vmax=1, col="time", col_wrap=5, **kwargs):
return self._obj[index].plot.imshow(vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)


@_typer()
def lee_filter(self, window_size:int=7):
return xr.apply_ufunc(_lee_filter,
self._obj,
input_core_dims=[["time"]],
dask="allowed",
output_core_dims=[["time"]],
kwargs=dict(window_size=window_size))

@_typer()
def centroid(self, to_wkt:str=False, to_4326:bool=True):
"""Return the geographic center point in 4326/WKT of this dataset."""
# we can use a cache on our accessor objects, because accessors
# themselves are cached on instances that access them.
lon = float(self._obj.x[int(self._obj.x.size/2)])
lat = float(self._obj.y[int(self._obj.y.size/2)])
point = gpd.GeoSeries([Point(lon,lat)],crs=self._obj.rio.crs)
if to_4326:
point = point.to_crs(epsg='4326')
if to_wkt:
point = point.map(lambda x: x.wkt).iloc[0]
return point


def _auto_mapper(self):
_BAND_MAPPING = {'coastal': 'A', 'blue': 'B', 'green': 'G', 'yellow': 'Y', 'red': 'R', 'rededge1': 'RE1', 'rededge2': 'RE2', 'rededge3': 'RE3', 'nir': 'N', 'nir08': 'N2', 'watervapor': 'WV', 'swir16': 'S1', 'swir22': 'S2', 'lwir': 'T1', 'lwir11': 'T2', 'vv': 'VV', 'vh': 'VH', 'hh': 'HH', 'hv': 'HV'}

params = {}
data_vars = list(self._obj.rename({var:var.lower() for var in self._obj.data_vars}).data_vars)
for v in data_vars:
if v in _BAND_MAPPING.keys():
params[_BAND_MAPPING[v]] = self._obj[v]
return params

def list_available_index(self, details=False):
mapper = list(self._auto_mapper().keys())
indices = spyndex.indices
available_indices = []
for k,v in indices.items():
needed_bands = v.bands
for needed_band in needed_bands:
if needed_band not in mapper:
break
available_indices.append(spyndex.indices[k] if details else k)
return available_indices

@_typer()
def add_index(self, index:list, **kwargs):
"""
Uses spyndex to compute and add index.
For list of indices, see https://github.com/awesome-spectral-indices/awesome-spectral-indices.
Parameters
----------
index : list
['NDVI'].
Returns
-------
xr.Dataset
The input xr.Dataset with new data_vars of indices.
"""


params = {}
bands_mapping = self._auto_mapper()
for k,v in bands_mapping.items():
params[k] = self._obj[v]
params.update(**kwargs)
idx = spyndex.computeIndex(
index = index,
params = params,
**kwargs
)

if len(index)==1:
idx = idx.expand_dims(index=index)
idx = idx.to_dataset(dim='index')

return xr.merge((self._obj,idx))

@_typer()
def sel_nearest_dates(self, target, max_delta:int=0, method:str="nearest", return_target:bool=False):
src_time = self._obj.sel(time=target.time.dt.date,method=method).time.dt.date
target_time = target.time.dt.date
pos = np.abs(src_time.data-target_time.data)
pos = [src_time.isel(time=i).time.values for i,j in enumerate(pos) if j.days<= max_delta]
if return_target:
method_convert = {"bfill":"ffill","ffill":"bfill","nearest":"nearest"}
return self._obj.sel(time=pos), target.sel(time=pos,method=method_convert[method])
return self._obj.sel(time=pos)

0 comments on commit 58bb461

Please sign in to comment.