From bca1ee54e38258c4ea0def067883fa67071b7d3d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 8 Feb 2023 20:46:13 +0900 Subject: [PATCH] [SPARK-42342][PYTHON][CONNECT] Introduce base hierarchy to exceptions ### What changes were proposed in this pull request? Introduces base hierarchy to exceptions. As a common hierarchy for users, base exception classes are subclasses of `PySparkException`. The concrete classes for both PySpark and Spark Connect inherits the base classes that should not be exposed to users. ### Why are the changes needed? Currently exception class hierarchy is separated between PySpark and Spark Connect. If users want to check the exception type, they need to switch the error classes based on whether they are running on PySpark or Spark Connect, but it's not ideal. ### Does this PR introduce _any_ user-facing change? No. Users still can use the existing exception classes to check the exception type. ### How was this patch tested? Updated tests. Closes #39882 from ueshin/issues/SPARK-42342/exceptions. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon (cherry picked from commit bd34b162d4774bcc19371096a04972b03f423bca) Signed-off-by: Hyukjin Kwon --- .../docs/source/reference/pyspark.errors.rst | 7 +- python/pyspark/errors/__init__.py | 16 +- python/pyspark/errors/exceptions/__init__.py | 16 ++ python/pyspark/errors/exceptions/base.py | 162 +++++++++++++++ .../{exceptions.py => exceptions/captured.py} | 188 ++---------------- python/pyspark/errors/exceptions/connect.py | 105 ++++++++++ python/pyspark/sql/catalog.py | 6 +- python/pyspark/sql/connect/client.py | 23 ++- python/pyspark/sql/connect/functions.py | 5 +- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/functions.py | 5 +- python/pyspark/sql/session.py | 2 +- python/pyspark/sql/streaming/query.py | 5 +- .../sql/tests/connect/test_connect_basic.py | 34 ++-- .../sql/tests/connect/test_connect_column.py | 2 +- .../tests/connect/test_connect_function.py | 16 +- .../sql/tests/connect/test_parity_column.py | 2 +- .../tests/connect/test_parity_dataframe.py | 5 + .../tests/connect/test_parity_functions.py | 7 + .../tests/connect/test_parity_pandas_udf.py | 25 +-- .../sql/tests/connect/test_parity_udf.py | 23 --- .../sql/tests/pandas/test_pandas_udf.py | 20 +- .../sql/tests/streaming/test_streaming.py | 2 +- python/pyspark/sql/tests/test_dataframe.py | 19 +- python/pyspark/sql/tests/test_functions.py | 16 +- python/setup.py | 1 + 26 files changed, 410 insertions(+), 304 deletions(-) create mode 100644 python/pyspark/errors/exceptions/__init__.py create mode 100644 python/pyspark/errors/exceptions/base.py rename python/pyspark/errors/{exceptions.py => exceptions/captured.py} (59%) create mode 100644 python/pyspark/errors/exceptions/connect.py diff --git a/python/docs/source/reference/pyspark.errors.rst b/python/docs/source/reference/pyspark.errors.rst index 1d54c6babe0bb..13db9bd01fa73 100644 --- a/python/docs/source/reference/pyspark.errors.rst +++ b/python/docs/source/reference/pyspark.errors.rst @@ -30,6 +30,7 @@ Classes PySparkException AnalysisException + TempTableAlreadyExistsException ParseException IllegalArgumentException StreamingQueryException @@ -37,12 +38,6 @@ Classes PythonException UnknownException SparkUpgradeException - SparkConnectAnalysisException - SparkConnectException - SparkConnectGrpcException - SparkConnectParseException - SparkConnectTempTableAlreadyExistsException - SparkConnectIllegalArgumentException Methods diff --git a/python/pyspark/errors/__init__.py b/python/pyspark/errors/__init__.py index 7faa0768a24cd..95da7ca2aa89f 100644 --- a/python/pyspark/errors/__init__.py +++ b/python/pyspark/errors/__init__.py @@ -18,9 +18,10 @@ """ PySpark exceptions. """ -from pyspark.errors.exceptions import ( # noqa: F401 +from pyspark.errors.exceptions.base import ( # noqa: F401 PySparkException, AnalysisException, + TempTableAlreadyExistsException, ParseException, IllegalArgumentException, StreamingQueryException, @@ -30,18 +31,13 @@ SparkUpgradeException, PySparkTypeError, PySparkValueError, - SparkConnectException, - SparkConnectGrpcException, - SparkConnectAnalysisException, - SparkConnectParseException, - SparkConnectTempTableAlreadyExistsException, - SparkConnectIllegalArgumentException, ) __all__ = [ "PySparkException", "AnalysisException", + "TempTableAlreadyExistsException", "ParseException", "IllegalArgumentException", "StreamingQueryException", @@ -51,10 +47,4 @@ "SparkUpgradeException", "PySparkTypeError", "PySparkValueError", - "SparkConnectException", - "SparkConnectGrpcException", - "SparkConnectAnalysisException", - "SparkConnectParseException", - "SparkConnectTempTableAlreadyExistsException", - "SparkConnectIllegalArgumentException", ] diff --git a/python/pyspark/errors/exceptions/__init__.py b/python/pyspark/errors/exceptions/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/errors/exceptions/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py new file mode 100644 index 0000000000000..6e67039374d90 --- /dev/null +++ b/python/pyspark/errors/exceptions/base.py @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, Optional, cast + +from pyspark.errors.utils import ErrorClassesReader + + +class PySparkException(Exception): + """ + Base Exception for handling errors generated from PySpark. + """ + + def __init__( + self, + message: Optional[str] = None, + error_class: Optional[str] = None, + message_parameters: Optional[Dict[str, str]] = None, + ): + # `message` vs `error_class` & `message_parameters` are mutually exclusive. + assert (message is not None and (error_class is None and message_parameters is None)) or ( + message is None and (error_class is not None and message_parameters is not None) + ) + + self.error_reader = ErrorClassesReader() + + if message is None: + self.message = self.error_reader.get_error_message( + cast(str, error_class), cast(Dict[str, str], message_parameters) + ) + else: + self.message = message + + self.error_class = error_class + self.message_parameters = message_parameters + + def getErrorClass(self) -> Optional[str]: + """ + Returns an error class as a string. + + .. versionadded:: 3.4.0 + + See Also + -------- + :meth:`PySparkException.getMessageParameters` + :meth:`PySparkException.getSqlState` + """ + return self.error_class + + def getMessageParameters(self) -> Optional[Dict[str, str]]: + """ + Returns a message parameters as a dictionary. + + .. versionadded:: 3.4.0 + + See Also + -------- + :meth:`PySparkException.getErrorClass` + :meth:`PySparkException.getSqlState` + """ + return self.message_parameters + + def getSqlState(self) -> None: + """ + Returns an SQLSTATE as a string. + + Errors generated in Python have no SQLSTATE, so it always returns None. + + .. versionadded:: 3.4.0 + + See Also + -------- + :meth:`PySparkException.getErrorClass` + :meth:`PySparkException.getMessageParameters` + """ + return None + + def __str__(self) -> str: + if self.getErrorClass() is not None: + return f"[{self.getErrorClass()}] {self.message}" + else: + return self.message + + +class AnalysisException(PySparkException): + """ + Failed to analyze a SQL query plan. + """ + + +class TempTableAlreadyExistsException(AnalysisException): + """ + Failed to create temp view since it is already exists. + """ + + +class ParseException(PySparkException): + """ + Failed to parse a SQL command. + """ + + +class IllegalArgumentException(PySparkException): + """ + Passed an illegal or inappropriate argument. + """ + + +class StreamingQueryException(PySparkException): + """ + Exception that stopped a :class:`StreamingQuery`. + """ + + +class QueryExecutionException(PySparkException): + """ + Failed to execute a query. + """ + + +class PythonException(PySparkException): + """ + Exceptions thrown from Python workers. + """ + + +class UnknownException(PySparkException): + """ + None of the above exceptions. + """ + + +class SparkUpgradeException(PySparkException): + """ + Exception thrown because of Spark upgrade. + """ + + +class PySparkValueError(PySparkException, ValueError): + """ + Wrapper class for ValueError to support error classes. + """ + + +class PySparkTypeError(PySparkException, TypeError): + """ + Wrapper class for TypeError to support error classes. + """ diff --git a/python/pyspark/errors/exceptions.py b/python/pyspark/errors/exceptions/captured.py similarity index 59% rename from python/pyspark/errors/exceptions.py rename to python/pyspark/errors/exceptions/captured.py index a799f4522debb..1764ed7d02c28 100644 --- a/python/pyspark/errors/exceptions.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -22,83 +22,17 @@ from py4j.java_gateway import is_instance_of from pyspark import SparkContext -from pyspark.errors.utils import ErrorClassesReader - - -class PySparkException(Exception): - """ - Base Exception for handling errors generated from PySpark. - """ - - def __init__( - self, - message: Optional[str] = None, - error_class: Optional[str] = None, - message_parameters: Optional[Dict[str, str]] = None, - ): - # `message` vs `error_class` & `message_parameters` are mutually exclusive. - assert (message is not None and (error_class is None and message_parameters is None)) or ( - message is None and (error_class is not None and message_parameters is not None) - ) - - self.error_reader = ErrorClassesReader() - - if message is None: - self.message = self.error_reader.get_error_message( - cast(str, error_class), cast(Dict[str, str], message_parameters) - ) - else: - self.message = message - - self.error_class = error_class - self.message_parameters = message_parameters - - def getErrorClass(self) -> Optional[str]: - """ - Returns an error class as a string. - - .. versionadded:: 3.4.0 - - See Also - -------- - :meth:`PySparkException.getMessageParameters` - :meth:`PySparkException.getSqlState` - """ - return self.error_class - - def getMessageParameters(self) -> Optional[Dict[str, str]]: - """ - Returns a message parameters as a dictionary. - - .. versionadded:: 3.4.0 - - See Also - -------- - :meth:`PySparkException.getErrorClass` - :meth:`PySparkException.getSqlState` - """ - return self.message_parameters - - def getSqlState(self) -> None: - """ - Returns an SQLSTATE as a string. - - Errors generated in Python have no SQLSTATE, so it always returns None. - - .. versionadded:: 3.4.0 - - See Also - -------- - :meth:`PySparkException.getErrorClass` - :meth:`PySparkException.getMessageParameters` - """ - return None - - def __str__(self) -> str: - if self.getErrorClass() is not None: - return f"[{self.getErrorClass()}] {self.message}" - else: - return self.message +from pyspark.errors.exceptions.base import ( + AnalysisException as BaseAnalysisException, + IllegalArgumentException as BaseIllegalArgumentException, + ParseException as BaseParseException, + PySparkException, + PythonException as BasePythonException, + QueryExecutionException as BaseQueryExecutionException, + SparkUpgradeException as BaseSparkUpgradeException, + StreamingQueryException as BaseStreamingQueryException, + UnknownException as BaseUnknownException, +) class CapturedException(PySparkException): @@ -247,133 +181,49 @@ def install_exception_handler() -> None: py4j.java_gateway.get_return_value = patched -class AnalysisException(CapturedException): +class AnalysisException(CapturedException, BaseAnalysisException): """ Failed to analyze a SQL query plan. """ -class ParseException(CapturedException): +class ParseException(CapturedException, BaseParseException): """ Failed to parse a SQL command. """ -class IllegalArgumentException(CapturedException): +class IllegalArgumentException(CapturedException, BaseIllegalArgumentException): """ Passed an illegal or inappropriate argument. """ -class StreamingQueryException(CapturedException): +class StreamingQueryException(CapturedException, BaseStreamingQueryException): """ Exception that stopped a :class:`StreamingQuery`. """ -class QueryExecutionException(CapturedException): +class QueryExecutionException(CapturedException, BaseQueryExecutionException): """ Failed to execute a query. """ -class PythonException(CapturedException): +class PythonException(CapturedException, BasePythonException): """ Exceptions thrown from Python workers. """ -class UnknownException(CapturedException): +class UnknownException(CapturedException, BaseUnknownException): """ None of the above exceptions. """ -class SparkUpgradeException(CapturedException): +class SparkUpgradeException(CapturedException, BaseSparkUpgradeException): """ Exception thrown because of Spark upgrade. """ - - -class SparkConnectException(PySparkException): - """ - Exception thrown from Spark Connect. - """ - - -class SparkConnectGrpcException(SparkConnectException): - """ - Base class to handle the errors from GRPC. - """ - - def __init__( - self, - message: Optional[str] = None, - error_class: Optional[str] = None, - message_parameters: Optional[Dict[str, str]] = None, - reason: Optional[str] = None, - ) -> None: - self.message = message # type: ignore[assignment] - if reason is not None: - self.message = f"({reason}) {self.message}" - - super().__init__( - message=self.message, - error_class=error_class, - message_parameters=message_parameters, - ) - - -class SparkConnectAnalysisException(SparkConnectGrpcException): - """ - Failed to analyze a SQL query plan from Spark Connect server. - """ - - def __init__( - self, - message: Optional[str] = None, - error_class: Optional[str] = None, - message_parameters: Optional[Dict[str, str]] = None, - plan: Optional[str] = None, - reason: Optional[str] = None, - ) -> None: - self.message = message # type: ignore[assignment] - if plan is not None: - self.message = f"{self.message}\nPlan: {plan}" - - super().__init__( - message=self.message, - error_class=error_class, - message_parameters=message_parameters, - reason=reason, - ) - - -class SparkConnectParseException(SparkConnectGrpcException): - """ - Failed to parse a SQL command from Spark Connect server. - """ - - -class SparkConnectTempTableAlreadyExistsException(SparkConnectAnalysisException): - """ - Failed to create temp view since it is already exists. - """ - - -class PySparkValueError(PySparkException, ValueError): - """ - Wrapper class for ValueError to support error classes. - """ - - -class PySparkTypeError(PySparkException, TypeError): - """ - Wrapper class for TypeError to support error classes. - """ - - -class SparkConnectIllegalArgumentException(SparkConnectGrpcException): - """ - Passed an illegal or inappropriate argument from Spark Connect server. - """ diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py new file mode 100644 index 0000000000000..ba3bc9f7576b7 --- /dev/null +++ b/python/pyspark/errors/exceptions/connect.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, Optional + +from pyspark.errors.exceptions.base import ( + AnalysisException as BaseAnalysisException, + IllegalArgumentException as BaseIllegalArgumentException, + ParseException as BaseParseException, + PySparkException, + PythonException as BasePythonException, + TempTableAlreadyExistsException as BaseTempTableAlreadyExistsException, +) + + +class SparkConnectException(PySparkException): + """ + Exception thrown from Spark Connect. + """ + + +class SparkConnectGrpcException(SparkConnectException): + """ + Base class to handle the errors from GRPC. + """ + + def __init__( + self, + message: Optional[str] = None, + error_class: Optional[str] = None, + message_parameters: Optional[Dict[str, str]] = None, + reason: Optional[str] = None, + ) -> None: + self.message = message # type: ignore[assignment] + if reason is not None: + self.message = f"({reason}) {self.message}" + + super().__init__( + message=self.message, + error_class=error_class, + message_parameters=message_parameters, + ) + + +class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): + """ + Failed to analyze a SQL query plan from Spark Connect server. + """ + + def __init__( + self, + message: Optional[str] = None, + error_class: Optional[str] = None, + message_parameters: Optional[Dict[str, str]] = None, + plan: Optional[str] = None, + reason: Optional[str] = None, + ) -> None: + self.message = message # type: ignore[assignment] + if plan is not None: + self.message = f"{self.message}\nPlan: {plan}" + + super().__init__( + message=self.message, + error_class=error_class, + message_parameters=message_parameters, + reason=reason, + ) + + +class TempTableAlreadyExistsException(AnalysisException, BaseTempTableAlreadyExistsException): + """ + Failed to create temp view from Spark Connect server since it is already exists. + """ + + +class ParseException(SparkConnectGrpcException, BaseParseException): + """ + Failed to parse a SQL command from Spark Connect server. + """ + + +class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException): + """ + Passed an illegal or inappropriate argument from Spark Connect server. + """ + + +class PythonException(SparkConnectGrpcException, BasePythonException): + """ + Exceptions thrown from Spark Connect server. + """ diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index a417f754a3625..6deee786164de 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -959,7 +959,7 @@ def isCached(self, tableName: str) -> bool: Throw an analysis exception when the table does not exist. - >>> spark.catalog.isCached("not_existing_table") # doctest: +SKIP + >>> spark.catalog.isCached("not_existing_table") Traceback (most recent call last): ... AnalysisException: ... @@ -997,7 +997,7 @@ def cacheTable(self, tableName: str) -> None: Throw an analysis exception when the table does not exist. - >>> spark.catalog.cacheTable("not_existing_table") # doctest: +SKIP + >>> spark.catalog.cacheTable("not_existing_table") Traceback (most recent call last): ... AnalysisException: ... @@ -1037,7 +1037,7 @@ def uncacheTable(self, tableName: str) -> None: Throw an analysis exception when the table does not exist. - >>> spark.catalog.uncacheTable("not_existing_table") # doctest: +SKIP + >>> spark.catalog.uncacheTable("not_existing_table") Traceback (most recent call last): ... AnalysisException: ... diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 36f1328b1bd08..8cf5fa5069357 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -58,13 +58,14 @@ import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib import pyspark.sql.connect.types as types -from pyspark.errors import ( +from pyspark.errors.exceptions.connect import ( + AnalysisException, + ParseException, + PythonException, SparkConnectException, SparkConnectGrpcException, - SparkConnectAnalysisException, - SparkConnectParseException, - SparkConnectTempTableAlreadyExistsException, - SparkConnectIllegalArgumentException, + TempTableAlreadyExistsException, + IllegalArgumentException, ) from pyspark.sql.types import ( DataType, @@ -672,22 +673,26 @@ def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn: d.Unpack(info) reason = info.reason if reason == "org.apache.spark.sql.AnalysisException": - raise SparkConnectAnalysisException( + raise AnalysisException( info.metadata["message"], plan=info.metadata["plan"] ) from None elif reason == "org.apache.spark.sql.catalyst.parser.ParseException": - raise SparkConnectParseException(info.metadata["message"]) from None + raise ParseException(info.metadata["message"]) from None elif ( reason == "org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException" ): - raise SparkConnectTempTableAlreadyExistsException( + raise TempTableAlreadyExistsException( info.metadata["message"], plan=info.metadata["plan"] ) from None elif reason == "java.lang.IllegalArgumentException": message = info.metadata["message"] message = message if message != "" else status.message - raise SparkConnectIllegalArgumentException(message) from None + raise IllegalArgumentException(message) from None + elif reason == "org.apache.spark.api.python.PythonException": + message = info.metadata["message"] + message = message if message != "" else status.message + raise PythonException(message) from None else: raise SparkConnectGrpcException( status.message, reason=info.reason diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 1c9e740474ba3..d4984b1ba6741 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -37,10 +37,7 @@ import numpy as np -from pyspark.errors.exceptions import ( - PySparkTypeError, - PySparkValueError, -) +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( CaseWhen, diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c3c83d48bd98d..3c47ebfb97365 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -41,7 +41,7 @@ from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.udf import UDFRegistration # noqa: F401 -from pyspark.errors.exceptions import install_exception_handler +from pyspark.errors.exceptions.captured import install_exception_handler from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.sql.types import AtomicType, DataType, StructType diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8bee517de6af7..39a3f036cc38c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -38,10 +38,7 @@ ) from pyspark import SparkContext -from pyspark.errors.exceptions import ( - PySparkTypeError, - PySparkValueError, -) +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal from pyspark.sql.dataframe import DataFrame diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 36ad15006870d..942e1da95c8ac 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -60,7 +60,7 @@ _parse_datatype_string, _from_numpy_type, ) -from pyspark.errors.exceptions import install_exception_handler +from pyspark.errors.exceptions.captured import install_exception_handler from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str if TYPE_CHECKING: diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index a577e99d0c26d..3c43628bf3780 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -22,6 +22,9 @@ from py4j.java_gateway import JavaObject, java_import from pyspark.errors import StreamingQueryException +from pyspark.errors.exceptions.captured import ( + StreamingQueryException as CapturedStreamingQueryException, +) from pyspark.sql.streaming.listener import StreamingQueryListener __all__ = ["StreamingQuery", "StreamingQueryManager"] @@ -387,7 +390,7 @@ def exception(self) -> Optional[StreamingQueryException]: je = self._jsq.exception().get() msg = je.toString().split(": ", 1)[1] # Drop the Java StreamingQueryException type info stackTrace = "\n\t at ".join(map(lambda x: x.toString(), je.getStackTrace())) - return StreamingQueryException(msg, stackTrace, je.getCause()) + return CapturedStreamingQueryException(msg, stackTrace, je.getCause()) else: return None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a9beb71545d04..eebfaaa39d841 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -49,11 +49,11 @@ connect_requirement_message, ) from pyspark.testing.pandasutils import PandasOnSparkTestUtils -from pyspark.errors import ( +from pyspark.errors.exceptions.connect import ( + AnalysisException, + ParseException, SparkConnectException, - SparkConnectAnalysisException, - SparkConnectParseException, - SparkConnectTempTableAlreadyExistsException, + TempTableAlreadyExistsException, ) if should_test_connect: @@ -223,7 +223,7 @@ def test_df_get_item(self): def test_error_handling(self): # SPARK-41533 Proper error handling for Spark Connect df = self.connect.range(10).select("id2") - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): df.collect() def test_simple_read(self): @@ -472,7 +472,7 @@ def test_with_local_ndarray(self): ): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) - with self.assertRaises(SparkConnectParseException): + with self.assertRaises(ParseException): self.connect.createDataFrame( data, "col1 magic_type, col2 int, col3 int, col4 int" ).show() @@ -518,7 +518,7 @@ def test_with_local_list(self): ): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) - with self.assertRaises(SparkConnectParseException): + with self.assertRaises(ParseException): self.connect.createDataFrame( data, "col1 magic_type, col2 int, col3 int, col4 int" ).show() @@ -1005,7 +1005,7 @@ def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): # incompatible field nullability schema = StructType([StructField("id", LongType(), False)]) self.assertRaisesRegex( - SparkConnectAnalysisException, + AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: cdf.to(schema).toPandas(), ) @@ -1013,7 +1013,7 @@ def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): # field cannot upcast schema = StructType([StructField("name", LongType())]) self.assertRaisesRegex( - SparkConnectAnalysisException, + AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: cdf.to(schema).toPandas(), ) @@ -1025,7 +1025,7 @@ def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): ] ) self.assertRaisesRegex( - SparkConnectAnalysisException, + AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: cdf.to(schema).toPandas(), ) @@ -1244,7 +1244,7 @@ def test_create_global_temp_view(self): # Test when creating a view which is already exists but self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) - with self.assertRaises(SparkConnectTempTableAlreadyExistsException): + with self.assertRaises(TempTableAlreadyExistsException): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") def test_create_session_local_temp_view(self): @@ -1256,7 +1256,7 @@ def test_create_session_local_temp_view(self): self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0) # Test when creating a view which is already exists but - with self.assertRaises(SparkConnectTempTableAlreadyExistsException): + with self.assertRaises(TempTableAlreadyExistsException): self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp") def test_to_pandas(self): @@ -1499,7 +1499,7 @@ def test_replace(self): self.connect.sql(query).replace({None: 1}, subset="a").toPandas() self.assertTrue("Mixed type replacements are not supported" in str(context.exception)) - with self.assertRaises(SparkConnectAnalysisException) as context: + with self.assertRaises(AnalysisException) as context: self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas() self.assertIn( """Cannot resolve column name "x" among (a, b, c)""", str(context.exception) @@ -1607,7 +1607,7 @@ def test_hint(self): ) # Hint with unsupported parameter values - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas() # Hint with unsupported parameter types @@ -1621,7 +1621,7 @@ def test_hint(self): ).toPandas() # Hint with wrong combination - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas() def test_join_hint(self): @@ -1950,7 +1950,7 @@ def test_repartition_by_expression(self) -> None: ) # repartition with unsupported parameter values - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).repartition("id+1").toPandas() def test_repartition_by_range(self) -> None: @@ -1974,7 +1974,7 @@ def test_repartition_by_range(self) -> None: ) # repartitionByRange with unsupported parameter values - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): self.connect.read.table(self.tbl_name).repartitionByRange("id+1").toPandas() def test_agg_with_two_agg_exprs(self) -> None: diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 624bdf4f539d7..a2c786db180d4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -41,7 +41,7 @@ DecimalType, BooleanType, ) -from pyspark.errors import SparkConnectException +from pyspark.errors.exceptions.connect import SparkConnectException from pyspark.testing.connectutils import should_test_connect from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index e3e668eb83590..243153be0832e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -23,7 +23,7 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.sqlutils import SQLTestUtils -from pyspark.errors import SparkConnectAnalysisException, SparkConnectException +from pyspark.errors.exceptions.connect import AnalysisException, SparkConnectException class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): @@ -899,7 +899,7 @@ def test_window_functions(self): cdf.select(CF.rank().over(cdf.a)) # invalid window function - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select(cdf.b.over(CW.orderBy("b"))).show() # invalid window frame @@ -913,34 +913,34 @@ def test_window_functions(self): CF.lead("c", 1), CF.ntile(1), ]: - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.unboundedPreceding, CW.currentRow)) ).show() # Function 'cume_dist' requires Windowframe(RangeFrame, UnboundedPreceding, CurrentRow) ccol = CF.cume_dist() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(SparkConnectAnalysisException): + with self.assertRaises(AnalysisException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.unboundedPreceding, CW.currentRow)) ).show() diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py b/python/pyspark/sql/tests/connect/test_parity_column.py index e157638935164..5cce063871ab8 100644 --- a/python/pyspark/sql/tests/connect/test_parity_column.py +++ b/python/pyspark/sql/tests/connect/test_parity_column.py @@ -32,7 +32,7 @@ class ColumnParityTests(ColumnTestsMixin, ReusedConnectTestCase): - # TODO(SPARK-42017): Different error type AnalysisException vs SparkConnectAnalysisException + # TODO(SPARK-42017): df["bad_key"] does not raise AnalysisException @unittest.skip("Fails in Spark Connect, should enable.") def test_access_column(self): super().test_access_column() diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index d3807285f3ebb..7e6735cb7cdb9 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -85,6 +85,11 @@ def test_require_cross(self): def test_same_semantics_error(self): super().test_same_semantics_error() + # TODO(SPARK-42338): Different exception in DataFrame.sample + @unittest.skip("Fails in Spark Connect, should enable.") + def test_sample(self): + super().test_sample() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_toDF_with_schema_string(self): super().test_toDF_with_schema_string() diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index b151986cb24ba..3d390c13913fa 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -17,11 +17,15 @@ import unittest +from pyspark.errors.exceptions.connect import SparkConnectException from pyspark.sql.tests.test_functions import FunctionsTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase): + def test_assert_true(self): + self.check_assert_true(SparkConnectException) + @unittest.skip("Spark Connect does not support Spark Context but the test depends on that.") def test_basic_functions(self): super().test_basic_functions() @@ -49,6 +53,9 @@ def test_lit_list(self): def test_lit_np_scalar(self): super().test_lit_np_scalar() + def test_raise_error(self): + self.check_assert_true(SparkConnectException) + # Comparing column type of connect and pyspark @unittest.skip("Fails in Spark Connect, should enable.") def test_sorting_functions_with_column(self): diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py index b4d1a9dd31a8c..4b1ce0a958788 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py @@ -19,9 +19,6 @@ from pyspark.sql.tests.pandas.test_pandas_udf import PandasUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.errors.exceptions import SparkConnectGrpcException -from pyspark.sql.connect.functions import udf -from pyspark.sql.functions import pandas_udf, PandasUDFType class PandasUDFParityTests(PandasUDFTestsMixin, ReusedConnectTestCase): @@ -53,24 +50,10 @@ def test_pandas_udf_decorator(self): def test_pandas_udf_basic(self): super().test_pandas_udf_basic() - def test_stopiteration_in_udf(self): - # The vanilla PySpark throws PythonException instead. - def foo(x): - raise StopIteration() - - exc_message = "Caught StopIteration thrown from user's code; failing the task" - df = self.spark.range(0, 100) - - self.assertRaisesRegex( - SparkConnectGrpcException, exc_message, df.withColumn("v", udf(foo)("id")).collect - ) - - # pandas scalar udf - self.assertRaisesRegex( - SparkConnectGrpcException, - exc_message, - df.withColumn("v", pandas_udf(foo, "double", PandasUDFType.SCALAR)("id")).collect, - ) + # TODO(SPARK-42340): implement GroupedData.applyInPandas + @unittest.skip("Fails in Spark Connect, should enable.") + def test_stopiteration_in_grouped_map(self): + super().test_stopiteration_in_grouped_map() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py index f74d21f0d8621..8d4bb69bf1633 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf.py @@ -27,9 +27,6 @@ from pyspark.sql.tests.test_udf import BaseUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.errors.exceptions import SparkConnectAnalysisException -from pyspark.sql.connect.functions import udf -from pyspark.sql.types import BooleanType class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase): @@ -182,26 +179,6 @@ def test_udf_with_string_return_type(self): def test_udf_in_subquery(self): super().test_udf_in_subquery() - def test_udf_not_supported_in_join_condition(self): - # The vanilla PySpark throws AnalysisException instead. - # test python udf is not supported in join type except inner join. - left = self.spark.createDataFrame([(1, 1, 1), (2, 2, 2)], ["a", "a1", "a2"]) - right = self.spark.createDataFrame([(1, 1, 1), (1, 3, 1)], ["b", "b1", "b2"]) - f = udf(lambda a, b: a == b, BooleanType()) - - def runWithJoinType(join_type, type_string): - with self.assertRaisesRegex( - SparkConnectAnalysisException, - """Python UDF in the ON clause of a %s JOIN.""" % type_string, - ): - left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() - - runWithJoinType("full", "FULL OUTER") - runWithJoinType("left", "LEFT OUTER") - runWithJoinType("right", "RIGHT OUTER") - runWithJoinType("leftanti", "LEFT ANTI") - runWithJoinType("leftsemi", "LEFT SEMI") - if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 768317ab60d74..1b3b4555d7ffd 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -171,9 +171,6 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - def foofoo(x, y): - raise StopIteration() - exc_message = "Caught StopIteration thrown from user's code; failing the task" df = self.spark.range(0, 100) @@ -189,6 +186,16 @@ def foofoo(x, y): df.withColumn("v", pandas_udf(foo, "double", PandasUDFType.SCALAR)("id")).collect, ) + def test_stopiteration_in_grouped_map(self): + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + # pandas grouped map self.assertRaisesRegex( PythonException, @@ -204,6 +211,13 @@ def foofoo(x, y): .collect, ) + def test_stopiteration_in_grouped_agg(self): + def foo(x): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + # pandas grouped agg self.assertRaisesRegex( PythonException, diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 5470f79ff2275..9f02ae848bf67 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -254,7 +254,7 @@ def test_stream_exception(self): self._assert_exception_tree_contains_msg(e, "ZeroDivisionError") finally: sq.stop() - self.assertTrue(type(sq.exception()) is StreamingQueryException) + self.assertIsInstance(sq.exception(), StreamingQueryException) self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError") def _assert_exception_tree_contains_msg(self, exception, msg): diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 033878470e193..1d52602a96f15 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -46,8 +46,6 @@ from pyspark.errors import ( AnalysisException, IllegalArgumentException, - SparkConnectException, - SparkConnectAnalysisException, PySparkTypeError, ) from pyspark.testing.sqlutils import ( @@ -948,8 +946,7 @@ def test_sample(self): self.assertRaises(TypeError, lambda: self.spark.range(1).sample(seed="abc")) self.assertRaises( - (IllegalArgumentException, SparkConnectException), - lambda: self.spark.range(1).sample(-1.0).count(), + IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count() ) def test_toDF_with_schema_string(self): @@ -1041,17 +1038,17 @@ def test_cache(self): self.assertFalse(spark.catalog.isCached("tab1")) self.assertFalse(spark.catalog.isCached("tab2")) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.isCached("does_not_exist"), ) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.cacheTable("does_not_exist"), ) self.assertRaisesRegex( - Exception, + AnalysisException, "does_not_exist", lambda: spark.catalog.uncacheTable("does_not_exist"), ) @@ -1595,17 +1592,13 @@ def test_to(self): # incompatible field nullability schema4 = StructType([StructField("j", LongType(), False)]) self.assertRaisesRegex( - (AnalysisException, SparkConnectAnalysisException), - "NULLABLE_COLUMN_OR_FIELD", - lambda: df.to(schema4).count(), + AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: df.to(schema4).count() ) # field cannot upcast schema5 = StructType([StructField("i", LongType())]) self.assertRaisesRegex( - (AnalysisException, SparkConnectAnalysisException), - "INVALID_COLUMN_OR_FIELD_DATA_TYPE", - lambda: df.to(schema5).count(), + AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: df.to(schema5).count() ) def test_repartition(self): diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 05492347755e1..d8343b4fb47ef 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -25,7 +25,7 @@ import unittest from py4j.protocol import Py4JJavaError -from pyspark.errors import PySparkTypeError, PySparkValueError, SparkConnectException +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import Row, Window, types from pyspark.sql.functions import ( udf, @@ -1056,6 +1056,9 @@ def test_datetime_functions(self): self.assertEqual(date(2017, 1, 22), parse_result["to_date(dateCol)"]) def test_assert_true(self): + self.check_assert_true(Py4JJavaError) + + def check_assert_true(self, tpe): from pyspark.sql.functions import assert_true df = self.spark.range(3) @@ -1065,10 +1068,10 @@ def test_assert_true(self): [Row(val=None), Row(val=None), Row(val=None)], ) - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "too big"): + with self.assertRaisesRegex(tpe, "too big"): df.select(assert_true(df.id < 2, "too big")).toDF("val").collect() - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "2000000"): + with self.assertRaisesRegex(tpe, "2000000"): df.select(assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect() with self.assertRaises(PySparkTypeError) as pe: @@ -1081,14 +1084,17 @@ def test_assert_true(self): ) def test_raise_error(self): + self.check_raise_error(Py4JJavaError) + + def check_raise_error(self, tpe): from pyspark.sql.functions import raise_error df = self.spark.createDataFrame([Row(id="foobar")]) - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "foobar"): + with self.assertRaisesRegex(tpe, "foobar"): df.select(raise_error(df.id)).collect() - with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), "barfoo"): + with self.assertRaisesRegex(tpe, "barfoo"): df.select(raise_error("barfoo")).collect() with self.assertRaises(PySparkTypeError) as pe: diff --git a/python/setup.py b/python/setup.py index af5c5f9384c39..ead1139f8f873 100755 --- a/python/setup.py +++ b/python/setup.py @@ -266,6 +266,7 @@ def run(self): "pyspark.licenses", "pyspark.resource", "pyspark.errors", + "pyspark.errors.exceptions", "pyspark.examples.src.main.python", ], include_package_data=True,