Skip to content

Commit bc6fb68

Browse files
committed
Added integration tests with spark
1 parent 4cf9198 commit bc6fb68

File tree

1 file changed

+59
-1
lines changed

1 file changed

+59
-1
lines changed

tests/integration/test_writes/test_writes.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,7 @@ def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with
16481648

16491649
@pytest.mark.integration
16501650
def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1651-
identifier = "default.test_concurrent_branch_operations"
1651+
identifier = "default.test_intertwined_branch_operations"
16521652
branch1 = "existing_branch_1"
16531653
branch2 = "existing_branch_2"
16541654

@@ -1669,3 +1669,61 @@ def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_nu
16691669
assert len(tbl.scan().use_ref(branch1).to_arrow()) == 2
16701670
assert len(tbl.scan().use_ref(branch2).to_arrow()) == 3
16711671
assert len(tbl.scan().to_arrow()) == 6
1672+
1673+
1674+
@pytest.mark.integration
1675+
def test_branch_spark_write_py_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None:
1676+
# Intialize table with branch
1677+
identifier = "default.test_branch_spark_write_py_read"
1678+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1679+
branch = "existing_spark_branch"
1680+
1681+
# Create branch in Spark
1682+
spark.sql(f"ALTER TABLE {identifier} CREATE BRANCH {branch}")
1683+
1684+
# Spark Write
1685+
spark.sql(
1686+
f"""
1687+
DELETE FROM {identifier}.branch_{branch}
1688+
WHERE int = 9
1689+
"""
1690+
)
1691+
1692+
# Refresh table to get new refs
1693+
tbl.refresh()
1694+
1695+
# Python Read
1696+
assert len(tbl.scan().to_arrow()) == 3
1697+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 2
1698+
1699+
1700+
@pytest.mark.integration
1701+
def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None:
1702+
# Intialize table with branch
1703+
identifier = "default.test_branch_py_write_spark_read"
1704+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1705+
branch = "existing_py_branch"
1706+
1707+
assert tbl.metadata.current_snapshot_id is not None
1708+
1709+
# Create branch
1710+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1711+
1712+
# Python Write
1713+
tbl.delete("int = 9", branch=branch)
1714+
1715+
# Spark Read
1716+
main_df = spark.sql(
1717+
f"""
1718+
SELECT *
1719+
FROM {identifier}
1720+
"""
1721+
)
1722+
branch_df = spark.sql(
1723+
f"""
1724+
SELECT *
1725+
FROM {identifier}.branch_{branch}
1726+
"""
1727+
)
1728+
assert main_df.count() == 3
1729+
assert branch_df.count() == 2

0 commit comments

Comments
 (0)