Skip to content

Commit 951c217

Browse files
committed
(fix): more typing!
1 parent 4c20d8a commit 951c217

File tree

3 files changed

+39
-28
lines changed

3 files changed

+39
-28
lines changed

properties/test_pandas_roundtrip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_roundtrip_dataarray(data, arr) -> None:
9393
def test_roundtrip_dataset(dataset: Dataset) -> None:
9494
df = dataset.to_dataframe()
9595
assert isinstance(df, pd.DataFrame)
96-
roundtripped = xr.Dataset(df)
96+
roundtripped = xr.Dataset.from_dataframe(df)
9797
xr.testing.assert_identical(dataset, roundtripped)
9898

9999

@@ -103,7 +103,7 @@ def test_roundtrip_pandas_series(ser, ix_name) -> None:
103103
ser.index.name = ix_name
104104
arr = xr.DataArray(ser)
105105
roundtripped = arr.to_pandas()
106-
pd.testing.assert_series_equal(ser, roundtripped)
106+
pd.testing.assert_series_equal(ser, roundtripped) # type: ignore[arg-type]
107107
xr.testing.assert_identical(arr, roundtripped.to_xarray())
108108

109109

xarray/core/dtypes.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import functools
44
from collections.abc import Iterable
5-
from typing import TYPE_CHECKING, cast
5+
from typing import TYPE_CHECKING, TypeVar, cast
66

77
import numpy as np
8+
from pandas.api.extensions import ExtensionDtype
89
from pandas.api.types import is_extension_array_dtype
910

1011
from xarray.compat import array_api_compat, npcompat
@@ -14,7 +15,6 @@
1415
if TYPE_CHECKING:
1516
from typing import Any
1617

17-
from pandas.api.extensions import ExtensionDtype
1818

1919
# Use as a sentinel value to indicate a dtype appropriate NA value.
2020
NA = utils.ReprObject("<NA>")
@@ -53,10 +53,10 @@ def __eq__(self, other):
5353
(np.bytes_, np.str_), # numpy promotes to unicode
5454
)
5555

56+
T_dtype = TypeVar("T_dtype", np.dtype, ExtensionDtype)
5657

