Skip to content

Commit 1bd61c9

Browse files
committed
wip
1 parent 5f71c4f commit 1bd61c9

File tree

3 files changed

+48
-44
lines changed

3 files changed

+48
-44
lines changed

pyiceberg/table/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,22 +1290,22 @@ def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
12901290
return self.snapshot_by_id(ref.snapshot_id)
12911291
return None
12921292

1293-
def rollback_to_snapshot(self, snapshot_id: int) -> Transaction:
1293+
def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
12941294
if self.snapshot_by_id(snapshot_id) is None:
12951295
raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")
1296-
if snapshot_id not in [ancestor.ancestor_id for ancestor in self.current_ancestors()]:
1296+
if snapshot_id not in set(ancestor.ancestor_id for ancestor in ancestors_of(self.current_snapshot(), self._transaction._table.metadata)):
12971297
raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")
1298-
return self.set_ref_snapshot(snapshot_id=snapshot_id, parent_snapshot_id=self.current_snapshot().snapshot_id, ref_name="main", type="branch")
1298+
return self._transaction.set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch")
12991299

1300-
def rollback_to_timestamp(self, timestamp: int) -> Transaction:
1301-
if (snapshot := self.latest_snapshot_before_timestamp(timestamp)) is None:
1300+
def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots:
1301+
if (snapshot := self.snapshot_as_of_timestamp(timestamp)) is None:
13021302
raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}")
1303-
return self.set_ref_snapshot(snapshot_id=snapshot.snapshot_id, parent_snapshot_id=self.current_snapshot().snapshot_id, ref_name="main", type="branch")
1303+
return self._transaction.set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch")
13041304

1305-
def set_current_snapshot(self, snapshot_id: int) -> Transaction:
1305+
def set_current_snapshot(self, snapshot_id: int) -> ManageSnapshots:
13061306
if self.snapshot_by_id(snapshot_id) is None:
13071307
raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")
1308-
return self.set_ref_snapshot(snapshot_id=snapshot_id, parent_snapshot_id=self.current_snapshot().snapshot_id, ref_name="main", type="branch")
1308+
return self._transaction.set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch")
13091309

13101310
def history(self) -> List[SnapshotLogEntry]:
13111311
"""Get the snapshot history of this table."""

tests/integration/test_writes/test_writes.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,3 +869,43 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null
869869
tbl.append(arrow_table_without_some_columns)
870870
# overwrite and then append should produce twice the data
871871
assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2
872+
873+
874+
@pytest.mark.integration
875+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
876+
def test_update_metadata_rollback_to_snapshot(table_v2: Table) -> None:
877+
assert table_v2.rollback_to_snapshot(snapshot_id=3051729675574597004) == Snapshot(
878+
snapshot_id=3051729675574597004,
879+
parent_snapshot_id=None,
880+
sequence_number=0,
881+
timestamp_ms=1515100955770,
882+
manifest_list='s3://a/b/1.avro',
883+
summary=Summary(Operation.APPEND),
884+
schema_id=None,
885+
)
886+
887+
888+
@pytest.mark.integration
889+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
890+
def test_update_metadata_rollback_to_timestamp(table_v2: Table) -> None:
891+
assert table_v2.rollback_to_timestamp(timestamp=1555100955771) == Snapshot(
892+
snapshot_id=3055729675574597004,
893+
parent_snapshot_id=3051729675574597004,
894+
sequence_number=1,
895+
timestamp_ms=1555100955770,
896+
manifest_list='s3://a/b/2.avro',
897+
summary=Summary(Operation.APPEND),
898+
schema_id=1,
899+
)
900+
901+
902+
def test_update_metadata_set_current_snapshot(table_v2: Table) -> None:
903+
assert table_v2.set_current_snapshot(snapshot_id=3051729675574597004) == Snapshot(
904+
snapshot_id=3051729675574597004,
905+
parent_snapshot_id=None,
906+
sequence_number=0,
907+
timestamp_ms=1515100955770,
908+
manifest_list='s3://a/b/1.avro',
909+
summary=Summary(Operation.APPEND),
910+
schema_id=None,
911+
)

tests/table/test_init.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -652,42 +652,6 @@ def test_update_metadata_add_snapshot(table_v2: Table) -> None:
652652
assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms
653653

654654

655-
def test_update_metadata_rollback_to_snapshot(table_v2: Table) -> None:
656-
assert table_v2.rollback_to_snapshot(snapshot_id=3051729675574597004) == Snapshot(
657-
snapshot_id=3051729675574597004,
658-
parent_snapshot_id=None,
659-
sequence_number=0,
660-
timestamp_ms=1515100955770,
661-
manifest_list='s3://a/b/1.avro',
662-
summary=Summary(Operation.APPEND),
663-
schema_id=None,
664-
)
665-
666-
667-
def test_update_metadata_rollback_to_timestamp(table_v2: Table) -> None:
668-
assert table_v2.rollback_to_timestamp(timestamp=1555100955771) == Snapshot(
669-
snapshot_id=3055729675574597004,
670-
parent_snapshot_id=3051729675574597004,
671-
sequence_number=1,
672-
timestamp_ms=1555100955770,
673-
manifest_list='s3://a/b/2.avro',
674-
summary=Summary(Operation.APPEND),
675-
schema_id=1,
676-
)
677-
678-
679-
def test_update_metadata_set_current_snapshot(table_v2: Table) -> None:
680-
assert table_v2.set_current_snapshot(snapshot_id=3051729675574597004) == Snapshot(
681-
snapshot_id=3051729675574597004,
682-
parent_snapshot_id=None,
683-
sequence_number=0,
684-
timestamp_ms=1515100955770,
685-
manifest_list='s3://a/b/1.avro',
686-
summary=Summary(Operation.APPEND),
687-
schema_id=None,
688-
)
689-
690-
691655
def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None:
692656
update = SetSnapshotRefUpdate(
693657
ref_name="main",

0 commit comments

Comments
 (0)