Skip to content

Add Index.validate_dataarray_coord #10137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0707a8b
typing fixes and tweaks
benbovy Mar 17, 2025
75086ef
add Index.validate_dataarray_coord()
benbovy Mar 17, 2025
8aaf2b8
Dataset._construct_dataarray: validate index coord
benbovy Mar 17, 2025
c9b4baa
DataArray init: validate index coord
benbovy Mar 17, 2025
a47523f
clean-up old TODO
benbovy Mar 17, 2025
551808a
refactor dataarray coord update
benbovy Mar 17, 2025
818b7f5
docstring tweaks
benbovy Mar 17, 2025
e8df9b5
add tests
benbovy Mar 13, 2025
678c013
assert invariants: skip check IndexVariable ...
benbovy Mar 14, 2025
0f822b5
update cherry-picked tests
benbovy Mar 17, 2025
43c44ea
update assert datarray invariants
benbovy Mar 17, 2025
3b33263
doc: add Index.validate_dataarray_coords to API
benbovy Mar 17, 2025
a8e6e20
typo
benbovy Mar 17, 2025
f1440c4
update whats new
benbovy Mar 17, 2025
5da014e
add CoordinateValidationError
benbovy Mar 18, 2025
6026656
docstrings tweaks
benbovy Mar 18, 2025
1eeec9c
nit refactor
benbovy Mar 18, 2025
426ddce
small refactor
benbovy Mar 18, 2025
5c0cc0f
Merge branch 'main' into index-validate-dataarray-coords
benbovy Mar 27, 2025
4399036
docstrings improvements
benbovy Mar 31, 2025
828a4cc
docstrings improvements
benbovy Mar 31, 2025
273d70c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2025
f49c83a
Merge branch 'main' into index-validate-dataarray-coords
benbovy Apr 23, 2025
3e55af0
refactor index check method
benbovy Apr 24, 2025
073c0a2
small refactor
benbovy Apr 24, 2025
8d43dcc
forgot updating API docs and whats new
benbovy Apr 24, 2025
4e7c70a
nit docstrings
benbovy Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@
Index.stack
Index.unstack
Index.create_variables
Index.should_add_coord_in_dataarray
Index.to_pandas_index
Index.isel
Index.sel
Expand Down
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,7 @@ Exceptions
.. autosummary::
:toctree: generated/

