Skip to content

Commit d4262a0

Browse files
committed
add public and private APIs, register RemoveSnapshotRefUpdate with apply metadata fn
1 parent 1dde51a commit d4262a0

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

pyiceberg/table/__init__.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,24 @@ def _set_ref_snapshot(
434434

435435
return updates, requirements
436436

437+
def _remove_ref_snapshot(self, ref_name: str) -> UpdatesAndRequirements:
438+
"""Remove a snapshot ref.
439+
440+
Args:
441+
ref_name: branch / tag name to remove
442+
443+
Returns
444+
The updates and requirements for the remove-snapshot-ref.
445+
"""
446+
updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),)
447+
requirements = (
448+
AssertRefSnapshotId(
449+
snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if ref_name in self.table_metadata.refs else None,
450+
ref=ref_name,
451+
),
452+
)
453+
return updates, requirements
454+
437455
def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
438456
"""Create a new UpdateSchema to alter the columns of this table.
439457
@@ -1022,6 +1040,23 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
10221040
return base_metadata.model_copy(update=metadata_updates)
10231041

10241042

1043+
@_apply_table_update.register(RemoveSnapshotRefUpdate)
1044+
def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
1045+
if (existing_ref := base_metadata.refs.get(update.ref_name)) is None:
1046+
return base_metadata
1047+
1048+
if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None:
1049+
raise ValueError(f"Cannot remove {update.ref_name} ref with unknown snapshot {existing_ref.snapshot_id}")
1050+
1051+
if update.ref_name == MAIN_BRANCH:
1052+
raise ValueError("Cannot remove main branch")
1053+
1054+
metadata_refs = {**base_metadata.refs}
1055+
metadata_refs.pop(update.ref_name, None)
1056+
context.add_update(update)
1057+
return base_metadata.model_copy(update={"refs": metadata_refs})
1058+
1059+
10251060
@_apply_table_update.register(AddSortOrderUpdate)
10261061
def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
10271062
context.add_update(update)
@@ -1978,6 +2013,21 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i
19782013
self._requirements += requirement
19792014
return self
19802015

2016+
def remove_tag(self, tag_name: str) -> ManageSnapshots:
2017+
"""
2018+
Remove a tag.
2019+
2020+
Args:
2021+
tag_name (str): name of tag to remove
2022+
2023+
Returns:
2024+
This for method chaining
2025+
"""
2026+
update, requirement = self._transaction._remove_ref_snapshot(ref_name=tag_name)
2027+
self._updates += update
2028+
self._requirements += requirement
2029+
return self
2030+
19812031
def create_branch(
19822032
self,
19832033
snapshot_id: int,
@@ -2010,6 +2060,20 @@ def create_branch(
20102060
self._requirements += requirement
20112061
return self
20122062

2063+
def remove_branch(self, branch_name: str) -> ManageSnapshots:
2064+
"""
2065+
Remove a branch.
2066+
2067+
Args:
2068+
branch_name (str): name of branch to remove
2069+
Returns:
2070+
This for method chaining
2071+
"""
2072+
update, requirement = self._transaction._remove_ref_snapshot(ref_name=branch_name)
2073+
self._updates += update
2074+
self._requirements += requirement
2075+
return self
2076+
20132077

20142078
class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
20152079
_schema: Schema

tests/integration/test_snapshot_operations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,35 @@ def test_create_branch(catalog: Catalog) -> None:
4040
branch_snapshot_id = tbl.history()[-2].snapshot_id
4141
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit()
4242
assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
43+
44+
45+
@pytest.mark.integration
46+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
47+
def test_remove_tag(catalog: Catalog) -> None:
48+
identifier = "default.test_table_snapshot_operations"
49+
tbl = catalog.load_table(identifier)
50+
assert len(tbl.history()) > 3
51+
# first, create the tag to remove
52+
tag_name = "tag_to_remove"
53+
tag_snapshot_id = tbl.history()[-3].snapshot_id
54+
tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name=tag_name).commit()
55+
assert tbl.metadata.refs[tag_name] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag")
56+
# now, remove the tag
57+
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
58+
assert tbl.metadata.refs.get(tag_name, None) is None
59+
60+
61+
@pytest.mark.integration
62+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
63+
def test_remove_branch(catalog: Catalog) -> None:
64+
identifier = "default.test_table_snapshot_operations"
65+
tbl = catalog.load_table(identifier)
66+
assert len(tbl.history()) > 2
67+
# first, create the branch to remove
68+
branch_name = "branch_to_remove"
69+
branch_snapshot_id = tbl.history()[-2].snapshot_id
70+
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name=branch_name).commit()
71+
assert tbl.metadata.refs[branch_name] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
72+
# now, remove the branch
73+
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
74+
assert tbl.metadata.refs.get(branch_name, None) is None

0 commit comments

Comments
 (0)