Skip to content

Commit 1ed3abd

Browse files
sungwyHonahXFokko
authored
Allow writing pa.Table that are either a subset of table schema or in arbitrary order, and support type promotion on write (#921)
* merge * thanks @HonahX :) Co-authored-by: Honah J. <[email protected]> * support promote * revert promote * use a visitor * support promotion on write * fix * Thank you @Fokko ! Co-authored-by: Fokko Driesprong <[email protected]> * revert * add-files promotiontest * support promote for add_files * add tests for uuid * add_files subset schema test --------- Co-authored-by: Honah J. <[email protected]> Co-authored-by: Fokko Driesprong <[email protected]>
1 parent 0f2e19e commit 1ed3abd

File tree

7 files changed

+545
-79
lines changed

7 files changed

+545
-79
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
Schema,
121121
SchemaVisitorPerPrimitiveType,
122122
SchemaWithPartnerVisitor,
123+
_check_schema_compatible,
123124
pre_order_visit,
124125
promote,
125126
prune_columns,
@@ -1407,7 +1408,7 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array:
14071408
# This can be removed once this has been fixed:
14081409
# https://github.com/apache/arrow/issues/38809
14091410
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)
1410-
1411+
value_array = self._cast_if_needed(list_type.element_field, value_array)
14111412
arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
14121413
return list_array.cast(arrow_field)
14131414
else:
@@ -1417,6 +1418,8 @@ def map(
14171418
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
14181419
) -> Optional[pa.Array]:
14191420
if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None:
1421+
key_result = self._cast_if_needed(map_type.key_field, key_result)
1422+
value_result = self._cast_if_needed(map_type.value_field, value_result)
14201423
arrow_field = pa.map_(
14211424
self._construct_field(map_type.key_field, key_result.type),
14221425
self._construct_field(map_type.value_field, value_result.type),
@@ -1549,9 +1552,16 @@ def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc
15491552

15501553
expected_physical_type = _primitive_to_physical(iceberg_type)
15511554
if expected_physical_type != physical_type_string:
1552-
raise ValueError(
1553-
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
1554-
)
1555+
# Allow promotable physical types
1556+
# INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts
1557+
if (physical_type_string == "INT32" and expected_physical_type == "INT64") or (
1558+
physical_type_string == "FLOAT" and expected_physical_type == "DOUBLE"
1559+
):
1560+
pass
1561+
else:
1562+
raise ValueError(
1563+
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
1564+
)
15551565

15561566
self.primitive_type = iceberg_type
15571567

@@ -1896,16 +1906,6 @@ def data_file_statistics_from_parquet_metadata(
18961906
set the mode for column metrics collection
18971907
parquet_column_mapping (Dict[str, int]): The mapping of the parquet file name to the field ID
18981908
"""
1899-
if parquet_metadata.num_columns != len(stats_columns):
1900-
raise ValueError(
1901-
f"Number of columns in statistics configuration ({len(stats_columns)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})"
1902-
)
1903-
1904-
if parquet_metadata.num_columns != len(parquet_column_mapping):
1905-
raise ValueError(
1906-
f"Number of columns in column mapping ({len(parquet_column_mapping)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})"
1907-
)
1908-
19091909
column_sizes: Dict[int, int] = {}
19101910
value_counts: Dict[int, int] = {}
19111911
split_offsets: List[int] = []
@@ -1998,8 +1998,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
19981998
)
19991999

20002000
def write_parquet(task: WriteTask) -> DataFile:
2001-
table_schema = task.schema
2002-
2001+
table_schema = table_metadata.schema()
20032002
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
20042003
# otherwise use the original schema
20052004
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
@@ -2011,7 +2010,7 @@ def write_parquet(task: WriteTask) -> DataFile:
20112010
batches = [
20122011
_to_requested_schema(
20132012
requested_schema=file_schema,
2014-
file_schema=table_schema,
2013+
file_schema=task.schema,
20152014
batch=batch,
20162015
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
20172016
include_field_ids=True,
@@ -2070,47 +2069,30 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[
20702069
return bin_packed_record_batches
20712070

20722071

2073-
def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None:
2072+
def _check_pyarrow_schema_compatible(
2073+
requested_schema: Schema, provided_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False
2074+
) -> None:
20742075
"""
2075-
Check if the `table_schema` is compatible with `other_schema`.
2076+
Check if the `requested_schema` is compatible with `provided_schema`.
20762077
20772078
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
20782079
20792080
Raises:
20802081
ValueError: If the schemas are not compatible.
20812082
"""
2082-
name_mapping = table_schema.name_mapping
2083+
name_mapping = requested_schema.name_mapping
20832084
try:
2084-
task_schema = pyarrow_to_schema(
2085-
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
2085+
provided_schema = pyarrow_to_schema(
2086+
provided_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
20862087
)
20872088
except ValueError as e:
2088-
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
2089-
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
2089+
provided_schema = _pyarrow_to_schema_without_ids(provided_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
2090+
additional_names = set(provided_schema._name_to_id.keys()) - set(requested_schema._name_to_id.keys())
20902091
raise ValueError(
20912092
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
20922093
) from e
20932094

2094-
if table_schema.as_struct() != task_schema.as_struct():
2095-
from rich.console import Console
2096-
from rich.table import Table as RichTable
2097-
2098-
console = Console(record=True)
2099-
2100-
rich_table = RichTable(show_header=True, header_style="bold")
2101-
rich_table.add_column("")
2102-
rich_table.add_column("Table field")
2103-
rich_table.add_column("Dataframe field")
2104-
2105-
for lhs in table_schema.fields:
2106-
try:
2107-
rhs = task_schema.find_field(lhs.field_id)
2108-
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
2109-
except ValueError:
2110-
rich_table.add_row("❌", str(lhs), "Missing")
2111-
2112-
console.print(rich_table)
2113-
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
2095+
_check_schema_compatible(requested_schema, provided_schema)
21142096

21152097

21162098
def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
@@ -2124,7 +2106,7 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_
21242106
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
21252107
)
21262108
schema = table_metadata.schema()
2127-
_check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())
2109+
_check_pyarrow_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())
21282110

21292111
statistics = data_file_statistics_from_parquet_metadata(
21302112
parquet_metadata=parquet_metadata,
@@ -2205,7 +2187,7 @@ def _dataframe_to_data_files(
22052187
Returns:
22062188
An iterable that supplies datafiles that represent the table.
22072189
"""
2208-
from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
2190+
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties, WriteTask
22092191

22102192
counter = counter or itertools.count(0)
22112193
write_uuid = write_uuid or uuid.uuid4()
@@ -2214,13 +2196,16 @@ def _dataframe_to_data_files(
22142196
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
22152197
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
22162198
)
2199+
name_mapping = table_metadata.schema().name_mapping
2200+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
2201+
task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
22172202

22182203
if table_metadata.spec().is_unpartitioned():
22192204
yield from write_file(
22202205
io=io,
22212206
table_metadata=table_metadata,
22222207
tasks=iter([
2223-
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
2208+
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
22242209
for batches in bin_pack_arrow_table(df, target_file_size)
22252210
]),
22262211
)
@@ -2235,7 +2220,7 @@ def _dataframe_to_data_files(
22352220
task_id=next(counter),
22362221
record_batches=batches,
22372222
partition_key=partition.partition_key,
2238-
schema=table_metadata.schema(),
2223+
schema=task_schema,
22392224
)
22402225
for partition in partitions
22412226
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)

pyiceberg/schema.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,3 +1616,103 @@ def _(file_type: FixedType, read_type: IcebergType) -> IcebergType:
16161616
return read_type
16171617
else:
16181618
raise ResolveError(f"Cannot promote {file_type} to {read_type}")
1619+
1620+
1621+
def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None:
1622+
"""
1623+
Check if the `provided_schema` is compatible with `requested_schema`.
1624+
1625+
Both Schemas must have valid IDs and share the same ID for the same field names.
1626+
1627+
Two schemas are considered compatible when:
1628+
1. All `required` fields in `requested_schema` are present and are also `required` in the `provided_schema`
1629+
2. Field Types are consistent for fields that are present in both schemas. I.e. the field type
1630+
in the `provided_schema` can be promoted to the field type of the same field ID in `requested_schema`
1631+
1632+
Raises:
1633+
ValueError: If the schemas are not compatible.
1634+
"""
1635+
pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema))
1636+
1637+
1638+
class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]):
1639+
provided_schema: Schema
1640+
1641+
def __init__(self, provided_schema: Schema):
1642+
from rich.console import Console
1643+
from rich.table import Table as RichTable
1644+
1645+
self.provided_schema = provided_schema
1646+
self.rich_table = RichTable(show_header=True, header_style="bold")
1647+
self.rich_table.add_column("")
1648+
self.rich_table.add_column("Table field")
1649+
self.rich_table.add_column("Dataframe field")
1650+
self.console = Console(record=True)
1651+
1652+
def _is_field_compatible(self, lhs: NestedField) -> bool:
1653+
# Validate nullability first.
1654+
# An optional field can be missing in the provided schema
1655+
# But a required field must exist as a required field
1656+
try:
1657+
rhs = self.provided_schema.find_field(lhs.field_id)
1658+
except ValueError:
1659+
if lhs.required:
1660+
self.rich_table.add_row("❌", str(lhs), "Missing")
1661+
return False
1662+
else:
1663+
self.rich_table.add_row("✅", str(lhs), "Missing")
1664+
return True
1665+
1666+
if lhs.required and not rhs.required:
1667+
self.rich_table.add_row("❌", str(lhs), str(rhs))
1668+
return False
1669+
1670+
# Check type compatibility
1671+
if lhs.field_type == rhs.field_type:
1672+
self.rich_table.add_row("✅", str(lhs), str(rhs))
1673+
return True
1674+
# We only check that the parent node is also of the same type.
1675+
# We check the type of the child nodes when we traverse them later.
1676+
elif any(
1677+
(isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type))
1678+
for container_type in {StructType, MapType, ListType}
1679+
):
1680+
self.rich_table.add_row("✅", str(lhs), str(rhs))
1681+
return True
1682+
else:
1683+
try:
1684+
# If type can be promoted to the requested schema
1685+
# it is considered compatible
1686+
promote(rhs.field_type, lhs.field_type)
1687+
self.rich_table.add_row("✅", str(lhs), str(rhs))
1688+
return True
1689+
except ResolveError:
1690+
self.rich_table.add_row("❌", str(lhs), str(rhs))
1691+
return False
1692+
1693+
def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> bool:
1694+
if not (result := struct_result()):
1695+
self.console.print(self.rich_table)
1696+
raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}")
1697+
return result
1698+
1699+
def struct(self, struct: StructType, field_results: List[Callable[[], bool]]) -> bool:
1700+
results = [result() for result in field_results]
1701+
return all(results)
1702+
1703+
def field(self, field: NestedField, field_result: Callable[[], bool]) -> bool:
1704+
return self._is_field_compatible(field) and field_result()
1705+
1706+
def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool:
1707+
return self._is_field_compatible(list_type.element_field) and element_result()
1708+
1709+
def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool:
1710+
return all([
1711+
self._is_field_compatible(map_type.key_field),
1712+
self._is_field_compatible(map_type.value_field),
1713+
key_result(),
1714+
value_result(),
1715+
])
1716+
1717+
def primitive(self, primitive: PrimitiveType) -> bool:
1718+
return True

