Skip to content

Create rollback and set snapshot APIs #758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions dev/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,51 @@
VALUES (4)
"""
)

spark.sql(
f"""
CREATE OR REPLACE TABLE {catalog_name}.default.test_table_rollback_to_snapshot_id (
timestamp int,
number integer
)
USING iceberg
TBLPROPERTIES (
'format-version'='2'
);
"""
)

spark.sql(
f"""
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
VALUES (200, 1)
"""
)

spark.sql(
f"""
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
VALUES (202, 2)
"""
)

spark.sql(
f"""
DELETE FROM {catalog_name}.default.test_table_rollback_to_snapshot_id
WHERE number = 2
"""
)

spark.sql(
f"""
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
VALUES (204, 3)
"""
)

spark.sql(
f"""
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
VALUES (206, 4)
"""
)
161 changes: 118 additions & 43 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,14 @@
NameMapping,
update_mapping,
)
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
Operation,
Snapshot,
SnapshotLogEntry,
SnapshotSummaryCollector,
Summary,
ancestor_right_before_timestamp,
ancestors_of,
update_snapshot_summaries,
)
Expand Down Expand Up @@ -299,7 +300,12 @@ def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
"""Close and commit the transaction."""
self.commit_transaction()

def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction:
def _apply(
self,
updates: Tuple[TableUpdate, ...],
requirements: Tuple[TableRequirement, ...] = (),
commit_transaction_if_autocommit: bool = True,
) -> Transaction:
"""Check if the requirements are met, and applies the updates to the metadata."""
for requirement in requirements:
requirement.validate(self.table_metadata)
Expand All @@ -309,7 +315,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ

self.table_metadata = update_table_metadata(self.table_metadata, updates)

if self._autocommit:
if self._autocommit and commit_transaction_if_autocommit:
self.commit_transaction()
self._updates = ()
self._requirements = ()
Expand Down Expand Up @@ -402,39 +408,6 @@ def set_ref_snapshot(
requirements = (AssertRefSnapshotId(snapshot_id=parent_snapshot_id, ref="main"),)
return self._apply(updates, requirements)

def _set_ref_snapshot(
self,
snapshot_id: int,
ref_name: str,
type: str,
max_ref_age_ms: Optional[int] = None,
max_snapshot_age_ms: Optional[int] = None,
min_snapshots_to_keep: Optional[int] = None,
) -> UpdatesAndRequirements:
"""Update a ref to a snapshot.

Returns:
The updates and requirements for the set-snapshot-ref staged
"""
updates = (
SetSnapshotRefUpdate(
snapshot_id=snapshot_id,
ref_name=ref_name,
type=type,
max_ref_age_ms=max_ref_age_ms,
max_snapshot_age_ms=max_snapshot_age_ms,
min_snapshots_to_keep=min_snapshots_to_keep,
),
)
requirements = (
AssertRefSnapshotId(
snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if ref_name in self.table_metadata.refs else None,
ref=ref_name,
),
)

return updates, requirements

def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
"""Create a new UpdateSchema to alter the columns of this table.

Expand Down Expand Up @@ -1975,6 +1948,48 @@ def _commit(self) -> UpdatesAndRequirements:
"""Apply the pending changes and commit."""
return self._updates, self._requirements

def _commit_if_ref_updates_exist(self) -> None:
self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False)
self._updates, self._requirements = (), ()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Java implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only issue here is that self.commit will commit the transaction if the ManageSnapshot object comes from

def manage_snapshots(self) -> ManageSnapshots:
"""
Shorthand to run snapshot management operations like create branch, create tag, etc.
Use table.manage_snapshots().<operation>().commit() to run a specific operation.
Use table.manage_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
Pending changes are applied on commit.
We can also use context managers to make more changes. For example,
with table.manage_snapshots() as ms:
ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B")
"""
return ManageSnapshots(transaction=Transaction(self, autocommit=True))

where autocommit is set to true.

One possible way to fix this is that we can add additional parameters in transaction._apply to override the autocommit behavior and call that directly here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated! Now there's an extra parameter commit_transaction_now that defaults to True, and we override it to False when staged refs need to be applied without commiting the transaction.

