Skip to content

implement the PintIndex #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Jul 9, 2024
Merged
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
b9baa9c
add a `PintMetaIndex` that for now can only `sel`
keewis Mar 25, 2022
89c5e2a
add a function to compare indexers
keewis Mar 27, 2022
17e9aec
expect indexer dicts for strip_indexer_units
keewis Mar 27, 2022
30a1d80
move the indexer comparison function to the utils
keewis Mar 27, 2022
34caf09
change extract_indexer_units to expect a dict
keewis Mar 27, 2022
05aa5a6
fix a few calls to extract_indexer_units
keewis Mar 27, 2022
2e0e5bd
one more call
keewis Mar 27, 2022
818db6c
Merge branch 'main' into pint-meta-index
keewis Sep 11, 2023
a049d03
implement `create_variables` and `from_variables`
keewis Sep 13, 2023
0706bc6
use the new index to attach units to dimension coordinates
keewis Sep 13, 2023
a603860
pass the dictionary of indexers instead iterating manually
keewis Sep 13, 2023
dea881c
use `Coordinates._construct_direct`
keewis Sep 13, 2023
9427b20
Merge branch 'main' into pint-meta-index
keewis Sep 14, 2023
a54b94a
delegate `isel` to the wrapped index and wrap the result
keewis Sep 14, 2023
23d1f76
Merge branch 'main' into pint-meta-index
keewis Sep 16, 2023
f0d0890
add a inline `repr` for the index
keewis Dec 6, 2023
fa9f1b3
stubs for the remaining methods
keewis Dec 9, 2023
fb01e32
rename the index class to `PintIndex`
keewis Dec 9, 2023
9278a2d
add a utility method to wrap the output of the wrapped index's methods
keewis Dec 9, 2023
3200bc8
implement `equals`
keewis Dec 9, 2023
281f03c
implement `roll`, `rename`, and `__getitem__` by forwarding
keewis Dec 9, 2023
c5e9022
start adding tests
keewis Dec 10, 2023
2c2c814
add tests for `create_variables`
keewis Dec 10, 2023
e5d8369
add tests for `sel`
keewis Dec 10, 2023
50a7287
add tests for `isel`
keewis Dec 10, 2023
2b3c5bb
improve the tests for `sel`
keewis Dec 10, 2023
57ea8e5
add tests for `equals`
keewis Dec 10, 2023
3eed8c9
add tests for `roll`
keewis Dec 10, 2023
aebaf37
add tests for `rename`
keewis Dec 10, 2023
58c540f
add tests for `__getitem__`
keewis Dec 10, 2023
9cb7e91
add tests for `_repr_inline_`
keewis Dec 10, 2023
9822520
configure coverage, just in case
keewis Dec 11, 2023
55ccb00
use `_replace` instead of manually constructing the new index
keewis Dec 11, 2023
6bd6726
explicitly check that the pint index gets created
keewis Dec 22, 2023
c7d523b
also verify that non-quantity variables don't become `PintIndex`ed
keewis Dec 22, 2023
bae3c3e
Merge branch 'main' into pint-meta-index
keewis Jun 23, 2024
9dde67d
don't use `.pint.sel`
keewis Jun 23, 2024
235ca0e
Merge branch 'main' into pint-meta-index
keewis Jun 23, 2024
b927436
fix `PintIndex.from_variables` and properly test it
keewis Jun 23, 2024
c31e6b0
quantify the test data
keewis Jun 23, 2024
415d059
explicity quantify the input of the `interp_like` tests
keewis Jun 25, 2024
1939b2d
also strip the units of `other`
keewis Jul 6, 2024
2538104
change expectations in the conversion tests
keewis Jul 6, 2024
eb2c405
refactor `attach_units_dataset`
keewis Jul 6, 2024
0d46b66
get `convert_units` to accept indexes
keewis Jul 6, 2024
caf4668
strip indexes as well
keewis Jul 6, 2024
7303960
change the `.pint.to` tests to not include indexes
keewis Jul 6, 2024
e88738d
extract the units of `other` in `.pint.interp_like`
keewis Jul 6, 2024
0b400d0
quantify the input and expected data in the `reindex` tests
keewis Jul 6, 2024
5bd3ec7
remove the left-over explicit quantification in the `interp` tests
keewis Jul 6, 2024
77eef6d
get `.pint.reindex` to work by explicitly converting, stripping, and …
keewis Jul 6, 2024
c38eb5a
quantify the input and expected objects in the `reindex_like` tests
keewis Jul 6, 2024
7277eb5
get `reindex_like` to work with indexes
keewis Jul 6, 2024
c7cf340
quantify expected only if we expect to make use of it
keewis Jul 6, 2024
948d20f
quantify input and expected objects in the `sel` and `loc` tests
keewis Jul 6, 2024
8c76cbc
get `.pint.sel` and `.pint.loc` to work with the indexes
keewis Jul 6, 2024
f9cb15c
remove the warning about indexed coordinates
keewis Jul 6, 2024
49942bf
preserve the order of the variables
keewis Jul 6, 2024
20dd15c
remove the remaining uses of `Coordinates._construct_direct`
keewis Jul 8, 2024
5efb318
whats-new entry
keewis Jul 9, 2024
f53539a
expose the index
keewis Jul 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/plotting.ipynb
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@
"metadata": {},
"outputs": [],
"source": [
"monthly_means.pint.sel(\n",
"monthly_means.sel(\n",
" lat=ureg.Quantity(4350, \"angular_minute\"),\n",
" lon=ureg.Quantity(12000, \"angular_minute\"),\n",
")"
2 changes: 2 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@ What's new
------------------
- drop support for python 3.9 (:pull:`266`)
By `Justus Magin <https://github.com/keewis>`_.
- create a `PintIndex` to allow units on indexed coordinates (:pull:`163`, :issue:`162`)
By `Justus Magin <https://github.com/keewis>`_ and `Benoit Bovy <https://github.com/benbovy>`_.

0.4 (23 Jun 2024)
-----------------
2 changes: 2 additions & 0 deletions pint_xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from . import accessors, formatting, testing # noqa: F401
from .accessors import default_registry as unit_registry
from .accessors import setup_registry
from .index import PintIndex

try:
__version__ = version("pint-xarray")
@@ -21,4 +22,5 @@
"testing",
"unit_registry",
"setup_registry",
"PintIndex",
]
183 changes: 81 additions & 102 deletions pint_xarray/accessors.py

