@@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr
655
655
}
656
656
657
657
658
- def _combine_positional_deletes (positional_deletes : List [pa .ChunkedArray ], rows : int ) -> pa .Array :
658
+ def _combine_positional_deletes (positional_deletes : List [pa .ChunkedArray ], start_index : int , end_index : int ) -> pa .Array :
659
659
if len (positional_deletes ) == 1 :
660
660
all_chunks = positional_deletes [0 ]
661
661
else :
662
662
all_chunks = pa .chunked_array (itertools .chain (* [arr .chunks for arr in positional_deletes ]))
663
- return np .setdiff1d (np .arange (rows ), all_chunks , assume_unique = False )
663
+ return np .subtract ( np . setdiff1d (np .arange (start_index , end_index ), all_chunks , assume_unique = False ), start_index )
664
664
665
665
666
666
def pyarrow_to_schema (schema : pa .Schema , name_mapping : Optional [NameMapping ] = None ) -> Schema :
@@ -967,17 +967,16 @@ def _field_id(self, field: pa.Field) -> int:
967
967
return - 1
968
968
969
969
970
- def _task_to_table (
970
+ def _task_to_record_batches (
971
971
fs : FileSystem ,
972
972
task : FileScanTask ,
973
973
bound_row_filter : BooleanExpression ,
974
974
projected_schema : Schema ,
975
975
projected_field_ids : Set [int ],
976
976
positional_deletes : Optional [List [ChunkedArray ]],
977
977
case_sensitive : bool ,
978
- limit : Optional [int ] = None ,
979
978
name_mapping : Optional [NameMapping ] = None ,
980
- ) -> Optional [pa .Table ]:
979
+ ) -> Iterator [pa .RecordBatch ]:
981
980
_ , _ , path = PyArrowFileIO .parse_location (task .file .file_path )
982
981
arrow_format = ds .ParquetFileFormat (pre_buffer = True , buffer_size = (ONE_MEGABYTE * 8 ))
983
982
with fs .open_input_file (path ) as fin :
@@ -1005,36 +1004,27 @@ def _task_to_table(
1005
1004
columns = [col .name for col in file_project_schema .columns ],
1006
1005
)
1007
1006
1008
- if positional_deletes :
1009
- # Create the mask of indices that we're interested in
1010
- indices = _combine_positional_deletes (positional_deletes , fragment .count_rows ())
1011
-
1012
- if limit :
1013
- if pyarrow_filter is not None :
1014
- # In case of the filter, we don't exactly know how many rows
1015
- # we need to fetch upfront, can be optimized in the future:
1016
- # https://github.com/apache/arrow/issues/35301
1017
- arrow_table = fragment_scanner .take (indices )
1018
- arrow_table = arrow_table .filter (pyarrow_filter )
1019
- arrow_table = arrow_table .slice (0 , limit )
1020
- else :
1021
- arrow_table = fragment_scanner .take (indices [0 :limit ])
1022
- else :
1023
- arrow_table = fragment_scanner .take (indices )
1007
+ current_index = 0
1008
+ batches = fragment_scanner .to_batches ()
1009
+ for batch in batches :
1010
+ if positional_deletes :
1011
+ # Create the mask of indices that we're interested in
1012
+ indices = _combine_positional_deletes (positional_deletes , current_index , len (batch ))
1013
+
1014
+ batch = batch .take (indices )
1024
1015
# Apply the user filter
1025
1016
if pyarrow_filter is not None :
1017
+ # we need to switch back and forth between RecordBatch and Table
1018
+ # as Expression filter isn't yet supported in RecordBatch
1019
+ # https://github.com/apache/arrow/issues/39220
1020
+ arrow_table = pa .Table .from_batches ([batch ])
1026
1021
arrow_table = arrow_table .filter (pyarrow_filter )
1027
- else :
1028
- # If there are no deletes, we can just take the head
1029
- # and the user-filter is already applied
1030
- if limit :
1031
- arrow_table = fragment_scanner .head (limit )
1022
+ arrow_batches = arrow_table .to_batches ()
1023
+ for arrow_batch in arrow_batches :
1024
+ yield to_requested_schema (projected_schema , file_project_schema , arrow_table )
1032
1025
else :
1033
- arrow_table = fragment_scanner .to_table ()
1034
-
1035
- if len (arrow_table ) < 1 :
1036
- return None
1037
- return to_requested_schema (projected_schema , file_project_schema , arrow_table )
1026
+ yield to_requested_schema (projected_schema , file_project_schema , arrow_table )
1027
+ current_index += len (batch )
1038
1028
1039
1029
1040
1030
def _read_all_delete_files (fs : FileSystem , tasks : Iterable [FileScanTask ]) -> Dict [str , List [ChunkedArray ]]:
@@ -1147,7 +1137,7 @@ def project_table(
1147
1137
return result
1148
1138
1149
1139
1150
- def to_requested_schema (requested_schema : Schema , file_schema : Schema , table : pa .Table ) -> pa .Table :
1140
+ def to_requested_schema (requested_schema : Schema , file_schema : Schema , table : pa .RecordBatch ) -> pa .RecordBatch :
1151
1141
struct_array = visit_with_partner (requested_schema , table , ArrowProjectionVisitor (file_schema ), ArrowAccessor (file_schema ))
1152
1142
1153
1143
arrays = []
@@ -1156,7 +1146,7 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa
1156
1146
array = struct_array .field (pos )
1157
1147
arrays .append (array )
1158
1148
fields .append (pa .field (field .name , array .type , field .optional ))
1159
- return pa .Table .from_arrays (arrays , schema = pa .schema (fields ))
1149
+ return pa .RecordBatch .from_arrays (arrays , schema = pa .schema (fields ))
1160
1150
1161
1151
1162
1152
class ArrowProjectionVisitor (SchemaWithPartnerVisitor [pa .Array , Optional [pa .Array ]]):
0 commit comments