Skip to content

Commit 1306663

Browse files
committed
Add tests
1 parent f4e23e5 commit 1306663

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

tests/integration/test_catalog.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,18 @@
2626
from pyiceberg.catalog.rest import RestCatalog
2727
from pyiceberg.catalog.sql import SqlCatalog
2828
from pyiceberg.exceptions import (
29+
CommitFailedException,
2930
NamespaceAlreadyExistsError,
3031
NamespaceNotEmptyError,
3132
NoSuchNamespaceError,
3233
NoSuchTableError,
3334
TableAlreadyExistsError,
3435
)
3536
from pyiceberg.io import WAREHOUSE
37+
from pyiceberg.partitioning import PartitionField, PartitionSpec
3638
from pyiceberg.schema import Schema
39+
from pyiceberg.transforms import BucketTransform
40+
from pyiceberg.types import LongType, NestedField, StringType
3741
from tests.conftest import clean_up
3842

3943

@@ -98,6 +102,9 @@ def hive_catalog() -> Generator[Catalog, None, None]:
98102
]
99103

100104

105+
SIMPLE_SCHEMA = Schema(NestedField(1, "id", LongType(), required=True), NestedField(2, "data", StringType(), required=False))
106+
107+
101108
@pytest.mark.integration
102109
@pytest.mark.parametrize("test_catalog", CATALOGS)
103110
def test_create_table_with_default_location(
@@ -343,3 +350,64 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str)
343350
else:
344351
assert k in update_report.removed
345352
assert "updated test description" == test_catalog.load_namespace_properties(database_name)["comment"]
353+
354+
355+
@pytest.mark.integration
356+
@pytest.mark.parametrize("test_catalog", CATALOGS)
357+
def test_update_table_spec(test_catalog: Catalog, table_name: str, database_name: str) -> None:
358+
identifier = (database_name, table_name)
359+
test_catalog.create_namespace(database_name)
360+
table = test_catalog.create_table(identifier, SIMPLE_SCHEMA)
361+
362+
with table.update_spec() as update:
363+
update.add_field(source_column_name="id", transform=BucketTransform(16), partition_field_name="shard")
364+
365+
loaded = test_catalog.load_table(identifier)
366+
expected_spec = PartitionSpec(
367+
PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="shard"), spec_id=1
368+
)
369+
# The spec ID may not match, so check equality of the fields
370+
assert loaded.spec().fields == expected_spec.fields
371+
372+
373+
@pytest.mark.integration
374+
@pytest.mark.parametrize("test_catalog", CATALOGS)
375+
def test_update_table_spec_conflict(test_catalog: Catalog, table_name: str, database_name: str) -> None:
376+
identifier = (database_name, table_name)
377+
test_catalog.create_namespace(database_name)
378+
spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket"))
379+
table = test_catalog.create_table(identifier, SIMPLE_SCHEMA, partition_spec=spec)
380+
381+
update = table.update_spec()
382+
update.add_field(source_column_name="data", transform=BucketTransform(16), partition_field_name="shard")
383+
384+
# concurrent update
385+
concurrent_table = test_catalog.load_table(identifier)
386+
with concurrent_table.update_spec() as concurrent_update:
387+
concurrent_update.remove_field("id_bucket")
388+
389+
with pytest.raises(CommitFailedException, match="Requirement failed: default partition spec changed|Cannot commit"):
390+
update.commit()
391+
392+
loaded = test_catalog.load_table(identifier)
393+
assert loaded.spec() == PartitionSpec(spec_id=1)
394+
395+
396+
@pytest.mark.integration
397+
@pytest.mark.parametrize("test_catalog", CATALOGS)
398+
def test_update_table_spec_then_revert(test_catalog: Catalog, table_name: str, database_name: str) -> None:
399+
identifier = (database_name, table_name)
400+
test_catalog.create_namespace(database_name)
401+
402+
initial_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket"))
403+
404+
table = test_catalog.create_table(identifier, SIMPLE_SCHEMA, partition_spec=initial_spec, properties={"format-version": "2"})
405+
assert table.format_version == 2
406+
407+
with table.update_spec() as update:
408+
update.add_identity(source_column_name="id")
409+
410+
with table.update_spec() as update:
411+
update.remove_field("id")
412+
413+
assert table.spec() == initial_spec

0 commit comments

Comments
 (0)