Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ private[spark] object PythonEvalType {
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
val SQL_GROUPED_MAP_ARROW_ITER_UDF = 215
val SQL_GROUPED_MAP_PANDAS_ITER_UDF = 216

// Arrow UDFs
val SQL_SCALAR_ARROW_UDF = 250
Expand Down Expand Up @@ -102,6 +103,8 @@ private[spark] object PythonEvalType {
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF => "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF"
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
"SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF"
case SQL_GROUPED_MAP_ARROW_ITER_UDF => "SQL_GROUPED_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_ITER_UDF => "SQL_GROUPED_MAP_PANDAS_ITER_UDF"

// Arrow UDFs
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
Expand Down
15 changes: 13 additions & 2 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,25 @@ def applyInPandas(
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.pandas.typehints import infer_group_pandas_eval_type_from_func

_validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
# Try to infer the eval type from type hints
eval_type = None
try:
eval_type = infer_group_pandas_eval_type_from_func(func)
except Exception:
warnings.warn("Cannot infer the eval type from type hints.", UserWarning)

if eval_type is None:
eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF

_validate_vectorized_udf(func, eval_type)
if isinstance(schema, str):
schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
evalType=eval_type,
)

res = DataFrame(
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ PandasGroupedMapUDFTransformWithStateInitStateType = Literal[212]
GroupedMapUDFTransformWithStateType = Literal[213]
GroupedMapUDFTransformWithStateInitStateType = Literal[214]
ArrowGroupedMapIterUDFType = Literal[215]
PandasGroupedMapIterUDFType = Literal[216]

# Arrow UDFs
ArrowScalarUDFType = Literal[250]
Expand Down Expand Up @@ -347,6 +348,8 @@ PandasScalarIterFunction = Union[
PandasGroupedMapFunction = Union[
Callable[[DataFrameLike], DataFrameLike],
Callable[[Any, DataFrameLike], DataFrameLike],
Callable[[Iterator[DataFrameLike]], Iterator[DataFrameLike]],
Callable[[Any, Iterator[DataFrameLike]], Iterator[DataFrameLike]],
]

PandasGroupedMapFunctionWithState = Callable[
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def calculate(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]:
pyspark.sql.GroupedData.applyInArrow
pyspark.sql.PandasCogroupedOps.applyInArrow
pyspark.sql.UDFRegistration.register
pyspark.sql.GroupedData.applyInPandas
"""
require_minimum_pyarrow_version()

Expand All @@ -346,6 +347,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. versionchanged:: 4.0.0
Supports keyword-arguments in SCALAR and GROUPED_AGG type.

.. versionchanged:: 4.1.0
Supports iterator API in GROUPED_MAP type.

Parameters
----------
f : function, optional
Expand Down Expand Up @@ -690,6 +694,7 @@ def vectorized_udf(
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
Expand Down Expand Up @@ -771,6 +776,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
)
elif evalType in [
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
Expand Down Expand Up @@ -836,6 +842,19 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
},
)

if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF and len(argspec.args) not in (
1,
2,
):
raise PySparkValueError(
errorClass="INVALID_PANDAS_UDF",
messageParameters={
"detail": "the function in groupby.applyInPandas with iterator API must take "
"either one argument (batches: Iterator[pandas.DataFrame]) or two arguments "
"(key, batches: Iterator[pandas.DataFrame]).",
},
)

if evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF and len(argspec.args) not in (1, 2):
raise PySparkValueError(
errorClass="INVALID_PANDAS_UDF",
Expand Down
14 changes: 10 additions & 4 deletions python/pyspark/sql/pandas/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ from pyspark.sql.pandas._typing import (
PandasGroupedAggFunction,
PandasGroupedAggUDFType,
PandasGroupedMapFunction,
PandasGroupedMapIterUDFType,
PandasGroupedMapUDFType,
PandasScalarIterFunction,
PandasScalarIterUDFType,
Expand Down Expand Up @@ -145,19 +146,24 @@ def pandas_udf(
def pandas_udf(
f: PandasGroupedMapFunction,
returnType: Union[StructType, str],
functionType: PandasGroupedMapUDFType,
functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> GroupedMapPandasUserDefinedFunction: ...
@overload
def pandas_udf(
f: Union[StructType, str], returnType: PandasGroupedMapUDFType
f: Union[StructType, str],
returnType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> Callable[[PandasGroupedMapFunction], GroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
*, returnType: Union[StructType, str], functionType: PandasGroupedMapUDFType
*,
returnType: Union[StructType, str],
functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> Callable[[PandasGroupedMapFunction], GroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
f: Union[StructType, str], *, functionType: PandasGroupedMapUDFType
f: Union[StructType, str],
*,
functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> Callable[[PandasGroupedMapFunction], GroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
Expand Down
96 changes: 82 additions & 14 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ def applyInPandas(
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
as a `DataFrame`.

The function should take a `pandas.DataFrame` and return another
`pandas.DataFrame`. Alternatively, the user can pass a function that takes
a tuple of the grouping key(s) and a `pandas.DataFrame`.
For each group, all columns are passed together as a `pandas.DataFrame`
to the user-function and the returned `pandas.DataFrame` are combined as a
:class:`DataFrame`.
The function can take one of two forms: It can take a `pandas.DataFrame` and return a
`pandas.DataFrame`, or it can take an iterator of `pandas.DataFrame` and yield
`pandas.DataFrame`. Alternatively each form can take a tuple of grouping keys
as the first argument in addition to the input type above.
For each group, all columns are passed together as a `pandas.DataFrame` or iterator of
`pandas.DataFrame`, and the returned `pandas.DataFrame` or iterator of `pandas.DataFrame`
are combined as a :class:`DataFrame`.

The `schema` should be a :class:`StructType` describing the schema of the returned
`pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
Expand All @@ -141,12 +142,17 @@ def applyInPandas(
.. versionchanged:: 3.4.0
Support Spark Connect.

.. versionchanged:: 4.1.0
Added support for an iterator of `pandas.DataFrame` API.

Parameters
----------
func : function
a Python native function that takes a `pandas.DataFrame` and outputs a
`pandas.DataFrame`, or that takes one tuple (grouping keys) and a
`pandas.DataFrame` and outputs a `pandas.DataFrame`.
a Python native function that either takes a `pandas.DataFrame` and outputs a
`pandas.DataFrame` or takes an iterator of `pandas.DataFrame` and yields
`pandas.DataFrame`. Additionally, each form can take a tuple of grouping keys
as the first argument, with the `pandas.DataFrame` or iterator of `pandas.DataFrame`
as the second argument.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Expand Down Expand Up @@ -214,22 +220,84 @@ def applyInPandas(
| 2| 2| 3.0|
+---+-----------+----+

The function can also take and return an iterator of `pandas.DataFrame` using type
hints.

>>> from typing import Iterator # doctest: +SKIP
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> def filter_func(
... batches: Iterator[pd.DataFrame]
... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
... for batch in batches:
... # Process and yield each batch independently
... filtered = batch[batch['v'] > 2.0]
... if not filtered.empty:
... yield filtered[['v']]
>>> df.groupby("id").applyInPandas(
... filter_func, schema="v double").show() # doctest: +SKIP
+----+
| v|
+----+
| 3.0|
| 5.0|
|10.0|
+----+

Alternatively, the user can pass a function that takes two arguments.
In this case, the grouping key(s) will be passed as the first argument and the data will
be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
data types. The data will still be passed in as an iterator of `pandas.DataFrame`.

>>> from typing import Iterator, Tuple, Any # doctest: +SKIP
>>> def transform_func(
... key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
... for batch in batches:
... # Yield transformed results for each batch
... result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
... yield result[['id', 'v_doubled']]
>>> df.groupby("id").applyInPandas(
... transform_func, schema="id long, v_doubled double").show() # doctest: +SKIP
+---+----------+
| id|v_doubled |
+---+----------+
| 1| 2.0|
| 1| 4.0|
| 2| 6.0|
| 2| 10.0|
| 2| 20.0|
+---+----------+

Notes
-----
This function requires a full shuffle. All the data of a group will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
This function requires a full shuffle. If using the `pandas.DataFrame` API, all data of a
group will be loaded into memory, so the user should be aware of the potential OOM risk if
data is skewed and certain groups are too large to fit in memory, and can use the
iterator of `pandas.DataFrame` API to mitigate this.

See Also
--------
pyspark.sql.functions.pandas_udf
"""
from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.functions import pandas_udf
from pyspark.sql.pandas.typehints import infer_group_pandas_eval_type_from_func

assert isinstance(self, GroupedData)

udf = pandas_udf(func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
# Try to infer the eval type from type hints
eval_type = None
try:
eval_type = infer_group_pandas_eval_type_from_func(func)
except Exception as e:
warnings.warn(f"Cannot infer the eval type from type hints: {e}", UserWarning)

if eval_type is None:
eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF

udf = pandas_udf(func, returnType=schema, functionType=eval_type)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc)
Expand Down
82 changes: 82 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,88 @@ def __repr__(self):
return "GroupPandasUDFSerializer"


class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
"""
Serializer for grouped map Pandas iterator UDFs.

Loads grouped data as pandas.Series and serializes results from iterator UDFs.
Flattens the (dataframes_generator, arrow_type) tuple by iterating over the generator.
"""

def __init__(
self,
timezone,
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled,
):
super(GroupPandasIterUDFSerializer, self).__init__(
timezone=timezone,
safecheck=safecheck,
assign_cols_by_name=assign_cols_by_name,
df_for_struct=False,
struct_in_pandas="dict",
ndarray_as_list=False,
arrow_cast=True,
input_types=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)

def load_stream(self, stream):
"""
Deserialize Grouped ArrowRecordBatches and yield a generator of pandas.Series lists
(one list per batch), allowing the iterator UDF to process data batch-by-batch.
"""
import pyarrow as pa

def process_group(batches: "Iterator[pa.RecordBatch]"):
# Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch
for batch in batches:
series = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
]
yield series

dataframes_in_group = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == 1:
# Lazily read and convert Arrow batches one at a time from the stream
# This avoids loading all batches into memory for the group
batch_iter = process_group(ArrowStreamSerializer.load_stream(self, stream))
yield batch_iter
# Make sure the batches are fully iterated before getting the next group
for _ in batch_iter:
pass

elif dataframes_in_group != 0:
raise PySparkValueError(
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
)

def dump_stream(self, iterator, stream):
"""
Flatten the (dataframes_generator, arrow_type) tuples by iterating over each generator.
This allows the iterator UDF to stream results without materializing all DataFrames.
"""
# Flatten: (dataframes_generator, arrow_type) -> (df, arrow_type), (df, arrow_type), ...
flattened_iter = (
(df, arrow_type) for dataframes_gen, arrow_type in iterator for df in dataframes_gen
)

# Convert each (df, arrow_type) to the format expected by parent's dump_stream
series_iter = ([(df, arrow_type)] for df, arrow_type in flattened_iter)

super(GroupPandasIterUDFSerializer, self).dump_stream(series_iter, stream)

def __repr__(self):
return "GroupPandasIterUDFSerializer"


class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
"""
Serializes pyarrow.RecordBatch data with Arrow streaming format.
Expand Down
Loading