Skip to content

Commit

Permalink
v0.0.12
Browse files Browse the repository at this point in the history
v0.0.12
  • Loading branch information
nkarasiak authored Mar 6, 2024
2 parents 2178a63 + 963fc9e commit aa6648d
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 26 deletions.
50 changes: 29 additions & 21 deletions earthdaily/accessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,36 @@ class MisType(Warning):
_SUPPORTED_DTYPE = [int, float, list, bool, str]


def _typer(raise_mistype=False):
def _typer(raise_mistype=False, custom_types={}):
def decorator(func):
def force(*args, **kwargs):
_args = list(args)
idx = 1
func_arg = func.__code__.co_varnames
for key, val in func.__annotations__.items():
if not isinstance(val, (list, tuple)):
val = [val]
idx = [i for i in range(len(func_arg)) if func_arg[i] == key][0]
is_kwargs = key in kwargs.keys()
if not is_kwargs and idx >= len(args):
continue
input_value = kwargs.get(key, None) if is_kwargs else args[idx]
if type(input_value) == val:
if type(input_value) in val:
continue
if raise_mistype and (
val != type(kwargs.get(key))
if (
type(kwargs.get(key)) not in val
if is_kwargs
else val != type(args[idx])
else type(args[idx]) not in val
):
if raise_mistype:
if is_kwargs:
expected = f"{type(kwargs[key]).__name__} ({kwargs[key]})"
else:
expected = f"{type(args[idx]).__name__} ({args[idx]})"
raise MisType(f"{key} expected {val.__name__}, not {expected}.")
if is_kwargs:
expected = f"{type(kwargs[key]).__name__} ({kwargs[key]})"
kwargs[key] = val[0](kwargs[key])
else:
expected = f"{type(args[idx]).__name__} ({args[idx]})"

raise MisType(
f"{key} expected a {val.__name__}, not a {expected}."
)
if is_kwargs:
kwargs[key] = val(kwargs[key]) if val != list else [kwargs[key]]
elif len(args) >= idx:
if isinstance(val, (list, tuple)) and len(val) > 1:
val = val[0]
_args[idx] = val(args[idx]) if val != list else [args[idx]]
idx += 1
_args[idx] = val[0](args[idx])
args = tuple(_args)
return func(*args, **kwargs)

Expand Down Expand Up @@ -335,8 +333,18 @@ def whittaker(
max_iter=max_iter,
)

def zonal_stats(self, geometry, operations: list = ["mean"]):
def zonal_stats(
self,
geometry,
operations: list = ["mean"],
raise_missing_geometry: bool = False,
):
from ..earthdatastore.cube_utils import zonal_stats, GeometryManager

geometry = GeometryManager(geometry).to_geopandas()
return zonal_stats(self._obj, geometry, operations=operations)
return zonal_stats(
self._obj,
geometry,
operations=operations,
raise_missing_geometry=raise_missing_geometry,
)
54 changes: 52 additions & 2 deletions earthdaily/earthdatastore/cube_utils/_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,77 @@ def zonal_stats_numpy(
def zonal_stats(
dataset,
gdf,
operations=["mean"],
operations: list = ["mean"],
all_touched=False,
method="geocube",
verbose=False,
raise_missing_geometry=False,
):
"""
Parameters
----------
dataset : xr.Dataset
DESCRIPTION.
gdf : gpd.GeoDataFrame
DESCRIPTION.
operations : TYPE, list.
DESCRIPTION. The default is ["mean"].
all_touched : TYPE, optional
DESCRIPTION. The default is False.
method : TYPE, optional
DESCRIPTION. The default is "geocube".
verbose : TYPE, optional
DESCRIPTION. The default is False.
raise_missing_geometry : TYPE, optional
DESCRIPTION. The default is False.
Raises
------
ValueError
DESCRIPTION.
NotImplementedError
DESCRIPTION.
Returns
-------
TYPE
DESCRIPTION.
"""
if method == "geocube":
from geocube.api.core import make_geocube
from geocube.rasterize import rasterize_image

def custom_rasterize_image(all_touched=all_touched, **kwargs):
return rasterize_image(all_touched=all_touched, **kwargs)

gdf["tmp_index"] = np.arange(gdf.shape[0])
out_grid = make_geocube(
gdf,
measurements=["tmp_index"],
like=dataset, # ensure the data are on the same grid
rasterize_function=custom_rasterize_image,
)
cube = dataset.groupby(out_grid.tmp_index)
zonal_stats = xr.concat(
[getattr(cube, operation)() for operation in operations], dim="stats"
)
zonal_stats["stats"] = operations
zonal_stats["tmp_index"] = list(gdf.index)

if zonal_stats["tmp_index"].size != gdf.shape[0]:
index_list = [
gdf.index[i] for i in zonal_stats["tmp_index"].values.astype(np.int16)
]
if raise_missing_geometry:
diff = gdf.shape[0] - len(index_list)
raise ValueError(
f'{diff} geometr{"y is" if diff==1 else "ies are"} missing in the zonal stats. This can be due to too small geometries, duplicated...'
)
else:
index_list = list(gdf.index)
zonal_stats["tmp_index"] = index_list
return zonal_stats.rename(dict(tmp_index="feature"))

tqdm_bar = tqdm.tqdm(total=gdf.shape[0])
Expand Down
12 changes: 9 additions & 3 deletions tests/test_zonalstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,23 @@ def setUp(self, constant=np.random.randint(1, 12)):
"time": times,
},
).rio.write_crs("EPSG:4326")
ds = ds.transpose('time','x','y')
# first pixel

geometry = [
Polygon([(0, 0), (0, 0.5), (0.5, 0.5), (0.5, 0)]),
Polygon([(0, 0), (0, 1.2), (1.2, 1.2), (1.2, 0)]),
Polygon([(1, 1), (9, 1), (9, 2.1), (1, 1)])
]
# out of bound geom # Polygon([(10,10), (10,11), (11,11), (11,10)])]
gdf = gpd.GeoDataFrame({"geometry": geometry}, crs="EPSG:4326")
gdf.index = ['tosmall','ok','ok']
self.gdf = gdf
self.datacube = ds


def test_basic(self):
zonalstats = earthdaily.earthdatastore.cube_utils.zonal_stats(
self.datacube, self.gdf, all_touched=True, operations=["min", "max"]
self.datacube, self.gdf, operations=["min", "max"], raise_missing_geometry=False
)
for operation in ["min", "max"]:
self._check_results(
Expand All @@ -55,6 +57,10 @@ def _check_results(self, stats_values, operation="min"):
}
self.assertTrue(np.all(stats_values == results[operation]))


def test_error(self):
with self.assertRaises(ValueError):
earthdaily.earthdatastore.cube_utils.zonal_stats(
self.datacube, self.gdf, operations=["min", "max"], raise_missing_geometry=True)

if __name__ == "__main__":
unittest.main()

0 comments on commit aa6648d

Please sign in to comment.