@@ -1648,7 +1648,7 @@ def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with
16481648
16491649@pytest .mark .integration
16501650def test_intertwined_branch_writes (session_catalog : Catalog , arrow_table_with_null : pa .Table ) -> None :
1651- identifier = "default.test_concurrent_branch_operations "
1651+ identifier = "default.test_intertwined_branch_operations "
16521652 branch1 = "existing_branch_1"
16531653 branch2 = "existing_branch_2"
16541654
@@ -1669,3 +1669,61 @@ def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_nu
16691669 assert len (tbl .scan ().use_ref (branch1 ).to_arrow ()) == 2
16701670 assert len (tbl .scan ().use_ref (branch2 ).to_arrow ()) == 3
16711671 assert len (tbl .scan ().to_arrow ()) == 6
1672+
1673+
1674+ @pytest .mark .integration
1675+ def test_branch_spark_write_py_read (session_catalog : Catalog , spark : SparkSession , arrow_table_with_null : pa .Table ) -> None :
1676+ # Intialize table with branch
1677+ identifier = "default.test_branch_spark_write_py_read"
1678+ tbl = _create_table (session_catalog , identifier , {"format-version" : "2" }, [arrow_table_with_null ])
1679+ branch = "existing_spark_branch"
1680+
1681+ # Create branch in Spark
1682+ spark .sql (f"ALTER TABLE { identifier } CREATE BRANCH { branch } " )
1683+
1684+ # Spark Write
1685+ spark .sql (
1686+ f"""
1687+ DELETE FROM { identifier } .branch_{ branch }
1688+ WHERE int = 9
1689+ """
1690+ )
1691+
1692+ # Refresh table to get new refs
1693+ tbl .refresh ()
1694+
1695+ # Python Read
1696+ assert len (tbl .scan ().to_arrow ()) == 3
1697+ assert len (tbl .scan ().use_ref (branch ).to_arrow ()) == 2
1698+
1699+
1700+ @pytest .mark .integration
1701+ def test_branch_py_write_spark_read (session_catalog : Catalog , spark : SparkSession , arrow_table_with_null : pa .Table ) -> None :
1702+ # Intialize table with branch
1703+ identifier = "default.test_branch_py_write_spark_read"
1704+ tbl = _create_table (session_catalog , identifier , {"format-version" : "2" }, [arrow_table_with_null ])
1705+ branch = "existing_py_branch"
1706+
1707+ assert tbl .metadata .current_snapshot_id is not None
1708+
1709+ # Create branch
1710+ tbl .manage_snapshots ().create_branch (snapshot_id = tbl .metadata .current_snapshot_id , branch_name = branch ).commit ()
1711+
1712+ # Python Write
1713+ tbl .delete ("int = 9" , branch = branch )
1714+
1715+ # Spark Read
1716+ main_df = spark .sql (
1717+ f"""
1718+ SELECT *
1719+ FROM { identifier }
1720+ """
1721+ )
1722+ branch_df = spark .sql (
1723+ f"""
1724+ SELECT *
1725+ FROM { identifier } .branch_{ branch }
1726+ """
1727+ )
1728+ assert main_df .count () == 3
1729+ assert branch_df .count () == 2
0 commit comments