Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 6fc6da4

Browse files
authored
Add ReadWriterOptionType (#363)
1 parent d7f9181 commit 6fc6da4

File tree

5 files changed

+41
-16
lines changed

5 files changed

+41
-16
lines changed

test-data/unit/sql-session.test

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ schema = StructType([
6666
])
6767

6868
# Invalid product should have StructType schema
69-
spark.createDataFrame(data, IntegerType()) # E: Argument 1 to "createDataFrame" of "SparkSession" has incompatible type "List[Tuple[str, int]]"; expected "Union[RDD[Union[Union[datetime, date], Union[bool, int, float, str], Decimal]], Iterable[Union[Union[datetime, date], Union[bool, int, float, str], Decimal]]]"
69+
spark.createDataFrame(data, IntegerType()) # E: Argument 1 to "createDataFrame" of "SparkSession" has incompatible type "List[Tuple[str, int]]"; expected "Union[RDD[Union[Union[datetime, date], Union[bool, float, int, str], Decimal]], Iterable[Union[Union[datetime, date], Union[bool, float, int, str], Decimal]]]"
7070

7171
# This shouldn't type check, though is technically speaking valid
7272
# because samplingRatio is ignored
@@ -76,3 +76,23 @@ spark.createDataFrame(data, schema, samplingRatio=0.1) # E: No overload variant
7676
# N: def [RowLike in (List[Any], Tuple[Any, ...], Row)] createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], schema: Union[List[str], Tuple[str, ...]] = ..., verifySchema: bool = ...) -> DataFrame \
7777
# N: <4 more similar overloads not shown, out of 6 total overloads>
7878
[out]
79+
80+
81+
[case readWriterOptions]
82+
from pyspark.sql import SparkSession
83+
84+
spark = SparkSession.builder.getOrCreate()
85+
86+
spark.read.option("foo", True).option("foo", 1).option("foo", 1.0).option("foo", "1")
87+
spark.readStream.option("foo", True).option("foo", 1).option("foo", 1.0).option("foo", "1")
88+
89+
spark.read.options(foo=True, bar=1).options(foo=1.0, bar="1")
90+
spark.readStream.options(foo=True, bar=1).options(foo=1.0, bar="1")
91+
92+
spark.read.load(foo=True)
93+
spark.readStream.load(foo=True)
94+
95+
spark.read.load(foo=["a"]) # E: Argument "foo" to "load" of "DataFrameReader" has incompatible type "List[str]"; expected "Union[bool, float, int, str]"
96+
spark.read.option("foo", (1, )) # E: Argument 2 to "option" of "DataFrameReader" has incompatible type "Tuple[int]"; expected "Union[bool, float, int, str]"
97+
spark.read.options(bar={1}) # E: Argument "bar" to "options" of "DataFrameReader" has incompatible type "Set[int]"; expected "Union[bool, float, int, str]"
98+
[out]

third_party/3/pyspark/_typing.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ from typing_extensions import Protocol
33

44
T = TypeVar('T', covariant=True)
55

6+
PrimitiveType = Union[bool, float, int, str]
7+
68
class SupportsIAdd(Protocol):
79
def __iadd__(self, other: SupportsIAdd) -> SupportsIAdd: ...
810

third_party/3/pyspark/sql/_typing.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from types import FunctionType
55
import datetime
66
import decimal
77

8+
from pyspark._typing import PrimitiveType
89
import pyspark.sql.column
910
import pyspark.sql.types
1011
from pyspark.sql.column import Column
@@ -16,9 +17,10 @@ import pandas.core.series # type: ignore
1617
ColumnOrName = Union[pyspark.sql.column.Column, str]
1718
DecimalLiteral = decimal.Decimal
1819
DateTimeLiteral = Union[datetime.datetime, datetime.date]
19-
LiteralType = Union[bool, int, float, str]
20+
LiteralType = PrimitiveType
2021
AtomicDataTypeOrString = Union[pyspark.sql.types.AtomicType, str]
2122
DataTypeOrString = Union[pyspark.sql.types.DataType, str]
23+
ReadWriterOptionType = PrimitiveType
2224

2325
RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row)
2426

third_party/3/pyspark/sql/readwriter.pyi

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import overload
55
from typing import Any, Dict, List, Optional, Tuple, Union
66

