diff --git a/.github/labeler.yml b/.github/labeler.yml
index 0d04244f8822b..cf1d2a7117203 100644
--- a/.github/labeler.yml
+++ b/.github/labeler.yml
@@ -152,6 +152,6 @@ WEB UI:
DEPLOY:
- "sbin/**/*"
CONNECT:
- - "connect/**/*"
+ - "connector/connect/**/*"
- "**/sql/sparkconnect/**/*"
- "python/pyspark/sql/**/connect/**/*"
diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index b0847187dffdd..b7f8b10c00f5b 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -334,7 +334,7 @@ jobs:
- >-
pyspark-pandas-slow
- >-
- pyspark-sql-connect
+ pyspark-connect
env:
MODULES_TO_TEST: ${{ matrix.modules }}
HADOOP_PROFILE: ${{ inputs.hadoop }}
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 218bf36795046..f37edcd7e49f4 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -74,11 +74,6 @@
spark-repl_${scala.binary.version}
${project.version}
-
- org.apache.spark
- spark-connect_${scala.binary.version}
- ${project.version}
-
diff --git a/python/mypy.ini b/python/mypy.ini
index 1094f33e833d2..baf1a4048eb04 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -26,12 +26,6 @@ warn_redundant_casts = True
[mypy-pyspark.sql.connect.proto.*]
ignore_errors = True
-; TODO(SPARK-40537) reenable mypi support.
-[mypy-pyspark.sql.tests.connect.*]
-disallow_untyped_defs = False
-ignore_missing_imports = True
-ignore_errors = True
-
; Allow untyped def in internal modules and tests
[mypy-pyspark.daemon]
@@ -78,6 +72,9 @@ disallow_untyped_defs = False
[mypy-pyspark.sql.tests.*]
disallow_untyped_defs = False
+; TODO(SPARK-40537) reenable mypi support.
+ignore_missing_imports = True
+ignore_errors = True
[mypy-pyspark.sql.pandas.serializers]
disallow_untyped_defs = False
diff --git a/python/pyspark/sql/connect/README.md b/python/pyspark/sql/connect/README.md
index e79e9aae9dd2b..ac3926a28b9c1 100644
--- a/python/pyspark/sql/connect/README.md
+++ b/python/pyspark/sql/connect/README.md
@@ -1,5 +1,4 @@
-
-# [EXPERIMENTAL] Spark Connect
+# Spark Connect
**Spark Connect is a strictly experimental feature and under heavy development.
All APIs should be considered volatile and should not be used in production.**
@@ -8,30 +7,32 @@ This module contains the implementation of Spark Connect which is a logical plan
facade for the implementation in Spark. Spark Connect is directly integrated into the build
of Spark. To enable it, you only need to activate the driver plugin for Spark Connect.
-
-
-
## Build
1. Build Spark as usual per the documentation.
+
2. Build and package the Spark Connect package
+
```bash
./build/mvn -Phive package
```
+
or
- ```shell
+
+ ```bash
./build/sbt -Phive package
```
## Run Spark Shell
```bash
-./bin/spark-shell --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
+./bin/spark-shell \
+ --packages org.apache.spark:spark-connect_2.12:3.4.0 \
+ --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
```
## Run Tests
-
```bash
./run-tests --testnames 'pyspark.sql.tests.connect.test_spark_connect'
```
diff --git a/python/pyspark/sql/tests/connect/__init__.py b/python/pyspark/sql/tests/connect/__init__.py
deleted file mode 100644
index cce3acad34a49..0000000000000
--- a/python/pyspark/sql/tests/connect/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
diff --git a/python/pyspark/sql/tests/connect/utils/__init__.py b/python/pyspark/sql/tests/connect/utils/__init__.py
deleted file mode 100644
index b95812c8a297e..0000000000000
--- a/python/pyspark/sql/tests/connect/utils/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from pyspark.sql.tests.connect.utils.spark_connect_test_utils import ( # noqa: F401
- PlanOnlyTestFixture, # noqa: F401
-) # noqa: F401
diff --git a/python/pyspark/sql/tests/connect/test_spark_connect.py b/python/pyspark/sql/tests/test_connect_basic.py
similarity index 91%
rename from python/pyspark/sql/tests/connect/test_spark_connect.py
rename to python/pyspark/sql/tests/test_connect_basic.py
index 7e891c5cf19f8..3e83e1bd6eaa9 100644
--- a/python/pyspark/sql/tests/connect/test_spark_connect.py
+++ b/python/pyspark/sql/tests/test_connect_basic.py
@@ -22,13 +22,18 @@
from pyspark.sql import SparkSession, Row
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.function_builder import udf
+from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import ReusedPySparkTestCase
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectSQLTestCase(ReusedPySparkTestCase):
"""Parent test fixture class for all Spark Connect related
test cases."""
+ connect = RemoteSparkSession
+ tbl_name = str
+
@classmethod
def setUpClass(cls: Any) -> None:
ReusedPySparkTestCase.setUpClass()
@@ -55,7 +60,6 @@ def spark_connect_test_data(cls: Any) -> None:
class SparkConnectTests(SparkConnectSQLTestCase):
def test_simple_read(self) -> None:
- """Tests that we can access the Spark Connect GRPC service locally."""
df = self.connect.read.table(self.tbl_name)
data = df.limit(10).toPandas()
# Check that the limit is applied
@@ -77,7 +81,7 @@ def test_simple_explain_string(self) -> None:
if __name__ == "__main__":
- from pyspark.sql.tests.connect.test_spark_connect import * # noqa: F401
+ from pyspark.sql.tests.test_connect_basic import * # noqa: F401
try:
import xmlrunner # type: ignore
diff --git a/python/pyspark/sql/tests/connect/test_column_expressions.py b/python/pyspark/sql/tests/test_connect_column_expressions.py
similarity index 94%
rename from python/pyspark/sql/tests/connect/test_column_expressions.py
rename to python/pyspark/sql/tests/test_connect_column_expressions.py
index 1f067bf799562..fc80d137b6cbb 100644
--- a/python/pyspark/sql/tests/connect/test_column_expressions.py
+++ b/python/pyspark/sql/tests/test_connect_column_expressions.py
@@ -14,12 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from pyspark.sql.tests.connect.utils import PlanOnlyTestFixture
+from pyspark.testing.connectutils import PlanOnlyTestFixture
import pyspark.sql.connect as c
import pyspark.sql.connect.plan as p
import pyspark.sql.connect.column as col
-
import pyspark.sql.connect.functions as fun
@@ -54,7 +53,7 @@ def test_column_literals(self):
if __name__ == "__main__":
import unittest
- from pyspark.sql.tests.connect.test_column_expressions import * # noqa: F401
+ from pyspark.sql.tests.test_connect_column_expressions import * # noqa: F401
try:
import xmlrunner # type: ignore
diff --git a/python/pyspark/sql/tests/connect/test_plan_only.py b/python/pyspark/sql/tests/test_connect_plan_only.py
similarity index 94%
rename from python/pyspark/sql/tests/connect/test_plan_only.py
rename to python/pyspark/sql/tests/test_connect_plan_only.py
index 9e6d30cbe1fda..ad59a682e9bea 100644
--- a/python/pyspark/sql/tests/connect/test_plan_only.py
+++ b/python/pyspark/sql/tests/test_connect_plan_only.py
@@ -16,10 +16,10 @@
#
import unittest
+from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.plan import Read
from pyspark.sql.connect.function_builder import UserDefinedFunction, udf
-from pyspark.sql.tests.connect.utils.spark_connect_test_utils import PlanOnlyTestFixture
from pyspark.sql.types import StringType
@@ -64,7 +64,7 @@ def read_table(x):
if __name__ == "__main__":
- from pyspark.sql.tests.connect.test_plan_only import * # noqa: F401
+ from pyspark.sql.tests.test_connect_plan_only import * # noqa: F401
try:
import xmlrunner # type: ignore
diff --git a/python/pyspark/sql/tests/connect/test_select_ops.py b/python/pyspark/sql/tests/test_connect_select_ops.py
similarity index 92%
rename from python/pyspark/sql/tests/connect/test_select_ops.py
rename to python/pyspark/sql/tests/test_connect_select_ops.py
index 818f82b33e863..fc624b0d5cc29 100644
--- a/python/pyspark/sql/tests/connect/test_select_ops.py
+++ b/python/pyspark/sql/tests/test_connect_select_ops.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from pyspark.sql.tests.connect.utils import PlanOnlyTestFixture
+from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.functions import col
from pyspark.sql.connect.plan import Read, InputValidationError
@@ -29,7 +29,7 @@ def test_select_with_literal(self):
if __name__ == "__main__":
import unittest
- from pyspark.sql.tests.connect.test_select_ops import * # noqa: F401
+ from pyspark.sql.tests.test_connect_select_ops import * # noqa: F401
try:
import xmlrunner # type: ignore
diff --git a/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py b/python/pyspark/testing/connectutils.py
similarity index 60%
rename from python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py
rename to python/pyspark/testing/connectutils.py
index 34bf49db49458..dc66526010fdb 100644
--- a/python/pyspark/sql/tests/connect/utils/spark_connect_test_utils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -14,11 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import os
from typing import Any, Dict
import functools
import unittest
import uuid
+from pyspark.testing.utils import search_jar
+
+
+connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect")
+if connect_jar is None:
+ connect_requirement_message = (
+ "Skipping all Spark Connect Python tests as the optional Spark Connect project was "
+ "not compiled into a JAR. To run these tests, you need to build Spark with "
+ "'build/sbt package' or 'build/mvn package' before running this test."
+ )
+else:
+ existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ jars_args = "--jars %s" % connect_jar
+ plugin_args = "--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin"
+ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args])
+ connect_requirement_message = None # type: ignore
+
+should_test_connect = connect_requirement_message is None
+
class MockRemoteSession:
def __init__(self) -> None:
@@ -33,6 +53,7 @@ def __getattr__(self, item: str) -> Any:
return functools.partial(self.hooks[item])
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
class PlanOnlyTestFixture(unittest.TestCase):
@classmethod
def setUpClass(cls: Any) -> None:
diff --git a/python/run-tests.py b/python/run-tests.py
index af4c6f1c94bef..19e39c822cbb4 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -110,15 +110,6 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_
metastore_dir = os.path.join(metastore_dir, str(uuid.uuid4()))
os.mkdir(metastore_dir)
- # Check if we should enable the SparkConnectPlugin
- additional_config = []
- if test_name.startswith("pyspark.sql.tests.connect"):
- # Adding Spark Connect JAR and Config
- additional_config += [
- "--conf",
- "spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin"
- ]
-
# Also override the JVM's temp directory by setting driver and executor options.
java_options = "-Djava.io.tmpdir={0}".format(tmp_dir)
java_options = java_options + " -Dio.netty.tryReflectionSetAccessible=true -Xss4M"
@@ -126,9 +117,8 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_
"--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.sql.warehouse.dir='{0}'".format(metastore_dir),
+ "pyspark-shell",
]
- spark_args += additional_config
- spark_args += ["pyspark-shell"]
env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)