Skip to content

Commit 59f1626

Browse files
committed
fixes based on review
1 parent 1f4a404 commit 59f1626

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

pyiceberg/table/__init__.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
NameMapping,
107107
update_mapping,
108108
)
109-
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
109+
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
110110
from pyiceberg.table.snapshots import (
111111
Operation,
112112
Snapshot,
@@ -1980,6 +1980,13 @@ def _commit_if_ref_updates_exist(self) -> None:
19801980
self.commit()
19811981
self._updates, self._requirements = (), ()
19821982

1983+
def _stage_main_branch_snapshot_ref(self, snapshot_id: int) -> None:
1984+
update, requirement = self._transaction._set_ref_snapshot(
1985+
snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=SnapshotRefType.BRANCH
1986+
)
1987+
self._updates += update
1988+
self._requirements += requirement
1989+
19831990
def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots:
19841991
"""
19851992
Create a new tag pointing to the given snapshot id.
@@ -2052,10 +2059,7 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
20522059
for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata)
20532060
}:
20542061
raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")
2055-
2056-
update, requirement = self._transaction._set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch")
2057-
self._updates += update
2058-
self._requirements += requirement
2062+
self._stage_main_branch_snapshot_ref(snapshot_id=snapshot_id)
20592063
return self
20602064

20612065
def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots:
@@ -2075,12 +2079,7 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots:
20752079
)
20762080
) is None:
20772081
raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}")
2078-
2079-
update, requirement = self._transaction._set_ref_snapshot(
2080-
snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch"
2081-
)
2082-
self._updates += update
2083-
self._requirements += requirement
2082+
self._stage_main_branch_snapshot_ref(snapshot_id=snapshot.snapshot_id)
20842083
return self
20852084

20862085
def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots:
@@ -2099,17 +2098,15 @@ def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Opti
20992098
raise ValidationError("Either snapshot_id or ref must be provided")
21002099
else:
21012100
if snapshot_id is None:
2102-
target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id # type:ignore
2101+
if ref_name not in self._transaction.table_metadata.refs:
2102+
raise ValidationError(f"Cannot set snapshot current to unknown ref {ref_name}")
2103+
target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id
21032104
else:
21042105
target_snapshot_id = snapshot_id
21052106
if (snapshot := self._transaction._table.snapshot_by_id(target_snapshot_id)) is None:
21062107
raise ValidationError(f"Cannot set snapshot current with snapshot id: {snapshot_id} or ref_name: {ref_name}")
21072108

2108-
update, requirement = self._transaction._set_ref_snapshot(
2109-
snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch"
2110-
)
2111-
self._updates += update
2112-
self._requirements += requirement
2109+
self._stage_main_branch_snapshot_ref(snapshot_id=snapshot.snapshot_id)
21132110
return self
21142111

21152112

0 commit comments

Comments
 (0)