7+
from pyspark.sql._typing import ReadWriterOptionType
78
from pyspark.sql.dataframe import DataFrame
89
from pyspark.rdd import RDD
910
from pyspark.sql.context import SQLContext
@@ -19,11 +20,11 @@ class DataFrameReader(OptionUtils):
1920
def format(self, source: str) -> DataFrameReader: ...
2021
def schema(self, schema: Union[StructType, str]) -> DataFrameReader: ...
2122
def option(self, key: str, value: Union[bool, float, int, str]) -> DataFrameReader: ...
22-
def options(self, **options: str) -> DataFrameReader: ...
23-
def load(self, path: Optional[PathOrPaths] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: str) -> DataFrame: ...
23+
def options(self, **options: ReadWriterOptionType) -> DataFrameReader: ...
24+
def load(self, path: Optional[PathOrPaths] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: ReadWriterOptionType) -> DataFrame: ...
2425
def json(self, path: Union[str, List[str], RDD[str]], schema: Optional[StructType] = ..., primitivesAsString: Optional[Union[bool, str]] = ..., prefersDecimal: Optional[Union[bool, str]] = ..., allowComments: Optional[Union[bool, str]] = ..., allowUnquotedFieldNames: Optional[Union[bool, str]] = ..., allowSingleQuotes: Optional[Union[bool, str]] = ..., allowNumericLeadingZero: Optional[Union[bool, str]] = ..., allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ..., mode: Optional[str] = ..., columnNameOfCorruptRecord: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., multiLine: Optional[Union[bool, str]] = ..., allowUnquotedControlChars: Optional[Union[bool, str]] = ..., lineSep: Optional[str] = ..., samplingRatio: Optional[Union[float, str]] = ..., dropFieldIfAllNull: Optional[Union[bool, str]] = ..., encoding: Optional[str] = ..., locale: Optional[str] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
2526
def table(self, tableName: str) -> DataFrame: ...
26-
def parquet(self, *paths: str, **options: str) -> DataFrame: ...
27+
def parquet(self, *paths: str, **options: ReadWriterOptionType) -> DataFrame: ...
2728
def text(self, paths: PathOrPaths, wholetext: bool = ..., lineSep: Optional[str] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
2829
def csv(self, path: PathOrPaths, schema: Optional[StructType] = ..., sep: Optional[str] = ..., encoding: Optional[str] = ..., quote: Optional[str] = ..., escape: Optional[str] = ..., comment: Optional[str] = ..., header: Optional[Union[bool, str]] = ..., inferSchema: Optional[Union[bool, str]] = ..., ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = ..., ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = ..., nullValue: Optional[str] = ..., nanValue: Optional[str] = ..., positiveInf: Optional[str] = ..., negativeInf: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., maxColumns: Optional[int] = ..., maxCharsPerColumn: Optional[int] = ..., maxMalformedLogPerPartition: Optional[int] = ..., mode: Optional[str] = ..., columnNameOfCorruptRecord: Optional[str] = ..., multiLine: Optional[Union[bool, str]] = ..., charToEscapeQuoteEscaping: Optional[str] = ..., samplingRatio: Optional[Union[float, str]] = ..., enforceSchema: Optional[Union[bool, str]] = ..., emptyValue: Optional[str] = ..., locale: Optional[str] = ..., lineSep: Optional[str] = ...) -> DataFrame: ...
2930
def orc(self, path: PathOrPaths, mergeSchema: Optional[bool] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
@@ -38,8 +39,8 @@ class DataFrameWriter(OptionUtils):
3839
def __init__(self, df: DataFrame) -> None: ...
3940
def mode(self, saveMode: str) -> DataFrameWriter: ...
4041
def format(self, source: str) -> DataFrameWriter: ...
41-
def option(self, key: str, value: Union[bool, float, int, str]) -> DataFrameWriter: ...
42-
def options(self, **options: str) -> DataFrameWriter: ...
42+
def option(self, key: str, value: ReadWriterOptionType) -> DataFrameWriter: ...
43+
def options(self, **options: ReadWriterOptionType) -> DataFrameWriter: ...
4344
@overload
4445
def partitionBy(self, *cols: str) -> DataFrameWriter: ...
4546
@overload
@@ -52,9 +53,9 @@ class DataFrameWriter(OptionUtils):
5253
def sortBy(self, col: str, *cols: str) -> DataFrameWriter: ...
5354
@overload
5455
def sortBy(self, col: TupleOrListOfString) -> DataFrameWriter: ...
55-
def save(self, path: Optional[str] = ..., format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: str) -> None: ...
56+
def save(self, path: Optional[str] = ..., format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: ReadWriterOptionType) -> None: ...
5657
def insertInto(self, tableName: str, overwrite: Optional[bool] = ...) -> None: ...
57-
def saveAsTable(self, name: str, format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: str) -> None: ...
58+
def saveAsTable(self, name: str, format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: ReadWriterOptionType) -> None: ...
5859
def json(self, path: str, mode: Optional[str] = ..., compression: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., lineSep: Optional[str] = ..., encoding: Optional[str] = ..., ignoreNullFields: Optional[bool] = ...) -> None: ...
5960
def parquet(self, path: str, mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., compression: Optional[str] = ...) -> None: ...
6061
def text(self, path: str, compression: Optional[str] = ..., lineSep: Optional[str] = ...) -> None: ...

third_party/3/pyspark/sql/streaming.pyi

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import overload
55
from typing import Any, Callable, Dict, List, Optional, Union
66

7-
from pyspark.sql._typing import SupportsProcess
7+
from pyspark.sql._typing import SupportsProcess, ReadWriterOptionType
88
from pyspark.sql.context import SQLContext
99
from pyspark.sql.dataframe import DataFrame
1010
from pyspark.sql.readwriter import OptionUtils
@@ -47,9 +47,9 @@ class DataStreamReader(OptionUtils):
4747
def __init__(self, spark: SQLContext) -> None: ...
4848
def format(self, source: str) -> DataStreamReader: ...
4949
def schema(self, schema: Union[StructType, str]) -> DataStreamReader: ...
50-
def option(self, key: str, value: Union[bool, float, int, str]) -> DataStreamReader: ...
51-
def options(self, **options: str) -> DataStreamReader: ...
52-
def load(self, path: Optional[str] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: str) -> DataFrame: ...
50+
def option(self, key: str, value: ReadWriterOptionType) -> DataStreamReader: ...
51+
def options(self, **options: ReadWriterOptionType) -> DataStreamReader: ...
52+
def load(self, path: Optional[str] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: ReadWriterOptionType) -> DataFrame: ...
5353
def json(self, path: str, schema: Optional[str] = ..., primitivesAsString: Optional[Union[bool, str]] = ..., prefersDecimal: Optional[Union[bool, str]] = ..., allowComments: Optional[Union[bool, str]] = ..., allowUnquotedFieldNames: Optional[Union[bool, str]] = ..., allowSingleQuotes: Optional[Union[bool, str]] = ..., allowNumericLeadingZero: Optional[Union[bool, str]] = ..., allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ..., mode: Optional[str] = ..., columnNameOfCorruptRecord: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., multiLine: Optional[Union[bool, str]] = ..., allowUnquotedControlChars: Optional[Union[bool, str]] = ..., lineSep: Optional[str] = ..., locale: Optional[str] = ..., dropFieldIfAllNull: Optional[Union[bool, str]] = ..., encoding: Optional[str] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
5454
def orc(self, path: str, mergeSchema: Optional[bool] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
5555
def parquet(self, path: str, mergeSchema: Optional[bool] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
@@ -60,8 +60,8 @@ class DataStreamWriter:
6060
def __init__(self, df: DataFrame) -> None: ...
6161
def outputMode(self, outputMode: str) -> DataStreamWriter: ...
6262
def format(self, source: str) -> DataStreamWriter: ...
63-
def option(self, key: str, value: Union[bool, float, int, str]) -> DataStreamWriter: ...
64-
def options(self, **options: str) -> DataStreamWriter: ...
63+
def option(self, key: str, value: ReadWriterOptionType) -> DataStreamWriter: ...
64+
def options(self, **options: ReadWriterOptionType) -> DataStreamWriter: ...
6565
@overload
6666
def partitionBy(self, *cols: str) -> DataStreamWriter: ...
6767
@overload
@@ -73,7 +73,7 @@ class DataStreamWriter:
7373
def trigger(self, once: bool) -> DataStreamWriter: ...
7474
@overload
7575
def trigger(self, continuous: bool) -> DataStreamWriter: ...
76-
def start(self, path: Optional[str] = ..., format: Optional[str] = ..., outputMode: Optional[str] = ..., partitionBy: Optional[Union[str, List[str]]] = ..., queryName: Optional[str] = ..., **options: str) -> StreamingQuery: ...
76+
def start(self, path: Optional[str] = ..., format: Optional[str] = ..., outputMode: Optional[str] = ..., partitionBy: Optional[Union[str, List[str]]] = ..., queryName: Optional[str] = ..., **options: ReadWriterOptionType) -> StreamingQuery: ...
7777
@overload
7878
def foreach(self, f: Callable[[Row], None]) -> DataStreamWriter: ...
7979
@overload

0 commit comments

Comments
 (0)