Skip to content

Commit

Permalink
[SPARK-40537][CONNECT] Enable mypy for Spark Connect Python Client
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This patch adds the missing type annotations for the Spark Connect Python client and renables the mypy checks. In addition, the patch adds the `pyi` files for the generated proto code for better client support.

### Why are the changes needed?

Tooling / Debugging / Testing

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

* Passes existing unit tests.
* I manually tested that the generated code is idempotent without changes. Running the `genereate_protos.sh` script multiple times does not produce a diff in the output.

Closes apache#38037 from grundprinzip/spark-40537.

Authored-by: Martin Grund <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
grundprinzip authored and HyukjinKwon committed Oct 6, 2022
1 parent e8fdf8e commit 0414213
Show file tree
Hide file tree
Showing 17 changed files with 3,205 additions and 156 deletions.
2 changes: 2 additions & 0 deletions connect/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Ignore generated proto files.
src/main/gen
82 changes: 82 additions & 0 deletions connect/dev/generate_protos.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
set -ex

SPARK_HOME="$(cd "`dirname $0`"/../..; pwd)"
cd "$SPARK_HOME"

pushd connect/src/main

LICENSE=$(cat <<'EOF'
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
EOF)
echo "$LICENSE" > /tmp/tmp_licence
# Delete the old generated protobuf files.
rm -Rf gen
# Now, regenerate the new files
buf generate --debug -vvv
# We need to edit the generate python files to account for the actual package location and not
# the one generated by proto.
for f in `find gen/proto/python -name "*.py*"`; do
# First fix the imports.
if [[ $f == *_pb2.py || $f == *_pb2_grpc.py ]]; then
sed -i '' -e 's/from spark.connect import/from pyspark.sql.connect.proto import/g' $f
elif [[ $f == *.pyi ]]; then
sed -i '' -e 's/import spark.connect./import pyspark.sql.connect.proto./g' $f
sed -i '' -e 's/spark.connect./pyspark.sql.connect.proto./g' $f
fi
# Prepend the Apache licence header to the files.
cp $f $f.bak
cat /tmp/tmp_licence $f.bak > $f
LC=$(wc -l < $f)
echo $LC
if [[ $f == *_grpc.py && $LC -eq 20 ]]; then
rm $f
fi
rm $f.bak
done
black --config $SPARK_HOME/dev/pyproject.toml gen/proto/python
# Last step copy the result files to the destination module.
for f in `find gen/proto/python -name "*.py*"`; do
cp $f $SPARK_HOME/python/pyspark/sql/connect/proto
done
# Clean up everything.
rm -Rf gen
12 changes: 8 additions & 4 deletions connect/src/main/buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ plugins:
out: gen/proto/csharp
- remote: buf.build/protocolbuffers/plugins/java:v3.20.0-1
out: gen/proto/java
- remote: buf.build/protocolbuffers/plugins/python:v3.20.0-1
out: gen/proto/python
- remote: buf.build/grpc/plugins/python:v1.47.0-1
out: gen/proto/python
- remote: buf.build/grpc/plugins/ruby:v1.47.0-1
out: gen/proto/ruby
- remote: buf.build/protocolbuffers/plugins/ruby:v21.2.0-1
out: gen/proto/ruby
# Building the Python build and building the mypy interfaces.
- remote: buf.build/protocolbuffers/plugins/python:v3.20.0-1
out: gen/proto/python
- remote: buf.build/grpc/plugins/python:v1.47.0-1
out: gen/proto/python
- name: mypy
out: gen/proto/python

5 changes: 1 addition & 4 deletions python/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ show_error_codes = True
warn_unused_ignores = True
warn_redundant_casts = True

; TODO(SPARK-40537) reenable mypi support.
[mypy-pyspark.sql.connect.*]
disallow_untyped_defs = False
ignore_missing_imports = True
[mypy-pyspark.sql.connect.proto.*]
ignore_errors = True