57-
def maybe_promote(
58-
dtype: np.dtype | ExtensionDtype,
59-
) -> tuple[np.dtype | ExtensionDtype, Any]:
58+
59+
def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
6060
"""Simpler equivalent of pandas.core.common._maybe_promote
6161
6262
Parameters
@@ -72,10 +72,12 @@ def maybe_promote(
7272
dtype_: np.typing.DTypeLike
7373
fill_value: Any
7474
if is_extension_array_dtype(dtype):
75-
return dtype, dtype.na_value
76-
else:
77-
dtype = cast(np.dtype, dtype)
78-
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
75+
return dtype, cast(ExtensionDtype, dtype).na_value # type: ignore[redundant-cast]
76+
if not isinstance(dtype, np.dtype):
77+
raise TypeError(
78+
f"dtype {dtype} must be one of an extension array dtype or numpy dtype"
79+
)
80+
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
7981
# for now, we always promote string dtypes to object for consistency with existing behavior
8082
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
8183
dtype_ = object
@@ -235,10 +237,14 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
235237

236238

237239
def maybe_promote_to_variable_width(
238-
array_or_dtype: np.typing.ArrayLike | np.typing.DTypeLike,
240+
array_or_dtype: np.typing.ArrayLike
241+
| np.typing.DTypeLike
242+
| ExtensionDtype
243+
| str
244+
| bytes,
239245
*,
240246
should_return_str_or_bytes: bool = False,
241-
) -> np.typing.ArrayLike | np.typing.DTypeLike:
247+
) -> np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype:
242248
if isinstance(array_or_dtype, str | bytes):
243249
if should_return_str_or_bytes:
244250
return array_or_dtype
@@ -256,7 +262,10 @@ def maybe_promote_to_variable_width(
256262

257263

258264
def should_promote_to_object(
259-
arrays_and_dtypes: Iterable[np.typing.ArrayLike | np.typing.DTypeLike], xp
265+
arrays_and_dtypes: Iterable[
266+
np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype
267+
],
268+
xp,
260269
) -> bool:
261270
"""
262271
Test whether the given arrays_and_dtypes, when evaluated individually, match the
@@ -286,9 +295,7 @@ def should_promote_to_object(
286295

287296

288297
def result_type(
289-
*arrays_and_dtypes: list[
290-
np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype
291-
],
298+
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype,
292299
xp=None,
293300
) -> np.dtype:
294301
"""Like np.result_type, but with type promotion rules matching pandas.

xarray/core/extension_array.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __extension_duck_array__astype(
6868
) -> ExtensionArray:
6969
if (
7070
not (
71-
is_extension_array_dtype(array_or_scalar) or is_extension_array_dtype(dtype) # type: ignore[arg-dtype]
71+
is_extension_array_dtype(array_or_scalar) or is_extension_array_dtype(dtype)
7272
)
7373
or casting != "unsafe"
7474
or not subok
@@ -96,21 +96,23 @@ def as_extension_array(
9696
copy: bool = False,
9797
) -> ExtensionArray:
9898
if is_scalar(array_or_scalar):
99-
return dtype.construct_array_type()._from_sequence( # type: ignore[attr-defined]
99+
return dtype.construct_array_type()._from_sequence( # type: ignore[union-attr]
100100
[array_or_scalar], dtype=dtype
101101
)
102102
else:
103-
return array_or_scalar.astype(dtype, copy=copy)
103+
return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr]
104104

105105

106106
@implements(np.result_type)
107107
def __extension_duck_array__result_type(
108-
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
108+
*arrays_and_dtypes: list[
109+
np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype | ExtensionArray
110+
],
109111
) -> DtypeObj:
110112
extension_arrays_and_dtypes: list[ExtensionDtype | ExtensionArray] = [
111-
x
113+
cast(ExtensionDtype | ExtensionArray, x)
112114
for x in arrays_and_dtypes
113-
if is_extension_array_dtype(x) # type: ignore[arg-type, misc]
115+
if is_extension_array_dtype(x)
114116
]
115117
if not extension_arrays_and_dtypes:
116118
return NotImplemented
@@ -120,15 +122,17 @@ def __extension_duck_array__result_type(
120122
for x in extension_arrays_and_dtypes
121123
]
122124
scalars: list[Scalar] = [
123-
x for x in arrays_and_dtypes if is_scalar(x) and x not in {pd.NA, np.nan}
125+
cast(Scalar, x)
126+
for x in arrays_and_dtypes
127+
if is_scalar(x) and x not in {pd.NA, np.nan}
124128
]
125129
# other_stuff could include:
126130
# - arrays such as pd.ABCSeries, np.ndarray, or other array-api duck arrays
127131
# - dtypes such as pd.DtypeObj, np.dtype, or other array-api duck dtypes
128132
other_stuff = [
129133
x
130134
for x in arrays_and_dtypes
131-
if not is_extension_array_dtype(x) and not is_scalar(x) # type: ignore[arg-type, misc]
135+
if not is_extension_array_dtype(x) and not is_scalar(x)
132136
]
133137
# We implement one special case: when possible, preserve Categoricals (avoid promoting
134138
# to object) by merging the categories of all given Categoricals + scalars + NA.
@@ -178,14 +182,14 @@ def __extension_duck_array__concatenate(
178182

179183
@implements(np.where)
180184
def __extension_duck_array__where(
181-
condition: T_ExtensionArray | np.ArrayLike,
185+
condition: T_ExtensionArray | np.typing.ArrayLike,
182186
x: T_ExtensionArray,
183-
y: T_ExtensionArray | np.ArrayLike,
187+
y: T_ExtensionArray | np.typing.ArrayLike,
184188
) -> T_ExtensionArray:
185189
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type]
186190

187191

188-
def _replace_duck(args, replacer: Callable[[PandasExtensionArray]]) -> list:
192+
def _replace_duck(args, replacer: Callable[[PandasExtensionArray], list]) -> list:
189193
args_as_list = list(args)
190194
for index, value in enumerate(args_as_list):
191195
if isinstance(value, PandasExtensionArray):

0 commit comments

Comments
 (0)