-
-
Notifications
You must be signed in to change notification settings - Fork 18.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Implement cum* methods for PyArrow strings #60633
Changes from 8 commits
8a00df4
170f2e2
4ccf0d4
be726f0
bf38cef
1f8e36e
dd8fcbe
ed895b9
46ff2c1
f2b448d
4d11a1d
d3468cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ | |
is_list_like, | ||
is_numeric_dtype, | ||
is_scalar, | ||
is_string_dtype, | ||
pandas_dtype, | ||
) | ||
from pandas.core.dtypes.dtypes import DatetimeTZDtype | ||
|
@@ -1619,6 +1620,9 @@ def _accumulate( | |
------ | ||
NotImplementedError : subclass does not define accumulations | ||
""" | ||
if is_string_dtype(self): | ||
return self._str_accumulate(name=name, skipna=skipna, **kwargs) | ||
|
||
pyarrow_name = { | ||
"cummax": "cumulative_max", | ||
"cummin": "cumulative_min", | ||
|
@@ -1654,6 +1658,64 @@ def _accumulate( | |
|
||
return type(self)(result) | ||
|
||
def _str_accumulate( | ||
self, name: str, *, skipna: bool = True, **kwargs | ||
) -> ArrowExtensionArray | ExtensionArray: | ||
""" | ||
Accumulate implementation for strings, see `_accumulate` docstring for details. | ||
|
||
pyarrow.compute does not implement these methods for strings. | ||
""" | ||
if name == "cumprod": | ||
msg = f"operation '{name}' not supported for dtype '{self.dtype}'" | ||
raise TypeError(msg) | ||
|
||
# We may need to strip out leading / trailing NA values | ||
head: pa.array | None = None | ||
tail: pa.array | None = None | ||
pa_array = self._pa_array | ||
np_func = { | ||
"cumsum": np.cumsum, | ||
"cummin": np.minimum.accumulate, | ||
"cummax": np.maximum.accumulate, | ||
}[name] | ||
|
||
if self._hasna: | ||
if skipna: | ||
if name == "cumsum": | ||
pa_array = pc.fill_null(pa_array, "") | ||
else: | ||
# After the first non-NA value we can retain the running min/max | ||
# by forward filling. | ||
pa_array = pc.fill_null_forward(pa_array) | ||
# But any leading NA values should result in "". | ||
nulls = pc.is_null(pa_array) | ||
idx = pc.index(nulls, False).as_py() | ||
if idx == -1: | ||
idx = len(pa_array) | ||
if idx > 0: | ||
head = pa.array([""] * idx, type=pa_array.type) | ||
pa_array = pa_array[idx:].combine_chunks() | ||
else: | ||
# When not skipping NA values, the result should be null from | ||
# the first NA value onward. | ||
nulls = pc.is_null(pa_array) | ||
idx = pc.index(nulls, True).as_py() | ||
tail = pa.nulls(len(pa_array) - idx, type=pa_array.type) | ||
pa_array = pa_array[:idx].combine_chunks() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the combine chunks needed here? (I would expect that the conversion to numpy (when calling the numpy func) will do this automatically (and potentially more efficiently)) |
||
|
||
# error: Cannot call function of unknown type | ||
pa_result = pa.array(np_func(pa_array), type=pa_array.type) # type: ignore[operator] | ||
|
||
assert head is None or tail is None | ||
if head is not None: | ||
pa_result = pa.concat_arrays([head, pa_result]) | ||
elif tail is not None: | ||
pa_result = pa.concat_arrays([pa_result, tail]) | ||
|
||
result = type(self)(pa_result) | ||
return result | ||
|
||
def _reduce_pyarrow(self, name: str, *, skipna: bool = True, **kwargs) -> pa.Scalar: | ||
""" | ||
Return a pyarrow scalar result of performing the reduction operation. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
tests.frame.test_cumulative | ||
""" | ||
|
||
import re | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
|
@@ -227,3 +229,55 @@ def test_cumprod_timedelta(self): | |
ser = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=3)]) | ||
with pytest.raises(TypeError, match="cumprod not supported for Timedelta"): | ||
ser.cumprod() | ||
|
||
@pytest.mark.parametrize( | ||
"data, op, skipna, expected_data", | ||
[ | ||
([], "cumsum", True, []), | ||
([], "cumsum", False, []), | ||
(["x", "z", "y"], "cumsum", True, ["x", "xz", "xzy"]), | ||
(["x", "z", "y"], "cumsum", False, ["x", "xz", "xzy"]), | ||
(["x", pd.NA, "y"], "cumsum", True, ["x", "x", "xy"]), | ||
(["x", pd.NA, "y"], "cumsum", False, ["x", pd.NA, pd.NA]), | ||
([pd.NA, "x", "y"], "cumsum", True, ["", "x", "xy"]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that for numerical data, we actually (somewhat inconsistently?) propagate leading NAs:
(i.e. the result doesn't have 0.0 for the first element) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually not related to "leading" NAs. It seems what is happening is that missing values are ignored to calculate the cumulative result, but then are propagated to the result elementwise. This is also shown in the docstring example of cumsum, so this seems intentional. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this. Agreed we should match this behavior. I do find it odd, but that's (possibly) for another day! |
||
([pd.NA, "x", "y"], "cumsum", False, [pd.NA, pd.NA, pd.NA]), | ||
([pd.NA, pd.NA, pd.NA], "cumsum", True, ["", "", ""]), | ||
([pd.NA, pd.NA, pd.NA], "cumsum", False, [pd.NA, pd.NA, pd.NA]), | ||
([], "cummin", True, []), | ||
([], "cummin", False, []), | ||
(["y", "z", "x"], "cummin", True, ["y", "y", "x"]), | ||
(["y", "z", "x"], "cummin", False, ["y", "y", "x"]), | ||
(["y", pd.NA, "x"], "cummin", True, ["y", "y", "x"]), | ||
(["y", pd.NA, "x"], "cummin", False, ["y", pd.NA, pd.NA]), | ||
([pd.NA, "y", "x"], "cummin", True, ["", "y", "x"]), | ||
([pd.NA, "y", "x"], "cummin", False, [pd.NA, pd.NA, pd.NA]), | ||
([pd.NA, pd.NA, pd.NA], "cummin", True, ["", "", ""]), | ||
([pd.NA, pd.NA, pd.NA], "cummin", False, [pd.NA, pd.NA, pd.NA]), | ||
([], "cummax", True, []), | ||
([], "cummax", False, []), | ||
(["x", "z", "y"], "cummax", True, ["x", "z", "z"]), | ||
(["x", "z", "y"], "cummax", False, ["x", "z", "z"]), | ||
(["x", pd.NA, "y"], "cummax", True, ["x", "x", "y"]), | ||
(["x", pd.NA, "y"], "cummax", False, ["x", pd.NA, pd.NA]), | ||
([pd.NA, "x", "y"], "cummax", True, ["", "x", "y"]), | ||
([pd.NA, "x", "y"], "cummax", False, [pd.NA, pd.NA, pd.NA]), | ||
([pd.NA, pd.NA, pd.NA], "cummax", True, ["", "", ""]), | ||
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]), | ||
], | ||
) | ||
def test_cum_methods_pyarrow_strings( | ||
self, pyarrow_string_dtype, data, op, skipna, expected_data | ||
): | ||
# https://github.com/pandas-dev/pandas/pull/60633 | ||
ser = pd.Series(data, dtype=pyarrow_string_dtype) | ||
method = getattr(ser, op) | ||
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype) | ||
result = method(skipna=skipna) | ||
tm.assert_series_equal(result, expected) | ||
|
||
def test_cumprod_pyarrow_strings(self, pyarrow_string_dtype, skipna): | ||
# https://github.com/pandas-dev/pandas/pull/60633 | ||
ser = pd.Series(list("xyz"), dtype=pyarrow_string_dtype) | ||
msg = re.escape(f"operation 'cumprod' not supported for dtype '{ser.dtype}'") | ||
with pytest.raises(TypeError, match=msg): | ||
ser.cumprod(skipna=skipna) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to comment: I don't think this can work. Although it is then strange the tests are passing :) But it seems this was not doing what I think you expected it was doing -> #60661
I would use the same approach of creating the dtype through StringDtype(..) like in some of the fixtures above