Skip to content

Commit b7b8ba0

Browse files
committed
Added tests and some refactoring
1 parent 398f6c0 commit b7b8ba0

File tree

3 files changed

+80
-6
lines changed

3 files changed

+80
-6
lines changed

pyiceberg/table/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def overwrite(
501501
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
502502
)
503503

504-
self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties,branch=branch)
504+
self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties, branch=branch)
505505

506506
with self.update_snapshot(branch=branch, snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
507507
# skip writing data files if the dataframe is empty

pyiceberg/table/update/snapshot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from pyiceberg.partitioning import (
5656
PartitionSpec,
5757
)
58-
from pyiceberg.table.refs import SnapshotRefType
58+
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRefType
5959
from pyiceberg.table.snapshots import (
6060
Operation,
6161
Snapshot,
@@ -622,7 +622,7 @@ class UpdateSnapshot:
622622
_snapshot_properties: Dict[str, str]
623623

624624
def __init__(
625-
self, transaction: Transaction, io: FileIO, branch: str, snapshot_properties: Dict[str, str] = EMPTY_DICT
625+
self, transaction: Transaction, io: FileIO, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH
626626
) -> None:
627627
self._transaction = transaction
628628
self._io = io

tests/integration/test_writes/test_writes.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pyiceberg.catalog.hive import HiveCatalog
4040
from pyiceberg.catalog.rest import RestCatalog
4141
from pyiceberg.catalog.sql import SqlCatalog
42-
from pyiceberg.exceptions import NoSuchTableError
42+
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError
4343
from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not
4444
from pyiceberg.io.pyarrow import _dataframe_to_data_files
4545
from pyiceberg.partitioning import PartitionField, PartitionSpec
@@ -1015,7 +1015,8 @@ def test_table_write_schema_with_valid_nullability_diff(
10151015
NestedField(field_id=1, name="long", field_type=LongType(), required=False),
10161016
)
10171017
other_schema = pa.schema((
1018-
pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field
1018+
pa.field("long", pa.int64(), nullable=False),
1019+
# can support writing required pyarrow field to optional Iceberg field
10191020
))
10201021
arrow_table = pa.Table.from_pydict(
10211022
{
@@ -1062,7 +1063,8 @@ def test_table_write_schema_with_valid_upcast(
10621063
pa.field("list", pa.large_list(pa.int64()), nullable=False),
10631064
pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False),
10641065
pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double
1065-
pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16
1066+
pa.field("uuid", pa.binary(length=16), nullable=True),
1067+
# can UUID is read as fixed length binary of length 16
10661068
))
10671069
)
10681070
lhs = spark.table(f"{identifier}").toPandas()
@@ -1448,3 +1450,75 @@ def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) ->
14481450
EqualTo("category", "A"),
14491451
),
14501452
)
1453+
1454+
1455+
@pytest.mark.integration
1456+
def test_append_to_non_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1457+
identifier = "default.test_non_existing_branch"
1458+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [])
1459+
with pytest.raises(CommitFailedException, match="No snapshot available in table for ref:"):
1460+
tbl.append(arrow_table_with_null, branch="non_existing_branch")
1461+
1462+
1463+
@pytest.mark.integration
1464+
def test_append_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1465+
identifier = "default.test_existing_branch_append"
1466+
branch = "existing_branch"
1467+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1468+
1469+
assert tbl.metadata.current_snapshot_id is not None
1470+
1471+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1472+
tbl.append(arrow_table_with_null, branch=branch)
1473+
1474+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 6
1475+
assert len(tbl.scan().to_arrow()) == 3
1476+
branch_snapshot = tbl.metadata.snapshot_by_name(branch)
1477+
assert branch_snapshot is not None
1478+
main_snapshot = tbl.metadata.snapshot_by_name("main")
1479+
assert main_snapshot is not None
1480+
assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id
1481+
1482+
1483+
@pytest.mark.integration
1484+
def test_delete_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1485+
identifier = "default.test_existing_branch_delete"
1486+
branch = "existing_branch"
1487+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1488+
1489+
assert tbl.metadata.current_snapshot_id is not None
1490+
1491+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1492+
tbl.delete(delete_filter="int = 9", branch=branch)
1493+
1494+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 2
1495+
assert len(tbl.scan().to_arrow()) == 3
1496+
branch_snapshot = tbl.metadata.snapshot_by_name(branch)
1497+
assert branch_snapshot is not None
1498+
main_snapshot = tbl.metadata.snapshot_by_name("main")
1499+
assert main_snapshot is not None
1500+
assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id
1501+
1502+
1503+
@pytest.mark.integration
1504+
def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1505+
identifier = "default.test_existing_branch_overwrite"
1506+
branch = "existing_branch"
1507+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1508+
1509+
assert tbl.metadata.current_snapshot_id is not None
1510+
1511+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1512+
tbl.overwrite(arrow_table_with_null, branch=branch)
1513+
1514+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 3
1515+
assert len(tbl.scan().to_arrow()) == 3
1516+
branch_snapshot = tbl.metadata.snapshot_by_name(branch)
1517+
assert branch_snapshot is not None and branch_snapshot.parent_snapshot_id is not None
1518+
delete_snapshot = tbl.metadata.snapshot_by_id(branch_snapshot.parent_snapshot_id)
1519+
assert delete_snapshot is not None
1520+
main_snapshot = tbl.metadata.snapshot_by_name("main")
1521+
assert main_snapshot is not None
1522+
assert (
1523+
delete_snapshot.parent_snapshot_id == main_snapshot.snapshot_id
1524+
) # Currently overwrite is a delete followed by an append operation

0 commit comments

Comments
 (0)