CoordinateValidationError
MergeError
SerializationWarning

Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ New Features
- Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This
includes ``datatree`` support, and removing slashes from dimension names. By
`Miguel Jimenez-Urias <https://github.com/Mikejmnez>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding
:py:meth:`Index.should_add_coord_in_dataarray`. For example, this enables support for CF boundaries coordinate (e.g.,
``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from xarray.conventions import SerializationWarning, decode_cf
from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like
from xarray.core.coordinates import Coordinates
from xarray.core.coordinates import Coordinates, CoordinateValidationError
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
Expand Down Expand Up @@ -128,6 +128,7 @@
"NamedArray",
"Variable",
# Exceptions
"CoordinateValidationError",
"InvalidTreeError",
"MergeError",
"NotFoundInTreeError",
Expand Down
79 changes: 60 additions & 19 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
Expand Down Expand Up @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
Expand Down Expand Up @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

Expand Down Expand Up @@ -964,22 +964,14 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords
validate_dataarray_coords(
self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims
)

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down Expand Up @@ -1154,9 +1146,58 @@ def create_coords_with_default_indexes(
return new_coords


def _coordinates_from_variable(variable: Variable) -> Coordinates:
from xarray.core.indexes import create_default_index_implicit
class CoordinateValidationError(ValueError):
"""Error class for Xarray coordinate validation failures."""


def validate_dataarray_coords(
shape: tuple[int, ...],
coords: Coordinates | Mapping[Hashable, Variable],
dim: tuple[Hashable, ...],
):
"""Validate coordinates ``coords`` to include in a DataArray defined by
``shape`` and dimensions ``dim``.

If a coordinate is associated with an index, the validation is performed by
the index. By default the coordinate dimensions must match (a subset of) the
array dimensions (in any order) to conform to the DataArray model. The index
may override this behavior with other validation rules, though.

Non-index coordinates must all conform to the DataArray model. Scalar
coordinates are always valid.
"""
sizes = dict(zip(dim, shape, strict=True))
dim_set = set(dim)

indexes: Mapping[Hashable, Index]
if isinstance(coords, Coordinates):
indexes = coords.xindexes
else:
indexes = {}

for k, v in coords.items():
if k in indexes:
invalid = not indexes[k].should_add_coord_in_dataarray(k, v, dim_set)
else:
invalid = any(d not in dim for d in v.dims)

if invalid:
raise CoordinateValidationError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if d in sizes and s != sizes[d]:
raise CoordinateValidationError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
f"coordinate {k!r}"
)


def coordinates_from_variable(variable: Variable) -> Coordinates:
(name,) = variable.dims
new_index, index_vars = create_default_index_implicit(variable)
indexes = dict.fromkeys(index_vars, new_index)
Expand Down
22 changes: 2 additions & 20 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DataArrayCoordinates,
assert_coordinate_consistent,
create_coords_with_default_indexes,
validate_dataarray_coords,
)
from xarray.core.dataset import Dataset
from xarray.core.extension_array import PandasExtensionArray
Expand Down Expand Up @@ -132,25 +133,6 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
sizes = dict(zip(dim, shape, strict=True))
for k, v in coords.items():
if any(d not in dim for d in v.dims):
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if s != sizes[d]:
raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
f"coordinate {k!r}"
)


def _infer_coords_and_dims(
shape: tuple[int, ...],
coords: (
Expand Down Expand Up @@ -214,7 +196,7 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)
validate_dataarray_coords(shape, new_coords, dims_tuple)

return new_coords, dims_tuple

Expand Down
10 changes: 9 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,15 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
coords: dict[Hashable, Variable] = {}
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
if k in self._indexes:
add_coord = self._indexes[k].should_add_coord_in_dataarray(
k, self._variables[k], needed_dims
)
else:
var_dims = set(self._variables[k].dims)
add_coord = k in self._coord_names and var_dims <= needed_dims

if add_coord:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DatasetGroupByAggregations,
)
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.coordinates import Coordinates, coordinates_from_variable
from xarray.core.duck_array_ops import where
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def _flox_reduce(
new_coords.append(
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
# all associated levels properly.
_coordinates_from_variable(
coordinates_from_variable(
IndexVariable(
dims=grouper.name,
data=output_index,
Expand Down
44 changes: 44 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,50 @@ def create_variables(
else:
return {}

def should_add_coord_in_dataarray(
self,
name: Hashable,
var: Variable,
dims: set[Hashable],
) -> bool:
"""Define whether or not an index coordinate variable should be added in
a new DataArray.

This method is called repeatedly for each Variable associated with
this index when creating a new DataArray (via its constructor or from a
Dataset) or updating an existing one.

By default returns ``True`` if the dimensions of the coordinate variable
are a subset of the array dimensions and ``False`` otherwise (DataArray
model). This default behavior may be overridden in Index subclasses to
bypass strict conformance with the DataArray model. This is useful for
example to include the (n+1)-dimensional cell boundary coordinate
associated with an interval index.

Returning ``False`` will either:

- raise a :py:class:`CoordinateValidationError` when passing the
coordinate directly to a new or an existing DataArray, e.g., via
``DataArray.__init__()`` or ``DataArray.assign_coords()``

- drop the coordinate (and therefore drop the index) when a new
DataArray is constructed by indexing a Dataset

Parameters
----------
name : Hashable
Name of a coordinate variable associated to this index.
var : Variable
Coordinate variable object.
dims: tuple
Dataarray's dimensions.

"""
if any(d not in dims for d in var.dims):
return False
else:
return True

def to_pandas_index(self) -> pd.Index:
"""Cast this xarray index to a pandas.Index object or raise a
``TypeError`` if this is not supported.
Expand Down
12 changes: 6 additions & 6 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.coordinates import Coordinates, coordinates_from_variable
from xarray.core.dataarray import DataArray
from xarray.core.duck_array_ops import array_all, isnull
from xarray.core.groupby import T_Group, _DummyGroup
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(

if coords is None:
assert not isinstance(self.unique_coord, _DummyGroup)
self.coords = _coordinates_from_variable(self.unique_coord)
self.coords = coordinates_from_variable(self.unique_coord)
else:
self.coords = coords

Expand Down Expand Up @@ -252,7 +252,7 @@ def _factorize_unique(self) -> EncodedGroups:
codes=codes,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)

def _factorize_dummy(self) -> EncodedGroups:
Expand Down Expand Up @@ -280,7 +280,7 @@ def _factorize_dummy(self) -> EncodedGroups:
else:
if TYPE_CHECKING:
assert isinstance(unique_coord, Variable)
coords = _coordinates_from_variable(unique_coord)
coords = coordinates_from_variable(unique_coord)

return EncodedGroups(
codes=codes,
Expand Down Expand Up @@ -409,7 +409,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
codes=codes,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)


Expand Down Expand Up @@ -543,7 +543,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
coords=_coordinates_from_variable(unique_coord),
coords=coordinates_from_variable(unique_coord),
)


Expand Down
8 changes: 4 additions & 4 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,12 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):

assert isinstance(da._coords, dict), da._coords
assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)

if check_default_indexes:
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
Expand Down
Loading
Loading