Description
Hi, I'm testing out flox (with dask) as replacement for scala based zonal stats on global rasters (30m resolution) and getting promising results in performance with cleaner and much smaller code! However I'm running into this memory issue that I wanted to see if has cleaner solution than what I'm doing now.
Here is the simple code for calculating tree cover loss area at three political boundary levels, grouped by a total of six dask array layers with expected_groups sizes of (23, 5, 7, 248, 854, 86).
tcl_by_year = xarray_reduce(
areas.band_data,
tcl_data,
drivers_data,
tcd_thresholds_data,
gadm_adm0_data,
gadm_adm1_data,
gadm_adm2_data,
func='sum',
expected_groups=(tcl_years, drivers_cats, tcd_threshold_levels, gadm_adm0_ids, gadm_adm1_ids, gadm_adm2_ids),
).compute()
That runs into this error:
MemoryError: Unable to allocate 109 GiB for an array with shape (14662360160,) and data type int64
I'm getting around this by chunking the group-by layer with the highest unique labels (854) and doing the above in a dask delayed function over the chunks and concatenating the results.
from dask import delayed
chunk_size = 200
tasks = [
delayed(reduce_chunk)(gadm_adm2_ids[i:i+chunk_size])
for i in range(0, len(gadm_adm2_ids), chunk_size)
]
results = dask.compute(*tasks)
combined = xr.concat(results, dim="gadm_adm2")
chunked = final.chunk({'tcl_year':-1, 'drivers':-1, 'tcd_threshold': -1, 'gadm_adm0':1, 'gadm_adm2':-1, 'gadm_adm1':-1})
chunked.to_zarr("s3://**/gadm_results.zarr", mode="w")

This works but I may not be persisting some of these layers correctly for use by the expected_group chunks and runs slower than expected. Is there a more efficient and elegant way to handle this situation? It'd be great for example if this MultiIndex can be built dynamically in the workers with the available groups. Thanks.
Environment:
Flox: 0.10.1
Dask: 2025.2.0
Numpy: 2.2.3
Xarray: 2025.1.2