Skip to content

Commit

Permalink
[backport 2.3.x] ENH: Implement cum* methods for PyArrow strings (#60633
Browse files Browse the repository at this point in the history
) (#60753)

(cherry picked from commit b5d4e89)

Co-authored-by: Richard Shadrach <[email protected]>
  • Loading branch information
jorisvandenbossche and rhshadrach authored Jan 22, 2025
1 parent c638e69 commit 1c0c351
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 10 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Other enhancements
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
updated to raise FutureWarning with NumPy >= 2 (:issue:`60340`)
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
-

.. ---------------------------------------------------------------------------
.. _whatsnew_230.notable_bug_fixes:
Expand Down
16 changes: 16 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,22 @@ def nullable_string_dtype(request):
return request.param


@pytest.fixture(
params=[
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
]
)
def pyarrow_string_dtype(request):
"""
Parametrized fixture for string dtypes backed by Pyarrow.
* 'str[pyarrow]'
* 'string[pyarrow]'
"""
return pd.StringDtype(*request.param)


@pytest.fixture(
params=[
"python",
Expand Down
55 changes: 55 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_list_like,
is_numeric_dtype,
is_scalar,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -1617,6 +1618,9 @@ def _accumulate(
------
NotImplementedError : subclass does not define accumulations
"""
if is_string_dtype(self):
return self._str_accumulate(name=name, skipna=skipna, **kwargs)

pyarrow_name = {
"cummax": "cumulative_max",
"cummin": "cumulative_min",
Expand Down Expand Up @@ -1652,6 +1656,57 @@ def _accumulate(

return type(self)(result)

def _str_accumulate(
self, name: str, *, skipna: bool = True, **kwargs
) -> ArrowExtensionArray | ExtensionArray:
"""
Accumulate implementation for strings, see `_accumulate` docstring for details.
pyarrow.compute does not implement these methods for strings.
"""
if name == "cumprod":
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
raise TypeError(msg)

# We may need to strip out trailing NA values
tail: pa.array | None = None
na_mask: pa.array | None = None
pa_array = self._pa_array
np_func = {
"cumsum": np.cumsum,
"cummin": np.minimum.accumulate,
"cummax": np.maximum.accumulate,
}[name]

if self._hasna:
na_mask = pc.is_null(pa_array)
if pc.all(na_mask) == pa.scalar(True):
return type(self)(pa_array)
if skipna:
if name == "cumsum":
pa_array = pc.fill_null(pa_array, "")
else:
# We can retain the running min/max by forward/backward filling.
pa_array = pc.fill_null_forward(pa_array)
pa_array = pc.fill_null_backward(pa_array)
else:
# When not skipping NA values, the result should be null from
# the first NA value onward.
idx = pc.index(na_mask, True).as_py()
tail = pa.nulls(len(pa_array) - idx, type=pa_array.type)
pa_array = pa_array[:idx]

# error: Cannot call function of unknown type
pa_result = pa.array(np_func(pa_array), type=pa_array.type) # type: ignore[operator]

if tail is not None:
pa_result = pa.concat_arrays([pa_result, tail])
elif na_mask is not None:
pa_result = pc.if_else(na_mask, None, pa_result)

result = type(self)(pa_result)
return result

def _reduce_pyarrow(self, name: str, *, skipna: bool = True, **kwargs) -> pa.Scalar:
"""
Return a pyarrow scalar result of performing the reduction operation.
Expand Down
10 changes: 8 additions & 2 deletions pandas/tests/apply/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pytest

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.common import is_number

from pandas import (
Expand Down Expand Up @@ -170,10 +172,14 @@ def test_agg_cython_table_transform_series(request, series, func, expected):
# GH21224
# test transforming functions in
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
if series.dtype == "string" and func in ("cumsum", np.cumsum, np.nancumsum):
if (
series.dtype == "string"
and func in ("cumsum", np.cumsum, np.nancumsum)
and not HAS_PYARROW
):
request.applymarker(
pytest.mark.xfail(
raises=(TypeError, NotImplementedError),
raises=NotImplementedError,
reason="TODO(infer_string) cumsum not yet implemented for string",
)
)
Expand Down
5 changes: 3 additions & 2 deletions pandas/tests/extension/base/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
try:
alt = ser.astype("float64")
except TypeError:
# e.g. Period can't be cast to float64
except (TypeError, ValueError):
# e.g. Period can't be cast to float64 (TypeError)
# String can't be cast to float64 (ValueError)
alt = ser.astype(object)

result = getattr(ser, op_name)(skipna=skipna)
Expand Down
15 changes: 10 additions & 5 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
# attribute "pyarrow_dtype"
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]

if (
pa.types.is_string(pa_type)
or pa.types.is_binary(pa_type)
or pa.types.is_decimal(pa_type)
):
if pa.types.is_binary(pa_type) or pa.types.is_decimal(pa_type):
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
return False
elif pa.types.is_string(pa_type):
if op_name == "cumprod":
return False
elif pa.types.is_boolean(pa_type):
if op_name in ["cumprod", "cummax", "cummin"]:
return False
Expand All @@ -409,6 +408,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
pa_type = data.dtype.pyarrow_dtype
op_name = all_numeric_accumulations

if pa.types.is_string(pa_type) and op_name in ["cumsum", "cummin", "cummax"]:
# https://github.com/pandas-dev/pandas/pull/60633
# Doesn't fit test structure, tested in series/test_cumulative.py instead.
return

ser = pd.Series(data)

if not self._supports_accumulation(ser, op_name):
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.base import StorageExtensionDtype

import pandas as pd
import pandas._testing as tm
from pandas.api.types import is_string_dtype
Expand Down Expand Up @@ -196,6 +198,14 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
and op_name in ("any", "all")
)

def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
assert isinstance(ser.dtype, StorageExtensionDtype)
return ser.dtype.storage == "pyarrow" and op_name in [
"cummin",
"cummax",
"cumsum",
]

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
dtype = cast(StringDtype, tm.get_dtype(obj))
if op_name in ["__add__", "__radd__"]:
Expand Down
54 changes: 54 additions & 0 deletions pandas/tests/series/test_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
tests.frame.test_cumulative
"""

import re

import numpy as np
import pytest

Expand Down Expand Up @@ -155,3 +157,55 @@ def test_cumprod_timedelta(self):
ser = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=3)])
with pytest.raises(TypeError, match="cumprod not supported for Timedelta"):
ser.cumprod()

@pytest.mark.parametrize(
"data, op, skipna, expected_data",
[
([], "cumsum", True, []),
([], "cumsum", False, []),
(["x", "z", "y"], "cumsum", True, ["x", "xz", "xzy"]),
(["x", "z", "y"], "cumsum", False, ["x", "xz", "xzy"]),
(["x", pd.NA, "y"], "cumsum", True, ["x", pd.NA, "xy"]),
(["x", pd.NA, "y"], "cumsum", False, ["x", pd.NA, pd.NA]),
([pd.NA, "x", "y"], "cumsum", True, [pd.NA, "x", "xy"]),
([pd.NA, "x", "y"], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cumsum", True, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
([], "cummin", True, []),
([], "cummin", False, []),
(["y", "z", "x"], "cummin", True, ["y", "y", "x"]),
(["y", "z", "x"], "cummin", False, ["y", "y", "x"]),
(["y", pd.NA, "x"], "cummin", True, ["y", pd.NA, "x"]),
(["y", pd.NA, "x"], "cummin", False, ["y", pd.NA, pd.NA]),
([pd.NA, "y", "x"], "cummin", True, [pd.NA, "y", "x"]),
([pd.NA, "y", "x"], "cummin", False, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummin", True, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummin", False, [pd.NA, pd.NA, pd.NA]),
([], "cummax", True, []),
([], "cummax", False, []),
(["x", "z", "y"], "cummax", True, ["x", "z", "z"]),
(["x", "z", "y"], "cummax", False, ["x", "z", "z"]),
(["x", pd.NA, "y"], "cummax", True, ["x", pd.NA, "y"]),
(["x", pd.NA, "y"], "cummax", False, ["x", pd.NA, pd.NA]),
([pd.NA, "x", "y"], "cummax", True, [pd.NA, "x", "y"]),
([pd.NA, "x", "y"], "cummax", False, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummax", True, [pd.NA, pd.NA, pd.NA]),
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
],
)
def test_cum_methods_pyarrow_strings(
self, pyarrow_string_dtype, data, op, skipna, expected_data
):
# https://github.com/pandas-dev/pandas/pull/60633
ser = pd.Series(data, dtype=pyarrow_string_dtype)
method = getattr(ser, op)
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
result = method(skipna=skipna)
tm.assert_series_equal(result, expected)

def test_cumprod_pyarrow_strings(self, pyarrow_string_dtype, skipna):
# https://github.com/pandas-dev/pandas/pull/60633
ser = pd.Series(list("xyz"), dtype=pyarrow_string_dtype)
msg = re.escape(f"operation 'cumprod' not supported for dtype '{ser.dtype}'")
with pytest.raises(TypeError, match=msg):
ser.cumprod(skipna=skipna)

0 comments on commit 1c0c351

Please sign in to comment.