Skip to content

Commit e27f572

Browse files
committed
Merge branch 'main' into bugfix-scalar-arr-casting
* main: (feat): Support for `pandas` `ExtensionArray` (pydata#8723) Migrate datatree mapping.py (pydata#8948) Add mypy to dev dependencies (pydata#8947) Convert 360_day calendars by choosing random dates to drop or add (pydata#8603)
2 parents e3493b0 + 9eb180b commit e27f572

25 files changed

+562
-90
lines changed

Diff for: doc/whats-new.rst

+10-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ v2024.04.0 (unreleased)
2222

2323
New Features
2424
~~~~~~~~~~~~
25-
25+
- New "random" method for converting to and from 360_day calendars (:pull:`8603`).
26+
By `Pascal Bourgault <https://github.com/aulemahal>`_.
27+
- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
28+
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`,
29+
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
30+
then, such as broadcasting.
31+
By `Ilan Gold <https://github.com/ilan-gold>`_.
2632

2733
Breaking changes
2834
~~~~~~~~~~~~~~~~
@@ -34,6 +40,9 @@ Bug fixes
3440

3541
Internal Changes
3642
~~~~~~~~~~~~~~~~
43+
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
44+
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
45+
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.
3746

3847

3948
.. _whats-new.2024.03.0:

Diff for: properties/test_pandas_roundtrip.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from hypothesis import given # isort:skip
1818

1919
numeric_dtypes = st.one_of(
20-
npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes()
20+
npst.unsigned_integer_dtypes(endianness="="),
21+
npst.integer_dtypes(endianness="="),
22+
npst.floating_dtypes(endianness="="),
2123
)
2224

2325
numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))

Diff for: pyproject.toml