Large diffs are not rendered by default.

146 changes: 118 additions & 28 deletions pint_xarray/conversion.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,11 @@
import re

import pint
from xarray import DataArray, Dataset, IndexVariable, Variable
from xarray import Coordinates, DataArray, Dataset, IndexVariable, Variable

from .compat import call_on_dataset
from .errors import format_error_message
from .index import PintIndex

no_unit_values = ("none", None)
unit_attribute_name = "units"
@@ -121,28 +122,62 @@ def attach_units_variable(variable, units):
return new_obj


def dataset_from_variables(variables, coords, attrs):
data_vars = {name: var for name, var in variables.items() if name not in coords}
coords = {name: var for name, var in variables.items() if name in coords}
def dataset_from_variables(variables, coordinate_names, indexes, attrs):
data_vars = {
name: var for name, var in variables.items() if name not in coordinate_names
}
coords = {name: var for name, var in variables.items() if name in coordinate_names}

new_coords = Coordinates(coords, indexes=indexes)
return Dataset(data_vars=data_vars, coords=new_coords, attrs=attrs)


def attach_units_index(index, index_vars, units):
if all(unit is None for unit in units.values()):
# skip non-quantity indexed variables
return index

if isinstance(index, PintIndex) and index.units != units:
raise ValueError(
f"cannot attach units to quantified index: {index.units} != {units}"
)

return Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
return PintIndex(index=index, units=units)


def attach_units_dataset(obj, units):
attached = {}
rejected_vars = {}

indexed_variables = obj.xindexes.variables
for name, var in obj.variables.items():
if name in indexed_variables:
continue

unit = units.get(name)
try:
converted = attach_units_variable(var, unit)
attached[name] = converted
except ValueError as e:
rejected_vars[name] = (unit, e)

indexes, index_vars = obj.xindexes.copy_indexes()
for idx, idx_vars in obj.xindexes.group_by_index():
idx_units = {name: units.get(name) for name in idx_vars.keys()}
try:
attached_idx = attach_units_index(idx, idx_vars, idx_units)
indexes.update({k: attached_idx for k in idx_vars})
index_vars.update(attached_idx.create_variables(idx_vars))
except ValueError as e:
rejected_vars[name] = (units, e)

attached.update(index_vars)

if rejected_vars:
raise ValueError(rejected_vars)

return dataset_from_variables(attached, obj._coord_names, obj.attrs)
reordered = {name: attached[name] for name in obj.variables.keys()}
return dataset_from_variables(reordered, obj._coord_names, indexes, obj.attrs)


def attach_units(obj, units):
@@ -215,20 +250,64 @@ def convert_units_variable(variable, units):
return new_obj


def convert_units_index(index, index_vars, units):
if not isinstance(index, PintIndex):
raise ValueError("cannot convert non-quantified index")

