Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,16 +1302,16 @@ def load_stream(self, stream):
dataframes_in_group = read_int(stream)

if dataframes_in_group == 2:
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batches1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batches2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
[
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches(batch1).itercolumns())
for i, c in enumerate(pa.Table.from_batches(batches1).itercolumns())
],
[
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches(batch2).itercolumns())
for i, c in enumerate(pa.Table.from_batches(batches2).itercolumns())
],
)

Expand Down
63 changes: 58 additions & 5 deletions python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pyspark.errors import PythonException
from pyspark.sql import Row
from pyspark.sql.functions import col
from pyspark.sql import functions as sf
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pyarrow,
Expand All @@ -39,16 +39,16 @@
class CogroupedMapInArrowTestsMixin:
@property
def left(self):
return self.spark.range(0, 10, 2, 3).withColumn("v", col("id") * 10)
return self.spark.range(0, 10, 2, 3).withColumn("v", sf.col("id") * 10)

@property
def right(self):
return self.spark.range(0, 10, 3, 3).withColumn("v", col("id") * 10)
return self.spark.range(0, 10, 3, 3).withColumn("v", sf.col("id") * 10)

@property
def cogrouped(self):
grouped_left_df = self.left.groupBy((col("id") / 4).cast("int"))
grouped_right_df = self.right.groupBy((col("id") / 4).cast("int"))
grouped_left_df = self.left.groupBy((sf.col("id") / 4).cast("int"))
grouped_right_df = self.right.groupBy((sf.col("id") / 4).cast("int"))
return grouped_left_df.cogroup(grouped_right_df)

@staticmethod
Expand Down Expand Up @@ -309,6 +309,59 @@ def arrow_func(key, left, right):

self.assertEqual(df2.join(df2).count(), 1)

def test_arrow_batch_slicing(self):
df1 = self.spark.range(10000000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(10)}
df1 = df1.withColumns(cols)