+8-6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
3333
complete = ["xarray[accel,io,parallel,viz,dev]"]
3434
dev = [
3535
"hypothesis",
36+
"mypy",
3637
"pre-commit",
3738
"pytest",
3839
"pytest-cov",
@@ -86,8 +87,8 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
8687
[tool.mypy]
8788
enable_error_code = "redundant-self"
8889
exclude = [
89-
'xarray/util/generate_.*\.py',
90-
'xarray/datatree_/.*\.py',
90+
'xarray/util/generate_.*\.py',
91+
'xarray/datatree_/.*\.py',
9192
]
9293
files = "xarray"
9394
show_error_codes = true
@@ -98,8 +99,8 @@ warn_unused_ignores = true
9899

99100
# Ignore mypy errors for modules imported from datatree_.
100101
[[tool.mypy.overrides]]
101-
module = "xarray.datatree_.*"
102102
ignore_errors = true
103+
module = "xarray.datatree_.*"
103104

104105
# Much of the numerical computing stack doesn't have type annotations yet.
105106
[[tool.mypy.overrides]]
@@ -129,6 +130,7 @@ module = [
129130
"opt_einsum.*",
130131
"pandas.*",
131132
"pooch.*",
133+
"pyarrow.*",
132134
"pydap.*",
133135
"pytest.*",
134136
"scipy.*",
@@ -255,6 +257,9 @@ target-version = "py39"
255257
# E402: module level import not at top of file
256258
# E501: line too long - let black worry about that
257259
# E731: do not assign a lambda expression, use a def
260+
extend-safe-fixes = [
261+
"TID252", # absolute imports
262+
]
258263
ignore = [
259264
"E402",
260265
"E501",
@@ -268,9 +273,6 @@ select = [
268273
"I", # isort
269274
"UP", # Pyupgrade
270275
]
271-
extend-safe-fixes = [
272-
"TID252", # absolute imports
273-
]
274276

275277
[tool.ruff.lint.per-file-ignores]
276278
# don't enforce absolute imports

Diff for: xarray/coding/calendar_ops.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def convert_calendar(
6464
The target calendar name.
6565
dim : str
6666
Name of the time coordinate in the input DataArray or Dataset.
67-
align_on : {None, 'date', 'year'}
67+
align_on : {None, 'date', 'year', 'random'}
6868
Must be specified when either the source or target is a `"360_day"`
6969
calendar; ignored otherwise. See Notes.
7070
missing : any, optional
@@ -143,6 +143,16 @@ def convert_calendar(
143143
will be dropped as there are no equivalent dates in a standard calendar.
144144
145145
This option is best used with data on a frequency coarser than daily.
146+
147+
"random"
148+
Similar to "year", each day of year of the source is mapped to another day of year
149+
of the target. However, instead of having always the same missing days according
150+
the source and target years, here 5 days are chosen randomly, one for each fifth
151+
of the year. However, February 29th is always missing when converting to a leap year,
152+
or its value is dropped when converting from a leap year. This is similar to the method
153+
used in the LOCA dataset (see Pierce, Cayan, and Thrasher (2014). doi:10.1175/JHM-D-14-0082.1).
154+
155+
This option is best used on daily data.
146156
"""
147157
from xarray.core.dataarray import DataArray
148158

@@ -174,14 +184,20 @@ def convert_calendar(
174184

175185
out = obj.copy()
176186

177-
if align_on == "year":
187+
if align_on in ["year", "random"]:
178188
# Special case for conversion involving 360_day calendar
179-
# Instead of translating dates directly, this tries to keep the position within a year similar.
180-
181-
new_doy = time.groupby(f"{dim}.year").map(
182-
_interpolate_day_of_year, target_calendar=calendar, use_cftime=use_cftime
183-
)
184-
189+
if align_on == "year":
190+
# Instead of translating dates directly, this tries to keep the position within a year similar.
191+
new_doy = time.groupby(f"{dim}.year").map(
192+
_interpolate_day_of_year,
193+
target_calendar=calendar,
194+
use_cftime=use_cftime,
195+
)
196+
elif align_on == "random":
197+
# The 5 days to remove are randomly chosen, one for each of the five 72-days periods of the year.
198+
new_doy = time.groupby(f"{dim}.year").map(
199+
_random_day_of_year, target_calendar=calendar, use_cftime=use_cftime
200+
)
185201
# Convert the source datetimes, but override the day of year with our new day of years.
186202
out[dim] = DataArray(
187203
[
@@ -229,6 +245,27 @@ def _interpolate_day_of_year(time, target_calendar, use_cftime):
229245
).astype(int)
230246

231247

248+
def _random_day_of_year(time, target_calendar, use_cftime):
249+
"""Return a day of year in the new calendar.
250+
251+
Removes Feb 29th and five other days chosen randomly within five sections of 72 days.
252+
"""
253+
year = int(time.dt.year[0])
254+
source_calendar = time.dt.calendar
255+
new_doy = np.arange(360) + 1
256+
rm_idx = np.random.default_rng().integers(0, 72, 5) + 72 * np.arange(5)
257+
if source_calendar == "360_day":
258+
for idx in rm_idx:
259+
new_doy[idx + 1 :] = new_doy[idx + 1 :] + 1
260+
if _days_in_year(year, target_calendar, use_cftime) == 366:
261+
new_doy[new_doy >= 60] = new_doy[new_doy >= 60] + 1
262+
elif target_calendar == "360_day":
263+
new_doy = np.insert(new_doy, rm_idx - np.arange(5), -1)
264+
if _days_in_year(year, source_calendar, use_cftime) == 366:
265+
new_doy = np.insert(new_doy, 60, -1)
266+
return new_doy[time.dt.dayofyear - 1]
267+
268+
232269
def _convert_to_new_calendar_with_new_day_of_year(
233270
date, day_of_year, calendar, use_cftime
234271
):

Diff for: xarray/core/dataset.py

+47-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload
2525

2626
import numpy as np
27+
from pandas.api.types import is_extension_array_dtype
2728

2829
# remove once numpy 2.0 is the oldest supported version
2930
try:
@@ -6852,10 +6853,13 @@ def reduce(
68526853
if (
68536854
# Some reduction functions (e.g. std, var) need to run on variables
68546855
# that don't have the reduce dims: PR5393
6855-
not reduce_dims
6856-
or not numeric_only
6857-
or np.issubdtype(var.dtype, np.number)
6858-
or (var.dtype == np.bool_)
6856+
not is_extension_array_dtype(var.dtype)
6857+
and (
6858+
not reduce_dims
6859+
or not numeric_only
6860+
or np.issubdtype(var.dtype, np.number)
6861+
or (var.dtype == np.bool_)
6862+
)
68596863
):
68606864
# prefer to aggregate over axis=None rather than
68616865
# axis=(0, 1) if they will be equivalent, because
@@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
71687172
)
71697173

71707174
def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
7171-
columns = [k for k in self.variables if k not in self.dims]
7175+
columns_in_order = [k for k in self.variables if k not in self.dims]
7176+
non_extension_array_columns = [
7177+
k
7178+
for k in columns_in_order
7179+
if not is_extension_array_dtype(self.variables[k].data)
7180+
]
7181+
extension_array_columns = [
7182+
k
7183+
for k in columns_in_order
7184+
if is_extension_array_dtype(self.variables[k].data)
7185+
]
71727186
data = [
71737187
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
7174-
for k in columns
7188+
for k in non_extension_array_columns
71757189
]
71767190
index = self.coords.to_index([*ordered_dims])
7177-
return pd.DataFrame(dict(zip(columns, data)), index=index)
7191+
broadcasted_df = pd.DataFrame(
7192+
dict(zip(non_extension_array_columns, data)), index=index
7193+
)
7194+
for extension_array_column in extension_array_columns:
7195+
extension_array = self.variables[extension_array_column].data.array
7196+
index = self[self.variables[extension_array_column].dims[0]].data
7197+
extension_array_df = pd.DataFrame(
7198+
{extension_array_column: extension_array},
7199+
index=self[self.variables[extension_array_column].dims[0]].data,
7200+
)
7201+
extension_array_df.index.name = self.variables[extension_array_column].dims[
7202+
0
7203+
]
7204+
broadcasted_df = broadcasted_df.join(extension_array_df)
7205+
return broadcasted_df[columns_in_order]
71787206

71797207
def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
71807208
"""Convert this dataset into a pandas.DataFrame.
@@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73217349
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
73227350
)
73237351

7324-
# Cast to a NumPy array first, in case the Series is a pandas Extension
7325-
# array (which doesn't have a valid NumPy dtype)
7326-
# TODO: allow users to control how this casting happens, e.g., by
7327-
# forwarding arguments to pandas.Series.to_numpy?
7328-
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
7352+
arrays = []
7353+
extension_arrays = []
7354+
for k, v in dataframe.items():
7355+
if not is_extension_array_dtype(v):
7356+
arrays.append((k, np.asarray(v)))
7357+
else:
7358+
extension_arrays.append((k, v))
73297359

73307360
indexes: dict[Hashable, Index] = {}
73317361
index_vars: dict[Hashable, Variable] = {}
@@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73397369
xr_idx = PandasIndex(lev, dim)
73407370
indexes[dim] = xr_idx
73417371
index_vars.update(xr_idx.create_variables())
7372+
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
7373+
extension_arrays = []
73427374
else:
73437375
index_name = idx.name if idx.name is not None else "index"
73447376
dims = (index_name,)
@@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
73527384
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
73537385
else:
73547386
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
7355-
return obj
7387+
for name, extension_array in extension_arrays:
7388+
obj[name] = (dims, extension_array)
7389+
return obj[dataframe.columns] if len(dataframe.columns) else obj
73567390

73577391
def to_dask_dataframe(
73587392
self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False

Diff for: xarray/core/datatree.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
from xarray.core.coordinates import DatasetCoordinates
1919
from xarray.core.dataarray import DataArray
2020
from xarray.core.dataset import Dataset, DataVariables
21+
from xarray.core.datatree_mapping import (
22+
TreeIsomorphismError,
23+
check_isomorphic,
24+
map_over_subtree,
25+
)
2126
from xarray.core.indexes import Index, Indexes
2227
from xarray.core.merge import dataset_update_method
2328
from xarray.core.options import OPTIONS as XR_OPTS
@@ -36,11 +41,6 @@
3641
from xarray.datatree_.datatree.formatting_html import (
3742
datatree_repr as datatree_repr_html,
3843
)
39-
from xarray.datatree_.datatree.mapping import (
40-
TreeIsomorphismError,
41-
check_isomorphic,
42-
map_over_subtree,
43-
)
4444
from xarray.datatree_.datatree.ops import (
4545
DataTreeArithmeticMixin,
4646
MappedDatasetMethodsMixin,

Diff for: xarray/datatree_/datatree/mapping.py renamed to xarray/core/datatree_mapping.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
import sys
55
from itertools import repeat
66
from textwrap import dedent
7-
from typing import TYPE_CHECKING, Callable, Tuple
7+
from typing import TYPE_CHECKING, Callable
88

99
from xarray import DataArray, Dataset
10-
1110
from xarray.core.iterators import LevelOrderIter
1211
from xarray.core.treenode import NodePath, TreeNode
1312

@@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
8483
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
8584
path_a, path_b = node_a.path, node_b.path
8685

87-
if require_names_equal:
88-
if node_a.name != node_b.name:
89-
diff = dedent(
90-
f"""\
86+
if require_names_equal and node_a.name != node_b.name:
87+
diff = dedent(
88+
f"""\
9189
Node '{path_a}' in the left object has name '{node_a.name}'
9290
Node '{path_b}' in the right object has name '{node_b.name}'"""
93-
)
94-
return diff
91+
)
92+
return diff
9593

9694
if len(node_a.children) != len(node_b.children):
9795
diff = dedent(
@@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
125123
func : callable
126124
Function to apply to datasets with signature:
127125
128-
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
126+
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
129127
130128
(i.e. func must accept at least one Dataset and return at least one Dataset.)
131129
Function will not be applied to any nodes without datasets.
@@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
154152
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
155153

156154
@functools.wraps(func)
157-
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
155+
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
158156
"""Internal function which maps func over every node in tree, returning a tree of the results."""
159157
from xarray.core.datatree import DataTree
160158

@@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
259257
return _map_over_subtree
260258

261259

262-
def _handle_errors_with_path_context(path):
260+
def _handle_errors_with_path_context(path: str):
263261
"""Wraps given function so that if it fails it also raises path to node on which it failed."""
264262

265263
def decorator(func):
266264
def wrapper(*args, **kwargs):
267265
try:
268266
return func(*args, **kwargs)
269267
except Exception as e:
270-
if sys.version_info >= (3, 11):
271-
# Add the context information to the error message
272-
e.add_note(
273-
f"Raised whilst mapping function over node with path {path}"
274-
)
268+
# Add the context information to the error message
269+
add_note(
270+
e, f"Raised whilst mapping function over node with path {path}"
271+
)
275272
raise
276273

277274
return wrapper
@@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
287284
err.add_note(msg)
288285

289286

290-
def _check_single_set_return_values(path_to_node, obj):
287+
def _check_single_set_return_values(
288+
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
289+
):
291290
"""Check types returned from single evaluation of func, and return number of return values received from func."""
292291
if isinstance(obj, (Dataset, DataArray)):
293292
return 1

0 commit comments

Comments
 (0)