Skip to content

Fix DataFrame.aggregate to preserve extension dtypes with callable functions #61816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ Other
- Bug in ``Series.list`` methods not preserving the original name. (:issue:`60522`)
- Bug in printing a :class:`DataFrame` with a :class:`DataFrame` stored in :attr:`DataFrame.attrs` raised a ``ValueError`` (:issue:`60455`)
- Bug in printing a :class:`Series` with a :class:`DataFrame` stored in :attr:`Series.attrs` raised a ``ValueError`` (:issue:`60568`)
- Bug in :meth:`DataFrame.aggregate` dropping pyarrow backend for lambda aggregation functions (:issue:`61812`)
- Fixed bug where the :class:`DataFrame` constructor misclassified array-like objects with a ``.name`` attribute as :class:`Series` or :class:`Index` (:issue:`61443`)
- Fixed regression in :meth:`DataFrame.from_records` not initializing subclasses properly (:issue:`57008`)

Expand Down
85 changes: 85 additions & 0 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def agg(self) -> DataFrame | Series | None:
elif is_list_like(func):
# we require a list, but not a 'str'
return self.agg_list_like()
elif callable(func):
return self.agg_callable()

# caller can react
return None
Expand Down Expand Up @@ -797,6 +799,89 @@ def _apply_str(self, obj, func: str, *args, **kwargs):
msg = f"'{func}' is not a valid function for '{type(obj).__name__}' object"
raise AttributeError(msg)

def agg_callable(self) -> DataFrame | Series:
"""
Compute aggregation in the case of a callable argument.

This method handles callable functions while preserving extension dtypes
by delegating to the same infrastructure used for string aggregations.

Returns
-------
Result of aggregation.
"""
obj = self.obj
func = self.func

if obj.ndim == 1:
return func(obj, *self.args, **self.kwargs)

# Use _reduce to preserve extension dtypes like on string aggregation
try:
result = obj._reduce(
func,
name=getattr(func, "__name__", "<lambda>"),
axis=self.axis,
skipna=True,
numeric_only=False,
**self.kwargs,
)
return result

except (AttributeError, TypeError):
# If _reduce fails, fallback to column-wise
return self._agg_callable_fallback()

def _agg_callable_fallback(self) -> DataFrame | Series:
"""
Fallback method for callable aggregation when _reduce fails.

This method applies the function column-wise while preserving dtypes,
but avoids the performance overhead of row-by-row processing.
"""
obj = self.obj
func = self.func

if self.axis == 1:
# For row-wise aggregation, transpose and recurse
transposed_result = obj.T._aggregate(
func, *self.args, axis=0, **self.kwargs
)
return transposed_result

from pandas import Series

try:
# Apply function to each column
results = {}
for name in obj.columns:
col = obj._get_column_reference(name)
result_val = func(col, *self.args, **self.kwargs)
results[name] = result_val

result = Series(results, name=None)

# Preserve extension dtypes where possible
for name in result.index:
if name in obj.columns:
original_dtype = obj.dtypes[name]
if hasattr(original_dtype, "construct_array_type"):
try:
array_type = original_dtype.construct_array_type()
if hasattr(array_type, "_from_sequence"):
preserved_val = array_type._from_sequence(
[result[name]], dtype=original_dtype
)[0]
result.loc[name] = preserved_val
except Exception:
# If dtype preservation fails, keep the computed value
pass

return result

except Exception:
return None

Copy link
Member

@arthurlw arthurlw Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Could you make sure there are two blank lines between function and class definitions for consistency?

Running the below should fix it and other linting errors:

pre-commit run --all-files

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure


class NDFrameApply(Apply):
"""
Expand Down
Loading