; TODO(SPARK-40537) reenable mypi support.
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union

PrimitiveType = Union[str, int, bool, float]
33 changes: 21 additions & 12 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@
import typing
import uuid

import grpc
import grpc # type: ignore
import pandas
import pandas as pd
import pyarrow as pa

import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.types
from pyspark import cloudpickle
from pyspark.sql.connect.data_frame import DataFrame
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.plan import SQL

from typing import Optional, Any, Union

NumericType = typing.Union[int, float]

Expand Down Expand Up @@ -62,7 +64,7 @@ def metric_type(self) -> str:


class PlanMetrics:
def __init__(self, name: str, id: str, parent: str, metrics: typing.List[MetricValue]):
def __init__(self, name: str, id: int, parent: int, metrics: typing.List[MetricValue]):
self._name = name
self._id = id
self._parent_id = parent
Expand All @@ -76,11 +78,11 @@ def name(self) -> str:
return self._name

@property
def plan_id(self) -> str:
def plan_id(self) -> int:
return self._id

@property
def parent_plan_id(self) -> str:
def parent_plan_id(self) -> int:
return self._parent_id

@property
Expand All @@ -102,7 +104,7 @@ def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
class RemoteSparkSession(object):
"""Conceptually the remote spark session that communicates with the server"""

def __init__(self, user_id: str, host: str = None, port: int = 15002):
def __init__(self, user_id: str, host: Optional[str] = None, port: int = 15002):
self._host = "localhost" if host is None else host
self._port = port
self._user_id = user_id
Expand All @@ -112,7 +114,9 @@ def __init__(self, user_id: str, host: str = None, port: int = 15002):
# Create the reader
self.read = DataFrameReader(self)

def register_udf(self, function, return_type) -> str:
def register_udf(
self, function: Any, return_type: Union[str, pyspark.sql.types.DataType]
) -> str:
"""Create a temporary UDF in the session catalog on the other side. We generate a
temporary name for it."""
name = f"fun_{uuid.uuid4().hex}"
Expand Down Expand Up @@ -141,7 +145,7 @@ def _build_metrics(self, metrics: "pb2.Response.Metrics") -> typing.List[PlanMet
def sql(self, sql_string: str) -> "DataFrame":
return DataFrame.withPlan(SQL(sql_string), self)

def _to_pandas(self, plan: pb2.Plan) -> pandas.DataFrame:
def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
req = pb2.Request()
req.user_context.user_id = self._user_id
req.plan.CopyFrom(plan)
Expand All @@ -155,26 +159,31 @@ def analyze(self, plan: pb2.Plan) -> AnalyzeResult:
resp = self._stub.AnalyzePlan(req)
return AnalyzeResult.fromProto(resp)

def _process_batch(self, b) -> pandas.DataFrame:
def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]:
if b.batch is not None and len(b.batch.data) > 0:
with pa.ipc.open_stream(b.data) as rd:
with pa.ipc.open_stream(b.batch.data) as rd:
return rd.read_pandas()
elif b.csv_batch is not None and len(b.csv_batch.data) > 0:
return pd.read_csv(io.StringIO(b.csv_batch.data), delimiter="|")
return None

def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]:
m = None
m: Optional[pb2.Response.Metrics] = None
result_dfs = []

for b in self._stub.ExecutePlan(req):
if b.metrics is not None:
m = b.metrics
result_dfs.append(self._process_batch(b))

pb = self._process_batch(b)
if pb is not None:
result_dfs.append(pb)

if len(result_dfs) > 0:
df = pd.concat(result_dfs)
# Attach the metrics to the DataFrame attributes.
df.attrs["metrics"] = self._build_metrics(m)
if m is not None:
df.attrs["metrics"] = self._build_metrics(m)
return df
else:
return None
44 changes: 22 additions & 22 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
# limitations under the License.
#

