@@ -771,6 +771,16 @@ def test_execution_plan(aggregate_df):
771
771
assert rows_returned == 5
772
772
773
773
774
+ @pytest .mark .asyncio
775
+ async def test_async_iteration_of_df (aggregate_df ):
776
+ rows_returned = 0
777
+ async for batch in aggregate_df .execute_stream ():
778
+ assert batch is not None
779
+ rows_returned += len (batch .to_pyarrow ()[0 ])
780
+
781
+ assert rows_returned == 5
782
+
783
+
774
784
def test_repartition (df ):
775
785
df .repartition (2 )
776
786
@@ -958,6 +968,18 @@ def test_execute_stream(df):
958
968
assert not list (stream ) # after one iteration the generator must be exhausted
959
969
960
970
971
+ @pytest .mark .asyncio
972
+ async def test_execute_stream_async (df ):
973
+ stream = df .execute_stream ()
974
+ batches = [batch async for batch in stream ]
975
+
976
+ assert all (batch is not None for batch in batches )
977
+
978
+ # After consuming all batches, the stream should be exhausted
979
+ remaining_batches = [batch async for batch in stream ]
980
+ assert not remaining_batches
981
+
982
+
961
983
@pytest .mark .parametrize ("schema" , [True , False ])
962
984
def test_execute_stream_to_arrow_table (df , schema ):
963
985
stream = df .execute_stream ()
@@ -974,6 +996,25 @@ def test_execute_stream_to_arrow_table(df, schema):
974
996
assert set (pyarrow_table .column_names ) == {"a" , "b" , "c" }
975
997
976
998
999
+ @pytest .mark .asyncio
1000
+ @pytest .mark .parametrize ("schema" , [True , False ])
1001
+ async def test_execute_stream_to_arrow_table_async (df , schema ):
1002
+ stream = df .execute_stream ()
1003
+
1004
+ if schema :
1005
+ pyarrow_table = pa .Table .from_batches (
1006
+ [batch .to_pyarrow () async for batch in stream ], schema = df .schema ()
1007
+ )
1008
+ else :
1009
+ pyarrow_table = pa .Table .from_batches (
1010
+ [batch .to_pyarrow () async for batch in stream ]
1011
+ )
1012
+
1013
+ assert isinstance (pyarrow_table , pa .Table )
1014
+ assert pyarrow_table .shape == (3 , 3 )
1015
+ assert set (pyarrow_table .column_names ) == {"a" , "b" , "c" }
1016
+
1017
+
977
1018
def test_execute_stream_partitioned (df ):
978
1019
streams = df .execute_stream_partitioned ()
979
1020
assert all (batch is not None for stream in streams for batch in stream )
@@ -982,6 +1023,19 @@ def test_execute_stream_partitioned(df):
982
1023
) # after one iteration all generators must be exhausted
983
1024
984
1025
1026
+ @pytest .mark .asyncio
1027
+ async def test_execute_stream_partitioned_async (df ):
1028
+ streams = df .execute_stream_partitioned ()
1029
+
1030
+ for stream in streams :
1031
+ batches = [batch async for batch in stream ]
1032
+ assert all (batch is not None for batch in batches )
1033
+
1034
+ # Ensure the stream is exhausted after iteration
1035
+ remaining_batches = [batch async for batch in stream ]
1036
+ assert not remaining_batches
1037
+
1038
+
985
1039
def test_empty_to_arrow_table (df ):
986
1040
# Convert empty datafusion dataframe to pyarrow Table
987
1041
pyarrow_table = df .limit (0 ).to_arrow_table ()
0 commit comments