Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ ds = xarray.open_dataset(
)
```

Open an ImageCollection with lazy loading to defer metadata RPCs until data access time:

```python
ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate(
'1992-10-05', '1993-03-31')
ds = xarray.open_dataset(
ic,
engine='ee',
lazy_load=True # Defers metadata RPCs for faster dataset opening
)
```

Open multiple ImageCollections into one `xarray.Dataset`, all with the same
projection:

Expand Down
151 changes: 105 additions & 46 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def open(
executor_kwargs: dict[str, Any] | None = None,
getitem_kwargs: dict[str, int] | None = None,
fast_time_slicing: bool = False,
lazy_load: bool = False,
) -> EarthEngineStore:
if mode != 'r':
raise ValueError(
Expand All @@ -186,6 +187,7 @@ def open(
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
lazy_load=lazy_load,
)

def __init__(
Expand All @@ -206,10 +208,12 @@ def __init__(
executor_kwargs: dict[str, Any] | None = None,
getitem_kwargs: dict[str, int] | None = None,
fast_time_slicing: bool = False,
lazy_load: bool = False,
):
self.ee_init_kwargs = ee_init_kwargs
self.ee_init_if_necessary = ee_init_if_necessary
self.fast_time_slicing = fast_time_slicing
self.lazy_load = lazy_load

# Initialize executor_kwargs
if executor_kwargs is None:
Expand All @@ -227,8 +231,11 @@ def __init__(
self.primary_dim_name = primary_dim_name or 'time'
self.primary_dim_property = primary_dim_property or 'system:time_start'

# Always need to get size for n_images
self.n_images = self.get_info['size']
self._props = self.get_info['props']
# These are loaded lazily if lazy_load=True
if 'props' in self.get_info:
self._props = self.get_info['props']
# Metadata should apply to all imgs.
self._img_info: types.ImageInfo = self.get_info['first']

Expand Down Expand Up @@ -281,57 +288,106 @@ def __init__(

@functools.cached_property
def get_info(self) -> dict[str, Any]:
"""Make all getInfo() calls to EE at once."""
"""Make all getInfo() calls to EE at once.

If lazy_load is True, only performs essential metadata calls and defers
other calls until data access time.
"""

if not hasattr(self, '_info_cache'):
self._info_cache = {}

# Perform minimal RPCs if lazy loading is enabled
if getattr(self, 'lazy_load', False):
# Only fetch essential metadata needed for dataset structure
if not self._info_cache:
rpcs = [
('size', self.image_collection.size()),
('first', self.image_collection.first()),
]

if isinstance(self.projection, ee.Projection):
rpcs.append(('projection', self.projection))

if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection)))
else:
rpcs.append(
(
'bounds',
self.image_collection.first()
.geometry()
.bounds(1, proj=self.projection),
)
)

info = ee.List([rpc for _, rpc in rpcs]).getInfo()
self._info_cache.update(dict(zip((name for name, _ in rpcs), info)))

return self._info_cache

# Full metadata loading if not lazy
if not self._info_cache or len(self._info_cache) < 5: # Check if we have full metadata
rpcs = [
('size', self.image_collection.size()),
('props', self.image_collection.toDictionary()),
('first', self.image_collection.first()),
]

rpcs = [
('size', self.image_collection.size()),
('props', self.image_collection.toDictionary()),
('first', self.image_collection.first()),
]
if isinstance(self.projection, ee.Projection):
rpcs.append(('projection', self.projection))

if isinstance(self.projection, ee.Projection):
rpcs.append(('projection', self.projection))
if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection)))
else:
rpcs.append(
(
'bounds',
self.image_collection.first()
.geometry()
.bounds(1, proj=self.projection),
)
)

if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection)))
else:
# TODO(#29, #30): This RPC call takes the longest time to compute. This
# requires a full scan of the images in the collection, which happens on the
# EE backend. This is essential because we want the primary dimension of the
# opened dataset to be something relevant to the data, like time (start
# time) as opposed to a random index number.
#
# One optimization that could prove really fruitful: read the first and last
# (few) values of the primary dim (read: time) and interpolate the rest
# client-side. Ideally, this would live behind a xarray-backend-specific
# feature flag, since it's not guaranteed that data is this consistent.
columns = ['system:id', self.primary_dim_property]
rpcs.append(
(
'bounds',
self.image_collection.first()
.geometry()
.bounds(1, proj=self.projection),
'properties',
(
self.image_collection.reduceColumns(
ee.Reducer.toList().repeat(len(columns)), columns
).get('list')
),
)
)

# TODO(#29, #30): This RPC call takes the longest time to compute. This
# requires a full scan of the images in the collection, which happens on the
# EE backend. This is essential because we want the primary dimension of the
# opened dataset to be something relevant to the data, like time (start
# time) as opposed to a random index number.
#
# One optimization that could prove really fruitful: read the first and last
# (few) values of the primary dim (read: time) and interpolate the rest
# client-side. Ideally, this would live behind a xarray-backend-specific
# feature flag, since it's not guaranteed that data is this consistent.
columns = ['system:id', self.primary_dim_property]
rpcs.append(
(
'properties',
(
self.image_collection.reduceColumns(
ee.Reducer.toList().repeat(len(columns)), columns
).get('list')
),
)
)

info = ee.List([rpc for _, rpc in rpcs]).getInfo()

return dict(zip((name for name, _ in rpcs), info))
info = ee.List([rpc for _, rpc in rpcs]).getInfo()
self._info_cache.update(dict(zip((name for name, _ in rpcs), info)))

return self._info_cache

@property
def image_collection_properties(self) -> tuple[list[str], list[str]]:
if self.lazy_load and 'properties' not in self._info_cache:
# Fetch properties on-demand if lazy loading is enabled
columns = ['system:id', self.primary_dim_property]
properties = (
self.image_collection.reduceColumns(
ee.Reducer.toList().repeat(len(columns)), columns
).get('list')
).getInfo()
self._info_cache['properties'] = properties

system_ids, primary_coord = self.get_info['properties']
return (system_ids, primary_coord)

Expand Down Expand Up @@ -942,18 +998,16 @@ def _raw_indexing_method(
math.ceil(w_range / self._apparent_chunks['width']),
math.ceil(h_range / self._apparent_chunks['height']),
)
tiles = [
[[None for _ in range(shape[2])] for _ in range(shape[1])]
for _ in range(shape[0])
]
# Pre-allocate tiles with numpy array instead of nested lists
tiles = np.empty(shape, dtype=object)

with concurrent.futures.ThreadPoolExecutor(
**self.store.executor_kwargs
) as pool:
for (i, j, k), arr in pool.map(
self._make_tile, self._tile_indexes(key[0], bbox)
):
tiles[i][j][k] = arr
tiles[i, j, k] = arr

out = np.block(tiles)

Expand Down Expand Up @@ -1044,6 +1098,7 @@ def open_dataset(
executor_kwargs: dict[str, Any] | None = None,
getitem_kwargs: dict[str, int] | None = None,
fast_time_slicing: bool = False,
lazy_load: bool = False,
) -> xarray.Dataset: # type: ignore
"""Open an Earth Engine ImageCollection as an Xarray Dataset.

Expand Down Expand Up @@ -1126,6 +1181,9 @@ def open_dataset(
makes slicing an ImageCollection across time faster. This optimization
loads EE images in a slice by ID, so any modifications to images in a
computed ImageCollection will not be reflected.
lazy_load (optional): If True, defers metadata RPCs to data access time,
making opening datasets faster. Similar to xr.open_zarr(..., chunks=None)
behavior. Defaults to False.
Returns:
An xarray.Dataset that streams in remote data from Earth Engine.
"""
Expand Down Expand Up @@ -1158,6 +1216,7 @@ def open_dataset(
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
lazy_load=lazy_load,
)

store_entrypoint = backends_store.StoreBackendEntrypoint()
Expand Down
41 changes: 41 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import pathlib
import tempfile
import time

from absl.testing import absltest
from google.auth import identity_pool
Expand Down Expand Up @@ -556,6 +557,46 @@ def test_fast_time_slicing(self):
fast_slicing = xr.open_dataset(**params, fast_time_slicing=True)
fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy()
self.assertTrue(np.all(fast_slicing_data > 0))

def test_lazy_loading(self):
"""Test that lazy loading defers metadata RPCs until data access time."""
ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate(
'1992-10-05', '1992-10-06') # Using a smaller date range for the test

# Open dataset with lazy loading
start_time = time.time()
lazy_ds = xr.open_dataset(
ic,
engine=xee.EarthEngineBackendEntrypoint,
lazy_load=True,
)
lazy_open_time = time.time() - start_time

# Open dataset without lazy loading
start_time = time.time()
regular_ds = xr.open_dataset(
ic,
engine=xee.EarthEngineBackendEntrypoint,
lazy_load=False,
)
regular_open_time = time.time() - start_time

# Verify that lazy opening is faster than regular opening
self.assertLess(lazy_open_time, regular_open_time,
f"Lazy loading ({lazy_open_time:.2f}s) should be faster than regular loading ({regular_open_time:.2f}s)")

# Verify that both datasets have the same structure
self.assertEqual(lazy_ds.dims, regular_ds.dims)
self.assertEqual(list(lazy_ds.data_vars), list(regular_ds.data_vars))

# Access data and verify it's the same
var_name = list(lazy_ds.data_vars)[0]
lazy_data = lazy_ds[var_name].isel(time=0).values
regular_data = regular_ds[var_name].isel(time=0).values

# Both should have same shape and data should not be all zeros or NaNs
self.assertEqual(lazy_data.shape, regular_data.shape)
self.assertTrue(np.allclose(lazy_data, regular_data, equal_nan=True))

@absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded')
def test_write_projected_dataset_to_raster(self):
Expand Down