Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
 into dev
  • Loading branch information
nicolasK committed Feb 28, 2024
2 parents f1c0b2b + 1b059bd commit 61daf65
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 74 deletions.
2 changes: 1 addition & 1 deletion earthdaily/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import earthdatastore, datasets
from .accessor import EarthDailyAccessorDataArray, EarthDailyAccessorDataset

__version__ = "0.0.7"
__version__ = "0.0.7"
227 changes: 154 additions & 73 deletions earthdaily/accessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,145 +14,219 @@
class MisType(Warning):
pass


_SUPPORTED_DTYPE = [int, float, list, bool, str]


def _typer(raise_mistype=False):
def decorator(func):
def force(*args,**kwargs):
def force(*args, **kwargs):
for key, val in func.__annotations__.items():
if val not in _SUPPORTED_DTYPE or kwargs.get(key,None) is None:
if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None:
continue
if raise_mistype and val != type(kwargs.get(key)):
raise MisType(f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})")
raise MisType(
f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})"
)
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
return func(*args,**kwargs)
return func(*args, **kwargs)

return force

return decorator


@_typer()
def xr_loop_func(dataset:xr.Dataset, func, to_numpy:bool=False, loop_dimension:str="time", **kwargs):
def _xr_loop_func(dataset, metafunc,loop_dimension,**kwargs):
def xr_loop_func(
dataset: xr.Dataset,
func,
to_numpy: bool = False,
loop_dimension: str = "time",
**kwargs,
):
def _xr_loop_func(dataset, metafunc, loop_dimension, **kwargs):
if to_numpy is True:
dataset_func = dataset.copy()
looped = [metafunc(dataset.isel({loop_dimension:i}).load().data, **kwargs) for i in range(dataset[loop_dimension].size)]
looped = [
metafunc(dataset.isel({loop_dimension: i}).load().data, **kwargs)
for i in range(dataset[loop_dimension].size)
]
dataset_func.data = np.asarray(looped)
return dataset_func
else:
return xr.concat([metafunc(dataset.isel({loop_dimension:i}), **kwargs) for i in range(dataset[loop_dimension].size)], dim=loop_dimension)
return dataset.map(func=_xr_loop_func, metafunc=func, loop_dimension=loop_dimension,**kwargs)
return xr.concat(
[
metafunc(dataset.isel({loop_dimension: i}), **kwargs)
for i in range(dataset[loop_dimension].size)
],
dim=loop_dimension,
)

return dataset.map(
func=_xr_loop_func, metafunc=func, loop_dimension=loop_dimension, **kwargs
)


@_typer()
def _lee_filter(img, window_size:int):
def _lee_filter(img, window_size: int):
img_ = img.copy()
ndimage_type = ndfilters
if hasattr(img,"data"):
if hasattr(img, "data"):
if isinstance(img.data, (memoryview, np.ndarray)):
ndimage_type = ndimage
img = img.data
# print(ndimage_type)
binary_nan = ndimage_type.minimum_filter(xr.where(np.isnan(img),0,1),size=window_size)
binary_nan = np.where(binary_nan==0,np.nan,1)
img = xr.where(np.isnan(img),0,img)
window_size = da.from_array([window_size,window_size,1])

binary_nan = ndimage_type.minimum_filter(
xr.where(np.isnan(img), 0, 1), size=window_size
)
binary_nan = np.where(binary_nan == 0, np.nan, 1)
img = xr.where(np.isnan(img), 0, img)
window_size = da.from_array([window_size, window_size, 1])

img_mean = ndimage_type.uniform_filter(img, window_size)
img_sqr_mean = ndimage_type.uniform_filter(img**2, window_size)
img_variance = img_sqr_mean - img_mean**2

overall_variance = np.var(img,axis=(0,1))
img_weights = img_variance / (np.add(img_variance,overall_variance))
overall_variance = np.var(img, axis=(0, 1))

img_weights = img_variance / (np.add(img_variance, overall_variance))

img_output = img_mean + img_weights * (np.subtract(img, img_mean))
img_output = xr.where(np.isnan(binary_nan),img_,img_output)
img_output = xr.where(np.isnan(binary_nan), img_, img_output)
return img_output


@xr.register_dataarray_accessor("ed")
class EarthDailyAccessorDataArray:
def __init__(self, xarray_obj):
self._obj = xarray_obj

@_typer()
def plot_band(self, cmap='Greys', col="time", col_wrap=5, **kwargs):
def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)

@_typer()
def plot_index(self, cmap='RdYlGn', vmin=-1,vmax=1, col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)
def plot_index(
self, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
return self._obj.plot.imshow(
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
)



@xr.register_dataset_accessor("ed")
class EarthDailyAccessorDataset:
def __init__(self, xarray_obj):
self._obj = xarray_obj

@_typer()
def plot_rgb(self, red:str='red', green:str='green', blue:str='blue', col="time", col_wrap=5, **kwargs):
return self._obj[[red,green,blue]].to_array(dim="bands").plot.imshow(col=col, col_wrap=col_wrap, **kwargs)

def plot_rgb(
self,
red: str = "red",
green: str = "green",
blue: str = "blue",
col="time",
col_wrap=5,
**kwargs,
):
return (
self._obj[[red, green, blue]]
.to_array(dim="bands")
.plot.imshow(col=col, col_wrap=col_wrap, **kwargs)
)

@_typer()
def plot_band(self, band, cmap='Greys', col="time", col_wrap=5, **kwargs):
return self._obj[band].plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)
def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj[band].plot.imshow(
cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
)

@_typer()
def plot_index(self, index, cmap='RdYlGn', vmin=-1,vmax=1, col="time", col_wrap=5, **kwargs):
return self._obj[index].plot.imshow(vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)