converted_vars = {}
failed = {}
for name, var in index_vars.items():
unit = units.get(name)
try:
converted = convert_units_variable(var, unit)
converted_vars[name] = strip_units_variable(converted)
except (ValueError, pint.errors.PintTypeError) as e:
failed[name] = e

if failed:
# raise exception group
raise ValueError("failed to convert index variables:", failed)

# TODO: figure out how to pull out `options`
converted_index = index.index.from_variables(converted_vars, options={})
return PintIndex(index=converted_index, units=units)


def convert_units_dataset(obj, units):
converted = {}
failed = {}
indexed_variables = obj.xindexes.variables
for name, var in obj.variables.items():
if name in indexed_variables:
continue

unit = units.get(name)
try:
converted[name] = convert_units_variable(var, unit)
except (ValueError, pint.errors.PintTypeError) as e:
failed[name] = e

indexes, index_vars = obj.xindexes.copy_indexes()
for idx, idx_vars in obj.xindexes.group_by_index():
idx_units = {name: units.get(name) for name in idx_vars.keys()}
if all(unit is None for unit in idx_units.values()):
continue

try:
converted_index = convert_units_index(idx, idx_vars, idx_units)
indexes.update({k: converted_index for k in idx_vars})
index_vars.update(converted_index.create_variables())
except (ValueError, pint.errors.PintTypeError) as e:
names = tuple(idx_vars)
failed[names] = e

converted.update(index_vars)

if failed:
raise ValueError(failed)

return dataset_from_variables(converted, obj._coord_names, obj.attrs)
reordered = {name: converted[name] for name in obj.variables.keys()}
return dataset_from_variables(reordered, obj._coord_names, indexes, obj.attrs)


def convert_units(obj, units):
@@ -308,7 +387,12 @@ def strip_units_variable(var):
def strip_units_dataset(obj):
variables = {name: strip_units_variable(var) for name, var in obj.variables.items()}

return dataset_from_variables(variables, obj._coord_names, obj.attrs)
indexes = {
name: (index.index if isinstance(index, PintIndex) else index)
for name, index in obj.xindexes.items()
}

return dataset_from_variables(variables, obj._coord_names, indexes, obj.attrs)


def strip_units(obj):
@@ -403,25 +487,31 @@ def convert(indexer, units):
return converted


def extract_indexer_units(indexer):
if isinstance(indexer, slice):
return slice_extract_units(indexer)
elif isinstance(indexer, (DataArray, Variable)):
return array_extract_units(indexer.data)
else:
return array_extract_units(indexer)
def extract_indexer_units(indexers):
def extract(indexer):
if isinstance(indexer, slice):
return slice_extract_units(indexer)
elif isinstance(indexer, (DataArray, Variable)):
return array_extract_units(indexer.data)
else:
return array_extract_units(indexer)

return {name: extract(indexer) for name, indexer in indexers.items()}

def strip_indexer_units(indexer):
if isinstance(indexer, slice):
return slice(
array_strip_units(indexer.start),
array_strip_units(indexer.stop),
array_strip_units(indexer.step),
)
elif isinstance(indexer, DataArray):
return strip_units(indexer)
elif isinstance(indexer, Variable):
return strip_units_variable(indexer)
else:
return array_strip_units(indexer)

def strip_indexer_units(indexers):
def strip(indexer):
if isinstance(indexer, slice):
return slice(
array_strip_units(indexer.start),
array_strip_units(indexer.stop),
array_strip_units(indexer.step),
)
elif isinstance(indexer, DataArray):
return strip_units(indexer)
elif isinstance(indexer, Variable):
return strip_units_variable(indexer)
else:
return array_strip_units(indexer)

return {name: strip(indexer) for name, indexer in indexers.items()}
95 changes: 95 additions & 0 deletions pint_xarray/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from xarray import Variable
from xarray.core.indexes import Index, PandasIndex

from . import conversion


class PintIndex(Index):
def __init__(self, *, index, units):
"""create a unit-aware MetaIndex
Parameters
----------
index : xarray.Index
The wrapped index object.
units : mapping of hashable to unit-like
The units of the indexed coordinates
"""
self.index = index
self.units = units

def _replace(self, new_index):
return self.__class__(index=new_index, units=self.units)

def create_variables(self, variables=None):
index_vars = self.index.create_variables(variables)

index_vars_units = {}
for name, var in index_vars.items():
data = conversion.array_attach_units(var.data, self.units[name])
var_units = Variable(var.dims, data, attrs=var.attrs, encoding=var.encoding)
index_vars_units[name] = var_units

return index_vars_units

