Skip to content

Commit c60938f

Browse files
committed
updates
1 parent 1bd61c9 commit c60938f

File tree

2 files changed

+58
-38
lines changed

2 files changed

+58
-38
lines changed

pyiceberg/table/__init__.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,21 +1291,29 @@ def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
12911291
return None
12921292

12931293
def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
1294+
"""Rollback the table to the given snapshot id, whose snapshot needs to be an ancestor of the current table state."""
12941295
if self.snapshot_by_id(snapshot_id) is None:
12951296
raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")
12961297
if snapshot_id not in set(ancestor.ancestor_id for ancestor in ancestors_of(self.current_snapshot(), self._transaction._table.metadata)):
12971298
raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")
12981299
return self._transaction.set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch")
12991300

13001301
def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots:
1301-
if (snapshot := self.snapshot_as_of_timestamp(timestamp)) is None:
1302+
"""Rollback the table to the snapshot right before the given timestamp."""
1303+
if (snapshot := self.snapshot_as_of_timestamp(timestamp, inclusive=False)) is None:
13021304
raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}")
13031305
return self._transaction.set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch")
13041306

1305-
def set_current_snapshot(self, snapshot_id: int) -> ManageSnapshots:
1306-
if self.snapshot_by_id(snapshot_id) is None:
1307-
raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")
1308-
return self._transaction.set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch")
1307+
def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots:
1308+
"""Set the table to a specific snapshot identified either by its id or the branch/tag its on, not both."""
1309+
if (snapshot_id or ref_name) and not (snapshot_id or ref_name):
1310+
target_snapshot_id = snapshot_id if snapshot_id is not None else self.metadata.refs[ref_name].snapshot_id
1311+
snapshot = self.snapshot_by_id(target_snapshot_id)
1312+
else:
1313+
raise ValidationError("Either snapshot_id or ref must be provided")
1314+
if snapshot is None:
1315+
raise ValidationError(f"Cannot set snapshot current with snapshot id: {snapshot_id} or ref_name: {ref_name}")
1316+
return self._transaction.set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch")
13091317

13101318
def history(self) -> List[SnapshotLogEntry]:
13111319
"""Get the snapshot history of this table."""

tests/integration/test_writes/test_writes.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from pyiceberg.exceptions import NoSuchTableError
3939
from pyiceberg.partitioning import PartitionField, PartitionSpec
4040
from pyiceberg.schema import Schema
41-
from pyiceberg.table import TableProperties, _dataframe_to_data_files
41+
from pyiceberg.table import TableProperties, _dataframe_to_data_files, SnapshotRef
4242
from pyiceberg.transforms import IdentityTransform
4343
from pyiceberg.types import IntegerType, NestedField
4444
from tests.conftest import TEST_DATA_WITH_NULL
@@ -873,39 +873,51 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null
873873

874874
@pytest.mark.integration
875875
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
876-
def test_update_metadata_rollback_to_snapshot(table_v2: Table) -> None:
877-
assert table_v2.rollback_to_snapshot(snapshot_id=3051729675574597004) == Snapshot(
878-
snapshot_id=3051729675574597004,
879-
parent_snapshot_id=None,
880-
sequence_number=0,
881-
timestamp_ms=1515100955770,
882-
manifest_list='s3://a/b/1.avro',
883-
summary=Summary(Operation.APPEND),
884-
schema_id=None,
885-
)
876+
def test_rollback_to_snapshot(catalog: Catalog) -> None:
877+
identifier = "default.test_table_snapshot_operations"
878+
tbl = catalog.load_table(identifier)
879+
assert len(tbl.history()) > 3
880+
rollback_snapshot_id = tbl.history()[-3].snapshot_id
881+
current_snapshot_id = tbl.current_snapshot().snapshot_id
882+
ms = tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id)
883+
ms.commit()
884+
tbl.refresh()
885+
assert tbl.current_snapshot().snapshot_id is not current_snapshot_id
886+
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id)
886887

887888

888889
@pytest.mark.integration
889890
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
890-
def test_update_metadata_rollback_to_timestamp(table_v2: Table) -> None:
891-
assert table_v2.rollback_to_timestamp(timestamp=1555100955771) == Snapshot(
892-
snapshot_id=3055729675574597004,
893-
parent_snapshot_id=3051729675574597004,
894-
sequence_number=1,
895-
timestamp_ms=1555100955770,
896-
manifest_list='s3://a/b/2.avro',
897-
summary=Summary(Operation.APPEND),
898-
schema_id=1,
899-
)
900-
901-
902-
def test_update_metadata_set_current_snapshot(table_v2: Table) -> None:
903-
assert table_v2.set_current_snapshot(snapshot_id=3051729675574597004) == Snapshot(
904-
snapshot_id=3051729675574597004,
905-
parent_snapshot_id=None,
906-
sequence_number=0,
907-
timestamp_ms=1515100955770,
908-
manifest_list='s3://a/b/1.avro',
909-
summary=Summary(Operation.APPEND),
910-
schema_id=None,
911-
)
891+
def test_rollback_to_timestamp(catalog: Catalog) -> None:
892+
identifier = "default.test_table_snapshot_operations"
893+
tbl = catalog.load_table(identifier)
894+
assert len(tbl.history()) > 3
895+
timestamp = tbl.history()[-2].timestamp_ms
896+
current_snapshot_id = tbl.current_snapshot().snapshot_id
897+
# not inclusive of rollback_timestamp
898+
ms = tbl.manage_snapshots().rollback_to_timestamp(timestamp=timestamp)
899+
ms.commit()
900+
tbl.refresh()
901+
assert tbl.current_snapshot().snapshot_id is not current_snapshot_id
902+
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=tbl.history()[-4].snapshot_id)
903+
904+
905+
def test_set_current_snapshot(catalog: Catalog) -> None:
906+
identifier = "default.test_table_snapshot_operations"
907+
tbl = catalog.load_table(identifier)
908+
assert len(tbl.history()) > 3
909+
# test with snapshot_id
910+
target_snapshot_id = tbl.history()[-4].snapshot_id
911+
current_snapshot_id = tbl.current_snapshot().snapshot_id
912+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=target_snapshot_id).commit()
913+
tbl.refresh()
914+
assert tbl.current_snapshot().snapshot_id is not current_snapshot_id
915+
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=target_snapshot_id)
916+
# test with ref_name
917+
new_current_snapshot_id = tbl.current_snapshot().snapshot_id
918+
expected_snapshot_id = tbl.history()[-3].snapshot_id
919+
tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag").commit()
920+
tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag").commit()
921+
tbl.refresh()
922+
assert tbl.current_snapshot().snapshot_id is not new_current_snapshot_id
923+
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id)

0 commit comments

Comments
 (0)