diff --git a/CHANGELOG.md b/CHANGELOG.md index e275f150..9113b38e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.0.8] - unreleased +## [0.0.8] - 2024-02-28 + +### Added + +- better management of `col_wrap` in `ed` xarray accessor. + +### Fixed + +- some bugs in xarray `ed` accessor. ## [0.0.7] - 2024-02-28 -#### Added +### Added - xarray `ed` accessor. diff --git a/earthdaily/accessor/__init__.py b/earthdaily/accessor/__init__.py index e23c93b3..dac4bba2 100644 --- a/earthdaily/accessor/__init__.py +++ b/earthdaily/accessor/__init__.py @@ -24,14 +24,27 @@ class MisType(Warning): def _typer(raise_mistype=False): def decorator(func): def force(*args, **kwargs): + _args = list(args) + idx = 1 for key, val in func.__annotations__.items(): - if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None: + is_kwargs = key in kwargs.keys() + if val not in _SUPPORTED_DTYPE or kwargs.get(key, None) is None and is_kwargs or len(args)==1: continue - if raise_mistype and val != type(kwargs.get(key)): - raise MisType( - f"{key} expected a {val.__name__}, not a {type(kwargs[key]).__name__} ({kwargs[key]})" + if raise_mistype and (val != type(kwargs.get(key)) if is_kwargs else val != type(args[idx])): + if is_kwargs: + expected = f"{type(kwargs[key]).__name__} ({kwargs[key]})" + else: + expected = f"{type(args[idx]).__name__} ({args[idx]})" + + raise MisType( + f"{key} expected a {val.__name__}, not a {expected}." ) - kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]] + if is_kwargs: + kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]] + else: + _args[idx] = val(args[idx]) if val != list else [args[idx]] + idx+=1 + args = tuple(_args) return func(*args, **kwargs) return force @@ -103,22 +116,24 @@ def _lee_filter(img, window_size: int): img_output = xr.where(np.isnan(binary_nan), img_, img_output) return img_output - @xr.register_dataarray_accessor("ed") class EarthDailyAccessorDataArray: def __init__(self, xarray_obj): self._obj = xarray_obj + + def _max_time_wrap(self, wish=5): + return np.min((wish,self._obj['time'].size)) @_typer() def plot_band(self, cmap="Greys", col="time", col_wrap=5, **kwargs): - return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=col_wrap, **kwargs) + return self._obj.plot.imshow(cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs) @_typer() def plot_index( self, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs ): return self._obj.plot.imshow( - vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs + vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs ) @@ -127,6 +142,10 @@ class EarthDailyAccessorDataset: def __init__(self, xarray_obj): self._obj = xarray_obj + def _max_time_wrap(self, wish=5): + return np.min((wish,self._obj['time'].size)) + + @_typer() def plot_rgb( self, @@ -140,13 +159,13 @@ def plot_rgb( return ( self._obj[[red, green, blue]] .to_array(dim="bands") - .plot.imshow(col=col, col_wrap=col_wrap, **kwargs) + .plot.imshow(col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs) ) @_typer() def plot_band(self, band, cmap="Greys", col="time", col_wrap=5, **kwargs): return self._obj[band].plot.imshow( - cmap=cmap, col=col, col_wrap=col_wrap, **kwargs + cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs ) @_typer() @@ -154,7 +173,7 @@ def plot_index( self, index, cmap="RdYlGn", vmin=-1, vmax=1, col="time", col_wrap=5, **kwargs ): return self._obj[index].plot.imshow( - vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=col_wrap, **kwargs + vmin=vmin, vmax=vmax, cmap=cmap, col=col, col_wrap=self._max_time_wrap(col_wrap), **kwargs ) @_typer() @@ -216,7 +235,7 @@ def _auto_mapper(self): params[_BAND_MAPPING[v]] = self._obj[v] return params - def list_available_index(self, details=False): + def available_index(self, details=False): mapper = list(self._auto_mapper().keys()) indices = spyndex.indices available_indices = [] @@ -248,9 +267,7 @@ def add_index(self, index: list, **kwargs): """ params = {} - bands_mapping = self._auto_mapper() - for k, v in bands_mapping.items(): - params[k] = self._obj[v] + params = self._auto_mapper() params.update(**kwargs) idx = spyndex.computeIndex(index=index, params=params, **kwargs) diff --git a/examples/compare_scale_s2.py b/examples/compare_scale_s2.py index e88ca1e6..f159068f 100644 --- a/examples/compare_scale_s2.py +++ b/examples/compare_scale_s2.py @@ -50,7 +50,7 @@ def get_cube(rescale=True): # Plots cube with SCL with at least 50% of clear data # ---------------------------------------------------- -pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3) +pivot_cube.ed.plot_rgb(col_wrap=3) plt.show() ############################################################################## @@ -66,6 +66,6 @@ def get_cube(rescale=True): # ---------------------------------------------------- -pivot_cube.ed.plot_rgb(vmin=0, vmax=0.33, col="time", col_wrap=3) +pivot_cube.ed.plot_rgb(col_wrap=3) plt.show() diff --git a/examples/multisensors_cube.py b/examples/multisensors_cube.py index 1150947b..70fc8461 100644 --- a/examples/multisensors_cube.py +++ b/examples/multisensors_cube.py @@ -43,9 +43,7 @@ ) # Add the NDVI -datacube["ndvi"] = (datacube["nir"] - datacube["red"]) / ( - datacube["nir"] + datacube["red"] -) +datacube = datacube.ed.add_index('NDVI') # Load in memory datacube = datacube.load() @@ -63,7 +61,7 @@ # See the NDVI evolution # ------------------------------------------- -datacube["ndvi"].plot.imshow( +datacube["NDVI"].plot.imshow( col="time", col_wrap=3, vmin=0, vmax=0.8, cmap="RdYlGn" ) plt.show() @@ -72,6 +70,6 @@ # See the NDVI mean evolution # ------------------------------------------- -datacube["ndvi"].groupby("time").mean(...).plot.line(x="time") +datacube["NDVI"].groupby("time").mean(...).plot.line(x="time") plt.title("NDVI evolution") plt.show() diff --git a/examples/venus_cube_mask.py b/examples/venus_cube_mask.py index 00c08e0b..00390cd3 100644 --- a/examples/venus_cube_mask.py +++ b/examples/venus_cube_mask.py @@ -37,7 +37,7 @@ # Search for items # ------------------------------------------- -items = eds.search(collection, query=query, prefer_alternate="download") +items = eds.search(collection, query=query, prefer_alternate="download", limit=5) ############################################################################## # .. note:: @@ -72,4 +72,5 @@ ) print(venus_datacube) -venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).ed.plot_rgb() +venus_datacube.isel(time=slice(29, 31), x=slice(4000, 4500), y=slice(4000, 4500)).ed.plot_rgb(vmax=0.2) +