Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ What's New

v0.7.0 (unreleased)
-------------------
- ``interpolate``: Added support for timedelta coordinates;
fixed support for datetime coordinates other than ``M8[ns]`` in xarray 2025.1.2.


.. _whats-new.0.6.0:
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ filterwarnings = [
"error",
# FIXME these need to be fixed in xarray
"ignore:__array_wrap__ must accept context and return:DeprecationWarning",
"ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed:DeprecationWarning",
# FIXME these need to be looked at
'ignore:.*will no longer be implicitly promoted:FutureWarning',
'ignore:.*updating coordinate .* with a PandasMultiIndex would leave the multi-index level coordinates .* in an inconsistent state:FutureWarning',
# xarray vs. pandas upstream
# These have been fixed; still needed for Python 3.9 CI
"ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed:DeprecationWarning",
'ignore:Converting non-nanosecond precision datetime:UserWarning',
'ignore:Converting non-nanosecond precision timedelta:UserWarning',
]

[tool.coverage.report]
Expand Down
17 changes: 12 additions & 5 deletions xarray_extras/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def splrep(a: xarray.DataArray, dim: Hashable, k: int = 3) -> xarray.Dataset:
a = a.transpose(dim, *[d for d in a.dims if d != dim])
x = a.coords[dim].values

if x.dtype.kind == "M":
if x.dtype.kind == "M": # datetime
# Same treatment will be applied to x_new.
# Allow x_new.dtype==M8[D] and x.dtype==M8[ns], or vice versa
x = x.astype("M8[ns]").astype(float)
elif x.dtype.kind == "m": # timedelta
x = x.astype("m8[ns]").astype(float)

t = kernels.make_interp_knots(x, k, check_finite=False)
if k < 2:
Expand Down Expand Up @@ -181,15 +183,20 @@ def splev(
if t.shape != (c.sizes[dim] + k + 1,):
raise ValueError("Interpolated dimension has been sliced")

if x_new.dtype.kind == "M":
if x_new.dtype.kind == "M": # datetime
# Note that we're modifying the x_new values, not the x_new coords
# xarray datetime objects are always in ns
x_new = x_new.astype(float)
x_new = x_new.astype("M8[ns]").astype(float)
elif x_new.dtype.kind == "m": # timedelta
x_new = x_new.astype("m8[ns]").astype(float)

if extrapolate == "clip":
x = tck.coords[dim].values
if x.dtype.kind == "M":

if x.dtype.kind == "M": # datetime
x = x.astype("M8[ns]").astype(float)
elif x.dtype.kind == "m": # timedelta
x = x.astype("m8[ns]").astype(float)

x_new = np.clip(x_new, x[0].tolist(), x[-1].tolist())
extrapolate = False

Expand Down
33 changes: 27 additions & 6 deletions xarray_extras/tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def test_nonfloat(x_dtype, x_new_dtype):


@pytest.mark.filterwarnings("ignore:Converting non-nanosecond precision datetime ")
@pytest.mark.parametrize("x_new_dtype", ["<M8[D]", "<M8[s]", "<M8[ns]"])
@pytest.mark.parametrize("x_dtype", ["<M8[D]", "<M8[s]", "<M8[ns]"])
@pytest.mark.parametrize("x_new_dtype", ["M8[D]", "M8[s]", "M8[ns]"])
@pytest.mark.parametrize("x_dtype", ["M8[D]", "M8[s]", "M8[ns]"])
def test_dates(x_dtype, x_new_dtype):
"""
- Test mismatched date formats on x and x_new
Expand All @@ -202,15 +202,36 @@ def test_dates(x_dtype, x_new_dtype):
y = DataArray(
[10, 20],
dims=["x"],
coords={"x": np.array(["2000-01-01", "2001-01-01"]).astype(x_dtype)},
coords={"x": np.array(["2000-01-01", "2001-01-01"], dtype=x_dtype)},
)
x_new = np.array(["2000-04-20", "2002-07-28"]).astype(x_new_dtype)
expect = DataArray(
[13.00546448, 20.0], dims=["x"], coords={"x": x_new.astype("<M8[ns]")}
x_new = np.array(["2000-04-20", "2002-07-28"], dtype=x_new_dtype)
expect = DataArray([13.00546448, 20.0], dims=["x"], coords={"x": x_new})

tck = splrep(y, "x", k=1)
y_new = splev(x_new, tck, extrapolate="clip")
assert y_new.x.dtype == expect.x.dtype
assert_allclose(expect, y_new, atol=1e-6, rtol=0)


@pytest.mark.filterwarnings("ignore:Converting non-nanosecond precision datetime ")
@pytest.mark.parametrize("x_new_dtype", ["m8[D]", "m8[s]", "m8[ns]"])
@pytest.mark.parametrize("x_dtype", ["m8[D]", "m8[s]", "m8[ns]"])
def test_timedeltas(x_dtype, x_new_dtype):
"""
- Test mismatched date formats on x and x_new
- Test clip extrapolation on test_dates
"""
y = DataArray(
[10, 20],
dims=["x"],
coords={"x": np.array([30, 50], dtype="m8[D]").astype(x_dtype)},
)
x_new = np.array([35, 45], dtype="m8[D]").astype(x_new_dtype)
expect = DataArray([12.5, 17.5], dims=["x"], coords={"x": x_new})

tck = splrep(y, "x", k=1)
y_new = splev(x_new, tck, extrapolate="clip")
assert y_new.x.dtype == expect.x.dtype
assert_allclose(expect, y_new, atol=1e-6, rtol=0)


Expand Down
Loading