@classmethod
def from_variables(cls, variables, options):
if len(variables) != 1:
raise ValueError("can only create a default index from single variables")

units = options.pop("units", None)
index = PandasIndex.from_variables(variables, options=options)
return cls(index=index, units={index.index.name: units})

@classmethod
def concat(cls, indexes, dim, positions):
raise NotImplementedError()

@classmethod
def stack(cls, variables, dim):
raise NotImplementedError()

def unstack(self):
raise NotImplementedError()

def sel(self, labels):
converted_labels = conversion.convert_indexer_units(labels, self.units)
stripped_labels = conversion.strip_indexer_units(converted_labels)

return self.index.sel(stripped_labels)

def isel(self, indexers):
subset = self.index.isel(indexers)
if subset is None:
return None

return self._replace(subset)

def join(self, other, how="inner"):
raise NotImplementedError()

def reindex_like(self, other):
raise NotImplementedError()

def equals(self, other):
if not isinstance(other, PintIndex):
return False

# for now we require exactly matching units to avoid the potentially expensive conversion
if self.units != other.units:
return False

# last to avoid the potentially expensive comparison
return self.index.equals(other.index)

def roll(self, shifts):
return self._replace(self.index.roll(shifts))

def rename(self, name_dict, dims_dict):
return self._replace(self.index.rename(name_dict, dims_dict))

def __getitem__(self, indexer):
return self._replace(self.index[indexer])

def _repr_inline_(self, max_width):
return f"{self.__class__.__name__}({self.index.__class__.__name__})"
654 changes: 302 additions & 352 deletions pint_xarray/tests/test_accessors.py

Large diffs are not rendered by default.

215 changes: 136 additions & 79 deletions pint_xarray/tests/test_conversion.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import numpy as np
import pandas as pd
import pint
import pytest
from xarray import DataArray, Dataset, Variable
from xarray import Coordinates, DataArray, Dataset, Variable
from xarray.core.indexes import PandasIndex

from pint_xarray import conversion
from pint_xarray.index import PintIndex

from .utils import (
assert_array_equal,
assert_array_units_equal,
assert_identical,
assert_indexer_equal,
assert_indexer_units_equal,
assert_indexers_equal,
)

unit_registry = pint.UnitRegistry()
@@ -245,17 +248,22 @@ def test_attach_units(self, type, units):

q_a = to_quantity(a, units.get("a"))
q_b = to_quantity(b, units.get("b"))
q_x = to_quantity(x, units.get("x"))
q_u = to_quantity(u, units.get("u"))

units_x = units.get("x")
index = PandasIndex(x, dim="x")
if units.get("x") is not None:
index = PintIndex(index=index, units=units.get("x"))

obj = Dataset({"a": ("x", a), "b": ("x", b)}, coords={"u": ("x", u), "x": x})
coords = Coordinates(
coords={"u": Variable("x", q_u), "x": Variable("x", q_x)},
indexes={"x": index},
)
expected = Dataset(
{"a": ("x", q_a), "b": ("x", q_b)},
coords={"u": ("x", q_u), "x": x},
coords=coords,
)
if units_x is not None:
expected.x.attrs["units"] = units_x

if type == "DataArray":
obj = obj["a"]
@@ -264,6 +272,12 @@ def test_attach_units(self, type, units):
actual = conversion.attach_units(obj, units)
assert_identical(actual, expected)

if units.get("x") is None:
assert not isinstance(actual.xindexes["x"], PintIndex)
else:
assert isinstance(actual.xindexes["x"], PintIndex)
assert actual.xindexes["x"].units == {"x": units.get("x")}

@pytest.mark.parametrize("type", ("DataArray", "Dataset"))
def test_attach_unit_attributes(self, type):
units = {"a": "K", "b": "hPa", "u": "m", "x": "s"}
@@ -372,15 +386,19 @@ def test_convert_units(self, type, variant, units, error, match):
q_u = to_quantity(u, original_units.get("u"))
q_x = to_quantity(x, original_units.get("x"))

x_index = PandasIndex(pd.Index(x), "x")
if original_units.get("x") is not None:
x_index = PintIndex(index=x_index, units={"x": original_units.get("x")})

