106
106
NameMapping ,
107
107
update_mapping ,
108
108
)
109
- from pyiceberg .table .refs import MAIN_BRANCH , SnapshotRef
109
+ from pyiceberg .table .refs import MAIN_BRANCH , SnapshotRef , SnapshotRefType
110
110
from pyiceberg .table .snapshots import (
111
111
Operation ,
112
112
Snapshot ,
@@ -1980,6 +1980,13 @@ def _commit_if_ref_updates_exist(self) -> None:
1980
1980
self .commit ()
1981
1981
self ._updates , self ._requirements = (), ()
1982
1982
1983
+ def _stage_main_branch_snapshot_ref (self , snapshot_id : int ) -> None :
1984
+ update , requirement = self ._transaction ._set_ref_snapshot (
1985
+ snapshot_id = snapshot_id , ref_name = MAIN_BRANCH , type = SnapshotRefType .BRANCH
1986
+ )
1987
+ self ._updates += update
1988
+ self ._requirements += requirement
1989
+
1983
1990
def create_tag (self , snapshot_id : int , tag_name : str , max_ref_age_ms : Optional [int ] = None ) -> ManageSnapshots :
1984
1991
"""
1985
1992
Create a new tag pointing to the given snapshot id.
@@ -2052,10 +2059,7 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
2052
2059
for ancestor in ancestors_of (self ._transaction ._table .current_snapshot (), self ._transaction .table_metadata )
2053
2060
}:
2054
2061
raise ValidationError (f"Cannot roll back to snapshot, not an ancestor of the current state: { snapshot_id } " )
2055
-
2056
- update , requirement = self ._transaction ._set_ref_snapshot (snapshot_id = snapshot_id , ref_name = "main" , type = "branch" )
2057
- self ._updates += update
2058
- self ._requirements += requirement
2062
+ self ._stage_main_branch_snapshot_ref (snapshot_id = snapshot_id )
2059
2063
return self
2060
2064
2061
2065
def rollback_to_timestamp (self , timestamp : int ) -> ManageSnapshots :
@@ -2075,12 +2079,7 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots:
2075
2079
)
2076
2080
) is None :
2077
2081
raise ValidationError (f"Cannot roll back, no valid snapshot older than: { timestamp } " )
2078
-
2079
- update , requirement = self ._transaction ._set_ref_snapshot (
2080
- snapshot_id = snapshot .snapshot_id , ref_name = "main" , type = "branch"
2081
- )
2082
- self ._updates += update
2083
- self ._requirements += requirement
2082
+ self ._stage_main_branch_snapshot_ref (snapshot_id = snapshot .snapshot_id )
2084
2083
return self
2085
2084
2086
2085
def set_current_snapshot (self , snapshot_id : Optional [int ] = None , ref_name : Optional [str ] = None ) -> ManageSnapshots :
@@ -2099,17 +2098,15 @@ def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Opti
2099
2098
raise ValidationError ("Either snapshot_id or ref must be provided" )
2100
2099
else :
2101
2100
if snapshot_id is None :
2102
- target_snapshot_id = self ._transaction .table_metadata .refs [ref_name ].snapshot_id # type:ignore
2101
+ if ref_name not in self ._transaction .table_metadata .refs :
2102
+ raise ValidationError (f"Cannot set snapshot current to unknown ref { ref_name } " )
2103
+ target_snapshot_id = self ._transaction .table_metadata .refs [ref_name ].snapshot_id
2103
2104
else :
2104
2105
target_snapshot_id = snapshot_id
2105
2106
if (snapshot := self ._transaction ._table .snapshot_by_id (target_snapshot_id )) is None :
2106
2107
raise ValidationError (f"Cannot set snapshot current with snapshot id: { snapshot_id } or ref_name: { ref_name } " )
2107
2108
2108
- update , requirement = self ._transaction ._set_ref_snapshot (
2109
- snapshot_id = snapshot .snapshot_id , ref_name = "main" , type = "branch"
2110
- )
2111
- self ._updates += update
2112
- self ._requirements += requirement
2109
+ self ._stage_main_branch_snapshot_ref (snapshot_id = snapshot .snapshot_id )
2113
2110
return self
2114
2111
2115
2112
0 commit comments