df2 = self.spark.range(100000).select(
(sf.col("id") % 4).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df2 = df2.withColumns(cols)

def summarize(key, left, right):
assert len(left) == 10000000 / 2 or len(left) == 0, len(left)
assert len(right) == 100000 / 4, len(right)
return pa.Table.from_pydict(
{
"key": [key[0].as_py()],
"left_rows": [left.num_rows],
"left_columns": [left.num_columns],
"right_rows": [right.num_rows],
"right_columns": [right.num_columns],
}
)

schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long"

expected = [
Row(key=0, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
Row(key=1, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
Row(key=2, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
Row(key=3, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
]

for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
}
):
result = (
df1.groupby("key")
.cogroup(df2.groupby("key"))
.applyInArrow(summarize, schema=schema)
.sort("key")
.collect()
)

self.assertEqual(expected, result)


class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
96 changes: 76 additions & 20 deletions python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import unittest
from typing import cast

from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
from pyspark.sql import functions as sf
from pyspark.sql.functions import pandas_udf, udf
from pyspark.sql.types import (
ArrayType,
DoubleType,
Expand Down Expand Up @@ -55,35 +56,35 @@ class CogroupedApplyInPandasTestsMixin:
def data1(self):
return (
self.spark.range(10)
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
.withColumn("k", explode(col("ks")))
.withColumn("v", col("k") * 10)
.withColumn("ks", sf.array([sf.lit(i) for i in range(20, 30)]))
.withColumn("k", sf.explode(sf.col("ks")))
.withColumn("v", sf.col("k") * 10)
.drop("ks")
)

@property
def data2(self):
return (
self.spark.range(10)
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
.withColumn("k", explode(col("ks")))
.withColumn("v2", col("k") * 100)
.withColumn("ks", sf.array([sf.lit(i) for i in range(20, 30)]))
.withColumn("k", sf.explode(sf.col("ks")))
.withColumn("v2", sf.col("k") * 100)
.drop("ks")
)

def test_simple(self):
self._test_merge(self.data1, self.data2)

def test_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
left = self.data1.where(sf.col("id") % 2 == 0)
self._test_merge(left, self.data2)

def test_right_group_empty(self):
right = self.data2.where(col("id") % 2 == 0)
right = self.data2.where(sf.col("id") % 2 == 0)
self._test_merge(self.data1, right)

def test_different_schemas(self):
right = self.data2.withColumn("v3", lit("a"))
right = self.data2.withColumn("v3", sf.lit("a"))
self._test_merge(
self.data1, right, output_schema="id long, k int, v int, v2 int, v3 string"
)
Expand Down Expand Up @@ -116,9 +117,9 @@ def test_complex_group_by(self):

right = pd.DataFrame.from_dict({"id": [11, 12, 13], "k": [5, 6, 7], "v2": [90, 100, 110]})

left_gdf = self.spark.createDataFrame(left).groupby(col("id") % 2 == 0)
left_gdf = self.spark.createDataFrame(left).groupby(sf.col("id") % 2 == 0)

right_gdf = self.spark.createDataFrame(right).groupby(col("id") % 2 == 0)
right_gdf = self.spark.createDataFrame(right).groupby(sf.col("id") % 2 == 0)

def merge_pandas(lft, rgt):
return pd.merge(lft[["k", "v"]], rgt[["k", "v2"]], on=["k"])
Expand Down Expand Up @@ -354,20 +355,20 @@ def test_with_key_right(self):
self._test_with_key(self.data1, self.data1, isLeft=False)

def test_with_key_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
left = self.data1.where(sf.col("id") % 2 == 0)
self._test_with_key(left, self.data1, isLeft=True)

def test_with_key_right_group_empty(self):
right = self.data1.where(col("id") % 2 == 0)
right = self.data1.where(sf.col("id") % 2 == 0)
self._test_with_key(self.data1, right, isLeft=False)

def test_with_key_complex(self):
def left_assign_key(key, lft, _):
return lft.assign(key=key[0])

result = (
self.data1.groupby(col("id") % 2 == 0)
.cogroup(self.data2.groupby(col("id") % 2 == 0))
self.data1.groupby(sf.col("id") % 2 == 0)
.cogroup(self.data2.groupby(sf.col("id") % 2 == 0))
.applyInPandas(left_assign_key, "id long, k int, v int, key boolean")
.sort(["id", "k"])
.toPandas()
Expand Down Expand Up @@ -456,7 +457,9 @@ def test_with_window_function(self):
left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from df here
right_df = (
df.select(col("id").alias("id"), col("day").alias("day"), col("value").alias("right"))
df.select(
sf.col("id").alias("id"), sf.col("day").alias("day"), sf.col("value").alias("right")
)
.repartition(parts)
.cache()
)
Expand All @@ -465,9 +468,9 @@ def test_with_window_function(self):
window = Window.partitionBy("day", "id")

left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy(
"id", "day"
)
right_grouped_df = right_df.withColumn(
"day_sum", sf.sum(sf.col("day")).over(window)
).groupBy("id", "day")

def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(
Expand Down Expand Up @@ -653,6 +656,59 @@ def __test_merge_error(
with self.assertRaisesRegex(errorClass, error_message_regex):
self.__test_merge(left, right, by, fn, output_schema)

def test_arrow_batch_slicing(self):
df1 = self.spark.range(10000000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(10)}
df1 = df1.withColumns(cols)

df2 = self.spark.range(100000).select(
(sf.col("id") % 4).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df2 = df2.withColumns(cols)

def summarize(key, left, right):
assert len(left) == 10000000 / 2 or len(left) == 0, len(left)
assert len(right) == 100000 / 4, len(right)
return pd.DataFrame(
{
"key": [key[0]],
"left_rows": [len(left)],
"left_columns": [len(left.columns)],
"right_rows": [len(right)],
"right_columns": [len(right.columns)],
}
)

schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long"

expected = [
Row(key=0, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
Row(key=1, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
Row(key=2, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
Row(key=3, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
]

for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
}
):
result = (
df1.groupby("key")
.cogroup(df2.groupby("key"))
.applyInPandas(summarize, schema=schema)
.sort("key")
.collect()
)

self.assertEqual(expected, result)


class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Loading