Copy link
Contributor Author

@chinmay-bhat chinmay-bhat Jul 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I'm re-opening this resolved conversation, since I don't think adding the additional parameter is enough.

Say, in the future, we have more APIs like:

branch_name, min_snapshots_to_keep = "test_branch_min_snapshots_to_keep", 2
with tbl.manage_snapshots() as ms:
        ms.create_branch(branch_name=branch_name, snapshot_id=snapshot_id)
        ms.set_min_snapshots_to_keep(branch_name=branch_name, min_snapshots_to_keep=min_snapshots_to_keep)

The updates and requirements would be :
(SetSnapshotRefUpdate(action='set-snapshot-ref', ref_name='test_branch_min_snapshots_to_keep', type='branch', snapshot_id=71191752302974125, max_ref_age_ms=None, max_snapshot_age_ms=None, min_snapshots_to_keep=None), SetSnapshotRefUpdate(action='set-snapshot-ref', ref_name='test_branch_min_snapshots_to_keep', type='branch', snapshot_id=71191752302974125, max_ref_age_ms=None, max_snapshot_age_ms=None, min_snapshots_to_keep=2))

(AssertRefSnapshotId(type='assert-ref-snapshot-id', ref='test_branch_min_snapshots_to_keep', snapshot_id=None), AssertRefSnapshotId(type='assert-ref-snapshot-id', ref='test_branch_min_snapshots_to_keep', snapshot_id=71191752302974125))

The 2nd requirement will fail with a CommitFailedException as the branch would be missing.
With _commit_if_ref_updates_exist() , the transaction.table_metadata would get updated, but when the transaction exits, it will try to commit_transaction() which runs _do_commit() which runs _commit_table().

In _commit_table(), for non-REST catalogs, we _update_and_stage_table() where we check the requirements with current table metadata, here the 2nd requirement fails.