obj = Dataset(
{
"a": ("x", q_a),
"b": ("x", q_b),
},
coords={
"u": ("x", q_u),
"x": ("x", x, {"units": original_units.get("x")}),
},
coords=Coordinates(
{"u": ("x", q_u), "x": ("x", q_x)},
indexes={"x": x_index},
),
)
if type == "DataArray":
obj = obj["a"]
@@ -394,20 +412,22 @@ def test_convert_units(self, type, variant, units, error, match):
expected_a = convert_quantity(q_a, units.get("a", original_units.get("a")))
expected_b = convert_quantity(q_b, units.get("b", original_units.get("b")))
expected_u = convert_quantity(q_u, units.get("u", original_units.get("u")))
expected_x = strip_quantity(convert_quantity(q_x, units.get("x")))
expected_x = convert_quantity(q_x, units.get("x"))
expected_index = PandasIndex(pd.Index(strip_quantity(expected_x)), "x")
if units.get("x") is not None:
expected_index = PintIndex(
index=expected_index, units={"x": units.get("x")}
)

expected = Dataset(
{
"a": ("x", expected_a),
"b": ("x", expected_b),
},
coords={
"u": ("x", expected_u),
"x": (
"x",
expected_x,
{"units": units.get("x", original_units.get("x"))},
),
},
coords=Coordinates(
{"u": ("x", expected_u), "x": ("x", expected_x)},
indexes={"x": expected_index},
),
)

if type == "DataArray":
@@ -416,7 +436,7 @@ def test_convert_units(self, type, variant, units, error, match):
actual = conversion.convert_units(obj, units)

assert conversion.extract_units(actual) == conversion.extract_units(expected)
assert_identical(expected, actual)
assert_identical(actual, expected)

@pytest.mark.parametrize(
"units",
@@ -436,15 +456,22 @@ def test_extract_units(self, type, units):
u = np.linspace(0, 100, 2)
x = np.arange(2)

index = PandasIndex(x, "x")
if units.get("x") is not None:
index = PintIndex(index=index, units={"x": units.get("x")})

obj = Dataset(
{
"a": ("x", to_quantity(a, units.get("a"))),
"b": ("x", to_quantity(b, units.get("b"))),
},
coords={
"u": ("x", to_quantity(u, units.get("u"))),
"x": ("x", x, {"units": units.get("x")}),
},
coords=Coordinates(
{
"u": ("x", to_quantity(u, units.get("u"))),
"x": ("x", to_quantity(x, units.get("x"))),
},
indexes={"x": index},
),
)
if type == "DataArray":
obj = obj["a"]
@@ -499,21 +526,33 @@ def test_extract_unit_attributes(self, obj, expected):
pytest.param(
DataArray(
dims="x",
data=[0, 4, 3] * unit_registry.m,
coords={"u": ("x", [2, 3, 4] * unit_registry.s)},
data=Quantity([0, 4, 3], "kg"),
coords=Coordinates(
{
"u": ("x", Quantity([2, 3, 4], "s")),
"x": Quantity([0, 1, 2], "m"),
},
indexes={},
),
),
{None: None, "u": None},
{None: None, "u": None, "x": None},
id="DataArray",
),
pytest.param(
Dataset(
data_vars={
"a": ("x", [3, 2, 5] * unit_registry.Pa),
"b": ("x", [0, 2, -1] * unit_registry.kg),
"a": ("x", Quantity([3, 2, 5], "Pa")),
"b": ("x", Quantity([0, 2, -1], "kg")),
},
coords={"u": ("x", [2, 3, 4] * unit_registry.s)},
coords=Coordinates(
{
"u": ("x", Quantity([2, 3, 4], "s")),
"x": Quantity([0, 1, 2], "m"),
},
indexes={},
),
),
{"a": None, "b": None, "u": None},
{"a": None, "b": None, "u": None, "x": None},
id="Dataset",
),
),
@@ -694,100 +733,118 @@ def test_convert_indexer_units(self, indexers, units, expected, error, match):
conversion.convert_indexer_units(indexers, units)
else:
actual = conversion.convert_indexer_units(indexers, units)
assert_indexer_equal(actual["x"], expected["x"])
assert_indexer_units_equal(actual["x"], expected["x"])
assert_indexers_equal(actual, expected)
assert_indexer_units_equal(actual, expected)

