Skip to content

Commit 1cba840

Browse files
committed
Add xarray-specific encoding convention for pd.IntervalArray
Closes #2847 xref #8005 (comment)
1 parent 5ce69b2 commit 1cba840

File tree

5 files changed

+107
-4
lines changed

5 files changed

+107
-4
lines changed

xarray/coding/times.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,9 +1358,9 @@ def __init__(
13581358
self.time_unit = time_unit
13591359

13601360
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
1361-
if np.issubdtype(
1362-
variable.data.dtype, np.datetime64
1363-
) or contains_cftime_datetimes(variable):
1361+
if np.issubdtype(variable.dtype, np.datetime64) or contains_cftime_datetimes(
1362+
variable
1363+
):
13641364
dims, data, attrs, encoding = unpack_for_encoding(variable)
13651365

13661366
units = encoding.pop("units", None)
@@ -1477,7 +1477,7 @@ def __init__(
14771477
self._emit_decode_timedelta_future_warning = False
14781478

14791479
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
1480-
if np.issubdtype(variable.data.dtype, np.timedelta64):
1480+
if np.issubdtype(variable.dtype, np.timedelta64):
14811481
dims, data, attrs, encoding = unpack_for_encoding(variable)
14821482
dtype = encoding.get("dtype", None)
14831483
units = encoding.pop("units", None)

xarray/coding/variables.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,3 +698,53 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
698698

699699
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
700700
raise NotImplementedError()
701+
702+
703+
class IntervalCoder(VariableCoder):
704+
"""
705+
Xarray-specific Interval Coder to roundtrip 1D pd.IntervalArray objects.
706+
"""
707+
708+
encoded_dtype = "pandas_interval"
709+
encoded_bounds_dim = "__xarray_bounds__"
710+
711+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
712+
if isinstance(dtype := variable.dtype, pd.IntervalDtype):
713+
dims, data, attrs, encoding = unpack_for_encoding(variable)
714+
715+
new_data = np.stack([data.left, data.right], axis=0)
716+
dims = (self.encoded_bounds_dim, *dims)
717+
safe_setitem(attrs, "closed", dtype.closed, name=name)
718+
safe_setitem(attrs, "dtype", self.encoded_dtype, name=name)
719+
safe_setitem(attrs, "bounds_dim", self.encoded_bounds_dim, name=name)
720+
return Variable(dims, new_data, attrs, encoding, fastpath=True)
721+
else:
722+
return variable
723+
724+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
725+
if (
726+
variable.attrs.get("dtype", None) == self.encoded_dtype
727+
and self.encoded_bounds_dim in variable.dims
728+
):
729+
if variable.ndim != 2:
730+
raise ValueError(
731+
f"Cannot decode intervals for variable named {name!r} with more than two dimensions."
732+
)
733+
734+
dims, data, attrs, encoding = unpack_for_decoding(variable)
735+
pop_to(attrs, encoding, "dtype", name=name)
736+
pop_to(attrs, encoding, "bounds_dim", name=name)
737+
closed = pop_to(attrs, encoding, "closed", name=name)
738+
739+
_, new_dims = variable.dims
740+
variable = variable.load()
741+
new_data = pd.arrays.IntervalArray.from_arrays(
742+
variable.isel({self.encoded_bounds_dim: 0}).data,
743+
variable.isel({self.encoded_bounds_dim: 1}).data,
744+
closed=closed,
745+
)
746+
return Variable(
747+
dims=new_dims, data=new_data, attrs=attrs, encoding=encoding
748+
)
749+
else:
750+
return variable

xarray/conventions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def encode_cf_variable(
9090
ensure_not_multiindex(var, name=name)
9191

9292
for coder in [
93+
# IntervalCoder must be before CFDatetimeCoder,
94+
# so we can first encode the interval, then datetimes if necessary
95+
variables.IntervalCoder(),
9396
CFDatetimeCoder(),
9497
CFTimedeltaCoder(),
9598
variables.CFScaleOffsetCoder(),
@@ -238,6 +241,8 @@ def decode_cf_variable(
238241
)
239242
var = decode_times.decode(var, name=name)
240243

244+
var = variables.IntervalCoder().decode(var)
245+
241246
if decode_endianness and not var.dtype.isnative:
242247
var = variables.EndianCoder().decode(var)
243248
original_dtype = var.dtype

xarray/tests/test_coding.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,31 @@ def test_decode_signed_from_unsigned(bits) -> None:
147147
decoded = coder.decode(encoded)
148148
assert decoded.dtype == signed_dtype
149149
assert decoded.values == original_values
150+
151+
152+
@pytest.mark.parametrize(
153+
"data",
154+
[
155+
[1, 2, 3, 4],
156+
np.array([1, 2, 3, 4], dtype=float),
157+
pd.date_range("2001-01-01", "2002-01-01", freq="MS"),
158+
],
159+
)
160+
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
161+
def test_roundtrip_pandas_interval(data, closed) -> None:
162+
v = xr.Variable("time", pd.IntervalIndex.from_breaks(data, closed=closed))
163+
coder = variables.IntervalCoder()
164+
encoded = coder.encode(v)
165+
expected = xr.Variable(
166+
dims=("__xarray_bounds__", "time"),
167+
data=np.stack([data[:-1], data[1:]], axis=0),
168+
attrs={
169+
"dtype": "pandas_interval",
170+
"bounds_dim": "__xarray_bounds__",
171+
"closed": closed,
172+
},
173+
)
174+
assert_identical(encoded, expected)
175+
176+
decoded = coder.decode(encoded)
177+
assert_identical(decoded, v)

xarray/tests/test_conventions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,3 +675,23 @@ def test_decode_cf_variables_decode_timedelta_warning() -> None:
675675

676676
with pytest.warns(FutureWarning, match="decode_timedelta"):
677677
conventions.decode_cf_variables(variables, {})
678+
679+
680+
@pytest.mark.parametrize(
681+
"data",
682+
[
683+
[1, 2, 3, 4],
684+
np.array([1, 2, 3, 4], dtype=float),
685+
pd.date_range("2001-01-01", "2002-01-01", freq="MS"),
686+
],
687+
)
688+
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
689+
def test_roundtrip_pandas_interval(data, closed) -> None:
690+
v = Variable("time", pd.IntervalIndex.from_breaks(data, closed=closed))
691+
encoded = conventions.encode_cf_variable(v)
692+
if isinstance(data, pd.DatetimeIndex):
693+
# make sure we've encoded datetimes.
694+
assert "units" in encoded.attrs
695+
assert "calendar" in encoded.attrs
696+
roundtripped = conventions.decode_cf_variable("foo", encoded)
697+
assert_identical(roundtripped, v)

0 commit comments

Comments
 (0)