Skip to content

Commit 4f45703

Browse files
authored
added pytest asyncio tests (#1063)
1 parent 55141ba commit 4f45703

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ dev = [
150150
"maturin>=1.8.1",
151151
"numpy>1.25.0",
152152
"pytest>=7.4.4",
153+
"pytest-asyncio>=0.23.3",
153154
"ruff>=0.9.1",
154155
"toml>=0.10.2",
155156
"pygithub==2.5.0",

python/tests/test_dataframe.py

+54
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,16 @@ def test_execution_plan(aggregate_df):
771771
assert rows_returned == 5
772772

773773

774+
@pytest.mark.asyncio
775+
async def test_async_iteration_of_df(aggregate_df):
776+
rows_returned = 0
777+
async for batch in aggregate_df.execute_stream():
778+
assert batch is not None
779+
rows_returned += len(batch.to_pyarrow()[0])
780+
781+
assert rows_returned == 5
782+
783+
774784
def test_repartition(df):
775785
df.repartition(2)
776786

@@ -958,6 +968,18 @@ def test_execute_stream(df):
958968
assert not list(stream) # after one iteration the generator must be exhausted
959969

960970

971+
@pytest.mark.asyncio
972+
async def test_execute_stream_async(df):
973+
stream = df.execute_stream()
974+
batches = [batch async for batch in stream]
975+
976+
assert all(batch is not None for batch in batches)
977+
978+
# After consuming all batches, the stream should be exhausted
979+
remaining_batches = [batch async for batch in stream]
980+
assert not remaining_batches
981+
982+
961983
@pytest.mark.parametrize("schema", [True, False])
962984
def test_execute_stream_to_arrow_table(df, schema):
963985
stream = df.execute_stream()
@@ -974,6 +996,25 @@ def test_execute_stream_to_arrow_table(df, schema):
974996
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
975997

976998

999+
@pytest.mark.asyncio
1000+
@pytest.mark.parametrize("schema", [True, False])
1001+
async def test_execute_stream_to_arrow_table_async(df, schema):
1002+
stream = df.execute_stream()
1003+
1004+
if schema:
1005+
pyarrow_table = pa.Table.from_batches(
1006+
[batch.to_pyarrow() async for batch in stream], schema=df.schema()
1007+
)
1008+
else:
1009+
pyarrow_table = pa.Table.from_batches(
1010+
[batch.to_pyarrow() async for batch in stream]
1011+
)
1012+
1013+
assert isinstance(pyarrow_table, pa.Table)
1014+
assert pyarrow_table.shape == (3, 3)
1015+
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
1016+
1017+
9771018
def test_execute_stream_partitioned(df):
9781019
streams = df.execute_stream_partitioned()
9791020
assert all(batch is not None for stream in streams for batch in stream)
@@ -982,6 +1023,19 @@ def test_execute_stream_partitioned(df):
9821023
) # after one iteration all generators must be exhausted
9831024

9841025

1026+
@pytest.mark.asyncio
1027+
async def test_execute_stream_partitioned_async(df):
1028+
streams = df.execute_stream_partitioned()
1029+
1030+
for stream in streams:
1031+
batches = [batch async for batch in stream]
1032+
assert all(batch is not None for batch in batches)
1033+
1034+
# Ensure the stream is exhausted after iteration
1035+
remaining_batches = [batch async for batch in stream]
1036+
assert not remaining_batches
1037+
1038+
9851039
def test_empty_to_arrow_table(df):
9861040
# Convert empty datafusion dataframe to pyarrow Table
9871041
pyarrow_table = df.limit(0).to_arrow_table()

uv.lock

+16-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)