Skip to content

Commit dc76dad

Browse files
committed
_task_to_table to _task_to_record_batches
1 parent 65a03d2 commit dc76dad

File tree

1 file changed

+23
-33
lines changed

1 file changed

+23
-33
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -648,12 +648,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr
648648
}
649649

650650

651-
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array:
651+
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array:
652652
if len(positional_deletes) == 1:
653653
all_chunks = positional_deletes[0]
654654
else:
655655
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
656-
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
656+
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index)
657657

658658

659659
def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
@@ -960,17 +960,16 @@ def _field_id(self, field: pa.Field) -> int:
960960
return -1
961961

962962

963-
def _task_to_table(
963+
def _task_to_record_batches(
964964
fs: FileSystem,
965965
task: FileScanTask,
966966
bound_row_filter: BooleanExpression,
967967
projected_schema: Schema,
968968
projected_field_ids: Set[int],
969969
positional_deletes: Optional[List[ChunkedArray]],
970970
case_sensitive: bool,
971-
limit: Optional[int] = None,
972971
name_mapping: Optional[NameMapping] = None,
973-
) -> Optional[pa.Table]:
972+
) -> Iterator[pa.RecordBatch]:
974973
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
975974
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
976975
with fs.open_input_file(path) as fin:
@@ -998,36 +997,27 @@ def _task_to_table(
998997
columns=[col.name for col in file_project_schema.columns],
999998
)
1000999

1001-
if positional_deletes:
1002-
# Create the mask of indices that we're interested in
1003-
indices = _combine_positional_deletes(positional_deletes, fragment.count_rows())
1004-
1005-
if limit:
1006-
if pyarrow_filter is not None:
1007-
# In case of the filter, we don't exactly know how many rows
1008-
# we need to fetch upfront, can be optimized in the future:
1009-
# https://github.com/apache/arrow/issues/35301
1010-
arrow_table = fragment_scanner.take(indices)
1011-
arrow_table = arrow_table.filter(pyarrow_filter)
1012-
arrow_table = arrow_table.slice(0, limit)
1013-
else:
1014-
arrow_table = fragment_scanner.take(indices[0:limit])
1015-
else:
1016-
arrow_table = fragment_scanner.take(indices)
1000+
current_index = 0
1001+
batches = fragment_scanner.to_batches()
1002+
for batch in batches:
1003+
if positional_deletes:
1004+
# Create the mask of indices that we're interested in
1005+
indices = _combine_positional_deletes(positional_deletes, current_index, len(batch))
1006+
1007+
batch = batch.take(indices)
10171008
# Apply the user filter
10181009
if pyarrow_filter is not None:
1010+
# we need to switch back and forth between RecordBatch and Table
1011+
# as Expression filter isn't yet supported in RecordBatch
1012+
# https://github.com/apache/arrow/issues/39220
1013+
arrow_table = pa.Table.from_batches([batch])
10191014
arrow_table = arrow_table.filter(pyarrow_filter)
1020-
else:
1021-
# If there are no deletes, we can just take the head
1022-
# and the user-filter is already applied
1023-
if limit:
1024-
arrow_table = fragment_scanner.head(limit)
1015+
arrow_batches = arrow_table.to_batches()
1016+
for arrow_batch in arrow_batches:
1017+
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
10251018
else:
1026-
arrow_table = fragment_scanner.to_table()
1027-
1028-
if len(arrow_table) < 1:
1029-
return None
1030-
return to_requested_schema(projected_schema, file_project_schema, arrow_table)
1019+
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
1020+
current_index += len(batch)
10311021

10321022

10331023
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
@@ -1140,7 +1130,7 @@ def project_table(
11401130
return result
11411131

11421132

1143-
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
1133+
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.RecordBatch) -> pa.RecordBatch:
11441134
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
11451135

11461136
arrays = []
@@ -1149,7 +1139,7 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa
11491139
array = struct_array.field(pos)
11501140
arrays.append(array)
11511141
fields.append(pa.field(field.name, array.type, field.optional))
1152-
return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
1142+
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
11531143

11541144

11551145
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):

0 commit comments

Comments
 (0)