diff --git a/python/docs/source/reference/pyspark.errors.rst b/python/docs/source/reference/pyspark.errors.rst index 1d54c6babe0bb..13db9bd01fa73 100644 --- a/python/docs/source/reference/pyspark.errors.rst +++ b/python/docs/source/reference/pyspark.errors.rst @@ -30,6 +30,7 @@ Classes PySparkException AnalysisException + TempTableAlreadyExistsException ParseException IllegalArgumentException StreamingQueryException @@ -37,12 +38,6 @@ Classes PythonException UnknownException SparkUpgradeException - SparkConnectAnalysisException - SparkConnectException - SparkConnectGrpcException - SparkConnectParseException - SparkConnectTempTableAlreadyExistsException - SparkConnectIllegalArgumentException Methods diff --git a/python/pyspark/errors/__init__.py b/python/pyspark/errors/__init__.py index 7faa0768a24cd..95da7ca2aa89f 100644 --- a/python/pyspark/errors/__init__.py +++ b/python/pyspark/errors/__init__.py @@ -18,9 +18,10 @@ """ PySpark exceptions. """ -from pyspark.errors.exceptions import ( # noqa: F401 +from pyspark.errors.exceptions.base import ( # noqa: F401 PySparkException, AnalysisException, + TempTableAlreadyExistsException, ParseException, IllegalArgumentException, StreamingQueryException, @@ -30,18 +31,13 @@ SparkUpgradeException, PySparkTypeError, PySparkValueError, - SparkConnectException, - SparkConnectGrpcException, - SparkConnectAnalysisException, - SparkConnectParseException, - SparkConnectTempTableAlreadyExistsException, - SparkConnectIllegalArgumentException, ) __all__ = [ "PySparkException", "AnalysisException", + "TempTableAlreadyExistsException", "ParseException", "IllegalArgumentException", "StreamingQueryException", @@ -51,10 +47,4 @@ "SparkUpgradeException", "PySparkTypeError", "PySparkValueError", - "SparkConnectException", - "SparkConnectGrpcException", - "SparkConnectAnalysisException", - "SparkConnectParseException", - "SparkConnectTempTableAlreadyExistsException", - "SparkConnectIllegalArgumentException", ] diff --git a/python/pyspark/errors/exceptions/__init__.py b/python/pyspark/errors/exceptions/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/errors/exceptions/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py new file mode 100644 index 0000000000000..6e67039374d90 --- /dev/null +++ b/python/pyspark/errors/exceptions/base.py @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, Optional, cast + +from pyspark.errors.utils import ErrorClassesReader + + +class PySparkException(Exception): + """ + Base Exception for handling errors generated from PySpark. + """ + + def __init__( + self, + message: Optional[str] = None, + error_class: Optional[str] = None, + message_parameters: Optional[Dict[str, str]] = None, + ): + # `message` vs `error_class` & `message_parameters` are mutually exclusive. + assert (message is not None and (error_class is None and message_parameters is None)) or ( + message is None and (error_class is not None and message_parameters is not None) + ) + + self.error_reader = ErrorClassesReader() + + if message is None: + self.message = self.error_reader.get_error_message( + cast(str, error_class), cast(Dict[str, str], message_parameters) + ) + else: + self.message = message + + self.error_class = error_class + self.message_parameters = message_parameters + + def getErrorClass(self) -> Optional[str]: + """ + Returns an error class as a string. + + .. versionadded:: 3.4.0 + + See Also + -------- + :meth:`PySparkException.getMessageParameters` + :meth:`PySparkException.getSqlState` + """ + return self.error_class + + def getMessageParameters(self) -> Optional[Dict[str, str]]: + """ + Returns a message parameters as a dictionary. + + .. versionadded:: 3.4.0 + + See Also + -------- + :meth:`PySparkException.getErrorClass` + :meth:`PySparkException.getSqlState` + """ + return self.message_parameters + + def getSqlState(self) -> None: + """ + Returns an SQLSTATE as a string. + + Errors generated in Python have no SQLSTATE, so it always returns None. + + .. versionadded:: 3.4.0 + + See Also + -------- + :meth:`PySparkException.getErrorClass` + :meth:`PySparkException.getMessageParameters` + """ + return None + + def __str__(self) -> str: + if self.getErrorClass() is not None: + return f"[{self.getErrorClass()}] {self.message}" + else: + return self.message + + +class AnalysisException(PySparkException): + """ + Failed to analyze a SQL query plan. + """ + + +class TempTableAlreadyExistsException(AnalysisException): + """ + Failed to create temp view since it is already exists. + """ + + +class ParseException(PySparkException): + """ + Failed to parse a SQL command. + """ + + +class IllegalArgumentException(PySparkException): + """ + Passed an illegal or inappropriate argument. + """ + + +class StreamingQueryException(PySparkException): + """ + Exception that stopped a :class:`StreamingQuery`. + """ + + +class QueryExecutionException(PySparkException): + """ + Failed to execute a query. + """ + + +class PythonException(PySparkException): + """ + Exceptions thrown from Python workers. + """ + + +class UnknownException(PySparkException): + """ + None of the above exceptions. + """ + + +class SparkUpgradeException(PySparkException): + """ + Exception thrown because of Spark upgrade. + """ + + +class PySparkValueError(PySparkException, ValueError): + """ + Wrapper class for ValueError to support error classes. + """ + + +class PySparkTypeError(PySparkException, TypeError): + """ + Wrapper class for TypeError to support error classes. + """ diff --git a/python/pyspark/errors/exceptions.py b/python/pyspark/errors/exceptions/captured.py similarity index 59% rename from python/pyspark/errors/exceptions.py rename to python/pyspark/errors/exceptions/captured.py index a799f4522debb..1764ed7d02c28 100644 --- a/python/pyspark/errors/exceptions.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -22,83 +22,17 @@ from py4j.java_gateway import is_instance_of from pyspark import SparkContext -from pyspark.errors.utils import ErrorClassesReader - - -class PySparkException(Exception): - """ - Base Exception for handling errors generated from PySpark. - """ - - def __init__( - self, - message: Optional[str] = None, - error_class: Optional[str] = None, - message_parameters: Optional[Dict[str, str]] = None, - ): - # `message` vs `error_class` & `message_parameters` are mutually exclusive. - assert (message is not None and (error_class is None and message_parameters is None)) or ( - message is None and (error_class is not None and message_parameters is not None) - ) - - self.error_reader = ErrorClassesReader() - - if message is None: - self.message = self.error_reader.get_error_message( - cast(str, error_class), cast(Dict[str, str], message_parameters) - ) - else: - self.message = message - - self.error_class = error_class - self.message_parameters = message_parameters - - def getErrorClass(self) -> Optional[str]: - """ - Returns an error class as a string. - - .. versionadded:: 3.4.0 - - See Also - -------- - :meth:`PySparkException.getMessageParameters` - :meth:`PySparkException.getSqlState` - """ - return self.error_class - - def getMessageParameters(self) -> Optional[Dict[str, str]]: - """ - Returns a message parameters as a dictionary. - - .. versionadded:: 3.4.0 - - See Also - -------- - :meth:`PySparkException.getErrorClass` - :meth:`PySparkException.getSqlState` - """ - return self.message_parameters - - def getSqlState(self) -> None: - """ - Returns an SQLSTATE as a string. - - Errors generated in Python have no SQLSTATE, so it always returns None. - - .. versionadded:: 3.4.0 - - See Also - -------- - :meth:`PySparkException.getErrorClass` - :meth:`PySparkException.getMessageParameters` - """ - return None - - def __str__(self) -> str: - if self.getErrorClass() is not None: - return f"[{self.getErrorClass()}] {self.message}" - else: - return self.message +from pyspark.errors.exceptions.base import ( + AnalysisException as BaseAnalysisException, + IllegalArgumentException as BaseIllegalArgumentException, + ParseException as BaseParseException, + PySparkException, + PythonException as BasePythonException, + QueryExecutionException as BaseQueryExecutionException, + SparkUpgradeException as BaseSparkUpgradeException, + StreamingQueryException as BaseStreamingQueryException, + UnknownException as BaseUnknownException, +) class CapturedException(PySparkException): @@ -247,133 +181,49 @@ def install_exception_handler() -> None: py4j.java_gateway.get_return_value = patched -class AnalysisException(CapturedException): +class AnalysisException(CapturedException, BaseAnalysisException): """ Failed to analyze a SQL query plan. """ -class ParseException(CapturedException): +class ParseException(CapturedException, BaseParseException): """ Failed to parse a SQL command. """ -class IllegalArgumentException(CapturedException): +class IllegalArgumentException(CapturedException, BaseIllegalArgumentException): """ Passed an illegal or inappropriate argument. """ -class StreamingQueryException(CapturedException): +class StreamingQueryException(CapturedException, BaseStreamingQueryException): """ Exception that stopped a :class:`StreamingQuery`. """ -class QueryExecutionException(CapturedException): +class QueryExecutionException(CapturedException, BaseQueryExecutionException): """ Failed to execute a query. """ -class PythonException(CapturedException): +class PythonException(CapturedException, BasePythonException): """ Exceptions thrown from Python workers. """ -class UnknownException(CapturedException): +class UnknownException(CapturedException, BaseUnknownException): """ None of the above exceptions. """ -class SparkUpgradeException(CapturedException): +class SparkUpgradeException(CapturedException, BaseSparkUpgradeException): """ Exception thrown because of Spark upgrade. """ - - -class SparkConnectException(PySparkException): - """ - Exception thrown from Spark Connect. - """ - - -class SparkConnectGrpcException(SparkConnectException): - """ - Base class to handle the errors from GRPC. - """ - - def __init__( - self, - message: Optional[str] = None, - error_class: Optional[str] = None, - message_parameters: Optional[Dict[str, str]] = None, - reason: Optional[str] = None, - ) -> None: - self.message = message # type: ignore[assignment] - if reason is not None: - self.message = f"({reason}) {self.message}" - - super().__init__( - message=self.message, - error_class=error_class, - message_parameters=message_parameters, - ) - - -class SparkConnectAnalysisException(SparkConnectGrpcException): - """ - Failed to analyze a SQL query plan from Spark Connect server. - """ - - def __init__( - self, - message: Optional[str] = None, - error_class: Optional[str] = None, - message_parameters: Optional[Dict[str, str]] = None, - plan: Optional[str] = None, - reason: Optional[str] = None, - ) -> None: - self.message = message # type: ignore[assignment] - if plan is not None: - self.message = f"{self.message}\nPlan: {plan}" - - super().__init__( - message=self.message, - error_class=error_class, - message_parameters=message_parameters, - reason=reason, - ) - - -class SparkConnectParseException(SparkConnectGrpcException): - """ - Failed to parse a SQL command from Spark Connect server. - """ - - -class SparkConnectTempTableAlreadyExistsException(SparkConnectAnalysisException): - """ - Failed to create temp view since it is already exists. - """ - - -class PySparkValueError(PySparkException, ValueError): - """ - Wrapper class for ValueError to support error classes. - """ - - -class PySparkTypeError(PySparkException, TypeError): - """ - Wrapper class for TypeError to support error classes. - """ - - -class SparkConnectIllegalArgumentException(SparkConnectGrpcException): - """ - Passed an illegal or inappropriate argument from Spark Connect server. - """ diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py new file mode 100644 index 0000000000000..ba3bc9f7576b7 --- /dev/null +++ b/python/pyspark/errors/exceptions/connect.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, Optional + +from pyspark.errors.exceptions.base import ( + AnalysisException as BaseAnalysisException, + IllegalArgumentException as BaseIllegalArgumentException, + ParseException as BaseParseException, + PySparkException, + PythonException as BasePythonException, + TempTableAlreadyExistsException as BaseTempTableAlreadyExistsException, +) + + +class SparkConnectException(PySparkException): + """ + Exception thrown from Spark Connect. + """ + + +class SparkConnectGrpcException(SparkConnectException): + """ + Base class to handle the errors from GRPC. + """ + + def __init__( + self, + message: Optional[str] = None, + error_class: Optional[str] = None, + message_parameters: Optional[Dict[str, str]] = None, + reason: Optional[str] = None, + ) -> None: + self.message = message # type: ignore[assignment] + if reason is not None: + self.message = f"({reason}) {self.message}" + + super().__init__( + message=self.message, + error_class=error_class, + message_parameters=message_parameters, + ) + + +class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): + """ + Failed to analyze a SQL query plan from Spark Connect server. + """ + + def __init__( + self, + message: Optional[str] = None, + error_class: Optional[str] = None, + message_parameters: Optional[Dict[str, str]] = None, + plan: Optional[str] = None, + reason: Optional[str] = None, + ) -> None: + self.message = message # type: ignore[assignment] + if plan is not None: + self.message = f"{self.message}\nPlan: {plan}" + + super().__init__( + message=self.message, + error_class=error_class, + message_parameters=message_parameters, + reason=reason, + ) + + +class TempTableAlreadyExistsException(AnalysisException, BaseTempTableAlreadyExistsException): + """ + Failed to create temp view from Spark Connect server since it is already exists. + """ + + +class ParseException(SparkConnectGrpcException, BaseParseException): + """ + Failed to parse a SQL command from Spark Connect server. + """ + + +class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException): + """ + Passed an illegal or inappropriate argument from Spark Connect server. + """ + + +class PythonException(SparkConnectGrpcException, BasePythonException): + """ + Exceptions thrown from Spark Connect server. + """ diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index a417f754a3625..6deee786164de 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -959,7 +959,7 @@ def isCached(self, tableName: str) -> bool: Throw an analysis exception when the table does not exist. - >>> spark.catalog.isCached("not_existing_table") # doctest: +SKIP + >>> spark.catalog.isCached("not_existing_table") Traceback (most recent call last): ... AnalysisException: ... @@ -997,7 +997,7 @@ def cacheTable(self, tableName: str) -> None: Throw an analysis exception when the table does not exist. - >>> spark.catalog.cacheTable("not_existing_table") # doctest: +SKIP + >>> spark.catalog.cacheTable("not_existing_table") Traceback (most recent call last): ... AnalysisException: ... @@ -1037,7 +1037,7 @@ def uncacheTable(self, tableName: str) -> None: Throw an analysis exception when the table does not exist. - >>> spark.catalog.uncacheTable("not_existing_table") # doctest: +SKIP + >>> spark.catalog.uncacheTable("not_existing_table") Traceback (most recent call last): ... AnalysisException: ... diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 36f1328b1bd08..8cf5fa5069357 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -58,13 +58,14 @@ import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib import pyspark.sql.connect.types as types -from pyspark.errors import ( +from pyspark.errors.exceptions.connect import ( + AnalysisException, + ParseException, + PythonException, SparkConnectException, SparkConnectGrpcException, - SparkConnectAnalysisException, - SparkConnectParseException, - SparkConnectTempTableAlreadyExistsException, - SparkConnectIllegalArgumentException, + TempTableAlreadyExistsException, + IllegalArgumentException, ) from pyspark.sql.types import ( DataType, @@ -672,22 +673,26 @@ def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn: d.Unpack(info) reason = info.reason if reason == "org.apache.spark.sql.AnalysisException": - raise SparkConnectAnalysisException( + raise AnalysisException( info.metadata["message"], plan=info.metadata["plan"] ) from None elif reason == "org.apache.spark.sql.catalyst.parser.ParseException": - raise SparkConnectParseException(info.metadata["message"]) from None + raise ParseException(info.metadata["message"]) from None elif ( reason == "org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException" ): - raise SparkConnectTempTableAlreadyExistsException( + raise TempTableAlreadyExistsException( info.metadata["message"], plan=info.metadata["plan"] ) from None elif reason == "java.lang.IllegalArgumentException": message = info.metadata["message"] message = message if message != "" else status.message - raise SparkConnectIllegalArgumentException(message) from None + raise IllegalArgumentException(message) from None + elif reason == "org.apache.spark.api.python.PythonException": + message = info.metadata["message"] + message = message if message != "" else status.message + raise PythonException(message) from None else: raise SparkConnectGrpcException( status.message, reason=info.reason diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 1c9e740474ba3..d4984b1ba6741 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -37,10 +37,7 @@ import numpy as np -from pyspark.errors.exceptions import ( - PySparkTypeError, - PySparkValueError, -) +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( CaseWhen, diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c3c83d48bd98d..3c47ebfb97365 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -41,7 +41,7 @@ from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.udf import UDFRegistration # noqa: F401 -from pyspark.errors.exceptions import install_exception_handler +from pyspark.errors.exceptions.captured import install_exception_handler from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.sql.types import AtomicType, DataType, StructType diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8bee517de6af7..39a3f036cc38c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -38,10 +38,7 @@ ) from pyspark import SparkContext -from pyspark.errors.exceptions import ( - PySparkTypeError, - PySparkValueError, -) +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal from pyspark.sql.dataframe import DataFrame diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 36ad15006870d..942e1da95c8ac 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -60,7 +60,7 @@ _parse_datatype_string, _from_numpy_type, ) -from pyspark.errors.exceptions import install_exception_handler +from pyspark.errors.exceptions.captured import install_exception_handler from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str if TYPE_CHECKING: diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index a577e99d0c26d..3c43628bf3780 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -22,6 +22,9 @@ from py4j.java_gateway import JavaObject, java_import from pyspark.errors import StreamingQueryException +from pyspark.errors.exceptions.captured import ( + StreamingQueryException as CapturedStreamingQueryException, +) from pyspark.sql.streaming.listener import StreamingQueryListener __all__ = ["StreamingQuery", "StreamingQueryManager"] @@ -387,7 +390,7 @@ def exception(self) -> Optional[StreamingQueryException]: je = self._jsq.exception().get() msg = je.toString().split(": ", 1)[1] # Drop the Java StreamingQueryException type info stackTrace = "\n\t at ".join(map(lambda x: x.toString(), je.getStackTrace())) - return StreamingQueryException(msg, stackTrace, je.getCause()) + return CapturedStreamingQueryException(msg, stackTrace, je.getCause()) else: return None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a9beb71545d04..eebfaaa39d841 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -49,11 +49,11 @@ connect_requirement_message, ) from pyspark.testing.pandasutils import PandasOnSparkTestUtils -from pyspark.errors import ( +from pyspark.errors.exceptions.connect import ( + AnalysisException, + ParseException, SparkConnectException, - SparkConnectAnalysisException, - SparkConnectParseException, - SparkConnectTempTableAlreadyExistsException, + TempTableAlreadyExistsException, ) if should_test_connect: @@ -223,7 +223,7 @@ def test_df_get_item(self): def test_error_handling(self): # SPARK-41533 Proper error handling for Spark Connect df = self.connect.range(10).select("id2") - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): df.collect() def test_simple_read(self): @@ -472,7 +472,7 @@ def test_with_local_ndarray(self): ): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) - with self.assertRaises(SparkConnectParseException): + with self.assertRaises(ParseException): self.connect.createDataFrame( data, "col1 magic_type, col2 int, col3 int, col4 int" ).show() @@ -518,7 +518,7 @@ def test_with_local_list(self): ): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) - with self.assertRaises(SparkConnectParseException): + with self.assertRaises(ParseException): self.connect.createDataFrame( data, "col1 magic_type, col2 int, col3 int, col4 int" ).show() @@ -1005,7 +1005,7 @@ def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): # incompatible field nullability schema = StructType([StructField("id", LongType(), False)]) self.assertRaisesRegex( - SparkConnectAnalysisException, + AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: cdf.to(schema).toPandas(), ) @@ -1013,7 +1013,7 @@ def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): # field cannot upcast schema = StructType([StructField("name", LongType())]) self.assertRaisesRegex( - SparkConnectAnalysisException, + AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: cdf.to(schema).toPandas(), ) @@ -1025,7 +1025,7 @@ def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): ] ) self.assertRaisesRegex( - SparkConnectAnalysisException, + AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: cdf.to(schema).toPandas(), ) @@ -1244,7 +1244,7 @@ def test_create_global_temp_view(self): # Test when creating a view which is already exists but self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) - with self.assertRaises(SparkConnectTempTableAlreadyExistsException): + with self.assertRaises(TempTableAlreadyExistsException): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") def test_create_session_local_temp_view(self): @@ -1256,7 +1256,7 @@ def test_create_session_local_temp_view(self): self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0) # Test when creating a view which is already exists but - with self.assertRaises(SparkConnectTempTableAlreadyExistsException): + with self.assertRaises(TempTableAlreadyExistsException): self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp") def test_to_pandas(self): @@ -1499,7 +1499,7 @@ def test_replace(self): self.connect.sql(query).replace({None: 1}, subset="a").toPandas() self.assertTrue("Mixed type replacements are not supported" in str(context.exception)) - with self.assertRaises(SparkConnectAnalysisException) as context: + with self.assertRaises(AnalysisException) as context: self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas() self.assertIn( """Cannot resolve column name "x" among (a, b, c)""", str(context.exception) @@ -1607,7 +1607,7 @@ def test_hint(self): ) # Hint with unsupported parameter values - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas() # Hint with unsupported parameter types @@ -1621,7 +1621,7 @@ def test_hint(self): ).toPandas() # Hint with wrong combination - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas() def test_join_hint(self): @@ -1950,7 +1950,7 @@ def test_repartition_by_expression(self) -> None: ) # repartition with unsupported parameter values - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).repartition("id+1").toPandas() def test_repartition_by_range(self) -> None: @@ -1974,7 +1974,7 @@ def test_repartition_by_range(self) -> None: ) # repartitionByRange with unsupported parameter values - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas() def test_agg_with_two_agg_exprs(self) -> None: diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 624bdf4f539d7..a2c786db180d4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -41,7 +41,7 @@ DecimalType, BooleanType, ) -from pyspark.errors import SparkConnectException +from pyspark.errors.exceptions.connect import SparkConnectException from pyspark.testing.connectutils import should_test_connect from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index e3e668eb83590..243153be0832e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -23,7 +23,7 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.sqlutils import SQLTestUtils -from pyspark.errors import SparkConnectAnalysisException, SparkConnectException +from pyspark.errors.exceptions.connect import AnalysisException, SparkConnectException class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): @@ -899,7 +899,7 @@ def test_window_functions(self): cdf.select(CF.rank().over(cdf.a)) # invalid window function - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select(cdf.b.over(CW.orderBy("b"))).show() # invalid window frame @@ -913,34 +913,34 @@ def test_window_functions(self): CF.lead("c", 1), CF.ntile(1), ]: - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.unboundedPreceding, CW.currentRow)) ).show() # Function 'cume_dist' requires Windowframe(RangeFrame, UnboundedPreceding, CurrentRow) ccol = CF.cume_dist() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.unboundedPreceding, CW.currentRow)) ).show() diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py b/python/pyspark/sql/tests/connect/test_parity_column.py index e157638935164..5cce063871ab8 100644 --- a/python/pyspark/sql/tests/connect/test_parity_column.py +++ b/python/pyspark/sql/tests/connect/test_parity_column.py @@ -32,7 +32,7 @@ class ColumnParityTests(ColumnTestsMixin, ReusedConnectTestCase): - # TODO(SPARK-42017): Different error type AnalysisException vs SparkConnectAnalysisException + # TODO(SPARK-42017): df["bad_key"] does not raise AnalysisException @unittest.skip("Fails in Spark Connect, should enable.") def test_access_column(self): super().test_access_column() diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index d3807285f3ebb..7e6735cb7cdb9 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -85,6 +85,11 @@ def test_require_cross(self): def test_same_semantics_error(self): super().test_same_semantics_error() + # TODO(SPARK-42338): Different exception in DataFrame.sample + @unittest.skip("Fails in Spark Connect, should enable.") + def test_sample(self): + super().test_sample() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_toDF_with_schema_string(self): super().test_toDF_with_schema_string() diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index b151986cb24ba..3d390c13913fa 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -17,11 +17,15 @@ import unittest +from pyspark.errors.exceptions.connect import SparkConnectException from pyspark.sql.tests.test_functions import FunctionsTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase): + def test_assert_true(self): + self.check_assert_true(SparkConnectException) + @unittest.skip("Spark Connect does not support Spark Context but the test depends on that.") def test_basic_functions(self): super().test_basic_functions() @@ -49,6 +53,9 @@ def test_lit_list(self): def test_lit_np_scalar(self): super().test_lit_np_scalar() + def test_raise_error(self): + self.check_assert_true(SparkConnectException) + # Comparing column type of connect and pyspark @unittest.skip("Fails in Spark Connect, should enable.") def test_sorting_functions_with_column(self): diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py index b4d1a9dd31a8c..4b1ce0a958788 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py @@ -19,9 +19,6 @@ from pyspark.sql.tests.pandas.test_pandas_udf import PandasUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.errors.exceptions import SparkConnectGrpcException -from pyspark.sql.connect.functions import udf -from pyspark.sql.functions import pandas_udf, PandasUDFType class PandasUDFParityTests(PandasUDFTestsMixin, ReusedConnectTestCase): @@ -53,24 +50,10 @@ def test_pandas_udf_decorator(self): def test_pandas_udf_basic(self): super().test_pandas_udf_basic() - def test_stopiteration_in_udf(self): - # The vanilla PySpark throws PythonException instead. - def foo(x): - raise StopIteration() - - exc_message = "Caught StopIteration thrown from user's code; failing the task" - df = self.spark.range(0, 100) - - self.assertRaisesRegex( - SparkConnectGrpcException, exc_message, df.withColumn("v", udf(foo)("id")).collect - ) - - # pandas scalar udf - self.assertRaisesRegex( - SparkConnectGrpcException, - exc_message, - df.withColumn("v", pandas_udf(foo, "double", PandasUDFType.SCALAR)("id")).collect, - ) + # TODO(SPARK-42340): implement GroupedData.applyInPandas + @unittest.skip("Fails in Spark Connect, should enable.") + def test_stopiteration_in_grouped_map(self): + super().test_stopiteration_in_grouped_map() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py index f74d21f0d8621..8d4bb69bf1633 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf.py @@ -27,9 +27,6 @@ from pyspark.sql.tests.test_udf import BaseUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.errors.exceptions import SparkConnectAnalysisException -from pyspark.sql.connect.functions import udf -from pyspark.sql.types import BooleanType class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase): @@ -182,26 +179,6 @@ def test_udf_with_string_return_type(self): def test_udf_in_subquery(self): super().test_udf_in_subquery() - def test_udf_not_supported_in_join_condition(self): - # The vanilla PySpark throws AnalysisException instead. - # test python udf is not supported in join type except inner join. - left = self.spark.createDataFrame([(1, 1, 1), (2, 2, 2)], ["a", "a1", "a2"]) - right = self.spark.createDataFrame([(1, 1, 1), (1, 3, 1)], ["b", "b1", "b2"]) - f = udf(lambda a, b: a == b, BooleanType()) - - def runWithJoinType(join_type, type_string): - with self.assertRaisesRegex( - SparkConnectAnalysisException, - """Python UDF in the ON clause of a %s JOIN.""" % type_string, - ): - left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() - - runWithJoinType("full", "FULL OUTER") - runWithJoinType("left", "LEFT OUTER") - runWithJoinType("right", "RIGHT OUTER") - runWithJoinType("leftanti", "LEFT ANTI") - runWithJoinType("leftsemi", "LEFT SEMI") - if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 768317ab60d74..1b3b4555d7ffd 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -171,9 +171,6 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - def foofoo(x, y): - raise StopIteration() - exc_message = "Caught StopIteration thrown from user's code; failing the task" df = self.spark.range(0, 100) @@ -189,6 +186,16 @@ def foofoo(x, y): df.withColumn("v", pandas_udf(foo, "double", PandasUDFType.SCALAR)("id")).collect, ) + def test_stopiteration_in_grouped_map(self): + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + # pandas grouped map self.assertRaisesRegex( PythonException, @@ -204,6 +211,13 @@ def foofoo(x, y): .collect, ) + def test_stopiteration_in_grouped_agg(self): + def foo(x): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + # pandas grouped agg self.assertRaisesRegex( PythonException, diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 5470f79ff2275..9f02ae848bf67 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -254,7 +254,7 @@ def test_stream_exception(self): self._assert_exception_tree_contains_msg(e, "ZeroDivisionError") finally: sq.stop() - self.assertTrue(type(sq.exception()) is StreamingQueryException) + self.assertIsInstance(sq.exception(), StreamingQueryException) self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError") def _assert_exception_tree_contains_msg(self, exception, msg): diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 033878470e193..1d52602a96f15 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -46,8 +46,6 @@ from pyspark.errors import ( AnalysisException, IllegalArgumentException, - SparkConnectException, - SparkConnectAnalysisException, PySparkTypeError, ) from pyspark.testing.sqlutils import ( @@ -948,8 +946,7 @@ def test_sample(self): self.assertRaises(TypeError, lambda: self.spark.range(1).sample(seed="abc")) self.assertRaises( - (IllegalArgumentException, SparkConnectException), - lambda: self.spark.range(1).sample(-1.0).count(), + IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count() ) def test_toDF_with_schema_string(self): @@ -1041,17 +1038,17 @@ def test_cache(self): self.assertFalse(spark.catalog.isCached("tab1")) self.assertFalse(spark.catalog.isCached("tab2")) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.isCached("does_not_exist"), ) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.cacheTable("does_not_exist"), ) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.uncacheTable("does_not_exist"), ) @@ -1595,17 +1592,13 @@ def test_to(self): # incompatible field nullability schema4 = StructType([StructField("j", LongType(), False)]) self.assertRaisesRegex( - (AnalysisException, SparkConnectAnalysisException), - "NULLABLE_COLUMN_OR_FIELD", - lambda: df.to(schema4).count(), + AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: df.to(schema4).count() ) # field cannot upcast schema5 = StructType([StructField("i", LongType())]) self.assertRaisesRegex( - (AnalysisException, SparkConnectAnalysisException), - "INVALID_COLUMN_OR_FIELD_DATA_TYPE", - lambda: df.to(schema5).count(), + AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: df.to(schema5).count() ) def test_repartition(self): diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 05492347755e1..d8343b4fb47ef 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -25,7 +25,7 @@ import unittest from py4j.protocol import Py4JJavaError -from pyspark.errors import PySparkTypeError, PySparkValueError, SparkConnectException +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import Row, Window, types from pyspark.sql.functions import ( udf, @@ -1056,6 +1056,9 @@ def test_datetime_functions(self): self.assertEqual(date(2017, 1, 22), parse_result["to_date(dateCol)"]) def test_assert_true(self): + self.check_assert_true(Py4JJavaError) + + def check_assert_true(self, tpe): from pyspark.sql.functions import assert_true df = self.spark.range(3) @@ -1065,10 +1068,10 @@ def test_assert_true(self): [Row(val=None), Row(val=None), Row(val=None)], ) - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "too big"): + with self.assertRaisesRegex(tpe, "too big"): df.select(assert_true(df.id < 2, "too big")).toDF("val").collect() - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "2000000"): + with self.assertRaisesRegex(tpe, "2000000"): df.select(assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect() with self.assertRaises(PySparkTypeError) as pe: @@ -1081,14 +1084,17 @@ def test_assert_true(self): ) def test_raise_error(self): + self.check_raise_error(Py4JJavaError) + + def check_raise_error(self, tpe): from pyspark.sql.functions import raise_error df = self.spark.createDataFrame([Row(id="foobar")]) - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "foobar"): + with self.assertRaisesRegex(tpe, "foobar"): df.select(raise_error(df.id)).collect() - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "barfoo"): + with self.assertRaisesRegex(tpe, "barfoo"): df.select(raise_error("barfoo")).collect() with self.assertRaises(PySparkTypeError) as pe: diff --git a/python/setup.py b/python/setup.py index af5c5f9384c39..ead1139f8f873 100755 --- a/python/setup.py +++ b/python/setup.py @@ -266,6 +266,7 @@ def run(self): "pyspark.licenses", "pyspark.resource", "pyspark.errors", + "pyspark.errors.exceptions", "pyspark.examples.src.main.python", ], include_package_data=True,