To fix this, we might consider one of the following solutions:

  1. in transaction._apply identify the differences between current table metadata and staged metadata, and only pass those differences in self._updates, while not sending the ref updates requirements (since we've already validated them once in transaction._apply) OR
  2. improve _update_and_stage_table() to iteratively apply the update with corresponding requirement and always check the requirements with updated_metadata. This is easier than (1), but only serves non-REST catalogs. OR
  3. continue the original implementation, i.e. for every commit_if_ref_exists(), the Transaction commits to the table. This would be expensive IMO, but the result would remain atomic and correct, with minimal changes in the PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chinmay-bhat Thank you so much for digging into this issue! I think you've made a great point. I am thinking of a similar solution like your first point: to derive a list of requirements when we commit the transaction: https://github.com/apache/iceberg/blob/d69ba0568a2e07dfb5af233350ad5668d9aef134/core/src/main/java/org/apache/iceberg/UpdateRequirements.java#L50-L58

This will save us from manually specifying requirements for every UpdateTableMetadata definition and also prevent the problems described above.

Let me research more on this and get back to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HonahX, should I make a new issue for this? Since changing how we specify requirements is not strictly in the scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chinmay-bhat. Sorry for the long wait🙏. I was distracted by other stuff and some blocking issues for 0.7.0 release. Yes, please feel free to create an issue to further discuss it. I can reply to that when I get something.


def _set_ref_snapshot(
self,
snapshot_id: int,
ref_name: str,
type: str,
max_ref_age_ms: Optional[int] = None,
max_snapshot_age_ms: Optional[int] = None,
min_snapshots_to_keep: Optional[int] = None,
) -> ManageSnapshots:
"""Update a ref to a snapshot.

Stages the updates and requirements for the set-snapshot-ref

Returns:
This for method chaining
"""
updates = (
SetSnapshotRefUpdate(
snapshot_id=snapshot_id,
ref_name=ref_name,
type=type,
max_ref_age_ms=max_ref_age_ms,
max_snapshot_age_ms=max_snapshot_age_ms,
min_snapshots_to_keep=min_snapshots_to_keep,
),
)
requirements = (
AssertRefSnapshotId(
snapshot_id=self._transaction.table_metadata.refs[ref_name].snapshot_id
if ref_name in self._transaction.table_metadata.refs
else None,
ref=ref_name,
),
)
self._updates += updates
self._requirements += requirements
return self

def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots:
"""
Create a new tag pointing to the given snapshot id.
Expand All @@ -1987,15 +2002,12 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i
Returns:
This for method chaining
"""
update, requirement = self._transaction._set_ref_snapshot(
return self._set_ref_snapshot(
snapshot_id=snapshot_id,
ref_name=tag_name,
type="tag",
max_ref_age_ms=max_ref_age_ms,
)
self._updates += update
self._requirements += requirement
return self

def create_branch(
self,
Expand All @@ -2017,16 +2029,79 @@ def create_branch(
Returns:
This for method chaining
"""
update, requirement = self._transaction._set_ref_snapshot(
return self._set_ref_snapshot(
snapshot_id=snapshot_id,
ref_name=branch_name,
type="branch",
max_ref_age_ms=max_ref_age_ms,
max_snapshot_age_ms=max_snapshot_age_ms,
min_snapshots_to_keep=min_snapshots_to_keep,
)
self._updates += update
self._requirements += requirement

def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
"""Rollback the table to the given snapshot id.

The snapshot needs to be an ancestor of the current table state.

Args:
snapshot_id (int): rollback to this snapshot_id that used to be current.
Returns:
This for method chaining
"""
self._commit_if_ref_updates_exist()
if self._transaction._table.snapshot_by_id(snapshot_id) is None:
raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")
if snapshot_id not in {
ancestor.snapshot_id
for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata)
}:
raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")
return self._set_ref_snapshot(snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH))

def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots:
"""Rollback the table to the snapshot right before the given timestamp.

The snapshot needs to be an ancestor of the current table state.

Args:
timestamp (int): rollback to the snapshot that used to be current right before this timestamp.
Returns:
This for method chaining
"""
self._commit_if_ref_updates_exist()
if (
snapshot := ancestor_right_before_timestamp(
self._transaction._table.current_snapshot(), self._transaction.table_metadata, timestamp
)
) is None:
raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}")
return self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH))

def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots:
"""Set the table to a specific snapshot identified either by its id or the branch/tag its on, not both.

The snapshot is not required to be an ancestor of the current table state.

Args:
snapshot_id (Optional[int]): id of the snapshot to be set as current
ref_name (Optional[str]): branch/tag where the snapshot to be set as current exists.
Returns:
This for method chaining
"""
self._commit_if_ref_updates_exist()
if (not snapshot_id or ref_name) and (snapshot_id or not ref_name):
raise ValidationError("Either snapshot_id or ref must be provided")
else:
if snapshot_id is None:
if ref_name not in self._transaction.table_metadata.refs:
raise ValidationError(f"Cannot set snapshot current to unknown ref {ref_name}")
target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id
else:
target_snapshot_id = snapshot_id
if (snapshot := self._transaction._table.snapshot_by_id(target_snapshot_id)) is None:
raise ValidationError(f"Cannot set snapshot current with snapshot id: {snapshot_id} or ref_name: {ref_name}")

self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH))
return self


Expand Down
11 changes: 11 additions & 0 deletions pyiceberg/table/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,17 @@ def set_when_positive(properties: Dict[str, str], num: int, property_name: str)
properties[property_name] = str(num)


def ancestor_right_before_timestamp(
current_snapshot: Optional[Snapshot], table_metadata: TableMetadata, timestamp_ms: int
) -> Optional[Snapshot]:
"""Get the ancestor right before the given timestamp."""
if current_snapshot:
for ancestor in ancestors_of(current_snapshot, table_metadata):
if ancestor.timestamp_ms < timestamp_ms:
return ancestor
return None


def ancestors_of(current_snapshot: Optional[Snapshot], table_metadata: TableMetadata) -> Iterable[Snapshot]:
"""Get the ancestors of and including the given snapshot."""
snapshot = current_snapshot
Expand Down
Loading