@pytest.mark.parametrize(
["indexer", "expected"],
["indexers", "expected"],
(
pytest.param(1, None, id="scalar-no units"),
pytest.param(Quantity(1, "m"), Unit("m"), id="scalar-units"),
pytest.param(np.array([1, 2]), None, id="array-no units"),
pytest.param(Quantity([1, 2], "s"), Unit("s"), id="array-units"),
pytest.param(Variable("x", [1, 2]), None, id="Variable-no units"),
pytest.param({"x": 1}, {"x": None}, id="scalar-no units"),
pytest.param({"x": Quantity(1, "m")}, {"x": Unit("m")}, id="scalar-units"),
pytest.param({"x": np.array([1, 2])}, {"x": None}, id="array-no units"),
pytest.param(
{"x": Quantity([1, 2], "s")}, {"x": Unit("s")}, id="array-units"
),
pytest.param(
Variable("x", Quantity([1, 2], "m")), Unit("m"), id="Variable-units"
{"x": Variable("x", [1, 2])}, {"x": None}, id="Variable-no units"
),
pytest.param(DataArray([1, 2], dims="x"), None, id="DataArray-no units"),
pytest.param(
DataArray(Quantity([1, 2], "s"), dims="x"),
Unit("s"),
{"x": Variable("x", Quantity([1, 2], "m"))},
{"x": Unit("m")},
id="Variable-units",
),
pytest.param(
{"x": DataArray([1, 2], dims="x")}, {"x": None}, id="DataArray-no units"
),
pytest.param(
{"x": DataArray(Quantity([1, 2], "s"), dims="x")},
{"x": Unit("s")},
id="DataArray-units",
),
pytest.param(slice(None), None, id="empty slice-no units"),
pytest.param(slice(1, None), None, id="slice-no units"),
pytest.param({"x": slice(None)}, {"x": None}, id="empty slice-no units"),
pytest.param({"x": slice(1, None)}, {"x": None}, id="slice-no units"),
pytest.param(
slice(Quantity(1, "m"), Quantity(2, "m")),
Unit("m"),
{"x": slice(Quantity(1, "m"), Quantity(2, "m"))},
{"x": Unit("m")},
id="slice-identical units",
),
pytest.param(
slice(Quantity(1, "m"), Quantity(2000, "mm")),
Unit("m"),
{"x": slice(Quantity(1, "m"), Quantity(2000, "mm"))},
{"x": Unit("m")},
id="slice-compatible units",
),
pytest.param(
slice(Quantity(1, "m"), Quantity(2, "ms")),
{"x": slice(Quantity(1, "m"), Quantity(2, "ms"))},
ValueError,
id="slice-incompatible units",
),
pytest.param(
slice(1, Quantity(2, "ms")),
{"x": slice(1, Quantity(2, "ms"))},
ValueError,
id="slice-incompatible units-mixed",
),
pytest.param(
slice(1, Quantity(2, "rad")),
Unit("rad"),
{"x": slice(1, Quantity(2, "rad"))},
{"x": Unit("rad")},
id="slice-incompatible units-mixed-dimensionless",
),
),
)
def test_extract_indexer_units(self, indexer, expected):
if expected is not None and not isinstance(expected, Unit):
def test_extract_indexer_units(self, indexers, expected):
if isinstance(expected, type) and issubclass(expected, Exception):
with pytest.raises(expected):
conversion.extract_indexer_units(indexer)
conversion.extract_indexer_units(indexers)
else:
actual = conversion.extract_indexer_units(indexer)
actual = conversion.extract_indexer_units(indexers)
assert actual == expected

@pytest.mark.parametrize(
["indexer", "expected"],
["indexers", "expected"],
(
pytest.param(1, 1, id="scalar-no units"),
pytest.param(Quantity(1, "m"), 1, id="scalar-units"),
pytest.param(np.array([1, 2]), np.array([1, 2]), id="array-no units"),
pytest.param(Quantity([1, 2], "s"), np.array([1, 2]), id="array-units"),
pytest.param({"x": 1}, {"x": 1}, id="scalar-no units"),
pytest.param({"x": Quantity(1, "m")}, {"x": 1}, id="scalar-units"),
pytest.param(
Variable("x", [1, 2]), Variable("x", [1, 2]), id="Variable-no units"
{"x": np.array([1, 2])},
{"x": np.array([1, 2])},
id="array-no units",
),
pytest.param(
{"x": Quantity([1, 2], "s")}, {"x": np.array([1, 2])}, id="array-units"
),
pytest.param(
{"x": Variable("x", [1, 2])},
{"x": Variable("x", [1, 2])},
id="Variable-no units",
),
pytest.param(
Variable("x", Quantity([1, 2], "m")),
Variable("x", [1, 2]),
{"x": Variable("x", Quantity([1, 2], "m"))},
{"x": Variable("x", [1, 2])},
id="Variable-units",
),
pytest.param(
DataArray([1, 2], dims="x"),
DataArray([1, 2], dims="x"),
{"x": DataArray([1, 2], dims="x")},
{"x": DataArray([1, 2], dims="x")},
id="DataArray-no units",
),
pytest.param(
DataArray(Quantity([1, 2], "s"), dims="x"),
DataArray([1, 2], dims="x"),
{"x": DataArray(Quantity([1, 2], "s"), dims="x")},
{"x": DataArray([1, 2], dims="x")},
id="DataArray-units",
),
pytest.param(slice(None), slice(None), id="empty slice-no units"),
pytest.param(slice(1, None), slice(1, None), id="slice-no units"),
pytest.param(
slice(Quantity(1, "m"), Quantity(2, "m")),
slice(1, 2),
{"x": slice(None)}, {"x": slice(None)}, id="empty slice-no units"
),
pytest.param(
{"x": slice(1, None)}, {"x": slice(1, None)}, id="slice-no units"
),
pytest.param(
{"x": slice(Quantity(1, "m"), Quantity(2, "m"))},
{"x": slice(1, 2)},
id="slice-units",
),
),
)
def test_strip_indexer_units(self, indexer, expected):
actual = conversion.strip_indexer_units(indexer)
if isinstance(indexer, DataArray):
assert_identical(actual, expected)
else:
assert_array_equal(actual, expected)
def test_strip_indexer_units(self, indexers, expected):
actual = conversion.strip_indexer_units(indexers)