def plot_index(
self, index, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
return self._obj[index].plot.imshow(
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
)

@_typer()
def lee_filter(self, window_size:int=7):
return xr.apply_ufunc(_lee_filter,
self._obj,
input_core_dims=[["time"]],
dask="allowed",
output_core_dims=[["time"]],
kwargs=dict(window_size=window_size))
def lee_filter(self, window_size: int = 7):
return xr.apply_ufunc(
_lee_filter,
self._obj,
input_core_dims=[["time"]],
dask="allowed",
output_core_dims=[["time"]],
kwargs=dict(window_size=window_size),
)

@_typer()
def centroid(self, to_wkt:str=False, to_4326:bool=True):
def centroid(self, to_wkt: str = False, to_4326: bool = True):
"""Return the geographic center point in 4326/WKT of this dataset."""
# we can use a cache on our accessor objects, because accessors
# themselves are cached on instances that access them.
lon = float(self._obj.x[int(self._obj.x.size/2)])
lat = float(self._obj.y[int(self._obj.y.size/2)])
point = gpd.GeoSeries([Point(lon,lat)],crs=self._obj.rio.crs)
lon = float(self._obj.x[int(self._obj.x.size / 2)])
lat = float(self._obj.y[int(self._obj.y.size / 2)])
point = gpd.GeoSeries([Point(lon, lat)], crs=self._obj.rio.crs)
if to_4326:
point = point.to_crs(epsg='4326')
point = point.to_crs(epsg="4326")
if to_wkt:
point = point.map(lambda x: x.wkt).iloc[0]
return point


def _auto_mapper(self):
_BAND_MAPPING = {'coastal': 'A', 'blue': 'B', 'green': 'G', 'yellow': 'Y', 'red': 'R', 'rededge1': 'RE1', 'rededge2': 'RE2', 'rededge3': 'RE3', 'nir': 'N', 'nir08': 'N2', 'watervapor': 'WV', 'swir16': 'S1', 'swir22': 'S2', 'lwir': 'T1', 'lwir11': 'T2', 'vv': 'VV', 'vh': 'VH', 'hh': 'HH', 'hv': 'HV'}
_BAND_MAPPING = {
"coastal": "A",
"blue": "B",
"green": "G",
"yellow": "Y",
"red": "R",
"rededge1": "RE1",
"rededge2": "RE2",
"rededge3": "RE3",
"nir": "N",
"nir08": "N2",
"watervapor": "WV",
"swir16": "S1",
"swir22": "S2",
"lwir": "T1",
"lwir11": "T2",
"vv": "VV",
"vh": "VH",
"hh": "HH",
"hv": "HV",
}

params = {}
data_vars = list(self._obj.rename({var:var.lower() for var in self._obj.data_vars}).data_vars)
data_vars = list(
self._obj.rename(
{var: var.lower() for var in self._obj.data_vars}
).data_vars
)
for v in data_vars:
if v in _BAND_MAPPING.keys():
params[_BAND_MAPPING[v]] = self._obj[v]
return params

def list_available_index(self, details=False):
mapper = list(self._auto_mapper().keys())
indices = spyndex.indices
available_indices = []
for k,v in indices.items():
for k, v in indices.items():
needed_bands = v.bands
for needed_band in needed_bands:
if needed_band not in mapper:
break
available_indices.append(spyndex.indices[k] if details else k)
return available_indices

@_typer()
def add_index(self, index:list, **kwargs):
def add_index(self, index: list, **kwargs):
"""
Uses spyndex to compute and add index.
For list of indices, see https://github.com/awesome-spectral-indices/awesome-spectral-indices.
Parameters
----------
Expand All @@ -164,32 +238,39 @@ def add_index(self, index:list, **kwargs):
The input xr.Dataset with new data_vars of indices.
"""


params = {}
bands_mapping = self._auto_mapper()
for k,v in bands_mapping.items():
for k, v in bands_mapping.items():
params[k] = self._obj[v]
params.update(**kwargs)
idx = spyndex.computeIndex(
index = index,
params = params,
**kwargs
)
idx = spyndex.computeIndex(index=index, params=params, **kwargs)

if len(index)==1:
if len(index) == 1:
idx = idx.expand_dims(index=index)
idx = idx.to_dataset(dim='index')
return xr.merge((self._obj,idx))
idx = idx.to_dataset(dim="index")

return xr.merge((self._obj, idx))

@_typer()
def sel_nearest_dates(self, target, max_delta:int=0, method:str="nearest", return_target:bool=False):
src_time = self._obj.sel(time=target.time.dt.date,method=method).time.dt.date
def sel_nearest_dates(
self,
target,
max_delta: int = 0,
method: str = "nearest",
return_target: bool = False,
):
src_time = self._obj.sel(time=target.time.dt.date, method=method).time.dt.date
target_time = target.time.dt.date
pos = np.abs(src_time.data-target_time.data)
pos = [src_time.isel(time=i).time.values for i,j in enumerate(pos) if j.days<= max_delta]
pos = np.abs(src_time.data - target_time.data)
pos = [
src_time.isel(time=i).time.values
for i, j in enumerate(pos)
if j.days <= max_delta
]
if return_target:
method_convert = {"bfill":"ffill","ffill":"bfill","nearest":"nearest"}
return self._obj.sel(time=pos), target.sel(time=pos,method=method_convert[method])
method_convert = {"bfill": "ffill", "ffill": "bfill", "nearest": "nearest"}
return self._obj.sel(time=pos), target.sel(
time=pos, method=method_convert[method]
)
return self._obj.sel(time=pos)

0 comments on commit 61daf65

Please sign in to comment.