pyiceberg/table/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
manifest_evaluator,
7474
)
7575
from pyiceberg.io import FileIO, OutputFile, load_file_io
76-
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
7776
from pyiceberg.manifest import (
7877
POSITIONAL_DELETE_SCHEMA,
7978
DataFile,
@@ -471,6 +470,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
471470
except ModuleNotFoundError as e:
472471
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
473472

473+
from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files
474+
474475
if not isinstance(df, pa.Table):
475476
raise ValueError(f"Expected PyArrow table, got: {df}")
476477

@@ -481,8 +482,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
481482
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
482483
)
483484
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
484-
_check_schema_compatible(
485-
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
485+
_check_pyarrow_schema_compatible(
486+
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
486487
)
487488

488489
manifest_merge_enabled = PropertyUtil.property_as_bool(
@@ -528,6 +529,8 @@ def overwrite(
528529
except ModuleNotFoundError as e:
529530
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
530531

532+
from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files
533+
531534
if not isinstance(df, pa.Table):
532535
raise ValueError(f"Expected PyArrow table, got: {df}")
533536

@@ -538,8 +541,8 @@ def overwrite(
538541
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
539542
)
540543
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
541-
_check_schema_compatible(
542-
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
544+
_check_pyarrow_schema_compatible(
545+
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
543546
)
544547

545548
self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)
@@ -566,6 +569,8 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
566569
delete_filter: A boolean expression to delete rows from a table
567570
snapshot_properties: Custom properties to be added to the snapshot summary
568571
"""
572+
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
573+
569574
if (
570575
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
571576
== TableProperties.DELETE_MODE_MERGE_ON_READ

0 commit comments

Comments
 (0)