Skip to content

Commit 46788ac

Browse files
committed
to_arrow_batches
1 parent dc76dad commit 46788ac

File tree

2 files changed

+116
-15
lines changed

2 files changed

+116
-15
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 103 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,24 +1002,39 @@ def _task_to_record_batches(
10021002
for batch in batches:
10031003
if positional_deletes:
10041004
# Create the mask of indices that we're interested in
1005-
indices = _combine_positional_deletes(positional_deletes, current_index, len(batch))
1006-
1005+
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
1006+
print(f"DEBUG: {indices=} {current_index=} {len(batch)=}")
1007+
print(f"{batch=}")
10071008
batch = batch.take(indices)
1009+
print(f"{batch=}")
10081010
# Apply the user filter
10091011
if pyarrow_filter is not None:
10101012
# we need to switch back and forth between RecordBatch and Table
10111013
# as Expression filter isn't yet supported in RecordBatch
10121014
# https://github.com/apache/arrow/issues/39220
10131015
arrow_table = pa.Table.from_batches([batch])
10141016
arrow_table = arrow_table.filter(pyarrow_filter)
1015-
arrow_batches = arrow_table.to_batches()
1016-
for arrow_batch in arrow_batches:
1017-
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
1018-
else:
1019-
yield to_requested_schema(projected_schema, file_project_schema, arrow_table)
1017+
batch = arrow_table.to_batches()[0]
1018+
yield to_requested_schema(projected_schema, file_project_schema, batch)
10201019
current_index += len(batch)
10211020

10221021

1022+
def _task_to_table(
1023+
fs: FileSystem,
1024+
task: FileScanTask,
1025+
bound_row_filter: BooleanExpression,
1026+
projected_schema: Schema,
1027+
projected_field_ids: Set[int],
1028+
positional_deletes: Optional[List[ChunkedArray]],
1029+
case_sensitive: bool,
1030+
name_mapping: Optional[NameMapping] = None,
1031+
) -> pa.Table:
1032+
batches = _task_to_record_batches(
1033+
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
1034+
)
1035+
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema))
1036+
1037+
10231038
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
10241039
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
10251040
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
@@ -1096,7 +1111,6 @@ def project_table(
10961111
projected_field_ids,
10971112
deletes_per_file.get(task.file.file_path),
10981113
case_sensitive,
1099-
limit,
11001114
table_metadata.name_mapping(),
11011115
)
11021116
for task in tasks
@@ -1130,8 +1144,78 @@ def project_table(
11301144
return result
11311145

11321146

1133-
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.RecordBatch) -> pa.RecordBatch:
1134-
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
1147+
def project_batches(
1148+
tasks: Iterable[FileScanTask],
1149+
table_metadata: TableMetadata,
1150+
io: FileIO,
1151+
row_filter: BooleanExpression,
1152+
projected_schema: Schema,
1153+
case_sensitive: bool = True,
1154+
limit: Optional[int] = None,
1155+
) -> Iterator[pa.ReordBatch]:
1156+
"""Resolve the right columns based on the identifier.
1157+
1158+
Args:
1159+
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
1160+
table_metadata (TableMetadata): The table metadata of the table that's being queried
1161+
io (FileIO): A FileIO to open streams to the object store
1162+
row_filter (BooleanExpression): The expression for filtering rows.
1163+
projected_schema (Schema): The output schema.
1164+
case_sensitive (bool): Case sensitivity when looking up column names.
1165+
limit (Optional[int]): Limit the number of records.
1166+
1167+
Raises:
1168+
ResolveError: When an incompatible query is done.
1169+
"""
1170+
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
1171+
if isinstance(io, PyArrowFileIO):
1172+
fs = io.fs_by_scheme(scheme, netloc)
1173+
else:
1174+
try:
1175+
from pyiceberg.io.fsspec import FsspecFileIO
1176+
1177+
if isinstance(io, FsspecFileIO):
1178+
from pyarrow.fs import PyFileSystem
1179+
1180+
fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
1181+
else:
1182+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
1183+
except ModuleNotFoundError as e:
1184+
# When FsSpec is not installed
1185+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e
1186+
1187+
bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
1188+
1189+
projected_field_ids = {
1190+
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
1191+
}.union(extract_field_ids(bound_row_filter))
1192+
1193+
deletes_per_file = _read_all_delete_files(fs, tasks)
1194+
1195+
total_row_count = 0
1196+
1197+
for task in tasks:
1198+
batches = _task_to_record_batches(
1199+
fs,
1200+
task,
1201+
bound_row_filter,
1202+
projected_schema,
1203+
projected_field_ids,
1204+
deletes_per_file.get(task.file.file_path),
1205+
case_sensitive,
1206+
table_metadata.name_mapping(),
1207+
)
1208+
for batch in batches:
1209+
if limit is not None:
1210+
if total_row_count + len(batch) >= limit:
1211+
yield batch.take(limit - total_row_count)
1212+
break
1213+
yield batch
1214+
total_row_count += len(batch)
1215+
1216+
1217+
def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
1218+
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
11351219

11361220
arrays = []
11371221
fields = []
@@ -1240,8 +1324,8 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st
12401324

12411325
if isinstance(partner_struct, pa.StructArray):
12421326
return partner_struct.field(name)
1243-
elif isinstance(partner_struct, pa.Table):
1244-
return partner_struct.column(name).combine_chunks()
1327+
elif isinstance(partner_struct, pa.RecordBatch):
1328+
return partner_struct.column(name)
12451329

12461330
return None
12471331

@@ -1778,15 +1862,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
17781862

17791863
def write_parquet(task: WriteTask) -> DataFile:
17801864
table_schema = task.schema
1781-
arrow_table = pa.Table.from_batches(task.record_batches)
1865+
17821866
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
17831867
# otherwise use the original schema
17841868
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
17851869
file_schema = sanitized_schema
1786-
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
1870+
batches = [
1871+
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
1872+
for batch in task.record_batches
1873+
]
17871874
else:
17881875
file_schema = table_schema
1789-
1876+
batches = task.record_batches
1877+
arrow_table = pa.Table.from_batches(batches)
17901878
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
17911879
fo = io.new_output(file_path)
17921880
with fo.create(overwrite=True) as fos:

pyiceberg/table/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,6 +1763,19 @@ def to_arrow(self) -> pa.Table:
17631763
limit=self.limit,
17641764
)
17651765

1766+
def to_arrow_batches(self) -> pa.Table:
1767+
from pyiceberg.io.pyarrow import project_batches
1768+
1769+
return project_batches(
1770+
self.plan_files(),
1771+
self.table_metadata,
1772+
self.io,
1773+
self.row_filter,
1774+
self.projection(),
1775+
case_sensitive=self.case_sensitive,
1776+
limit=self.limit,
1777+
)
1778+
17661779
def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
17671780
return self.to_arrow().to_pandas(**kwargs)
17681781

0 commit comments

Comments
 (0)