assert_indexers_equal(actual, expected)
227 changes: 227 additions & 0 deletions pint_xarray/tests/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from xarray.core.indexes import IndexSelResult, PandasIndex

from pint_xarray import unit_registry as ureg
from pint_xarray.index import PintIndex


def indexer_equal(first, second):
if type(first) is not type(second):
return False

if isinstance(first, np.ndarray):
return np.all(first == second)
else:
return first == second


@pytest.mark.parametrize(
"base_index",
[
PandasIndex(pd.Index([1, 2, 3]), dim="x"),
PandasIndex(pd.Index([0.1, 0.2, 0.3]), dim="x"),
PandasIndex(pd.Index([1j, 2j, 3j]), dim="y"),
],
)
@pytest.mark.parametrize("units", [ureg.Unit("m"), ureg.Unit("s")])
def test_init(base_index, units):
index = PintIndex(index=base_index, units=units)

assert index.index.equals(base_index)
assert index.units == units


def test_replace():
old_index = PandasIndex([1, 2, 3], dim="y")
new_index = PandasIndex([0.1, 0.2, 0.3], dim="x")

old = PintIndex(index=old_index, units=ureg.Unit("m"))
new = old._replace(new_index)

assert new.index.equals(new_index)
assert new.units == old.units
# no mutation
assert old.index.equals(old_index)


@pytest.mark.parametrize(
["wrapped_index", "units", "expected"],
(
pytest.param(
PandasIndex(pd.Index([1, 2, 3]), dim="x"),
{"x": ureg.Unit("m")},
{"x": xr.Variable("x", ureg.Quantity([1, 2, 3], "m"))},
),
pytest.param(
PandasIndex(pd.Index([1j, 2j, 3j]), dim="y"),
{"y": ureg.Unit("ms")},
{"y": xr.Variable("y", ureg.Quantity([1j, 2j, 3j], "ms"))},
),
),
)
def test_create_variables(wrapped_index, units, expected):
index = PintIndex(index=wrapped_index, units=units)

actual = index.create_variables()

assert list(actual.keys()) == list(expected.keys())
assert all([actual[k].equals(expected[k]) for k in expected.keys()])


@pytest.mark.parametrize(
["labels", "expected"],
(
({"x": ureg.Quantity(1, "m")}, IndexSelResult(dim_indexers={"x": 0})),
({"x": ureg.Quantity(3000, "mm")}, IndexSelResult(dim_indexers={"x": 2})),
({"x": ureg.Quantity(0.002, "km")}, IndexSelResult(dim_indexers={"x": 1})),
(
{"x": ureg.Quantity([0.002, 0.004], "km")},
IndexSelResult(dim_indexers={"x": np.array([1, 3])}),
),
(
{"x": slice(ureg.Quantity(2, "m"), ureg.Quantity(3000, "mm"))},
IndexSelResult(dim_indexers={"x": slice(1, 3)}),
),
),
)
def test_sel(labels, expected):
index = PintIndex(
index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), units={"x": ureg.Unit("m")}
)

actual = index.sel(labels)

assert isinstance(actual, IndexSelResult)
assert list(actual.dim_indexers.keys()) == list(expected.dim_indexers.keys())
assert all(
[
indexer_equal(actual.dim_indexers[k], expected.dim_indexers[k])
for k in expected.dim_indexers.keys()
]
)


@pytest.mark.parametrize(
"indexers",
({"y": 0}, {"y": [1, 2]}, {"y": slice(0, None, 2)}, {"y": xr.Variable("y", [1])}),
)
def test_isel(indexers):
wrapped_index = PandasIndex(pd.Index([1, 2, 3, 4]), dim="y")
index = PintIndex(index=wrapped_index, units={"y": ureg.Unit("s")})

