Skip to content

Commit f604b15

Browse files
committed
to_arrow_batches
1 parent 1629d28 commit f604b15

File tree

2 files changed

+116
-15
lines changed

2 files changed

+116
-15
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 103 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,24 +1009,39 @@ def _task_to_record_batches(
10091009
for batch in batches:
10101010
if positional_deletes:
10111011
# Create the mask of indices that we're interested in
1012-
indices = _combine_positional_deletes(positional_deletes, current_index, len(batch))
1013-
1012+
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
1013+
print(f"DEBUG: {indices=} {current_index=} {len(batch)=}")
1014+
print(f"{batch=}")
10141015
batch = batch.take(indices)
1016+
print(f"{batch=}")
10151017
# Apply the user filter
10161018
if pyarrow_filter is not None:
10171019
# we need to switch back and forth between RecordBatch and Table
10181020
# as Expression filter isn't yet supported in RecordBatch
10191021
# https://github.com/apache/arrow/issues/39220
10201022
arrow_table = pa.Table.from_batches([batch])
10211023
arrow_table = arrow_table.filter(pyarrow_filter)
1022-
arrow_batches = arrow_table.to_batches()
1023-
for arrow_batch in arrow_batches:
1024-
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
1025-
else:
1026-
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
1024+
batch = arrow_table.to_batches()[0]
1025+
yield to_requested_schema(projected_schema, file_project_schema, batch)
10271026
current_index += len(batch)
10281027

10291028

1029+
def _task_to_table(
1030+
fs: FileSystem,
1031+
task: FileScanTask,
1032+
bound_row_filter: BooleanExpression,
1033+
projected_schema: Schema,
1034+
projected_field_ids: Set[int],
1035+
positional_deletes: Optional[List[ChunkedArray]],
1036+
case_sensitive: bool,
1037+
name_mapping: Optional[NameMapping] = None,
1038+
) -> pa.Table:
1039+
batches = _task_to_record_batches(
1040+
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
1041+
)
1042+
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema))
1043+
1044+
10301045
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
10311046
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
10321047
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
@@ -1103,7 +1118,6 @@ def project_table(
11031118
projected_field_ids,
11041119
deletes_per_file.get(task.file.file_path),
11051120
case_sensitive,
1106-
limit,
11071121
table_metadata.name_mapping(),
11081122
)
11091123
for task in tasks
@@ -1137,8 +1151,78 @@ def project_table(
11371151
return result
11381152

11391153

1140-
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.RecordBatch) -> pa.RecordBatch:
1141-
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
1154+
def project_batches(
1155+
tasks: Iterable[FileScanTask],
1156+
table_metadata: TableMetadata,
1157+
io: FileIO,
1158+
row_filter: BooleanExpression,
1159+
projected_schema: Schema,
1160+
case_sensitive: bool = True,
1161+
limit: Optional[int] = None,
1162+
) -> Iterator[pa.ReordBatch]:
1163+
"""Resolve the right columns based on the identifier.
1164+
1165+
Args:
1166+
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
1167+
table_metadata (TableMetadata): The table metadata of the table that's being queried
1168+
io (FileIO): A FileIO to open streams to the object store
1169+
row_filter (BooleanExpression): The expression for filtering rows.
1170+
projected_schema (Schema): The output schema.
1171+
case_sensitive (bool): Case sensitivity when looking up column names.
1172+
limit (Optional[int]): Limit the number of records.
1173+
1174+
Raises:
1175+
ResolveError: When an incompatible query is done.
1176+
"""
1177+
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
1178+
if isinstance(io, PyArrowFileIO):
1179+
fs = io.fs_by_scheme(scheme, netloc)
1180+
else:
1181+
try:
1182+
from pyiceberg.io.fsspec import FsspecFileIO
1183+
1184+
if isinstance(io, FsspecFileIO):
1185+
from pyarrow.fs import PyFileSystem
1186+
1187+
fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
1188+
else:
1189+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
1190+
except ModuleNotFoundError as e:
1191+
# When FsSpec is not installed
1192+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e
1193+
1194+
bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
1195+
1196+
projected_field_ids = {
1197+
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
1198+
}.union(extract_field_ids(bound_row_filter))
1199+
1200+
deletes_per_file = _read_all_delete_files(fs, tasks)
1201+
1202+
total_row_count = 0
1203+
1204+
for task in tasks:
1205+
batches = _task_to_record_batches(
1206+
fs,
1207+
task,
1208+
bound_row_filter,
1209+
projected_schema,
1210+
projected_field_ids,
1211+
deletes_per_file.get(task.file.file_path),
1212+
case_sensitive,
1213+
table_metadata.name_mapping(),
1214+
)
1215+
for batch in batches:
1216+
if limit is not None:
1217+
if total_row_count + len(batch) >= limit:
1218+
yield batch.take(limit - total_row_count)
1219+
break
1220+
yield batch
1221+
total_row_count += len(batch)
1222+
1223+
1224+
def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
1225+
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
11421226

11431227
arrays = []
11441228
fields = []
@@ -1247,8 +1331,8 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st
12471331

12481332
if isinstance(partner_struct, pa.StructArray):
12491333
return partner_struct.field(name)
1250-
elif isinstance(partner_struct, pa.Table):
1251-
return partner_struct.column(name).combine_chunks()
1334+
elif isinstance(partner_struct, pa.RecordBatch):
1335+
return partner_struct.column(name)
12521336

12531337
return None
12541338

@@ -1785,15 +1869,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
17851869

17861870
def write_parquet(task: WriteTask) -> DataFile:
17871871
table_schema = task.schema
1788-
arrow_table = pa.Table.from_batches(task.record_batches)
1872+
17891873
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
17901874
# otherwise use the original schema
17911875
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
17921876
file_schema = sanitized_schema
1793-
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
1877+
batches = [
1878+
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
1879+
for batch in task.record_batches
1880+
]
17941881
else:
17951882
file_schema = table_schema
1796-
1883+
batches = task.record_batches
1884+
arrow_table = pa.Table.from_batches(batches)
17971885
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
17981886
fo = io.new_output(file_path)
17991887
with fo.create(overwrite=True) as fos:

pyiceberg/table/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,6 +1763,19 @@ def to_arrow(self) -> pa.Table:
17631763
limit=self.limit,
17641764
)
17651765

1766+
def to_arrow_batches(self) -> pa.Table:
1767+
from pyiceberg.io.pyarrow import project_batches
1768+
1769+
return project_batches(
1770+
self.plan_files(),
1771+
self.table_metadata,
1772+
self.io,
1773+
self.row_filter,
1774+
self.projection(),
1775+
case_sensitive=self.case_sensitive,
1776+
limit=self.limit,
1777+
)
1778+
17661779
def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
17671780
return self.to_arrow().to_pandas(**kwargs)
17681781

0 commit comments

Comments
 (0)