Skip to content

Commit 1629d28

Browse files
committed
_task_to_table to _task_to_record_batches
1 parent e61ef57 commit 1629d28

File tree

1 file changed

+23
-33
lines changed

1 file changed

+23
-33
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr
655655
}
656656

657657

658-
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array:
658+
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array:
659659
if len(positional_deletes) == 1:
660660
all_chunks = positional_deletes[0]
661661
else:
662662
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
663-
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
663+
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index)
664664

665665

666666
def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
@@ -967,17 +967,16 @@ def _field_id(self, field: pa.Field) -> int:
967967
return -1
968968

969969

970-
def _task_to_table(
970+
def _task_to_record_batches(
971971
fs: FileSystem,
972972
task: FileScanTask,
973973
bound_row_filter: BooleanExpression,
974974
projected_schema: Schema,
975975
projected_field_ids: Set[int],
976976
positional_deletes: Optional[List[ChunkedArray]],
977977
case_sensitive: bool,
978-
limit: Optional[int] = None,
979978
name_mapping: Optional[NameMapping] = None,
980-
) -> Optional[pa.Table]:
979+
) -> Iterator[pa.RecordBatch]:
981980
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
982981
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
983982
with fs.open_input_file(path) as fin:
@@ -1005,36 +1004,27 @@ def _task_to_table(
10051004
columns=[col.name for col in file_project_schema.columns],
10061005
)
10071006

1008-
if positional_deletes:
1009-
# Create the mask of indices that we're interested in
1010-
indices = _combine_positional_deletes(positional_deletes, fragment.count_rows())
1011-
1012-
if limit:
1013-
if pyarrow_filter is not None:
1014-
# In case of the filter, we don't exactly know how many rows
1015-
# we need to fetch upfront, can be optimized in the future:
1016-
# https://github.com/apache/arrow/issues/35301
1017-
arrow_table = fragment_scanner.take(indices)
1018-
arrow_table = arrow_table.filter(pyarrow_filter)
1019-
arrow_table = arrow_table.slice(0, limit)
1020-
else:
1021-
arrow_table = fragment_scanner.take(indices[0:limit])
1022-
else:
1023-
arrow_table = fragment_scanner.take(indices)
1007+
current_index = 0
1008+
batches = fragment_scanner.to_batches()
1009+
for batch in batches:
1010+
if positional_deletes:
1011+
# Create the mask of indices that we're interested in
1012+
indices = _combine_positional_deletes(positional_deletes, current_index, len(batch))
1013+
1014+
batch = batch.take(indices)
10241015
# Apply the user filter
10251016
if pyarrow_filter is not None:
1017+
# we need to switch back and forth between RecordBatch and Table
1018+
# as Expression filter isn't yet supported in RecordBatch
1019+
# https://github.com/apache/arrow/issues/39220
1020+
arrow_table = pa.Table.from_batches([batch])
10261021
arrow_table = arrow_table.filter(pyarrow_filter)
1027-
else:
1028-
# If there are no deletes, we can just take the head
1029-
# and the user-filter is already applied
1030-
if limit:
1031-
arrow_table = fragment_scanner.head(limit)
1022+
arrow_batches = arrow_table.to_batches()
1023+
for arrow_batch in arrow_batches:
1024+
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
10321025
else:
1033-
arrow_table = fragment_scanner.to_table()
1034-
1035-
if len(arrow_table) < 1:
1036-
return None
1037-
return to_requested_schema(projected_schema, file_project_schema, arrow_table)
1026+
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
1027+
current_index += len(batch)
10381028

10391029

10401030
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
@@ -1147,7 +1137,7 @@ def project_table(
11471137
return result
11481138

11491139

1150-
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
1140+
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.RecordBatch) -> pa.RecordBatch:
11511141
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
11521142

11531143
arrays = []
@@ -1156,7 +1146,7 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa
11561146
array = struct_array.field(pos)
11571147
arrays.append(array)
11581148
fields.append(pa.field(field.name, array.type, field.optional))
1159-
return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
1149+
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
11601150

11611151

11621152
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):

0 commit comments

Comments
 (0)