Skip to content

Commit

Permalink
ENH: Implement cum* methods for PyArrow strings (#60633)
Browse files Browse the repository at this point in the history
* ENH: Implement cum* methods for PyArrow strings

* cleanup

* Cleanup

* fixup

* Fix extension tests

* xfail test when there is no pyarrow

* mypy fixups

* Change logic & whatsnew

* Change logic & whatsnew

* Fix fixture

* Fixup
  • Loading branch information
rhshadrach authored Jan 13, 2025
1 parent 1708e90 commit b5d4e89
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 11 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Other enhancements
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
updated to work correctly with NumPy >= 2 (:issue:`57739`)
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
-

.. ---------------------------------------------------------------------------
.. _whatsnew_230.notable_bug_fixes:
Expand Down
16 changes: 16 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,22 @@ def nullable_string_dtype(request):
return request.param


@pytest.fixture(
params=[
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
]
)
def pyarrow_string_dtype(request):
"""
Parametrized fixture for string dtypes backed by Pyarrow.
* 'str[pyarrow]'
* 'string[pyarrow]'
"""
return pd.StringDtype(*request.param)


@pytest.fixture(
params=[
"python",
Expand Down
55 changes: 55 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_list_like,
is_numeric_dtype,
is_scalar,
is_string_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
Expand Down Expand Up @@ -1619,6 +1620,9 @@ def _accumulate(
------
NotImplementedError : subclass does not define accumulations
"""
if is_string_dtype(self):
return self._str_accumulate(name=name, skipna=skipna, **kwargs)

pyarrow_name = {
"cummax": "cumulative_max",
"cummin": "cumulative_min",
Expand Down Expand Up @@ -1654,6 +1658,57 @@ def _accumulate(

return type(self)(result)

def _str_accumulate(
self, name: str, *, skipna: bool = True, **kwargs
) -> ArrowExtensionArray | ExtensionArray:
"""
Accumulate implementation for strings, see `_accumulate` docstring for details.
pyarrow.compute does not implement these methods for strings.
"""
if name == "cumprod":
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
raise TypeError(msg)

# We may need to strip out trailing NA values
tail: pa.array | None = None
na_mask: pa.array | None = None
pa_array = self._pa_array
np_func = {
"cumsum": np.cumsum,
"cummin": np.minimum.accumulate,
"cummax": np.maximum.accumulate,
}[name]

if self._hasna:
na_mask = pc.is_null(pa_array)
if pc.all(na_mask) == pa.scalar(True):
return type(self)(pa_array)
if skipna:
if name == "cumsum":
pa_array = pc.fill_null(pa_array, "")
else:
# We can retain the running min/max by forward/backward filling.
pa_array = pc.fill_null_forward(pa_array)
pa_array = pc.fill_null_backward(pa_array)
else:
# When not skipping NA values, the result should be null from
# the first NA value onward.
idx = pc.index(na_mask, True).as_py()
tail = pa.nulls(len(pa_array) - idx, type=pa_array.type)
pa_array = pa_array[:idx]

# error: Cannot call function of unknown type
pa_result = pa.array(np_func(pa_array), type=pa_array.type) # type: ignore[operator]

if tail is not None:
pa_result = pa.concat_arrays([pa_result, tail])
elif na_mask is not None:
pa_result = pc.if_else(na_mask, None, pa_result)

result = type(self)(pa_result)
return result

def _reduce_pyarrow(self, name: str, *, skipna: bool = True, **kwargs) -> pa.Scalar:
"""
Return a pyarrow scalar result of performing the reduction operation.
Expand Down
9 changes: 6 additions & 3 deletions pandas/tests/apply/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import numpy as np
import pytest

from pandas.compat import WASM
from pandas.compat import (
HAS_PYARROW,
WASM,
)

from pandas.core.dtypes.common import is_number

Expand Down Expand Up @@ -163,10 +166,10 @@ def test_agg_cython_table_transform_series(request, series, func, expected):
# GH21224
# test transforming functions in
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
if series.dtype == "string" and func == "cumsum":
if series.dtype == "string" and func == "cumsum" and not HAS_PYARROW:
request.applymarker(
pytest.mark.xfail(
raises=(TypeError, NotImplementedError),
raises=NotImplementedError,
reason="TODO(infer_string) cumsum not yet implemented for string",
)
)
Expand Down
5 changes: 3 additions & 2 deletions pandas/tests/extension/base/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
try:
alt = ser.astype("float64")
except TypeError:
# e.g. Period can't be cast to float64
except (TypeError, ValueError):
# e.g. Period can't be cast to float64 (TypeError)
# String can't be cast to float64 (ValueError)
alt = ser.astype(object)

result = getattr(ser, op_name)(skipna=skipna)
Expand Down
15 changes: 10 additions & 5 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
# attribute "pyarrow_dtype"
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]

if (
pa.types.is_string(pa_type)
or pa.types.is_binary(pa_type)
or pa.types.is_decimal(pa_type)
):
if pa.types.is_binary(pa_type) or pa.types.is_decimal(pa_type):
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
return False
elif pa.types.is_string(pa_type):
if op_name == "cumprod":
return False
elif pa.types.is_boolean(pa_type):
if op_name in ["cumprod", "cummax", "cummin"]:
return False
Expand All @@ -414,6 +413,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
pa_type = data.dtype.pyarrow_dtype
op_name = all_numeric_accumulations

if pa.types.is_string(pa_type) and op_name in ["cumsum", "cummin", "cummax"]:
# https://github.com/pandas-dev/pandas/pull/60633
# Doesn't fit test structure, tested in series/test_cumulative.py instead.
return

ser = pd.Series(data)

if not self._supports_accumulation(ser, op_name):
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.base import StorageExtensionDtype

import pandas as pd
import pandas._testing as tm
from pandas.api.types import is_string_dtype
Expand Down Expand Up @@ -192,6 +194,14 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
and op_name in ("any", "all")
)

def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
assert isinstance(ser.dtype, StorageExtensionDtype)
return ser.dtype.storage == "pyarrow" and op_name in [
"cummin",
"cummax",
"cumsum",
]

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
dtype = cast(StringDtype, tm.get_dtype(obj))
if op_name in ["__add__", "__radd__"]:
Expand Down
54 changes: 54 additions & 0 deletions pandas/tests/series/test_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
tests.frame.test_cumulative
"""

import re

import numpy as np
import pytest

Expand Down Expand Up @@ -227,3 +229,55 @@ def test_cumprod_timedelta(self):
ser = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=3)])
with pytest.raises(TypeError, match="cumprod not supported for Timedelta"):
ser.cumprod()

@pytest.mark.parametrize(
"data, op, skipna, expected_data",
[
([], "cumsum", True, []),
([], "cumsum", False, []),
(["x", "z", "y"], "cumsum", True, ["x", "xz", "xzy"]),
(["x", "z", "y"], "cumsum", False, ["x", "xz", "xzy"]),
(["x", pd.NA, "y"], "cumsum", True, ["x", pd.NA, "xy"]),
(["x", pd.NA, "y"], "cumsum", False, ["x", pd.NA, pd.NA]),
([pd.NA, "x", "y"], "cumsum", True, [pd.NA, "x", "xy"]),
([pd.NA, "x", "y"], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cumsum", True, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
([], "cummin", True, []),
([], "cummin", False, []),
(["y", "z", "x"], "cummin", True, ["y", "y", "x"]),
(["y", "z", "x"], "cummin", False, ["y", "y", "x"]),
(["y", pd.NA, "x"], "cummin", True, ["y", pd.NA, "x"]),
(["y", pd.NA, "x"], "cummin", False, ["y", pd.NA, pd.NA]),
([pd.NA, "y", "x"], "cummin", True, [pd.NA, "y", "x"]),
([pd.NA, "y", "x"], "cummin", False, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummin", True, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummin", False, [pd.NA, pd.NA, pd.NA]),
([], "cummax", True, []),
([], "cummax", False, []),
(["x", "z", "y"], "cummax", True, ["x", "z", "z"]),
(["x", "z", "y"], "cummax", False, ["x", "z", "z"]),
(["x", pd.NA, "y"], "cummax", True, ["x", pd.NA, "y"]),
(["x", pd.NA, "y"], "cummax", False, ["x", pd.NA, pd.NA]),
([pd.NA, "x", "y"], "cummax", True, [pd.NA, "x", "y"]),
([pd.NA, "x", "y"], "cummax", False, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummax", True, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
],
)
def test_cum_methods_pyarrow_strings(
self, pyarrow_string_dtype, data, op, skipna, expected_data
):
# https://github.com/pandas-dev/pandas/pull/60633
ser = pd.Series(data, dtype=pyarrow_string_dtype)
method = getattr(ser, op)
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
result = method(skipna=skipna)
tm.assert_series_equal(result, expected)

def test_cumprod_pyarrow_strings(self, pyarrow_string_dtype, skipna):
# https://github.com/pandas-dev/pandas/pull/60633
ser = pd.Series(list("xyz"), dtype=pyarrow_string_dtype)
msg = re.escape(f"operation 'cumprod' not supported for dtype '{ser.dtype}'")
with pytest.raises(TypeError, match=msg):
ser.cumprod(skipna=skipna)

0 comments on commit b5d4e89

Please sign in to comment.