Skip to content

Commit

Permalink
fix(_typer)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasK committed Feb 28, 2024
1 parent c6470c3 commit 9fdf797
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 24 deletions.
47 changes: 32 additions & 15 deletions earthdaily/accessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,27 @@ class MisType(Warning):
def _typer(raise_mistype=False):
def decorator(func):
def force(*args, **kwargs):
_args = list(args)
idx = 1
for key, val in func.__annotations__.items():
if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None:
is_kwargs = key in kwargs.keys()
if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None and is_kwargs or len(args)==1:
continue
if raise_mistype and val != type(kwargs.get(key)):
raise MisType(
f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})"
if raise_mistype and (val != type(kwargs.get(key)) if is_kwargs else val != type(args[idx])):
if is_kwargs:
expected = f"{type(kwargs[key]).__name__} ({kwargs[key]})"
else:
expected = f"{type(args[idx]).__name__} ({args[idx]})"

raise MisType(
f"{key} expected a {val.__name__}, not a {expected}."
)
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
if is_kwargs:
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
else:
_args[idx] = val(args[idx]) if val != list else [args[idx]]
idx+=1
args = tuple(_args)
return func(*args, **kwargs)

return force
Expand Down Expand Up @@ -103,22 +116,24 @@ def _lee_filter(img, window_size: int):
img_output = xr.where(np.isnan(binary_nan), img_, img_output)
return img_output


@xr.register_dataarray_accessor("ed")
class EarthDailyAccessorDataArray:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def _max_time_wrap(self, wish=5):
return np.min((wish,self._obj['time'].size))

@_typer()
def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs)
return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs)

@_typer()
def plot_index(
self, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
return self._obj.plot.imshow(
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
)


Expand All @@ -127,6 +142,10 @@ class EarthDailyAccessorDataset:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def _max_time_wrap(self, wish=5):
return np.min((wish,self._obj['time'].size))


@_typer()
def plot_rgb(
self,
Expand All @@ -140,21 +159,21 @@ def plot_rgb(
return (
self._obj[[red, green, blue]]
.to_array(dim="bands")
.plot.imshow(col=col, col_wrap=col_wrap, **kwargs)
.plot.imshow(col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs)
)

@_typer()
def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs):
return self._obj[band].plot.imshow(
cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
)

@_typer()
def plot_index(
self, index, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs
):
return self._obj[index].plot.imshow(
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs
vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs
)

@_typer()
Expand Down Expand Up @@ -216,7 +235,7 @@ def _auto_mapper(self):
params[_BAND_MAPPING[v]] = self._obj[v]
return params

def list_available_index(self, details=False):
def available_index(self, details=False):
mapper = list(self._auto_mapper().keys())
indices = spyndex.indices
available_indices = []
Expand Down Expand Up @@ -248,9 +267,7 @@ def add_index(self, index: list, **kwargs):
"""

params = {}
bands_mapping = self._auto_mapper()
for k, v in bands_mapping.items():
params[k] = self._obj[v]
params = self._auto_mapper()
params.update(**kwargs)
idx = spyndex.computeIndex(index=index, params=params, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions examples/compare_scale_s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_cube(rescale=True):
# Plots cube with SCL with at least 50% of clear data
# ----------------------------------------------------

pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)
pivot_cube.ed.plot_rgb(col_wrap=3)
plt.show()

##############################################################################
Expand All @@ -66,6 +66,6 @@ def get_cube(rescale=True):
# ----------------------------------------------------


pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3)
pivot_cube.ed.plot_rgb(col_wrap=3)

plt.show()
8 changes: 3 additions & 5 deletions examples/multisensors_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
)

# Add the NDVI
datacube["ndvi"] = (datacube["nir"] - datacube["red"]) / (
datacube["nir"] + datacube["red"]
)
datacube = datacube.ed.add_index('NDVI')

# Load in memory
datacube = datacube.load()
Expand All @@ -63,7 +61,7 @@
# See the NDVI evolution
# -------------------------------------------

datacube["ndvi"].plot.imshow(
datacube["NDVI"].plot.imshow(
col="time", col_wrap=3, vmin=0, vmax=0.8, cmap="RdYlGn"
)
plt.show()
Expand All @@ -72,6 +70,6 @@
# See the NDVI mean evolution
# -------------------------------------------

datacube["ndvi"].groupby("time").mean(...).plot.line(x="time")
datacube["NDVI"].groupby("time").mean(...).plot.line(x="time")
plt.title("NDVI evolution")
plt.show()
5 changes: 3 additions & 2 deletions examples/venus_cube_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# Search for items
# -------------------------------------------

items = eds.search(collection, query=query, prefer_alternate="download")
items = eds.search(collection, query=query, prefer_alternate="download", limit=5)

##############################################################################
# .. note::
Expand Down Expand Up @@ -72,4 +72,5 @@
)
print(venus_datacube)

venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).ed.plot_rgb()
venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).ed.plot_rgb(vmax=0.2)

0 comments on commit 9fdf797

Please sign in to comment.