actual = index.isel(indexers)

wrapped_ = wrapped_index.isel(indexers)
if wrapped_ is not None:
expected = PintIndex(
index=wrapped_index.isel(indexers), units={"y": ureg.Unit("s")}
)
else:
expected = None

assert (actual is None and expected is None) or actual.equals(expected)


@pytest.mark.parametrize(
["other", "expected"],
(
(
PintIndex(
index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"),
units={"x": ureg.Unit("cm")},
),
True,
),
(PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), False),
(
PintIndex(
index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"),
units={"x": ureg.Unit("m")},
),
False,
),
(
PintIndex(
index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="y"),
units={"y": ureg.Unit("cm")},
),
False,
),
(
PintIndex(
index=PandasIndex(pd.Index([1, 3, 3, 4]), dim="x"),
units={"x": ureg.Unit("cm")},
),
False,
),
),
)
def test_equals(other, expected):
index = PintIndex(
index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), units={"x": ureg.Unit("cm")}
)

actual = index.equals(other)

assert actual == expected


@pytest.mark.parametrize(
["shifts", "expected_index"],
(
({"x": 0}, PandasIndex(pd.Index([-2, -1, 0, 1, 2]), dim="x")),
({"x": 1}, PandasIndex(pd.Index([2, -2, -1, 0, 1]), dim="x")),
({"x": 2}, PandasIndex(pd.Index([1, 2, -2, -1, 0]), dim="x")),
({"x": -1}, PandasIndex(pd.Index([-1, 0, 1, 2, -2]), dim="x")),
({"x": -2}, PandasIndex(pd.Index([0, 1, 2, -2, -1]), dim="x")),
),
)
def test_roll(shifts, expected_index):
index = PintIndex(
index=PandasIndex(pd.Index([-2, -1, 0, 1, 2]), dim="x"),
units={"x": ureg.Unit("m")},
)

actual = index.roll(shifts)
expected = index._replace(expected_index)

assert actual.equals(expected)


@pytest.mark.parametrize("dims_dict", ({"y": "x"}, {"y": "z"}))
@pytest.mark.parametrize("name_dict", ({"y2": "y3"}, {"y2": "y1"}))
def test_rename(name_dict, dims_dict):
wrapped_index = PandasIndex(pd.Index([1, 2], name="y2"), dim="y")
index = PintIndex(index=wrapped_index, units={"y": ureg.Unit("m")})

actual = index.rename(name_dict, dims_dict)
expected = PintIndex(
index=wrapped_index.rename(name_dict, dims_dict), units=index.units
)

assert actual.equals(expected)


@pytest.mark.parametrize("indexer", ([0], slice(0, 2)))
def test_getitem(indexer):
wrapped_index = PandasIndex(pd.Index([1, 2], name="y2"), dim="y")
index = PintIndex(index=wrapped_index, units={"y": ureg.Unit("m")})

actual = index[indexer]
expected = PintIndex(index=wrapped_index[indexer], units=index.units)

assert actual.equals(expected)


@pytest.mark.parametrize("wrapped_index", (PandasIndex(pd.Index([1, 2]), dim="x"),))
def test_repr_inline(wrapped_index):
index = PintIndex(index=wrapped_index, units=ureg.Unit("m"))

# TODO: parametrize
actual = index._repr_inline_(90)

assert "PintIndex" in actual
assert wrapped_index.__class__.__name__ in actual
28 changes: 28 additions & 0 deletions pint_xarray/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from contextlib import contextmanager
from textwrap import indent

import numpy as np
import pytest
@@ -97,6 +98,33 @@ def assert_indexer_equal(a, b):
assert a_ == b_, f"different values: {a_!r} ←→ {b_!r}"


def assert_indexers_equal(first, second):
__tracebackhide__ = True
# same keys
assert first.keys() == second.keys(), "different keys"

errors = {}
for name in first:
first_value = first[name]
second_value = second[name]

try:
assert_indexer_equal(first_value, second_value)
except AssertionError as e:
errors[name] = e

if errors:
message = "\n".join(
["indexers are not equal:"]
+ [
f" - {name}:\n{indent(str(error), ' ' * 4)}"
for name, error in errors.items()
]
)

raise AssertionError(message)


def assert_indexer_units_equal(a, b):
__tracebackhide__ = True

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -52,3 +52,11 @@ skip_gitignore = "true"
force_to_top = "true"
default_section = "THIRDPARTY"
known_first_party = "pint_xarray"

[tool.coverage.run]
source = ["pint_xarray"]
branch = true

[tool.coverage.report]
show_missing = true
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]