from typing import List, Union, cast, get_args, TYPE_CHECKING
from typing import List, cast, get_args, TYPE_CHECKING, Optional, Callable, Any

import pyspark.sql.connect.proto as proto

PrimitiveType = Union[str, int, bool, float]
ExpressionOrString = Union[str, "Expression"]
ColumnOrString = Union[str, "ColumnRef"]
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect._typing import PrimitiveType

if TYPE_CHECKING:
from pyspark.sql.connect.client import RemoteSparkSession
Expand All @@ -33,10 +31,10 @@ class Expression(object):
Expression base class.
"""

def __init__(self) -> None: # type: ignore[name-defined]
def __init__(self) -> None:
pass

def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": # type: ignore
def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression":
...

def __str__(self) -> str:
Expand All @@ -49,11 +47,11 @@ class LiteralExpression(Expression):
The Python types are converted best effort into the relevant proto types. On the Spark Connect
server side, the proto types are converted to the Catalyst equivalents."""

def __init__(self, value: PrimitiveType) -> None: # type: ignore[name-defined]
def __init__(self, value: PrimitiveType) -> None:
super().__init__()
self._value = value

def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
def to_plan(self, session: Optional["RemoteSparkSession"]) -> "proto.Expression":
"""Converts the literal expression to the literal in proto.
TODO(SPARK-40533) This method always assumes the largest type and can thus
Expand All @@ -75,8 +73,10 @@ def __str__(self) -> str:
return f"Literal({self._value})"


def _bin_op(name: str, doc: str = "binary function", reverse=False):
def _(self: "ColumnRef", other) -> Expression:
def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["ColumnRef", Any], Expression]:
def _(self: "ColumnRef", other: Any) -> Expression:
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
Expand All @@ -94,12 +94,12 @@ class ColumnRef(Expression):
qualified name are identical"""

@classmethod
def from_qualified_name(cls, name) -> "ColumnRef":
def from_qualified_name(cls, name: str) -> "ColumnRef":
return ColumnRef(*name.split("."))

def __init__(self, *parts: str) -> None: # type: ignore[name-defined]
def __init__(self, *parts: str) -> None:
super().__init__()
self._parts: List[str] = list(filter(lambda x: x is not None, list(parts)))
self._parts: List[str] = list(parts)

def name(self) -> str:
"""Returns the qualified name of the column reference."""
Expand All @@ -123,32 +123,32 @@ def name(self) -> str:
__ge__ = _bin_op("greterEquals")
__le__ = _bin_op("lessEquals")

def __eq__(self, other) -> Expression: # type: ignore[override]
def __eq__(self, other: Any) -> Expression: # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("eq", self, other)

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.unresolved_attribute.parts.extend(self._parts)
return expr

def desc(self):
def desc(self) -> "SortOrder":
return SortOrder(self, ascending=False)

def asc(self):
def asc(self) -> "SortOrder":
return SortOrder(self, ascending=True)

def __str__(self) -> str:
return f"Column({'.'.join(self._parts)})"


class SortOrder(Expression):
def __init__(self, col: ColumnRef, ascending=True, nullsLast=True) -> None:
def __init__(self, col: ColumnRef, ascending: bool = True, nullsLast: bool = True) -> None:
super().__init__()
self.ref = col
self.ascending = ascending
Expand All @@ -157,8 +157,8 @@ def __init__(self, col: ColumnRef, ascending=True, nullsLast=True) -> None:
def __str__(self) -> str:
return str(self.ref) + " ASC" if self.ascending else " DESC"

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
return self.ref.to_plan()
def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
return self.ref.to_plan(session)


class ScalarFunctionExpression(Expression):
Expand All @@ -171,7 +171,7 @@ def __init__(
self._args = args
self._op = op

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
fun = proto.Expression()
fun.unresolved_function.parts.append(self._op)
fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args])
Expand Down
Loading

0 comments on commit 0414213

Please sign in to comment.