Skip to content

Refactor to write APIs to default to main branch #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyiceberg/cli/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pyiceberg.catalog import Catalog, load_catalog
from pyiceberg.cli.output import ConsoleOutput, JsonOutput, Output
from pyiceberg.exceptions import NoSuchNamespaceError, NoSuchPropertyException, NoSuchTableError
from pyiceberg.table.refs import SnapshotRef
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType

DEFAULT_MIN_SNAPSHOTS_TO_KEEP = 1
DEFAULT_MAX_SNAPSHOT_AGE_MS = 432000000
Expand Down Expand Up @@ -388,7 +388,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None:
refs = table.refs()
if type:
type = type.lower()
if type not in {"branch", "tag"}:
if type not in {SnapshotRefType.BRANCH, SnapshotRefType.TAG}:
raise ValueError(f"Type must be either branch or tag, got: {type}")

relevant_refs = [
Expand All @@ -402,7 +402,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None:

def _retention_properties(ref: SnapshotRef, table_properties: Dict[str, str]) -> Dict[str, str]:
retention_properties = {}
if ref.snapshot_ref_type == "branch":
if ref.snapshot_ref_type == SnapshotRefType.BRANCH:
default_min_snapshots_to_keep = table_properties.get(
"history.expire.min-snapshots-to-keep", DEFAULT_MIN_SNAPSHOTS_TO_KEEP
)
Expand Down
22 changes: 14 additions & 8 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
create_mapping_from_schema,
parse_mapping_from_json,
)
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
Operation,
Snapshot,
Expand Down Expand Up @@ -269,7 +269,7 @@ def set_ref_snapshot(
)
)

self._append_requirements(AssertRefSnapshotId(snapshot_id=parent_snapshot_id, ref="main"))
self._append_requirements(AssertRefSnapshotId(snapshot_id=parent_snapshot_id, ref=ref_name))
return self

def update_schema(self) -> UpdateSchema:
Expand Down Expand Up @@ -391,7 +391,7 @@ class AddSnapshotUpdate(TableUpdate):
class SetSnapshotRefUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.set_snapshot_ref
ref_name: str = Field(alias="ref-name")
type: Literal["tag", "branch"]
type: Literal[SnapshotRefType.TAG, SnapshotRefType.BRANCH]
snapshot_id: int = Field(alias="snapshot-id")
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)]
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)]
Expand Down Expand Up @@ -925,7 +925,7 @@ def name_mapping(self) -> NameMapping:
else:
return create_mapping_from_schema(self.schema())

def append(self, df: pa.Table) -> None:
def append(self, df: pa.Table, branch: str = MAIN_BRANCH) -> None:
"""
Append data to the table.

Expand All @@ -939,13 +939,13 @@ def append(self, df: pa.Table) -> None:
raise ValueError("Cannot write to tables with a sort-order")

data_files = _dataframe_to_data_files(self, df=df)
merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)
merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self, branch=branch)
for data_file in data_files:
merge.append_data_file(data_file)

merge.commit()

def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> None:
def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, branch: str = MAIN_BRANCH) -> None:
"""
Overwrite all the data in the table.

Expand All @@ -967,6 +967,7 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
merge = _MergingSnapshotProducer(
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,
table=self,
branch=branch,
)

for data_file in data_files:
Expand Down Expand Up @@ -2279,12 +2280,14 @@ class _MergingSnapshotProducer:
_parent_snapshot_id: Optional[int]
_added_data_files: List[DataFile]
_commit_uuid: uuid.UUID
_branch: str

def __init__(self, operation: Operation, table: Table) -> None:
def __init__(self, operation: Operation, table: Table, branch: str) -> None:
self._operation = operation
self._table = table
self._snapshot_id = table.new_snapshot_id()
# Since we only support the main branch for now
self._branch = branch
self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := self._table.current_snapshot()) else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is more involved than just setting the name of the branch. There we take the parent snapshot that's by default the parent of the current snapshot. The current snapshot is the HEAD of the main branch. Before adding this, I would love to see also some integration tests with Spark, and some tests that assert the snapshot-tree.

self._added_data_files = []
self._commit_uuid = uuid.uuid4()
Expand Down Expand Up @@ -2445,7 +2448,10 @@ def commit(self) -> Snapshot:
with self._table.transaction() as tx:
tx.add_snapshot(snapshot=snapshot)
tx.set_ref_snapshot(
snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id,
ref_name=self._branch,
type=SnapshotRefType.BRANCH,
)

return snapshot