From 31a16fbb405a19dc3eb732347e0e1f873b16971d Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 24 Sep 2020 14:15:36 +0900 Subject: [PATCH] [SPARK-32714][PYTHON] Initial pyspark-stubs port ### What changes were proposed in this pull request? This PR proposes migration of [`pyspark-stubs`](https://github.com/zero323/pyspark-stubs) into Spark codebase. ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? Yes. This PR adds type annotations directly to Spark source. This can impact interaction with development tools for users, which haven't used `pyspark-stubs`. ### How was this patch tested? - [x] MyPy tests of the PySpark source ``` mypy --no-incremental --config python/mypy.ini python/pyspark ``` - [x] MyPy tests of Spark examples ``` MYPYPATH=python/ mypy --no-incremental --config python/mypy.ini examples/src/main/python/ml examples/src/main/python/sql examples/src/main/python/sql/streaming ``` - [x] Existing Flake8 linter - [x] Existing unit tests Tested against: - `mypy==0.790+dev.e959952d9001e9713d329a2f9b196705b028f894` - `mypy==0.782` Closes #29591 from zero323/SPARK-32681. Authored-by: zero323 Signed-off-by: HyukjinKwon --- dev/.rat-excludes | 1 + dev/tox.ini | 2 +- .../ml/estimator_transformer_param_example.py | 8 +- .../main/python/ml/fm_classifier_example.py | 6 +- .../main/python/ml/fm_regressor_example.py | 6 +- .../src/main/python/ml/pipeline_example.py | 8 +- examples/src/main/python/sql/arrow.py | 4 +- python/MANIFEST.in | 1 + python/mypy.ini | 36 + python/pyspark/__init__.pyi | 73 + python/pyspark/_globals.pyi | 27 + python/pyspark/_typing.pyi | 33 + python/pyspark/accumulators.pyi | 71 + python/pyspark/broadcast.pyi | 46 + python/pyspark/conf.pyi | 44 + python/pyspark/context.pyi | 176 ++ python/pyspark/daemon.pyi | 29 + python/pyspark/files.pyi | 24 + python/pyspark/find_spark_home.pyi | 17 + python/pyspark/java_gateway.pyi | 24 + python/pyspark/join.pyi | 50 + python/pyspark/ml/__init__.pyi | 45 + python/pyspark/ml/_typing.pyi | 76 + python/pyspark/ml/base.pyi | 103 ++ python/pyspark/ml/classification.pyi | 922 ++++++++++ python/pyspark/ml/clustering.pyi | 437 +++++ python/pyspark/ml/common.pyi | 20 + python/pyspark/ml/evaluation.pyi | 281 +++ python/pyspark/ml/feature.pyi | 1629 +++++++++++++++++ python/pyspark/ml/fpm.pyi | 109 ++ python/pyspark/ml/functions.pyi | 22 + python/pyspark/ml/image.pyi | 40 + python/pyspark/ml/linalg/__init__.pyi | 255 +++ python/pyspark/ml/param/__init__.pyi | 96 + .../ml/param/_shared_params_code_gen.pyi | 19 + python/pyspark/ml/param/shared.pyi | 187 ++ python/pyspark/ml/pipeline.pyi | 97 + python/pyspark/ml/recommendation.pyi | 152 ++ python/pyspark/ml/regression.pyi | 825 +++++++++ python/pyspark/ml/stat.pyi | 89 + python/pyspark/ml/tests/test_algorithms.py | 2 +- python/pyspark/ml/tests/test_base.py | 2 +- python/pyspark/ml/tests/test_evaluation.py | 2 +- python/pyspark/ml/tests/test_feature.py | 2 +- python/pyspark/ml/tests/test_image.py | 2 +- python/pyspark/ml/tests/test_linalg.py | 2 +- python/pyspark/ml/tests/test_param.py | 2 +- python/pyspark/ml/tests/test_persistence.py | 2 +- python/pyspark/ml/tests/test_pipeline.py | 2 +- python/pyspark/ml/tests/test_stat.py | 2 +- .../pyspark/ml/tests/test_training_summary.py | 2 +- python/pyspark/ml/tests/test_tuning.py | 2 +- python/pyspark/ml/tests/test_wrapper.py | 6 +- python/pyspark/ml/tree.pyi | 112 ++ python/pyspark/ml/tuning.pyi | 185 ++ python/pyspark/ml/util.pyi | 128 ++ python/pyspark/ml/wrapper.pyi | 48 + python/pyspark/mllib/__init__.pyi | 32 + python/pyspark/mllib/_typing.pyi | 23 + python/pyspark/mllib/classification.pyi | 151 ++ python/pyspark/mllib/clustering.pyi | 196 ++ python/pyspark/mllib/common.pyi | 27 + python/pyspark/mllib/evaluation.pyi | 94 + python/pyspark/mllib/feature.pyi | 167 ++ python/pyspark/mllib/fpm.pyi | 57 + python/pyspark/mllib/linalg/__init__.pyi | 273 +++ python/pyspark/mllib/linalg/distributed.pyi | 147 ++ python/pyspark/mllib/random.pyi | 126 ++ python/pyspark/mllib/recommendation.pyi | 75 + python/pyspark/mllib/regression.pyi | 155 ++ python/pyspark/mllib/stat/KernelDensity.pyi | 27 + python/pyspark/mllib/stat/__init__.pyi | 29 + python/pyspark/mllib/stat/_statistics.pyi | 69 + python/pyspark/mllib/stat/distribution.pyi | 25 + python/pyspark/mllib/stat/test.pyi | 39 + python/pyspark/mllib/tests/test_algorithms.py | 2 +- python/pyspark/mllib/tests/test_feature.py | 2 +- python/pyspark/mllib/tests/test_linalg.py | 6 +- python/pyspark/mllib/tests/test_stat.py | 2 +- .../mllib/tests/test_streaming_algorithms.py | 2 +- python/pyspark/mllib/tests/test_util.py | 4 +- python/pyspark/mllib/tree.pyi | 126 ++ python/pyspark/mllib/util.pyi | 90 + python/pyspark/profiler.pyi | 56 + python/pyspark/py.typed | 1 + python/pyspark/rdd.pyi | 479 +++++ python/pyspark/rddsampler.pyi | 54 + python/pyspark/resource/__init__.pyi | 31 + python/pyspark/resource/information.pyi | 26 + python/pyspark/resource/profile.pyi | 51 + python/pyspark/resource/requests.pyi | 71 + .../pyspark/resource/tests/test_resources.py | 2 +- python/pyspark/resultiterable.pyi | 30 + python/pyspark/serializers.pyi | 122 ++ python/pyspark/shell.pyi | 31 + python/pyspark/shuffle.pyi | 109 ++ python/pyspark/sql/__init__.pyi | 41 + python/pyspark/sql/_typing.pyi | 57 + python/pyspark/sql/avro/__init__.pyi | 22 + python/pyspark/sql/avro/functions.pyi | 27 + python/pyspark/sql/catalog.pyi | 63 + python/pyspark/sql/column.pyi | 112 ++ python/pyspark/sql/conf.pyi | 27 + python/pyspark/sql/context.pyi | 139 ++ python/pyspark/sql/dataframe.pyi | 324 ++++ python/pyspark/sql/functions.pyi | 343 ++++ python/pyspark/sql/group.pyi | 44 + python/pyspark/sql/pandas/__init__.pyi | 17 + .../pyspark/sql/pandas/_typing/__init__.pyi | 338 ++++ .../sql/pandas/_typing/protocols/__init__.pyi | 17 + .../sql/pandas/_typing/protocols/frame.pyi | 428 +++++ .../sql/pandas/_typing/protocols/series.pyi | 253 +++ python/pyspark/sql/pandas/conversion.pyi | 58 + python/pyspark/sql/pandas/functions.pyi | 176 ++ python/pyspark/sql/pandas/group_ops.pyi | 49 + python/pyspark/sql/pandas/map_ops.pyi | 30 + python/pyspark/sql/pandas/serializers.pyi | 65 + python/pyspark/sql/pandas/typehints.pyi | 33 + python/pyspark/sql/pandas/types.pyi | 41 + python/pyspark/sql/pandas/utils.pyi | 20 + python/pyspark/sql/readwriter.pyi | 250 +++ python/pyspark/sql/session.pyi | 125 ++ python/pyspark/sql/streaming.pyi | 179 ++ python/pyspark/sql/tests/test_arrow.py | 6 +- python/pyspark/sql/tests/test_catalog.py | 2 +- python/pyspark/sql/tests/test_column.py | 2 +- python/pyspark/sql/tests/test_conf.py | 2 +- python/pyspark/sql/tests/test_context.py | 2 +- python/pyspark/sql/tests/test_dataframe.py | 22 +- python/pyspark/sql/tests/test_datasources.py | 2 +- python/pyspark/sql/tests/test_functions.py | 2 +- python/pyspark/sql/tests/test_group.py | 2 +- .../sql/tests/test_pandas_cogrouped_map.py | 4 +- .../sql/tests/test_pandas_grouped_map.py | 4 +- python/pyspark/sql/tests/test_pandas_map.py | 4 +- python/pyspark/sql/tests/test_pandas_udf.py | 4 +- .../sql/tests/test_pandas_udf_grouped_agg.py | 4 +- .../sql/tests/test_pandas_udf_scalar.py | 6 +- .../sql/tests/test_pandas_udf_typehints.py | 4 +- .../sql/tests/test_pandas_udf_window.py | 4 +- python/pyspark/sql/tests/test_readwriter.py | 2 +- python/pyspark/sql/tests/test_serde.py | 2 +- python/pyspark/sql/tests/test_session.py | 2 +- python/pyspark/sql/tests/test_streaming.py | 2 +- python/pyspark/sql/tests/test_types.py | 9 +- python/pyspark/sql/tests/test_udf.py | 11 +- python/pyspark/sql/tests/test_utils.py | 2 +- python/pyspark/sql/types.pyi | 204 +++ python/pyspark/sql/udf.pyi | 57 + python/pyspark/sql/utils.pyi | 55 + python/pyspark/sql/window.pyi | 40 + python/pyspark/statcounter.pyi | 44 + python/pyspark/status.pyi | 42 + python/pyspark/storagelevel.pyi | 43 + python/pyspark/streaming/__init__.pyi | 23 + python/pyspark/streaming/context.pyi | 75 + python/pyspark/streaming/dstream.pyi | 208 +++ python/pyspark/streaming/kinesis.pyi | 46 + python/pyspark/streaming/listener.pyi | 35 + .../pyspark/streaming/tests/test_context.py | 2 +- .../pyspark/streaming/tests/test_dstream.py | 2 +- .../pyspark/streaming/tests/test_kinesis.py | 2 +- .../pyspark/streaming/tests/test_listener.py | 2 +- python/pyspark/streaming/util.pyi | 48 + python/pyspark/taskcontext.pyi | 45 + python/pyspark/testing/mlutils.py | 5 +- python/pyspark/testing/sqlutils.py | 2 +- python/pyspark/testing/streamingutils.py | 4 +- python/pyspark/tests/test_appsubmit.py | 2 +- python/pyspark/tests/test_broadcast.py | 2 +- python/pyspark/tests/test_conf.py | 2 +- python/pyspark/tests/test_context.py | 11 +- python/pyspark/tests/test_daemon.py | 2 +- python/pyspark/tests/test_join.py | 2 +- python/pyspark/tests/test_pin_thread.py | 2 +- python/pyspark/tests/test_profiler.py | 2 +- python/pyspark/tests/test_rdd.py | 2 +- python/pyspark/tests/test_rddbarrier.py | 2 +- python/pyspark/tests/test_readwrite.py | 2 +- python/pyspark/tests/test_serializers.py | 4 +- python/pyspark/tests/test_shuffle.py | 2 +- python/pyspark/tests/test_taskcontext.py | 2 +- python/pyspark/tests/test_util.py | 2 +- python/pyspark/tests/test_worker.py | 2 +- python/pyspark/traceback_utils.pyi | 29 + python/pyspark/util.pyi | 35 + python/pyspark/version.pyi | 19 + python/pyspark/worker.pyi | 73 + python/setup.py | 3 +- 189 files changed, 14053 insertions(+), 119 deletions(-) create mode 100644 python/mypy.ini create mode 100644 python/pyspark/__init__.pyi create mode 100644 python/pyspark/_globals.pyi create mode 100644 python/pyspark/_typing.pyi create mode 100644 python/pyspark/accumulators.pyi create mode 100644 python/pyspark/broadcast.pyi create mode 100644 python/pyspark/conf.pyi create mode 100644 python/pyspark/context.pyi create mode 100644 python/pyspark/daemon.pyi create mode 100644 python/pyspark/files.pyi create mode 100644 python/pyspark/find_spark_home.pyi create mode 100644 python/pyspark/java_gateway.pyi create mode 100644 python/pyspark/join.pyi create mode 100644 python/pyspark/ml/__init__.pyi create mode 100644 python/pyspark/ml/_typing.pyi create mode 100644 python/pyspark/ml/base.pyi create mode 100644 python/pyspark/ml/classification.pyi create mode 100644 python/pyspark/ml/clustering.pyi create mode 100644 python/pyspark/ml/common.pyi create mode 100644 python/pyspark/ml/evaluation.pyi create mode 100644 python/pyspark/ml/feature.pyi create mode 100644 python/pyspark/ml/fpm.pyi create mode 100644 python/pyspark/ml/functions.pyi create mode 100644 python/pyspark/ml/image.pyi create mode 100644 python/pyspark/ml/linalg/__init__.pyi create mode 100644 python/pyspark/ml/param/__init__.pyi create mode 100644 python/pyspark/ml/param/_shared_params_code_gen.pyi create mode 100644 python/pyspark/ml/param/shared.pyi create mode 100644 python/pyspark/ml/pipeline.pyi create mode 100644 python/pyspark/ml/recommendation.pyi create mode 100644 python/pyspark/ml/regression.pyi create mode 100644 python/pyspark/ml/stat.pyi create mode 100644 python/pyspark/ml/tree.pyi create mode 100644 python/pyspark/ml/tuning.pyi create mode 100644 python/pyspark/ml/util.pyi create mode 100644 python/pyspark/ml/wrapper.pyi create mode 100644 python/pyspark/mllib/__init__.pyi create mode 100644 python/pyspark/mllib/_typing.pyi create mode 100644 python/pyspark/mllib/classification.pyi create mode 100644 python/pyspark/mllib/clustering.pyi create mode 100644 python/pyspark/mllib/common.pyi create mode 100644 python/pyspark/mllib/evaluation.pyi create mode 100644 python/pyspark/mllib/feature.pyi create mode 100644 python/pyspark/mllib/fpm.pyi create mode 100644 python/pyspark/mllib/linalg/__init__.pyi create mode 100644 python/pyspark/mllib/linalg/distributed.pyi create mode 100644 python/pyspark/mllib/random.pyi create mode 100644 python/pyspark/mllib/recommendation.pyi create mode 100644 python/pyspark/mllib/regression.pyi create mode 100644 python/pyspark/mllib/stat/KernelDensity.pyi create mode 100644 python/pyspark/mllib/stat/__init__.pyi create mode 100644 python/pyspark/mllib/stat/_statistics.pyi create mode 100644 python/pyspark/mllib/stat/distribution.pyi create mode 100644 python/pyspark/mllib/stat/test.pyi create mode 100644 python/pyspark/mllib/tree.pyi create mode 100644 python/pyspark/mllib/util.pyi create mode 100644 python/pyspark/profiler.pyi create mode 100644 python/pyspark/py.typed create mode 100644 python/pyspark/rdd.pyi create mode 100644 python/pyspark/rddsampler.pyi create mode 100644 python/pyspark/resource/__init__.pyi create mode 100644 python/pyspark/resource/information.pyi create mode 100644 python/pyspark/resource/profile.pyi create mode 100644 python/pyspark/resource/requests.pyi create mode 100644 python/pyspark/resultiterable.pyi create mode 100644 python/pyspark/serializers.pyi create mode 100644 python/pyspark/shell.pyi create mode 100644 python/pyspark/shuffle.pyi create mode 100644 python/pyspark/sql/__init__.pyi create mode 100644 python/pyspark/sql/_typing.pyi create mode 100644 python/pyspark/sql/avro/__init__.pyi create mode 100644 python/pyspark/sql/avro/functions.pyi create mode 100644 python/pyspark/sql/catalog.pyi create mode 100644 python/pyspark/sql/column.pyi create mode 100644 python/pyspark/sql/conf.pyi create mode 100644 python/pyspark/sql/context.pyi create mode 100644 python/pyspark/sql/dataframe.pyi create mode 100644 python/pyspark/sql/functions.pyi create mode 100644 python/pyspark/sql/group.pyi create mode 100644 python/pyspark/sql/pandas/__init__.pyi create mode 100644 python/pyspark/sql/pandas/_typing/__init__.pyi create mode 100644 python/pyspark/sql/pandas/_typing/protocols/__init__.pyi create mode 100644 python/pyspark/sql/pandas/_typing/protocols/frame.pyi create mode 100644 python/pyspark/sql/pandas/_typing/protocols/series.pyi create mode 100644 python/pyspark/sql/pandas/conversion.pyi create mode 100644 python/pyspark/sql/pandas/functions.pyi create mode 100644 python/pyspark/sql/pandas/group_ops.pyi create mode 100644 python/pyspark/sql/pandas/map_ops.pyi create mode 100644 python/pyspark/sql/pandas/serializers.pyi create mode 100644 python/pyspark/sql/pandas/typehints.pyi create mode 100644 python/pyspark/sql/pandas/types.pyi create mode 100644 python/pyspark/sql/pandas/utils.pyi create mode 100644 python/pyspark/sql/readwriter.pyi create mode 100644 python/pyspark/sql/session.pyi create mode 100644 python/pyspark/sql/streaming.pyi create mode 100644 python/pyspark/sql/types.pyi create mode 100644 python/pyspark/sql/udf.pyi create mode 100644 python/pyspark/sql/utils.pyi create mode 100644 python/pyspark/sql/window.pyi create mode 100644 python/pyspark/statcounter.pyi create mode 100644 python/pyspark/status.pyi create mode 100644 python/pyspark/storagelevel.pyi create mode 100644 python/pyspark/streaming/__init__.pyi create mode 100644 python/pyspark/streaming/context.pyi create mode 100644 python/pyspark/streaming/dstream.pyi create mode 100644 python/pyspark/streaming/kinesis.pyi create mode 100644 python/pyspark/streaming/listener.pyi create mode 100644 python/pyspark/streaming/util.pyi create mode 100644 python/pyspark/taskcontext.pyi create mode 100644 python/pyspark/traceback_utils.pyi create mode 100644 python/pyspark/util.pyi create mode 100644 python/pyspark/version.pyi create mode 100644 python/pyspark/worker.pyi diff --git a/dev/.rat-excludes b/dev/.rat-excludes index df1dd51a7c519..98786437f7b1c 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -124,3 +124,4 @@ GangliaReporter.java application_1578436911597_0052 config.properties app-20200706201101-0003 +py.typed diff --git a/dev/tox.ini b/dev/tox.ini index c14e6b9446cca..7edf7d597fb58 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -20,5 +20,5 @@ exclude=python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,wor [flake8] select = E901,E999,F821,F822,F823,F401,F405 -exclude = python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/* +exclude = python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi max-line-length = 100 diff --git a/examples/src/main/python/ml/estimator_transformer_param_example.py b/examples/src/main/python/ml/estimator_transformer_param_example.py index 1dcca6c201119..2cf9432646b5e 100644 --- a/examples/src/main/python/ml/estimator_transformer_param_example.py +++ b/examples/src/main/python/ml/estimator_transformer_param_example.py @@ -56,12 +56,14 @@ # We may alternatively specify parameters using a Python dictionary as a paramMap paramMap = {lr.maxIter: 20} paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. - paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + # Specify multiple Params. + paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # type: ignore # You can combine paramMaps, which are python dictionaries. - paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name + # Change output column name + paramMap2 = {lr.probabilityCol: "myProbability"} # type: ignore paramMapCombined = paramMap.copy() - paramMapCombined.update(paramMap2) + paramMapCombined.update(paramMap2) # type: ignore # Now learn a new model using the paramMapCombined parameters. # paramMapCombined overrides all parameters set earlier via lr.set* methods. diff --git a/examples/src/main/python/ml/fm_classifier_example.py b/examples/src/main/python/ml/fm_classifier_example.py index b47bdc5275beb..da49e5fc2baa9 100644 --- a/examples/src/main/python/ml/fm_classifier_example.py +++ b/examples/src/main/python/ml/fm_classifier_example.py @@ -67,9 +67,9 @@ print("Test set accuracy = %g" % accuracy) fmModel = model.stages[2] - print("Factors: " + str(fmModel.factors)) - print("Linear: " + str(fmModel.linear)) - print("Intercept: " + str(fmModel.intercept)) + print("Factors: " + str(fmModel.factors)) # type: ignore + print("Linear: " + str(fmModel.linear)) # type: ignore + print("Intercept: " + str(fmModel.intercept)) # type: ignore # $example off$ spark.stop() diff --git a/examples/src/main/python/ml/fm_regressor_example.py b/examples/src/main/python/ml/fm_regressor_example.py index 5c8133996ae83..47544b6324203 100644 --- a/examples/src/main/python/ml/fm_regressor_example.py +++ b/examples/src/main/python/ml/fm_regressor_example.py @@ -64,9 +64,9 @@ print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) fmModel = model.stages[1] - print("Factors: " + str(fmModel.factors)) - print("Linear: " + str(fmModel.linear)) - print("Intercept: " + str(fmModel.intercept)) + print("Factors: " + str(fmModel.factors)) # type: ignore + print("Linear: " + str(fmModel.linear)) # type: ignore + print("Intercept: " + str(fmModel.intercept)) # type: ignore # $example off$ spark.stop() diff --git a/examples/src/main/python/ml/pipeline_example.py b/examples/src/main/python/ml/pipeline_example.py index e1fab7cbe6d80..66fdd73632a70 100644 --- a/examples/src/main/python/ml/pipeline_example.py +++ b/examples/src/main/python/ml/pipeline_example.py @@ -62,8 +62,12 @@ prediction = model.transform(test) selected = prediction.select("id", "text", "probability", "prediction") for row in selected.collect(): - rid, text, prob, prediction = row - print("(%d, %s) --> prob=%s, prediction=%f" % (rid, text, str(prob), prediction)) + rid, text, prob, prediction = row # type: ignore + print( + "(%d, %s) --> prob=%s, prediction=%f" % ( + rid, text, str(prob), prediction # type: ignore + ) + ) # $example off$ spark.stop() diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 1789a54f0276e..9978e8601449a 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -32,8 +32,8 @@ def dataframe_with_arrow_example(spark): - import numpy as np - import pandas as pd + import numpy as np # type: ignore[import] + import pandas as pd # type: ignore[import] # Enable Arrow-based columnar data transfers spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") diff --git a/python/MANIFEST.in b/python/MANIFEST.in index 2d78a001a4d98..862d62b1d3b29 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -22,4 +22,5 @@ recursive-include deps/data *.data *.txt recursive-include deps/licenses *.txt recursive-include deps/examples *.py recursive-include lib *.zip +recursive-include pyspark *.pyi py.typed include README.md diff --git a/python/mypy.ini b/python/mypy.ini new file mode 100644 index 0000000000000..a9523e622ca0d --- /dev/null +++ b/python/mypy.ini @@ -0,0 +1,36 @@ +; +; Licensed to the Apache Software Foundation (ASF) under one or more +; contributor license agreements. See the NOTICE file distributed with +; this work for additional information regarding copyright ownership. +; The ASF licenses this file to You under the Apache License, Version 2.0 +; (the "License"); you may not use this file except in compliance with +; the License. You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, +; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +; See the License for the specific language governing permissions and +; limitations under the License. +; + +[mypy] + +[mypy-pyspark.cloudpickle.*] +ignore_errors = True + +[mypy-py4j.*] +ignore_missing_imports = True + +[mypy-numpy] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-pandas.*] +ignore_missing_imports = True + +[mypy-pyarrow] +ignore_missing_imports = True diff --git a/python/pyspark/__init__.pyi b/python/pyspark/__init__.pyi new file mode 100644 index 0000000000000..98bd40684c01b --- /dev/null +++ b/python/pyspark/__init__.pyi @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Callable, Optional, TypeVar + +from pyspark.accumulators import ( # noqa: F401 + Accumulator as Accumulator, + AccumulatorParam as AccumulatorParam, +) +from pyspark.broadcast import Broadcast as Broadcast # noqa: F401 +from pyspark.conf import SparkConf as SparkConf # noqa: F401 +from pyspark.context import SparkContext as SparkContext # noqa: F401 +from pyspark.files import SparkFiles as SparkFiles # noqa: F401 +from pyspark.status import ( + StatusTracker as StatusTracker, + SparkJobInfo as SparkJobInfo, + SparkStageInfo as SparkStageInfo, +) # noqa: F401 +from pyspark.profiler import ( # noqa: F401 + BasicProfiler as BasicProfiler, + Profiler as Profiler, +) +from pyspark.rdd import RDD as RDD, RDDBarrier as RDDBarrier # noqa: F401 +from pyspark.serializers import ( # noqa: F401 + MarshalSerializer as MarshalSerializer, + PickleSerializer as PickleSerializer, +) +from pyspark.status import ( # noqa: F401 + SparkJobInfo as SparkJobInfo, + SparkStageInfo as SparkStageInfo, + StatusTracker as StatusTracker, +) +from pyspark.storagelevel import StorageLevel as StorageLevel # noqa: F401 +from pyspark.taskcontext import ( # noqa: F401 + BarrierTaskContext as BarrierTaskContext, + BarrierTaskInfo as BarrierTaskInfo, + TaskContext as TaskContext, +) +from pyspark.util import InheritableThread as InheritableThread # noqa: F401 + +# Compatiblity imports +from pyspark.sql import ( # noqa: F401 + SQLContext as SQLContext, + HiveContext as HiveContext, + Row as Row, +) + +T = TypeVar("T") +F = TypeVar("F", bound=Callable) + +def since(version: str) -> Callable[[T], T]: ... +def copy_func( + f: F, + name: Optional[str] = ..., + sinceversion: Optional[str] = ..., + doc: Optional[str] = ..., +) -> F: ... +def keyword_only(func: F) -> F: ... diff --git a/python/pyspark/_globals.pyi b/python/pyspark/_globals.pyi new file mode 100644 index 0000000000000..9453775621196 --- /dev/null +++ b/python/pyspark/_globals.pyi @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: This dynamically typed stub was automatically generated by stubgen. + +from typing import Any + +__ALL__: Any + +class _NoValueType: + def __new__(cls): ... + def __reduce__(self): ... diff --git a/python/pyspark/_typing.pyi b/python/pyspark/_typing.pyi new file mode 100644 index 0000000000000..637e4cb4fbccc --- /dev/null +++ b/python/pyspark/_typing.pyi @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Callable, Iterable, Sized, TypeVar, Union +from typing_extensions import Protocol + +F = TypeVar("F", bound=Callable) +T = TypeVar("T", covariant=True) + +PrimitiveType = Union[bool, float, int, str] + +class SupportsIAdd(Protocol): + def __iadd__(self, other: SupportsIAdd) -> SupportsIAdd: ... + +class SupportsOrdering(Protocol): + def __le__(self, other: SupportsOrdering) -> bool: ... + +class SizedIterable(Protocol, Sized, Iterable[T]): ... diff --git a/python/pyspark/accumulators.pyi b/python/pyspark/accumulators.pyi new file mode 100644 index 0000000000000..94f8023d1102b --- /dev/null +++ b/python/pyspark/accumulators.pyi @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Callable, Generic, Tuple, Type, TypeVar + +import socketserver.BaseRequestHandler # type: ignore + +from pyspark._typing import SupportsIAdd + +T = TypeVar("T") +U = TypeVar("U", bound=SupportsIAdd) + +import socketserver as SocketServer + +class Accumulator(Generic[T]): + aid: int + accum_param: AccumulatorParam[T] + def __init__( + self, aid: int, value: T, accum_param: AccumulatorParam[T] + ) -> None: ... + def __reduce__( + self, + ) -> Tuple[ + Callable[[int, int, AccumulatorParam[T]], Accumulator[T]], + Tuple[int, int, AccumulatorParam[T]], + ]: ... + @property + def value(self) -> T: ... + @value.setter + def value(self, value: T) -> None: ... + def add(self, term: T) -> None: ... + def __iadd__(self, term: T) -> Accumulator[T]: ... + +class AccumulatorParam(Generic[T]): + def zero(self, value: T) -> T: ... + def addInPlace(self, value1: T, value2: T) -> T: ... + +class AddingAccumulatorParam(AccumulatorParam[U]): + zero_value: U + def __init__(self, zero_value: U) -> None: ... + def zero(self, value: U) -> U: ... + def addInPlace(self, value1: U, value2: U) -> U: ... + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self) -> None: ... + +class AccumulatorServer(SocketServer.TCPServer): + auth_token: str + def __init__( + self, + server_address: Tuple[str, int], + RequestHandlerClass: Type[socketserver.BaseRequestHandler], + auth_token: str, + ) -> None: ... + server_shutdown: bool + def shutdown(self) -> None: ... diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi new file mode 100644 index 0000000000000..c2ea3c6f7d8b4 --- /dev/null +++ b/python/pyspark/broadcast.pyi @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import threading +from typing import Any, Generic, Optional, TypeVar + +T = TypeVar("T") + +class Broadcast(Generic[T]): + def __init__( + self, + sc: Optional[Any] = ..., + value: Optional[T] = ..., + pickle_registry: Optional[Any] = ..., + path: Optional[Any] = ..., + sock_file: Optional[Any] = ..., + ) -> None: ... + def dump(self, value: Any, f: Any) -> None: ... + def load_from_path(self, path: Any): ... + def load(self, file: Any): ... + @property + def value(self) -> T: ... + def unpersist(self, blocking: bool = ...) -> None: ... + def destroy(self, blocking: bool = ...) -> None: ... + def __reduce__(self): ... + +class BroadcastPickleRegistry(threading.local): + def __init__(self) -> None: ... + def __iter__(self) -> None: ... + def add(self, bcast: Any) -> None: ... + def clear(self) -> None: ... diff --git a/python/pyspark/conf.pyi b/python/pyspark/conf.pyi new file mode 100644 index 0000000000000..f7ca61dea9cd2 --- /dev/null +++ b/python/pyspark/conf.pyi @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import List, Optional, Tuple + +from py4j.java_gateway import JVMView, JavaObject # type: ignore[import] + +class SparkConf: + def __init__( + self, + loadDefaults: bool = ..., + _jvm: Optional[JVMView] = ..., + _jconf: Optional[JavaObject] = ..., + ) -> None: ... + def set(self, key: str, value: str) -> SparkConf: ... + def setIfMissing(self, key: str, value: str) -> SparkConf: ... + def setMaster(self, value: str) -> SparkConf: ... + def setAppName(self, value: str) -> SparkConf: ... + def setSparkHome(self, value: str) -> SparkConf: ... + @overload + def setExecutorEnv(self, key: str, value: str) -> SparkConf: ... + @overload + def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> SparkConf: ... + def setAll(self, pairs: List[Tuple[str, str]]) -> SparkConf: ... + def get(self, key: str, defaultValue: Optional[str] = ...) -> str: ... + def getAll(self) -> List[Tuple[str, str]]: ... + def contains(self, key: str) -> bool: ... + def toDebugString(self) -> str: ... diff --git a/python/pyspark/context.pyi b/python/pyspark/context.pyi new file mode 100644 index 0000000000000..76ecf8911471a --- /dev/null +++ b/python/pyspark/context.pyi @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar + +from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import] + +from pyspark.accumulators import Accumulator, AccumulatorParam +from pyspark.broadcast import Broadcast +from pyspark.conf import SparkConf +from pyspark.profiler import Profiler # noqa: F401 +from pyspark.resource.information import ResourceInformation +from pyspark.rdd import RDD +from pyspark.serializers import Serializer +from pyspark.status import StatusTracker + +T = TypeVar("T") +U = TypeVar("U") + +class SparkContext: + master: str + appName: str + sparkHome: str + PACKAGE_EXTENSIONS: Iterable[str] + def __init__( + self, + master: Optional[str] = ..., + appName: Optional[str] = ..., + sparkHome: Optional[str] = ..., + pyFiles: Optional[List[str]] = ..., + environment: Optional[Dict[str, str]] = ..., + batchSize: int = ..., + serializer: Serializer = ..., + conf: Optional[SparkConf] = ..., + gateway: Optional[JavaGateway] = ..., + jsc: Optional[JavaObject] = ..., + profiler_cls: type = ..., + ) -> None: ... + def __getnewargs__(self): ... + def __enter__(self): ... + def __exit__(self, type, value, trace): ... + @classmethod + def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ... + def setLogLevel(self, logLevel: str) -> None: ... + @classmethod + def setSystemProperty(cls, key: str, value: str) -> None: ... + @property + def version(self) -> str: ... + @property + def applicationId(self) -> str: ... + @property + def uiWebUrl(self) -> str: ... + @property + def startTime(self) -> int: ... + @property + def defaultParallelism(self) -> int: ... + @property + def defaultMinPartitions(self) -> int: ... + def stop(self) -> None: ... + def emptyRDD(self) -> RDD[Any]: ... + def range( + self, + start: int, + end: Optional[int] = ..., + step: int = ..., + numSlices: Optional[int] = ..., + ) -> RDD[int]: ... + def parallelize(self, c: Iterable[T], numSlices: Optional[int] = ...) -> RDD[T]: ... + def pickleFile(self, name: str, minPartitions: Optional[int] = ...) -> RDD[Any]: ... + def textFile( + self, name: str, minPartitions: Optional[int] = ..., use_unicode: bool = ... + ) -> RDD[str]: ... + def wholeTextFiles( + self, path: str, minPartitions: Optional[int] = ..., use_unicode: bool = ... + ) -> RDD[Tuple[str, str]]: ... + def binaryFiles( + self, path: str, minPartitions: Optional[int] = ... + ) -> RDD[Tuple[str, bytes]]: ... + def binaryRecords(self, path: str, recordLength: int) -> RDD[bytes]: ... + def sequenceFile( + self, + path: str, + keyClass: Optional[str] = ..., + valueClass: Optional[str] = ..., + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + minSplits: Optional[int] = ..., + batchSize: int = ..., + ) -> RDD[Tuple[T, U]]: ... + def newAPIHadoopFile( + self, + path: str, + inputFormatClass: str, + keyClass: str, + valueClass: str, + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + conf: Optional[Dict[str, str]] = ..., + batchSize: int = ..., + ) -> RDD[Tuple[T, U]]: ... + def newAPIHadoopRDD( + self, + inputFormatClass: str, + keyClass: str, + valueClass: str, + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + conf: Optional[Dict[str, str]] = ..., + batchSize: int = ..., + ) -> RDD[Tuple[T, U]]: ... + def hadoopFile( + self, + path: str, + inputFormatClass: str, + keyClass: str, + valueClass: str, + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + conf: Optional[Dict[str, str]] = ..., + batchSize: int = ..., + ) -> RDD[Tuple[T, U]]: ... + def hadoopRDD( + self, + inputFormatClass: str, + keyClass: str, + valueClass: str, + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + conf: Optional[Dict[str, str]] = ..., + batchSize: int = ..., + ) -> RDD[Tuple[T, U]]: ... + def union(self, rdds: Iterable[RDD[T]]) -> RDD[T]: ... + def broadcast(self, value: T) -> Broadcast[T]: ... + def accumulator( + self, value: T, accum_param: Optional[AccumulatorParam[T]] = ... + ) -> Accumulator[T]: ... + def addFile(self, path: str, recursive: bool = ...) -> None: ... + def addPyFile(self, path: str) -> None: ... + def setCheckpointDir(self, dirName: str) -> None: ... + def setJobGroup( + self, groupId: str, description: str, interruptOnCancel: bool = ... + ) -> None: ... + def setLocalProperty(self, key: str, value: str) -> None: ... + def getLocalProperty(self, key: str) -> Optional[str]: ... + def sparkUser(self) -> str: ... + def setJobDescription(self, value: str) -> None: ... + def cancelJobGroup(self, groupId: str) -> None: ... + def cancelAllJobs(self) -> None: ... + def statusTracker(self) -> StatusTracker: ... + def runJob( + self, + rdd: RDD[T], + partitionFunc: Callable[[Iterable[T]], Iterable[U]], + partitions: Optional[List[int]] = ..., + allowLocal: bool = ..., + ) -> List[U]: ... + def show_profiles(self) -> None: ... + def dump_profiles(self, path: str) -> None: ... + def getConf(self) -> SparkConf: ... + @property + def resources(self) -> Dict[str, ResourceInformation]: ... diff --git a/python/pyspark/daemon.pyi b/python/pyspark/daemon.pyi new file mode 100644 index 0000000000000..dfacf30a9f8a7 --- /dev/null +++ b/python/pyspark/daemon.pyi @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.serializers import ( # noqa: F401 + UTF8Deserializer as UTF8Deserializer, + read_int as read_int, + write_int as write_int, + write_with_length as write_with_length, +) +from typing import Any + +def compute_real_exit_code(exit_code: Any): ... +def worker(sock: Any, authenticated: Any): ... +def manager() -> None: ... diff --git a/python/pyspark/files.pyi b/python/pyspark/files.pyi new file mode 100644 index 0000000000000..9e7cad17ebbdb --- /dev/null +++ b/python/pyspark/files.pyi @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +class SparkFiles: + def __init__(self) -> None: ... + @classmethod + def get(cls, filename: str) -> str: ... + @classmethod + def getRootDirectory(cls) -> str: ... diff --git a/python/pyspark/find_spark_home.pyi b/python/pyspark/find_spark_home.pyi new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/python/pyspark/find_spark_home.pyi @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/pyspark/java_gateway.pyi b/python/pyspark/java_gateway.pyi new file mode 100644 index 0000000000000..5b45206dc045c --- /dev/null +++ b/python/pyspark/java_gateway.pyi @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.serializers import UTF8Deserializer as UTF8Deserializer, read_int as read_int, write_with_length as write_with_length # type: ignore[attr-defined] +from typing import Any, Optional + +def launch_gateway(conf: Optional[Any] = ..., popen_kwargs: Optional[Any] = ...): ... +def local_connect_and_auth(port: Any, auth_secret: Any): ... +def ensure_callback_server_started(gw: Any) -> None: ... diff --git a/python/pyspark/join.pyi b/python/pyspark/join.pyi new file mode 100644 index 0000000000000..e89e0fbbcda9b --- /dev/null +++ b/python/pyspark/join.pyi @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Hashable, Iterable, Optional, Tuple, TypeVar + +from pyspark.resultiterable import ResultIterable +import pyspark.rdd + +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") +U = TypeVar("U") + +def python_join( + rdd: pyspark.rdd.RDD[Tuple[K, V]], + other: pyspark.rdd.RDD[Tuple[K, U]], + numPartitions: int, +) -> pyspark.rdd.RDD[Tuple[K, Tuple[V, U]]]: ... +def python_right_outer_join( + rdd: pyspark.rdd.RDD[Tuple[K, V]], + other: pyspark.rdd.RDD[Tuple[K, U]], + numPartitions: int, +) -> pyspark.rdd.RDD[Tuple[K, Tuple[V, Optional[U]]]]: ... +def python_left_outer_join( + rdd: pyspark.rdd.RDD[Tuple[K, V]], + other: pyspark.rdd.RDD[Tuple[K, U]], + numPartitions: int, +) -> pyspark.rdd.RDD[Tuple[K, Tuple[Optional[V], U]]]: ... +def python_full_outer_join( + rdd: pyspark.rdd.RDD[Tuple[K, V]], + other: pyspark.rdd.RDD[Tuple[K, U]], + numPartitions: int, +) -> pyspark.rdd.RDD[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ... +def python_cogroup( + rdds: Iterable[pyspark.rdd.RDD[Tuple[K, V]]], numPartitions: int +) -> pyspark.rdd.RDD[Tuple[K, Tuple[ResultIterable[V], ...]]]: ... diff --git a/python/pyspark/ml/__init__.pyi b/python/pyspark/ml/__init__.pyi new file mode 100644 index 0000000000000..8e3b8a5daeb08 --- /dev/null +++ b/python/pyspark/ml/__init__.pyi @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.ml import ( # noqa: F401 + classification as classification, + clustering as clustering, + evaluation as evaluation, + feature as feature, + fpm as fpm, + image as image, + linalg as linalg, + param as param, + recommendation as recommendation, + regression as regression, + stat as stat, + tuning as tuning, + util as util, +) +from pyspark.ml.base import ( # noqa: F401 + Estimator as Estimator, + Model as Model, + PredictionModel as PredictionModel, + Predictor as Predictor, + Transformer as Transformer, + UnaryTransformer as UnaryTransformer, +) +from pyspark.ml.pipeline import ( # noqa: F401 + Pipeline as Pipeline, + PipelineModel as PipelineModel, +) diff --git a/python/pyspark/ml/_typing.pyi b/python/pyspark/ml/_typing.pyi new file mode 100644 index 0000000000000..d966a787c0fca --- /dev/null +++ b/python/pyspark/ml/_typing.pyi @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, TypeVar, Union +from typing_extensions import Literal + +import pyspark.ml.base +import pyspark.ml.param +import pyspark.ml.util +import pyspark.ml.wrapper + +ParamMap = Dict[pyspark.ml.param.Param, Any] +PipelineStage = Union[pyspark.ml.base.Estimator, pyspark.ml.base.Transformer] + +T = TypeVar("T") +P = TypeVar("P", bound=pyspark.ml.param.Params) +M = TypeVar("M", bound=pyspark.ml.base.Transformer) +JM = TypeVar("JM", bound=pyspark.ml.wrapper.JavaTransformer) + +BinaryClassificationEvaluatorMetricType = Union[ + Literal["areaUnderROC"], Literal["areaUnderPR"] +] +RegressionEvaluatorMetricType = Union[ + Literal["rmse"], Literal["mse"], Literal["r2"], Literal["mae"], Literal["var"] +] +MulticlassClassificationEvaluatorMetricType = Union[ + Literal["f1"], + Literal["accuracy"], + Literal["weightedPrecision"], + Literal["weightedRecall"], + Literal["weightedTruePositiveRate"], + Literal["weightedFalsePositiveRate"], + Literal["weightedFMeasure"], + Literal["truePositiveRateByLabel"], + Literal["falsePositiveRateByLabel"], + Literal["precisionByLabel"], + Literal["recallByLabel"], + Literal["fMeasureByLabel"], +] +MultilabelClassificationEvaluatorMetricType = Union[ + Literal["subsetAccuracy"], + Literal["accuracy"], + Literal["hammingLoss"], + Literal["precision"], + Literal["recall"], + Literal["f1Measure"], + Literal["precisionByLabel"], + Literal["recallByLabel"], + Literal["f1MeasureByLabel"], + Literal["microPrecision"], + Literal["microRecall"], + Literal["microF1Measure"], +] +ClusteringEvaluatorMetricType = Union[Literal["silhouette"]] +RankingEvaluatorMetricType = Union[ + Literal["meanAveragePrecision"], + Literal["meanAveragePrecisionAtK"], + Literal["precisionAtK"], + Literal["ndcgAtK"], + Literal["recallAtK"], +] diff --git a/python/pyspark/ml/base.pyi b/python/pyspark/ml/base.pyi new file mode 100644 index 0000000000000..7fd8c3b70b672 --- /dev/null +++ b/python/pyspark/ml/base.pyi @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import ( + Callable, + Generic, + Iterable, + List, + Optional, + Tuple, +) +from pyspark.ml._typing import M, P, T, ParamMap + +import _thread + +import abc +from abc import abstractmethod +from pyspark import since as since # noqa: F401 +from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401 +from pyspark.ml.param.shared import ( + HasFeaturesCol as HasFeaturesCol, + HasInputCol as HasInputCol, + HasLabelCol as HasLabelCol, + HasOutputCol as HasOutputCol, + HasPredictionCol as HasPredictionCol, + Params as Params, +) +from pyspark.sql.functions import udf as udf # noqa: F401 +from pyspark.sql.types import ( # noqa: F401 + DataType, + StructField as StructField, + StructType as StructType, +) + +from pyspark.sql.dataframe import DataFrame + +class _FitMultipleIterator: + fitSingleModel: Callable[[int], Transformer] + numModel: int + counter: int = ... + lock: _thread.LockType + def __init__( + self, fitSingleModel: Callable[[int], Transformer], numModels: int + ) -> None: ... + def __iter__(self) -> _FitMultipleIterator: ... + def __next__(self) -> Tuple[int, Transformer]: ... + def next(self) -> Tuple[int, Transformer]: ... + +class Estimator(Generic[M], Params, metaclass=abc.ABCMeta): + @overload + def fit(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> M: ... + @overload + def fit(self, dataset: DataFrame, params: List[ParamMap]) -> List[M]: ... + def fitMultiple( + self, dataset: DataFrame, params: List[ParamMap] + ) -> Iterable[Tuple[int, M]]: ... + +class Transformer(Params, metaclass=abc.ABCMeta): + def transform( + self, dataset: DataFrame, params: Optional[ParamMap] = ... + ) -> DataFrame: ... + +class Model(Transformer, metaclass=abc.ABCMeta): ... + +class UnaryTransformer(HasInputCol, HasOutputCol, Transformer, metaclass=abc.ABCMeta): + def createTransformFunc(self) -> Callable: ... + def outputDataType(self) -> DataType: ... + def validateInputType(self, inputType: DataType) -> None: ... + def transformSchema(self, schema: StructType) -> StructType: ... + def setInputCol(self: M, value: str) -> M: ... + def setOutputCol(self: M, value: str) -> M: ... + +class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol): ... + +class Predictor(Estimator[M], _PredictorParams, metaclass=abc.ABCMeta): + def setLabelCol(self: P, value: str) -> P: ... + def setFeaturesCol(self: P, value: str) -> P: ... + def setPredictionCol(self: P, value: str) -> P: ... + +class PredictionModel(Generic[T], Model, _PredictorParams, metaclass=abc.ABCMeta): + def setFeaturesCol(self: M, value: str) -> M: ... + def setPredictionCol(self: M, value: str) -> M: ... + @property + @abc.abstractmethod + def numFeatures(self) -> int: ... + @abstractmethod + def predict(self, value: T) -> float: ... diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi new file mode 100644 index 0000000000000..55afc20a54cb9 --- /dev/null +++ b/python/pyspark/ml/classification.pyi @@ -0,0 +1,922 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, List, Optional +from pyspark.ml._typing import JM, M, P, T, ParamMap + +import abc +from abc import abstractmethod +from pyspark.ml import Estimator, Model, PredictionModel, Predictor, Transformer +from pyspark.ml.base import _PredictorParams +from pyspark.ml.param.shared import ( + HasAggregationDepth, + HasBlockSize, + HasElasticNetParam, + HasFitIntercept, + HasMaxIter, + HasParallelism, + HasProbabilityCol, + HasRawPredictionCol, + HasRegParam, + HasSeed, + HasSolver, + HasStandardization, + HasStepSize, + HasThreshold, + HasThresholds, + HasTol, + HasWeightCol, +) +from pyspark.ml.regression import _FactorizationMachinesParams +from pyspark.ml.tree import ( + _DecisionTreeModel, + _DecisionTreeParams, + _GBTParams, + _HasVarianceImpurity, + _RandomForestParams, + _TreeClassifierParams, + _TreeEnsembleModel, +) +from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable +from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaWrapper + +from pyspark.ml.linalg import Matrix, Vector +from pyspark.ml.param import Param +from pyspark.ml.regression import DecisionTreeRegressionModel +from pyspark.sql.dataframe import DataFrame + +class _ClassifierParams(HasRawPredictionCol, _PredictorParams): ... + +class Classifier(Predictor, _ClassifierParams, metaclass=abc.ABCMeta): + def setRawPredictionCol(self: P, value: str) -> P: ... + +class ClassificationModel(PredictionModel, _ClassifierParams, metaclass=abc.ABCMeta): + def setRawPredictionCol(self: P, value: str) -> P: ... + @property + @abc.abstractmethod + def numClasses(self) -> int: ... + @abstractmethod + def predictRaw(self, value: Vector) -> Vector: ... + +class _ProbabilisticClassifierParams( + HasProbabilityCol, HasThresholds, _ClassifierParams +): ... + +class ProbabilisticClassifier( + Classifier, _ProbabilisticClassifierParams, metaclass=abc.ABCMeta +): + def setProbabilityCol(self: P, value: str) -> P: ... + def setThresholds(self: P, value: List[float]) -> P: ... + +class ProbabilisticClassificationModel( + ClassificationModel, _ProbabilisticClassifierParams, metaclass=abc.ABCMeta +): + def setProbabilityCol(self: M, value: str) -> M: ... + def setThresholds(self: M, value: List[float]) -> M: ... + @abstractmethod + def predictProbability(self, value: Vector) -> Vector: ... + +class _JavaClassifier(Classifier, JavaPredictor[JM], metaclass=abc.ABCMeta): + def setRawPredictionCol(self: P, value: str) -> P: ... + +class _JavaClassificationModel(ClassificationModel, JavaPredictionModel[T]): + @property + def numClasses(self) -> int: ... + def predictRaw(self, value: Vector) -> Vector: ... + +class _JavaProbabilisticClassifier( + ProbabilisticClassifier, _JavaClassifier[JM], metaclass=abc.ABCMeta +): ... + +class _JavaProbabilisticClassificationModel( + ProbabilisticClassificationModel, _JavaClassificationModel[T] +): + def predictProbability(self, value: Any): ... + +class _ClassificationSummary(JavaWrapper): + @property + def predictions(self) -> DataFrame: ... + @property + def predictionCol(self) -> str: ... + @property + def labelCol(self) -> str: ... + @property + def weightCol(self) -> str: ... + @property + def labels(self) -> List[str]: ... + @property + def truePositiveRateByLabel(self) -> List[float]: ... + @property + def falsePositiveRateByLabel(self) -> List[float]: ... + @property + def precisionByLabel(self) -> List[float]: ... + @property + def recallByLabel(self) -> List[float]: ... + def fMeasureByLabel(self, beta: float = ...) -> List[float]: ... + @property + def accuracy(self) -> float: ... + @property + def weightedTruePositiveRate(self) -> float: ... + @property + def weightedFalsePositiveRate(self) -> float: ... + @property + def weightedRecall(self) -> float: ... + @property + def weightedPrecision(self) -> float: ... + def weightedFMeasure(self, beta: float = ...) -> float: ... + +class _TrainingSummary(JavaWrapper): + @property + def objectiveHistory(self) -> List[float]: ... + @property + def totalIterations(self) -> int: ... + +class _BinaryClassificationSummary(_ClassificationSummary): + @property + def scoreCol(self) -> str: ... + @property + def roc(self) -> DataFrame: ... + @property + def areaUnderROC(self) -> float: ... + @property + def pr(self) -> DataFrame: ... + @property + def fMeasureByThreshold(self) -> DataFrame: ... + @property + def precisionByThreshold(self) -> DataFrame: ... + @property + def recallByThreshold(self) -> DataFrame: ... + +class _LinearSVCParams( + _ClassifierParams, + HasRegParam, + HasMaxIter, + HasFitIntercept, + HasTol, + HasStandardization, + HasWeightCol, + HasAggregationDepth, + HasThreshold, + HasBlockSize, +): + threshold: Param[float] + def __init__(self, *args: Any) -> None: ... + +class LinearSVC( + _JavaClassifier[LinearSVCModel], + _LinearSVCParams, + JavaMLWritable, + JavaMLReadable[LinearSVC], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + tol: float = ..., + rawPredictionCol: str = ..., + fitIntercept: bool = ..., + standardization: bool = ..., + threshold: float = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + blockSize: int = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + tol: float = ..., + rawPredictionCol: str = ..., + fitIntercept: bool = ..., + standardization: bool = ..., + threshold: float = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + blockSize: int = ... + ) -> LinearSVC: ... + def setMaxIter(self, value: int) -> LinearSVC: ... + def setRegParam(self, value: float) -> LinearSVC: ... + def setTol(self, value: float) -> LinearSVC: ... + def setFitIntercept(self, value: bool) -> LinearSVC: ... + def setStandardization(self, value: bool) -> LinearSVC: ... + def setThreshold(self, value: float) -> LinearSVC: ... + def setWeightCol(self, value: str) -> LinearSVC: ... + def setAggregationDepth(self, value: int) -> LinearSVC: ... + def setBlockSize(self, value: int) -> LinearSVC: ... + +class LinearSVCModel( + _JavaClassificationModel[Vector], + _LinearSVCParams, + JavaMLWritable, + JavaMLReadable[LinearSVCModel], + HasTrainingSummary[LinearSVCTrainingSummary], +): + def setThreshold(self, value: float) -> LinearSVCModel: ... + @property + def coefficients(self) -> Vector: ... + @property + def intercept(self) -> float: ... + def summary(self) -> LinearSVCTrainingSummary: ... + def evaluate(self, dataset: DataFrame) -> LinearSVCSummary: ... + +class LinearSVCSummary(_BinaryClassificationSummary): ... +class LinearSVCTrainingSummary(LinearSVCSummary, _TrainingSummary): ... + +class _LogisticRegressionParams( + _ProbabilisticClassifierParams, + HasRegParam, + HasElasticNetParam, + HasMaxIter, + HasFitIntercept, + HasTol, + HasStandardization, + HasWeightCol, + HasAggregationDepth, + HasThreshold, + HasBlockSize, +): + threshold: Param[float] + family: Param[str] + lowerBoundsOnCoefficients: Param[Matrix] + upperBoundsOnCoefficients: Param[Matrix] + lowerBoundsOnIntercepts: Param[Vector] + upperBoundsOnIntercepts: Param[Vector] + def __init__(self, *args: Any): ... + def setThreshold(self: P, value: float) -> P: ... + def getThreshold(self) -> float: ... + def setThresholds(self: P, value: List[float]) -> P: ... + def getThresholds(self) -> List[float]: ... + def getFamily(self) -> str: ... + def getLowerBoundsOnCoefficients(self) -> Matrix: ... + def getUpperBoundsOnCoefficients(self) -> Matrix: ... + def getLowerBoundsOnIntercepts(self) -> Vector: ... + def getUpperBoundsOnIntercepts(self) -> Vector: ... + +class LogisticRegression( + _JavaProbabilisticClassifier[LogisticRegressionModel], + _LogisticRegressionParams, + JavaMLWritable, + JavaMLReadable[LogisticRegression], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + threshold: float = ..., + thresholds: Optional[List[float]] = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + standardization: bool = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + family: str = ..., + lowerBoundsOnCoefficients: Optional[Matrix] = ..., + upperBoundsOnCoefficients: Optional[Matrix] = ..., + lowerBoundsOnIntercepts: Optional[Vector] = ..., + upperBoundsOnIntercepts: Optional[Vector] = ..., + blockSize: int = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + threshold: float = ..., + thresholds: Optional[List[float]] = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + standardization: bool = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + family: str = ..., + lowerBoundsOnCoefficients: Optional[Matrix] = ..., + upperBoundsOnCoefficients: Optional[Matrix] = ..., + lowerBoundsOnIntercepts: Optional[Vector] = ..., + upperBoundsOnIntercepts: Optional[Vector] = ..., + blockSize: int = ... + ) -> LogisticRegression: ... + def setFamily(self, value: str) -> LogisticRegression: ... + def setLowerBoundsOnCoefficients(self, value: Matrix) -> LogisticRegression: ... + def setUpperBoundsOnCoefficients(self, value: Matrix) -> LogisticRegression: ... + def setLowerBoundsOnIntercepts(self, value: Vector) -> LogisticRegression: ... + def setUpperBoundsOnIntercepts(self, value: Vector) -> LogisticRegression: ... + def setMaxIter(self, value: int) -> LogisticRegression: ... + def setRegParam(self, value: float) -> LogisticRegression: ... + def setTol(self, value: float) -> LogisticRegression: ... + def setElasticNetParam(self, value: float) -> LogisticRegression: ... + def setFitIntercept(self, value: bool) -> LogisticRegression: ... + def setStandardization(self, value: bool) -> LogisticRegression: ... + def setWeightCol(self, value: str) -> LogisticRegression: ... + def setAggregationDepth(self, value: int) -> LogisticRegression: ... + def setBlockSize(self, value: int) -> LogisticRegression: ... + +class LogisticRegressionModel( + _JavaProbabilisticClassificationModel[Vector], + _LogisticRegressionParams, + JavaMLWritable, + JavaMLReadable[LogisticRegressionModel], + HasTrainingSummary[LogisticRegressionTrainingSummary], +): + @property + def coefficients(self) -> Vector: ... + @property + def intercept(self) -> float: ... + @property + def coefficientMatrix(self) -> Matrix: ... + @property + def interceptVector(self) -> Vector: ... + @property + def summary(self) -> LogisticRegressionTrainingSummary: ... + def evaluate(self, dataset: DataFrame) -> LogisticRegressionSummary: ... + +class LogisticRegressionSummary(_ClassificationSummary): + @property + def probabilityCol(self) -> str: ... + @property + def featuresCol(self) -> str: ... + +class LogisticRegressionTrainingSummary( + LogisticRegressionSummary, _TrainingSummary +): ... +class BinaryLogisticRegressionSummary( + _BinaryClassificationSummary, LogisticRegressionSummary +): ... +class BinaryLogisticRegressionTrainingSummary( + BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary +): ... + +class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams): + def __init__(self, *args: Any): ... + +class DecisionTreeClassifier( + _JavaProbabilisticClassifier[DecisionTreeClassificationModel], + _DecisionTreeClassifierParams, + JavaMLWritable, + JavaMLReadable[DecisionTreeClassifier], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + impurity: str = ..., + seed: Optional[int] = ..., + weightCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + impurity: str = ..., + seed: Optional[int] = ..., + weightCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ... + ) -> DecisionTreeClassifier: ... + def setMaxDepth(self, value: int) -> DecisionTreeClassifier: ... + def setMaxBins(self, value: int) -> DecisionTreeClassifier: ... + def setMinInstancesPerNode(self, value: int) -> DecisionTreeClassifier: ... + def setMinWeightFractionPerNode(self, value: float) -> DecisionTreeClassifier: ... + def setMinInfoGain(self, value: float) -> DecisionTreeClassifier: ... + def setMaxMemoryInMB(self, value: int) -> DecisionTreeClassifier: ... + def setCacheNodeIds(self, value: bool) -> DecisionTreeClassifier: ... + def setImpurity(self, value: str) -> DecisionTreeClassifier: ... + def setCheckpointInterval(self, value: int) -> DecisionTreeClassifier: ... + def setSeed(self, value: int) -> DecisionTreeClassifier: ... + def setWeightCol(self, value: str) -> DecisionTreeClassifier: ... + +class DecisionTreeClassificationModel( + _DecisionTreeModel, + _JavaProbabilisticClassificationModel[Vector], + _DecisionTreeClassifierParams, + JavaMLWritable, + JavaMLReadable[DecisionTreeClassificationModel], +): + @property + def featureImportances(self) -> Vector: ... + +class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams): + def __init__(self, *args: Any): ... + +class RandomForestClassifier( + _JavaProbabilisticClassifier[RandomForestClassificationModel], + _RandomForestClassifierParams, + JavaMLWritable, + JavaMLReadable[RandomForestClassifier], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + impurity: str = ..., + numTrees: int = ..., + featureSubsetStrategy: str = ..., + seed: Optional[int] = ..., + subsamplingRate: float = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ..., + bootstrap: Optional[bool] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + seed: Optional[int] = ..., + impurity: str = ..., + numTrees: int = ..., + featureSubsetStrategy: str = ..., + subsamplingRate: float = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ..., + bootstrap: Optional[bool] = ... + ) -> RandomForestClassifier: ... + def setMaxDepth(self, value: int) -> RandomForestClassifier: ... + def setMaxBins(self, value: int) -> RandomForestClassifier: ... + def setMinInstancesPerNode(self, value: int) -> RandomForestClassifier: ... + def setMinInfoGain(self, value: float) -> RandomForestClassifier: ... + def setMaxMemoryInMB(self, value: int) -> RandomForestClassifier: ... + def setCacheNodeIds(self, value: bool) -> RandomForestClassifier: ... + def setImpurity(self, value: str) -> RandomForestClassifier: ... + def setNumTrees(self, value: int) -> RandomForestClassifier: ... + def setBootstrap(self, value: bool) -> RandomForestClassifier: ... + def setSubsamplingRate(self, value: float) -> RandomForestClassifier: ... + def setFeatureSubsetStrategy(self, value: str) -> RandomForestClassifier: ... + def setSeed(self, value: int) -> RandomForestClassifier: ... + def setCheckpointInterval(self, value: int) -> RandomForestClassifier: ... + def setWeightCol(self, value: str) -> RandomForestClassifier: ... + def setMinWeightFractionPerNode(self, value: float) -> RandomForestClassifier: ... + +class RandomForestClassificationModel( + _TreeEnsembleModel, + _JavaProbabilisticClassificationModel[Vector], + _RandomForestClassifierParams, + JavaMLWritable, + JavaMLReadable[RandomForestClassificationModel], + HasTrainingSummary[RandomForestClassificationTrainingSummary], +): + @property + def featureImportances(self) -> Vector: ... + @property + def trees(self) -> List[DecisionTreeClassificationModel]: ... + def summary(self) -> RandomForestClassificationTrainingSummary: ... + def evaluate(self, dataset) -> RandomForestClassificationSummary: ... + +class RandomForestClassificationSummary(_ClassificationSummary): ... +class RandomForestClassificationTrainingSummary( + RandomForestClassificationSummary, _TrainingSummary +): ... +class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary): ... +class BinaryRandomForestClassificationTrainingSummary( + BinaryRandomForestClassificationSummary, RandomForestClassificationTrainingSummary +): ... + +class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity): + supportedLossTypes: List[str] + lossType: Param[str] + def __init__(self, *args: Any): ... + def getLossType(self) -> str: ... + +class GBTClassifier( + _JavaProbabilisticClassifier[GBTClassificationModel], + _GBTClassifierParams, + JavaMLWritable, + JavaMLReadable[GBTClassifier], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + lossType: str = ..., + maxIter: int = ..., + stepSize: float = ..., + seed: Optional[int] = ..., + subsamplingRate: float = ..., + featureSubsetStrategy: str = ..., + validationTol: float = ..., + validationIndicatorCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + lossType: str = ..., + maxIter: int = ..., + stepSize: float = ..., + seed: Optional[int] = ..., + subsamplingRate: float = ..., + featureSubsetStrategy: str = ..., + validationTol: float = ..., + validationIndicatorCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ... + ) -> GBTClassifier: ... + def setMaxDepth(self, value: int) -> GBTClassifier: ... + def setMaxBins(self, value: int) -> GBTClassifier: ... + def setMinInstancesPerNode(self, value: int) -> GBTClassifier: ... + def setMinInfoGain(self, value: float) -> GBTClassifier: ... + def setMaxMemoryInMB(self, value: int) -> GBTClassifier: ... + def setCacheNodeIds(self, value: bool) -> GBTClassifier: ... + def setImpurity(self, value: str) -> GBTClassifier: ... + def setLossType(self, value: str) -> GBTClassifier: ... + def setSubsamplingRate(self, value: float) -> GBTClassifier: ... + def setFeatureSubsetStrategy(self, value: str) -> GBTClassifier: ... + def setValidationIndicatorCol(self, value: str) -> GBTClassifier: ... + def setMaxIter(self, value: int) -> GBTClassifier: ... + def setCheckpointInterval(self, value: int) -> GBTClassifier: ... + def setSeed(self, value: int) -> GBTClassifier: ... + def setStepSize(self, value: float) -> GBTClassifier: ... + def setWeightCol(self, value: str) -> GBTClassifier: ... + def setMinWeightFractionPerNode(self, value: float) -> GBTClassifier: ... + +class GBTClassificationModel( + _TreeEnsembleModel, + _JavaProbabilisticClassificationModel[Vector], + _GBTClassifierParams, + JavaMLWritable, + JavaMLReadable[GBTClassificationModel], +): + @property + def featureImportances(self) -> Vector: ... + @property + def trees(self) -> List[DecisionTreeRegressionModel]: ... + def evaluateEachIteration(self, dataset: DataFrame) -> List[float]: ... + +class _NaiveBayesParams(_PredictorParams, HasWeightCol): + smoothing: Param[float] + modelType: Param[str] + def __init__(self, *args: Any): ... + def getSmoothing(self) -> float: ... + def getModelType(self) -> str: ... + +class NaiveBayes( + _JavaProbabilisticClassifier[NaiveBayesModel], + _NaiveBayesParams, + HasThresholds, + HasWeightCol, + JavaMLWritable, + JavaMLReadable[NaiveBayes], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + smoothing: float = ..., + modelType: str = ..., + thresholds: Optional[List[float]] = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + smoothing: float = ..., + modelType: str = ..., + thresholds: Optional[List[float]] = ..., + weightCol: Optional[str] = ... + ) -> NaiveBayes: ... + def setSmoothing(self, value: float) -> NaiveBayes: ... + def setModelType(self, value: str) -> NaiveBayes: ... + def setWeightCol(self, value: str) -> NaiveBayes: ... + +class NaiveBayesModel( + _JavaProbabilisticClassificationModel[Vector], + _NaiveBayesParams, + JavaMLWritable, + JavaMLReadable[NaiveBayesModel], +): + @property + def pi(self) -> Vector: ... + @property + def theta(self) -> Matrix: ... + @property + def sigma(self) -> Matrix: ... + +class _MultilayerPerceptronParams( + _ProbabilisticClassifierParams, + HasSeed, + HasMaxIter, + HasTol, + HasStepSize, + HasSolver, + HasBlockSize, +): + layers: Param[List[int]] + solver: Param[str] + initialWeights: Param[Vector] + def __init__(self, *args: Any): ... + def getLayers(self) -> List[int]: ... + def getInitialWeights(self) -> Vector: ... + +class MultilayerPerceptronClassifier( + _JavaProbabilisticClassifier[MultilayerPerceptronClassificationModel], + _MultilayerPerceptronParams, + JavaMLWritable, + JavaMLReadable[MultilayerPerceptronClassifier], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + tol: float = ..., + seed: Optional[int] = ..., + layers: Optional[List[int]] = ..., + blockSize: int = ..., + stepSize: float = ..., + solver: str = ..., + initialWeights: Optional[Vector] = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + tol: float = ..., + seed: Optional[int] = ..., + layers: Optional[List[int]] = ..., + blockSize: int = ..., + stepSize: float = ..., + solver: str = ..., + initialWeights: Optional[Vector] = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ... + ) -> MultilayerPerceptronClassifier: ... + def setLayers(self, value: List[int]) -> MultilayerPerceptronClassifier: ... + def setBlockSize(self, value: int) -> MultilayerPerceptronClassifier: ... + def setInitialWeights(self, value: Vector) -> MultilayerPerceptronClassifier: ... + def setMaxIter(self, value: int) -> MultilayerPerceptronClassifier: ... + def setSeed(self, value: int) -> MultilayerPerceptronClassifier: ... + def setTol(self, value: float) -> MultilayerPerceptronClassifier: ... + def setStepSize(self, value: float) -> MultilayerPerceptronClassifier: ... + def setSolver(self, value: str) -> MultilayerPerceptronClassifier: ... + +class MultilayerPerceptronClassificationModel( + _JavaProbabilisticClassificationModel[Vector], + _MultilayerPerceptronParams, + JavaMLWritable, + JavaMLReadable[MultilayerPerceptronClassificationModel], + HasTrainingSummary[MultilayerPerceptronClassificationTrainingSummary], +): + @property + def weights(self) -> Vector: ... + def summary(self) -> MultilayerPerceptronClassificationTrainingSummary: ... + def evaluate( + self, dataset: DataFrame + ) -> MultilayerPerceptronClassificationSummary: ... + +class MultilayerPerceptronClassificationSummary(_ClassificationSummary): ... +class MultilayerPerceptronClassificationTrainingSummary( + MultilayerPerceptronClassificationSummary, _TrainingSummary +): ... + +class _OneVsRestParams(_ClassifierParams, HasWeightCol): + classifier: Param[Estimator] + def getClassifier(self) -> Estimator[M]: ... + +class OneVsRest( + Estimator[OneVsRestModel], + _OneVsRestParams, + HasParallelism, + JavaMLReadable[OneVsRest], + JavaMLWritable, +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + rawPredictionCol: str = ..., + classifier: Optional[Estimator[M]] = ..., + weightCol: Optional[str] = ..., + parallelism: int = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: Optional[str] = ..., + labelCol: Optional[str] = ..., + predictionCol: Optional[str] = ..., + rawPredictionCol: str = ..., + classifier: Optional[Estimator[M]] = ..., + weightCol: Optional[str] = ..., + parallelism: int = ... + ) -> OneVsRest: ... + def setClassifier(self, value: Estimator[M]) -> OneVsRest: ... + def setLabelCol(self, value: str) -> OneVsRest: ... + def setFeaturesCol(self, value: str) -> OneVsRest: ... + def setPredictionCol(self, value: str) -> OneVsRest: ... + def setRawPredictionCol(self, value: str) -> OneVsRest: ... + def setWeightCol(self, value: str) -> OneVsRest: ... + def setParallelism(self, value: int) -> OneVsRest: ... + def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRest: ... + +class OneVsRestModel( + Model, _OneVsRestParams, JavaMLReadable[OneVsRestModel], JavaMLWritable +): + models: List[Transformer] + def __init__(self, models: List[Transformer]) -> None: ... + def setFeaturesCol(self, value: str) -> OneVsRestModel: ... + def setPredictionCol(self, value: str) -> OneVsRestModel: ... + def setRawPredictionCol(self, value: str) -> OneVsRestModel: ... + def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRestModel: ... + +class FMClassifier( + _JavaProbabilisticClassifier[FMClassificationModel], + _FactorizationMachinesParams, + JavaMLWritable, + JavaMLReadable[FMClassifier], +): + factorSize: Param[int] + fitLinear: Param[bool] + miniBatchFraction: Param[float] + initStd: Param[float] + solver: Param[str] + def __init__( + self, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + factorSize: int = ..., + fitIntercept: bool = ..., + fitLinear: bool = ..., + regParam: float = ..., + miniBatchFraction: float = ..., + initStd: float = ..., + maxIter: int = ..., + stepSize: float = ..., + tol: float = ..., + solver: str = ..., + thresholds: Optional[Any] = ..., + seed: Optional[Any] = ..., + ) -> None: ... + def setParams( + self, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + factorSize: int = ..., + fitIntercept: bool = ..., + fitLinear: bool = ..., + regParam: float = ..., + miniBatchFraction: float = ..., + initStd: float = ..., + maxIter: int = ..., + stepSize: float = ..., + tol: float = ..., + solver: str = ..., + thresholds: Optional[Any] = ..., + seed: Optional[Any] = ..., + ): ... + def setFactorSize(self, value: int) -> FMClassifier: ... + def setFitLinear(self, value: bool) -> FMClassifier: ... + def setMiniBatchFraction(self, value: float) -> FMClassifier: ... + def setInitStd(self, value: float) -> FMClassifier: ... + def setMaxIter(self, value: int) -> FMClassifier: ... + def setStepSize(self, value: float) -> FMClassifier: ... + def setTol(self, value: float) -> FMClassifier: ... + def setSolver(self, value: str) -> FMClassifier: ... + def setSeed(self, value: int) -> FMClassifier: ... + def setFitIntercept(self, value: bool) -> FMClassifier: ... + def setRegParam(self, value: float) -> FMClassifier: ... + +class FMClassificationModel( + _JavaProbabilisticClassificationModel[Vector], + _FactorizationMachinesParams, + JavaMLWritable, + JavaMLReadable[FMClassificationModel], +): + @property + def intercept(self) -> float: ... + @property + def linear(self) -> Vector: ... + @property + def factors(self) -> Matrix: ... + def summary(self) -> FMClassificationTrainingSummary: ... + def evaluate(self, dataset: DataFrame) -> FMClassificationSummary: ... + +class FMClassificationSummary(_BinaryClassificationSummary): ... +class FMClassificationTrainingSummary(FMClassificationSummary, _TrainingSummary): ... diff --git a/python/pyspark/ml/clustering.pyi b/python/pyspark/ml/clustering.pyi new file mode 100644 index 0000000000000..e2a2d7e888367 --- /dev/null +++ b/python/pyspark/ml/clustering.pyi @@ -0,0 +1,437 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, List, Optional + +from pyspark.ml.linalg import Matrix, Vector +from pyspark.ml.util import ( + GeneralJavaMLWritable, + HasTrainingSummary, + JavaMLReadable, + JavaMLWritable, +) +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper +from pyspark.ml.param.shared import ( + HasAggregationDepth, + HasCheckpointInterval, + HasDistanceMeasure, + HasFeaturesCol, + HasMaxIter, + HasPredictionCol, + HasProbabilityCol, + HasSeed, + HasTol, + HasWeightCol, +) + +from pyspark.ml.param import Param +from pyspark.ml.stat import MultivariateGaussian +from pyspark.sql.dataframe import DataFrame + +from numpy import ndarray # type: ignore[import] + +class ClusteringSummary(JavaWrapper): + @property + def predictionCol(self) -> str: ... + @property + def predictions(self) -> DataFrame: ... + @property + def featuresCol(self) -> str: ... + @property + def k(self) -> int: ... + @property + def cluster(self) -> DataFrame: ... + @property + def clusterSizes(self) -> List[int]: ... + @property + def numIter(self) -> int: ... + +class _GaussianMixtureParams( + HasMaxIter, + HasFeaturesCol, + HasSeed, + HasPredictionCol, + HasProbabilityCol, + HasTol, + HasAggregationDepth, + HasWeightCol, +): + k: Param[int] + def __init__(self, *args: Any): ... + def getK(self) -> int: ... + +class GaussianMixtureModel( + JavaModel, + _GaussianMixtureParams, + JavaMLWritable, + JavaMLReadable[GaussianMixtureModel], + HasTrainingSummary[GaussianMixtureSummary], +): + def setFeaturesCol(self, value: str) -> GaussianMixtureModel: ... + def setPredictionCol(self, value: str) -> GaussianMixtureModel: ... + def setProbabilityCol(self, value: str) -> GaussianMixtureModel: ... + @property + def weights(self) -> List[float]: ... + @property + def gaussians(self) -> List[MultivariateGaussian]: ... + @property + def gaussiansDF(self) -> DataFrame: ... + @property + def summary(self) -> GaussianMixtureSummary: ... + def predict(self, value: Vector) -> int: ... + def predictProbability(self, value: Vector) -> Vector: ... + +class GaussianMixture( + JavaEstimator[GaussianMixtureModel], + _GaussianMixtureParams, + JavaMLWritable, + JavaMLReadable[GaussianMixture], +): + def __init__( + self, + *, + featuresCol: str = ..., + predictionCol: str = ..., + k: int = ..., + probabilityCol: str = ..., + tol: float = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + aggregationDepth: int = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + predictionCol: str = ..., + k: int = ..., + probabilityCol: str = ..., + tol: float = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + aggregationDepth: int = ..., + weightCol: Optional[str] = ... + ) -> GaussianMixture: ... + def setK(self, value: int) -> GaussianMixture: ... + def setMaxIter(self, value: int) -> GaussianMixture: ... + def setFeaturesCol(self, value: str) -> GaussianMixture: ... + def setPredictionCol(self, value: str) -> GaussianMixture: ... + def setProbabilityCol(self, value: str) -> GaussianMixture: ... + def setWeightCol(self, value: str) -> GaussianMixture: ... + def setSeed(self, value: int) -> GaussianMixture: ... + def setTol(self, value: float) -> GaussianMixture: ... + def setAggregationDepth(self, value: int) -> GaussianMixture: ... + +class GaussianMixtureSummary(ClusteringSummary): + @property + def probabilityCol(self) -> str: ... + @property + def probability(self) -> DataFrame: ... + @property + def logLikelihood(self) -> float: ... + +class KMeansSummary(ClusteringSummary): + def trainingCost(self) -> float: ... + +class _KMeansParams( + HasMaxIter, + HasFeaturesCol, + HasSeed, + HasPredictionCol, + HasTol, + HasDistanceMeasure, + HasWeightCol, +): + k: Param[int] + initMode: Param[str] + initSteps: Param[int] + def __init__(self, *args: Any): ... + def getK(self) -> int: ... + def getInitMode(self) -> str: ... + def getInitSteps(self) -> int: ... + +class KMeansModel( + JavaModel, + _KMeansParams, + GeneralJavaMLWritable, + JavaMLReadable[KMeansModel], + HasTrainingSummary[KMeansSummary], +): + def setFeaturesCol(self, value: str) -> KMeansModel: ... + def setPredictionCol(self, value: str) -> KMeansModel: ... + def clusterCenters(self) -> List[ndarray]: ... + @property + def summary(self) -> KMeansSummary: ... + def predict(self, value: Vector) -> int: ... + +class KMeans( + JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable[KMeans] +): + def __init__( + self, + *, + featuresCol: str = ..., + predictionCol: str = ..., + k: int = ..., + initMode: str = ..., + initSteps: int = ..., + tol: float = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + distanceMeasure: str = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + predictionCol: str = ..., + k: int = ..., + initMode: str = ..., + initSteps: int = ..., + tol: float = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + distanceMeasure: str = ..., + weightCol: Optional[str] = ... + ) -> KMeans: ... + def setK(self, value: int) -> KMeans: ... + def setInitMode(self, value: str) -> KMeans: ... + def setInitSteps(self, value: int) -> KMeans: ... + def setDistanceMeasure(self, value: str) -> KMeans: ... + def setMaxIter(self, value: int) -> KMeans: ... + def setFeaturesCol(self, value: str) -> KMeans: ... + def setPredictionCol(self, value: str) -> KMeans: ... + def setSeed(self, value: int) -> KMeans: ... + def setTol(self, value: float) -> KMeans: ... + def setWeightCol(self, value: str) -> KMeans: ... + +class _BisectingKMeansParams( + HasMaxIter, + HasFeaturesCol, + HasSeed, + HasPredictionCol, + HasDistanceMeasure, + HasWeightCol, +): + k: Param[int] + minDivisibleClusterSize: Param[float] + def __init__(self, *args: Any): ... + def getK(self) -> int: ... + def getMinDivisibleClusterSize(self) -> float: ... + +class BisectingKMeansModel( + JavaModel, + _BisectingKMeansParams, + JavaMLWritable, + JavaMLReadable[BisectingKMeansModel], + HasTrainingSummary[BisectingKMeansSummary], +): + def setFeaturesCol(self, value: str) -> BisectingKMeansModel: ... + def setPredictionCol(self, value: str) -> BisectingKMeansModel: ... + def clusterCenters(self) -> List[ndarray]: ... + def computeCost(self, dataset: DataFrame) -> float: ... + @property + def summary(self) -> BisectingKMeansSummary: ... + def predict(self, value: Vector) -> int: ... + +class BisectingKMeans( + JavaEstimator[BisectingKMeansModel], + _BisectingKMeansParams, + JavaMLWritable, + JavaMLReadable[BisectingKMeans], +): + def __init__( + self, + *, + featuresCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + k: int = ..., + minDivisibleClusterSize: float = ..., + distanceMeasure: str = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + k: int = ..., + minDivisibleClusterSize: float = ..., + distanceMeasure: str = ..., + weightCol: Optional[str] = ... + ) -> BisectingKMeans: ... + def setK(self, value: int) -> BisectingKMeans: ... + def setMinDivisibleClusterSize(self, value: float) -> BisectingKMeans: ... + def setDistanceMeasure(self, value: str) -> BisectingKMeans: ... + def setMaxIter(self, value: int) -> BisectingKMeans: ... + def setFeaturesCol(self, value: str) -> BisectingKMeans: ... + def setPredictionCol(self, value: str) -> BisectingKMeans: ... + def setSeed(self, value: int) -> BisectingKMeans: ... + def setWeightCol(self, value: str) -> BisectingKMeans: ... + +class BisectingKMeansSummary(ClusteringSummary): + @property + def trainingCost(self) -> float: ... + +class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): + k: Param[int] + optimizer: Param[str] + learningOffset: Param[float] + learningDecay: Param[float] + subsamplingRate: Param[float] + optimizeDocConcentration: Param[bool] + docConcentration: Param[List[float]] + topicConcentration: Param[float] + topicDistributionCol: Param[str] + keepLastCheckpoint: Param[bool] + def __init__(self, *args: Any): ... + def setK(self, value: int) -> LDA: ... + def getOptimizer(self) -> str: ... + def getLearningOffset(self) -> float: ... + def getLearningDecay(self) -> float: ... + def getSubsamplingRate(self) -> float: ... + def getOptimizeDocConcentration(self) -> bool: ... + def getDocConcentration(self) -> List[float]: ... + def getTopicConcentration(self) -> float: ... + def getTopicDistributionCol(self) -> str: ... + def getKeepLastCheckpoint(self) -> bool: ... + +class LDAModel(JavaModel, _LDAParams): + def setFeaturesCol(self, value: str) -> LDAModel: ... + def setSeed(self, value: int) -> LDAModel: ... + def setTopicDistributionCol(self, value: str) -> LDAModel: ... + def isDistributed(self) -> bool: ... + def vocabSize(self) -> int: ... + def topicsMatrix(self) -> Matrix: ... + def logLikelihood(self, dataset: DataFrame) -> float: ... + def logPerplexity(self, dataset: DataFrame) -> float: ... + def describeTopics(self, maxTermsPerTopic: int = ...) -> DataFrame: ... + def estimatedDocConcentration(self) -> Vector: ... + +class DistributedLDAModel( + LDAModel, JavaMLReadable[DistributedLDAModel], JavaMLWritable +): + def toLocal(self) -> LDAModel: ... + def trainingLogLikelihood(self) -> float: ... + def logPrior(self) -> float: ... + def getCheckpointFiles(self) -> List[str]: ... + +class LocalLDAModel(LDAModel, JavaMLReadable[LocalLDAModel], JavaMLWritable): ... + +class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable[LDA], JavaMLWritable): + def __init__( + self, + *, + featuresCol: str = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + checkpointInterval: int = ..., + k: int = ..., + optimizer: str = ..., + learningOffset: float = ..., + learningDecay: float = ..., + subsamplingRate: float = ..., + optimizeDocConcentration: bool = ..., + docConcentration: Optional[List[float]] = ..., + topicConcentration: Optional[float] = ..., + topicDistributionCol: str = ..., + keepLastCheckpoint: bool = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + checkpointInterval: int = ..., + k: int = ..., + optimizer: str = ..., + learningOffset: float = ..., + learningDecay: float = ..., + subsamplingRate: float = ..., + optimizeDocConcentration: bool = ..., + docConcentration: Optional[List[float]] = ..., + topicConcentration: Optional[float] = ..., + topicDistributionCol: str = ..., + keepLastCheckpoint: bool = ... + ) -> LDA: ... + def setCheckpointInterval(self, value: int) -> LDA: ... + def setSeed(self, value: int) -> LDA: ... + def setK(self, value: int) -> LDA: ... + def setOptimizer(self, value: str) -> LDA: ... + def setLearningOffset(self, value: float) -> LDA: ... + def setLearningDecay(self, value: float) -> LDA: ... + def setSubsamplingRate(self, value: float) -> LDA: ... + def setOptimizeDocConcentration(self, value: bool) -> LDA: ... + def setDocConcentration(self, value: List[float]) -> LDA: ... + def setTopicConcentration(self, value: float) -> LDA: ... + def setTopicDistributionCol(self, value: str) -> LDA: ... + def setKeepLastCheckpoint(self, value: bool) -> LDA: ... + def setMaxIter(self, value: int) -> LDA: ... + def setFeaturesCol(self, value: str) -> LDA: ... + +class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol): + k: Param[int] + initMode: Param[str] + srcCol: Param[str] + dstCol: Param[str] + def __init__(self, *args: Any): ... + def getK(self) -> int: ... + def getInitMode(self) -> str: ... + def getSrcCol(self) -> str: ... + def getDstCol(self) -> str: ... + +class PowerIterationClustering( + _PowerIterationClusteringParams, + JavaParams, + JavaMLReadable[PowerIterationClustering], + JavaMLWritable, +): + def __init__( + self, + *, + k: int = ..., + maxIter: int = ..., + initMode: str = ..., + srcCol: str = ..., + dstCol: str = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + k: int = ..., + maxIter: int = ..., + initMode: str = ..., + srcCol: str = ..., + dstCol: str = ..., + weightCol: Optional[str] = ... + ) -> PowerIterationClustering: ... + def setK(self, value: int) -> PowerIterationClustering: ... + def setInitMode(self, value: str) -> PowerIterationClustering: ... + def setSrcCol(self, value: str) -> str: ... + def setDstCol(self, value: str) -> PowerIterationClustering: ... + def setMaxIter(self, value: int) -> PowerIterationClustering: ... + def setWeightCol(self, value: str) -> PowerIterationClustering: ... + def assignClusters(self, dataset: DataFrame) -> DataFrame: ... diff --git a/python/pyspark/ml/common.pyi b/python/pyspark/ml/common.pyi new file mode 100644 index 0000000000000..7bf0ed6183d8a --- /dev/null +++ b/python/pyspark/ml/common.pyi @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +def callJavaFunc(sc, func, *args): ... +def inherit_doc(cls): ... diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi new file mode 100644 index 0000000000000..ea0a9f045cd6a --- /dev/null +++ b/python/pyspark/ml/evaluation.pyi @@ -0,0 +1,281 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import abc +from typing import Optional +from pyspark.ml._typing import ( + ParamMap, + BinaryClassificationEvaluatorMetricType, + ClusteringEvaluatorMetricType, + MulticlassClassificationEvaluatorMetricType, + MultilabelClassificationEvaluatorMetricType, + RankingEvaluatorMetricType, + RegressionEvaluatorMetricType, +) + +from pyspark.ml.wrapper import JavaParams +from pyspark.ml.param import Param, Params +from pyspark.ml.param.shared import ( + HasFeaturesCol, + HasLabelCol, + HasPredictionCol, + HasProbabilityCol, + HasRawPredictionCol, + HasWeightCol, +) +from pyspark.ml.util import JavaMLReadable, JavaMLWritable + +class Evaluator(Params, metaclass=abc.ABCMeta): + def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ... + def isLargerBetter(self) -> bool: ... + +class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta): + def isLargerBetter(self) -> bool: ... + +class BinaryClassificationEvaluator( + JavaEvaluator, + HasLabelCol, + HasRawPredictionCol, + HasWeightCol, + JavaMLReadable[BinaryClassificationEvaluator], + JavaMLWritable, +): + metricName: Param[BinaryClassificationEvaluatorMetricType] + numBins: Param[int] + def __init__( + self, + *, + rawPredictionCol: str = ..., + labelCol: str = ..., + metricName: BinaryClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + numBins: int = ... + ) -> None: ... + def setMetricName( + self, value: BinaryClassificationEvaluatorMetricType + ) -> BinaryClassificationEvaluator: ... + def getMetricName(self) -> BinaryClassificationEvaluatorMetricType: ... + def setNumBins(self, value: int) -> BinaryClassificationEvaluator: ... + def getNumBins(self) -> int: ... + def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ... + def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ... + def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ... + +def setParams( + self, + *, + rawPredictionCol: str = ..., + labelCol: str = ..., + metricName: BinaryClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + numBins: int = ... +) -> BinaryClassificationEvaluator: ... + +class RegressionEvaluator( + JavaEvaluator, + HasLabelCol, + HasPredictionCol, + HasWeightCol, + JavaMLReadable[RegressionEvaluator], + JavaMLWritable, +): + metricName: Param[RegressionEvaluatorMetricType] + throughOrigin: Param[bool] + def __init__( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: RegressionEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + throughOrigin: bool = ... + ) -> None: ... + def setMetricName( + self, value: RegressionEvaluatorMetricType + ) -> RegressionEvaluator: ... + def getMetricName(self) -> RegressionEvaluatorMetricType: ... + def setThroughOrigin(self, value: bool) -> RegressionEvaluator: ... + def getThroughOrigin(self) -> bool: ... + def setLabelCol(self, value: str) -> RegressionEvaluator: ... + def setPredictionCol(self, value: str) -> RegressionEvaluator: ... + def setWeightCol(self, value: str) -> RegressionEvaluator: ... + def setParams( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: RegressionEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + throughOrigin: bool = ... + ) -> RegressionEvaluator: ... + +class MulticlassClassificationEvaluator( + JavaEvaluator, + HasLabelCol, + HasPredictionCol, + HasWeightCol, + HasProbabilityCol, + JavaMLReadable[MulticlassClassificationEvaluator], + JavaMLWritable, +): + metricName: Param[MulticlassClassificationEvaluatorMetricType] + metricLabel: Param[float] + beta: Param[float] + eps: Param[float] + def __init__( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: MulticlassClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + metricLabel: float = ..., + beta: float = ..., + probabilityCol: str = ..., + eps: float = ... + ) -> None: ... + def setMetricName( + self, value: MulticlassClassificationEvaluatorMetricType + ) -> MulticlassClassificationEvaluator: ... + def getMetricName(self) -> MulticlassClassificationEvaluatorMetricType: ... + def setMetricLabel(self, value: float) -> MulticlassClassificationEvaluator: ... + def getMetricLabel(self) -> float: ... + def setBeta(self, value: float) -> MulticlassClassificationEvaluator: ... + def getBeta(self) -> float: ... + def setEps(self, value: float) -> MulticlassClassificationEvaluator: ... + def getEps(self) -> float: ... + def setLabelCol(self, value: str) -> MulticlassClassificationEvaluator: ... + def setPredictionCol(self, value: str) -> MulticlassClassificationEvaluator: ... + def setProbabilityCol(self, value: str) -> MulticlassClassificationEvaluator: ... + def setWeightCol(self, value: str) -> MulticlassClassificationEvaluator: ... + def setParams( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: MulticlassClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + metricLabel: float = ..., + beta: float = ..., + probabilityCol: str = ..., + eps: float = ... + ) -> MulticlassClassificationEvaluator: ... + +class MultilabelClassificationEvaluator( + JavaEvaluator, + HasLabelCol, + HasPredictionCol, + JavaMLReadable[MultilabelClassificationEvaluator], + JavaMLWritable, +): + metricName: Param[MultilabelClassificationEvaluatorMetricType] + metricLabel: Param[float] + def __init__( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: MultilabelClassificationEvaluatorMetricType = ..., + metricLabel: float = ... + ) -> None: ... + def setMetricName( + self, value: MultilabelClassificationEvaluatorMetricType + ) -> MultilabelClassificationEvaluator: ... + def getMetricName(self) -> MultilabelClassificationEvaluatorMetricType: ... + def setMetricLabel(self, value: float) -> MultilabelClassificationEvaluator: ... + def getMetricLabel(self) -> float: ... + def setLabelCol(self, value: str) -> MultilabelClassificationEvaluator: ... + def setPredictionCol(self, value: str) -> MultilabelClassificationEvaluator: ... + def setParams( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: MultilabelClassificationEvaluatorMetricType = ..., + metricLabel: float = ... + ) -> MultilabelClassificationEvaluator: ... + +class ClusteringEvaluator( + JavaEvaluator, + HasPredictionCol, + HasFeaturesCol, + HasWeightCol, + JavaMLReadable[ClusteringEvaluator], + JavaMLWritable, +): + metricName: Param[ClusteringEvaluatorMetricType] + distanceMeasure: Param[str] + def __init__( + self, + *, + predictionCol: str = ..., + featuresCol: str = ..., + metricName: ClusteringEvaluatorMetricType = ..., + distanceMeasure: str = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + predictionCol: str = ..., + featuresCol: str = ..., + metricName: ClusteringEvaluatorMetricType = ..., + distanceMeasure: str = ..., + weightCol: Optional[str] = ... + ) -> ClusteringEvaluator: ... + def setMetricName( + self, value: ClusteringEvaluatorMetricType + ) -> ClusteringEvaluator: ... + def getMetricName(self) -> ClusteringEvaluatorMetricType: ... + def setDistanceMeasure(self, value: str) -> ClusteringEvaluator: ... + def getDistanceMeasure(self) -> str: ... + def setFeaturesCol(self, value: str) -> ClusteringEvaluator: ... + def setPredictionCol(self, value: str) -> ClusteringEvaluator: ... + def setWeightCol(self, value: str) -> ClusteringEvaluator: ... + +class RankingEvaluator( + JavaEvaluator, + HasLabelCol, + HasPredictionCol, + JavaMLReadable[RankingEvaluator], + JavaMLWritable, +): + metricName: Param[RankingEvaluatorMetricType] + k: Param[int] + def __init__( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: RankingEvaluatorMetricType = ..., + k: int = ... + ) -> None: ... + def setMetricName(self, value: RankingEvaluatorMetricType) -> RankingEvaluator: ... + def getMetricName(self) -> RankingEvaluatorMetricType: ... + def setK(self, value: int) -> RankingEvaluator: ... + def getK(self) -> int: ... + def setLabelCol(self, value: str) -> RankingEvaluator: ... + def setPredictionCol(self, value: str) -> RankingEvaluator: ... + def setParams( + self, + *, + predictionCol: str = ..., + labelCol: str = ..., + metricName: RankingEvaluatorMetricType = ..., + k: int = ... + ) -> RankingEvaluator: ... diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi new file mode 100644 index 0000000000000..f5b12a5b2ffc6 --- /dev/null +++ b/python/pyspark/ml/feature.pyi @@ -0,0 +1,1629 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Dict, Generic, List, Optional, Tuple +from pyspark.ml._typing import JM, P + +from pyspark.ml.param.shared import ( + HasFeaturesCol, + HasHandleInvalid, + HasInputCol, + HasInputCols, + HasLabelCol, + HasMaxIter, + HasNumFeatures, + HasOutputCol, + HasOutputCols, + HasRelativeError, + HasSeed, + HasStepSize, + HasThreshold, + HasThresholds, +) +from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer +from pyspark.ml.linalg import Vector, DenseVector, DenseMatrix +from pyspark.sql.dataframe import DataFrame +from pyspark.ml.param import Param + +class Binarizer( + JavaTransformer, + HasThreshold, + HasThresholds, + HasInputCol, + HasOutputCol, + HasInputCols, + HasOutputCols, + JavaMLReadable[Binarizer], + JavaMLWritable, +): + threshold: Param[float] + thresholds: Param[List[float]] + @overload + def __init__( + self, + *, + threshold: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + @overload + def __init__( + self, + *, + thresholds: Optional[List[float]] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> None: ... + @overload + def setParams( + self, + *, + threshold: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> Binarizer: ... + @overload + def setParams( + self, + *, + thresholds: Optional[List[float]] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> Binarizer: ... + def setThreshold(self, value: float) -> Binarizer: ... + def setThresholds(self, value: List[float]) -> Binarizer: ... + def setInputCol(self, value: str) -> Binarizer: ... + def setInputCols(self, value: List[str]) -> Binarizer: ... + def setOutputCol(self, value: str) -> Binarizer: ... + def setOutputCols(self, value: List[str]) -> Binarizer: ... + +class _LSHParams(HasInputCol, HasOutputCol): + numHashTables: Param[int] + def __init__(self, *args: Any): ... + def getNumHashTables(self) -> int: ... + +class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable): + def setNumHashTables(self: P, value) -> P: ... + def setInputCol(self: P, value) -> P: ... + def setOutputCol(self: P, value) -> P: ... + +class _LSHModel(JavaModel, _LSHParams): + def setInputCol(self: P, value: str) -> P: ... + def setOutputCol(self: P, value: str) -> P: ... + def approxNearestNeighbors( + self, + dataset: DataFrame, + key: Vector, + numNearestNeighbors: int, + distCol: str = ..., + ) -> DataFrame: ... + def approxSimilarityJoin( + self, + datasetA: DataFrame, + datasetB: DataFrame, + threshold: float, + distCol: str = ..., + ) -> DataFrame: ... + +class _BucketedRandomProjectionLSHParams: + bucketLength: Param[float] + def getBucketLength(self) -> float: ... + +class BucketedRandomProjectionLSH( + _LSH[BucketedRandomProjectionLSHModel], + _LSHParams, + HasSeed, + JavaMLReadable[BucketedRandomProjectionLSH], + JavaMLWritable, +): + def __init__( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + seed: Optional[int] = ..., + numHashTables: int = ..., + bucketLength: Optional[float] = ... + ) -> None: ... + def setParams( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + seed: Optional[int] = ..., + numHashTables: int = ..., + bucketLength: Optional[float] = ... + ) -> BucketedRandomProjectionLSH: ... + def setBucketLength(self, value: float) -> BucketedRandomProjectionLSH: ... + def setSeed(self, value: int) -> BucketedRandomProjectionLSH: ... + +class BucketedRandomProjectionLSHModel( + _LSHModel, + _BucketedRandomProjectionLSHParams, + JavaMLReadable[BucketedRandomProjectionLSHModel], + JavaMLWritable, +): ... + +class Bucketizer( + JavaTransformer, + HasInputCol, + HasOutputCol, + HasInputCols, + HasOutputCols, + HasHandleInvalid, + JavaMLReadable[Bucketizer], + JavaMLWritable, +): + splits: Param[List[float]] + handleInvalid: Param[str] + splitsArray: Param[List[List[float]]] + @overload + def __init__( + self, + *, + splits: Optional[List[float]] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ... + ) -> None: ... + @overload + def __init__( + self, + *, + handleInvalid: str = ..., + splitsArray: Optional[List[List[float]]] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> None: ... + @overload + def setParams( + self, + *, + splits: Optional[List[float]] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ... + ) -> Bucketizer: ... + @overload + def setParams( + self, + *, + handleInvalid: str = ..., + splitsArray: Optional[List[List[float]]] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> Bucketizer: ... + def setSplits(self, value: List[float]) -> Bucketizer: ... + def getSplits(self) -> List[float]: ... + def setSplitsArray(self, value: List[List[float]]) -> Bucketizer: ... + def getSplitsArray(self) -> List[List[float]]: ... + def setInputCol(self, value: str) -> Bucketizer: ... + def setInputCols(self, value: List[str]) -> Bucketizer: ... + def setOutputCol(self, value: str) -> Bucketizer: ... + def setOutputCols(self, value: List[str]) -> Bucketizer: ... + def setHandleInvalid(self, value: str) -> Bucketizer: ... + +class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): + minTF: Param[float] + minDF: Param[float] + maxDF: Param[float] + vocabSize: Param[int] + binary: Param[bool] + def __init__(self, *args: Any) -> None: ... + def getMinTF(self) -> float: ... + def getMinDF(self) -> float: ... + def getMaxDF(self) -> float: ... + def getVocabSize(self) -> int: ... + def getBinary(self) -> bool: ... + +class CountVectorizer( + JavaEstimator[CountVectorizerModel], + _CountVectorizerParams, + JavaMLReadable[CountVectorizer], + JavaMLWritable, +): + def __init__( + self, + *, + minTF: float = ..., + minDF: float = ..., + maxDF: float = ..., + vocabSize: int = ..., + binary: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + minTF: float = ..., + minDF: float = ..., + maxDF: float = ..., + vocabSize: int = ..., + binary: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> CountVectorizer: ... + def setMinTF(self, value: float) -> CountVectorizer: ... + def setMinDF(self, value: float) -> CountVectorizer: ... + def setMaxDF(self, value: float) -> CountVectorizer: ... + def setVocabSize(self, value: int) -> CountVectorizer: ... + def setBinary(self, value: bool) -> CountVectorizer: ... + def setInputCol(self, value: str) -> CountVectorizer: ... + def setOutputCol(self, value: str) -> CountVectorizer: ... + +class CountVectorizerModel( + JavaModel, JavaMLReadable[CountVectorizerModel], JavaMLWritable +): + def setInputCol(self, value: str) -> CountVectorizerModel: ... + def setOutputCol(self, value: str) -> CountVectorizerModel: ... + def setMinTF(self, value: float) -> CountVectorizerModel: ... + def setBinary(self, value: bool) -> CountVectorizerModel: ... + @classmethod + def from_vocabulary( + cls, + vocabulary: List[str], + inputCol: str, + outputCol: Optional[str] = ..., + minTF: Optional[float] = ..., + binary: Optional[bool] = ..., + ) -> CountVectorizerModel: ... + @property + def vocabulary(self) -> List[str]: ... + +class DCT( + JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[DCT], JavaMLWritable +): + inverse: Param[bool] + def __init__( + self, + *, + inverse: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + inverse: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> DCT: ... + def setInverse(self, value: bool) -> DCT: ... + def getInverse(self) -> bool: ... + def setInputCol(self, value: str) -> DCT: ... + def setOutputCol(self, value: str) -> DCT: ... + +class ElementwiseProduct( + JavaTransformer, + HasInputCol, + HasOutputCol, + JavaMLReadable[ElementwiseProduct], + JavaMLWritable, +): + scalingVec: Param[Vector] + def __init__( + self, + *, + scalingVec: Optional[Vector] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + scalingVec: Optional[Vector] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> ElementwiseProduct: ... + def setScalingVec(self, value: Vector) -> ElementwiseProduct: ... + def getScalingVec(self) -> Vector: ... + def setInputCol(self, value: str) -> ElementwiseProduct: ... + def setOutputCol(self, value: str) -> ElementwiseProduct: ... + +class FeatureHasher( + JavaTransformer, + HasInputCols, + HasOutputCol, + HasNumFeatures, + JavaMLReadable[FeatureHasher], + JavaMLWritable, +): + categoricalCols: Param[List[str]] + def __init__( + self, + *, + numFeatures: int = ..., + inputCols: Optional[List[str]] = ..., + outputCol: Optional[str] = ..., + categoricalCols: Optional[List[str]] = ... + ) -> None: ... + def setParams( + self, + *, + numFeatures: int = ..., + inputCols: Optional[List[str]] = ..., + outputCol: Optional[str] = ..., + categoricalCols: Optional[List[str]] = ... + ) -> FeatureHasher: ... + def setCategoricalCols(self, value: List[str]) -> FeatureHasher: ... + def getCategoricalCols(self) -> List[str]: ... + def setInputCols(self, value: List[str]) -> FeatureHasher: ... + def setOutputCol(self, value: str) -> FeatureHasher: ... + def setNumFeatures(self, value: int) -> FeatureHasher: ... + +class HashingTF( + JavaTransformer, + HasInputCol, + HasOutputCol, + HasNumFeatures, + JavaMLReadable[HashingTF], + JavaMLWritable, +): + binary: Param[bool] + def __init__( + self, + *, + numFeatures: int = ..., + binary: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + numFeatures: int = ..., + binary: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> HashingTF: ... + def setBinary(self, value: bool) -> HashingTF: ... + def getBinary(self) -> bool: ... + def setInputCol(self, value: str) -> HashingTF: ... + def setOutputCol(self, value: str) -> HashingTF: ... + def setNumFeatures(self, value: int) -> HashingTF: ... + def indexOf(self, term: Any) -> int: ... + +class _IDFParams(HasInputCol, HasOutputCol): + minDocFreq: Param[int] + def __init__(self, *args: Any): ... + def getMinDocFreq(self) -> int: ... + +class IDF(JavaEstimator[IDFModel], _IDFParams, JavaMLReadable[IDF], JavaMLWritable): + def __init__( + self, + *, + minDocFreq: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + minDocFreq: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> IDF: ... + def setMinDocFreq(self, value: int) -> IDF: ... + def setInputCol(self, value: str) -> IDF: ... + def setOutputCol(self, value: str) -> IDF: ... + +class IDFModel(JavaModel, _IDFParams, JavaMLReadable[IDFModel], JavaMLWritable): + def setInputCol(self, value: str) -> IDFModel: ... + def setOutputCol(self, value: str) -> IDFModel: ... + @property + def idf(self) -> Vector: ... + @property + def docFreq(self) -> List[int]: ... + @property + def numDocs(self) -> int: ... + +class _ImputerParams( + HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasRelativeError +): + strategy: Param[str] + missingValue: Param[float] + def getStrategy(self) -> str: ... + def getMissingValue(self) -> float: ... + +class Imputer( + JavaEstimator[ImputerModel], _ImputerParams, JavaMLReadable[Imputer], JavaMLWritable +): + @overload + def __init__( + self, + *, + strategy: str = ..., + missingValue: float = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + relativeError: float = ... + ) -> None: ... + @overload + def __init__( + self, + *, + strategy: str = ..., + missingValue: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + relativeError: float = ... + ) -> None: ... + @overload + def setParams( + self, + *, + strategy: str = ..., + missingValue: float = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + relativeError: float = ... + ) -> Imputer: ... + @overload + def setParams( + self, + *, + strategy: str = ..., + missingValue: float = ..., + inputCol: Optional[str] = ..., + outputCols: Optional[str] = ..., + relativeError: float = ... + ) -> Imputer: ... + def setStrategy(self, value: str) -> Imputer: ... + def setMissingValue(self, value: float) -> Imputer: ... + def setInputCols(self, value: List[str]) -> Imputer: ... + def setOutputCols(self, value: List[str]) -> Imputer: ... + def setInputCol(self, value: str) -> Imputer: ... + def setOutputCol(self, value: str) -> Imputer: ... + def setRelativeError(self, value: float) -> Imputer: ... + +class ImputerModel( + JavaModel, _ImputerParams, JavaMLReadable[ImputerModel], JavaMLWritable +): + def setInputCols(self, value: List[str]) -> ImputerModel: ... + def setOutputCols(self, value: List[str]) -> ImputerModel: ... + def setInputCol(self, value: str) -> ImputerModel: ... + def setOutputCol(self, value: str) -> ImputerModel: ... + @property + def surrogateDF(self) -> DataFrame: ... + +class Interaction( + JavaTransformer, + HasInputCols, + HasOutputCol, + JavaMLReadable[Interaction], + JavaMLWritable, +): + def __init__( + self, *, inputCols: Optional[List[str]] = ..., outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, *, inputCols: Optional[List[str]] = ..., outputCol: Optional[str] = ... + ) -> Interaction: ... + def setInputCols(self, value: List[str]) -> Interaction: ... + def setOutputCol(self, value: str) -> Interaction: ... + +class _MaxAbsScalerParams(HasInputCol, HasOutputCol): ... + +class MaxAbsScaler( + JavaEstimator[MaxAbsScalerModel], + _MaxAbsScalerParams, + JavaMLReadable[MaxAbsScaler], + JavaMLWritable, +): + def __init__( + self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ... + ) -> MaxAbsScaler: ... + def setInputCol(self, value: str) -> MaxAbsScaler: ... + def setOutputCol(self, value: str) -> MaxAbsScaler: ... + +class MaxAbsScalerModel( + JavaModel, _MaxAbsScalerParams, JavaMLReadable[MaxAbsScalerModel], JavaMLWritable +): + def setInputCol(self, value: str) -> MaxAbsScalerModel: ... + def setOutputCol(self, value: str) -> MaxAbsScalerModel: ... + @property + def maxAbs(self) -> Vector: ... + +class MinHashLSH( + _LSH[MinHashLSHModel], + HasInputCol, + HasOutputCol, + HasSeed, + JavaMLReadable[MinHashLSH], + JavaMLWritable, +): + def __init__( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + seed: Optional[int] = ..., + numHashTables: int = ... + ) -> None: ... + def setParams( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + seed: Optional[int] = ..., + numHashTables: int = ... + ) -> MinHashLSH: ... + def setSeed(self, value: int) -> MinHashLSH: ... + +class MinHashLSHModel(_LSHModel, JavaMLReadable[MinHashLSHModel], JavaMLWritable): ... + +class _MinMaxScalerParams(HasInputCol, HasOutputCol): + min: Param[float] + max: Param[float] + def __init__(self, *args: Any): ... + def getMin(self) -> float: ... + def getMax(self) -> float: ... + +class MinMaxScaler( + JavaEstimator[MinMaxScalerModel], + _MinMaxScalerParams, + JavaMLReadable[MinMaxScaler], + JavaMLWritable, +): + def __init__( + self, + *, + min: float = ..., + max: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + min: float = ..., + max: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> MinMaxScaler: ... + def setMin(self, value: float) -> MinMaxScaler: ... + def setMax(self, value: float) -> MinMaxScaler: ... + def setInputCol(self, value: str) -> MinMaxScaler: ... + def setOutputCol(self, value: str) -> MinMaxScaler: ... + +class MinMaxScalerModel( + JavaModel, _MinMaxScalerParams, JavaMLReadable[MinMaxScalerModel], JavaMLWritable +): + def setInputCol(self, value: str) -> MinMaxScalerModel: ... + def setOutputCol(self, value: str) -> MinMaxScalerModel: ... + def setMin(self, value: float) -> MinMaxScalerModel: ... + def setMax(self, value: float) -> MinMaxScalerModel: ... + @property + def originalMin(self) -> Vector: ... + @property + def originalMax(self) -> Vector: ... + +class NGram( + JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[NGram], JavaMLWritable +): + n: Param[int] + def __init__( + self, + *, + n: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + n: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> NGram: ... + def setN(self, value: int) -> NGram: ... + def getN(self) -> int: ... + def setInputCol(self, value: str) -> NGram: ... + def setOutputCol(self, value: str) -> NGram: ... + +class Normalizer( + JavaTransformer, + HasInputCol, + HasOutputCol, + JavaMLReadable[Normalizer], + JavaMLWritable, +): + p: Param[float] + def __init__( + self, + *, + p: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + p: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> Normalizer: ... + def setP(self, value: float) -> Normalizer: ... + def getP(self) -> float: ... + def setInputCol(self, value: str) -> Normalizer: ... + def setOutputCol(self, value: str) -> Normalizer: ... + +class _OneHotEncoderParams(HasInputCols, HasOutputCols, HasHandleInvalid): + handleInvalid: Param[str] + dropLast: Param[bool] + def __init__(self, *args: Any): ... + def getDropLast(self) -> bool: ... + +class OneHotEncoder( + JavaEstimator[OneHotEncoderModel], + _OneHotEncoderParams, + JavaMLReadable[OneHotEncoder], + JavaMLWritable, +): + @overload + def __init__( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + handleInvalid: str = ..., + dropLast: bool = ... + ) -> None: ... + @overload + def __init__( + self, + *, + handleInvalid: str = ..., + dropLast: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + @overload + def setParams( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + handleInvalid: str = ..., + dropLast: bool = ... + ) -> OneHotEncoder: ... + @overload + def setParams( + self, + *, + handleInvalid: str = ..., + dropLast: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> OneHotEncoder: ... + def setDropLast(self, value: bool) -> OneHotEncoder: ... + def setInputCols(self, value: List[str]) -> OneHotEncoder: ... + def setOutputCols(self, value: List[str]) -> OneHotEncoder: ... + def setHandleInvalid(self, value: str) -> OneHotEncoder: ... + def setInputCol(self, value: str) -> OneHotEncoder: ... + def setOutputCol(self, value: str) -> OneHotEncoder: ... + +class OneHotEncoderModel( + JavaModel, _OneHotEncoderParams, JavaMLReadable[OneHotEncoderModel], JavaMLWritable +): + def setDropLast(self, value: bool) -> OneHotEncoderModel: ... + def setInputCols(self, value: List[str]) -> OneHotEncoderModel: ... + def setOutputCols(self, value: List[str]) -> OneHotEncoderModel: ... + def setInputCol(self, value: str) -> OneHotEncoderModel: ... + def setOutputCol(self, value: str) -> OneHotEncoderModel: ... + def setHandleInvalid(self, value: str) -> OneHotEncoderModel: ... + @property + def categorySizes(self) -> List[int]: ... + +class PolynomialExpansion( + JavaTransformer, + HasInputCol, + HasOutputCol, + JavaMLReadable[PolynomialExpansion], + JavaMLWritable, +): + degree: Param[int] + def __init__( + self, + *, + degree: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + degree: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> PolynomialExpansion: ... + def setDegree(self, value: int) -> PolynomialExpansion: ... + def getDegree(self) -> int: ... + def setInputCol(self, value: str) -> PolynomialExpansion: ... + def setOutputCol(self, value: str) -> PolynomialExpansion: ... + +class QuantileDiscretizer( + JavaEstimator[Bucketizer], + HasInputCol, + HasOutputCol, + HasInputCols, + HasOutputCols, + HasHandleInvalid, + HasRelativeError, + JavaMLReadable[QuantileDiscretizer], + JavaMLWritable, +): + numBuckets: Param[int] + handleInvalid: Param[str] + numBucketsArray: Param[List[int]] + @overload + def __init__( + self, + *, + numBuckets: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + relativeError: float = ..., + handleInvalid: str = ... + ) -> None: ... + @overload + def __init__( + self, + *, + relativeError: float = ..., + handleInvalid: str = ..., + numBucketsArray: Optional[List[int]] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> None: ... + @overload + def setParams( + self, + *, + numBuckets: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + relativeError: float = ..., + handleInvalid: str = ... + ) -> QuantileDiscretizer: ... + @overload + def setParams( + self, + *, + relativeError: float = ..., + handleInvalid: str = ..., + numBucketsArray: Optional[List[int]] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> QuantileDiscretizer: ... + def setNumBuckets(self, value: int) -> QuantileDiscretizer: ... + def getNumBuckets(self) -> int: ... + def setNumBucketsArray(self, value: List[int]) -> QuantileDiscretizer: ... + def getNumBucketsArray(self) -> List[int]: ... + def setRelativeError(self, value: float) -> QuantileDiscretizer: ... + def setInputCol(self, value: str) -> QuantileDiscretizer: ... + def setInputCols(self, value: List[str]) -> QuantileDiscretizer: ... + def setOutputCol(self, value: str) -> QuantileDiscretizer: ... + def setOutputCols(self, value: List[str]) -> QuantileDiscretizer: ... + def setHandleInvalid(self, value: str) -> QuantileDiscretizer: ... + +class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError): + lower: Param[float] + upper: Param[float] + withCentering: Param[bool] + withScaling: Param[bool] + def __init__(self, *args: Any): ... + def getLower(self) -> float: ... + def getUpper(self) -> float: ... + def getWithCentering(self) -> bool: ... + def getWithScaling(self) -> bool: ... + +class RobustScaler( + JavaEstimator, _RobustScalerParams, JavaMLReadable[RobustScaler], JavaMLWritable +): + def __init__( + self, + *, + lower: float = ..., + upper: float = ..., + withCentering: bool = ..., + withScaling: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + relativeError: float = ... + ) -> None: ... + def setParams( + self, + *, + lower: float = ..., + upper: float = ..., + withCentering: bool = ..., + withScaling: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + relativeError: float = ... + ) -> RobustScaler: ... + def setLower(self, value: float) -> RobustScaler: ... + def setUpper(self, value: float) -> RobustScaler: ... + def setWithCentering(self, value: bool) -> RobustScaler: ... + def setWithScaling(self, value: bool) -> RobustScaler: ... + def setInputCol(self, value: str) -> RobustScaler: ... + def setOutputCol(self, value: str) -> RobustScaler: ... + def setRelativeError(self, value: float) -> RobustScaler: ... + +class RobustScalerModel( + JavaModel, _RobustScalerParams, JavaMLReadable[RobustScalerModel], JavaMLWritable +): + def setInputCol(self, value: str) -> RobustScalerModel: ... + def setOutputCol(self, value: str) -> RobustScalerModel: ... + @property + def median(self) -> Vector: ... + @property + def range(self) -> Vector: ... + +class RegexTokenizer( + JavaTransformer, + HasInputCol, + HasOutputCol, + JavaMLReadable[RegexTokenizer], + JavaMLWritable, +): + minTokenLength: Param[int] + gaps: Param[bool] + pattern: Param[str] + toLowercase: Param[bool] + def __init__( + self, + *, + minTokenLength: int = ..., + gaps: bool = ..., + pattern: str = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + toLowercase: bool = ... + ) -> None: ... + def setParams( + self, + *, + minTokenLength: int = ..., + gaps: bool = ..., + pattern: str = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + toLowercase: bool = ... + ) -> RegexTokenizer: ... + def setMinTokenLength(self, value: int) -> RegexTokenizer: ... + def getMinTokenLength(self) -> int: ... + def setGaps(self, value: bool) -> RegexTokenizer: ... + def getGaps(self) -> bool: ... + def setPattern(self, value: str) -> RegexTokenizer: ... + def getPattern(self) -> str: ... + def setToLowercase(self, value: bool) -> RegexTokenizer: ... + def getToLowercase(self) -> bool: ... + def setInputCol(self, value: str) -> RegexTokenizer: ... + def setOutputCol(self, value: str) -> RegexTokenizer: ... + +class SQLTransformer(JavaTransformer, JavaMLReadable[SQLTransformer], JavaMLWritable): + statement: Param[str] + def __init__(self, *, statement: Optional[str] = ...) -> None: ... + def setParams(self, *, statement: Optional[str] = ...) -> SQLTransformer: ... + def setStatement(self, value: str) -> SQLTransformer: ... + def getStatement(self) -> str: ... + +class _StandardScalerParams(HasInputCol, HasOutputCol): + withMean: Param[bool] + withStd: Param[bool] + def __init__(self, *args: Any): ... + def getWithMean(self) -> bool: ... + def getWithStd(self) -> bool: ... + +class StandardScaler( + JavaEstimator[StandardScalerModel], + _StandardScalerParams, + JavaMLReadable[StandardScaler], + JavaMLWritable, +): + def __init__( + self, + *, + withMean: bool = ..., + withStd: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + withMean: bool = ..., + withStd: bool = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> StandardScaler: ... + def setWithMean(self, value: bool) -> StandardScaler: ... + def setWithStd(self, value: bool) -> StandardScaler: ... + def setInputCol(self, value: str) -> StandardScaler: ... + def setOutputCol(self, value: str) -> StandardScaler: ... + +class StandardScalerModel( + JavaModel, + _StandardScalerParams, + JavaMLReadable[StandardScalerModel], + JavaMLWritable, +): + def setInputCol(self, value: str) -> StandardScalerModel: ... + def setOutputCol(self, value: str) -> StandardScalerModel: ... + @property + def std(self) -> Vector: ... + @property + def mean(self) -> Vector: ... + +class _StringIndexerParams( + JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols +): + stringOrderType: Param[str] + handleInvalid: Param[str] + def __init__(self, *args: Any) -> None: ... + def getStringOrderType(self) -> str: ... + +class StringIndexer( + JavaEstimator[StringIndexerModel], + _StringIndexerParams, + JavaMLReadable[StringIndexer], + JavaMLWritable, +): + @overload + def __init__( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ..., + stringOrderType: str = ... + ) -> None: ... + @overload + def __init__( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + handleInvalid: str = ..., + stringOrderType: str = ... + ) -> None: ... + @overload + def setParams( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ..., + stringOrderType: str = ... + ) -> StringIndexer: ... + @overload + def setParams( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + handleInvalid: str = ..., + stringOrderType: str = ... + ) -> StringIndexer: ... + def setStringOrderType(self, value: str) -> StringIndexer: ... + def setInputCol(self, value: str) -> StringIndexer: ... + def setInputCols(self, value: List[str]) -> StringIndexer: ... + def setOutputCol(self, value: str) -> StringIndexer: ... + def setOutputCols(self, value: List[str]) -> StringIndexer: ... + def setHandleInvalid(self, value: str) -> StringIndexer: ... + +class StringIndexerModel( + JavaModel, _StringIndexerParams, JavaMLReadable[StringIndexerModel], JavaMLWritable +): + def setInputCol(self, value: str) -> StringIndexerModel: ... + def setInputCols(self, value: List[str]) -> StringIndexerModel: ... + def setOutputCol(self, value: str) -> StringIndexerModel: ... + def setOutputCols(self, value: List[str]) -> StringIndexerModel: ... + def setHandleInvalid(self, value: str) -> StringIndexerModel: ... + @classmethod + def from_labels( + cls, + labels: List[str], + inputCol: str, + outputCol: Optional[str] = ..., + handleInvalid: Optional[str] = ..., + ) -> StringIndexerModel: ... + @classmethod + def from_arrays_of_labels( + cls, + arrayOfLabels: List[List[str]], + inputCols: List[str], + outputCols: Optional[List[str]] = ..., + handleInvalid: Optional[str] = ..., + ) -> StringIndexerModel: ... + @property + def labels(self) -> List[str]: ... + +class IndexToString( + JavaTransformer, + HasInputCol, + HasOutputCol, + JavaMLReadable[IndexToString], + JavaMLWritable, +): + labels: Param[List[str]] + def __init__( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + labels: Optional[List[str]] = ... + ) -> None: ... + def setParams( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + labels: Optional[List[str]] = ... + ) -> IndexToString: ... + def setLabels(self, value: List[str]) -> IndexToString: ... + def getLabels(self) -> List[str]: ... + def setInputCol(self, value: str) -> IndexToString: ... + def setOutputCol(self, value: str) -> IndexToString: ... + +class StopWordsRemover( + JavaTransformer, + HasInputCol, + HasOutputCol, + HasInputCols, + HasOutputCols, + JavaMLReadable[StopWordsRemover], + JavaMLWritable, +): + stopWords: Param[List[str]] + caseSensitive: Param[bool] + locale: Param[str] + @overload + def __init__( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + stopWords: Optional[List[str]] = ..., + caseSensitive: bool = ..., + locale: Optional[str] = ... + ) -> None: ... + @overload + def __init__( + self, + *, + stopWords: Optional[List[str]] = ..., + caseSensitive: bool = ..., + locale: Optional[str] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> None: ... + @overload + def setParams( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + stopWords: Optional[List[str]] = ..., + caseSensitive: bool = ..., + locale: Optional[str] = ... + ) -> StopWordsRemover: ... + @overload + def setParams( + self, + *, + stopWords: Optional[List[str]] = ..., + caseSensitive: bool = ..., + locale: Optional[str] = ..., + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ... + ) -> StopWordsRemover: ... + def setStopWords(self, value: List[str]) -> StopWordsRemover: ... + def getStopWords(self) -> List[str]: ... + def setCaseSensitive(self, value: bool) -> StopWordsRemover: ... + def getCaseSensitive(self) -> bool: ... + def setLocale(self, value: str) -> StopWordsRemover: ... + def getLocale(self) -> str: ... + def setInputCol(self, value: str) -> StopWordsRemover: ... + def setOutputCol(self, value: str) -> StopWordsRemover: ... + def setInputCols(self, value: List[str]) -> StopWordsRemover: ... + def setOutputCols(self, value: List[str]) -> StopWordsRemover: ... + @staticmethod + def loadDefaultStopWords(language: str) -> List[str]: ... + +class Tokenizer( + JavaTransformer, + HasInputCol, + HasOutputCol, + JavaMLReadable[Tokenizer], + JavaMLWritable, +): + def __init__( + self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ... + ) -> Tokenizer: ... + def setInputCol(self, value: str) -> Tokenizer: ... + def setOutputCol(self, value: str) -> Tokenizer: ... + +class VectorAssembler( + JavaTransformer, + HasInputCols, + HasOutputCol, + HasHandleInvalid, + JavaMLReadable[VectorAssembler], + JavaMLWritable, +): + handleInvalid: Param[str] + def __init__( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ... + ) -> None: ... + def setParams( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ... + ) -> VectorAssembler: ... + def setInputCols(self, value: List[str]) -> VectorAssembler: ... + def setOutputCol(self, value: str) -> VectorAssembler: ... + def setHandleInvalid(self, value: str) -> VectorAssembler: ... + +class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid): + maxCategories: Param[int] + handleInvalid: Param[str] + def __init__(self, *args: Any): ... + def getMaxCategories(self) -> int: ... + +class VectorIndexer( + JavaEstimator[VectorIndexerModel], + _VectorIndexerParams, + HasHandleInvalid, + JavaMLReadable[VectorIndexer], + JavaMLWritable, +): + def __init__( + self, + *, + maxCategories: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ... + ) -> None: ... + def setParams( + self, + *, + maxCategories: int = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + handleInvalid: str = ... + ) -> VectorIndexer: ... + def setMaxCategories(self, value: int) -> VectorIndexer: ... + def setInputCol(self, value: str) -> VectorIndexer: ... + def setOutputCol(self, value: str) -> VectorIndexer: ... + def setHandleInvalid(self, value: str) -> VectorIndexer: ... + +class VectorIndexerModel( + JavaModel, _VectorIndexerParams, JavaMLReadable[VectorIndexerModel], JavaMLWritable +): + def setInputCol(self, value: str) -> VectorIndexerModel: ... + def setOutputCol(self, value: str) -> VectorIndexerModel: ... + @property + def numFeatures(self) -> int: ... + @property + def categoryMaps(self) -> Dict[int, Tuple[float, int]]: ... + +class VectorSlicer( + JavaTransformer, + HasInputCol, + HasOutputCol, + JavaMLReadable[VectorSlicer], + JavaMLWritable, +): + indices: Param[List[int]] + names: Param[List[str]] + def __init__( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + indices: Optional[List[int]] = ..., + names: Optional[List[str]] = ... + ) -> None: ... + def setParams( + self, + *, + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + indices: Optional[List[int]] = ..., + names: Optional[List[str]] = ... + ) -> VectorSlicer: ... + def setIndices(self, value: List[int]) -> VectorSlicer: ... + def getIndices(self) -> List[int]: ... + def setNames(self, value: List[str]) -> VectorSlicer: ... + def getNames(self) -> List[str]: ... + def setInputCol(self, value: str) -> VectorSlicer: ... + def setOutputCol(self, value: str) -> VectorSlicer: ... + +class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): + vectorSize: Param[int] + numPartitions: Param[int] + minCount: Param[int] + windowSize: Param[int] + maxSentenceLength: Param[int] + def __init__(self, *args: Any): ... + def getVectorSize(self) -> int: ... + def getNumPartitions(self) -> int: ... + def getMinCount(self) -> int: ... + def getWindowSize(self) -> int: ... + def getMaxSentenceLength(self) -> int: ... + +class Word2Vec( + JavaEstimator[Word2VecModel], + _Word2VecParams, + JavaMLReadable[Word2Vec], + JavaMLWritable, +): + def __init__( + self, + *, + vectorSize: int = ..., + minCount: int = ..., + numPartitions: int = ..., + stepSize: float = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + windowSize: int = ..., + maxSentenceLength: int = ... + ) -> None: ... + def setParams( + self, + *, + vectorSize: int = ..., + minCount: int = ..., + numPartitions: int = ..., + stepSize: float = ..., + maxIter: int = ..., + seed: Optional[int] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + windowSize: int = ..., + maxSentenceLength: int = ... + ) -> Word2Vec: ... + def setVectorSize(self, value: int) -> Word2Vec: ... + def setNumPartitions(self, value: int) -> Word2Vec: ... + def setMinCount(self, value: int) -> Word2Vec: ... + def setWindowSize(self, value: int) -> Word2Vec: ... + def setMaxSentenceLength(self, value: int) -> Word2Vec: ... + def setMaxIter(self, value: int) -> Word2Vec: ... + def setInputCol(self, value: str) -> Word2Vec: ... + def setOutputCol(self, value: str) -> Word2Vec: ... + def setSeed(self, value: int) -> Word2Vec: ... + def setStepSize(self, value: float) -> Word2Vec: ... + +class Word2VecModel( + JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel], JavaMLWritable +): + def getVectors(self) -> DataFrame: ... + def setInputCol(self, value: str) -> Word2VecModel: ... + def setOutputCol(self, value: str) -> Word2VecModel: ... + @overload + def findSynonyms(self, word: str, num: int) -> DataFrame: ... + @overload + def findSynonyms(self, word: Vector, num: int) -> DataFrame: ... + @overload + def findSynonymsArray(self, word: str, num: int) -> List[Tuple[str, float]]: ... + @overload + def findSynonymsArray(self, word: Vector, num: int) -> List[Tuple[str, float]]: ... + +class _PCAParams(HasInputCol, HasOutputCol): + k: Param[int] + def getK(self) -> int: ... + +class PCA(JavaEstimator[PCAModel], _PCAParams, JavaMLReadable[PCA], JavaMLWritable): + def __init__( + self, + *, + k: Optional[int] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + k: Optional[int] = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ... + ) -> PCA: ... + def setK(self, value: int) -> PCA: ... + def setInputCol(self, value: str) -> PCA: ... + def setOutputCol(self, value: str) -> PCA: ... + +class PCAModel(JavaModel, _PCAParams, JavaMLReadable[PCAModel], JavaMLWritable): + def setInputCol(self, value: str) -> PCAModel: ... + def setOutputCol(self, value: str) -> PCAModel: ... + @property + def pc(self) -> DenseMatrix: ... + @property + def explainedVariance(self) -> DenseVector: ... + +class _RFormulaParams(HasFeaturesCol, HasLabelCol, HasHandleInvalid): + formula: Param[str] + forceIndexLabel: Param[bool] + stringIndexerOrderType: Param[str] + handleInvalid: Param[str] + def __init__(self, *args: Any): ... + def getFormula(self) -> str: ... + def getForceIndexLabel(self) -> bool: ... + def getStringIndexerOrderType(self) -> str: ... + +class RFormula( + JavaEstimator[RFormulaModel], + _RFormulaParams, + JavaMLReadable[RFormula], + JavaMLWritable, +): + def __init__( + self, + *, + formula: Optional[str] = ..., + featuresCol: str = ..., + labelCol: str = ..., + forceIndexLabel: bool = ..., + stringIndexerOrderType: str = ..., + handleInvalid: str = ... + ) -> None: ... + def setParams( + self, + *, + formula: Optional[str] = ..., + featuresCol: str = ..., + labelCol: str = ..., + forceIndexLabel: bool = ..., + stringIndexerOrderType: str = ..., + handleInvalid: str = ... + ) -> RFormula: ... + def setFormula(self, value: str) -> RFormula: ... + def setForceIndexLabel(self, value: bool) -> RFormula: ... + def setStringIndexerOrderType(self, value: str) -> RFormula: ... + def setFeaturesCol(self, value: str) -> RFormula: ... + def setLabelCol(self, value: str) -> RFormula: ... + def setHandleInvalid(self, value: str) -> RFormula: ... + +class RFormulaModel( + JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel], JavaMLWritable +): ... + +class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol): + selectorType: Param[str] + numTopFeatures: Param[int] + percentile: Param[float] + fpr: Param[float] + fdr: Param[float] + fwe: Param[float] + def __init__(self, *args: Any): ... + def getSelectorType(self) -> str: ... + def getNumTopFeatures(self) -> int: ... + def getPercentile(self) -> float: ... + def getFpr(self) -> float: ... + def getFdr(self) -> float: ... + def getFwe(self) -> float: ... + +class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable, JavaMLWritable): + def setSelectorType(self: P, value: str) -> P: ... + def setNumTopFeatures(self: P, value: int) -> P: ... + def setPercentile(self: P, value: float) -> P: ... + def setFpr(self: P, value: float) -> P: ... + def setFdr(self: P, value: float) -> P: ... + def setFwe(self: P, value: float) -> P: ... + def setFeaturesCol(self: P, value: str) -> P: ... + def setOutputCol(self: P, value: str) -> P: ... + def setLabelCol(self: P, value: str) -> P: ... + +class _SelectorModel(JavaModel, _SelectorParams): + def setFeaturesCol(self: P, value: str) -> P: ... + def setOutputCol(self: P, value: str) -> P: ... + @property + def selectedFeatures(self) -> List[int]: ... + +class ANOVASelector( + _Selector[ANOVASelectorModel], JavaMLReadable[ANOVASelector], JavaMLWritable +): + def __init__( + self, + numTopFeatures: int = ..., + featuresCol: str = ..., + outputCol: Optional[str] = ..., + labelCol: str = ..., + selectorType: str = ..., + percentile: float = ..., + fpr: float = ..., + fdr: float = ..., + fwe: float = ..., + ) -> None: ... + def setParams( + self, + numTopFeatures: int = ..., + featuresCol: str = ..., + outputCol: Optional[str] = ..., + labelCol: str = ..., + selectorType: str = ..., + percentile: float = ..., + fpr: float = ..., + fdr: float = ..., + fwe: float = ..., + ) -> ANOVASelector: ... + +class ANOVASelectorModel( + _SelectorModel, JavaMLReadable[ANOVASelectorModel], JavaMLWritable +): ... + +class ChiSqSelector( + _Selector[ChiSqSelectorModel], + JavaMLReadable[ChiSqSelector], + JavaMLWritable, +): + def __init__( + self, + *, + numTopFeatures: int = ..., + featuresCol: str = ..., + outputCol: Optional[str] = ..., + labelCol: str = ..., + selectorType: str = ..., + percentile: float = ..., + fpr: float = ..., + fdr: float = ..., + fwe: float = ... + ) -> None: ... + def setParams( + self, + *, + numTopFeatures: int = ..., + featuresCol: str = ..., + outputCol: Optional[str] = ..., + labelCol: str = ..., + selectorType: str = ..., + percentile: float = ..., + fpr: float = ..., + fdr: float = ..., + fwe: float = ... + ): ... + def setSelectorType(self, value: str) -> ChiSqSelector: ... + def setNumTopFeatures(self, value: int) -> ChiSqSelector: ... + def setPercentile(self, value: float) -> ChiSqSelector: ... + def setFpr(self, value: float) -> ChiSqSelector: ... + def setFdr(self, value: float) -> ChiSqSelector: ... + def setFwe(self, value: float) -> ChiSqSelector: ... + def setFeaturesCol(self, value: str) -> ChiSqSelector: ... + def setOutputCol(self, value: str) -> ChiSqSelector: ... + def setLabelCol(self, value: str) -> ChiSqSelector: ... + +class ChiSqSelectorModel( + _SelectorModel, JavaMLReadable[ChiSqSelectorModel], JavaMLWritable +): + def setFeaturesCol(self, value: str) -> ChiSqSelectorModel: ... + def setOutputCol(self, value: str) -> ChiSqSelectorModel: ... + @property + def selectedFeatures(self) -> List[int]: ... + +class VectorSizeHint( + JavaTransformer, + HasInputCol, + HasHandleInvalid, + JavaMLReadable[VectorSizeHint], + JavaMLWritable, +): + size: Param[int] + handleInvalid: Param[str] + def __init__( + self, + *, + inputCol: Optional[str] = ..., + size: Optional[int] = ..., + handleInvalid: str = ... + ) -> None: ... + def setParams( + self, + *, + inputCol: Optional[str] = ..., + size: Optional[int] = ..., + handleInvalid: str = ... + ) -> VectorSizeHint: ... + def setSize(self, value: int) -> VectorSizeHint: ... + def getSize(self) -> int: ... + def setInputCol(self, value: str) -> VectorSizeHint: ... + def setHandleInvalid(self, value: str) -> VectorSizeHint: ... + +class FValueSelector( + _Selector[FValueSelectorModel], JavaMLReadable[FValueSelector], JavaMLWritable +): + def __init__( + self, + numTopFeatures: int = ..., + featuresCol: str = ..., + outputCol: Optional[str] = ..., + labelCol: str = ..., + selectorType: str = ..., + percentile: float = ..., + fpr: float = ..., + fdr: float = ..., + fwe: float = ..., + ) -> None: ... + def setParams( + self, + numTopFeatures: int = ..., + featuresCol: str = ..., + outputCol: Optional[str] = ..., + labelCol: str = ..., + selectorType: str = ..., + percentile: float = ..., + fpr: float = ..., + fdr: float = ..., + fwe: float = ..., + ) -> FValueSelector: ... + +class FValueSelectorModel( + _SelectorModel, JavaMLReadable[FValueSelectorModel], JavaMLWritable +): ... + +class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol): + varianceThreshold: Param[float] = ... + def getVarianceThreshold(self) -> float: ... + +class VarianceThresholdSelector( + JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable +): + def __init__( + self, + featuresCol: str = ..., + outputCol: Optional[str] = ..., + varianceThreshold: float = ..., + ) -> None: ... + def setParams( + self, + featuresCol: str = ..., + outputCol: Optional[str] = ..., + varianceThreshold: float = ..., + ): ... + def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ... + def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ... + def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... + +class VarianceThresholdSelectorModel( + JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable +): + def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ... + def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ... + @property + def selectedFeatures(self) -> List[int]: ... diff --git a/python/pyspark/ml/fpm.pyi b/python/pyspark/ml/fpm.pyi new file mode 100644 index 0000000000000..7cc304a2ffa39 --- /dev/null +++ b/python/pyspark/ml/fpm.pyi @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Optional + +from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.ml.wrapper import JavaEstimator, JavaParams, JavaModel +from pyspark.ml.param.shared import HasPredictionCol +from pyspark.sql.dataframe import DataFrame + +from pyspark.ml.param import Param + +class _FPGrowthParams(HasPredictionCol): + itemsCol: Param[str] + minSupport: Param[float] + numPartitions: Param[int] + minConfidence: Param[float] + def __init__(self, *args: Any): ... + def getItemsCol(self) -> str: ... + def getMinSupport(self) -> float: ... + def getNumPartitions(self) -> int: ... + def getMinConfidence(self) -> float: ... + +class FPGrowthModel( + JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable[FPGrowthModel] +): + def setItemsCol(self, value: str) -> FPGrowthModel: ... + def setMinConfidence(self, value: float) -> FPGrowthModel: ... + def setPredictionCol(self, value: str) -> FPGrowthModel: ... + @property + def freqItemsets(self) -> DataFrame: ... + @property + def associationRules(self) -> DataFrame: ... + +class FPGrowth( + JavaEstimator[FPGrowthModel], + _FPGrowthParams, + JavaMLWritable, + JavaMLReadable[FPGrowth], +): + def __init__( + self, + *, + minSupport: float = ..., + minConfidence: float = ..., + itemsCol: str = ..., + predictionCol: str = ..., + numPartitions: Optional[int] = ... + ) -> None: ... + def setParams( + self, + *, + minSupport: float = ..., + minConfidence: float = ..., + itemsCol: str = ..., + predictionCol: str = ..., + numPartitions: Optional[int] = ... + ) -> FPGrowth: ... + def setItemsCol(self, value: str) -> FPGrowth: ... + def setMinSupport(self, value: float) -> FPGrowth: ... + def setNumPartitions(self, value: int) -> FPGrowth: ... + def setMinConfidence(self, value: float) -> FPGrowth: ... + def setPredictionCol(self, value: str) -> FPGrowth: ... + +class PrefixSpan(JavaParams): + minSupport: Param[float] + maxPatternLength: Param[int] + maxLocalProjDBSize: Param[int] + sequenceCol: Param[str] + def __init__( + self, + *, + minSupport: float = ..., + maxPatternLength: int = ..., + maxLocalProjDBSize: int = ..., + sequenceCol: str = ... + ) -> None: ... + def setParams( + self, + *, + minSupport: float = ..., + maxPatternLength: int = ..., + maxLocalProjDBSize: int = ..., + sequenceCol: str = ... + ) -> PrefixSpan: ... + def setMinSupport(self, value: float) -> PrefixSpan: ... + def getMinSupport(self) -> float: ... + def setMaxPatternLength(self, value: int) -> PrefixSpan: ... + def getMaxPatternLength(self) -> int: ... + def setMaxLocalProjDBSize(self, value: int) -> PrefixSpan: ... + def getMaxLocalProjDBSize(self) -> int: ... + def setSequenceCol(self, value: str) -> PrefixSpan: ... + def getSequenceCol(self) -> str: ... + def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame: ... diff --git a/python/pyspark/ml/functions.pyi b/python/pyspark/ml/functions.pyi new file mode 100644 index 0000000000000..42650e742e781 --- /dev/null +++ b/python/pyspark/ml/functions.pyi @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark import SparkContext as SparkContext, since as since # noqa: F401 +from pyspark.sql.column import Column as Column + +def vector_to_array(col: Column) -> Column: ... diff --git a/python/pyspark/ml/image.pyi b/python/pyspark/ml/image.pyi new file mode 100644 index 0000000000000..9ff3a8817aadd --- /dev/null +++ b/python/pyspark/ml/image.pyi @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, List + +from pyspark.sql.types import Row, StructType + +from numpy import ndarray # type: ignore[import] + +class _ImageSchema: + def __init__(self) -> None: ... + @property + def imageSchema(self) -> StructType: ... + @property + def ocvTypes(self) -> Dict[str, int]: ... + @property + def columnSchema(self) -> StructType: ... + @property + def imageFields(self) -> List[str]: ... + @property + def undefinedImageType(self) -> str: ... + def toNDArray(self, image: Row) -> ndarray: ... + def toImage(self, array: ndarray, origin: str = ...) -> Row: ... + +ImageSchema: _ImageSchema diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi new file mode 100644 index 0000000000000..a576b30aec308 --- /dev/null +++ b/python/pyspark/ml/linalg/__init__.pyi @@ -0,0 +1,255 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +from pyspark.ml import linalg as newlinalg # noqa: F401 +from pyspark.sql.types import StructType, UserDefinedType + +from numpy import float64, ndarray # type: ignore[import] + +class VectorUDT(UserDefinedType): + @classmethod + def sqlType(cls) -> StructType: ... + @classmethod + def module(cls) -> str: ... + @classmethod + def scalaUDT(cls) -> str: ... + def serialize( + self, obj: Vector + ) -> Tuple[int, Optional[int], Optional[List[int]], List[float]]: ... + def deserialize(self, datum: Any) -> Vector: ... + def simpleString(self) -> str: ... + +class MatrixUDT(UserDefinedType): + @classmethod + def sqlType(cls) -> StructType: ... + @classmethod + def module(cls) -> str: ... + @classmethod + def scalaUDT(cls) -> str: ... + def serialize( + self, obj + ) -> Tuple[ + int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool + ]: ... + def deserialize(self, datum: Any) -> Matrix: ... + def simpleString(self) -> str: ... + +class Vector: + __UDT__: VectorUDT + def toArray(self) -> ndarray: ... + +class DenseVector(Vector): + array: ndarray + @overload + def __init__(self, *elements: float) -> None: ... + @overload + def __init__(self, __arr: bytes) -> None: ... + @overload + def __init__(self, __arr: Iterable[float]) -> None: ... + @staticmethod + def parse(s) -> DenseVector: ... + def __reduce__(self) -> Tuple[type, bytes]: ... + def numNonzeros(self) -> int: ... + def norm(self, p: Union[float, str]) -> float64: ... + def dot(self, other: Iterable[float]) -> float64: ... + def squared_distance(self, other: Iterable[float]) -> float64: ... + def toArray(self) -> ndarray: ... + @property + def values(self) -> ndarray: ... + def __getitem__(self, item: int) -> float64: ... + def __len__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __hash__(self) -> int: ... + def __getattr__(self, item: str) -> Any: ... + def __neg__(self) -> DenseVector: ... + def __add__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __sub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __mul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __div__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __truediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __mod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __radd__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rsub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rmul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rdiv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rtruediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rmod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + +class SparseVector(Vector): + size: int + indices: ndarray + values: ndarray + @overload + def __init__(self, size: int, *args: Tuple[int, float]) -> None: ... + @overload + def __init__(self, size: int, __indices: bytes, __values: bytes) -> None: ... + @overload + def __init__( + self, size: int, __indices: Iterable[int], __values: Iterable[float] + ) -> None: ... + @overload + def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]) -> None: ... + @overload + def __init__(self, size: int, __map: Dict[int, float]) -> None: ... + def numNonzeros(self) -> int: ... + def norm(self, p: Union[float, str]) -> float64: ... + def __reduce__(self): ... + @staticmethod + def parse(s: str) -> SparseVector: ... + def dot(self, other: Iterable[float]) -> float64: ... + def squared_distance(self, other: Iterable[float]) -> float64: ... + def toArray(self) -> ndarray: ... + def __len__(self) -> int: ... + def __eq__(self, other) -> bool: ... + def __getitem__(self, index: int) -> float64: ... + def __ne__(self, other) -> bool: ... + def __hash__(self) -> int: ... + +class Vectors: + @overload + @staticmethod + def sparse(size: int, *args: Tuple[int, float]) -> SparseVector: ... + @overload + @staticmethod + def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: ... + @overload + @staticmethod + def sparse( + size: int, __indices: Iterable[int], __values: Iterable[float] + ) -> SparseVector: ... + @overload + @staticmethod + def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: ... + @overload + @staticmethod + def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... + @overload + @staticmethod + def dense(self, *elements: float) -> DenseVector: ... + @overload + @staticmethod + def dense(self, __arr: bytes) -> DenseVector: ... + @overload + @staticmethod + def dense(self, __arr: Iterable[float]) -> DenseVector: ... + @staticmethod + def stringify(vector: Vector) -> str: ... + @staticmethod + def squared_distance(v1: Vector, v2: Vector) -> float64: ... + @staticmethod + def norm(vector: Vector, p: Union[float, str]) -> float64: ... + @staticmethod + def parse(s: str) -> Vector: ... + @staticmethod + def zeros(size: int) -> DenseVector: ... + +class Matrix: + __UDT__: MatrixUDT + numRows: int + numCols: int + isTransposed: bool + def __init__( + self, numRows: int, numCols: int, isTransposed: bool = ... + ) -> None: ... + def toArray(self): ... + +class DenseMatrix(Matrix): + values: Any + @overload + def __init__( + self, numRows: int, numCols: int, values: bytes, isTransposed: bool = ... + ) -> None: ... + @overload + def __init__( + self, + numRows: int, + numCols: int, + values: Iterable[float], + isTransposed: bool = ..., + ) -> None: ... + def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def toArray(self) -> ndarray: ... + def toSparse(self) -> SparseMatrix: ... + def __getitem__(self, indices: Tuple[int, int]) -> float64: ... + def __eq__(self, other) -> bool: ... + +class SparseMatrix(Matrix): + colPtrs: ndarray + rowIndices: ndarray + values: ndarray + @overload + def __init__( + self, + numRows: int, + numCols: int, + colPtrs: bytes, + rowIndices: bytes, + values: bytes, + isTransposed: bool = ..., + ) -> None: ... + @overload + def __init__( + self, + numRows: int, + numCols: int, + colPtrs: Iterable[int], + rowIndices: Iterable[int], + values: Iterable[float], + isTransposed: bool = ..., + ) -> None: ... + def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __getitem__(self, indices: Tuple[int, int]) -> float64: ... + def toArray(self) -> ndarray: ... + def toDense(self) -> DenseMatrix: ... + def __eq__(self, other) -> bool: ... + +class Matrices: + @overload + @staticmethod + def dense( + numRows: int, numCols: int, values: bytes, isTransposed: bool = ... + ) -> DenseMatrix: ... + @overload + @staticmethod + def dense( + numRows: int, numCols: int, values: Iterable[float], isTransposed: bool = ... + ) -> DenseMatrix: ... + @overload + @staticmethod + def sparse( + numRows: int, + numCols: int, + colPtrs: bytes, + rowIndices: bytes, + values: bytes, + isTransposed: bool = ..., + ) -> SparseMatrix: ... + @overload + @staticmethod + def sparse( + numRows: int, + numCols: int, + colPtrs: Iterable[int], + rowIndices: Iterable[int], + values: Iterable[float], + isTransposed: bool = ..., + ) -> SparseMatrix: ... diff --git a/python/pyspark/ml/param/__init__.pyi b/python/pyspark/ml/param/__init__.pyi new file mode 100644 index 0000000000000..23a63c573e452 --- /dev/null +++ b/python/pyspark/ml/param/__init__.pyi @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import abc +from typing import overload +from typing import Any, Callable, Generic, List, Optional +from pyspark.ml._typing import T +import pyspark.ml._typing + +import pyspark.ml.util +from pyspark.ml.linalg import DenseVector, Matrix + +class Param(Generic[T]): + parent: str + name: str + doc: str + typeConverter: Callable[[Any], T] + def __init__( + self, + parent: pyspark.ml.util.Identifiable, + name: str, + doc: str, + typeConverter: Optional[Callable[[Any], T]] = ..., + ) -> None: ... + def __hash__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + +class TypeConverters: + @staticmethod + def identity(value: T) -> T: ... + @staticmethod + def toList(value: Any) -> List: ... + @staticmethod + def toListFloat(value: Any) -> List[float]: ... + @staticmethod + def toListInt(value: Any) -> List[int]: ... + @staticmethod + def toListString(value: Any) -> List[str]: ... + @staticmethod + def toVector(value: Any) -> DenseVector: ... + @staticmethod + def toMatrix(value: Any) -> Matrix: ... + @staticmethod + def toFloat(value: Any) -> float: ... + @staticmethod + def toInt(value: Any) -> int: ... + @staticmethod + def toString(value: Any) -> str: ... + @staticmethod + def toBoolean(value: Any) -> bool: ... + +class Params(pyspark.ml.util.Identifiable, metaclass=abc.ABCMeta): + def __init__(self) -> None: ... + @property + def params(self) -> List[Param]: ... + def explainParam(self, param: str) -> str: ... + def explainParams(self) -> str: ... + def getParam(self, paramName: str) -> Param: ... + @overload + def isSet(self, param: str) -> bool: ... + @overload + def isSet(self, param: Param[Any]) -> bool: ... + @overload + def hasDefault(self, param: str) -> bool: ... + @overload + def hasDefault(self, param: Param[Any]) -> bool: ... + @overload + def isDefined(self, param: str) -> bool: ... + @overload + def isDefined(self, param: Param[Any]) -> bool: ... + def hasParam(self, paramName: str) -> bool: ... + @overload + def getOrDefault(self, param: str) -> Any: ... + @overload + def getOrDefault(self, param: Param[T]) -> T: ... + def extractParamMap( + self, extra: Optional[pyspark.ml._typing.ParamMap] = ... + ) -> pyspark.ml._typing.ParamMap: ... + def copy(self, extra: Optional[pyspark.ml._typing.ParamMap] = ...) -> Params: ... + def set(self, param: Param, value: Any) -> None: ... + def clear(self, param: Param) -> None: ... diff --git a/python/pyspark/ml/param/_shared_params_code_gen.pyi b/python/pyspark/ml/param/_shared_params_code_gen.pyi new file mode 100644 index 0000000000000..e436a54c0eaa4 --- /dev/null +++ b/python/pyspark/ml/param/_shared_params_code_gen.pyi @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +header: str diff --git a/python/pyspark/ml/param/shared.pyi b/python/pyspark/ml/param/shared.pyi new file mode 100644 index 0000000000000..5999c0eaa4661 --- /dev/null +++ b/python/pyspark/ml/param/shared.pyi @@ -0,0 +1,187 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Generic, List +from pyspark.ml._typing import T + +from pyspark.ml.param import * + +class HasMaxIter(Params): + maxIter: Param[int] + def __init__(self) -> None: ... + def getMaxIter(self) -> int: ... + +class HasRegParam(Params): + regParam: Param[float] + def __init__(self) -> None: ... + def getRegParam(self) -> float: ... + +class HasFeaturesCol(Params): + featuresCol: Param[str] + def __init__(self) -> None: ... + def getFeaturesCol(self) -> str: ... + +class HasLabelCol(Params): + labelCol: Param[str] + def __init__(self) -> None: ... + def getLabelCol(self) -> str: ... + +class HasPredictionCol(Params): + predictionCol: Param[str] + def __init__(self) -> None: ... + def getPredictionCol(self) -> str: ... + +class HasProbabilityCol(Params): + probabilityCol: Param[str] + def __init__(self) -> None: ... + def getProbabilityCol(self) -> str: ... + +class HasRawPredictionCol(Params): + rawPredictionCol: Param[str] + def __init__(self) -> None: ... + def getRawPredictionCol(self) -> str: ... + +class HasInputCol(Params): + inputCol: Param[str] + def __init__(self) -> None: ... + def getInputCol(self) -> str: ... + +class HasInputCols(Params): + inputCols: Param[List[str]] + def __init__(self) -> None: ... + def getInputCols(self) -> List[str]: ... + +class HasOutputCol(Params): + outputCol: Param[str] + def __init__(self) -> None: ... + def getOutputCol(self) -> str: ... + +class HasOutputCols(Params): + outputCols: Param[List[str]] + def __init__(self) -> None: ... + def getOutputCols(self) -> List[str]: ... + +class HasNumFeatures(Params): + numFeatures: Param[int] + def __init__(self) -> None: ... + def getNumFeatures(self) -> int: ... + +class HasCheckpointInterval(Params): + checkpointInterval: Param[int] + def __init__(self) -> None: ... + def getCheckpointInterval(self) -> int: ... + +class HasSeed(Params): + seed: Param[int] + def __init__(self) -> None: ... + def getSeed(self) -> int: ... + +class HasTol(Params): + tol: Param[float] + def __init__(self) -> None: ... + def getTol(self) -> float: ... + +class HasRelativeError(Params): + relativeError: Param[float] + def __init__(self) -> None: ... + def getRelativeError(self) -> float: ... + +class HasStepSize(Params): + stepSize: Param[float] + def __init__(self) -> None: ... + def getStepSize(self) -> float: ... + +class HasHandleInvalid(Params): + handleInvalid: Param[str] + def __init__(self) -> None: ... + def getHandleInvalid(self) -> str: ... + +class HasElasticNetParam(Params): + elasticNetParam: Param[float] + def __init__(self) -> None: ... + def getElasticNetParam(self) -> float: ... + +class HasFitIntercept(Params): + fitIntercept: Param[bool] + def __init__(self) -> None: ... + def getFitIntercept(self) -> bool: ... + +class HasStandardization(Params): + standardization: Param[bool] + def __init__(self) -> None: ... + def getStandardization(self) -> bool: ... + +class HasThresholds(Params): + thresholds: Param[List[float]] + def __init__(self) -> None: ... + def getThresholds(self) -> List[float]: ... + +class HasThreshold(Params): + threshold: Param[float] + def __init__(self) -> None: ... + def getThreshold(self) -> float: ... + +class HasWeightCol(Params): + weightCol: Param[str] + def __init__(self) -> None: ... + def getWeightCol(self) -> str: ... + +class HasSolver(Params): + solver: Param[str] + def __init__(self) -> None: ... + def getSolver(self) -> str: ... + +class HasVarianceCol(Params): + varianceCol: Param[str] + def __init__(self) -> None: ... + def getVarianceCol(self) -> str: ... + +class HasAggregationDepth(Params): + aggregationDepth: Param[int] + def __init__(self) -> None: ... + def getAggregationDepth(self) -> int: ... + +class HasParallelism(Params): + parallelism: Param[int] + def __init__(self) -> None: ... + def getParallelism(self) -> int: ... + +class HasCollectSubModels(Params): + collectSubModels: Param[bool] + def __init__(self) -> None: ... + def getCollectSubModels(self) -> bool: ... + +class HasLoss(Params): + loss: Param[str] + def __init__(self) -> None: ... + def getLoss(self) -> str: ... + +class HasValidationIndicatorCol(Params): + validationIndicatorCol: Param[str] + def __init__(self) -> None: ... + def getValidationIndicatorCol(self) -> str: ... + +class HasDistanceMeasure(Params): + distanceMeasure: Param[str] + def __init__(self) -> None: ... + def getDistanceMeasure(self) -> str: ... + +class HasBlockSize(Params): + blockSize: Param[int] + def __init__(self) -> None: ... + def getBlockSize(self) -> int: ... diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi new file mode 100644 index 0000000000000..44680586d70d1 --- /dev/null +++ b/python/pyspark/ml/pipeline.pyi @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from pyspark.ml._typing import PipelineStage +from pyspark.context import SparkContext +from pyspark.ml.base import Estimator, Model, Transformer +from pyspark.ml.param import Param +from pyspark.ml.util import ( # noqa: F401 + DefaultParamsReader as DefaultParamsReader, + DefaultParamsWriter as DefaultParamsWriter, + JavaMLReader as JavaMLReader, + JavaMLWritable as JavaMLWritable, + JavaMLWriter as JavaMLWriter, + MLReadable as MLReadable, + MLReader as MLReader, + MLWritable as MLWritable, + MLWriter as MLWriter, +) + +class Pipeline(Estimator[PipelineModel], MLReadable[Pipeline], MLWritable): + stages: List[PipelineStage] + def __init__(self, *, stages: Optional[List[PipelineStage]] = ...) -> None: ... + def setStages(self, stages: List[PipelineStage]) -> Pipeline: ... + def getStages(self) -> List[PipelineStage]: ... + def setParams(self, *, stages: Optional[List[PipelineStage]] = ...) -> Pipeline: ... + def copy(self, extra: Optional[Dict[Param, str]] = ...) -> Pipeline: ... + def write(self) -> JavaMLWriter: ... + def save(self, path: str) -> None: ... + @classmethod + def read(cls) -> PipelineReader: ... + +class PipelineWriter(MLWriter): + instance: Pipeline + def __init__(self, instance: Pipeline) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class PipelineReader(MLReader): + cls: Type[Pipeline] + def __init__(self, cls: Type[Pipeline]) -> None: ... + def load(self, path: str) -> Pipeline: ... + +class PipelineModelWriter(MLWriter): + instance: PipelineModel + def __init__(self, instance: PipelineModel) -> None: ... + def saveImpl(self, path: str) -> None: ... + +class PipelineModelReader(MLReader): + cls: Type[PipelineModel] + def __init__(self, cls: Type[PipelineModel]) -> None: ... + def load(self, path: str) -> PipelineModel: ... + +class PipelineModel(Model, MLReadable[PipelineModel], MLWritable): + stages: List[PipelineStage] + def __init__(self, stages: List[Transformer]) -> None: ... + def copy(self, extra: Optional[Dict[Param, Any]] = ...) -> PipelineModel: ... + def write(self) -> JavaMLWriter: ... + def save(self, path: str) -> None: ... + @classmethod + def read(cls) -> PipelineModelReader: ... + +class PipelineSharedReadWrite: + @staticmethod + def checkStagesForJava(stages: List[PipelineStage]) -> bool: ... + @staticmethod + def validateStages(stages: List[PipelineStage]) -> None: ... + @staticmethod + def saveImpl( + instance: Union[Pipeline, PipelineModel], + stages: List[PipelineStage], + sc: SparkContext, + path: str, + ) -> None: ... + @staticmethod + def load( + metadata: Dict[str, Any], sc: SparkContext, path: str + ) -> Tuple[str, List[PipelineStage]]: ... + @staticmethod + def getStagePath( + stageUid: str, stageIdx: int, numStages: int, stagesDir: str + ) -> str: ... diff --git a/python/pyspark/ml/recommendation.pyi b/python/pyspark/ml/recommendation.pyi new file mode 100644 index 0000000000000..390486b45c5e6 --- /dev/null +++ b/python/pyspark/ml/recommendation.pyi @@ -0,0 +1,152 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Optional + +import sys # noqa: F401 + +from pyspark import since, keyword_only # noqa: F401 +from pyspark.ml.param.shared import ( + HasBlockSize, + HasCheckpointInterval, + HasMaxIter, + HasPredictionCol, + HasRegParam, + HasSeed, +) +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.common import inherit_doc # noqa: F401 +from pyspark.ml.param import Param +from pyspark.ml.util import JavaMLWritable, JavaMLReadable + +from pyspark.sql.dataframe import DataFrame + +class _ALSModelParams(HasPredictionCol, HasBlockSize): + userCol: Param[str] + itemCol: Param[str] + coldStartStrategy: Param[str] + def getUserCol(self) -> str: ... + def getItemCol(self) -> str: ... + def getColdStartStrategy(self) -> str: ... + +class _ALSParams( + _ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval, HasSeed +): + rank: Param[int] + numUserBlocks: Param[int] + numItemBlocks: Param[int] + implicitPrefs: Param[bool] + alpha: Param[float] + ratingCol: Param[str] + nonnegative: Param[bool] + intermediateStorageLevel: Param[str] + finalStorageLevel: Param[str] + def __init__(self, *args: Any): ... + def getRank(self) -> int: ... + def getNumUserBlocks(self) -> int: ... + def getNumItemBlocks(self) -> int: ... + def getImplicitPrefs(self) -> bool: ... + def getAlpha(self) -> float: ... + def getRatingCol(self) -> str: ... + def getNonnegative(self) -> bool: ... + def getIntermediateStorageLevel(self) -> str: ... + def getFinalStorageLevel(self) -> str: ... + +class ALS(JavaEstimator[ALSModel], _ALSParams, JavaMLWritable, JavaMLReadable[ALS]): + def __init__( + self, + *, + rank: int = ..., + maxIter: int = ..., + regParam: float = ..., + numUserBlocks: int = ..., + numItemBlocks: int = ..., + implicitPrefs: bool = ..., + alpha: float = ..., + userCol: str = ..., + itemCol: str = ..., + seed: Optional[int] = ..., + ratingCol: str = ..., + nonnegative: bool = ..., + checkpointInterval: int = ..., + intermediateStorageLevel: str = ..., + finalStorageLevel: str = ..., + coldStartStrategy: str = ..., + blockSize: int = ... + ) -> None: ... + def setParams( + self, + *, + rank: int = ..., + maxIter: int = ..., + regParam: float = ..., + numUserBlocks: int = ..., + numItemBlocks: int = ..., + implicitPrefs: bool = ..., + alpha: float = ..., + userCol: str = ..., + itemCol: str = ..., + seed: Optional[int] = ..., + ratingCol: str = ..., + nonnegative: bool = ..., + checkpointInterval: int = ..., + intermediateStorageLevel: str = ..., + finalStorageLevel: str = ..., + coldStartStrategy: str = ..., + blockSize: int = ... + ) -> ALS: ... + def setRank(self, value: int) -> ALS: ... + def setNumUserBlocks(self, value: int) -> ALS: ... + def setNumItemBlocks(self, value: int) -> ALS: ... + def setNumBlocks(self, value: int) -> ALS: ... + def setImplicitPrefs(self, value: bool) -> ALS: ... + def setAlpha(self, value: float) -> ALS: ... + def setUserCol(self, value: str) -> ALS: ... + def setItemCol(self, value: str) -> ALS: ... + def setRatingCol(self, value: str) -> ALS: ... + def setNonnegative(self, value: bool) -> ALS: ... + def setIntermediateStorageLevel(self, value: str) -> ALS: ... + def setFinalStorageLevel(self, value: str) -> ALS: ... + def setColdStartStrategy(self, value: str) -> ALS: ... + def setMaxIter(self, value: int) -> ALS: ... + def setRegParam(self, value: float) -> ALS: ... + def setPredictionCol(self, value: str) -> ALS: ... + def setCheckpointInterval(self, value: int) -> ALS: ... + def setSeed(self, value: int) -> ALS: ... + def setBlockSize(self, value: int) -> ALS: ... + +class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable[ALSModel]): + def setUserCol(self, value: str) -> ALSModel: ... + def setItemCol(self, value: str) -> ALSModel: ... + def setColdStartStrategy(self, value: str) -> ALSModel: ... + def setPredictionCol(self, value: str) -> ALSModel: ... + def setBlockSize(self, value: int) -> ALSModel: ... + @property + def rank(self) -> int: ... + @property + def userFactors(self) -> DataFrame: ... + @property + def itemFactors(self) -> DataFrame: ... + def recommendForAllUsers(self, numItems: int) -> DataFrame: ... + def recommendForAllItems(self, numUsers: int) -> DataFrame: ... + def recommendForUserSubset( + self, dataset: DataFrame, numItems: int + ) -> DataFrame: ... + def recommendForItemSubset( + self, dataset: DataFrame, numUsers: int + ) -> DataFrame: ... diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi new file mode 100644 index 0000000000000..991eb4f12ac85 --- /dev/null +++ b/python/pyspark/ml/regression.pyi @@ -0,0 +1,825 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, List, Optional +from pyspark.ml._typing import JM, M, T + +import abc +from pyspark.ml import PredictionModel, Predictor +from pyspark.ml.base import _PredictorParams +from pyspark.ml.param.shared import ( + HasAggregationDepth, + HasBlockSize, + HasElasticNetParam, + HasFeaturesCol, + HasFitIntercept, + HasLabelCol, + HasLoss, + HasMaxIter, + HasPredictionCol, + HasRegParam, + HasSeed, + HasSolver, + HasStandardization, + HasStepSize, + HasTol, + HasVarianceCol, + HasWeightCol, +) +from pyspark.ml.tree import ( + _DecisionTreeModel, + _DecisionTreeParams, + _GBTParams, + _RandomForestParams, + _TreeEnsembleModel, + _TreeRegressorParams, +) +from pyspark.ml.util import ( + GeneralJavaMLWritable, + HasTrainingSummary, + JavaMLReadable, + JavaMLWritable, +) +from pyspark.ml.wrapper import ( + JavaEstimator, + JavaModel, + JavaPredictionModel, + JavaPredictor, + JavaWrapper, +) + +from pyspark.ml.linalg import Matrix, Vector +from pyspark.ml.param import Param +from pyspark.sql.dataframe import DataFrame + +class Regressor(Predictor[M], _PredictorParams, metaclass=abc.ABCMeta): ... +class RegressionModel(PredictionModel[T], _PredictorParams, metaclass=abc.ABCMeta): ... +class _JavaRegressor(Regressor, JavaPredictor[JM], metaclass=abc.ABCMeta): ... +class _JavaRegressionModel( + RegressionModel, JavaPredictionModel[T], metaclass=abc.ABCMeta +): ... + +class _LinearRegressionParams( + _PredictorParams, + HasRegParam, + HasElasticNetParam, + HasMaxIter, + HasTol, + HasFitIntercept, + HasStandardization, + HasWeightCol, + HasSolver, + HasAggregationDepth, + HasLoss, + HasBlockSize, +): + solver: Param[str] + loss: Param[str] + epsilon: Param[float] + def __init__(self, *args: Any): ... + def getEpsilon(self) -> float: ... + +class LinearRegression( + _JavaRegressor[LinearRegressionModel], + _LinearRegressionParams, + JavaMLWritable, + JavaMLReadable[LinearRegression], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + standardization: bool = ..., + solver: str = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + epsilon: float = ..., + blockSize: int = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + standardization: bool = ..., + solver: str = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + epsilon: float = ..., + blockSize: int = ... + ) -> LinearRegression: ... + def setEpsilon(self, value: float) -> LinearRegression: ... + def setMaxIter(self, value: int) -> LinearRegression: ... + def setRegParam(self, value: float) -> LinearRegression: ... + def setTol(self, value: float) -> LinearRegression: ... + def setElasticNetParam(self, value: float) -> LinearRegression: ... + def setFitIntercept(self, value: bool) -> LinearRegression: ... + def setStandardization(self, value: bool) -> LinearRegression: ... + def setWeightCol(self, value: str) -> LinearRegression: ... + def setSolver(self, value: str) -> LinearRegression: ... + def setAggregationDepth(self, value: int) -> LinearRegression: ... + def setLoss(self, value: str) -> LinearRegression: ... + def setBlockSize(self, value: int) -> LinearRegression: ... + +class LinearRegressionModel( + _JavaRegressionModel[Vector], + _LinearRegressionParams, + GeneralJavaMLWritable, + JavaMLReadable[LinearRegressionModel], + HasTrainingSummary[LinearRegressionSummary], +): + @property + def coefficients(self) -> Vector: ... + @property + def intercept(self) -> float: ... + @property + def summary(self) -> LinearRegressionTrainingSummary: ... + def evaluate(self, dataset: DataFrame) -> LinearRegressionSummary: ... + +class LinearRegressionSummary(JavaWrapper): + @property + def predictions(self) -> DataFrame: ... + @property + def predictionCol(self) -> str: ... + @property + def labelCol(self) -> str: ... + @property + def featuresCol(self) -> str: ... + @property + def explainedVariance(self) -> float: ... + @property + def meanAbsoluteError(self) -> float: ... + @property + def meanSquaredError(self) -> float: ... + @property + def rootMeanSquaredError(self) -> float: ... + @property + def r2(self) -> float: ... + @property + def r2adj(self) -> float: ... + @property + def residuals(self) -> DataFrame: ... + @property + def numInstances(self) -> int: ... + @property + def devianceResiduals(self) -> List[float]: ... + @property + def coefficientStandardErrors(self) -> List[float]: ... + @property + def tValues(self) -> List[float]: ... + @property + def pValues(self) -> List[float]: ... + +class LinearRegressionTrainingSummary(LinearRegressionSummary): + @property + def objectiveHistory(self) -> List[float]: ... + @property + def totalIterations(self) -> int: ... + +class _IsotonicRegressionParams( + HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol +): + isotonic: Param[bool] + featureIndex: Param[int] + def getIsotonic(self) -> bool: ... + def getFeatureIndex(self) -> int: ... + +class IsotonicRegression( + JavaEstimator[IsotonicRegressionModel], + _IsotonicRegressionParams, + HasWeightCol, + JavaMLWritable, + JavaMLReadable[IsotonicRegression], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + weightCol: Optional[str] = ..., + isotonic: bool = ..., + featureIndex: int = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + weightCol: Optional[str] = ..., + isotonic: bool = ..., + featureIndex: int = ... + ) -> IsotonicRegression: ... + def setIsotonic(self, value: bool) -> IsotonicRegression: ... + def setFeatureIndex(self, value: int) -> IsotonicRegression: ... + def setFeaturesCol(self, value: str) -> IsotonicRegression: ... + def setPredictionCol(self, value: str) -> IsotonicRegression: ... + def setLabelCol(self, value: str) -> IsotonicRegression: ... + def setWeightCol(self, value: str) -> IsotonicRegression: ... + +class IsotonicRegressionModel( + JavaModel, + _IsotonicRegressionParams, + JavaMLWritable, + JavaMLReadable[IsotonicRegressionModel], +): + def setFeaturesCol(self, value: str) -> IsotonicRegressionModel: ... + def setPredictionCol(self, value: str) -> IsotonicRegressionModel: ... + def setFeatureIndex(self, value: int) -> IsotonicRegressionModel: ... + @property + def boundaries(self) -> Vector: ... + @property + def predictions(self) -> Vector: ... + @property + def numFeatures(self) -> int: ... + def predict(self, value: float) -> float: ... + +class _DecisionTreeRegressorParams( + _DecisionTreeParams, _TreeRegressorParams, HasVarianceCol +): + def __init__(self, *args: Any): ... + +class DecisionTreeRegressor( + _JavaRegressor[DecisionTreeRegressionModel], + _DecisionTreeRegressorParams, + JavaMLWritable, + JavaMLReadable[DecisionTreeRegressor], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + impurity: str = ..., + seed: Optional[int] = ..., + varianceCol: Optional[str] = ..., + weightCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + impurity: str = ..., + seed: Optional[int] = ..., + varianceCol: Optional[str] = ..., + weightCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ... + ) -> DecisionTreeRegressor: ... + def setMaxDepth(self, value: int) -> DecisionTreeRegressor: ... + def setMaxBins(self, value: int) -> DecisionTreeRegressor: ... + def setMinInstancesPerNode(self, value: int) -> DecisionTreeRegressor: ... + def setMinWeightFractionPerNode(self, value: float) -> DecisionTreeRegressor: ... + def setMinInfoGain(self, value: float) -> DecisionTreeRegressor: ... + def setMaxMemoryInMB(self, value: int) -> DecisionTreeRegressor: ... + def setCacheNodeIds(self, value: bool) -> DecisionTreeRegressor: ... + def setImpurity(self, value: str) -> DecisionTreeRegressor: ... + def setCheckpointInterval(self, value: int) -> DecisionTreeRegressor: ... + def setSeed(self, value: int) -> DecisionTreeRegressor: ... + def setWeightCol(self, value: str) -> DecisionTreeRegressor: ... + def setVarianceCol(self, value: str) -> DecisionTreeRegressor: ... + +class DecisionTreeRegressionModel( + _JavaRegressionModel[Vector], + _DecisionTreeModel, + _DecisionTreeRegressorParams, + JavaMLWritable, + JavaMLReadable[DecisionTreeRegressionModel], +): + def setVarianceCol(self, value: str) -> DecisionTreeRegressionModel: ... + @property + def featureImportances(self) -> Vector: ... + +class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): + def __init__(self, *args: Any): ... + +class RandomForestRegressor( + _JavaRegressor[RandomForestRegressionModel], + _RandomForestRegressorParams, + JavaMLWritable, + JavaMLReadable[RandomForestRegressor], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + impurity: str = ..., + subsamplingRate: float = ..., + seed: Optional[int] = ..., + numTrees: int = ..., + featureSubsetStrategy: str = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ..., + bootstrap: Optional[bool] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + checkpointInterval: int = ..., + impurity: str = ..., + subsamplingRate: float = ..., + seed: Optional[int] = ..., + numTrees: int = ..., + featureSubsetStrategy: str = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ..., + bootstrap: Optional[bool] = ... + ) -> RandomForestRegressor: ... + def setMaxDepth(self, value: int) -> RandomForestRegressor: ... + def setMaxBins(self, value: int) -> RandomForestRegressor: ... + def setMinInstancesPerNode(self, value: int) -> RandomForestRegressor: ... + def setMinInfoGain(self, value: float) -> RandomForestRegressor: ... + def setMaxMemoryInMB(self, value: int) -> RandomForestRegressor: ... + def setCacheNodeIds(self, value: bool) -> RandomForestRegressor: ... + def setImpurity(self, value: str) -> RandomForestRegressor: ... + def setNumTrees(self, value: int) -> RandomForestRegressor: ... + def setBootstrap(self, value: bool) -> RandomForestRegressor: ... + def setSubsamplingRate(self, value: float) -> RandomForestRegressor: ... + def setFeatureSubsetStrategy(self, value: str) -> RandomForestRegressor: ... + def setCheckpointInterval(self, value: int) -> RandomForestRegressor: ... + def setSeed(self, value: int) -> RandomForestRegressor: ... + def setWeightCol(self, value: str) -> RandomForestRegressor: ... + def setMinWeightFractionPerNode(self, value: float) -> RandomForestRegressor: ... + +class RandomForestRegressionModel( + _JavaRegressionModel[Vector], + _TreeEnsembleModel, + _RandomForestRegressorParams, + JavaMLWritable, + JavaMLReadable, +): + @property + def trees(self) -> List[DecisionTreeRegressionModel]: ... + @property + def featureImportances(self) -> Vector: ... + +class _GBTRegressorParams(_GBTParams, _TreeRegressorParams): + supportedLossTypes: List[str] + lossType: Param[str] + def __init__(self, *args: Any): ... + def getLossType(self) -> str: ... + +class GBTRegressor( + _JavaRegressor[GBTRegressionModel], + _GBTRegressorParams, + JavaMLWritable, + JavaMLReadable[GBTRegressor], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + subsamplingRate: float = ..., + checkpointInterval: int = ..., + lossType: str = ..., + maxIter: int = ..., + stepSize: float = ..., + seed: Optional[int] = ..., + impurity: str = ..., + featureSubsetStrategy: str = ..., + validationTol: float = ..., + validationIndicatorCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + maxMemoryInMB: int = ..., + cacheNodeIds: bool = ..., + subsamplingRate: float = ..., + checkpointInterval: int = ..., + lossType: str = ..., + maxIter: int = ..., + stepSize: float = ..., + seed: Optional[int] = ..., + impuriy: str = ..., + featureSubsetStrategy: str = ..., + validationTol: float = ..., + validationIndicatorCol: Optional[str] = ..., + leafCol: str = ..., + minWeightFractionPerNode: float = ..., + weightCol: Optional[str] = ... + ) -> GBTRegressor: ... + def setMaxDepth(self, value: int) -> GBTRegressor: ... + def setMaxBins(self, value: int) -> GBTRegressor: ... + def setMinInstancesPerNode(self, value: int) -> GBTRegressor: ... + def setMinInfoGain(self, value: float) -> GBTRegressor: ... + def setMaxMemoryInMB(self, value: int) -> GBTRegressor: ... + def setCacheNodeIds(self, value: bool) -> GBTRegressor: ... + def setImpurity(self, value: str) -> GBTRegressor: ... + def setLossType(self, value: str) -> GBTRegressor: ... + def setSubsamplingRate(self, value: float) -> GBTRegressor: ... + def setFeatureSubsetStrategy(self, value: str) -> GBTRegressor: ... + def setValidationIndicatorCol(self, value: str) -> GBTRegressor: ... + def setMaxIter(self, value: int) -> GBTRegressor: ... + def setCheckpointInterval(self, value: int) -> GBTRegressor: ... + def setSeed(self, value: int) -> GBTRegressor: ... + def setStepSize(self, value: float) -> GBTRegressor: ... + def setWeightCol(self, value: str) -> GBTRegressor: ... + def setMinWeightFractionPerNode(self, value: float) -> GBTRegressor: ... + +class GBTRegressionModel( + _JavaRegressionModel[Vector], + _TreeEnsembleModel, + _GBTRegressorParams, + JavaMLWritable, + JavaMLReadable[GBTRegressionModel], +): + @property + def featureImportances(self) -> Vector: ... + @property + def trees(self) -> List[DecisionTreeRegressionModel]: ... + def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> List[float]: ... + +class _AFTSurvivalRegressionParams( + _PredictorParams, + HasMaxIter, + HasTol, + HasFitIntercept, + HasAggregationDepth, + HasBlockSize, +): + censorCol: Param[str] + quantileProbabilities: Param[List[float]] + quantilesCol: Param[str] + def __init__(self, *args: Any): ... + def getCensorCol(self) -> str: ... + def getQuantileProbabilities(self) -> List[float]: ... + def getQuantilesCol(self) -> str: ... + +class AFTSurvivalRegression( + _JavaRegressor[AFTSurvivalRegressionModel], + _AFTSurvivalRegressionParams, + JavaMLWritable, + JavaMLReadable[AFTSurvivalRegression], +): + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + fitIntercept: bool = ..., + maxIter: int = ..., + tol: float = ..., + censorCol: str = ..., + quantileProbabilities: List[float] = ..., + quantilesCol: Optional[str] = ..., + aggregationDepth: int = ..., + blockSize: int = ... + ) -> None: ... + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + fitIntercept: bool = ..., + maxIter: int = ..., + tol: float = ..., + censorCol: str = ..., + quantileProbabilities: List[float] = ..., + quantilesCol: Optional[str] = ..., + aggregationDepth: int = ..., + blockSize: int = ... + ) -> AFTSurvivalRegression: ... + def setCensorCol(self, value: str) -> AFTSurvivalRegression: ... + def setQuantileProbabilities(self, value: List[float]) -> AFTSurvivalRegression: ... + def setQuantilesCol(self, value: str) -> AFTSurvivalRegression: ... + def setMaxIter(self, value: int) -> AFTSurvivalRegression: ... + def setTol(self, value: float) -> AFTSurvivalRegression: ... + def setFitIntercept(self, value: bool) -> AFTSurvivalRegression: ... + def setAggregationDepth(self, value: int) -> AFTSurvivalRegression: ... + def setBlockSize(self, value: int) -> AFTSurvivalRegression: ... + +class AFTSurvivalRegressionModel( + _JavaRegressionModel[Vector], + _AFTSurvivalRegressionParams, + JavaMLWritable, + JavaMLReadable[AFTSurvivalRegressionModel], +): + def setQuantileProbabilities( + self, value: List[float] + ) -> AFTSurvivalRegressionModel: ... + def setQuantilesCol(self, value: str) -> AFTSurvivalRegressionModel: ... + @property + def coefficients(self) -> Vector: ... + @property + def intercept(self) -> float: ... + @property + def scale(self) -> float: ... + def predictQuantiles(self, features: Vector) -> Vector: ... + def predict(self, features: Vector) -> float: ... + +class _GeneralizedLinearRegressionParams( + _PredictorParams, + HasFitIntercept, + HasMaxIter, + HasTol, + HasRegParam, + HasWeightCol, + HasSolver, + HasAggregationDepth, +): + family: Param[str] + link: Param[str] + linkPredictionCol: Param[str] + variancePower: Param[float] + linkPower: Param[float] + solver: Param[str] + offsetCol: Param[str] + def __init__(self, *args: Any): ... + def getFamily(self) -> str: ... + def getLinkPredictionCol(self) -> str: ... + def getLink(self) -> str: ... + def getVariancePower(self) -> float: ... + def getLinkPower(self) -> float: ... + def getOffsetCol(self) -> str: ... + +class GeneralizedLinearRegression( + _JavaRegressor[GeneralizedLinearRegressionModel], + _GeneralizedLinearRegressionParams, + JavaMLWritable, + JavaMLReadable[GeneralizedLinearRegression], +): + def __init__( + self, + *, + labelCol: str = ..., + featuresCol: str = ..., + predictionCol: str = ..., + family: str = ..., + link: Optional[str] = ..., + fitIntercept: bool = ..., + maxIter: int = ..., + tol: float = ..., + regParam: float = ..., + weightCol: Optional[str] = ..., + solver: str = ..., + linkPredictionCol: Optional[str] = ..., + variancePower: float = ..., + linkPower: Optional[float] = ..., + offsetCol: Optional[str] = ..., + aggregationDepth: int = ... + ) -> None: ... + def setParams( + self, + *, + labelCol: str = ..., + featuresCol: str = ..., + predictionCol: str = ..., + family: str = ..., + link: Optional[str] = ..., + fitIntercept: bool = ..., + maxIter: int = ..., + tol: float = ..., + regParam: float = ..., + weightCol: Optional[str] = ..., + solver: str = ..., + linkPredictionCol: Optional[str] = ..., + variancePower: float = ..., + linkPower: Optional[float] = ..., + offsetCol: Optional[str] = ..., + aggregationDepth: int = ... + ) -> GeneralizedLinearRegression: ... + def setFamily(self, value: str) -> GeneralizedLinearRegression: ... + def setLinkPredictionCol(self, value: str) -> GeneralizedLinearRegression: ... + def setLink(self, value: str) -> GeneralizedLinearRegression: ... + def setVariancePower(self, value: float) -> GeneralizedLinearRegression: ... + def setLinkPower(self, value: float) -> GeneralizedLinearRegression: ... + def setOffsetCol(self, value: str) -> GeneralizedLinearRegression: ... + def setMaxIter(self, value: int) -> GeneralizedLinearRegression: ... + def setRegParam(self, value: float) -> GeneralizedLinearRegression: ... + def setTol(self, value: float) -> GeneralizedLinearRegression: ... + def setFitIntercept(self, value: bool) -> GeneralizedLinearRegression: ... + def setWeightCol(self, value: str) -> GeneralizedLinearRegression: ... + def setSolver(self, value: str) -> GeneralizedLinearRegression: ... + def setAggregationDepth(self, value: int) -> GeneralizedLinearRegression: ... + +class GeneralizedLinearRegressionModel( + _JavaRegressionModel[Vector], + _GeneralizedLinearRegressionParams, + JavaMLWritable, + JavaMLReadable[GeneralizedLinearRegressionModel], + HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary], +): + def setLinkPredictionCol(self, value: str) -> GeneralizedLinearRegressionModel: ... + @property + def coefficients(self) -> Vector: ... + @property + def intercept(self) -> float: ... + @property + def summary(self) -> GeneralizedLinearRegressionTrainingSummary: ... + def evaluate(self, dataset: DataFrame) -> GeneralizedLinearRegressionSummary: ... + +class GeneralizedLinearRegressionSummary(JavaWrapper): + @property + def predictions(self) -> DataFrame: ... + @property + def predictionCol(self) -> str: ... + @property + def rank(self) -> int: ... + @property + def degreesOfFreedom(self) -> int: ... + @property + def residualDegreeOfFreedom(self) -> int: ... + @property + def residualDegreeOfFreedomNull(self) -> int: ... + def residuals(self, residualsType: str = ...) -> DataFrame: ... + @property + def nullDeviance(self) -> float: ... + @property + def deviance(self) -> float: ... + @property + def dispersion(self) -> float: ... + @property + def aic(self) -> float: ... + +class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary): + @property + def numIterations(self) -> int: ... + @property + def solver(self) -> str: ... + @property + def coefficientStandardErrors(self) -> List[float]: ... + @property + def tValues(self) -> List[float]: ... + @property + def pValues(self) -> List[float]: ... + +class _FactorizationMachinesParams( + _PredictorParams, + HasMaxIter, + HasStepSize, + HasTol, + HasSolver, + HasSeed, + HasFitIntercept, + HasRegParam, + HasWeightCol, +): + factorSize: Param[int] + fitLinear: Param[bool] + miniBatchFraction: Param[float] + initStd: Param[float] + solver: Param[str] + def __init__(self, *args: Any): ... + def getFactorSize(self): ... + def getFitLinear(self): ... + def getMiniBatchFraction(self): ... + def getInitStd(self): ... + +class FMRegressor( + _JavaRegressor[FMRegressionModel], + _FactorizationMachinesParams, + JavaMLWritable, + JavaMLReadable[FMRegressor], +): + factorSize: Param[int] + fitLinear: Param[bool] + miniBatchFraction: Param[float] + initStd: Param[float] + solver: Param[str] + def __init__( + self, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + factorSize: int = ..., + fitIntercept: bool = ..., + fitLinear: bool = ..., + regParam: float = ..., + miniBatchFraction: float = ..., + initStd: float = ..., + maxIter: int = ..., + stepSize: float = ..., + tol: float = ..., + solver: str = ..., + seed: Optional[int] = ..., + ) -> None: ... + def setParams( + self, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + factorSize: int = ..., + fitIntercept: bool = ..., + fitLinear: bool = ..., + regParam: float = ..., + miniBatchFraction: float = ..., + initStd: float = ..., + maxIter: int = ..., + stepSize: float = ..., + tol: float = ..., + solver: str = ..., + seed: Optional[int] = ..., + ) -> FMRegressor: ... + def setFactorSize(self, value: int) -> FMRegressor: ... + def setFitLinear(self, value: bool) -> FMRegressor: ... + def setMiniBatchFraction(self, value: float) -> FMRegressor: ... + def setInitStd(self, value: float) -> FMRegressor: ... + def setMaxIter(self, value: int) -> FMRegressor: ... + def setStepSize(self, value: float) -> FMRegressor: ... + def setTol(self, value: float) -> FMRegressor: ... + def setSolver(self, value: str) -> FMRegressor: ... + def setSeed(self, value: int) -> FMRegressor: ... + def setFitIntercept(self, value: bool) -> FMRegressor: ... + def setRegParam(self, value: float) -> FMRegressor: ... + +class FMRegressionModel( + _JavaRegressionModel, + _FactorizationMachinesParams, + JavaMLWritable, + JavaMLReadable[FMRegressionModel], +): + @property + def intercept(self) -> float: ... + @property + def linear(self) -> Vector: ... + @property + def factors(self) -> Matrix: ... diff --git a/python/pyspark/ml/stat.pyi b/python/pyspark/ml/stat.pyi new file mode 100644 index 0000000000000..83b0f7eacb8f0 --- /dev/null +++ b/python/pyspark/ml/stat.pyi @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional + +from pyspark.ml.linalg import Matrix, Vector +from pyspark.ml.wrapper import JavaWrapper +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame + +from py4j.java_gateway import JavaObject # type: ignore[import] + +class ChiSquareTest: + @staticmethod + def test( + dataset: DataFrame, featuresCol: str, labelCol: str, flatten: bool = ... + ) -> DataFrame: ... + +class Correlation: + @staticmethod + def corr(dataset: DataFrame, column: str, method: str = ...) -> DataFrame: ... + +class KolmogorovSmirnovTest: + @staticmethod + def test( + dataset: DataFrame, sampleCol: str, distName: str, *params: float + ) -> DataFrame: ... + +class Summarizer: + @staticmethod + def mean(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def sum(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def variance(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def std(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def count(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def numNonZeros(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def max(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def min(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def normL1(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def normL2(col: Column, weightCol: Optional[Column] = ...) -> Column: ... + @staticmethod + def metrics(*metrics: str) -> SummaryBuilder: ... + +class SummaryBuilder(JavaWrapper): + def __init__(self, jSummaryBuilder: JavaObject) -> None: ... + def summary( + self, featuresCol: Column, weightCol: Optional[Column] = ... + ) -> Column: ... + +class MultivariateGaussian: + mean: Vector + cov: Matrix + def __init__(self, mean: Vector, cov: Matrix) -> None: ... + +class ANOVATest: + @staticmethod + def test( + dataset: DataFrame, featuresCol: str, labelCol: str, flatten: bool = ... + ) -> DataFrame: ... + +class FValueTest: + @staticmethod + def test( + dataset: DataFrame, featuresCol: str, labelCol: str, flatten: bool = ... + ) -> DataFrame: ... diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index 492e849658f7a..03653c25b4ad4 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -333,7 +333,7 @@ def test_linear_regression_with_huber_loss(self): from pyspark.ml.tests.test_algorithms import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py index cba5369ca2623..d2c0bdfdf8556 100644 --- a/python/pyspark/ml/tests/test_base.py +++ b/python/pyspark/ml/tests/test_base.py @@ -70,7 +70,7 @@ def testDefaultFitMultiple(self): from pyspark.ml.tests.test_base import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py index 7883df7882769..746605076f86b 100644 --- a/python/pyspark/ml/tests/test_evaluation.py +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -56,7 +56,7 @@ def test_clustering_evaluator_with_cosine_distance(self): from pyspark.ml.tests.test_evaluation import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 7fd8c0b669d9a..244110a986138 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -303,7 +303,7 @@ def test_apply_binary_term_freqs(self): from pyspark.ml.tests.test_feature import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py index 069ffceb50103..ceecdae971c99 100644 --- a/python/pyspark/ml/tests/test_image.py +++ b/python/pyspark/ml/tests/test_image.py @@ -69,7 +69,7 @@ def test_read_images(self): from pyspark.ml.tests.test_image import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 60dda82fe0911..18c01ddf88e67 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -381,7 +381,7 @@ def test_infer_schema(self): from pyspark.ml.tests.test_linalg import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index abee6d1be5e29..4cddf50f36bdf 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -372,7 +372,7 @@ def test_java_params(self): from pyspark.ml.tests.test_param import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index 4acf58da21531..826e6cd351d32 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -456,7 +456,7 @@ def test_default_read_write_default_params(self): from pyspark.ml.tests.test_persistence import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py index 011e6537a8db5..c29b2d3f44679 100644 --- a/python/pyspark/ml/tests/test_pipeline.py +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -62,7 +62,7 @@ def doTransform(pipeline): from pyspark.ml.tests.test_pipeline import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py index 666d0aec58db5..a2403b38873db 100644 --- a/python/pyspark/ml/tests/test_stat.py +++ b/python/pyspark/ml/tests/test_stat.py @@ -43,7 +43,7 @@ def test_chisquaretest(self): from pyspark.ml.tests.test_stat import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index cb0effbe2bf2a..7dafdcb3d683b 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -445,7 +445,7 @@ def test_kmeans_summary(self): from pyspark.ml.tests.test_training_summary import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index c9163627fdd54..729e46419ae2c 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -864,7 +864,7 @@ def test_copy(self): from pyspark.ml.tests.test_tuning import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py index e6eef8a7de97a..31475299c7b98 100644 --- a/python/pyspark/ml/tests/test_wrapper.py +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -21,7 +21,9 @@ from pyspark.ml.linalg import DenseVector, Vectors from pyspark.ml.regression import LinearRegression -from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper +from pyspark.ml.wrapper import ( # type: ignore[attr-defined] + _java2py, _py2java, JavaParams, JavaWrapper +) from pyspark.testing.mllibutils import MLlibTestCase from pyspark.testing.mlutils import SparkSessionTestCase from pyspark.testing.utils import eventually @@ -120,7 +122,7 @@ def test_new_java_array(self): from pyspark.ml.tests.test_wrapper import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/ml/tree.pyi b/python/pyspark/ml/tree.pyi new file mode 100644 index 0000000000000..ff6307654c569 --- /dev/null +++ b/python/pyspark/ml/tree.pyi @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Sequence +from pyspark.ml._typing import P, T + +from pyspark.ml.linalg import Vector +from pyspark import since as since # noqa: F401 +from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401 +from pyspark.ml.param import Param, Params as Params +from pyspark.ml.param.shared import ( # noqa: F401 + HasCheckpointInterval as HasCheckpointInterval, + HasMaxIter as HasMaxIter, + HasSeed as HasSeed, + HasStepSize as HasStepSize, + HasValidationIndicatorCol as HasValidationIndicatorCol, + HasWeightCol as HasWeightCol, + Param as Param, + TypeConverters as TypeConverters, +) +from pyspark.ml.wrapper import JavaPredictionModel as JavaPredictionModel + +class _DecisionTreeModel(JavaPredictionModel[T]): + @property + def numNodes(self) -> int: ... + @property + def depth(self) -> int: ... + @property + def toDebugString(self) -> str: ... + def predictLeaf(self, value: Vector) -> float: ... + +class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): + leafCol: Param[str] + maxDepth: Param[int] + maxBins: Param[int] + minInstancesPerNode: Param[int] + minWeightFractionPerNode: Param[float] + minInfoGain: Param[float] + maxMemoryInMB: Param[int] + cacheNodeIds: Param[bool] + def __init__(self) -> None: ... + def setLeafCol(self: P, value: str) -> P: ... + def getLeafCol(self) -> str: ... + def getMaxDepth(self) -> int: ... + def getMaxBins(self) -> int: ... + def getMinInstancesPerNode(self) -> int: ... + def getMinInfoGain(self) -> float: ... + def getMaxMemoryInMB(self) -> int: ... + def getCacheNodeIds(self) -> bool: ... + +class _TreeEnsembleModel(JavaPredictionModel[T]): + @property + def trees(self) -> Sequence[_DecisionTreeModel]: ... + @property + def getNumTrees(self) -> int: ... + @property + def treeWeights(self) -> List[float]: ... + @property + def totalNumNodes(self) -> int: ... + @property + def toDebugString(self) -> str: ... + +class _TreeEnsembleParams(_DecisionTreeParams): + subsamplingRate: Param[float] + supportedFeatureSubsetStrategies: List[str] + featureSubsetStrategy: Param[str] + def __init__(self) -> None: ... + def getSubsamplingRate(self) -> float: ... + def getFeatureSubsetStrategy(self) -> str: ... + +class _RandomForestParams(_TreeEnsembleParams): + numTrees: Param[int] + bootstrap: Param[bool] + def __init__(self) -> None: ... + def getNumTrees(self) -> int: ... + def getBootstrap(self) -> bool: ... + +class _GBTParams( + _TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol +): + stepSize: Param[float] + validationTol: Param[float] + def getValidationTol(self) -> float: ... + +class _HasVarianceImpurity(Params): + supportedImpurities: List[str] + impurity: Param[str] + def __init__(self) -> None: ... + def getImpurity(self) -> str: ... + +class _TreeClassifierParams(Params): + supportedImpurities: List[str] + impurity: Param[str] + def __init__(self) -> None: ... + def getImpurity(self) -> str: ... + +class _TreeRegressorParams(_HasVarianceImpurity): ... diff --git a/python/pyspark/ml/tuning.pyi b/python/pyspark/ml/tuning.pyi new file mode 100644 index 0000000000000..63cd75f0e1d74 --- /dev/null +++ b/python/pyspark/ml/tuning.pyi @@ -0,0 +1,185 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, List, Optional, Tuple, Type +from pyspark.ml._typing import ParamMap + +from pyspark.ml import Estimator, Model +from pyspark.ml.evaluation import Evaluator +from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed +from pyspark.ml.util import MLReader, MLReadable, MLWriter, MLWritable + +class ParamGridBuilder: + def __init__(self) -> None: ... + def addGrid(self, param: Param, values: List[Any]) -> ParamGridBuilder: ... + @overload + def baseOn(self, __args: ParamMap) -> ParamGridBuilder: ... + @overload + def baseOn(self, *args: Tuple[Param, Any]) -> ParamGridBuilder: ... + def build(self) -> List[ParamMap]: ... + +class _ValidatorParams(HasSeed): + estimator: Param[Estimator] + estimatorParamMaps: Param[List[ParamMap]] + evaluator: Param[Evaluator] + def getEstimator(self) -> Estimator: ... + def getEstimatorParamMaps(self) -> List[ParamMap]: ... + def getEvaluator(self) -> Evaluator: ... + +class _CrossValidatorParams(_ValidatorParams): + numFolds: Param[int] + foldCol: Param[str] + def __init__(self, *args: Any): ... + def getNumFolds(self) -> int: ... + def getFoldCol(self) -> str: ... + +class CrossValidator( + Estimator[CrossValidatorModel], + _CrossValidatorParams, + HasParallelism, + HasCollectSubModels, + MLReadable[CrossValidator], + MLWritable, +): + def __init__( + self, + *, + estimator: Optional[Estimator] = ..., + estimatorParamMaps: Optional[List[ParamMap]] = ..., + evaluator: Optional[Evaluator] = ..., + numFolds: int = ..., + seed: Optional[int] = ..., + parallelism: int = ..., + collectSubModels: bool = ..., + foldCol: str = ... + ) -> None: ... + def setParams( + self, + *, + estimator: Optional[Estimator] = ..., + estimatorParamMaps: Optional[List[ParamMap]] = ..., + evaluator: Optional[Evaluator] = ..., + numFolds: int = ..., + seed: Optional[int] = ..., + parallelism: int = ..., + collectSubModels: bool = ..., + foldCol: str = ... + ) -> CrossValidator: ... + def setEstimator(self, value: Estimator) -> CrossValidator: ... + def setEstimatorParamMaps(self, value: List[ParamMap]) -> CrossValidator: ... + def setEvaluator(self, value: Evaluator) -> CrossValidator: ... + def setNumFolds(self, value: int) -> CrossValidator: ... + def setFoldCol(self, value: str) -> CrossValidator: ... + def setSeed(self, value: int) -> CrossValidator: ... + def setParallelism(self, value: int) -> CrossValidator: ... + def setCollectSubModels(self, value: bool) -> CrossValidator: ... + def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidator: ... + def write(self) -> MLWriter: ... + @classmethod + def read(cls: Type[CrossValidator]) -> MLReader: ... + +class CrossValidatorModel( + Model, _CrossValidatorParams, MLReadable[CrossValidatorModel], MLWritable +): + bestModel: Model + avgMetrics: List[float] + subModels: List[List[Model]] + def __init__( + self, + bestModel: Model, + avgMetrics: List[float] = ..., + subModels: Optional[List[List[Model]]] = ..., + ) -> None: ... + def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidatorModel: ... + def write(self) -> MLWriter: ... + @classmethod + def read(cls: Type[CrossValidatorModel]) -> MLReader: ... + +class _TrainValidationSplitParams(_ValidatorParams): + trainRatio: Param[float] + def __init__(self, *args: Any): ... + def getTrainRatio(self) -> float: ... + +class TrainValidationSplit( + Estimator[TrainValidationSplitModel], + _TrainValidationSplitParams, + HasParallelism, + HasCollectSubModels, + MLReadable[TrainValidationSplit], + MLWritable, +): + def __init__( + self, + *, + estimator: Optional[Estimator] = ..., + estimatorParamMaps: Optional[List[ParamMap]] = ..., + evaluator: Optional[Evaluator] = ..., + trainRatio: float = ..., + parallelism: int = ..., + collectSubModels: bool = ..., + seed: Optional[int] = ... + ) -> None: ... + def setParams( + self, + *, + estimator: Optional[Estimator] = ..., + estimatorParamMaps: Optional[List[ParamMap]] = ..., + evaluator: Optional[Evaluator] = ..., + trainRatio: float = ..., + parallelism: int = ..., + collectSubModels: bool = ..., + seed: Optional[int] = ... + ) -> TrainValidationSplit: ... + def setEstimator(self, value: Estimator) -> TrainValidationSplit: ... + def setEstimatorParamMaps(self, value: List[ParamMap]) -> TrainValidationSplit: ... + def setEvaluator(self, value: Evaluator) -> TrainValidationSplit: ... + def setTrainRatio(self, value: float) -> TrainValidationSplit: ... + def setSeed(self, value: int) -> TrainValidationSplit: ... + def setParallelism(self, value: int) -> TrainValidationSplit: ... + def setCollectSubModels(self, value: bool) -> TrainValidationSplit: ... + def copy(self, extra: Optional[ParamMap] = ...) -> TrainValidationSplit: ... + def write(self) -> MLWriter: ... + @classmethod + def read(cls: Type[TrainValidationSplit]) -> MLReader: ... + +class TrainValidationSplitModel( + Model, + _TrainValidationSplitParams, + MLReadable[TrainValidationSplitModel], + MLWritable, +): + bestModel: Model + validationMetrics: List[float] + subModels: List[Model] + def __init__( + self, + bestModel: Model, + validationMetrics: List[float] = ..., + subModels: Optional[List[Model]] = ..., + ) -> None: ... + def setEstimator(self, value: Estimator) -> TrainValidationSplitModel: ... + def setEstimatorParamMaps( + self, value: List[ParamMap] + ) -> TrainValidationSplitModel: ... + def setEvaluator(self, value: Evaluator) -> TrainValidationSplitModel: ... + def copy(self, extra: Optional[ParamMap] = ...) -> TrainValidationSplitModel: ... + def write(self) -> MLWriter: ... + @classmethod + def read(cls: Type[TrainValidationSplitModel]) -> MLReader: ... diff --git a/python/pyspark/ml/util.pyi b/python/pyspark/ml/util.pyi new file mode 100644 index 0000000000000..d0781b2e26ed5 --- /dev/null +++ b/python/pyspark/ml/util.pyi @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union + +from pyspark import SparkContext as SparkContext, since as since # noqa: F401 +from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401 +from pyspark.sql import SparkSession as SparkSession +from pyspark.util import VersionUtils as VersionUtils # noqa: F401 + +S = TypeVar("S") +R = TypeVar("R", bound=MLReadable) + +class Identifiable: + uid: str + def __init__(self) -> None: ... + +class BaseReadWrite: + def __init__(self) -> None: ... + def session(self, sparkSession: SparkSession) -> Union[MLWriter, MLReader]: ... + @property + def sparkSession(self) -> SparkSession: ... + @property + def sc(self) -> SparkContext: ... + +class MLWriter(BaseReadWrite): + shouldOverwrite: bool = ... + def __init__(self) -> None: ... + def save(self, path: str) -> None: ... + def saveImpl(self, path: str) -> None: ... + def overwrite(self) -> MLWriter: ... + +class GeneralMLWriter(MLWriter): + source: str + def format(self, source: str) -> MLWriter: ... + +class JavaMLWriter(MLWriter): + def __init__(self, instance: JavaMLWritable) -> None: ... + def save(self, path: str) -> None: ... + def overwrite(self) -> JavaMLWriter: ... + def option(self, key: str, value: Any) -> JavaMLWriter: ... + def session(self, sparkSession: SparkSession) -> JavaMLWriter: ... + +class GeneralJavaMLWriter(JavaMLWriter): + def __init__(self, instance: MLWritable) -> None: ... + def format(self, source: str) -> GeneralJavaMLWriter: ... + +class MLWritable: + def write(self) -> MLWriter: ... + def save(self, path: str) -> None: ... + +class JavaMLWritable(MLWritable): + def write(self) -> JavaMLWriter: ... + +class GeneralJavaMLWritable(JavaMLWritable): + def write(self) -> GeneralJavaMLWriter: ... + +class MLReader(BaseReadWrite, Generic[R]): + def load(self, path: str) -> R: ... + +class JavaMLReader(MLReader[R]): + def __init__(self, clazz: Type[JavaMLReadable]) -> None: ... + def load(self, path: str) -> R: ... + def session(self, sparkSession: SparkSession) -> JavaMLReader[R]: ... + +class MLReadable(Generic[R]): + @classmethod + def read(cls: Type[R]) -> MLReader[R]: ... + @classmethod + def load(cls: Type[R], path: str) -> R: ... + +class JavaMLReadable(MLReadable[R]): + @classmethod + def read(cls: Type[R]) -> JavaMLReader[R]: ... + +class DefaultParamsWritable(MLWritable): + def write(self) -> MLWriter: ... + +class DefaultParamsWriter(MLWriter): + instance: DefaultParamsWritable + def __init__(self, instance: DefaultParamsWritable) -> None: ... + def saveImpl(self, path: str) -> None: ... + @staticmethod + def saveMetadata( + instance: DefaultParamsWritable, + path: str, + sc: SparkContext, + extraMetadata: Optional[Dict[str, Any]] = ..., + paramMap: Optional[Dict[str, Any]] = ..., + ) -> None: ... + +class DefaultParamsReadable(MLReadable[R]): + @classmethod + def read(cls: Type[R]) -> MLReader[R]: ... + +class DefaultParamsReader(MLReader[R]): + cls: Type[R] + def __init__(self, cls: Type[MLReadable]) -> None: ... + def load(self, path: str) -> R: ... + @staticmethod + def loadMetadata( + path: str, sc: SparkContext, expectedClassName: str = ... + ) -> Dict[str, Any]: ... + @staticmethod + def getAndSetParams(instance: R, metadata: Dict[str, Any]) -> None: ... + @staticmethod + def loadParamsInstance(path: str, sc: SparkContext) -> R: ... + +class HasTrainingSummary(Generic[S]): + @property + def hasSummary(self) -> bool: ... + @property + def summary(self) -> S: ... diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi new file mode 100644 index 0000000000000..830224c177d1e --- /dev/null +++ b/python/pyspark/ml/wrapper.pyi @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import abc +from typing import Any, Optional +from pyspark.ml._typing import P, T, JM, ParamMap + +from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, Model +from pyspark.ml.base import _PredictorParams +from pyspark.ml.param import Param, Params + +class JavaWrapper: + def __init__(self, java_obj: Optional[Any] = ...) -> None: ... + def __del__(self) -> None: ... + +class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta): + def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ... + def clear(self, param: Param) -> None: ... + +class JavaEstimator(JavaParams, Estimator[JM], metaclass=abc.ABCMeta): ... +class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta): ... + +class JavaModel(JavaTransformer, Model, metaclass=abc.ABCMeta): + def __init__(self, java_model: Optional[Any] = ...) -> None: ... + +class JavaPredictor( + Predictor[JM], JavaEstimator, _PredictorParams, metaclass=abc.ABCMeta +): ... + +class JavaPredictionModel(PredictionModel[T], JavaModel, _PredictorParams): + @property + def numFeatures(self) -> int: ... + def predict(self, value: T) -> float: ... diff --git a/python/pyspark/mllib/__init__.pyi b/python/pyspark/mllib/__init__.pyi new file mode 100644 index 0000000000000..83032c4580fc8 --- /dev/null +++ b/python/pyspark/mllib/__init__.pyi @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: This dynamically typed stub was automatically generated by stubgen. + +# Names in __all__ with no definition: +# classification +# clustering +# feature +# fpm +# linalg +# random +# recommendation +# regression +# stat +# tree +# util diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi new file mode 100644 index 0000000000000..213a69996b0ad --- /dev/null +++ b/python/pyspark/mllib/_typing.pyi @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Tuple, Union +from pyspark.mllib.linalg import Vector +from numpy import ndarray # noqa: F401 + +VectorLike = Union[Vector, List[float], Tuple[float, ...]] diff --git a/python/pyspark/mllib/classification.pyi b/python/pyspark/mllib/classification.pyi new file mode 100644 index 0000000000000..c51882c87bfc2 --- /dev/null +++ b/python/pyspark/mllib/classification.pyi @@ -0,0 +1,151 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Optional, Union + +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.mllib._typing import VectorLike +from pyspark.mllib.linalg import Vector +from pyspark.mllib.regression import LabeledPoint, LinearModel, StreamingLinearAlgorithm +from pyspark.mllib.util import Saveable, Loader +from pyspark.streaming.dstream import DStream + +from numpy import float64, ndarray # type: ignore[import] + +class LinearClassificationModel(LinearModel): + def __init__(self, weights: Vector, intercept: float) -> None: ... + def setThreshold(self, value: float) -> None: ... + @property + def threshold(self) -> Optional[float]: ... + def clearThreshold(self) -> None: ... + @overload + def predict(self, test: VectorLike) -> Union[int, float, float64]: ... + @overload + def predict(self, test: RDD[VectorLike]) -> RDD[Union[int, float]]: ... + +class LogisticRegressionModel(LinearClassificationModel): + def __init__( + self, weights: Vector, intercept: float, numFeatures: int, numClasses: int + ) -> None: ... + @property + def numFeatures(self) -> int: ... + @property + def numClasses(self) -> int: ... + @overload + def predict(self, x: VectorLike) -> Union[int, float]: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[Union[int, float]]: ... + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> LogisticRegressionModel: ... + +class LogisticRegressionWithSGD: + @classmethod + def train( + cls, + data: RDD[LabeledPoint], + iterations: int = ..., + step: float = ..., + miniBatchFraction: float = ..., + initialWeights: Optional[VectorLike] = ..., + regParam: float = ..., + regType: str = ..., + intercept: bool = ..., + validateData: bool = ..., + convergenceTol: float = ..., + ) -> LogisticRegressionModel: ... + +class LogisticRegressionWithLBFGS: + @classmethod + def train( + cls, + data: RDD[LabeledPoint], + iterations: int = ..., + initialWeights: Optional[VectorLike] = ..., + regParam: float = ..., + regType: str = ..., + intercept: bool = ..., + corrections: int = ..., + tolerance: float = ..., + validateData: bool = ..., + numClasses: int = ..., + ) -> LogisticRegressionModel: ... + +class SVMModel(LinearClassificationModel): + def __init__(self, weights: Vector, intercept: float) -> None: ... + @overload + def predict(self, x: VectorLike) -> float64: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[float64]: ... + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> SVMModel: ... + +class SVMWithSGD: + @classmethod + def train( + cls, + data: RDD[LabeledPoint], + iterations: int = ..., + step: float = ..., + regParam: float = ..., + miniBatchFraction: float = ..., + initialWeights: Optional[VectorLike] = ..., + regType: str = ..., + intercept: bool = ..., + validateData: bool = ..., + convergenceTol: float = ..., + ) -> SVMModel: ... + +class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]): + labels: ndarray + pi: ndarray + theta: ndarray + def __init__(self, labels, pi, theta) -> None: ... + @overload + def predict(self, x: VectorLike) -> float64: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[float64]: ... + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> NaiveBayesModel: ... + +class NaiveBayes: + @classmethod + def train(cls, data: RDD[VectorLike], lambda_: float = ...) -> NaiveBayesModel: ... + +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): + stepSize: float + numIterations: int + regParam: float + miniBatchFraction: float + convergenceTol: float + def __init__( + self, + stepSize: float = ..., + numIterations: int = ..., + miniBatchFraction: float = ..., + regParam: float = ..., + convergenceTol: float = ..., + ) -> None: ... + def setInitialWeights( + self, initialWeights: VectorLike + ) -> StreamingLogisticRegressionWithSGD: ... + def trainOn(self, dstream: DStream[LabeledPoint]) -> None: ... diff --git a/python/pyspark/mllib/clustering.pyi b/python/pyspark/mllib/clustering.pyi new file mode 100644 index 0000000000000..1c3eba17e201c --- /dev/null +++ b/python/pyspark/mllib/clustering.pyi @@ -0,0 +1,196 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import List, NamedTuple, Optional, Tuple, TypeVar + +import array + +from numpy import float64, int64, ndarray # type: ignore[import] +from py4j.java_gateway import JavaObject # type: ignore[import] + +from pyspark.mllib._typing import VectorLike +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.util import Saveable, Loader, JavaLoader, JavaSaveable +from pyspark.streaming.dstream import DStream + +T = TypeVar("T") + +class BisectingKMeansModel(JavaModelWrapper): + centers: List[ndarray] + def __init__(self, java_model: JavaObject) -> None: ... + @property + def clusterCenters(self) -> List[ndarray]: ... + @property + def k(self) -> int: ... + @overload + def predict(self, x: VectorLike) -> int: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[int]: ... + @overload + def computeCost(self, x: VectorLike) -> float: ... + @overload + def computeCost(self, x: RDD[VectorLike]) -> float: ... + +class BisectingKMeans: + @classmethod + def train( + self, + rdd: RDD[VectorLike], + k: int = ..., + maxIterations: int = ..., + minDivisibleClusterSize: float = ..., + seed: int = ..., + ) -> BisectingKMeansModel: ... + +class KMeansModel(Saveable, Loader[KMeansModel]): + centers: List[ndarray] + def __init__(self, centers: List[ndarray]) -> None: ... + @property + def clusterCenters(self) -> List[ndarray]: ... + @property + def k(self) -> int: ... + @overload + def predict(self, x: VectorLike) -> int: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[int]: ... + def computeCost(self, rdd: RDD[VectorLike]) -> float: ... + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> KMeansModel: ... + +class KMeans: + @classmethod + def train( + cls, + rdd: RDD[VectorLike], + k: int, + maxIterations: int = ..., + initializationMode: str = ..., + seed: Optional[int] = ..., + initializationSteps: int = ..., + epsilon: float = ..., + initialModel: Optional[KMeansModel] = ..., + ) -> KMeansModel: ... + +class GaussianMixtureModel( + JavaModelWrapper, JavaSaveable, JavaLoader[GaussianMixtureModel] +): + @property + def weights(self) -> ndarray: ... + @property + def gaussians(self) -> List[MultivariateGaussian]: ... + @property + def k(self) -> int: ... + @overload + def predict(self, x: VectorLike) -> int64: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[int]: ... + @overload + def predictSoft(self, x: VectorLike) -> ndarray: ... + @overload + def predictSoft(self, x: RDD[VectorLike]) -> RDD[array.array]: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> GaussianMixtureModel: ... + +class GaussianMixture: + @classmethod + def train( + cls, + rdd: RDD[VectorLike], + k: int, + convergenceTol: float = ..., + maxIterations: int = ..., + seed: Optional[int] = ..., + initialModel: Optional[GaussianMixtureModel] = ..., + ) -> GaussianMixtureModel: ... + +class PowerIterationClusteringModel( + JavaModelWrapper, JavaSaveable, JavaLoader[PowerIterationClusteringModel] +): + @property + def k(self) -> int: ... + def assignments(self) -> RDD[PowerIterationClustering.Assignment]: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> PowerIterationClusteringModel: ... + +class PowerIterationClustering: + @classmethod + def train( + cls, + rdd: RDD[Tuple[int, int, float]], + k: int, + maxIterations: int = ..., + initMode: str = ..., + ) -> PowerIterationClusteringModel: ... + class Assignment(NamedTuple("Assignment", [("id", int), ("cluster", int)])): ... + +class StreamingKMeansModel(KMeansModel): + def __init__(self, clusterCenters, clusterWeights) -> None: ... + @property + def clusterWeights(self) -> List[float64]: ... + centers: ndarray + def update( + self, data: RDD[VectorLike], decayFactor: float, timeUnit: str + ) -> StreamingKMeansModel: ... + +class StreamingKMeans: + def __init__( + self, k: int = ..., decayFactor: float = ..., timeUnit: str = ... + ) -> None: ... + def latestModel(self) -> StreamingKMeansModel: ... + def setK(self, k: int) -> StreamingKMeans: ... + def setDecayFactor(self, decayFactor: float) -> StreamingKMeans: ... + def setHalfLife(self, halfLife: float, timeUnit: str) -> StreamingKMeans: ... + def setInitialCenters( + self, centers: List[VectorLike], weights: List[float] + ) -> StreamingKMeans: ... + def setRandomCenters( + self, dim: int, weight: float, seed: int + ) -> StreamingKMeans: ... + def trainOn(self, dstream: DStream[VectorLike]) -> None: ... + def predictOn(self, dstream: DStream[VectorLike]) -> DStream[int]: ... + def predictOnValues( + self, dstream: DStream[Tuple[T, VectorLike]] + ) -> DStream[Tuple[T, int]]: ... + +class LDAModel(JavaModelWrapper, JavaSaveable, Loader[LDAModel]): + def topicsMatrix(self) -> ndarray: ... + def vocabSize(self) -> int: ... + def describeTopics( + self, maxTermsPerTopic: Optional[int] = ... + ) -> List[Tuple[List[int], List[float]]]: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> LDAModel: ... + +class LDA: + @classmethod + def train( + cls, + rdd: RDD[Tuple[int, VectorLike]], + k: int = ..., + maxIterations: int = ..., + docConcentration: float = ..., + topicConcentration: float = ..., + seed: Optional[int] = ..., + checkpointInterval: int = ..., + optimizer: str = ..., + ) -> LDAModel: ... diff --git a/python/pyspark/mllib/common.pyi b/python/pyspark/mllib/common.pyi new file mode 100644 index 0000000000000..1df308b91b5a1 --- /dev/null +++ b/python/pyspark/mllib/common.pyi @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +def callJavaFunc(sc, func, *args): ... +def callMLlibFunc(name, *args): ... + +class JavaModelWrapper: + def __init__(self, java_model) -> None: ... + def __del__(self): ... + def call(self, name, *a): ... + +def inherit_doc(cls): ... diff --git a/python/pyspark/mllib/evaluation.pyi b/python/pyspark/mllib/evaluation.pyi new file mode 100644 index 0000000000000..03583784f0c3b --- /dev/null +++ b/python/pyspark/mllib/evaluation.pyi @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Optional, Tuple, TypeVar +from pyspark.rdd import RDD +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.linalg import Matrix + +T = TypeVar("T") + +class BinaryClassificationMetrics(JavaModelWrapper): + def __init__(self, scoreAndLabels: RDD[Tuple[float, float]]) -> None: ... + @property + def areaUnderROC(self) -> float: ... + @property + def areaUnderPR(self) -> float: ... + def unpersist(self) -> None: ... + +class RegressionMetrics(JavaModelWrapper): + def __init__(self, predictionAndObservations: RDD[Tuple[float, float]]) -> None: ... + @property + def explainedVariance(self) -> float: ... + @property + def meanAbsoluteError(self) -> float: ... + @property + def meanSquaredError(self) -> float: ... + @property + def rootMeanSquaredError(self) -> float: ... + @property + def r2(self) -> float: ... + +class MulticlassMetrics(JavaModelWrapper): + def __init__(self, predictionAndLabels: RDD[Tuple[float, float]]) -> None: ... + def confusionMatrix(self) -> Matrix: ... + def truePositiveRate(self, label: float) -> float: ... + def falsePositiveRate(self, label: float) -> float: ... + def precision(self, label: float = ...) -> float: ... + def recall(self, label: float = ...) -> float: ... + def fMeasure(self, label: float = ..., beta: Optional[float] = ...) -> float: ... + @property + def accuracy(self) -> float: ... + @property + def weightedTruePositiveRate(self) -> float: ... + @property + def weightedFalsePositiveRate(self) -> float: ... + @property + def weightedRecall(self) -> float: ... + @property + def weightedPrecision(self) -> float: ... + def weightedFMeasure(self, beta: Optional[float] = ...) -> float: ... + +class RankingMetrics(JavaModelWrapper): + def __init__(self, predictionAndLabels: RDD[Tuple[List[T], List[T]]]) -> None: ... + def precisionAt(self, k: int) -> float: ... + @property + def meanAveragePrecision(self) -> float: ... + def meanAveragePrecisionAt(self, k: int) -> float: ... + def ndcgAt(self, k: int) -> float: ... + def recallAt(self, k: int) -> float: ... + +class MultilabelMetrics(JavaModelWrapper): + def __init__( + self, predictionAndLabels: RDD[Tuple[List[float], List[float]]] + ) -> None: ... + def precision(self, label: Optional[float] = ...) -> float: ... + def recall(self, label: Optional[float] = ...) -> float: ... + def f1Measure(self, label: Optional[float] = ...) -> float: ... + @property + def microPrecision(self) -> float: ... + @property + def microRecall(self) -> float: ... + @property + def microF1Measure(self) -> float: ... + @property + def hammingLoss(self) -> float: ... + @property + def subsetAccuracy(self) -> float: ... + @property + def accuracy(self) -> float: ... diff --git a/python/pyspark/mllib/feature.pyi b/python/pyspark/mllib/feature.pyi new file mode 100644 index 0000000000000..9ccec36abd6ff --- /dev/null +++ b/python/pyspark/mllib/feature.pyi @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Iterable, Hashable, List, Tuple + +from pyspark.mllib._typing import VectorLike +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.linalg import Vector +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable + +from py4j.java_collections import JavaMap # type: ignore[import] + +class VectorTransformer: + @overload + def transform(self, vector: VectorLike) -> Vector: ... + @overload + def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... + +class Normalizer(VectorTransformer): + p: float + def __init__(self, p: float = ...) -> None: ... + @overload + def transform(self, vector: VectorLike) -> Vector: ... + @overload + def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... + +class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): + @overload + def transform(self, vector: VectorLike) -> Vector: ... + @overload + def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... + +class StandardScalerModel(JavaVectorTransformer): + @overload + def transform(self, vector: VectorLike) -> Vector: ... + @overload + def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... + def setWithMean(self, withMean: bool) -> StandardScalerModel: ... + def setWithStd(self, withStd: bool) -> StandardScalerModel: ... + @property + def withStd(self) -> bool: ... + @property + def withMean(self) -> bool: ... + @property + def std(self) -> Vector: ... + @property + def mean(self) -> Vector: ... + +class StandardScaler: + withMean: bool + withStd: bool + def __init__(self, withMean: bool = ..., withStd: bool = ...) -> None: ... + def fit(self, dataset: RDD[VectorLike]) -> StandardScalerModel: ... + +class ChiSqSelectorModel(JavaVectorTransformer): + @overload + def transform(self, vector: VectorLike) -> Vector: ... + @overload + def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... + +class ChiSqSelector: + numTopFeatures: int + selectorType: str + percentile: float + fpr: float + fdr: float + fwe: float + def __init__( + self, + numTopFeatures: int = ..., + selectorType: str = ..., + percentile: float = ..., + fpr: float = ..., + fdr: float = ..., + fwe: float = ..., + ) -> None: ... + def setNumTopFeatures(self, numTopFeatures: int) -> ChiSqSelector: ... + def setPercentile(self, percentile: float) -> ChiSqSelector: ... + def setFpr(self, fpr: float) -> ChiSqSelector: ... + def setFdr(self, fdr: float) -> ChiSqSelector: ... + def setFwe(self, fwe: float) -> ChiSqSelector: ... + def setSelectorType(self, selectorType: str) -> ChiSqSelector: ... + def fit(self, data: RDD[LabeledPoint]) -> ChiSqSelectorModel: ... + +class PCAModel(JavaVectorTransformer): ... + +class PCA: + k: int + def __init__(self, k: int) -> None: ... + def fit(self, data: RDD[VectorLike]) -> PCAModel: ... + +class HashingTF: + numFeatures: int + binary: bool + def __init__(self, numFeatures: int = ...) -> None: ... + def setBinary(self, value: bool) -> HashingTF: ... + def indexOf(self, term: Hashable) -> int: ... + @overload + def transform(self, document: Iterable[Hashable]) -> Vector: ... + @overload + def transform(self, document: RDD[Iterable[Hashable]]) -> RDD[Vector]: ... + +class IDFModel(JavaVectorTransformer): + @overload + def transform(self, x: VectorLike) -> Vector: ... + @overload + def transform(self, x: RDD[VectorLike]) -> RDD[Vector]: ... + def idf(self) -> Vector: ... + def docFreq(self) -> List[int]: ... + def numDocs(self) -> int: ... + +class IDF: + minDocFreq: int + def __init__(self, minDocFreq: int = ...) -> None: ... + def fit(self, dataset: RDD[VectorLike]) -> IDFModel: ... + +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader[Word2VecModel]): + def transform(self, word: str) -> Vector: ... # type: ignore + def findSynonyms(self, word: str, num: int) -> Iterable[Tuple[str, float]]: ... + def getVectors(self) -> JavaMap: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> Word2VecModel: ... + +class Word2Vec: + vectorSize: int + learningRate: float + numPartitions: int + numIterations: int + seed: int + minCount: int + windowSize: int + def __init__(self) -> None: ... + def setVectorSize(self, vectorSize: int) -> Word2Vec: ... + def setLearningRate(self, learningRate: float) -> Word2Vec: ... + def setNumPartitions(self, numPartitions: int) -> Word2Vec: ... + def setNumIterations(self, numIterations: int) -> Word2Vec: ... + def setSeed(self, seed: int) -> Word2Vec: ... + def setMinCount(self, minCount: int) -> Word2Vec: ... + def setWindowSize(self, windowSize: int) -> Word2Vec: ... + def fit(self, data: RDD[List[str]]) -> Word2VecModel: ... + +class ElementwiseProduct(VectorTransformer): + scalingVector: Vector + def __init__(self, scalingVector: Vector) -> None: ... + @overload + def transform(self, vector: VectorLike) -> Vector: ... + @overload + def transform(self, vector: RDD[VectorLike]) -> RDD[Vector]: ... diff --git a/python/pyspark/mllib/fpm.pyi b/python/pyspark/mllib/fpm.pyi new file mode 100644 index 0000000000000..880baae1a91a5 --- /dev/null +++ b/python/pyspark/mllib/fpm.pyi @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Generic, List, TypeVar +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.util import JavaSaveable, JavaLoader + +T = TypeVar("T") + +class FPGrowthModel( + JavaModelWrapper, JavaSaveable, JavaLoader[FPGrowthModel], Generic[T] +): + def freqItemsets(self) -> RDD[FPGrowth.FreqItemset[T]]: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> FPGrowthModel: ... + +class FPGrowth: + @classmethod + def train( + cls, data: RDD[List[T]], minSupport: float = ..., numPartitions: int = ... + ) -> FPGrowthModel[T]: ... + class FreqItemset(Generic[T]): + items = ... # List[T] + freq = ... # int + +class PrefixSpanModel(JavaModelWrapper, Generic[T]): + def freqSequences(self) -> RDD[PrefixSpan.FreqSequence[T]]: ... + +class PrefixSpan: + @classmethod + def train( + cls, + data: RDD[List[List[T]]], + minSupport: float = ..., + maxPatternLength: int = ..., + maxLocalProjDBSize: int = ..., + ) -> PrefixSpanModel[T]: ... + class FreqSequence(tuple, Generic[T]): + sequence: List[T] + freq: int diff --git a/python/pyspark/mllib/linalg/__init__.pyi b/python/pyspark/mllib/linalg/__init__.pyi new file mode 100644 index 0000000000000..c0719c535c8f4 --- /dev/null +++ b/python/pyspark/mllib/linalg/__init__.pyi @@ -0,0 +1,273 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union +from pyspark.ml import linalg as newlinalg +from pyspark.sql.types import StructType, UserDefinedType +from numpy import float64, ndarray # type: ignore[import] + +QT = TypeVar("QT") +RT = TypeVar("RT") + +class VectorUDT(UserDefinedType): + @classmethod + def sqlType(cls) -> StructType: ... + @classmethod + def module(cls) -> str: ... + @classmethod + def scalaUDT(cls) -> str: ... + def serialize( + self, obj: Vector + ) -> Tuple[int, Optional[int], Optional[List[int]], List[float]]: ... + def deserialize(self, datum: Any) -> Vector: ... + def simpleString(self) -> str: ... + +class MatrixUDT(UserDefinedType): + @classmethod + def sqlType(cls) -> StructType: ... + @classmethod + def module(cls) -> str: ... + @classmethod + def scalaUDT(cls) -> str: ... + def serialize( + self, obj + ) -> Tuple[ + int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool + ]: ... + def deserialize(self, datum: Any) -> Matrix: ... + def simpleString(self) -> str: ... + +class Vector: + __UDT__: VectorUDT + def toArray(self) -> ndarray: ... + def asML(self) -> newlinalg.Vector: ... + +class DenseVector(Vector): + array: ndarray + @overload + def __init__(self, *elements: float) -> None: ... + @overload + def __init__(self, __arr: bytes) -> None: ... + @overload + def __init__(self, __arr: Iterable[float]) -> None: ... + @staticmethod + def parse(s) -> DenseVector: ... + def __reduce__(self) -> Tuple[type, bytes]: ... + def numNonzeros(self) -> int: ... + def norm(self, p: Union[float, str]) -> float64: ... + def dot(self, other: Iterable[float]) -> float64: ... + def squared_distance(self, other: Iterable[float]) -> float64: ... + def toArray(self) -> ndarray: ... + def asML(self) -> newlinalg.DenseVector: ... + @property + def values(self) -> ndarray: ... + def __getitem__(self, item: int) -> float64: ... + def __len__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __hash__(self) -> int: ... + def __getattr__(self, item: str) -> Any: ... + def __neg__(self) -> DenseVector: ... + def __add__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __sub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __mul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __div__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __truediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __mod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __radd__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rsub__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rmul__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rdiv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rtruediv__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + def __rmod__(self, other: Union[float, Iterable[float]]) -> DenseVector: ... + +class SparseVector(Vector): + size: int + indices: ndarray + values: ndarray + @overload + def __init__(self, size: int, *args: Tuple[int, float]) -> None: ... + @overload + def __init__(self, size: int, __indices: bytes, __values: bytes) -> None: ... + @overload + def __init__( + self, size: int, __indices: Iterable[int], __values: Iterable[float] + ) -> None: ... + @overload + def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]) -> None: ... + @overload + def __init__(self, size: int, __map: Dict[int, float]) -> None: ... + def numNonzeros(self) -> int: ... + def norm(self, p: Union[float, str]) -> float64: ... + def __reduce__(self): ... + @staticmethod + def parse(s: str) -> SparseVector: ... + def dot(self, other: Iterable[float]) -> float64: ... + def squared_distance(self, other: Iterable[float]) -> float64: ... + def toArray(self) -> ndarray: ... + def asML(self) -> newlinalg.SparseVector: ... + def __len__(self) -> int: ... + def __eq__(self, other) -> bool: ... + def __getitem__(self, index: int) -> float64: ... + def __ne__(self, other) -> bool: ... + def __hash__(self) -> int: ... + +class Vectors: + @overload + @staticmethod + def sparse(size: int, *args: Tuple[int, float]) -> SparseVector: ... + @overload + @staticmethod + def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: ... + @overload + @staticmethod + def sparse( + size: int, __indices: Iterable[int], __values: Iterable[float] + ) -> SparseVector: ... + @overload + @staticmethod + def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: ... + @overload + @staticmethod + def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... + @overload + @staticmethod + def dense(self, *elements: float) -> DenseVector: ... + @overload + @staticmethod + def dense(self, __arr: bytes) -> DenseVector: ... + @overload + @staticmethod + def dense(self, __arr: Iterable[float]) -> DenseVector: ... + @staticmethod + def fromML(vec: newlinalg.DenseVector) -> DenseVector: ... + @staticmethod + def stringify(vector: Vector) -> str: ... + @staticmethod + def squared_distance(v1: Vector, v2: Vector) -> float64: ... + @staticmethod + def norm(vector: Vector, p: Union[float, str]) -> float64: ... + @staticmethod + def parse(s: str) -> Vector: ... + @staticmethod + def zeros(size: int) -> DenseVector: ... + +class Matrix: + __UDT__: MatrixUDT + numRows: int + numCols: int + isTransposed: bool + def __init__( + self, numRows: int, numCols: int, isTransposed: bool = ... + ) -> None: ... + def toArray(self): ... + def asML(self): ... + +class DenseMatrix(Matrix): + values: Any + @overload + def __init__( + self, numRows: int, numCols: int, values: bytes, isTransposed: bool = ... + ) -> None: ... + @overload + def __init__( + self, + numRows: int, + numCols: int, + values: Iterable[float], + isTransposed: bool = ..., + ) -> None: ... + def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def toArray(self) -> ndarray: ... + def toSparse(self) -> SparseMatrix: ... + def asML(self) -> newlinalg.DenseMatrix: ... + def __getitem__(self, indices: Tuple[int, int]) -> float64: ... + def __eq__(self, other) -> bool: ... + +class SparseMatrix(Matrix): + colPtrs: ndarray + rowIndices: ndarray + values: ndarray + @overload + def __init__( + self, + numRows: int, + numCols: int, + colPtrs: bytes, + rowIndices: bytes, + values: bytes, + isTransposed: bool = ..., + ) -> None: ... + @overload + def __init__( + self, + numRows: int, + numCols: int, + colPtrs: Iterable[int], + rowIndices: Iterable[int], + values: Iterable[float], + isTransposed: bool = ..., + ) -> None: ... + def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __getitem__(self, indices: Tuple[int, int]) -> float64: ... + def toArray(self) -> ndarray: ... + def toDense(self) -> DenseMatrix: ... + def asML(self) -> newlinalg.SparseMatrix: ... + def __eq__(self, other) -> bool: ... + +class Matrices: + @overload + @staticmethod + def dense( + numRows: int, numCols: int, values: bytes, isTransposed: bool = ... + ) -> DenseMatrix: ... + @overload + @staticmethod + def dense( + numRows: int, numCols: int, values: Iterable[float], isTransposed: bool = ... + ) -> DenseMatrix: ... + @overload + @staticmethod + def sparse( + numRows: int, + numCols: int, + colPtrs: bytes, + rowIndices: bytes, + values: bytes, + isTransposed: bool = ..., + ) -> SparseMatrix: ... + @overload + @staticmethod + def sparse( + numRows: int, + numCols: int, + colPtrs: Iterable[int], + rowIndices: Iterable[int], + values: Iterable[float], + isTransposed: bool = ..., + ) -> SparseMatrix: ... + @staticmethod + def fromML(mat: newlinalg.Matrix) -> Matrix: ... + +class QRDecomposition(Generic[QT, RT]): + def __init__(self, Q: QT, R: RT) -> None: ... + @property + def Q(self) -> QT: ... + @property + def R(self) -> RT: ... diff --git a/python/pyspark/mllib/linalg/distributed.pyi b/python/pyspark/mllib/linalg/distributed.pyi new file mode 100644 index 0000000000000..238c4ea32e4e8 --- /dev/null +++ b/python/pyspark/mllib/linalg/distributed.pyi @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Generic, Sequence, Optional, Tuple, TypeVar, Union +from pyspark.rdd import RDD +from pyspark.storagelevel import StorageLevel +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.linalg import Vector, Matrix, QRDecomposition +from pyspark.mllib.stat import MultivariateStatisticalSummary +from numpy import ndarray # noqa: F401 + +VectorLike = Union[Vector, Sequence[Union[float, int]]] + +UT = TypeVar("UT") +VT = TypeVar("VT") + +class DistributedMatrix: + def numRows(self) -> int: ... + def numCols(self) -> int: ... + +class RowMatrix(DistributedMatrix): + def __init__( + self, rows: RDD[Vector], numRows: int = ..., numCols: int = ... + ) -> None: ... + @property + def rows(self) -> RDD[Vector]: ... + def numRows(self) -> int: ... + def numCols(self) -> int: ... + def computeColumnSummaryStatistics(self) -> MultivariateStatisticalSummary: ... + def computeCovariance(self) -> Matrix: ... + def computeGramianMatrix(self) -> Matrix: ... + def columnSimilarities(self, threshold: float = ...) -> CoordinateMatrix: ... + def tallSkinnyQR( + self, computeQ: bool = ... + ) -> QRDecomposition[RowMatrix, Matrix]: ... + def computeSVD( + self, k: int, computeU: bool = ..., rCond: float = ... + ) -> SingularValueDecomposition[RowMatrix, Matrix]: ... + def computePrincipalComponents(self, k: int) -> Matrix: ... + def multiply(self, matrix: Matrix) -> RowMatrix: ... + +class SingularValueDecomposition(JavaModelWrapper, Generic[UT, VT]): + @property + def U(self) -> Optional[UT]: ... + @property + def s(self) -> Vector: ... + @property + def V(self) -> VT: ... + +class IndexedRow: + index: int + vector: VectorLike + def __init__(self, index: int, vector: VectorLike) -> None: ... + +class IndexedRowMatrix(DistributedMatrix): + def __init__( + self, + rows: RDD[Union[Tuple[int, VectorLike], IndexedRow]], + numRows: int = ..., + numCols: int = ..., + ) -> None: ... + @property + def rows(self) -> RDD[IndexedRow]: ... + def numRows(self) -> int: ... + def numCols(self) -> int: ... + def columnSimilarities(self) -> CoordinateMatrix: ... + def computeGramianMatrix(self) -> Matrix: ... + def toRowMatrix(self) -> RowMatrix: ... + def toCoordinateMatrix(self) -> CoordinateMatrix: ... + def toBlockMatrix( + self, rowsPerBlock: int = ..., colsPerBlock: int = ... + ) -> BlockMatrix: ... + def computeSVD( + self, k: int, computeU: bool = ..., rCond: float = ... + ) -> SingularValueDecomposition[IndexedRowMatrix, Matrix]: ... + def multiply(self, matrix: Matrix) -> IndexedRowMatrix: ... + +class MatrixEntry: + i: int + j: int + value: float + def __init__(self, i: int, j: int, value: float) -> None: ... + +class CoordinateMatrix(DistributedMatrix): + def __init__( + self, + entries: RDD[Union[Tuple[int, int, float], MatrixEntry]], + numRows: int = ..., + numCols: int = ..., + ) -> None: ... + @property + def entries(self) -> RDD[MatrixEntry]: ... + def numRows(self) -> int: ... + def numCols(self) -> int: ... + def transpose(self) -> CoordinateMatrix: ... + def toRowMatrix(self) -> RowMatrix: ... + def toIndexedRowMatrix(self) -> IndexedRowMatrix: ... + def toBlockMatrix( + self, rowsPerBlock: int = ..., colsPerBlock: int = ... + ) -> BlockMatrix: ... + +class BlockMatrix(DistributedMatrix): + def __init__( + self, + blocks: RDD[Tuple[Tuple[int, int], Matrix]], + rowsPerBlock: int, + colsPerBlock: int, + numRows: int = ..., + numCols: int = ..., + ) -> None: ... + @property + def blocks(self) -> RDD[Tuple[Tuple[int, int], Matrix]]: ... + @property + def rowsPerBlock(self) -> int: ... + @property + def colsPerBlock(self) -> int: ... + @property + def numRowBlocks(self) -> int: ... + @property + def numColBlocks(self) -> int: ... + def numRows(self) -> int: ... + def numCols(self) -> int: ... + def cache(self) -> BlockMatrix: ... + def persist(self, storageLevel: StorageLevel) -> BlockMatrix: ... + def validate(self) -> None: ... + def add(self, other: BlockMatrix) -> BlockMatrix: ... + def subtract(self, other: BlockMatrix) -> BlockMatrix: ... + def multiply(self, other: BlockMatrix) -> BlockMatrix: ... + def transpose(self) -> BlockMatrix: ... + def toLocalMatrix(self) -> Matrix: ... + def toIndexedRowMatrix(self) -> IndexedRowMatrix: ... + def toCoordinateMatrix(self) -> CoordinateMatrix: ... diff --git a/python/pyspark/mllib/random.pyi b/python/pyspark/mllib/random.pyi new file mode 100644 index 0000000000000..dc5f4701614da --- /dev/null +++ b/python/pyspark/mllib/random.pyi @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.mllib.linalg import Vector + +class RandomRDDs: + @staticmethod + def uniformRDD( + sc: SparkContext, + size: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[float]: ... + @staticmethod + def normalRDD( + sc: SparkContext, + size: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[float]: ... + @staticmethod + def logNormalRDD( + sc: SparkContext, + mean: float, + std: float, + size: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[float]: ... + @staticmethod + def poissonRDD( + sc: SparkContext, + mean: float, + size: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[float]: ... + @staticmethod + def exponentialRDD( + sc: SparkContext, + mean: float, + size: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[float]: ... + @staticmethod + def gammaRDD( + sc: SparkContext, + shape: float, + scale: float, + size: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[float]: ... + @staticmethod + def uniformVectorRDD( + sc: SparkContext, + numRows: int, + numCols: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[Vector]: ... + @staticmethod + def normalVectorRDD( + sc: SparkContext, + numRows: int, + numCols: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[Vector]: ... + @staticmethod + def logNormalVectorRDD( + sc: SparkContext, + mean: float, + std, + numRows: int, + numCols: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[Vector]: ... + @staticmethod + def poissonVectorRDD( + sc: SparkContext, + mean: float, + numRows: int, + numCols: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[Vector]: ... + @staticmethod + def exponentialVectorRDD( + sc: SparkContext, + mean: float, + numRows: int, + numCols: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[Vector]: ... + @staticmethod + def gammaVectorRDD( + sc: SparkContext, + shape: float, + scale: float, + numRows: int, + numCols: int, + numPartitions: Optional[int] = ..., + seed: Optional[int] = ..., + ) -> RDD[Vector]: ... diff --git a/python/pyspark/mllib/recommendation.pyi b/python/pyspark/mllib/recommendation.pyi new file mode 100644 index 0000000000000..e2f15494209e9 --- /dev/null +++ b/python/pyspark/mllib/recommendation.pyi @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Optional, Tuple, Union + +import array +from collections import namedtuple + +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.util import JavaLoader, JavaSaveable + +class Rating(namedtuple("Rating", ["user", "product", "rating"])): + def __reduce__(self): ... + +class MatrixFactorizationModel( + JavaModelWrapper, JavaSaveable, JavaLoader[MatrixFactorizationModel] +): + def predict(self, user: int, product: int) -> float: ... + def predictAll(self, user_product: RDD[Tuple[int, int]]) -> RDD[Rating]: ... + def userFeatures(self) -> RDD[Tuple[int, array.array]]: ... + def productFeatures(self) -> RDD[Tuple[int, array.array]]: ... + def recommendUsers(self, product: int, num: int) -> List[Rating]: ... + def recommendProducts(self, user: int, num: int) -> List[Rating]: ... + def recommendProductsForUsers( + self, num: int + ) -> RDD[Tuple[int, Tuple[Rating, ...]]]: ... + def recommendUsersForProducts( + self, num: int + ) -> RDD[Tuple[int, Tuple[Rating, ...]]]: ... + @property + def rank(self) -> int: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> MatrixFactorizationModel: ... + +class ALS: + @classmethod + def train( + cls, + ratings: Union[RDD[Rating], RDD[Tuple[int, int, float]]], + rank: int, + iterations: int = ..., + lambda_: float = ..., + blocks: int = ..., + nonnegative: bool = ..., + seed: Optional[int] = ..., + ) -> MatrixFactorizationModel: ... + @classmethod + def trainImplicit( + cls, + ratings: Union[RDD[Rating], RDD[Tuple[int, int, float]]], + rank: int, + iterations: int = ..., + lambda_: float = ..., + blocks: int = ..., + alpha: float = ..., + nonnegative: bool = ..., + seed: Optional[int] = ..., + ) -> MatrixFactorizationModel: ... diff --git a/python/pyspark/mllib/regression.pyi b/python/pyspark/mllib/regression.pyi new file mode 100644 index 0000000000000..0283378b98cf3 --- /dev/null +++ b/python/pyspark/mllib/regression.pyi @@ -0,0 +1,155 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Iterable, Optional, Tuple, TypeVar +from pyspark.rdd import RDD +from pyspark.mllib._typing import VectorLike +from pyspark.context import SparkContext +from pyspark.mllib.linalg import Vector +from pyspark.mllib.util import Saveable, Loader +from pyspark.streaming.dstream import DStream +from numpy import ndarray # type: ignore[import] + +K = TypeVar("K") + +class LabeledPoint: + label: int + features: Vector + def __init__(self, label: float, features: Iterable[float]) -> None: ... + def __reduce__(self) -> Tuple[type, Tuple[bytes]]: ... + +class LinearModel: + def __init__(self, weights: Vector, intercept: float) -> None: ... + @property + def weights(self) -> Vector: ... + @property + def intercept(self) -> float: ... + +class LinearRegressionModelBase(LinearModel): + @overload + def predict(self, x: Vector) -> float: ... + @overload + def predict(self, x: RDD[Vector]) -> RDD[float]: ... + +class LinearRegressionModel(LinearRegressionModelBase): + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> LinearRegressionModel: ... + +class LinearRegressionWithSGD: + @classmethod + def train( + cls, + data: RDD[LabeledPoint], + iterations: int = ..., + step: float = ..., + miniBatchFraction: float = ..., + initialWeights: Optional[VectorLike] = ..., + regParam: float = ..., + regType: Optional[str] = ..., + intercept: bool = ..., + validateData: bool = ..., + convergenceTol: float = ..., + ) -> LinearRegressionModel: ... + +class LassoModel(LinearRegressionModelBase): + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> LassoModel: ... + +class LassoWithSGD: + @classmethod + def train( + cls, + data: RDD[LabeledPoint], + iterations: int = ..., + step: float = ..., + regParam: float = ..., + miniBatchFraction: float = ..., + initialWeights: Optional[VectorLike] = ..., + intercept: bool = ..., + validateData: bool = ..., + convergenceTol: float = ..., + ) -> LassoModel: ... + +class RidgeRegressionModel(LinearRegressionModelBase): + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> RidgeRegressionModel: ... + +class RidgeRegressionWithSGD: + @classmethod + def train( + cls, + data: RDD[LabeledPoint], + iterations: int = ..., + step: float = ..., + regParam: float = ..., + miniBatchFraction: float = ..., + initialWeights: Optional[VectorLike] = ..., + intercept: bool = ..., + validateData: bool = ..., + convergenceTol: float = ..., + ) -> RidgeRegressionModel: ... + +class IsotonicRegressionModel(Saveable, Loader[IsotonicRegressionModel]): + boundaries: ndarray + predictions: ndarray + isotonic: bool + def __init__( + self, boundaries: ndarray, predictions: ndarray, isotonic: bool + ) -> None: ... + @overload + def predict(self, x: Vector) -> ndarray: ... + @overload + def predict(self, x: RDD[Vector]) -> RDD[ndarray]: ... + def save(self, sc: SparkContext, path: str) -> None: ... + @classmethod + def load(cls, sc: SparkContext, path: str) -> IsotonicRegressionModel: ... + +class IsotonicRegression: + @classmethod + def train( + cls, data: RDD[VectorLike], isotonic: bool = ... + ) -> IsotonicRegressionModel: ... + +class StreamingLinearAlgorithm: + def __init__(self, model: LinearModel) -> None: ... + def latestModel(self) -> LinearModel: ... + def predictOn(self, dstream: DStream[VectorLike]) -> DStream[float]: ... + def predictOnValues( + self, dstream: DStream[Tuple[K, VectorLike]] + ) -> DStream[Tuple[K, float]]: ... + +class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): + stepSize: float + numIterations: int + miniBatchFraction: float + convergenceTol: float + def __init__( + self, + stepSize: float = ..., + numIterations: int = ..., + miniBatchFraction: float = ..., + convergenceTol: float = ..., + ) -> None: ... + def setInitialWeights( + self, initialWeights: VectorLike + ) -> StreamingLinearRegressionWithSGD: ... + def trainOn(self, dstream: DStream[LabeledPoint]) -> None: ... diff --git a/python/pyspark/mllib/stat/KernelDensity.pyi b/python/pyspark/mllib/stat/KernelDensity.pyi new file mode 100644 index 0000000000000..efc70c9470dbe --- /dev/null +++ b/python/pyspark/mllib/stat/KernelDensity.pyi @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterable +from pyspark.rdd import RDD +from numpy import ndarray # type: ignore[import] + +class KernelDensity: + def __init__(self) -> None: ... + def setBandwidth(self, bandwidth: float) -> None: ... + def setSample(self, sample: RDD[float]) -> None: ... + def estimate(self, points: Iterable[float]) -> ndarray: ... diff --git a/python/pyspark/mllib/stat/__init__.pyi b/python/pyspark/mllib/stat/__init__.pyi new file mode 100644 index 0000000000000..bdd080a08cd56 --- /dev/null +++ b/python/pyspark/mllib/stat/__init__.pyi @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.mllib.stat.KernelDensity import ( # noqa: F401 + KernelDensity as KernelDensity, +) +from pyspark.mllib.stat._statistics import ( # noqa: F401 + MultivariateStatisticalSummary as MultivariateStatisticalSummary, + Statistics as Statistics, +) +from pyspark.mllib.stat.distribution import ( # noqa: F401 + MultivariateGaussian as MultivariateGaussian, +) +from pyspark.mllib.stat.test import ChiSqTestResult as ChiSqTestResult # noqa: F401 diff --git a/python/pyspark/mllib/stat/_statistics.pyi b/python/pyspark/mllib/stat/_statistics.pyi new file mode 100644 index 0000000000000..4d2701d486881 --- /dev/null +++ b/python/pyspark/mllib/stat/_statistics.pyi @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Optional, overload, Union +from typing_extensions import Literal + +from numpy import ndarray # type: ignore[import] + +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.linalg import Vector, Matrix +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat.test import ChiSqTestResult, KolmogorovSmirnovTestResult +from pyspark.rdd import RDD + +CorrelationMethod = Union[Literal["spearman"], Literal["pearson"]] + +class MultivariateStatisticalSummary(JavaModelWrapper): + def mean(self) -> ndarray: ... + def variance(self) -> ndarray: ... + def count(self) -> int: ... + def numNonzeros(self) -> ndarray: ... + def max(self) -> ndarray: ... + def min(self) -> ndarray: ... + def normL1(self) -> ndarray: ... + def normL2(self) -> ndarray: ... + +class Statistics: + @staticmethod + def colStats(rdd: RDD[Vector]) -> MultivariateStatisticalSummary: ... + @overload + @staticmethod + def corr( + x: RDD[Vector], *, method: Optional[CorrelationMethod] = ... + ) -> Matrix: ... + @overload + @staticmethod + def corr( + x: RDD[float], y: RDD[float], method: Optional[CorrelationMethod] = ... + ) -> float: ... + @overload + @staticmethod + def chiSqTest(observed: Matrix) -> ChiSqTestResult: ... + @overload + @staticmethod + def chiSqTest( + observed: Vector, expected: Optional[Vector] = ... + ) -> ChiSqTestResult: ... + @overload + @staticmethod + def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ... + @staticmethod + def kolmogorovSmirnovTest( + data, distName: Literal["norm"] = ..., *params: float + ) -> KolmogorovSmirnovTestResult: ... diff --git a/python/pyspark/mllib/stat/distribution.pyi b/python/pyspark/mllib/stat/distribution.pyi new file mode 100644 index 0000000000000..8bb93f91b07b5 --- /dev/null +++ b/python/pyspark/mllib/stat/distribution.pyi @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import NamedTuple + +from pyspark.mllib.linalg import Vector, Matrix + +class MultivariateGaussian(NamedTuple): + mu: Vector + sigma: Matrix diff --git a/python/pyspark/mllib/stat/test.pyi b/python/pyspark/mllib/stat/test.pyi new file mode 100644 index 0000000000000..a65f8e40e87d8 --- /dev/null +++ b/python/pyspark/mllib/stat/test.pyi @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Generic, Tuple, TypeVar + +from pyspark.mllib.common import JavaModelWrapper + +DF = TypeVar("DF", int, float, Tuple[int, ...], Tuple[float, ...]) + +class TestResult(JavaModelWrapper, Generic[DF]): + @property + def pValue(self) -> float: ... + @property + def degreesOfFreedom(self) -> DF: ... + @property + def statistic(self) -> float: ... + @property + def nullHypothesis(self) -> str: ... + +class ChiSqTestResult(TestResult[int]): + @property + def method(self) -> str: ... + +class KolmogorovSmirnovTestResult(TestResult[int]): ... diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index 27a340068a52a..89d09fae5cfbc 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -295,7 +295,7 @@ def test_fpgrowth(self): from pyspark.mllib.tests.test_algorithms import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py index 165c1466ddfa8..7fba83b3ea35f 100644 --- a/python/pyspark/mllib/tests/test_feature.py +++ b/python/pyspark/mllib/tests/test_feature.py @@ -185,7 +185,7 @@ def test_pca(self): from pyspark.mllib.tests.test_feature import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index 0e25836599307..a8303ba4341f3 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -22,8 +22,10 @@ import pyspark.ml.linalg as newlinalg from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ +from pyspark.mllib.linalg import ( # type: ignore[attr-defined] + Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +) from pyspark.mllib.linalg.distributed import RowMatrix, IndexedRowMatrix from pyspark.mllib.regression import LabeledPoint from pyspark.sql import Row @@ -641,7 +643,7 @@ def test_regression(self): from pyspark.mllib.tests.test_linalg import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py index 6ed0589387a46..414106fe51cc8 100644 --- a/python/pyspark/mllib/tests/test_stat.py +++ b/python/pyspark/mllib/tests/test_stat.py @@ -180,7 +180,7 @@ def test_R_implementation_equivalence(self): from pyspark.mllib.tests.test_stat import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index 666f6f4d8628b..b94fb2778d88d 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -469,7 +469,7 @@ def condition(): from pyspark.mllib.tests.test_streaming_algorithms import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py index 12578e417bcdf..2be3f17069fd4 100644 --- a/python/pyspark/mllib/tests/test_util.py +++ b/python/pyspark/mllib/tests/test_util.py @@ -19,7 +19,7 @@ import tempfile import unittest -from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.common import _to_java_object_rdd # type: ignore[attr-defined] from pyspark.mllib.util import LinearDataGenerator from pyspark.mllib.util import MLUtils from pyspark.mllib.linalg import SparseVector, DenseVector, Vectors @@ -97,7 +97,7 @@ def test_to_java_object_rdd(self): # SPARK-6660 from pyspark.mllib.tests.test_util import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/mllib/tree.pyi b/python/pyspark/mllib/tree.pyi new file mode 100644 index 0000000000000..511afdeb063d9 --- /dev/null +++ b/python/pyspark/mllib/tree.pyi @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Dict, Optional, Tuple +from pyspark.mllib._typing import VectorLike +from pyspark.rdd import RDD +from pyspark.mllib.common import JavaModelWrapper +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable + +class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + @overload + def predict(self, x: VectorLike) -> float: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[VectorLike]: ... + def numTrees(self) -> int: ... + def totalNumNodes(self) -> int: ... + def toDebugString(self) -> str: ... + +class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader[DecisionTreeModel]): + @overload + def predict(self, x: VectorLike) -> float: ... + @overload + def predict(self, x: RDD[VectorLike]) -> RDD[VectorLike]: ... + def numNodes(self) -> int: ... + def depth(self) -> int: ... + def toDebugString(self) -> str: ... + +class DecisionTree: + @classmethod + def trainClassifier( + cls, + data: RDD[LabeledPoint], + numClasses: int, + categoricalFeaturesInfo: Dict[int, int], + impurity: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + ) -> DecisionTreeModel: ... + @classmethod + def trainRegressor( + cls, + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + impurity: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + minInstancesPerNode: int = ..., + minInfoGain: float = ..., + ) -> DecisionTreeModel: ... + +class RandomForestModel(TreeEnsembleModel, JavaLoader[RandomForestModel]): ... + +class RandomForest: + supportedFeatureSubsetStrategies: Tuple[str, ...] + @classmethod + def trainClassifier( + cls, + data: RDD[LabeledPoint], + numClasses: int, + categoricalFeaturesInfo: Dict[int, int], + numTrees: int, + featureSubsetStrategy: str = ..., + impurity: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + seed: Optional[int] = ..., + ) -> RandomForestModel: ... + @classmethod + def trainRegressor( + cls, + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + numTrees: int, + featureSubsetStrategy: str = ..., + impurity: str = ..., + maxDepth: int = ..., + maxBins: int = ..., + seed: Optional[int] = ..., + ) -> RandomForestModel: ... + +class GradientBoostedTreesModel( + TreeEnsembleModel, JavaLoader[GradientBoostedTreesModel] +): ... + +class GradientBoostedTrees: + @classmethod + def trainClassifier( + cls, + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + loss: str = ..., + numIterations: int = ..., + learningRate: float = ..., + maxDepth: int = ..., + maxBins: int = ..., + ) -> GradientBoostedTreesModel: ... + @classmethod + def trainRegressor( + cls, + data: RDD[LabeledPoint], + categoricalFeaturesInfo: Dict[int, int], + loss: str = ..., + numIterations: int = ..., + learningRate: float = ..., + maxDepth: int = ..., + maxBins: int = ..., + ) -> GradientBoostedTreesModel: ... diff --git a/python/pyspark/mllib/util.pyi b/python/pyspark/mllib/util.pyi new file mode 100644 index 0000000000000..265f765ee263a --- /dev/null +++ b/python/pyspark/mllib/util.pyi @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Generic, List, Optional, TypeVar + +from pyspark.mllib._typing import VectorLike +from pyspark.context import SparkContext +from pyspark.mllib.linalg import Vector +from pyspark.mllib.regression import LabeledPoint +from pyspark.rdd import RDD +from pyspark.sql.dataframe import DataFrame + +T = TypeVar("T") + +class MLUtils: + @staticmethod + def loadLibSVMFile( + sc: SparkContext, + path: str, + numFeatures: int = ..., + minPartitions: Optional[int] = ..., + ) -> RDD[LabeledPoint]: ... + @staticmethod + def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: str) -> None: ... + @staticmethod + def loadLabeledPoints( + sc: SparkContext, path: str, minPartitions: Optional[int] = ... + ) -> RDD[LabeledPoint]: ... + @staticmethod + def appendBias(data: Vector) -> Vector: ... + @staticmethod + def loadVectors(sc: SparkContext, path: str) -> RDD[Vector]: ... + @staticmethod + def convertVectorColumnsToML(dataset: DataFrame, *cols: str) -> DataFrame: ... + @staticmethod + def convertVectorColumnsFromML(dataset: DataFrame, *cols: str) -> DataFrame: ... + @staticmethod + def convertMatrixColumnsToML(dataset: DataFrame, *cols: str) -> DataFrame: ... + @staticmethod + def convertMatrixColumnsFromML(dataset: DataFrame, *cols: str) -> DataFrame: ... + +class Saveable: + def save(self, sc: SparkContext, path: str) -> None: ... + +class JavaSaveable(Saveable): + def save(self, sc: SparkContext, path: str) -> None: ... + +class Loader(Generic[T]): + @classmethod + def load(cls, sc: SparkContext, path: str) -> T: ... + +class JavaLoader(Loader[T]): + @classmethod + def load(cls, sc: SparkContext, path: str) -> T: ... + +class LinearDataGenerator: + @staticmethod + def generateLinearInput( + intercept: float, + weights: VectorLike, + xMean: VectorLike, + xVariance: VectorLike, + nPoints: int, + seed: int, + eps: float, + ) -> List[LabeledPoint]: ... + @staticmethod + def generateLinearRDD( + sc: SparkContext, + nexamples: int, + nfeatures: int, + eps: float, + nParts: int = ..., + intercept: float = ..., + ) -> RDD[LabeledPoint]: ... diff --git a/python/pyspark/profiler.pyi b/python/pyspark/profiler.pyi new file mode 100644 index 0000000000000..7276da529fa17 --- /dev/null +++ b/python/pyspark/profiler.pyi @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, List, Optional, Tuple, Type + +import pstats + +from pyspark.accumulators import AccumulatorParam +from pyspark.context import SparkContext + +class ProfilerCollector: + profiler_cls: Type[Profiler] + profile_dump_path: Optional[str] + profilers: List[Tuple[int, Profiler, bool]] + def __init__( + self, profiler_cls: Type[Profiler], dump_path: Optional[str] = ... + ) -> None: ... + def new_profiler(self, ctx: SparkContext) -> Profiler: ... + def add_profiler(self, id: int, profiler: Profiler) -> None: ... + def dump_profiles(self, path: str) -> None: ... + def show_profiles(self) -> None: ... + +class Profiler: + def __init__(self, ctx: SparkContext) -> None: ... + def profile(self, func: Callable[[], Any]) -> None: ... + def stats(self) -> pstats.Stats: ... + def show(self, id: int) -> None: ... + def dump(self, id: int, path: str) -> None: ... + +class PStatsParam(AccumulatorParam): + @staticmethod + def zero(value: pstats.Stats) -> None: ... + @staticmethod + def addInPlace( + value1: Optional[pstats.Stats], value2: Optional[pstats.Stats] + ) -> Optional[pstats.Stats]: ... + +class BasicProfiler(Profiler): + def __init__(self, ctx: SparkContext) -> None: ... + def profile(self, func: Callable[[], Any]) -> None: ... + def stats(self) -> pstats.Stats: ... diff --git a/python/pyspark/py.typed b/python/pyspark/py.typed new file mode 100644 index 0000000000000..b648ac9233330 --- /dev/null +++ b/python/pyspark/py.typed @@ -0,0 +1 @@ +partial diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi new file mode 100644 index 0000000000000..35c49e952b0cd --- /dev/null +++ b/python/pyspark/rdd.pyi @@ -0,0 +1,479 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import ( + Any, + Callable, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, + TypeVar, +) +from typing_extensions import Literal + +from numpy import int32, int64, float32, float64, ndarray # type: ignore[import] + +from pyspark._typing import SupportsOrdering +from pyspark.sql.pandas._typing import ( + PandasScalarUDFType, + PandasScalarIterUDFType, + PandasGroupedMapUDFType, + PandasCogroupedMapUDFType, + PandasGroupedAggUDFType, + PandasMapIterUDFType, +) +import pyspark.context +from pyspark.resultiterable import ResultIterable +from pyspark.serializers import Serializer +from pyspark.storagelevel import StorageLevel +from pyspark.resource.requests import ( # noqa: F401 + ExecutorResourceRequests, + TaskResourceRequests, +) +from pyspark.resource.profile import ResourceProfile +from pyspark.statcounter import StatCounter +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import StructType +from pyspark.sql._typing import RowLike +from py4j.java_gateway import JavaObject # type: ignore[import] + +T = TypeVar("T") +U = TypeVar("U") +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") +V1 = TypeVar("V1") +V2 = TypeVar("V2") +V3 = TypeVar("V3") +O = TypeVar("O", bound=SupportsOrdering) +NumberOrArray = TypeVar( + "NumberOrArray", float, int, complex, int32, int64, float32, float64, ndarray +) + +def portable_hash(x: Hashable) -> int: ... + +class PythonEvalType: + NON_UDF: Literal[0] + SQL_BATCHED_UDF: Literal[100] + SQL_SCALAR_PANDAS_UDF: PandasScalarUDFType + SQL_GROUPED_MAP_PANDAS_UDF: PandasGroupedMapUDFType + SQL_GROUPED_AGG_PANDAS_UDF: PandasGroupedAggUDFType + SQL_WINDOW_AGG_PANDAS_UDF: Literal[203] + SQL_SCALAR_PANDAS_ITER_UDF: PandasScalarIterUDFType + SQL_MAP_PANDAS_ITER_UDF: PandasMapIterUDFType + SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType + +class BoundedFloat(float): + def __new__(cls, mean: float, confidence: float, low: float, high: float): ... + +class Partitioner: + numPartitions: int + partitionFunc: Callable[[Any], int] + def __init__(self, numPartitions, partitionFunc) -> None: ... + def __eq__(self, other: Any) -> bool: ... + def __call__(self, k: Any) -> int: ... + +class RDD(Generic[T]): + is_cached: bool + is_checkpointed: bool + ctx: pyspark.context.SparkContext + partitioner: Optional[Partitioner] + def __init__( + self, + jrdd: JavaObject, + ctx: pyspark.context.SparkContext, + jrdd_deserializer: Serializer = ..., + ) -> None: ... + def id(self) -> int: ... + def __getnewargs__(self) -> Any: ... + @property + def context(self) -> pyspark.context.SparkContext: ... + def cache(self) -> RDD[T]: ... + def persist(self, storageLevel: StorageLevel = ...) -> RDD[T]: ... + def unpersist(self, blocking: bool = ...) -> RDD[T]: ... + def checkpoint(self) -> None: ... + def isCheckpointed(self) -> bool: ... + def localCheckpoint(self) -> None: ... + def isLocallyCheckpointed(self) -> bool: ... + def getCheckpointFile(self) -> Optional[str]: ... + def map(self, f: Callable[[T], U], preservesPartitioning: bool = ...) -> RDD[U]: ... + def flatMap( + self, f: Callable[[T], Iterable[U]], preservesPartitioning: bool = ... + ) -> RDD[U]: ... + def mapPartitions( + self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ... + ) -> RDD[U]: ... + def mapPartitionsWithIndex( + self, + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = ..., + ) -> RDD[U]: ... + def mapPartitionsWithSplit( + self, + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = ..., + ) -> RDD[U]: ... + def getNumPartitions(self) -> int: ... + def filter(self, f: Callable[[T], bool]) -> RDD[T]: ... + def distinct(self, numPartitions: Optional[int] = ...) -> RDD[T]: ... + def sample( + self, withReplacement: bool, fraction: float, seed: Optional[int] = ... + ) -> RDD[T]: ... + def randomSplit( + self, weights: List[Union[int, float]], seed: Optional[int] = ... + ) -> List[RDD[T]]: ... + def takeSample( + self, withReplacement: bool, num: int, seed: Optional[int] = ... + ) -> List[T]: ... + def union(self, other: RDD[U]) -> RDD[Union[T, U]]: ... + def intersection(self, other: RDD[T]) -> RDD[T]: ... + def __add__(self, other: RDD[T]) -> RDD[T]: ... + @overload + def repartitionAndSortWithinPartitions( + self: RDD[Tuple[O, V]], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[O], int] = ..., + ascending: bool = ..., + ) -> RDD[Tuple[O, V]]: ... + @overload + def repartitionAndSortWithinPartitions( + self: RDD[Tuple[K, V]], + numPartitions: Optional[int], + partitionFunc: Callable[[K], int], + ascending: bool, + keyfunc: Callable[[K], O], + ) -> RDD[Tuple[K, V]]: ... + @overload + def repartitionAndSortWithinPartitions( + self: RDD[Tuple[K, V]], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ascending: bool = ..., + *, + keyfunc: Callable[[K], O] + ) -> RDD[Tuple[K, V]]: ... + @overload + def sortByKey( + self: RDD[Tuple[O, V]], + ascending: bool = ..., + numPartitions: Optional[int] = ..., + ) -> RDD[Tuple[K, V]]: ... + @overload + def sortByKey( + self: RDD[Tuple[K, V]], + ascending: bool, + numPartitions: int, + keyfunc: Callable[[K], O], + ) -> RDD[Tuple[K, V]]: ... + @overload + def sortByKey( + self: RDD[Tuple[K, V]], + ascending: bool = ..., + numPartitions: Optional[int] = ..., + *, + keyfunc: Callable[[K], O] + ) -> RDD[Tuple[K, V]]: ... + def sortBy( + self: RDD[T], + keyfunc: Callable[[T], O], + ascending: bool = ..., + numPartitions: Optional[int] = ..., + ) -> RDD[T]: ... + def glom(self) -> RDD[List[T]]: ... + def cartesian(self, other: RDD[U]) -> RDD[Tuple[T, U]]: ... + def groupBy( + self, + f: Callable[[T], K], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ) -> RDD[Tuple[K, Iterable[T]]]: ... + def pipe( + self, command: str, env: Optional[Dict[str, str]] = ..., checkCode: bool = ... + ) -> RDD[str]: ... + def foreach(self, f: Callable[[T], None]) -> None: ... + def foreachPartition(self, f: Callable[[Iterable[T]], None]) -> None: ... + def collect(self) -> List[T]: ... + def collectWithJobGroup( + self, groupId: str, description: str, interruptOnCancel: bool = ... + ) -> List[T]: ... + def reduce(self, f: Callable[[T, T], T]) -> T: ... + def treeReduce(self, f: Callable[[T, T], T], depth: int = ...) -> T: ... + def fold(self, zeroValue: T, op: Callable[[T, T], T]) -> T: ... + def aggregate( + self, zeroValue: U, seqOp: Callable[[U, T], U], combOp: Callable[[U, U], U] + ) -> U: ... + def treeAggregate( + self, + zeroValue: U, + seqOp: Callable[[U, T], U], + combOp: Callable[[U, U], U], + depth: int = ..., + ) -> U: ... + @overload + def max(self: RDD[O]) -> O: ... + @overload + def max(self, key: Callable[[T], O]) -> T: ... + @overload + def min(self: RDD[O]) -> O: ... + @overload + def min(self, key: Callable[[T], O]) -> T: ... + def sum(self: RDD[NumberOrArray]) -> NumberOrArray: ... + def count(self) -> int: ... + def stats(self: RDD[NumberOrArray]) -> StatCounter: ... + def histogram(self, buckets: List[T]) -> Tuple[List[T], List[int]]: ... + def mean(self: RDD[NumberOrArray]) -> NumberOrArray: ... + def variance(self: RDD[NumberOrArray]) -> NumberOrArray: ... + def stdev(self: RDD[NumberOrArray]) -> NumberOrArray: ... + def sampleStdev(self: RDD[NumberOrArray]) -> NumberOrArray: ... + def sampleVariance(self: RDD[NumberOrArray]) -> NumberOrArray: ... + def countByValue(self: RDD[K]) -> Dict[K, int]: ... + @overload + def top(self: RDD[O], num: int) -> List[O]: ... + @overload + def top(self: RDD[T], num: int, key: Callable[[T], O]) -> List[T]: ... + @overload + def takeOrdered(self: RDD[O], num: int) -> List[O]: ... + @overload + def takeOrdered(self: RDD[T], num: int, key: Callable[[T], O]) -> List[T]: ... + def take(self, num: int) -> List[T]: ... + def first(self) -> T: ... + def isEmpty(self) -> bool: ... + def saveAsNewAPIHadoopDataset( + self: RDD[Tuple[K, V]], + conf: Dict[str, str], + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + ) -> None: ... + def saveAsNewAPIHadoopFile( + self: RDD[Tuple[K, V]], + path: str, + outputFormatClass: str, + keyClass: Optional[str] = ..., + valueClass: Optional[str] = ..., + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + conf: Optional[Dict[str, str]] = ..., + ) -> None: ... + def saveAsHadoopDataset( + self: RDD[Tuple[K, V]], + conf: Dict[str, str], + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + ) -> None: ... + def saveAsHadoopFile( + self: RDD[Tuple[K, V]], + path: str, + outputFormatClass: str, + keyClass: Optional[str] = ..., + valueClass: Optional[str] = ..., + keyConverter: Optional[str] = ..., + valueConverter: Optional[str] = ..., + conf: Optional[str] = ..., + compressionCodecClass: Optional[str] = ..., + ) -> None: ... + def saveAsSequenceFile( + self: RDD[Tuple[K, V]], path: str, compressionCodecClass: Optional[str] = ... + ) -> None: ... + def saveAsPickleFile(self, path: str, batchSize: int = ...) -> None: ... + def saveAsTextFile( + self, path: str, compressionCodecClass: Optional[str] = ... + ) -> None: ... + def collectAsMap(self: RDD[Tuple[K, V]]) -> Dict[K, V]: ... + def keys(self: RDD[Tuple[K, V]]) -> RDD[K]: ... + def values(self: RDD[Tuple[K, V]]) -> RDD[V]: ... + def reduceByKey( + self: RDD[Tuple[K, V]], + func: Callable[[V, V], V], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ) -> RDD[Tuple[K, V]]: ... + def reduceByKeyLocally( + self: RDD[Tuple[K, V]], func: Callable[[V, V], V] + ) -> Dict[K, V]: ... + def countByKey(self: RDD[Tuple[K, V]]) -> Dict[K, int]: ... + def join( + self: RDD[Tuple[K, V]], + other: RDD[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> RDD[Tuple[K, Tuple[V, U]]]: ... + def leftOuterJoin( + self: RDD[Tuple[K, V]], + other: RDD[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> RDD[Tuple[K, Tuple[V, Optional[U]]]]: ... + def rightOuterJoin( + self: RDD[Tuple[K, V]], + other: RDD[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> RDD[Tuple[K, Tuple[Optional[V], U]]]: ... + def fullOuterJoin( + self: RDD[Tuple[K, V]], + other: RDD[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> RDD[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ... + def partitionBy( + self: RDD[Tuple[K, V]], + numPartitions: int, + partitionFunc: Callable[[K], int] = ..., + ) -> RDD[Tuple[K, V]]: ... + def combineByKey( + self: RDD[Tuple[K, V]], + createCombiner: Callable[[V], U], + mergeValue: Callable[[U, V], U], + mergeCombiners: Callable[[U, U], U], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ) -> RDD[Tuple[K, U]]: ... + def aggregateByKey( + self: RDD[Tuple[K, V]], + zeroValue: U, + seqFunc: Callable[[U, V], U], + combFunc: Callable[[U, U], U], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ) -> RDD[Tuple[K, U]]: ... + def foldByKey( + self: RDD[Tuple[K, V]], + zeroValue: V, + func: Callable[[V, V], V], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ) -> RDD[Tuple[K, V]]: ... + def groupByKey( + self: RDD[Tuple[K, V]], + numPartitions: Optional[int] = ..., + partitionFunc: Callable[[K], int] = ..., + ) -> RDD[Tuple[K, Iterable[V]]]: ... + def flatMapValues( + self: RDD[Tuple[K, V]], f: Callable[[V], Iterable[U]] + ) -> RDD[Tuple[K, U]]: ... + def mapValues(self: RDD[Tuple[K, V]], f: Callable[[V], U]) -> RDD[Tuple[K, U]]: ... + @overload + def groupWith( + self: RDD[Tuple[K, V]], __o: RDD[Tuple[K, V1]] + ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1]]]]: ... + @overload + def groupWith( + self: RDD[Tuple[K, V]], __o1: RDD[Tuple[K, V1]], __o2: RDD[Tuple[K, V2]] + ) -> RDD[ + Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1], ResultIterable[V2]]] + ]: ... + @overload + def groupWith( + self: RDD[Tuple[K, V]], + other1: RDD[Tuple[K, V1]], + other2: RDD[Tuple[K, V2]], + other3: RDD[Tuple[K, V3]], + ) -> RDD[ + Tuple[ + K, + Tuple[ + ResultIterable[V], + ResultIterable[V1], + ResultIterable[V2], + ResultIterable[V3], + ], + ] + ]: ... + def cogroup( + self: RDD[Tuple[K, V]], + other: RDD[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]: ... + def sampleByKey( + self: RDD[Tuple[K, V]], + withReplacement: bool, + fractions: Dict[K, Union[float, int]], + seed: Optional[int] = ..., + ) -> RDD[Tuple[K, V]]: ... + def subtractByKey( + self: RDD[Tuple[K, V]], + other: RDD[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> RDD[Tuple[K, V]]: ... + def subtract( + self: RDD[T], other: RDD[T], numPartitions: Optional[int] = ... + ) -> RDD[T]: ... + def keyBy(self: RDD[T], f: Callable[[T], K]) -> RDD[Tuple[K, T]]: ... + def repartition(self, numPartitions: int) -> RDD[T]: ... + def coalesce(self, numPartitions: int, shuffle: bool = ...) -> RDD[T]: ... + def zip(self, other: RDD[U]) -> RDD[Tuple[T, U]]: ... + def zipWithIndex(self) -> RDD[Tuple[T, int]]: ... + def zipWithUniqueId(self) -> RDD[Tuple[T, int]]: ... + def name(self) -> str: ... + def setName(self, name: str) -> RDD[T]: ... + def toDebugString(self) -> bytes: ... + def getStorageLevel(self) -> StorageLevel: ... + def lookup(self: RDD[Tuple[K, V]], key: K) -> List[V]: ... + def countApprox(self, timeout: int, confidence: float = ...) -> int: ... + def sumApprox( + self: RDD[Union[float, int]], timeout: int, confidence: float = ... + ) -> BoundedFloat: ... + def meanApprox( + self: RDD[Union[float, int]], timeout: int, confidence: float = ... + ) -> BoundedFloat: ... + def countApproxDistinct(self, relativeSD: float = ...) -> int: ... + def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[T]: ... + def barrier(self: RDD[T]) -> RDDBarrier[T]: ... + def withResources(self: RDD[T], profile: ResourceProfile) -> RDD[T]: ... + def getResourceProfile(self) -> Optional[ResourceProfile]: ... + @overload + def toDF( + self: RDD[RowLike], + schema: Optional[List[str]] = ..., + sampleRatio: Optional[float] = ..., + ) -> DataFrame: ... + @overload + def toDF(self: RDD[RowLike], schema: Optional[StructType] = ...) -> DataFrame: ... + +class RDDBarrier(Generic[T]): + rdd: RDD[T] + def __init__(self, rdd: RDD[T]) -> None: ... + def mapPartitions( + self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ... + ) -> RDD[U]: ... + def mapPartitionsWithIndex( + self, + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = ..., + ) -> RDD[U]: ... + +class PipelinedRDD(RDD[U], Generic[T, U]): + func: Callable[[T], U] + preservesPartitioning: bool + is_cached: bool + is_checkpointed: bool + ctx: pyspark.context.SparkContext + prev: RDD[T] + partitioner: Optional[Partitioner] + is_barrier: bool + def __init__( + self, + prev: RDD[T], + func: Callable[[Iterable[T]], Iterable[U]], + preservesPartitioning: bool = ..., + isFromBarrier: bool = ..., + ) -> None: ... + def getNumPartitions(self) -> int: ... + def id(self) -> int: ... diff --git a/python/pyspark/rddsampler.pyi b/python/pyspark/rddsampler.pyi new file mode 100644 index 0000000000000..8fbf72d90025c --- /dev/null +++ b/python/pyspark/rddsampler.pyi @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Iterator, Optional, Tuple, TypeVar + +T = TypeVar("T") +U = TypeVar("U") +K = TypeVar("K") +V = TypeVar("V") + +class RDDSamplerBase: + def __init__(self, withReplacement: bool, seed: Optional[int] = ...) -> None: ... + def initRandomGenerator(self, split: int) -> None: ... + def getUniformSample(self) -> float: ... + def getPoissonSample(self, mean: float) -> int: ... + def func(self, split: int, iterator: Iterator[Any]) -> Iterator[Any]: ... + +class RDDSampler(RDDSamplerBase): + def __init__( + self, withReplacement: bool, fraction: float, seed: Optional[int] = ... + ) -> None: ... + def func(self, split: int, iterator: Iterator[T]) -> Iterator[T]: ... + +class RDDRangeSampler(RDDSamplerBase): + def __init__( + self, lowerBound: T, upperBound: T, seed: Optional[Any] = ... + ) -> None: ... + def func(self, split: int, iterator: Iterator[T]) -> Iterator[T]: ... + +class RDDStratifiedSampler(RDDSamplerBase): + def __init__( + self, + withReplacement: bool, + fractions: Dict[K, float], + seed: Optional[int] = ..., + ) -> None: ... + def func( + self, split: int, iterator: Iterator[Tuple[K, V]] + ) -> Iterator[Tuple[K, V]]: ... diff --git a/python/pyspark/resource/__init__.pyi b/python/pyspark/resource/__init__.pyi new file mode 100644 index 0000000000000..87a9b53c268ac --- /dev/null +++ b/python/pyspark/resource/__init__.pyi @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.resource.information import ( # noqa: F401 + ResourceInformation as ResourceInformation, +) +from pyspark.resource.profile import ( # noqa: F401 + ResourceProfile as ResourceProfile, + ResourceProfileBuilder as ResourceProfileBuilder, +) +from pyspark.resource.requests import ( # noqa: F401 + ExecutorResourceRequest as ExecutorResourceRequest, + ExecutorResourceRequests as ExecutorResourceRequests, + TaskResourceRequest as TaskResourceRequest, + TaskResourceRequests as TaskResourceRequests, +) diff --git a/python/pyspark/resource/information.pyi b/python/pyspark/resource/information.pyi new file mode 100644 index 0000000000000..7baa6ca8520bd --- /dev/null +++ b/python/pyspark/resource/information.pyi @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any + +class ResourceInformation: + def __init__(self, name: Any, addresses: Any) -> None: ... + @property + def name(self): ... + @property + def addresses(self): ... diff --git a/python/pyspark/resource/profile.pyi b/python/pyspark/resource/profile.pyi new file mode 100644 index 0000000000000..8ce7d93b29e93 --- /dev/null +++ b/python/pyspark/resource/profile.pyi @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.resource.requests import ( # noqa: F401 + ExecutorResourceRequest as ExecutorResourceRequest, + ExecutorResourceRequests as ExecutorResourceRequests, + TaskResourceRequest as TaskResourceRequest, + TaskResourceRequests as TaskResourceRequests, +) +from typing import Any, Optional + +class ResourceProfile: + def __init__( + self, + _java_resource_profile: Optional[Any] = ..., + _exec_req: Any = ..., + _task_req: Any = ..., + ) -> None: ... + @property + def id(self): ... + @property + def taskResources(self): ... + @property + def executorResources(self): ... + +class ResourceProfileBuilder: + def __init__(self) -> None: ... + def require(self, resourceRequest: Any): ... + def clearExecutorResourceRequests(self) -> None: ... + def clearTaskResourceRequests(self) -> None: ... + @property + def taskResources(self): ... + @property + def executorResources(self): ... + @property + def build(self): ... diff --git a/python/pyspark/resource/requests.pyi b/python/pyspark/resource/requests.pyi new file mode 100644 index 0000000000000..f9448d0780409 --- /dev/null +++ b/python/pyspark/resource/requests.pyi @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Optional + +class ExecutorResourceRequest: + def __init__( + self, + resourceName: Any, + amount: Any, + discoveryScript: str = ..., + vendor: str = ..., + ) -> None: ... + @property + def resourceName(self): ... + @property + def amount(self): ... + @property + def discoveryScript(self): ... + @property + def vendor(self): ... + +class ExecutorResourceRequests: + def __init__( + self, _jvm: Optional[Any] = ..., _requests: Optional[Any] = ... + ) -> None: ... + def memory(self, amount: Any): ... + def memoryOverhead(self, amount: Any): ... + def pysparkMemory(self, amount: Any): ... + def offheapMemory(self, amount: Any): ... + def cores(self, amount: Any): ... + def resource( + self, + resourceName: Any, + amount: Any, + discoveryScript: str = ..., + vendor: str = ..., + ): ... + @property + def requests(self): ... + +class TaskResourceRequest: + def __init__(self, resourceName: Any, amount: Any) -> None: ... + @property + def resourceName(self): ... + @property + def amount(self): ... + +class TaskResourceRequests: + def __init__( + self, _jvm: Optional[Any] = ..., _requests: Optional[Any] = ... + ) -> None: ... + def cpus(self, amount: Any): ... + def resource(self, resourceName: Any, amount: Any): ... + @property + def requests(self): ... diff --git a/python/pyspark/resource/tests/test_resources.py b/python/pyspark/resource/tests/test_resources.py index c2b574c61abc5..6149f1ff7205a 100644 --- a/python/pyspark/resource/tests/test_resources.py +++ b/python/pyspark/resource/tests/test_resources.py @@ -75,7 +75,7 @@ def assert_request_contents(exec_reqs, task_reqs): from pyspark.resource.tests.test_resources import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/resultiterable.pyi b/python/pyspark/resultiterable.pyi new file mode 100644 index 0000000000000..69596ad82c8cc --- /dev/null +++ b/python/pyspark/resultiterable.pyi @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark._typing import SizedIterable +from typing import Iterator, TypeVar + +T = TypeVar("T") + +class ResultIterable(SizedIterable[T]): + data: SizedIterable[T] + index: int + maxindex: int + def __init__(self, data: SizedIterable[T]) -> None: ... + def __iter__(self) -> Iterator[T]: ... + def __len__(self) -> int: ... diff --git a/python/pyspark/serializers.pyi b/python/pyspark/serializers.pyi new file mode 100644 index 0000000000000..26ef17c38d227 --- /dev/null +++ b/python/pyspark/serializers.pyi @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any + +class SpecialLengths: + END_OF_DATA_SECTION: int = ... + PYTHON_EXCEPTION_THROWN: int = ... + TIMING_DATA: int = ... + END_OF_STREAM: int = ... + NULL: int = ... + START_ARROW_STREAM: int = ... + +class Serializer: + def dump_stream(self, iterator: Any, stream: Any) -> None: ... + def load_stream(self, stream: Any) -> None: ... + def __eq__(self, other: Any) -> Any: ... + def __ne__(self, other: Any) -> Any: ... + def __hash__(self) -> Any: ... + +class FramedSerializer(Serializer): + def __init__(self) -> None: ... + def dump_stream(self, iterator: Any, stream: Any) -> None: ... + def load_stream(self, stream: Any) -> None: ... + def dumps(self, obj: Any) -> None: ... + def loads(self, obj: Any) -> None: ... + +class BatchedSerializer(Serializer): + UNLIMITED_BATCH_SIZE: int = ... + UNKNOWN_BATCH_SIZE: int = ... + serializer: Any = ... + batchSize: Any = ... + def __init__(self, serializer: Any, batchSize: Any = ...) -> None: ... + def dump_stream(self, iterator: Any, stream: Any) -> None: ... + def load_stream(self, stream: Any): ... + +class FlattenedValuesSerializer(BatchedSerializer): + def __init__(self, serializer: Any, batchSize: int = ...) -> None: ... + def load_stream(self, stream: Any): ... + +class AutoBatchedSerializer(BatchedSerializer): + bestSize: Any = ... + def __init__(self, serializer: Any, bestSize: Any = ...) -> None: ... + def dump_stream(self, iterator: Any, stream: Any) -> None: ... + +class CartesianDeserializer(Serializer): + key_ser: Any = ... + val_ser: Any = ... + def __init__(self, key_ser: Any, val_ser: Any) -> None: ... + def load_stream(self, stream: Any): ... + +class PairDeserializer(Serializer): + key_ser: Any = ... + val_ser: Any = ... + def __init__(self, key_ser: Any, val_ser: Any) -> None: ... + def load_stream(self, stream: Any): ... + +class NoOpSerializer(FramedSerializer): + def loads(self, obj: Any): ... + def dumps(self, obj: Any): ... + +class PickleSerializer(FramedSerializer): + def dumps(self, obj: Any): ... + def loads(self, obj: Any, encoding: str = ...): ... + +class CloudPickleSerializer(PickleSerializer): + def dumps(self, obj: Any): ... + +class MarshalSerializer(FramedSerializer): + def dumps(self, obj: Any): ... + def loads(self, obj: Any): ... + +class AutoSerializer(FramedSerializer): + def __init__(self) -> None: ... + def dumps(self, obj: Any): ... + def loads(self, obj: Any): ... + +class CompressedSerializer(FramedSerializer): + serializer: Any = ... + def __init__(self, serializer: Any) -> None: ... + def dumps(self, obj: Any): ... + def loads(self, obj: Any): ... + +class UTF8Deserializer(Serializer): + use_unicode: Any = ... + def __init__(self, use_unicode: bool = ...) -> None: ... + def loads(self, stream: Any): ... + def load_stream(self, stream: Any) -> None: ... + +class ChunkedStream: + buffer_size: Any = ... + buffer: Any = ... + current_pos: int = ... + wrapped: Any = ... + def __init__(self, wrapped: Any, buffer_size: Any) -> None: ... + def write(self, bytes: Any) -> None: ... + def close(self) -> None: ... + @property + def closed(self): ... + +def write_with_length(obj: Any, stream: Any): ... +def pack_long(value): ... +def read_int(stream): ... +def read_long(stream): ... +def read_bool(stream): ... +def write_int(value, stream): ... +def write_long(value, stream): ... diff --git a/python/pyspark/shell.pyi b/python/pyspark/shell.pyi new file mode 100644 index 0000000000000..0760309542f8d --- /dev/null +++ b/python/pyspark/shell.pyi @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark import SparkConf as SparkConf # noqa: F401 +from pyspark.context import SparkContext as SparkContext +from pyspark.sql import SQLContext as SQLContext, SparkSession as SparkSession +from typing import Any, Callable + +from pyspark.sql.dataframe import DataFrame + +spark: SparkSession +sc: SparkContext +sql: Callable[[str], DataFrame] +sqlContext: SQLContext +sqlCtx: SQLContext +code: Any diff --git a/python/pyspark/shuffle.pyi b/python/pyspark/shuffle.pyi new file mode 100644 index 0000000000000..10648c51dca8f --- /dev/null +++ b/python/pyspark/shuffle.pyi @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.serializers import ( # noqa: F401 + AutoBatchedSerializer as AutoBatchedSerializer, + BatchedSerializer as BatchedSerializer, + CompressedSerializer as CompressedSerializer, + FlattenedValuesSerializer as FlattenedValuesSerializer, + PickleSerializer as PickleSerializer, +) +from pyspark.util import fail_on_stopiteration as fail_on_stopiteration # noqa: F401 +from typing import Any, Optional + +process: Any + +def get_used_memory(): ... + +MemoryBytesSpilled: int +DiskBytesSpilled: int + +class Aggregator: + createCombiner: Any = ... + mergeValue: Any = ... + mergeCombiners: Any = ... + def __init__( + self, createCombiner: Any, mergeValue: Any, mergeCombiners: Any + ) -> None: ... + +class SimpleAggregator(Aggregator): + def __init__(self, combiner: Any): ... + +class Merger: + agg: Any = ... + def __init__(self, aggregator: Any) -> None: ... + def mergeValues(self, iterator: Any) -> None: ... + def mergeCombiners(self, iterator: Any) -> None: ... + def items(self) -> None: ... + +class ExternalMerger(Merger): + MAX_TOTAL_PARTITIONS: int = ... + memory_limit: Any = ... + serializer: Any = ... + localdirs: Any = ... + partitions: Any = ... + batch: Any = ... + scale: Any = ... + data: Any = ... + pdata: Any = ... + spills: int = ... + def __init__( + self, + aggregator: Any, + memory_limit: int = ..., + serializer: Optional[Any] = ..., + localdirs: Optional[Any] = ..., + scale: int = ..., + partitions: int = ..., + batch: int = ..., + ) -> None: ... + def mergeValues(self, iterator: Any) -> None: ... + def mergeCombiners(self, iterator: Any, limit: Optional[Any] = ...) -> None: ... + def items(self): ... + +class ExternalSorter: + memory_limit: Any = ... + local_dirs: Any = ... + serializer: Any = ... + def __init__(self, memory_limit: Any, serializer: Optional[Any] = ...) -> None: ... + def sorted(self, iterator: Any, key: Optional[Any] = ..., reverse: bool = ...): ... + +class ExternalList: + LIMIT: int = ... + values: Any = ... + count: Any = ... + def __init__(self, values: Any) -> None: ... + def __iter__(self) -> Any: ... + def __len__(self): ... + def append(self, value: Any) -> None: ... + def __del__(self) -> None: ... + +class ExternalListOfList(ExternalList): + count: Any = ... + def __init__(self, values: Any) -> None: ... + def append(self, value: Any) -> None: ... + def __iter__(self) -> Any: ... + +class GroupByKey: + iterator: Any = ... + def __init__(self, iterator: Any) -> None: ... + def __iter__(self) -> Any: ... + +class ExternalGroupBy(ExternalMerger): + SORT_KEY_LIMIT: int = ... + def flattened_serializer(self): ... diff --git a/python/pyspark/sql/__init__.pyi b/python/pyspark/sql/__init__.pyi new file mode 100644 index 0000000000000..787be5647772e --- /dev/null +++ b/python/pyspark/sql/__init__.pyi @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.sql.catalog import Catalog as Catalog # noqa: F401 +from pyspark.sql.column import Column as Column # noqa: F401 +from pyspark.sql.context import ( # noqa: F401 + HiveContext as HiveContext, + SQLContext as SQLContext, + UDFRegistration as UDFRegistration, +) +from pyspark.sql.dataframe import ( # noqa: F401 + DataFrame as DataFrame, + DataFrameNaFunctions as DataFrameNaFunctions, + DataFrameStatFunctions as DataFrameStatFunctions, +) +from pyspark.sql.group import GroupedData as GroupedData # noqa: F401 +from pyspark.sql.pandas.group_ops import ( # noqa: F401 + PandasCogroupedOps as PandasCogroupedOps, +) +from pyspark.sql.readwriter import ( # noqa: F401 + DataFrameReader as DataFrameReader, + DataFrameWriter as DataFrameWriter, +) +from pyspark.sql.session import SparkSession as SparkSession # noqa: F401 +from pyspark.sql.types import Row as Row # noqa: F401 +from pyspark.sql.window import Window as Window, WindowSpec as WindowSpec # noqa: F401 diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi new file mode 100644 index 0000000000000..799a73204a639 --- /dev/null +++ b/python/pyspark/sql/_typing.pyi @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import ( + Any, + List, + Optional, + Tuple, + TypeVar, + Union, +) +from typing_extensions import Protocol + +import datetime +import decimal + +from pyspark._typing import PrimitiveType +import pyspark.sql.column +import pyspark.sql.types +from pyspark.sql.column import Column + +ColumnOrName = Union[pyspark.sql.column.Column, str] +DecimalLiteral = decimal.Decimal +DateTimeLiteral = Union[datetime.datetime, datetime.date] +LiteralType = PrimitiveType +AtomicDataTypeOrString = Union[pyspark.sql.types.AtomicType, str] +DataTypeOrString = Union[pyspark.sql.types.DataType, str] +OptionalPrimitiveType = Optional[PrimitiveType] + +RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row) + +class SupportsOpen(Protocol): + def open(self, partition_id: int, epoch_id: int) -> bool: ... + +class SupportsProcess(Protocol): + def process(self, row: pyspark.sql.types.Row) -> None: ... + +class SupportsClose(Protocol): + def close(self, error: Exception) -> None: ... + +class UserDefinedFunctionLike(Protocol): + def __call__(self, *_: ColumnOrName) -> Column: ... diff --git a/python/pyspark/sql/avro/__init__.pyi b/python/pyspark/sql/avro/__init__.pyi new file mode 100644 index 0000000000000..0d7871da4c100 --- /dev/null +++ b/python/pyspark/sql/avro/__init__.pyi @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: This dynamically typed stub was automatically generated by stubgen. + +# Names in __all__ with no definition: +# functions diff --git a/python/pyspark/sql/avro/functions.pyi b/python/pyspark/sql/avro/functions.pyi new file mode 100644 index 0000000000000..4c2e3814a9e94 --- /dev/null +++ b/python/pyspark/sql/avro/functions.pyi @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict + +from pyspark.sql._typing import ColumnOrName +from pyspark.sql.column import Column + +def from_avro( + data: ColumnOrName, jsonFormatSchema: str, options: Dict[str, str] = ... +) -> Column: ... +def to_avro(data: ColumnOrName, jsonFormatSchema: str = ...) -> Column: ... diff --git a/python/pyspark/sql/catalog.pyi b/python/pyspark/sql/catalog.pyi new file mode 100644 index 0000000000000..86263fff63ce8 --- /dev/null +++ b/python/pyspark/sql/catalog.pyi @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, List, Optional +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.session import SparkSession +from pyspark.sql.types import DataType, StructType +from collections import namedtuple + +Database = namedtuple("Database", "name description locationUri") + +Table = namedtuple("Table", "name database description tableType isTemporary") + +Column = namedtuple("Column", "name description dataType nullable isPartition isBucket") + +Function = namedtuple("Function", "name description className isTemporary") + +class Catalog: + def __init__(self, sparkSession: SparkSession) -> None: ... + def currentDatabase(self) -> str: ... + def setCurrentDatabase(self, dbName: str) -> None: ... + def listDatabases(self) -> List[Database]: ... + def listTables(self, dbName: Optional[str] = ...) -> List[Table]: ... + def listFunctions(self, dbName: Optional[str] = ...) -> List[Function]: ... + def listColumns( + self, tableName: str, dbName: Optional[str] = ... + ) -> List[Column]: ... + def createTable( + self, + tableName: str, + path: Optional[str] = ..., + source: Optional[str] = ..., + schema: Optional[StructType] = ..., + description: Optional[str] = ..., + **options: str + ) -> DataFrame: ... + def dropTempView(self, viewName: str) -> None: ... + def dropGlobalTempView(self, viewName: str) -> None: ... + def registerFunction( + self, name: str, f: Callable[..., Any], returnType: DataType = ... + ) -> None: ... + def isCached(self, tableName: str) -> bool: ... + def cacheTable(self, tableName: str) -> None: ... + def uncacheTable(self, tableName: str) -> None: ... + def clearCache(self) -> None: ... + def refreshTable(self, tableName: str) -> None: ... + def recoverPartitions(self, tableName: str) -> None: ... + def refreshByPath(self, path: str) -> None: ... diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi new file mode 100644 index 0000000000000..261fb6e5f3911 --- /dev/null +++ b/python/pyspark/sql/column.pyi @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Union + +from pyspark.sql._typing import LiteralType, DecimalLiteral, DateTimeLiteral +from pyspark.sql.types import ( # noqa: F401 + DataType, + StructField, + StructType, + IntegerType, + StringType, +) +from pyspark.sql.window import WindowSpec + +from py4j.java_gateway import JavaObject # type: ignore[import] + +class Column: + def __init__(self, JavaObject) -> None: ... + def __neg__(self) -> Column: ... + def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... + def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... + def __mul__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... + def __div__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... + def __truediv__( + self, other: Union[Column, LiteralType, DecimalLiteral] + ) -> Column: ... + def __mod__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... + def __radd__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... + def __rsub__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... + def __rmul__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... + def __rdiv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... + def __rtruediv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... + def __rmod__(self, other: Union[bool, int, float, DecimalLiteral]) -> Column: ... + def __pow__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... + def __rpow__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... + def __eq__(self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]) -> Column: ... # type: ignore[override] + def __ne__(self, other: Any) -> Column: ... # type: ignore[override] + def __lt__( + self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] + ) -> Column: ... + def __le__( + self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] + ) -> Column: ... + def __ge__( + self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] + ) -> Column: ... + def __gt__( + self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] + ) -> Column: ... + def eqNullSafe( + self, other: Union[Column, LiteralType, DecimalLiteral] + ) -> Column: ... + def __and__(self, other: Column) -> Column: ... + def __or__(self, other: Column) -> Column: ... + def __invert__(self) -> Column: ... + def __rand__(self, other: Column) -> Column: ... + def __ror__(self, other: Column) -> Column: ... + def __contains__(self, other: Any) -> Column: ... + def __getitem__(self, other: Any) -> Column: ... + def bitwiseOR(self, other: Union[Column, int]) -> Column: ... + def bitwiseAND(self, other: Union[Column, int]) -> Column: ... + def bitwiseXOR(self, other: Union[Column, int]) -> Column: ... + def getItem(self, key: Any) -> Column: ... + def getField(self, name: Any) -> Column: ... + def withField(self, fieldName: str, col: Column) -> Column: ... + def __getattr__(self, item: Any) -> Column: ... + def __iter__(self) -> None: ... + def rlike(self, item: str) -> Column: ... + def like(self, item: str) -> Column: ... + def startswith(self, item: Union[str, Column]) -> Column: ... + def endswith(self, item: Union[str, Column]) -> Column: ... + @overload + def substr(self, startPos: int, length: int) -> Column: ... + @overload + def substr(self, startPos: Column, length: Column) -> Column: ... + def __getslice__(self, startPos: int, length: int) -> Column: ... + def isin(self, *cols: Any) -> Column: ... + def asc(self) -> Column: ... + def asc_nulls_first(self) -> Column: ... + def asc_nulls_last(self) -> Column: ... + def desc(self) -> Column: ... + def desc_nulls_first(self) -> Column: ... + def desc_nulls_last(self) -> Column: ... + def isNull(self) -> Column: ... + def isNotNull(self) -> Column: ... + def alias(self, *alias: str, **kwargs: Any) -> Column: ... + def name(self, *alias: str) -> Column: ... + def cast(self, dataType: Union[DataType, str]) -> Column: ... + def astype(self, dataType: Union[DataType, str]) -> Column: ... + def between(self, lowerBound, upperBound) -> Column: ... + def when(self, condition: Column, value: Any) -> Column: ... + def otherwise(self, value: Any) -> Column: ... + def over(self, window: WindowSpec) -> Column: ... + def __nonzero__(self) -> None: ... + def __bool__(self) -> None: ... diff --git a/python/pyspark/sql/conf.pyi b/python/pyspark/sql/conf.pyi new file mode 100644 index 0000000000000..304dfcb3f9e53 --- /dev/null +++ b/python/pyspark/sql/conf.pyi @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional +from py4j.java_gateway import JavaObject # type: ignore[import] + +class RuntimeConfig: + def __init__(self, jconf: JavaObject) -> None: ... + def set(self, key: str, value: str) -> str: ... + def get(self, key: str, default: Optional[str] = ...) -> str: ... + def unset(self, key: str) -> None: ... + def isModifiable(self, key: str) -> bool: ... diff --git a/python/pyspark/sql/context.pyi b/python/pyspark/sql/context.pyi new file mode 100644 index 0000000000000..64927b37ac2a9 --- /dev/null +++ b/python/pyspark/sql/context.pyi @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union + +from py4j.java_gateway import JavaObject # type: ignore[import] + +from pyspark.sql._typing import ( + DateTimeLiteral, + LiteralType, + DecimalLiteral, + RowLike, +) +from pyspark.sql.pandas._typing import DataFrameLike +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.session import SparkSession +from pyspark.sql.types import AtomicType, DataType, StructType +from pyspark.sql.udf import UDFRegistration as UDFRegistration +from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.streaming import DataStreamReader, StreamingQueryManager + +T = TypeVar("T") + +class SQLContext: + sparkSession: SparkSession + def __init__( + self, + sparkContext, + sparkSession: Optional[SparkSession] = ..., + jsqlContext: Optional[JavaObject] = ..., + ) -> None: ... + @classmethod + def getOrCreate(cls: type, sc: SparkContext) -> SQLContext: ... + def newSession(self) -> SQLContext: ... + def setConf(self, key: str, value) -> None: ... + def getConf(self, key: str, defaultValue: Optional[str] = ...) -> str: ... + @property + def udf(self) -> UDFRegistration: ... + def range( + self, + start: int, + end: Optional[int] = ..., + step: int = ..., + numPartitions: Optional[int] = ..., + ) -> DataFrame: ... + def registerFunction( + self, name: str, f: Callable[..., Any], returnType: DataType = ... + ) -> None: ... + def registerJavaFunction( + self, name: str, javaClassName: str, returnType: Optional[DataType] = ... + ) -> None: ... + @overload + def createDataFrame( + self, + data: Union[RDD[RowLike], Iterable[RowLike]], + samplingRatio: Optional[float] = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: Union[RDD[RowLike], Iterable[RowLike]], + schema: Union[List[str], Tuple[str, ...]] = ..., + verifySchema: bool = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: Union[ + RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]], + Iterable[Union[DateTimeLiteral, LiteralType, DecimalLiteral]], + ], + schema: Union[AtomicType, str], + verifySchema: bool = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: Union[RDD[RowLike], Iterable[RowLike]], + schema: Union[StructType, str], + verifySchema: bool = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, data: DataFrameLike, samplingRatio: Optional[float] = ... + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: DataFrameLike, + schema: Union[StructType, str], + verifySchema: bool = ..., + ) -> DataFrame: ... + def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: ... + def dropTempTable(self, tableName: str) -> None: ... + def createExternalTable( + self, + tableName: str, + path: Optional[str] = ..., + source: Optional[str] = ..., + schema: Optional[StructType] = ..., + **options + ) -> DataFrame: ... + def sql(self, sqlQuery: str) -> DataFrame: ... + def table(self, tableName: str) -> DataFrame: ... + def tables(self, dbName: Optional[str] = ...) -> DataFrame: ... + def tableNames(self, dbName: Optional[str] = ...) -> List[str]: ... + def cacheTable(self, tableName: str) -> None: ... + def uncacheTable(self, tableName: str) -> None: ... + def clearCache(self) -> None: ... + @property + def read(self) -> DataFrameReader: ... + @property + def readStream(self) -> DataStreamReader: ... + @property + def streams(self) -> StreamingQueryManager: ... + +class HiveContext(SQLContext): + def __init__( + self, sparkContext: SparkContext, jhiveContext: Optional[JavaObject] = ... + ) -> None: ... + def refreshTable(self, tableName: str) -> None: ... diff --git a/python/pyspark/sql/dataframe.pyi b/python/pyspark/sql/dataframe.pyi new file mode 100644 index 0000000000000..c498d529d820f --- /dev/null +++ b/python/pyspark/sql/dataframe.pyi @@ -0,0 +1,324 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, +) + +from py4j.java_gateway import JavaObject # type: ignore[import] + +from pyspark.sql._typing import ColumnOrName, LiteralType, OptionalPrimitiveType +from pyspark.sql.types import ( # noqa: F401 + StructType, + StructField, + StringType, + IntegerType, + Row, +) # noqa: F401 +from pyspark.sql.context import SQLContext +from pyspark.sql.group import GroupedData +from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 +from pyspark.sql.streaming import DataStreamWriter +from pyspark.sql.column import Column +from pyspark.rdd import RDD +from pyspark.storagelevel import StorageLevel + +from pyspark.sql.pandas.conversion import PandasConversionMixin +from pyspark.sql.pandas.map_ops import PandasMapOpsMixin + +class DataFrame(PandasMapOpsMixin, PandasConversionMixin): + sql_ctx: SQLContext + is_cached: bool + def __init__(self, jdf: JavaObject, sql_ctx: SQLContext) -> None: ... + @property + def rdd(self) -> RDD[Row]: ... + @property + def na(self) -> DataFrameNaFunctions: ... + @property + def stat(self) -> DataFrameStatFunctions: ... + def toJSON(self, use_unicode: bool = ...) -> RDD[str]: ... + def registerTempTable(self, name: str) -> None: ... + def createTempView(self, name: str) -> None: ... + def createOrReplaceTempView(self, name: str) -> None: ... + def createGlobalTempView(self, name: str) -> None: ... + @property + def write(self) -> DataFrameWriter: ... + @property + def writeStream(self) -> DataStreamWriter: ... + @property + def schema(self) -> StructType: ... + def printSchema(self) -> None: ... + def explain( + self, extended: Optional[Union[bool, str]] = ..., mode: Optional[str] = ... + ) -> None: ... + def exceptAll(self, other: DataFrame) -> DataFrame: ... + def isLocal(self) -> bool: ... + @property + def isStreaming(self) -> bool: ... + def show( + self, n: int = ..., truncate: Union[bool, int] = ..., vertical: bool = ... + ) -> None: ... + def checkpoint(self, eager: bool = ...) -> DataFrame: ... + def localCheckpoint(self, eager: bool = ...) -> DataFrame: ... + def withWatermark( + self, eventTime: ColumnOrName, delayThreshold: str + ) -> DataFrame: ... + def hint(self, name: str, *parameters: Any) -> DataFrame: ... + def count(self) -> int: ... + def collect(self) -> List[Row]: ... + def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[Row]: ... + def limit(self, num: int) -> DataFrame: ... + def take(self, num: int) -> List[Row]: ... + def tail(self, num: int) -> List[Row]: ... + def foreach(self, f: Callable[[Row], None]) -> None: ... + def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: ... + def cache(self) -> DataFrame: ... + def persist(self, storageLevel: StorageLevel = ...) -> DataFrame: ... + @property + def storageLevel(self) -> StorageLevel: ... + def unpersist(self, blocking: bool = ...) -> DataFrame: ... + def coalesce(self, numPartitions: int) -> DataFrame: ... + @overload + def repartition(self, numPartitions: int, *cols: ColumnOrName) -> DataFrame: ... + @overload + def repartition(self, *cols: ColumnOrName) -> DataFrame: ... + @overload + def repartitionByRange( + self, numPartitions: int, *cols: ColumnOrName + ) -> DataFrame: ... + @overload + def repartitionByRange(self, *cols: ColumnOrName) -> DataFrame: ... + def distinct(self) -> DataFrame: ... + @overload + def sample(self, fraction: float, seed: Optional[int] = ...) -> DataFrame: ... + @overload + def sample( + self, + withReplacement: Optional[bool], + fraction: float, + seed: Optional[int] = ..., + ) -> DataFrame: ... + def sampleBy( + self, col: str, fractions: Dict[Any, float], seed: Optional[int] = ... + ) -> DataFrame: ... + def randomSplit( + self, weights: List[float], seed: Optional[int] = ... + ) -> List[DataFrame]: ... + @property + def dtypes(self) -> List[Tuple[str, str]]: ... + @property + def columns(self) -> List[str]: ... + def colRegex(self, colName: str) -> Column: ... + def alias(self, alias: str) -> DataFrame: ... + def crossJoin(self, other: DataFrame) -> DataFrame: ... + def join( + self, + other: DataFrame, + on: Optional[Union[str, List[str], Column, List[Column]]] = ..., + how: Optional[str] = ..., + ) -> DataFrame: ... + def sortWithinPartitions( + self, + *cols: Union[str, Column, List[Union[str, Column]]], + ascending: Union[bool, List[bool]] = ... + ) -> DataFrame: ... + def sort( + self, + *cols: Union[str, Column, List[Union[str, Column]]], + ascending: Union[bool, List[bool]] = ... + ) -> DataFrame: ... + def orderBy( + self, + *cols: Union[str, Column, List[Union[str, Column]]], + ascending: Union[bool, List[bool]] = ... + ) -> DataFrame: ... + def describe(self, *cols: Union[str, List[str]]) -> DataFrame: ... + def summary(self, *statistics: str) -> DataFrame: ... + @overload + def head(self) -> Row: ... + @overload + def head(self, n: int) -> List[Row]: ... + def first(self) -> Row: ... + def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Column: ... + def __getattr__(self, name: str) -> Column: ... + @overload + def select(self, *cols: ColumnOrName) -> DataFrame: ... + @overload + def select(self, __cols: Union[List[Column], List[str]]) -> DataFrame: ... + @overload + def selectExpr(self, *expr: str) -> DataFrame: ... + @overload + def selectExpr(self, *expr: List[str]) -> DataFrame: ... + def filter(self, condition: ColumnOrName) -> DataFrame: ... + @overload + def groupBy(self, *cols: ColumnOrName) -> GroupedData: ... + @overload + def groupBy(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... + @overload + def rollup(self, *cols: ColumnOrName) -> GroupedData: ... + @overload + def rollup(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... + @overload + def cube(self, *cols: ColumnOrName) -> GroupedData: ... + @overload + def cube(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... + def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: ... + def union(self, other: DataFrame) -> DataFrame: ... + def unionAll(self, other: DataFrame) -> DataFrame: ... + def unionByName( + self, other: DataFrame, allowMissingColumns: bool = ... + ) -> DataFrame: ... + def intersect(self, other: DataFrame) -> DataFrame: ... + def intersectAll(self, other: DataFrame) -> DataFrame: ... + def subtract(self, other: DataFrame) -> DataFrame: ... + def dropDuplicates(self, subset: Optional[List[str]] = ...) -> DataFrame: ... + def dropna( + self, + how: str = ..., + thresh: Optional[int] = ..., + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def fillna( + self, + value: LiteralType, + subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ..., + ) -> DataFrame: ... + @overload + def fillna(self, value: Dict[str, LiteralType]) -> DataFrame: ... + @overload + def replace( + self, + to_replace: LiteralType, + value: OptionalPrimitiveType, + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def replace( + self, + to_replace: List[LiteralType], + value: List[OptionalPrimitiveType], + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def replace( + self, + to_replace: Dict[LiteralType, OptionalPrimitiveType], + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def replace( + self, + to_replace: List[LiteralType], + value: OptionalPrimitiveType, + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + def approxQuantile( + self, col: str, probabilities: List[float], relativeError: float + ) -> List[float]: ... + def corr(self, col1: str, col2: str, method: Optional[str] = ...) -> float: ... + def cov(self, col1: str, col2: str) -> float: ... + def crosstab(self, col1: str, col2: str) -> DataFrame: ... + def freqItems( + self, cols: List[str], support: Optional[float] = ... + ) -> DataFrame: ... + def withColumn(self, colName: str, col: Column) -> DataFrame: ... + def withColumnRenamed(self, existing: str, new: str) -> DataFrame: ... + @overload + def drop(self, cols: ColumnOrName) -> DataFrame: ... + @overload + def drop(self, *cols: str) -> DataFrame: ... + def toDF(self, *cols: ColumnOrName) -> DataFrame: ... + def transform(self, func: Callable[[DataFrame], DataFrame]) -> DataFrame: ... + @overload + def groupby(self, *cols: ColumnOrName) -> GroupedData: ... + @overload + def groupby(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... + def drop_duplicates(self, subset: Optional[List[str]] = ...) -> DataFrame: ... + def where(self, condition: ColumnOrName) -> DataFrame: ... + def sameSemantics(self, other: DataFrame) -> bool: ... + def semanticHash(self) -> int: ... + def inputFiles(self) -> List[str]: ... + def writeTo(self, table: str) -> DataFrameWriterV2: ... + +class DataFrameNaFunctions: + df: DataFrame + def __init__(self, df: DataFrame) -> None: ... + def drop( + self, + how: str = ..., + thresh: Optional[int] = ..., + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def fill( + self, value: LiteralType, subset: Optional[List[str]] = ... + ) -> DataFrame: ... + @overload + def fill(self, value: Dict[str, LiteralType]) -> DataFrame: ... + @overload + def replace( + self, + to_replace: LiteralType, + value: OptionalPrimitiveType, + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def replace( + self, + to_replace: List[LiteralType], + value: List[OptionalPrimitiveType], + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def replace( + self, + to_replace: Dict[LiteralType, OptionalPrimitiveType], + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + @overload + def replace( + self, + to_replace: List[LiteralType], + value: OptionalPrimitiveType, + subset: Optional[List[str]] = ..., + ) -> DataFrame: ... + +class DataFrameStatFunctions: + df: DataFrame + def __init__(self, df: DataFrame) -> None: ... + def approxQuantile( + self, col: str, probabilities: List[float], relativeError: float + ) -> List[float]: ... + def corr(self, col1: str, col2: str, method: Optional[str] = ...) -> float: ... + def cov(self, col1: str, col2: str) -> float: ... + def crosstab(self, col1: str, col2: str) -> DataFrame: ... + def freqItems( + self, cols: List[str], support: Optional[float] = ... + ) -> DataFrame: ... + def sampleBy( + self, col: str, fractions: Dict[Any, float], seed: Optional[int] = ... + ) -> DataFrame: ... diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi new file mode 100644 index 0000000000000..3b0b2030178ef --- /dev/null +++ b/python/pyspark/sql/functions.pyi @@ -0,0 +1,343 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Callable, Dict, List, Optional, Union + +from pyspark.sql._typing import ( + ColumnOrName, + DataTypeOrString, +) +from pyspark.sql.pandas.functions import ( # noqa: F401 + pandas_udf as pandas_udf, + PandasUDFType as PandasUDFType, +) +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import ( # noqa: F401 + ArrayType, + StringType, + StructType, + DataType, +) +from pyspark.sql.utils import to_str # noqa: F401 + +def approxCountDistinct(col: ColumnOrName, rsd: Optional[float] = ...) -> Column: ... +def approx_count_distinct(col: ColumnOrName, rsd: Optional[float] = ...) -> Column: ... +def broadcast(df: DataFrame) -> DataFrame: ... +def coalesce(*cols: ColumnOrName) -> Column: ... +def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: ... +def first(col: ColumnOrName, ignorenulls: bool = ...) -> Column: ... +def grouping(col: ColumnOrName) -> Column: ... +def grouping_id(*cols: ColumnOrName) -> Column: ... +def input_file_name() -> Column: ... +def isnan(col: ColumnOrName) -> Column: ... +def isnull(col: ColumnOrName) -> Column: ... +def last(col: ColumnOrName, ignorenulls: bool = ...) -> Column: ... +def monotonically_increasing_id() -> Column: ... +def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def percentile_approx( + col: ColumnOrName, + percentage: Union[Column, float, List[float]], + accuracy: Union[Column, float] = ..., +) -> Column: ... +def rand(seed: Optional[int] = ...) -> Column: ... +def randn(seed: Optional[int] = ...) -> Column: ... +def round(col: ColumnOrName, scale: int = ...) -> Column: ... +def bround(col: ColumnOrName, scale: int = ...) -> Column: ... +def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ... +def shiftRight(col: ColumnOrName, numBits: int) -> Column: ... +def shiftRightUnsigned(col, numBits) -> Column: ... +def spark_partition_id() -> Column: ... +def expr(str: str) -> Column: ... +def struct(*cols: ColumnOrName) -> Column: ... +def greatest(*cols: ColumnOrName) -> Column: ... +def least(*cols: Column) -> Column: ... +def when(condition: Column, value) -> Column: ... +@overload +def log(arg1: ColumnOrName) -> Column: ... +@overload +def log(arg1: float, arg2: ColumnOrName) -> Column: ... +def log2(col: ColumnOrName) -> Column: ... +def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: ... +def factorial(col: ColumnOrName) -> Column: ... +def lag( + col: ColumnOrName, offset: int = ..., default: Optional[Any] = ... +) -> Column: ... +def lead( + col: ColumnOrName, offset: int = ..., default: Optional[Any] = ... +) -> Column: ... +def ntile(n: int) -> Column: ... +def current_date() -> Column: ... +def current_timestamp() -> Column: ... +def date_format(date: ColumnOrName, format: str) -> Column: ... +def year(col: ColumnOrName) -> Column: ... +def quarter(col: ColumnOrName) -> Column: ... +def month(col: ColumnOrName) -> Column: ... +def dayofweek(col: ColumnOrName) -> Column: ... +def dayofmonth(col: ColumnOrName) -> Column: ... +def dayofyear(col: ColumnOrName) -> Column: ... +def hour(col: ColumnOrName) -> Column: ... +def minute(col: ColumnOrName) -> Column: ... +def second(col: ColumnOrName) -> Column: ... +def weekofyear(col: ColumnOrName) -> Column: ... +def date_add(start: ColumnOrName, days: int) -> Column: ... +def date_sub(start: ColumnOrName, days: int) -> Column: ... +def datediff(end: ColumnOrName, start: ColumnOrName) -> Column: ... +def add_months(start: ColumnOrName, months: int) -> Column: ... +def months_between( + date1: ColumnOrName, date2: ColumnOrName, roundOff: bool = ... +) -> Column: ... +def to_date(col: ColumnOrName, format: Optional[str] = ...) -> Column: ... +@overload +def to_timestamp(col: ColumnOrName) -> Column: ... +@overload +def to_timestamp(col: ColumnOrName, format: str) -> Column: ... +def trunc(date: ColumnOrName, format: str) -> Column: ... +def date_trunc(format: str, timestamp: ColumnOrName) -> Column: ... +def next_day(date: ColumnOrName, dayOfWeek: str) -> Column: ... +def last_day(date: ColumnOrName) -> Column: ... +def from_unixtime(timestamp: ColumnOrName, format: str = ...) -> Column: ... +def unix_timestamp( + timestamp: Optional[ColumnOrName] = ..., format: str = ... +) -> Column: ... +def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: ... +def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: ... +def timestamp_seconds(col: ColumnOrName) -> Column: ... +def window( + timeColumn: ColumnOrName, + windowDuration: str, + slideDuration: Optional[str] = ..., + startTime: Optional[str] = ..., +) -> Column: ... +def crc32(col: ColumnOrName) -> Column: ... +def md5(col: ColumnOrName) -> Column: ... +def sha1(col: ColumnOrName) -> Column: ... +def sha2(col: ColumnOrName, numBits: int) -> Column: ... +def hash(*cols: ColumnOrName) -> Column: ... +def xxhash64(*cols: ColumnOrName) -> Column: ... +def concat(*cols: ColumnOrName) -> Column: ... +def concat_ws(sep: str, *cols: ColumnOrName) -> Column: ... +def decode(col: ColumnOrName, charset: str) -> Column: ... +def encode(col: ColumnOrName, charset: str) -> Column: ... +def format_number(col: ColumnOrName, d: int) -> Column: ... +def format_string(format: str, *cols: ColumnOrName) -> Column: ... +def instr(str: ColumnOrName, substr: str) -> Column: ... +def overlay( + src: ColumnOrName, + replace: ColumnOrName, + pos: Union[Column, int], + len: Union[Column, int] = ..., +) -> Column: ... +def substring(str: ColumnOrName, pos: int, len: int) -> Column: ... +def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: ... +def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: ... +def locate(substr: str, str: Column, pos: int = ...) -> Column: ... +def lpad(col: Column, len: int, pad: str) -> Column: ... +def rpad(col: Column, len: int, pad: str) -> Column: ... +def repeat(col: Column, n: int) -> Column: ... +def split(str: Column, pattern: str, limit: int = ...) -> Column: ... +def regexp_extract(str: ColumnOrName, pattern: str, idx: int) -> Column: ... +def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: ... +def initcap(col: ColumnOrName) -> Column: ... +def soundex(col: ColumnOrName) -> Column: ... +def bin(col: ColumnOrName) -> Column: ... +def hex(col: ColumnOrName) -> Column: ... +def unhex(col: ColumnOrName) -> Column: ... +def length(col: ColumnOrName) -> Column: ... +def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: ... +def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def create_map(*cols: ColumnOrName) -> Column: ... +def array(*cols: ColumnOrName) -> Column: ... +def array_contains(col: ColumnOrName, value: Any) -> Column: ... +def arrays_overlap(a1: ColumnOrName, a2: ColumnOrName) -> Column: ... +def slice(x: ColumnOrName, start: int, length: int) -> Column: ... +def array_join( + col: ColumnOrName, delimiter: str, null_replacement: Optional[str] = ... +) -> Column: ... +def array_position(col: ColumnOrName, value: Any) -> Column: ... +def element_at(col: ColumnOrName, extraction: Any) -> Column: ... +def array_remove(col: ColumnOrName, element: Any) -> Column: ... +def array_distinct(col: ColumnOrName) -> Column: ... +def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +def explode(col: ColumnOrName) -> Column: ... +def explode_outer(col: ColumnOrName) -> Column: ... +def posexplode(col: ColumnOrName) -> Column: ... +def posexplode_outer(col: ColumnOrName) -> Column: ... +def get_json_object(col: ColumnOrName, path: str) -> Column: ... +def json_tuple(col: ColumnOrName, *fields: str) -> Column: ... +def from_json( + col: ColumnOrName, + schema: Union[ArrayType, StructType, Column, str], + options: Dict[str, str] = ..., +) -> Column: ... +def to_json(col: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... +def schema_of_json(json: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... +def schema_of_csv(csv: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... +def to_csv(col: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... +def size(col: ColumnOrName) -> Column: ... +def array_min(col: ColumnOrName) -> Column: ... +def array_max(col: ColumnOrName) -> Column: ... +def sort_array(col: ColumnOrName, asc: bool = ...) -> Column: ... +def array_sort(col: ColumnOrName) -> Column: ... +def shuffle(col: ColumnOrName) -> Column: ... +def reverse(col: ColumnOrName) -> Column: ... +def flatten(col: ColumnOrName) -> Column: ... +def map_keys(col: ColumnOrName) -> Column: ... +def map_values(col: ColumnOrName) -> Column: ... +def map_entries(col: ColumnOrName) -> Column: ... +def map_from_entries(col: ColumnOrName) -> Column: ... +def array_repeat(col: ColumnOrName, count: Union[Column, int]) -> Column: ... +def arrays_zip(*cols: ColumnOrName) -> Column: ... +def map_concat(*cols: ColumnOrName) -> Column: ... +def sequence( + start: ColumnOrName, stop: ColumnOrName, step: Optional[ColumnOrName] = ... +) -> Column: ... +def from_csv( + col: ColumnOrName, + schema: Union[StructType, Column, str], + options: Dict[str, str] = ..., +) -> Column: ... +@overload +def transform(col: ColumnOrName, f: Callable[[Column], Column]) -> Column: ... +@overload +def transform(col: ColumnOrName, f: Callable[[Column, Column], Column]) -> Column: ... +def exists(col: ColumnOrName, f: Callable[[Column], Column]) -> Column: ... +def forall(col: ColumnOrName, f: Callable[[Column], Column]) -> Column: ... +@overload +def filter(col: ColumnOrName, f: Callable[[Column], Column]) -> Column: ... +@overload +def filter(col: ColumnOrName, f: Callable[[Column, Column], Column]) -> Column: ... +def aggregate( + col: ColumnOrName, + zero: ColumnOrName, + merge: Callable[[Column, Column], Column], + finish: Optional[Callable[[Column], Column]] = ..., +) -> Column: ... +def zip_with( + col1: ColumnOrName, + ColumnOrName: ColumnOrName, + f: Callable[[Column, Column], Column], +) -> Column: ... +def transform_keys( + col: ColumnOrName, f: Callable[[Column, Column], Column] +) -> Column: ... +def transform_values( + col: ColumnOrName, f: Callable[[Column, Column], Column] +) -> Column: ... +def map_filter(col: ColumnOrName, f: Callable[[Column, Column], Column]) -> Column: ... +def map_zip_with( + col1: ColumnOrName, + col2: ColumnOrName, + f: Callable[[Column, Column, Column], Column], +) -> Column: ... +def abs(col: ColumnOrName) -> Column: ... +def acos(col: ColumnOrName) -> Column: ... +def asc(col: ColumnOrName) -> Column: ... +def asc_nulls_first(col: ColumnOrName) -> Column: ... +def asc_nulls_last(col: ColumnOrName) -> Column: ... +def ascii(col: ColumnOrName) -> Column: ... +def asin(col: ColumnOrName) -> Column: ... +def atan(col: ColumnOrName) -> Column: ... +@overload +def atan2(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +@overload +def atan2(col1: float, col2: ColumnOrName) -> Column: ... +@overload +def atan2(col1: ColumnOrName, col2: float) -> Column: ... +def avg(col: ColumnOrName) -> Column: ... +def base64(col: ColumnOrName) -> Column: ... +def bitwiseNOT(col: ColumnOrName) -> Column: ... +def cbrt(col: ColumnOrName) -> Column: ... +def ceil(col: ColumnOrName) -> Column: ... +def col(col: str) -> Column: ... +def collect_list(col: ColumnOrName) -> Column: ... +def collect_set(col: ColumnOrName) -> Column: ... +def column(col: str) -> Column: ... +def cos(col: ColumnOrName) -> Column: ... +def cosh(col: ColumnOrName) -> Column: ... +def count(col: ColumnOrName) -> Column: ... +def cume_dist() -> Column: ... +def degrees(col: ColumnOrName) -> Column: ... +def dense_rank() -> Column: ... +def desc(col: ColumnOrName) -> Column: ... +def desc_nulls_first(col: ColumnOrName) -> Column: ... +def desc_nulls_last(col: ColumnOrName) -> Column: ... +def exp(col: ColumnOrName) -> Column: ... +def expm1(col: ColumnOrName) -> Column: ... +def floor(col: ColumnOrName) -> Column: ... +@overload +def hypot(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +@overload +def hypot(col1: float, col2: ColumnOrName) -> Column: ... +@overload +def hypot(col1: ColumnOrName, col2: float) -> Column: ... +def kurtosis(col: ColumnOrName) -> Column: ... +def lit(col: Any) -> Column: ... +def log10(col: ColumnOrName) -> Column: ... +def log1p(col: ColumnOrName) -> Column: ... +def lower(col: ColumnOrName) -> Column: ... +def ltrim(col: ColumnOrName) -> Column: ... +def max(col: ColumnOrName) -> Column: ... +def mean(col: ColumnOrName) -> Column: ... +def min(col: ColumnOrName) -> Column: ... +def percent_rank() -> Column: ... +@overload +def pow(col1: ColumnOrName, col2: ColumnOrName) -> Column: ... +@overload +def pow(col1: float, col2: ColumnOrName) -> Column: ... +@overload +def pow(col1: ColumnOrName, col2: float) -> Column: ... +def radians(col: ColumnOrName) -> Column: ... +def rank() -> Column: ... +def rint(col: ColumnOrName) -> Column: ... +def row_number() -> Column: ... +def rtrim(col: ColumnOrName) -> Column: ... +def signum(col: ColumnOrName) -> Column: ... +def sin(col: ColumnOrName) -> Column: ... +def sinh(col: ColumnOrName) -> Column: ... +def skewness(col: ColumnOrName) -> Column: ... +def sqrt(col: ColumnOrName) -> Column: ... +def stddev(col: ColumnOrName) -> Column: ... +def stddev_pop(col: ColumnOrName) -> Column: ... +def stddev_samp(col: ColumnOrName) -> Column: ... +def sum(col: ColumnOrName) -> Column: ... +def sumDistinct(col: ColumnOrName) -> Column: ... +def tan(col: ColumnOrName) -> Column: ... +def tanh(col: ColumnOrName) -> Column: ... +def toDegrees(col: ColumnOrName) -> Column: ... +def toRadians(col: ColumnOrName) -> Column: ... +def trim(col: ColumnOrName) -> Column: ... +def unbase64(col: ColumnOrName) -> Column: ... +def upper(col: ColumnOrName) -> Column: ... +def var_pop(col: ColumnOrName) -> Column: ... +def var_samp(col: ColumnOrName) -> Column: ... +def variance(col: ColumnOrName) -> Column: ... +@overload +def udf( + f: Callable[..., Any], returnType: DataTypeOrString = ... +) -> Callable[..., Column]: ... +@overload +def udf( + f: DataTypeOrString = ..., +) -> Callable[[Callable[..., Any]], Callable[..., Column]]: ... diff --git a/python/pyspark/sql/group.pyi b/python/pyspark/sql/group.pyi new file mode 100644 index 0000000000000..0b0df8c63cfdd --- /dev/null +++ b/python/pyspark/sql/group.pyi @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Dict, List, Optional + +from pyspark.sql._typing import LiteralType +from pyspark.sql.context import SQLContext +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin +from py4j.java_gateway import JavaObject # type: ignore[import] + +class GroupedData(PandasGroupedOpsMixin): + sql_ctx: SQLContext + def __init__(self, jgd: JavaObject, df: DataFrame) -> None: ... + @overload + def agg(self, *exprs: Column) -> DataFrame: ... + @overload + def agg(self, __exprs: Dict[str, str]) -> DataFrame: ... + def count(self) -> DataFrame: ... + def mean(self, *cols: str) -> DataFrame: ... + def avg(self, *cols: str) -> DataFrame: ... + def max(self, *cols: str) -> DataFrame: ... + def min(self, *cols: str) -> DataFrame: ... + def sum(self, *cols: str) -> DataFrame: ... + def pivot( + self, pivot_col: str, values: Optional[List[LiteralType]] = ... + ) -> GroupedData: ... diff --git a/python/pyspark/sql/pandas/__init__.pyi b/python/pyspark/sql/pandas/__init__.pyi new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/python/pyspark/sql/pandas/__init__.pyi @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi new file mode 100644 index 0000000000000..dda1b3341b31c --- /dev/null +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -0,0 +1,338 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import ( + Any, + Callable, + Iterable, + NewType, + Tuple, + Type, + Union, +) +from typing_extensions import Protocol, Literal +from types import FunctionType + +from pyspark.sql._typing import LiteralType +from pyspark.sql.pandas._typing.protocols.frame import DataFrameLike as DataFrameLike +from pyspark.sql.pandas._typing.protocols.series import SeriesLike as SeriesLike + +import pandas.core.frame # type: ignore[import] +import pandas.core.series # type: ignore[import] + +# POC compatibility annotations +PandasDataFrame: Type[DataFrameLike] = pandas.core.frame.DataFrame +PandasSeries: Type[SeriesLike] = pandas.core.series.Series + +DataFrameOrSeriesLike = Union[DataFrameLike, SeriesLike] + +# UDF annotations +PandasScalarUDFType = Literal[200] +PandasScalarIterUDFType = Literal[204] +PandasGroupedMapUDFType = Literal[201] +PandasCogroupedMapUDFType = Literal[206] +PandasGroupedAggUDFType = Literal[202] +PandasMapIterUDFType = Literal[205] + +class PandasVariadicScalarToScalarFunction(Protocol): + def __call__(self, *_: DataFrameOrSeriesLike) -> SeriesLike: ... + +PandasScalarToScalarFunction = Union[ + PandasVariadicScalarToScalarFunction, + Callable[[DataFrameOrSeriesLike], SeriesLike], + Callable[[DataFrameOrSeriesLike, DataFrameOrSeriesLike], SeriesLike], + Callable[ + [DataFrameOrSeriesLike, DataFrameOrSeriesLike, DataFrameOrSeriesLike], + SeriesLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + SeriesLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + SeriesLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + SeriesLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + SeriesLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + SeriesLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + SeriesLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + SeriesLike, + ], +] + +class PandasVariadicScalarToStructFunction(Protocol): + def __call__(self, *_: DataFrameOrSeriesLike) -> DataFrameLike: ... + +PandasScalarToStructFunction = Union[ + PandasVariadicScalarToStructFunction, + Callable[[DataFrameOrSeriesLike], DataFrameLike], + Callable[[DataFrameOrSeriesLike, DataFrameOrSeriesLike], DataFrameLike], + Callable[ + [DataFrameOrSeriesLike, DataFrameOrSeriesLike, DataFrameOrSeriesLike], + DataFrameLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + DataFrameLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + DataFrameLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + DataFrameLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + DataFrameLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + DataFrameLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + DataFrameLike, + ], + Callable[ + [ + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + DataFrameOrSeriesLike, + ], + DataFrameLike, + ], +] + +PandasScalarIterFunction = Callable[ + [Iterable[Union[DataFrameOrSeriesLike, Tuple[DataFrameOrSeriesLike, ...]]]], + Iterable[SeriesLike], +] + +PandasGroupedMapFunction = Union[ + Callable[[DataFrameLike], DataFrameLike], + Callable[[Any, DataFrameLike], DataFrameLike], +] + +class PandasVariadicGroupedAggFunction(Protocol): + def __call__(self, *_: SeriesLike) -> LiteralType: ... + +PandasGroupedAggFunction = Union[ + Callable[[SeriesLike], LiteralType], + Callable[[SeriesLike, SeriesLike], LiteralType], + Callable[[SeriesLike, SeriesLike, SeriesLike], LiteralType], + Callable[[SeriesLike, SeriesLike, SeriesLike, SeriesLike], LiteralType], + Callable[[SeriesLike, SeriesLike, SeriesLike, SeriesLike, SeriesLike], LiteralType], + Callable[ + [SeriesLike, SeriesLike, SeriesLike, SeriesLike, SeriesLike, SeriesLike], + LiteralType, + ], + Callable[ + [ + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + ], + LiteralType, + ], + Callable[ + [ + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + ], + LiteralType, + ], + Callable[ + [ + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + ], + LiteralType, + ], + Callable[ + [ + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + SeriesLike, + ], + LiteralType, + ], + PandasVariadicGroupedAggFunction, +] + +PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]] + +PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike], DataFrameLike] + +MapIterPandasUserDefinedFunction = NewType( + "MapIterPandasUserDefinedFunction", FunctionType +) +GroupedMapPandasUserDefinedFunction = NewType( + "GroupedMapPandasUserDefinedFunction", FunctionType +) +CogroupedMapPandasUserDefinedFunction = NewType( + "CogroupedMapPandasUserDefinedFunction", FunctionType +) diff --git a/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi b/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/python/pyspark/sql/pandas/_typing/protocols/__init__.pyi @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/pyspark/sql/pandas/_typing/protocols/frame.pyi b/python/pyspark/sql/pandas/_typing/protocols/frame.pyi new file mode 100644 index 0000000000000..de679ee2cd017 --- /dev/null +++ b/python/pyspark/sql/pandas/_typing/protocols/frame.pyi @@ -0,0 +1,428 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This Protocol resuses core Pandas annotation. +# Overall pipeline looks as follows +# - Stubgen pandas.core.frame +# - Add Protocol as a base class +# - Replace imports with Any + +import numpy.ma as np # type: ignore[import] +from typing import Any, Hashable, IO, Iterable, List, Optional, Sequence, Tuple, Union +from typing_extensions import Protocol +from .series import SeriesLike + +Axes = Any +Dtype = Any +Index = Any +Renamer = Any +Axis = Any +Level = Any + +class DataFrameLike(Protocol): + def __init__( + self, + data: Any = ..., + index: Optional[Axes] = ..., + columns: Optional[Axes] = ..., + dtype: Optional[Dtype] = ..., + copy: bool = ..., + ) -> None: ... + @property + def axes(self) -> List[Index]: ... + @property + def shape(self) -> Tuple[int, int]: ... + @property + def style(self) -> Any: ... + def items(self) -> Iterable[Tuple[Optional[Hashable], SeriesLike]]: ... + def iteritems(self) -> Iterable[Tuple[Optional[Hashable], SeriesLike]]: ... + def iterrows(self) -> Iterable[Tuple[Optional[Hashable], SeriesLike]]: ... + def itertuples(self, index: bool = ..., name: str = ...): ... + def __len__(self) -> int: ... + def dot(self, other: Any): ... + def __matmul__(self, other: Any): ... + def __rmatmul__(self, other: Any): ... + @classmethod + def from_dict( + cls: Any, data: Any, orient: Any = ..., dtype: Any = ..., columns: Any = ... + ) -> DataFrameLike: ... + def to_numpy(self, dtype: Any = ..., copy: Any = ...) -> np.ndarray: ... + def to_dict(self, orient: str = ..., into: Any = ...): ... + def to_gbq( + self, + destination_table: Any, + project_id: Any = ..., + chunksize: Any = ..., + reauth: Any = ..., + if_exists: Any = ..., + auth_local_webserver: Any = ..., + table_schema: Any = ..., + location: Any = ..., + progress_bar: Any = ..., + credentials: Any = ..., + ) -> None: ... + @classmethod + def from_records( + cls: Any, + data: Any, + index: Any = ..., + exclude: Any = ..., + columns: Any = ..., + coerce_float: Any = ..., + nrows: Any = ..., + ) -> DataFrameLike: ... + def to_records( + self, index: Any = ..., column_dtypes: Any = ..., index_dtypes: Any = ... + ) -> np.recarray: ... + def to_stata( + self, + path: Any, + convert_dates: Optional[Any] = ..., + write_index: bool = ..., + byteorder: Optional[Any] = ..., + time_stamp: Optional[Any] = ..., + data_label: Optional[Any] = ..., + variable_labels: Optional[Any] = ..., + version: int = ..., + convert_strl: Optional[Any] = ..., + ) -> None: ... + def to_feather(self, path: Any) -> None: ... + def to_markdown( + self, buf: Optional[IO[str]] = ..., mode: Optional[str] = ..., **kwargs: Any + ) -> Optional[str]: ... + def to_parquet( + self, + path: Any, + engine: Any = ..., + compression: Any = ..., + index: Any = ..., + partition_cols: Any = ..., + **kwargs: Any + ) -> None: ... + def to_html( + self, + buf: Optional[Any] = ..., + columns: Optional[Any] = ..., + col_space: Optional[Any] = ..., + header: bool = ..., + index: bool = ..., + na_rep: str = ..., + formatters: Optional[Any] = ..., + float_format: Optional[Any] = ..., + sparsify: Optional[Any] = ..., + index_names: bool = ..., + justify: Optional[Any] = ..., + max_rows: Optional[Any] = ..., + max_cols: Optional[Any] = ..., + show_dimensions: bool = ..., + decimal: str = ..., + bold_rows: bool = ..., + classes: Optional[Any] = ..., + escape: bool = ..., + notebook: bool = ..., + border: Optional[Any] = ..., + table_id: Optional[Any] = ..., + render_links: bool = ..., + encoding: Optional[Any] = ..., + ): ... + def info( + self, + verbose: Any = ..., + buf: Any = ..., + max_cols: Any = ..., + memory_usage: Any = ..., + null_counts: Any = ..., + ) -> None: ... + def memory_usage(self, index: Any = ..., deep: Any = ...) -> SeriesLike: ... + def transpose(self, *args: Any, copy: bool = ...) -> DataFrameLike: ... + T: Any = ... + def __getitem__(self, key: Any): ... + def __setitem__(self, key: Any, value: Any): ... + def query(self, expr: Any, inplace: bool = ..., **kwargs: Any): ... + def eval(self, expr: Any, inplace: bool = ..., **kwargs: Any): ... + def select_dtypes( + self, include: Any = ..., exclude: Any = ... + ) -> DataFrameLike: ... + def insert( + self, loc: Any, column: Any, value: Any, allow_duplicates: Any = ... + ) -> None: ... + def assign(self, **kwargs: Any) -> DataFrameLike: ... + def lookup(self, row_labels: Any, col_labels: Any) -> np.ndarray: ... + def align( + self, + other: Any, + join: Any = ..., + axis: Any = ..., + level: Any = ..., + copy: Any = ..., + fill_value: Any = ..., + method: Any = ..., + limit: Any = ..., + fill_axis: Any = ..., + broadcast_axis: Any = ..., + ) -> DataFrameLike: ... + def reindex(self, *args: Any, **kwargs: Any) -> DataFrameLike: ... + def drop( + self, + labels: Optional[Any] = ..., + axis: int = ..., + index: Optional[Any] = ..., + columns: Optional[Any] = ..., + level: Optional[Any] = ..., + inplace: bool = ..., + errors: str = ..., + ): ... + def rename( + self, + mapper: Optional[Renamer] = ..., + *, + index: Optional[Renamer] = ..., + columns: Optional[Renamer] = ..., + axis: Optional[Axis] = ..., + copy: bool = ..., + inplace: bool = ..., + level: Optional[Level] = ..., + errors: str = ... + ) -> Optional[DataFrameLike]: ... + def fillna( + self, + value: Any = ..., + method: Any = ..., + axis: Any = ..., + inplace: Any = ..., + limit: Any = ..., + downcast: Any = ..., + ) -> Optional[DataFrameLike]: ... + def replace( + self, + to_replace: Optional[Any] = ..., + value: Optional[Any] = ..., + inplace: bool = ..., + limit: Optional[Any] = ..., + regex: bool = ..., + method: str = ..., + ): ... + def shift( + self, + periods: Any = ..., + freq: Any = ..., + axis: Any = ..., + fill_value: Any = ..., + ) -> DataFrameLike: ... + def set_index( + self, + keys: Any, + drop: bool = ..., + append: bool = ..., + inplace: bool = ..., + verify_integrity: bool = ..., + ): ... + def reset_index( + self, + level: Optional[Union[Hashable, Sequence[Hashable]]] = ..., + drop: bool = ..., + inplace: bool = ..., + col_level: Hashable = ..., + col_fill: Optional[Hashable] = ..., + ) -> Optional[DataFrameLike]: ... + def isna(self) -> DataFrameLike: ... + def isnull(self) -> DataFrameLike: ... + def notna(self) -> DataFrameLike: ... + def notnull(self) -> DataFrameLike: ... + def dropna( + self, + axis: int = ..., + how: str = ..., + thresh: Optional[Any] = ..., + subset: Optional[Any] = ..., + inplace: bool = ..., + ): ... + def drop_duplicates( + self, + subset: Optional[Union[Hashable, Sequence[Hashable]]] = ..., + keep: Union[str, bool] = ..., + inplace: bool = ..., + ignore_index: bool = ..., + ) -> Optional[DataFrameLike]: ... + def duplicated( + self, + subset: Optional[Union[Hashable, Sequence[Hashable]]] = ..., + keep: Union[str, bool] = ..., + ) -> SeriesLike: ... + def sort_values( + self, + by: Any, + axis: int = ..., + ascending: bool = ..., + inplace: bool = ..., + kind: str = ..., + na_position: str = ..., + ignore_index: bool = ..., + ): ... + def sort_index( + self, + axis: Any = ..., + level: Any = ..., + ascending: Any = ..., + inplace: Any = ..., + kind: Any = ..., + na_position: Any = ..., + sort_remaining: Any = ..., + ignore_index: bool = ..., + ) -> Any: ... + def nlargest(self, n: Any, columns: Any, keep: Any = ...) -> DataFrameLike: ... + def nsmallest(self, n: Any, columns: Any, keep: Any = ...) -> DataFrameLike: ... + def swaplevel( + self, i: Any = ..., j: Any = ..., axis: Any = ... + ) -> DataFrameLike: ... + def reorder_levels(self, order: Any, axis: Any = ...) -> DataFrameLike: ... + def combine( + self, + other: DataFrameLike, + func: Any, + fill_value: Any = ..., + overwrite: Any = ..., + ) -> DataFrameLike: ... + def combine_first(self, other: DataFrameLike) -> DataFrameLike: ... + def update( + self, + other: Any, + join: Any = ..., + overwrite: Any = ..., + filter_func: Any = ..., + errors: Any = ..., + ) -> None: ... + def groupby( + self, + by: Any = ..., + axis: Any = ..., + level: Any = ..., + as_index: bool = ..., + sort: bool = ..., + group_keys: bool = ..., + squeeze: bool = ..., + observed: bool = ..., + ) -> Any: ... + def pivot( + self, index: Any = ..., columns: Any = ..., values: Any = ... + ) -> DataFrameLike: ... + def pivot_table( + self, + values: Any = ..., + index: Any = ..., + columns: Any = ..., + aggfunc: Any = ..., + fill_value: Any = ..., + margins: Any = ..., + dropna: Any = ..., + margins_name: Any = ..., + observed: Any = ..., + ) -> DataFrameLike: ... + def stack(self, level: int = ..., dropna: bool = ...): ... + def explode(self, column: Union[str, Tuple]) -> DataFrameLike: ... + def unstack(self, level: int = ..., fill_value: Optional[Any] = ...): ... + def melt( + self, + id_vars: Any = ..., + value_vars: Any = ..., + var_name: Any = ..., + value_name: Any = ..., + col_level: Any = ..., + ) -> DataFrameLike: ... + def diff(self, periods: Any = ..., axis: Any = ...) -> DataFrameLike: ... + def aggregate(self, func: Any, axis: int = ..., *args: Any, **kwargs: Any): ... + agg: Any = ... + def transform( + self, func: Any, axis: Any = ..., *args: Any, **kwargs: Any + ) -> DataFrameLike: ... + def apply( + self, + func: Any, + axis: int = ..., + raw: bool = ..., + result_type: Optional[Any] = ..., + args: Any = ..., + **kwds: Any + ): ... + def applymap(self, func: Any) -> DataFrameLike: ... + def append( + self, + other: Any, + ignore_index: Any = ..., + verify_integrity: Any = ..., + sort: Any = ..., + ) -> DataFrameLike: ... + def join( + self, + other: Any, + on: Any = ..., + how: Any = ..., + lsuffix: Any = ..., + rsuffix: Any = ..., + sort: Any = ..., + ) -> DataFrameLike: ... + def merge( + self, + right: Any, + how: Any = ..., + on: Any = ..., + left_on: Any = ..., + right_on: Any = ..., + left_index: Any = ..., + right_index: Any = ..., + sort: Any = ..., + suffixes: Any = ..., + copy: Any = ..., + indicator: Any = ..., + validate: Any = ..., + ) -> DataFrameLike: ... + def round( + self, decimals: Any = ..., *args: Any, **kwargs: Any + ) -> DataFrameLike: ... + def corr(self, method: Any = ..., min_periods: Any = ...) -> DataFrameLike: ... + def cov(self, min_periods: Any = ...) -> DataFrameLike: ... + def corrwith( + self, other: Any, axis: Any = ..., drop: Any = ..., method: Any = ... + ) -> SeriesLike: ... + def count( + self, axis: int = ..., level: Optional[Any] = ..., numeric_only: bool = ... + ): ... + def nunique(self, axis: Any = ..., dropna: Any = ...) -> SeriesLike: ... + def idxmin(self, axis: Any = ..., skipna: Any = ...) -> SeriesLike: ... + def idxmax(self, axis: Any = ..., skipna: Any = ...) -> SeriesLike: ... + def mode( + self, axis: Any = ..., numeric_only: Any = ..., dropna: Any = ... + ) -> DataFrameLike: ... + def quantile( + self, + q: float = ..., + axis: int = ..., + numeric_only: bool = ..., + interpolation: str = ..., + ): ... + def to_timestamp( + self, freq: Any = ..., how: Any = ..., axis: Any = ..., copy: Any = ... + ) -> DataFrameLike: ... + def to_period( + self, freq: Any = ..., axis: Any = ..., copy: Any = ... + ) -> DataFrameLike: ... + def isin(self, values: Any) -> DataFrameLike: ... + plot: Any = ... + hist: Any = ... + boxplot: Any = ... + sparse: Any = ... diff --git a/python/pyspark/sql/pandas/_typing/protocols/series.pyi b/python/pyspark/sql/pandas/_typing/protocols/series.pyi new file mode 100644 index 0000000000000..14babb067da0d --- /dev/null +++ b/python/pyspark/sql/pandas/_typing/protocols/series.pyi @@ -0,0 +1,253 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This Protocol resuses core Pandas annotation. +# Overall pipeline looks as follows +# - Stubgen pandas.core.series +# - Add Protocol as a base class +# - Replace imports with Any + +import numpy as np # type: ignore[import] +from typing import Any, Callable, Hashable, IO, Optional +from typing_extensions import Protocol + +groupby_generic = Any + +class SeriesLike(Protocol): + hasnans: Any = ... + div: Callable[[SeriesLike, Any], SeriesLike] + rdiv: Callable[[SeriesLike, Any], SeriesLike] + def __init__( + self, + data: Optional[Any] = ..., + index: Optional[Any] = ..., + dtype: Optional[Any] = ..., + name: Optional[Any] = ..., + copy: bool = ..., + fastpath: bool = ..., + ) -> None: ... + @property + def dtype(self): ... + @property + def dtypes(self): ... + @property + def name(self) -> Optional[Hashable]: ... + @name.setter + def name(self, value: Optional[Hashable]) -> None: ... + @property + def values(self): ... + def ravel(self, order: str = ...): ... + def __len__(self) -> int: ... + def view(self, dtype: Optional[Any] = ...): ... + def __array_ufunc__( + self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any + ) -> Any: ... + def __array__(self, dtype: Any = ...) -> np.ndarray: ... + __float__: Any = ... + __long__: Any = ... + __int__: Any = ... + @property + def axes(self): ... + def take( + self, indices: Any, axis: int = ..., is_copy: bool = ..., **kwargs: Any + ): ... + def __getitem__(self, key: Any): ... + def __setitem__(self, key: Any, value: Any) -> None: ... + def repeat(self, repeats: Any, axis: Optional[Any] = ...): ... + index: Any = ... + def reset_index( + self, + level: Optional[Any] = ..., + drop: bool = ..., + name: Optional[Any] = ..., + inplace: bool = ..., + ): ... + def to_string( + self, + buf: Optional[Any] = ..., + na_rep: str = ..., + float_format: Optional[Any] = ..., + header: bool = ..., + index: bool = ..., + length: bool = ..., + dtype: bool = ..., + name: bool = ..., + max_rows: Optional[Any] = ..., + min_rows: Optional[Any] = ..., + ): ... + def to_markdown( + self, buf: Optional[IO[str]] = ..., mode: Optional[str] = ..., **kwargs: Any + ) -> Optional[str]: ... + def items(self): ... + def iteritems(self): ... + def keys(self): ... + def to_dict(self, into: Any = ...): ... + def to_frame(self, name: Optional[Any] = ...): ... + def groupby( + self, + by: Any = ..., + axis: Any = ..., + level: Any = ..., + as_index: bool = ..., + sort: bool = ..., + group_keys: bool = ..., + squeeze: bool = ..., + observed: bool = ..., + ) -> Any: ... + def count(self, level: Optional[Any] = ...): ... + def mode(self, dropna: bool = ...): ... + def unique(self): ... + def drop_duplicates(self, keep: str = ..., inplace: bool = ...): ... + def duplicated(self, keep: str = ...): ... + def idxmin( + self, axis: int = ..., skipna: bool = ..., *args: Any, **kwargs: Any + ): ... + def idxmax( + self, axis: int = ..., skipna: bool = ..., *args: Any, **kwargs: Any + ): ... + def round(self, decimals: int = ..., *args: Any, **kwargs: Any): ... + def quantile(self, q: float = ..., interpolation: str = ...): ... + def corr(self, other: Any, method: str = ..., min_periods: Optional[Any] = ...): ... + def cov(self, other: Any, min_periods: Optional[Any] = ...): ... + def diff(self, periods: int = ...): ... + def autocorr(self, lag: int = ...): ... + def dot(self, other: Any): ... + def __matmul__(self, other: Any): ... + def __rmatmul__(self, other: Any): ... + def searchsorted( + self, value: Any, side: str = ..., sorter: Optional[Any] = ... + ): ... + def append( + self, to_append: Any, ignore_index: bool = ..., verify_integrity: bool = ... + ): ... + def combine(self, other: Any, func: Any, fill_value: Optional[Any] = ...): ... + def combine_first(self, other: Any): ... + def update(self, other: Any) -> None: ... + def sort_values( + self, + axis: int = ..., + ascending: bool = ..., + inplace: bool = ..., + kind: str = ..., + na_position: str = ..., + ignore_index: bool = ..., + ): ... + def sort_index( + self, + axis: Any = ..., + level: Any = ..., + ascending: Any = ..., + inplace: Any = ..., + kind: Any = ..., + na_position: Any = ..., + sort_remaining: Any = ..., + ignore_index: bool = ..., + ) -> Any: ... + def argsort(self, axis: int = ..., kind: str = ..., order: Optional[Any] = ...): ... + def nlargest(self, n: int = ..., keep: str = ...): ... + def nsmallest(self, n: int = ..., keep: str = ...): ... + def swaplevel(self, i: int = ..., j: int = ..., copy: bool = ...): ... + def reorder_levels(self, order: Any): ... + def explode(self) -> SeriesLike: ... + def unstack(self, level: int = ..., fill_value: Optional[Any] = ...): ... + def map(self, arg: Any, na_action: Optional[Any] = ...): ... + def aggregate(self, func: Any, axis: int = ..., *args: Any, **kwargs: Any): ... + agg: Any = ... + def transform(self, func: Any, axis: int = ..., *args: Any, **kwargs: Any): ... + def apply( + self, func: Any, convert_dtype: bool = ..., args: Any = ..., **kwds: Any + ): ... + def align( + self, + other: Any, + join: str = ..., + axis: Optional[Any] = ..., + level: Optional[Any] = ..., + copy: bool = ..., + fill_value: Optional[Any] = ..., + method: Optional[Any] = ..., + limit: Optional[Any] = ..., + fill_axis: int = ..., + broadcast_axis: Optional[Any] = ..., + ): ... + def rename( + self, + index: Optional[Any] = ..., + *, + axis: Optional[Any] = ..., + copy: bool = ..., + inplace: bool = ..., + level: Optional[Any] = ..., + errors: str = ... + ): ... + def reindex(self, index: Optional[Any] = ..., **kwargs: Any): ... + def drop( + self, + labels: Optional[Any] = ..., + axis: int = ..., + index: Optional[Any] = ..., + columns: Optional[Any] = ..., + level: Optional[Any] = ..., + inplace: bool = ..., + errors: str = ..., + ): ... + def fillna( + self, + value: Any = ..., + method: Any = ..., + axis: Any = ..., + inplace: Any = ..., + limit: Any = ..., + downcast: Any = ..., + ) -> Optional[SeriesLike]: ... + def replace( + self, + to_replace: Optional[Any] = ..., + value: Optional[Any] = ..., + inplace: bool = ..., + limit: Optional[Any] = ..., + regex: bool = ..., + method: str = ..., + ): ... + def shift( + self, + periods: int = ..., + freq: Optional[Any] = ..., + axis: int = ..., + fill_value: Optional[Any] = ..., + ): ... + def memory_usage(self, index: bool = ..., deep: bool = ...): ... + def isin(self, values: Any): ... + def between(self, left: Any, right: Any, inclusive: bool = ...): ... + def isna(self): ... + def isnull(self): ... + def notna(self): ... + def notnull(self): ... + def dropna( + self, axis: int = ..., inplace: bool = ..., how: Optional[Any] = ... + ): ... + def to_timestamp( + self, freq: Optional[Any] = ..., how: str = ..., copy: bool = ... + ): ... + def to_period(self, freq: Optional[Any] = ..., copy: bool = ...): ... + str: Any = ... + dt: Any = ... + cat: Any = ... + plot: Any = ... + sparse: Any = ... + hist: Any = ... diff --git a/python/pyspark/sql/pandas/conversion.pyi b/python/pyspark/sql/pandas/conversion.pyi new file mode 100644 index 0000000000000..031852fcc053d --- /dev/null +++ b/python/pyspark/sql/pandas/conversion.pyi @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Optional, Union + +from pyspark.sql.pandas._typing import DataFrameLike +from pyspark import since as since # noqa: F401 +from pyspark.rdd import RDD # noqa: F401 +import pyspark.sql.dataframe +from pyspark.sql.pandas.serializers import ( # noqa: F401 + ArrowCollectSerializer as ArrowCollectSerializer, +) +from pyspark.sql.types import ( # noqa: F401 + BooleanType as BooleanType, + ByteType as ByteType, + DataType as DataType, + DoubleType as DoubleType, + FloatType as FloatType, + IntegerType as IntegerType, + IntegralType as IntegralType, + LongType as LongType, + ShortType as ShortType, + StructType as StructType, + TimestampType as TimestampType, +) +from pyspark.traceback_utils import SCCallSiteSync as SCCallSiteSync # noqa: F401 + +class PandasConversionMixin: + def toPandas(self) -> DataFrameLike: ... + +class SparkConversionMixin: + @overload + def createDataFrame( + self, data: DataFrameLike, samplingRatio: Optional[float] = ... + ) -> pyspark.sql.dataframe.DataFrame: ... + @overload + def createDataFrame( + self, + data: DataFrameLike, + schema: Union[StructType, str], + verifySchema: bool = ..., + ) -> pyspark.sql.dataframe.DataFrame: ... diff --git a/python/pyspark/sql/pandas/functions.pyi b/python/pyspark/sql/pandas/functions.pyi new file mode 100644 index 0000000000000..09318e43f8aa1 --- /dev/null +++ b/python/pyspark/sql/pandas/functions.pyi @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Union, Callable + +from pyspark.sql._typing import ( + AtomicDataTypeOrString, + UserDefinedFunctionLike, +) +from pyspark.sql.pandas._typing import ( + GroupedMapPandasUserDefinedFunction, + MapIterPandasUserDefinedFunction, + CogroupedMapPandasUserDefinedFunction, + PandasCogroupedMapFunction, + PandasCogroupedMapUDFType, + PandasGroupedAggFunction, + PandasGroupedAggUDFType, + PandasGroupedMapFunction, + PandasGroupedMapUDFType, + PandasMapIterFunction, + PandasMapIterUDFType, + PandasScalarIterFunction, + PandasScalarIterUDFType, + PandasScalarToScalarFunction, + PandasScalarToStructFunction, + PandasScalarUDFType, +) + +from pyspark import since as since # noqa: F401 +from pyspark.rdd import PythonEvalType as PythonEvalType # noqa: F401 +from pyspark.sql.types import ArrayType, StructType + +class PandasUDFType: + SCALAR: PandasScalarUDFType + SCALAR_ITER: PandasScalarIterUDFType + GROUPED_MAP: PandasGroupedMapUDFType + GROUPED_AGG: PandasGroupedAggUDFType + +@overload +def pandas_udf( + f: PandasScalarToScalarFunction, + returnType: Union[AtomicDataTypeOrString, ArrayType], + functionType: PandasScalarUDFType, +) -> UserDefinedFunctionLike: ... +@overload +def pandas_udf(f: Union[AtomicDataTypeOrString, ArrayType], returnType: PandasScalarUDFType) -> Callable[[PandasScalarToScalarFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +@overload +def pandas_udf(f: Union[AtomicDataTypeOrString, ArrayType], *, functionType: PandasScalarUDFType) -> Callable[[PandasScalarToScalarFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +@overload +def pandas_udf(*, returnType: Union[AtomicDataTypeOrString, ArrayType], functionType: PandasScalarUDFType) -> Callable[[PandasScalarToScalarFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +@overload +def pandas_udf( + f: PandasScalarToStructFunction, + returnType: Union[StructType, str], + functionType: PandasScalarUDFType, +) -> UserDefinedFunctionLike: ... +@overload +def pandas_udf(f: Union[StructType, str], returnType: PandasScalarUDFType) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +@overload +def pandas_udf(f: Union[StructType, str], *, functionType: PandasScalarUDFType) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +@overload +def pandas_udf(*, returnType: Union[StructType, str], functionType: PandasScalarUDFType) -> Callable[[PandasScalarToStructFunction], UserDefinedFunctionLike]: ... # type: ignore[misc] +@overload +def pandas_udf( + f: PandasScalarIterFunction, + returnType: Union[AtomicDataTypeOrString, ArrayType], + functionType: PandasScalarIterUDFType, +) -> UserDefinedFunctionLike: ... +@overload +def pandas_udf( + f: Union[AtomicDataTypeOrString, ArrayType], returnType: PandasScalarIterUDFType +) -> Callable[[PandasScalarIterFunction], UserDefinedFunctionLike]: ... +@overload +def pandas_udf( + *, + returnType: Union[AtomicDataTypeOrString, ArrayType], + functionType: PandasScalarIterUDFType +) -> Callable[[PandasScalarIterFunction], UserDefinedFunctionLike]: ... +@overload +def pandas_udf( + f: Union[AtomicDataTypeOrString, ArrayType], + *, + functionType: PandasScalarIterUDFType +) -> Callable[[PandasScalarIterFunction], UserDefinedFunctionLike]: ... +@overload +def pandas_udf( + f: PandasGroupedMapFunction, + returnType: Union[StructType, str], + functionType: PandasGroupedMapUDFType, +) -> GroupedMapPandasUserDefinedFunction: ... +@overload +def pandas_udf( + f: Union[StructType, str], returnType: PandasGroupedMapUDFType +) -> Callable[[PandasGroupedMapFunction], GroupedMapPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + *, returnType: Union[StructType, str], functionType: PandasGroupedMapUDFType +) -> Callable[[PandasGroupedMapFunction], GroupedMapPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + f: Union[StructType, str], *, functionType: PandasGroupedMapUDFType +) -> Callable[[PandasGroupedMapFunction], GroupedMapPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + f: PandasGroupedAggFunction, + returnType: Union[AtomicDataTypeOrString, ArrayType], + functionType: PandasGroupedAggUDFType, +) -> UserDefinedFunctionLike: ... +@overload +def pandas_udf( + f: Union[AtomicDataTypeOrString, ArrayType], returnType: PandasGroupedAggUDFType +) -> Callable[[PandasGroupedAggFunction], UserDefinedFunctionLike]: ... +@overload +def pandas_udf( + *, + returnType: Union[AtomicDataTypeOrString, ArrayType], + functionType: PandasGroupedAggUDFType +) -> Callable[[PandasGroupedAggFunction], UserDefinedFunctionLike]: ... +@overload +def pandas_udf( + f: Union[AtomicDataTypeOrString, ArrayType], + *, + functionType: PandasGroupedAggUDFType +) -> Callable[[PandasGroupedAggFunction], UserDefinedFunctionLike]: ... +@overload +def pandas_udf( + f: PandasMapIterFunction, + returnType: Union[StructType, str], + functionType: PandasMapIterUDFType, +) -> MapIterPandasUserDefinedFunction: ... +@overload +def pandas_udf( + f: Union[StructType, str], returnType: PandasMapIterUDFType +) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + *, returnType: Union[StructType, str], functionType: PandasMapIterUDFType +) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + f: Union[StructType, str], *, functionType: PandasMapIterUDFType +) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + f: PandasCogroupedMapFunction, + returnType: Union[StructType, str], + functionType: PandasCogroupedMapUDFType, +) -> CogroupedMapPandasUserDefinedFunction: ... +@overload +def pandas_udf( + f: Union[StructType, str], returnType: PandasCogroupedMapUDFType +) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + *, returnType: Union[StructType, str], functionType: PandasCogroupedMapUDFType +) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ... +@overload +def pandas_udf( + f: Union[StructType, str], *, functionType: PandasCogroupedMapUDFType +) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ... diff --git a/python/pyspark/sql/pandas/group_ops.pyi b/python/pyspark/sql/pandas/group_ops.pyi new file mode 100644 index 0000000000000..2c543e0dc77b9 --- /dev/null +++ b/python/pyspark/sql/pandas/group_ops.pyi @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Union + +from pyspark.sql.pandas._typing import ( + GroupedMapPandasUserDefinedFunction, + PandasGroupedMapFunction, + PandasCogroupedMapFunction, +) + +from pyspark import since as since # noqa: F401 +from pyspark.rdd import PythonEvalType as PythonEvalType # noqa: F401 +from pyspark.sql.column import Column as Column # noqa: F401 +from pyspark.sql.context import SQLContext +import pyspark.sql.group +from pyspark.sql.dataframe import DataFrame as DataFrame +from pyspark.sql.types import StructType + +class PandasGroupedOpsMixin: + def cogroup(self, other: pyspark.sql.group.GroupedData) -> PandasCogroupedOps: ... + def apply(self, udf: GroupedMapPandasUserDefinedFunction) -> DataFrame: ... + def applyInPandas( + self, func: PandasGroupedMapFunction, schema: Union[StructType, str] + ) -> DataFrame: ... + +class PandasCogroupedOps: + sql_ctx: SQLContext + def __init__( + self, gd1: pyspark.sql.group.GroupedData, gd2: pyspark.sql.group.GroupedData + ) -> None: ... + def applyInPandas( + self, func: PandasCogroupedMapFunction, schema: Union[StructType, str] + ) -> DataFrame: ... diff --git a/python/pyspark/sql/pandas/map_ops.pyi b/python/pyspark/sql/pandas/map_ops.pyi new file mode 100644 index 0000000000000..cab885278c388 --- /dev/null +++ b/python/pyspark/sql/pandas/map_ops.pyi @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Union + +from pyspark.sql.pandas._typing import PandasMapIterFunction +from pyspark import since as since # noqa: F401 +from pyspark.rdd import PythonEvalType as PythonEvalType # noqa: F401 +from pyspark.sql.types import StructType +import pyspark.sql.dataframe + +class PandasMapOpsMixin: + def mapInPandas( + self, udf: PandasMapIterFunction, schema: Union[StructType, str] + ) -> pyspark.sql.dataframe.DataFrame: ... diff --git a/python/pyspark/sql/pandas/serializers.pyi b/python/pyspark/sql/pandas/serializers.pyi new file mode 100644 index 0000000000000..8be3c0dcbc9ad --- /dev/null +++ b/python/pyspark/sql/pandas/serializers.pyi @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.serializers import ( # noqa: F401 + Serializer as Serializer, + UTF8Deserializer as UTF8Deserializer, + read_int as read_int, + write_int as write_int, +) +from typing import Any + +class SpecialLengths: + END_OF_DATA_SECTION: int = ... + PYTHON_EXCEPTION_THROWN: int = ... + TIMING_DATA: int = ... + END_OF_STREAM: int = ... + NULL: int = ... + START_ARROW_STREAM: int = ... + +class ArrowCollectSerializer(Serializer): + serializer: Any = ... + def __init__(self) -> None: ... + def dump_stream(self, iterator: Any, stream: Any): ... + def load_stream(self, stream: Any) -> None: ... + +class ArrowStreamSerializer(Serializer): + def dump_stream(self, iterator: Any, stream: Any) -> None: ... + def load_stream(self, stream: Any) -> None: ... + +class ArrowStreamPandasSerializer(ArrowStreamSerializer): + def __init__( + self, timezone: Any, safecheck: Any, assign_cols_by_name: Any + ) -> None: ... + def arrow_to_pandas(self, arrow_column: Any): ... + def dump_stream(self, iterator: Any, stream: Any) -> None: ... + def load_stream(self, stream: Any) -> None: ... + +class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): + def __init__( + self, + timezone: Any, + safecheck: Any, + assign_cols_by_name: Any, + df_for_struct: bool = ..., + ) -> None: ... + def arrow_to_pandas(self, arrow_column: Any): ... + def dump_stream(self, iterator: Any, stream: Any): ... + +class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): + def load_stream(self, stream: Any) -> None: ... diff --git a/python/pyspark/sql/pandas/typehints.pyi b/python/pyspark/sql/pandas/typehints.pyi new file mode 100644 index 0000000000000..eea9c86225332 --- /dev/null +++ b/python/pyspark/sql/pandas/typehints.pyi @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.sql.pandas.utils import ( # noqa: F401 + require_minimum_pandas_version as require_minimum_pandas_version, +) +from typing import Any, Optional + +def infer_eval_type(sig: Any): ... +def check_tuple_annotation( + annotation: Any, parameter_check_func: Optional[Any] = ... +): ... +def check_iterator_annotation( + annotation: Any, parameter_check_func: Optional[Any] = ... +): ... +def check_union_annotation( + annotation: Any, parameter_check_func: Optional[Any] = ... +): ... diff --git a/python/pyspark/sql/pandas/types.pyi b/python/pyspark/sql/pandas/types.pyi new file mode 100644 index 0000000000000..5ae29bd273180 --- /dev/null +++ b/python/pyspark/sql/pandas/types.pyi @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.sql.types import ( # noqa: F401 + ArrayType as ArrayType, + BinaryType as BinaryType, + BooleanType as BooleanType, + ByteType as ByteType, + DateType as DateType, + DecimalType as DecimalType, + DoubleType as DoubleType, + FloatType as FloatType, + IntegerType as IntegerType, + LongType as LongType, + ShortType as ShortType, + StringType as StringType, + StructField as StructField, + StructType as StructType, + TimestampType as TimestampType, +) +from typing import Any + +def to_arrow_type(dt: Any): ... +def to_arrow_schema(schema: Any): ... +def from_arrow_type(at: Any): ... +def from_arrow_schema(arrow_schema: Any): ... diff --git a/python/pyspark/sql/pandas/utils.pyi b/python/pyspark/sql/pandas/utils.pyi new file mode 100644 index 0000000000000..e4d315b0ce205 --- /dev/null +++ b/python/pyspark/sql/pandas/utils.pyi @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +def require_minimum_pandas_version() -> None: ... +def require_minimum_pyarrow_version() -> None: ... diff --git a/python/pyspark/sql/readwriter.pyi b/python/pyspark/sql/readwriter.pyi new file mode 100644 index 0000000000000..a111cbe416c2f --- /dev/null +++ b/python/pyspark/sql/readwriter.pyi @@ -0,0 +1,250 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Dict, List, Optional, Tuple, Union + +from pyspark.sql._typing import OptionalPrimitiveType +from pyspark.sql.dataframe import DataFrame +from pyspark.rdd import RDD +from pyspark.sql.column import Column +from pyspark.sql.context import SQLContext +from pyspark.sql.types import StructType + +PathOrPaths = Union[str, List[str]] +TupleOrListOfString = Union[List[str], Tuple[str, ...]] + +class OptionUtils: ... + +class DataFrameReader(OptionUtils): + def __init__(self, spark: SQLContext) -> None: ... + def format(self, source: str) -> DataFrameReader: ... + def schema(self, schema: Union[StructType, str]) -> DataFrameReader: ... + def option(self, key: str, value: OptionalPrimitiveType) -> DataFrameReader: ... + def options(self, **options: OptionalPrimitiveType) -> DataFrameReader: ... + def load( + self, + path: Optional[PathOrPaths] = ..., + format: Optional[str] = ..., + schema: Optional[StructType] = ..., + **options: OptionalPrimitiveType + ) -> DataFrame: ... + def json( + self, + path: Union[str, List[str], RDD[str]], + schema: Optional[Union[StructType, str]] = ..., + primitivesAsString: Optional[Union[bool, str]] = ..., + prefersDecimal: Optional[Union[bool, str]] = ..., + allowComments: Optional[Union[bool, str]] = ..., + allowUnquotedFieldNames: Optional[Union[bool, str]] = ..., + allowSingleQuotes: Optional[Union[bool, str]] = ..., + allowNumericLeadingZero: Optional[Union[bool, str]] = ..., + allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ..., + mode: Optional[str] = ..., + columnNameOfCorruptRecord: Optional[str] = ..., + dateFormat: Optional[str] = ..., + timestampFormat: Optional[str] = ..., + multiLine: Optional[Union[bool, str]] = ..., + allowUnquotedControlChars: Optional[Union[bool, str]] = ..., + lineSep: Optional[str] = ..., + samplingRatio: Optional[Union[float, str]] = ..., + dropFieldIfAllNull: Optional[Union[bool, str]] = ..., + encoding: Optional[str] = ..., + locale: Optional[str] = ..., + recursiveFileLookup: Optional[bool] = ..., + ) -> DataFrame: ... + def table(self, tableName: str) -> DataFrame: ... + def parquet(self, *paths: str, **options: OptionalPrimitiveType) -> DataFrame: ... + def text( + self, + paths: PathOrPaths, + wholetext: bool = ..., + lineSep: Optional[str] = ..., + recursiveFileLookup: Optional[bool] = ..., + ) -> DataFrame: ... + def csv( + self, + path: PathOrPaths, + schema: Optional[Union[StructType, str]] = ..., + sep: Optional[str] = ..., + encoding: Optional[str] = ..., + quote: Optional[str] = ..., + escape: Optional[str] = ..., + comment: Optional[str] = ..., + header: Optional[Union[bool, str]] = ..., + inferSchema: Optional[Union[bool, str]] = ..., + ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = ..., + ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = ..., + nullValue: Optional[str] = ..., + nanValue: Optional[str] = ..., + positiveInf: Optional[str] = ..., + negativeInf: Optional[str] = ..., + dateFormat: Optional[str] = ..., + timestampFormat: Optional[str] = ..., + maxColumns: Optional[int] = ..., + maxCharsPerColumn: Optional[int] = ..., + maxMalformedLogPerPartition: Optional[int] = ..., + mode: Optional[str] = ..., + columnNameOfCorruptRecord: Optional[str] = ..., + multiLine: Optional[Union[bool, str]] = ..., + charToEscapeQuoteEscaping: Optional[str] = ..., + samplingRatio: Optional[Union[float, str]] = ..., + enforceSchema: Optional[Union[bool, str]] = ..., + emptyValue: Optional[str] = ..., + locale: Optional[str] = ..., + lineSep: Optional[str] = ..., + ) -> DataFrame: ... + def orc( + self, + path: PathOrPaths, + mergeSchema: Optional[bool] = ..., + recursiveFileLookup: Optional[bool] = ..., + ) -> DataFrame: ... + @overload + def jdbc( + self, url: str, table: str, *, properties: Optional[Dict[str, str]] = ... + ) -> DataFrame: ... + @overload + def jdbc( + self, + url: str, + table: str, + column: str, + lowerBound: int, + upperBound: int, + numPartitions: int, + *, + properties: Optional[Dict[str, str]] = ... + ) -> DataFrame: ... + @overload + def jdbc( + self, + url: str, + table: str, + *, + predicates: List[str], + properties: Optional[Dict[str, str]] = ... + ) -> DataFrame: ... + +class DataFrameWriter(OptionUtils): + def __init__(self, df: DataFrame) -> None: ... + def mode(self, saveMode: str) -> DataFrameWriter: ... + def format(self, source: str) -> DataFrameWriter: ... + def option(self, key: str, value: OptionalPrimitiveType) -> DataFrameWriter: ... + def options(self, **options: OptionalPrimitiveType) -> DataFrameWriter: ... + @overload + def partitionBy(self, *cols: str) -> DataFrameWriter: ... + @overload + def partitionBy(self, __cols: List[str]) -> DataFrameWriter: ... + @overload + def bucketBy(self, numBuckets: int, col: str, *cols: str) -> DataFrameWriter: ... + @overload + def bucketBy( + self, numBuckets: int, col: TupleOrListOfString + ) -> DataFrameWriter: ... + @overload + def sortBy(self, col: str, *cols: str) -> DataFrameWriter: ... + @overload + def sortBy(self, col: TupleOrListOfString) -> DataFrameWriter: ... + def save( + self, + path: Optional[str] = ..., + format: Optional[str] = ..., + mode: Optional[str] = ..., + partitionBy: Optional[List[str]] = ..., + **options: OptionalPrimitiveType + ) -> None: ... + def insertInto(self, tableName: str, overwrite: Optional[bool] = ...) -> None: ... + def saveAsTable( + self, + name: str, + format: Optional[str] = ..., + mode: Optional[str] = ..., + partitionBy: Optional[List[str]] = ..., + **options: OptionalPrimitiveType + ) -> None: ... + def json( + self, + path: str, + mode: Optional[str] = ..., + compression: Optional[str] = ..., + dateFormat: Optional[str] = ..., + timestampFormat: Optional[str] = ..., + lineSep: Optional[str] = ..., + encoding: Optional[str] = ..., + ignoreNullFields: Optional[bool] = ..., + ) -> None: ... + def parquet( + self, + path: str, + mode: Optional[str] = ..., + partitionBy: Optional[List[str]] = ..., + compression: Optional[str] = ..., + ) -> None: ... + def text( + self, path: str, compression: Optional[str] = ..., lineSep: Optional[str] = ... + ) -> None: ... + def csv( + self, + path: str, + mode: Optional[str] = ..., + compression: Optional[str] = ..., + sep: Optional[str] = ..., + quote: Optional[str] = ..., + escape: Optional[str] = ..., + header: Optional[Union[bool, str]] = ..., + nullValue: Optional[str] = ..., + escapeQuotes: Optional[Union[bool, str]] = ..., + quoteAll: Optional[Union[bool, str]] = ..., + dateFormat: Optional[str] = ..., + timestampFormat: Optional[str] = ..., + ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = ..., + ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = ..., + charToEscapeQuoteEscaping: Optional[str] = ..., + encoding: Optional[str] = ..., + emptyValue: Optional[str] = ..., + lineSep: Optional[str] = ..., + ) -> None: ... + def orc( + self, + path: str, + mode: Optional[str] = ..., + partitionBy: Optional[List[str]] = ..., + compression: Optional[str] = ..., + ) -> None: ... + def jdbc( + self, + url: str, + table: str, + mode: Optional[str] = ..., + properties: Optional[Dict[str, str]] = ..., + ) -> None: ... + +class DataFrameWriterV2: + def __init__(self, df: DataFrame, table: str) -> None: ... + def using(self, provider: str) -> DataFrameWriterV2: ... + def option(self, key: str, value: OptionalPrimitiveType) -> DataFrameWriterV2: ... + def options(self, **options: OptionalPrimitiveType) -> DataFrameWriterV2: ... + def tableProperty(self, property: str, value: str) -> DataFrameWriterV2: ... + def partitionedBy(self, col: Column, *cols: Column) -> DataFrameWriterV2: ... + def create(self) -> None: ... + def replace(self) -> None: ... + def createOrReplace(self) -> None: ... + def append(self) -> None: ... + def overwrite(self, condition: Column) -> None: ... + def overwritePartitions(self) -> None: ... diff --git a/python/pyspark/sql/session.pyi b/python/pyspark/sql/session.pyi new file mode 100644 index 0000000000000..17ba8894c1731 --- /dev/null +++ b/python/pyspark/sql/session.pyi @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Iterable, List, Optional, Tuple, TypeVar, Union + +from py4j.java_gateway import JavaObject # type: ignore[import] + +from pyspark.sql._typing import DateTimeLiteral, LiteralType, DecimalLiteral, RowLike +from pyspark.sql.pandas._typing import DataFrameLike +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.sql.catalog import Catalog +from pyspark.sql.conf import RuntimeConfig +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.pandas.conversion import SparkConversionMixin +from pyspark.sql.types import AtomicType, StructType +from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.streaming import DataStreamReader, StreamingQueryManager +from pyspark.sql.udf import UDFRegistration + +T = TypeVar("T") + +class SparkSession(SparkConversionMixin): + class Builder: + @overload + def config(self, *, conf: SparkConf) -> SparkSession.Builder: ... + @overload + def config(self, key: str, value: Any) -> SparkSession.Builder: ... + def master(self, master: str) -> SparkSession.Builder: ... + def appName(self, name: str) -> SparkSession.Builder: ... + def enableHiveSupport(self) -> SparkSession.Builder: ... + def getOrCreate(self) -> SparkSession: ... + builder: SparkSession.Builder + def __init__( + self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = ... + ) -> None: ... + def newSession(self) -> SparkSession: ... + @classmethod + def getActiveSession(cls) -> SparkSession: ... + @property + def sparkContext(self) -> SparkContext: ... + @property + def version(self) -> str: ... + @property + def conf(self) -> RuntimeConfig: ... + @property + def catalog(self) -> Catalog: ... + @property + def udf(self) -> UDFRegistration: ... + def range( + self, + start: int, + end: Optional[int] = ..., + step: int = ..., + numPartitions: Optional[int] = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: Union[RDD[RowLike], Iterable[RowLike]], + samplingRatio: Optional[float] = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: Union[RDD[RowLike], Iterable[RowLike]], + schema: Union[List[str], Tuple[str, ...]] = ..., + verifySchema: bool = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: Union[ + RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]], + Iterable[Union[DateTimeLiteral, LiteralType, DecimalLiteral]], + ], + schema: Union[AtomicType, str], + verifySchema: bool = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: Union[RDD[RowLike], Iterable[RowLike]], + schema: Union[StructType, str], + verifySchema: bool = ..., + ) -> DataFrame: ... + @overload + def createDataFrame( + self, data: DataFrameLike, samplingRatio: Optional[float] = ... + ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: DataFrameLike, + schema: Union[StructType, str], + verifySchema: bool = ..., + ) -> DataFrame: ... + def sql(self, sqlQuery: str) -> DataFrame: ... + def table(self, tableName: str) -> DataFrame: ... + @property + def read(self) -> DataFrameReader: ... + @property + def readStream(self) -> DataStreamReader: ... + @property + def streams(self) -> StreamingQueryManager: ... + def stop(self) -> None: ... + def __enter__(self) -> SparkSession: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... diff --git a/python/pyspark/sql/streaming.pyi b/python/pyspark/sql/streaming.pyi new file mode 100644 index 0000000000000..22055b2efc06b --- /dev/null +++ b/python/pyspark/sql/streaming.pyi @@ -0,0 +1,179 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Callable, Dict, List, Optional, Union + +from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType +from pyspark.sql.context import SQLContext +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.readwriter import OptionUtils +from pyspark.sql.types import Row, StructType +from pyspark.sql.utils import StreamingQueryException + +from py4j.java_gateway import JavaObject # type: ignore[import] + +class StreamingQuery: + def __init__(self, jsq: JavaObject) -> None: ... + @property + def id(self) -> str: ... + @property + def runId(self) -> str: ... + @property + def name(self) -> str: ... + @property + def isActive(self) -> bool: ... + def awaitTermination(self, timeout: Optional[int] = ...) -> Optional[bool]: ... + @property + def status(self) -> Dict[str, Any]: ... + @property + def recentProgress(self) -> List[Dict[str, Any]]: ... + @property + def lastProgress(self) -> Optional[Dict[str, Any]]: ... + def processAllAvailable(self) -> None: ... + def stop(self) -> None: ... + def explain(self, extended: bool = ...) -> None: ... + def exception(self) -> Optional[StreamingQueryException]: ... + +class StreamingQueryManager: + def __init__(self, jsqm: JavaObject) -> None: ... + @property + def active(self) -> List[StreamingQuery]: ... + def get(self, id: str) -> StreamingQuery: ... + def awaitAnyTermination(self, timeout: Optional[int] = ...) -> bool: ... + def resetTerminated(self) -> None: ... + +class DataStreamReader(OptionUtils): + def __init__(self, spark: SQLContext) -> None: ... + def format(self, source: str) -> DataStreamReader: ... + def schema(self, schema: Union[StructType, str]) -> DataStreamReader: ... + def option(self, key: str, value: OptionalPrimitiveType) -> DataStreamReader: ... + def options(self, **options: OptionalPrimitiveType) -> DataStreamReader: ... + def load( + self, + path: Optional[str] = ..., + format: Optional[str] = ..., + schema: Optional[StructType] = ..., + **options: OptionalPrimitiveType + ) -> DataFrame: ... + def json( + self, + path: str, + schema: Optional[Union[StructType, str]] = ..., + primitivesAsString: Optional[Union[bool, str]] = ..., + prefersDecimal: Optional[Union[bool, str]] = ..., + allowComments: Optional[Union[bool, str]] = ..., + allowUnquotedFieldNames: Optional[Union[bool, str]] = ..., + allowSingleQuotes: Optional[Union[bool, str]] = ..., + allowNumericLeadingZero: Optional[Union[bool, str]] = ..., + allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ..., + mode: Optional[str] = ..., + columnNameOfCorruptRecord: Optional[str] = ..., + dateFormat: Optional[str] = ..., + timestampFormat: Optional[str] = ..., + multiLine: Optional[Union[bool, str]] = ..., + allowUnquotedControlChars: Optional[Union[bool, str]] = ..., + lineSep: Optional[str] = ..., + locale: Optional[str] = ..., + dropFieldIfAllNull: Optional[Union[bool, str]] = ..., + encoding: Optional[str] = ..., + recursiveFileLookup: Optional[bool] = ..., + ) -> DataFrame: ... + def orc( + self, + path: str, + mergeSchema: Optional[bool] = ..., + recursiveFileLookup: Optional[bool] = ..., + ) -> DataFrame: ... + def parquet( + self, + path: str, + mergeSchema: Optional[bool] = ..., + recursiveFileLookup: Optional[bool] = ..., + ) -> DataFrame: ... + def text( + self, + path: str, + wholetext: bool = ..., + lineSep: Optional[str] = ..., + recursiveFileLookup: Optional[bool] = ..., + ) -> DataFrame: ... + def csv( + self, + path: str, + schema: Optional[Union[StructType, str]] = ..., + sep: Optional[str] = ..., + encoding: Optional[str] = ..., + quote: Optional[str] = ..., + escape: Optional[str] = ..., + comment: Optional[str] = ..., + header: Optional[Union[bool, str]] = ..., + inferSchema: Optional[Union[bool, str]] = ..., + ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = ..., + ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = ..., + nullValue: Optional[str] = ..., + nanValue: Optional[str] = ..., + positiveInf: Optional[str] = ..., + negativeInf: Optional[str] = ..., + dateFormat: Optional[str] = ..., + timestampFormat: Optional[str] = ..., + maxColumns: Optional[Union[int, str]] = ..., + maxCharsPerColumn: Optional[Union[int, str]] = ..., + mode: Optional[str] = ..., + columnNameOfCorruptRecord: Optional[str] = ..., + multiLine: Optional[Union[bool, str]] = ..., + charToEscapeQuoteEscaping: Optional[Union[bool, str]] = ..., + enforceSchema: Optional[Union[bool, str]] = ..., + emptyValue: Optional[str] = ..., + locale: Optional[str] = ..., + lineSep: Optional[str] = ..., + ) -> DataFrame: ... + +class DataStreamWriter: + def __init__(self, df: DataFrame) -> None: ... + def outputMode(self, outputMode: str) -> DataStreamWriter: ... + def format(self, source: str) -> DataStreamWriter: ... + def option(self, key: str, value: OptionalPrimitiveType) -> DataStreamWriter: ... + def options(self, **options: OptionalPrimitiveType) -> DataStreamWriter: ... + @overload + def partitionBy(self, *cols: str) -> DataStreamWriter: ... + @overload + def partitionBy(self, __cols: List[str]) -> DataStreamWriter: ... + def queryName(self, queryName: str) -> DataStreamWriter: ... + @overload + def trigger(self, processingTime: str) -> DataStreamWriter: ... + @overload + def trigger(self, once: bool) -> DataStreamWriter: ... + @overload + def trigger(self, continuous: bool) -> DataStreamWriter: ... + def start( + self, + path: Optional[str] = ..., + format: Optional[str] = ..., + outputMode: Optional[str] = ..., + partitionBy: Optional[Union[str, List[str]]] = ..., + queryName: Optional[str] = ..., + **options: OptionalPrimitiveType + ) -> StreamingQuery: ... + @overload + def foreach(self, f: Callable[[Row], None]) -> DataStreamWriter: ... + @overload + def foreach(self, f: SupportsProcess) -> DataStreamWriter: ... + def foreachBatch( + self, func: Callable[[DataFrame, int], None] + ) -> DataStreamWriter: ... diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index fb4f619c8bf63..c6497923d84fb 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -42,7 +42,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore class ArrowTests(ReusedSQLTestCase): @classmethod @@ -465,7 +465,7 @@ def test_createDataFrame_empty_partition(self): @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore class MaxResultArrowTests(unittest.TestCase): # These tests are separate as 'spark.driver.maxResultSize' configuration # is a static configuration to Spark context. @@ -500,7 +500,7 @@ def conf(cls): from pyspark.sql.tests.test_arrow import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 141b249db0fc6..ca4e427a7db28 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -206,7 +206,7 @@ def test_list_columns(self): from pyspark.sql.tests.test_catalog import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 8a89e6e9d5599..7e03e2ef3e6d0 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -161,7 +161,7 @@ def test_with_field(self): from pyspark.sql.tests.test_column import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py index dd2e0be85d508..1cc0c1b7562c5 100644 --- a/python/pyspark/sql/tests/test_conf.py +++ b/python/pyspark/sql/tests/test_conf.py @@ -49,7 +49,7 @@ def test_conf(self): from pyspark.sql.tests.test_conf import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index ce22a52dc119e..d506908b784db 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -276,7 +276,7 @@ def test_get_or_create(self): from pyspark.sql.tests.test_context import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index d03939821a176..d941707b8969f 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -518,7 +518,7 @@ def _to_pandas(self): df = self.spark.createDataFrame(data, schema) return df.toPandas() - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas(self): import numpy as np pdf = self._to_pandas() @@ -530,7 +530,7 @@ def test_to_pandas(self): self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_with_duplicated_column_names(self): import numpy as np @@ -543,7 +543,7 @@ def test_to_pandas_with_duplicated_column_names(self): self.assertEquals(types.iloc[0], np.int32) self.assertEquals(types.iloc[1], np.int32) - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_on_cross_join(self): import numpy as np @@ -569,7 +569,7 @@ def test_to_pandas_required_pandas_not_found(self): with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_avoid_astype(self): import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ @@ -581,7 +581,7 @@ def test_to_pandas_avoid_astype(self): self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.float64) - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_empty_dataframe(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes @@ -601,7 +601,7 @@ def test_to_pandas_from_empty_dataframe(self): dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df)) - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_null_dataframe(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes @@ -629,7 +629,7 @@ def test_to_pandas_from_null_dataframe(self): self.assertEqual(types[7], np.object) self.assertTrue(np.can_cast(np.datetime64, types[8])) - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_mixed_dataframe(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes @@ -657,7 +657,7 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_create_dataframe_from_pandas_with_timestamp(self): import pandas as pd from datetime import datetime @@ -685,7 +685,7 @@ def test_create_dataframe_required_pandas_not_found(self): self.spark.createDataFrame(pdf) # Regression test for SPARK-23360 - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_create_dataframe_from_pandas_with_dst(self): import pandas as pd from pandas.util.testing import assert_frame_equal @@ -889,7 +889,7 @@ def test_query_execution_listener_on_collect(self): @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore def test_query_execution_listener_on_collect_with_arrow(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}): self.assertFalse( @@ -907,7 +907,7 @@ def test_query_execution_listener_on_collect_with_arrow(self): from pyspark.sql.tests.test_dataframe import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index dfef8f5740050..9425494fb0d90 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -164,7 +164,7 @@ def test_ignore_column_of_all_nulls(self): from pyspark.sql.tests.test_datasources import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 09f5960c6f648..5638cad51b755 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -396,7 +396,7 @@ def test_higher_order_function_failures(self): from pyspark.sql.tests.test_functions import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 2fab7a08da1da..324c964f4f0cf 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -39,7 +39,7 @@ def test_aggregator(self): from pyspark.sql.tests.test_group import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index 5013e2d4d6bd9..f9a7dd69b61fb 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -33,7 +33,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] class CogroupedMapInPandasTests(ReusedSQLTestCase): @property @@ -247,7 +247,7 @@ def merge_pandas(l, r): from pyspark.sql.tests.test_pandas_cogrouped_map import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 6eb5355044bb0..81b6d5efb710a 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -41,7 +41,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] class GroupedMapInPandasTests(ReusedSQLTestCase): @property @@ -611,7 +611,7 @@ def my_pandas_udf(pdf): from pyspark.sql.tests.test_pandas_grouped_map import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_map.py b/python/pyspark/sql/tests/test_pandas_map.py index bda370dffbf6a..3ca437f75fc23 100644 --- a/python/pyspark/sql/tests/test_pandas_map.py +++ b/python/pyspark/sql/tests/test_pandas_map.py @@ -27,7 +27,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] class MapInPandasTests(ReusedSQLTestCase): @classmethod @@ -117,7 +117,7 @@ def func(iterator): from pyspark.sql.tests.test_pandas_map import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index 24b98182b7fcf..cc742fc4267cb 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -28,7 +28,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): @@ -244,7 +244,7 @@ def udf(column): from pyspark.sql.tests.test_pandas_udf import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index f63f52239fdf2..451308927629b 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -35,7 +35,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property @@ -514,7 +514,7 @@ def mean(x): from pyspark.sql.tests.test_pandas_udf_grouped_agg import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 522807b03af70..6d325c9085ce1 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -46,7 +46,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod @@ -1095,7 +1095,7 @@ def f3i(it): self.assertEquals(expected, df1.collect()) # SPARK-24721 - @unittest.skipIf(not test_compiled, test_not_compiled_message) + @unittest.skipIf(not test_compiled, test_not_compiled_message) # type: ignore def test_datasource_with_udf(self): # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF # This needs to a separate test because Arrow dependency is optional @@ -1142,7 +1142,7 @@ def test_datasource_with_udf(self): from pyspark.sql.tests.test_pandas_udf_scalar import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_udf_typehints.py b/python/pyspark/sql/tests/test_pandas_udf_typehints.py index 7be81f82808e4..d9717da4d2fbd 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_typehints.py +++ b/python/pyspark/sql/tests/test_pandas_udf_typehints.py @@ -34,7 +34,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] class PandasUDFTypeHintsTests(ReusedSQLTestCase): def test_type_annotation_scalar(self): def func(col: pd.Series) -> pd.Series: @@ -246,7 +246,7 @@ def pandas_plus_one(iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: from pyspark.sql.tests.test_pandas_udf_typehints import * # noqa: #401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py index 6e59255da13a2..5ad2ecd8f85d4 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -31,7 +31,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message) + pandas_requirement_message or pyarrow_requirement_message) # type: ignore[arg-type] class WindowPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -355,7 +355,7 @@ def test_bounded_mixed(self): from pyspark.sql.tests.test_pandas_udf_window import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 55ffefc43c105..80b4118ae796a 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -204,7 +204,7 @@ def test_partitioning_functions(self): from pyspark.sql.tests.test_readwriter import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index 35c14e430af50..ce087ff4ce550 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -142,7 +142,7 @@ def test_bytes_as_binary_type(self): from pyspark.sql.tests.test_serde import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index d10f7bf906c3b..7faeb1857b983 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -361,7 +361,7 @@ def test_use_custom_class_for_extensions(self): from pyspark.sql.tests.test_session import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index 21ce04618a904..28a50f9575a0a 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -575,7 +575,7 @@ def collectBatch(df, id): from pyspark.sql.tests.test_streaming import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 7256db055fb9c..e85e8a6e6d1ee 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -25,12 +25,15 @@ import unittest from pyspark.sql import Row -from pyspark.sql.functions import col, UserDefinedFunction +from pyspark.sql.functions import col +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import ByteType, ShortType, IntegerType, FloatType, DateType, \ TimestampType, MapType, StringType, StructType, StructField, ArrayType, DoubleType, LongType, \ DecimalType, BinaryType, BooleanType, NullType -from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings, \ +from pyspark.sql.types import ( # type: ignore + _array_signed_int_typecode_ctype_mappings, _array_type_mappings, _array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type +) from pyspark.testing.sqlutils import ReusedSQLTestCase, ExamplePointUDT, PythonOnlyUDT, \ ExamplePoint, PythonOnlyPoint, MyObject @@ -974,7 +977,7 @@ def test_row_without_field_sorting(self): from pyspark.sql.tests.test_types import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index ad94bc83cc5be..a7dcbfd32ac1c 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -23,7 +23,8 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, Column, Row -from pyspark.sql.functions import UserDefinedFunction, udf +from pyspark.sql.functions import udf +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import StringType, IntegerType, BooleanType, DoubleType, LongType, \ ArrayType, StructType, StructField from pyspark.sql.utils import AnalysisException @@ -356,7 +357,7 @@ def test_udf_registration_returns_udf(self): df.select(add_four("id").alias("plus_four")).collect() ) - @unittest.skipIf(not test_compiled, test_not_compiled_message) + @unittest.skipIf(not test_compiled, test_not_compiled_message) # type: ignore def test_register_java_function(self): self.spark.udf.registerJavaFunction( "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) @@ -373,7 +374,7 @@ def test_register_java_function(self): [value] = self.spark.sql("SELECT javaStringLength3('test')").first() self.assertEqual(value, 4) - @unittest.skipIf(not test_compiled, test_not_compiled_message) + @unittest.skipIf(not test_compiled, test_not_compiled_message) # type: ignore def test_register_java_udaf(self): self.spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], ["id", "name"]) @@ -560,7 +561,7 @@ def test_nonparam_udf_with_aggregate(self): self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) # SPARK-24721 - @unittest.skipIf(not test_compiled, test_not_compiled_message) + @unittest.skipIf(not test_compiled, test_not_compiled_message) # type: ignore def test_datasource_with_udf(self): from pyspark.sql.functions import lit, col @@ -699,7 +700,7 @@ def test_udf_init_shouldnt_initialize_context(self): from pyspark.sql.tests.test_udf import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index c6e7fcd8ec11a..b08e17208d8af 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -55,7 +55,7 @@ def test_capture_illegalargument_exception(self): from pyspark.sql.tests.test_utils import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi new file mode 100644 index 0000000000000..31765e94884d7 --- /dev/null +++ b/python/pyspark/sql/types.pyi @@ -0,0 +1,204 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, TypeVar +import datetime + +T = TypeVar("T") +U = TypeVar("U") + +class DataType: + def __hash__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + @classmethod + def typeName(cls) -> str: ... + def simpleString(self) -> str: ... + def jsonValue(self) -> Union[str, Dict[str, Any]]: ... + def json(self) -> str: ... + def needConversion(self) -> bool: ... + def toInternal(self, obj: Any) -> Any: ... + def fromInternal(self, obj: Any) -> Any: ... + +class DataTypeSingleton(type): + def __call__(cls): ... + +class NullType(DataType, metaclass=DataTypeSingleton): ... +class AtomicType(DataType): ... +class NumericType(AtomicType): ... +class IntegralType(NumericType, metaclass=DataTypeSingleton): ... +class FractionalType(NumericType): ... +class StringType(AtomicType, metaclass=DataTypeSingleton): ... +class BinaryType(AtomicType, metaclass=DataTypeSingleton): ... +class BooleanType(AtomicType, metaclass=DataTypeSingleton): ... + +class DateType(AtomicType, metaclass=DataTypeSingleton): + EPOCH_ORDINAL: int + def needConversion(self) -> bool: ... + def toInternal(self, d: datetime.date) -> int: ... + def fromInternal(self, v: int) -> datetime.date: ... + +class TimestampType(AtomicType, metaclass=DataTypeSingleton): + def needConversion(self) -> bool: ... + def toInternal(self, dt: datetime.datetime) -> int: ... + def fromInternal(self, ts: int) -> datetime.datetime: ... + +class DecimalType(FractionalType): + precision: int + scale: int + hasPrecisionInfo: bool + def __init__(self, precision: int = ..., scale: int = ...) -> None: ... + def simpleString(self) -> str: ... + def jsonValue(self) -> str: ... + +class DoubleType(FractionalType, metaclass=DataTypeSingleton): ... +class FloatType(FractionalType, metaclass=DataTypeSingleton): ... + +class ByteType(IntegralType): + def simpleString(self) -> str: ... + +class IntegerType(IntegralType): + def simpleString(self) -> str: ... + +class LongType(IntegralType): + def simpleString(self) -> str: ... + +class ShortType(IntegralType): + def simpleString(self) -> str: ... + +class ArrayType(DataType): + elementType: DataType + containsNull: bool + def __init__(self, elementType=DataType, containsNull: bool = ...) -> None: ... + def simpleString(self): ... + def jsonValue(self) -> Dict[str, Any]: ... + @classmethod + def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ... + def needConversion(self) -> bool: ... + def toInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: ... + def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: ... + +class MapType(DataType): + keyType: DataType + valueType: DataType + valueContainsNull: bool + def __init__( + self, keyType: DataType, valueType: DataType, valueContainsNull: bool = ... + ) -> None: ... + def simpleString(self) -> str: ... + def jsonValue(self) -> Dict[str, Any]: ... + @classmethod + def fromJson(cls, json: Dict[str, Any]) -> MapType: ... + def needConversion(self) -> bool: ... + def toInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: ... + def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: ... + +class StructField(DataType): + name: str + dataType: DataType + nullable: bool + metadata: Dict[str, Any] + def __init__( + self, + name: str, + dataType: DataType, + nullable: bool = ..., + metadata: Optional[Dict[str, Any]] = ..., + ) -> None: ... + def simpleString(self) -> str: ... + def jsonValue(self) -> Dict[str, Any]: ... + @classmethod + def fromJson(cls, json: Dict[str, Any]) -> StructField: ... + def needConversion(self) -> bool: ... + def toInternal(self, obj: T) -> T: ... + def fromInternal(self, obj: T) -> T: ... + +class StructType(DataType): + fields: List[StructField] + names: List[str] + def __init__(self, fields: Optional[List[StructField]] = ...) -> None: ... + @overload + def add( + self, + field: str, + data_type: Union[str, DataType], + nullable: bool = ..., + metadata: Optional[Dict[str, Any]] = ..., + ) -> StructType: ... + @overload + def add(self, field: StructField) -> StructType: ... + def __iter__(self) -> Iterator[StructField]: ... + def __len__(self) -> int: ... + def __getitem__(self, key: Union[str, int]) -> StructField: ... + def simpleString(self) -> str: ... + def jsonValue(self) -> Dict[str, Any]: ... + @classmethod + def fromJson(cls, json: Dict[str, Any]) -> StructType: ... + def fieldNames(self) -> List[str]: ... + def needConversion(self) -> bool: ... + def toInternal(self, obj: Tuple) -> Tuple: ... + def fromInternal(self, obj: Tuple) -> Row: ... + +class UserDefinedType(DataType): + @classmethod + def typeName(cls) -> str: ... + @classmethod + def sqlType(cls) -> DataType: ... + @classmethod + def module(cls) -> str: ... + @classmethod + def scalaUDT(cls) -> str: ... + def needConversion(self) -> bool: ... + def toInternal(self, obj: Any) -> Any: ... + def fromInternal(self, obj: Any) -> Any: ... + def serialize(self, obj: Any) -> Any: ... + def deserialize(self, datum: Any) -> Any: ... + def simpleString(self) -> str: ... + def json(self) -> str: ... + def jsonValue(self) -> Dict[str, Any]: ... + @classmethod + def fromJson(cls, json: Dict[str, Any]) -> UserDefinedType: ... + def __eq__(self, other: Any) -> bool: ... + +class Row(tuple): + @overload + def __new__(self, *args: str) -> Row: ... + @overload + def __new__(self, **kwargs: Any) -> Row: ... + @overload + def __init__(self, *args: str) -> None: ... + @overload + def __init__(self, **kwargs: Any) -> None: ... + def asDict(self, recursive: bool = ...) -> Dict[str, Any]: ... + def __contains__(self, item: Any) -> bool: ... + def __call__(self, *args: Any) -> Row: ... + def __getitem__(self, item: Any) -> Any: ... + def __getattr__(self, item: str) -> Any: ... + def __setattr__(self, key: Any, value: Any) -> None: ... + def __reduce__( + self, + ) -> Tuple[Callable[[List[str], List[Any]], Row], Tuple[List[str], Tuple]]: ... + +class DateConverter: + def can_convert(self, obj: Any) -> bool: ... + def convert(self, obj, gateway_client) -> Any: ... + +class DatetimeConverter: + def can_convert(self, obj) -> bool: ... + def convert(self, obj, gateway_client) -> Any: ... diff --git a/python/pyspark/sql/udf.pyi b/python/pyspark/sql/udf.pyi new file mode 100644 index 0000000000000..87c3672780037 --- /dev/null +++ b/python/pyspark/sql/udf.pyi @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, Optional + +from pyspark.sql._typing import ColumnOrName, DataTypeOrString +from pyspark.sql.column import Column +import pyspark.sql.session + +class UserDefinedFunction: + func: Callable[..., Any] + evalType: int + deterministic: bool + def __init__( + self, + func: Callable[..., Any], + returnType: DataTypeOrString = ..., + name: Optional[str] = ..., + evalType: int = ..., + deterministic: bool = ..., + ) -> None: ... + @property + def returnType(self): ... + def __call__(self, *cols: ColumnOrName) -> Column: ... + def asNondeterministic(self) -> UserDefinedFunction: ... + +class UDFRegistration: + sparkSession: pyspark.sql.session.SparkSession + def __init__(self, sparkSession: pyspark.sql.session.SparkSession) -> None: ... + def register( + self, + name: str, + f: Callable[..., Any], + returnType: Optional[DataTypeOrString] = ..., + ): ... + def registerJavaFunction( + self, + name: str, + javaClassName: str, + returnType: Optional[DataTypeOrString] = ..., + ) -> None: ... + def registerJavaUDAF(self, name: str, javaClassName: str) -> None: ... diff --git a/python/pyspark/sql/utils.pyi b/python/pyspark/sql/utils.pyi new file mode 100644 index 0000000000000..c11e4bed54e7f --- /dev/null +++ b/python/pyspark/sql/utils.pyi @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: This dynamically typed stub was automatically generated by stubgen. + +from pyspark import SparkContext as SparkContext # noqa: F401 +from typing import Any, Optional + +class CapturedException(Exception): + desc: Any = ... + stackTrace: Any = ... + cause: Any = ... + def __init__( + self, desc: Any, stackTrace: Any, cause: Optional[Any] = ... + ) -> None: ... + +class AnalysisException(CapturedException): ... +class ParseException(CapturedException): ... +class IllegalArgumentException(CapturedException): ... +class StreamingQueryException(CapturedException): ... +class QueryExecutionException(CapturedException): ... +class PythonException(CapturedException): ... +class UnknownException(CapturedException): ... + +def convert_exception(e: Any): ... +def capture_sql_exception(f: Any): ... +def install_exception_handler() -> None: ... +def toJArray(gateway: Any, jtype: Any, arr: Any): ... +def require_test_compiled() -> None: ... + +class ForeachBatchFunction: + sql_ctx: Any = ... + func: Any = ... + def __init__(self, sql_ctx: Any, func: Any) -> None: ... + error: Any = ... + def call(self, jdf: Any, batch_id: Any) -> None: ... + class Java: + implements: Any = ... + +def to_str(value: Any): ... diff --git a/python/pyspark/sql/window.pyi b/python/pyspark/sql/window.pyi new file mode 100644 index 0000000000000..4e31d57bec4d0 --- /dev/null +++ b/python/pyspark/sql/window.pyi @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.sql._typing import ColumnOrName +from py4j.java_gateway import JavaObject # type: ignore[import] + +class Window: + unboundedPreceding: int + unboundedFollowing: int + currentRow: int + @staticmethod + def partitionBy(*cols: ColumnOrName) -> WindowSpec: ... + @staticmethod + def orderBy(*cols: ColumnOrName) -> WindowSpec: ... + @staticmethod + def rowsBetween(start: int, end: int) -> WindowSpec: ... + @staticmethod + def rangeBetween(start: int, end: int) -> WindowSpec: ... + +class WindowSpec: + def __init__(self, jspec: JavaObject) -> None: ... + def partitionBy(self, *cols: ColumnOrName) -> WindowSpec: ... + def orderBy(self, *cols: ColumnOrName) -> WindowSpec: ... + def rowsBetween(self, start: int, end: int) -> WindowSpec: ... + def rangeBetween(self, start: int, end: int) -> WindowSpec: ... diff --git a/python/pyspark/statcounter.pyi b/python/pyspark/statcounter.pyi new file mode 100644 index 0000000000000..38e5970501527 --- /dev/null +++ b/python/pyspark/statcounter.pyi @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Iterable, Optional, Union + +maximum: Any +minimum: Any +sqrt: Any + +class StatCounter: + n: int + mu: float + m2: float + maxValue: float + minValue: float + def __init__(self, values: Optional[Iterable[float]] = ...) -> None: ... + def merge(self, value: float) -> StatCounter: ... + def mergeStats(self, other: StatCounter) -> StatCounter: ... + def copy(self) -> StatCounter: ... + def count(self) -> int: ... + def mean(self) -> float: ... + def sum(self) -> float: ... + def min(self) -> float: ... + def max(self) -> float: ... + def variance(self) -> float: ... + def sampleVariance(self) -> float: ... + def stdev(self) -> float: ... + def sampleStdev(self) -> float: ... + def asDict(self, sample: bool = ...) -> Dict[str, Union[float, int]]: ... diff --git a/python/pyspark/status.pyi b/python/pyspark/status.pyi new file mode 100644 index 0000000000000..0558e245f49cc --- /dev/null +++ b/python/pyspark/status.pyi @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, NamedTuple, Optional +from py4j.java_gateway import JavaArray, JavaObject # type: ignore[import] + +class SparkJobInfo(NamedTuple): + jobId: int + stageIds: JavaArray + status: str + +class SparkStageInfo(NamedTuple): + stageId: int + currentAttemptId: int + name: str + numTasks: int + numActiveTasks: int + numCompletedTasks: int + numFailedTasks: int + +class StatusTracker: + def __init__(self, jtracker: JavaObject) -> None: ... + def getJobIdsForGroup(self, jobGroup: Optional[str] = ...) -> List[int]: ... + def getActiveStageIds(self) -> List[int]: ... + def getActiveJobsIds(self) -> List[int]: ... + def getJobInfo(self, jobId: int) -> SparkJobInfo: ... + def getStageInfo(self, stageId: int) -> SparkStageInfo: ... diff --git a/python/pyspark/storagelevel.pyi b/python/pyspark/storagelevel.pyi new file mode 100644 index 0000000000000..2eb05850bae78 --- /dev/null +++ b/python/pyspark/storagelevel.pyi @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import ClassVar + +class StorageLevel: + DISK_ONLY: ClassVar[StorageLevel] + DISK_ONLY_2: ClassVar[StorageLevel] + MEMORY_ONLY: ClassVar[StorageLevel] + MEMORY_ONLY_2: ClassVar[StorageLevel] + DISK_ONLY_3: ClassVar[StorageLevel] + MEMORY_AND_DISK: ClassVar[StorageLevel] + MEMORY_AND_DISK_2: ClassVar[StorageLevel] + OFF_HEAP: ClassVar[StorageLevel] + + useDisk: bool + useMemory: bool + useOffHeap: bool + deserialized: bool + replication: int + def __init__( + self, + useDisk: bool, + useMemory: bool, + useOffHeap: bool, + deserialized: bool, + replication: int = ..., + ) -> None: ... diff --git a/python/pyspark/streaming/__init__.pyi b/python/pyspark/streaming/__init__.pyi new file mode 100644 index 0000000000000..281c06e51cc60 --- /dev/null +++ b/python/pyspark/streaming/__init__.pyi @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.streaming.context import StreamingContext as StreamingContext # noqa: F401 +from pyspark.streaming.dstream import DStream as DStream # noqa: F401 +from pyspark.streaming.listener import ( # noqa: F401 + StreamingListener as StreamingListener, +) diff --git a/python/pyspark/streaming/context.pyi b/python/pyspark/streaming/context.pyi new file mode 100644 index 0000000000000..f4b3dad38f1fb --- /dev/null +++ b/python/pyspark/streaming/context.pyi @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, List, Optional, TypeVar, Union + +from py4j.java_gateway import JavaObject # type: ignore[import] + +from pyspark.context import SparkContext +from pyspark.rdd import RDD +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.dstream import DStream +from pyspark.streaming.listener import StreamingListener + +T = TypeVar("T") + +class StreamingContext: + def __init__( + self, + sparkContext: SparkContext, + batchDuration: Union[float, int] = ..., + jssc: Optional[JavaObject] = ..., + ) -> None: ... + @classmethod + def getOrCreate( + cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext] + ) -> StreamingContext: ... + @classmethod + def getActive(cls) -> StreamingContext: ... + @classmethod + def getActiveOrCreate( + cls, checkpointPath: str, setupFunc: Callable[[], StreamingContext] + ) -> StreamingContext: ... + @property + def sparkContext(self) -> SparkContext: ... + def start(self) -> None: ... + def awaitTermination(self, timeout: Optional[int] = ...) -> None: ... + def awaitTerminationOrTimeout(self, timeout: int) -> None: ... + def stop( + self, stopSparkContext: bool = ..., stopGraceFully: bool = ... + ) -> None: ... + def remember(self, duration: int) -> None: ... + def checkpoint(self, directory: str) -> None: ... + def socketTextStream( + self, hostname: str, port: int, storageLevel: StorageLevel = ... + ) -> DStream[str]: ... + def textFileStream(self, directory: str) -> DStream[str]: ... + def binaryRecordsStream( + self, directory: str, recordLength: int + ) -> DStream[bytes]: ... + def queueStream( + self, + rdds: List[RDD[T]], + oneAtATime: bool = ..., + default: Optional[RDD[T]] = ..., + ) -> DStream[T]: ... + def transform( + self, dstreams: List[DStream[Any]], transformFunc: Callable[..., RDD[T]] + ) -> DStream[T]: ... + def union(self, *dstreams: DStream[T]) -> DStream[T]: ... + def addStreamingListener(self, streamingListener: StreamingListener) -> None: ... diff --git a/python/pyspark/streaming/dstream.pyi b/python/pyspark/streaming/dstream.pyi new file mode 100644 index 0000000000000..bbeea69ee9ac2 --- /dev/null +++ b/python/pyspark/streaming/dstream.pyi @@ -0,0 +1,208 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import overload +from typing import ( + Callable, + Generic, + Hashable, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, +) +import datetime +from pyspark.rdd import RDD +from pyspark.storagelevel import StorageLevel +import pyspark.streaming.context + +S = TypeVar("S") +T = TypeVar("T") +U = TypeVar("U") +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") + +class DStream(Generic[T]): + is_cached: bool + is_checkpointed: bool + def __init__(self, jdstream, ssc, jrdd_deserializer) -> None: ... + def context(self) -> pyspark.streaming.context.StreamingContext: ... + def count(self) -> DStream[int]: ... + def filter(self, f: Callable[[T], bool]) -> DStream[T]: ... + def flatMap( + self: DStream[T], + f: Callable[[T], Iterable[U]], + preservesPartitioning: bool = ..., + ) -> DStream[U]: ... + def map( + self: DStream[T], f: Callable[[T], U], preservesPartitioning: bool = ... + ) -> DStream[U]: ... + def mapPartitions( + self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ... + ) -> DStream[U]: ... + def mapPartitionsWithIndex( + self, + f: Callable[[int, Iterable[T]], Iterable[U]], + preservesPartitioning: bool = ..., + ) -> DStream[U]: ... + def reduce(self, func: Callable[[T, T], T]) -> DStream[T]: ... + def reduceByKey( + self: DStream[Tuple[K, V]], + func: Callable[[V, V], V], + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, V]]: ... + def combineByKey( + self: DStream[Tuple[K, V]], + createCombiner: Callable[[V], U], + mergeValue: Callable[[U, V], U], + mergeCombiners: Callable[[U, U], U], + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, U]]: ... + def partitionBy( + self: DStream[Tuple[K, V]], + numPartitions: int, + partitionFunc: Callable[[K], int] = ..., + ) -> DStream[Tuple[K, V]]: ... + @overload + def foreachRDD(self, func: Callable[[RDD[T]], None]) -> None: ... + @overload + def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T]], None]) -> None: ... + def pprint(self, num: int = ...) -> None: ... + def mapValues( + self: DStream[Tuple[K, V]], f: Callable[[V], U] + ) -> DStream[Tuple[K, U]]: ... + def flatMapValues( + self: DStream[Tuple[K, V]], f: Callable[[V], Iterable[U]] + ) -> DStream[Tuple[K, U]]: ... + def glom(self) -> DStream[List[T]]: ... + def cache(self) -> DStream[T]: ... + def persist(self, storageLevel: StorageLevel) -> DStream[T]: ... + def checkpoint(self, interval: Union[float, int]) -> DStream[T]: ... + def groupByKey( + self: DStream[Tuple[K, V]], numPartitions: Optional[int] = ... + ) -> DStream[Tuple[K, Iterable[V]]]: ... + def countByValue(self) -> DStream[Tuple[T, int]]: ... + def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = ...) -> None: ... + @overload + def transform(self, func: Callable[[RDD[T]], RDD[U]]) -> TransformedDStream[U]: ... + @overload + def transform( + self, func: Callable[[datetime.datetime, RDD[T]], RDD[U]] + ) -> TransformedDStream[U]: ... + @overload + def transformWith( + self, + func: Callable[[RDD[T], RDD[U]], RDD[V]], + other: RDD[U], + keepSerializer: bool = ..., + ) -> DStream[V]: ... + @overload + def transformWith( + self, + func: Callable[[datetime.datetime, RDD[T], RDD[U]], RDD[V]], + other: RDD[U], + keepSerializer: bool = ..., + ) -> DStream[V]: ... + def repartition(self, numPartitions: int) -> DStream[T]: ... + def union(self, other: DStream[U]) -> DStream[Union[T, U]]: ... + def cogroup( + self: DStream[Tuple[K, V]], + other: DStream[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, Tuple[List[V], List[U]]]]: ... + def join( + self: DStream[Tuple[K, V]], + other: DStream[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, Tuple[V, U]]]: ... + def leftOuterJoin( + self: DStream[Tuple[K, V]], + other: DStream[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, Tuple[V, Optional[U]]]]: ... + def rightOuterJoin( + self: DStream[Tuple[K, V]], + other: DStream[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, Tuple[Optional[V], U]]]: ... + def fullOuterJoin( + self: DStream[Tuple[K, V]], + other: DStream[Tuple[K, U]], + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ... + def slice( + self, begin: Union[datetime.datetime, int], end: Union[datetime.datetime, int] + ) -> List[RDD[T]]: ... + def window( + self, windowDuration: int, slideDuration: Optional[int] = ... + ) -> DStream[T]: ... + def reduceByWindow( + self, + reduceFunc: Callable[[T, T], T], + invReduceFunc: Optional[Callable[[T, T], T]], + windowDuration: int, + slideDuration: int, + ) -> DStream[T]: ... + def countByWindow( + self, windowDuration: int, slideDuration: int + ) -> DStream[Tuple[T, int]]: ... + def countByValueAndWindow( + self, + windowDuration: int, + slideDuration: int, + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[T, int]]: ... + def groupByKeyAndWindow( + self: DStream[Tuple[K, V]], + windowDuration: int, + slideDuration: int, + numPartitions: Optional[int] = ..., + ) -> DStream[Tuple[K, Iterable[V]]]: ... + def reduceByKeyAndWindow( + self: DStream[Tuple[K, V]], + func: Callable[[V, V], V], + invFunc: Optional[Callable[[V, V], V]], + windowDuration: int, + slideDuration: Optional[int] = ..., + numPartitions: Optional[int] = ..., + filterFunc: Optional[Callable[[Tuple[K, V]], bool]] = ..., + ) -> DStream[Tuple[K, V]]: ... + def updateStateByKey( + self: DStream[Tuple[K, V]], + updateFunc: Callable[[Iterable[V], Optional[S]], S], + numPartitions: Optional[int] = ..., + initialRDD: Optional[RDD[Tuple[K, S]]] = ..., + ) -> DStream[Tuple[K, S]]: ... + +class TransformedDStream(DStream[U]): + is_cached: bool + is_checkpointed: bool + func: Callable + prev: DStream + @overload + def __init__( + self: DStream[U], prev: DStream[T], func: Callable[[RDD[T]], RDD[U]] + ) -> None: ... + @overload + def __init__( + self: DStream[U], + prev: DStream[T], + func: Callable[[datetime.datetime, RDD[T]], RDD[U]], + ) -> None: ... diff --git a/python/pyspark/streaming/kinesis.pyi b/python/pyspark/streaming/kinesis.pyi new file mode 100644 index 0000000000000..246fa58ca6da3 --- /dev/null +++ b/python/pyspark/streaming/kinesis.pyi @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: This dynamically typed stub was automatically generated by stubgen. + +from typing import Any, Optional + +def utf8_decoder(s): ... + +class KinesisUtils: + @staticmethod + def createStream( + ssc, + kinesisAppName, + streamName, + endpointUrl, + regionName, + initialPositionInStream, + checkpointInterval, + storageLevel: Any = ..., + awsAccessKeyId: Optional[Any] = ..., + awsSecretKey: Optional[Any] = ..., + decoder: Any = ..., + stsAssumeRoleArn: Optional[Any] = ..., + stsSessionName: Optional[Any] = ..., + stsExternalId: Optional[Any] = ..., + ): ... + +class InitialPositionInStream: + LATEST: Any + TRIM_HORIZON: Any diff --git a/python/pyspark/streaming/listener.pyi b/python/pyspark/streaming/listener.pyi new file mode 100644 index 0000000000000..4033529607cea --- /dev/null +++ b/python/pyspark/streaming/listener.pyi @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: This dynamically typed stub was automatically generated by stubgen. + +from typing import Any + +class StreamingListener: + def __init__(self) -> None: ... + def onStreamingStarted(self, streamingStarted: Any) -> None: ... + def onReceiverStarted(self, receiverStarted: Any) -> None: ... + def onReceiverError(self, receiverError: Any) -> None: ... + def onReceiverStopped(self, receiverStopped: Any) -> None: ... + def onBatchSubmitted(self, batchSubmitted: Any) -> None: ... + def onBatchStarted(self, batchStarted: Any) -> None: ... + def onBatchCompleted(self, batchCompleted: Any) -> None: ... + def onOutputOperationStarted(self, outputOperationStarted: Any) -> None: ... + def onOutputOperationCompleted(self, outputOperationCompleted: Any) -> None: ... + class Java: + implements: Any = ... diff --git a/python/pyspark/streaming/tests/test_context.py b/python/pyspark/streaming/tests/test_context.py index 26f1d24f644ea..b255796cdcdd7 100644 --- a/python/pyspark/streaming/tests/test_context.py +++ b/python/pyspark/streaming/tests/test_context.py @@ -178,7 +178,7 @@ def test_await_termination_or_timeout(self): from pyspark.streaming.tests.test_context import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/streaming/tests/test_dstream.py b/python/pyspark/streaming/tests/test_dstream.py index 00d00b50c9283..ea5353c77b6b2 100644 --- a/python/pyspark/streaming/tests/test_dstream.py +++ b/python/pyspark/streaming/tests/test_dstream.py @@ -647,7 +647,7 @@ def check_output(n): from pyspark.streaming.tests.test_dstream import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/streaming/tests/test_kinesis.py b/python/pyspark/streaming/tests/test_kinesis.py index b39809e2f69c2..70c9a012e7a03 100644 --- a/python/pyspark/streaming/tests/test_kinesis.py +++ b/python/pyspark/streaming/tests/test_kinesis.py @@ -83,7 +83,7 @@ def get_output(_, rdd): from pyspark.streaming.tests.test_kinesis import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/streaming/tests/test_listener.py b/python/pyspark/streaming/tests/test_listener.py index 3970cf6589394..e4dab1bba3a6c 100644 --- a/python/pyspark/streaming/tests/test_listener.py +++ b/python/pyspark/streaming/tests/test_listener.py @@ -152,7 +152,7 @@ def func(dstream): from pyspark.streaming.tests.test_listener import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/streaming/util.pyi b/python/pyspark/streaming/util.pyi new file mode 100644 index 0000000000000..d552eb15f4818 --- /dev/null +++ b/python/pyspark/streaming/util.pyi @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: This dynamically typed stub was automatically generated by stubgen. + +from typing import Any, Optional + +class TransformFunction: + ctx: Any + func: Any + deserializers: Any + rdd_wrap_func: Any + failure: Any + def __init__(self, ctx, func, *deserializers) -> None: ... + def rdd_wrapper(self, func): ... + def call(self, milliseconds, jrdds): ... + def getLastFailure(self): ... + class Java: + implements: Any + +class TransformFunctionSerializer: + ctx: Any + serializer: Any + gateway: Any + failure: Any + def __init__(self, ctx, serializer, gateway: Optional[Any] = ...) -> None: ... + def dumps(self, id): ... + def loads(self, data): ... + def getLastFailure(self): ... + class Java: + implements: Any + +def rddToFileName(prefix, suffix, timestamp): ... diff --git a/python/pyspark/taskcontext.pyi b/python/pyspark/taskcontext.pyi new file mode 100644 index 0000000000000..3415c69f02177 --- /dev/null +++ b/python/pyspark/taskcontext.pyi @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, List +from typing_extensions import Literal +from pyspark.resource.information import ResourceInformation + +class TaskContext: + def __new__(cls) -> TaskContext: ... + @classmethod + def get(cls) -> TaskContext: ... + def stageId(self) -> int: ... + def partitionId(self) -> int: ... + def attemptNumber(self) -> int: ... + def taskAttemptId(self) -> int: ... + def getLocalProperty(self, key: str) -> str: ... + def resources(self) -> Dict[str, ResourceInformation]: ... + +BARRIER_FUNCTION = Literal[1] + +class BarrierTaskContext(TaskContext): + @classmethod + def get(cls) -> BarrierTaskContext: ... + def barrier(self) -> None: ... + def allGather(self, message: str = ...) -> List[str]: ... + def getTaskInfos(self) -> List[BarrierTaskInfo]: ... + +class BarrierTaskInfo: + address: str + def __init__(self, address: str) -> None: ... diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py index a36d0709d8013..a8cf53b31f8c9 100644 --- a/python/pyspark/testing/mlutils.py +++ b/python/pyspark/testing/mlutils.py @@ -20,7 +20,7 @@ from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable -from pyspark.ml.wrapper import _java2py +from pyspark.ml.wrapper import _java2py # type: ignore from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import DoubleType from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase @@ -116,7 +116,8 @@ def _transform(self, dataset): class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): - shift = Param(Params._dummy(), "shift", "The amount by which to shift " + + shift = Param(Params._dummy(), # type: ignore + "shift", "The amount by which to shift " + "data in a DataFrame", typeConverter=TypeConverters.toFloat) diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index e85cae7dda2c6..a394e8eecc69e 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -147,7 +147,7 @@ class PythonOnlyPoint(ExamplePoint): """ An example class to demonstrate UDT in only Python """ - __UDT__ = PythonOnlyUDT() + __UDT__ = PythonOnlyUDT() # type: ignore class MyObject(object): diff --git a/python/pyspark/testing/streamingutils.py b/python/pyspark/testing/streamingutils.py index a6abc2ef673b7..f6a317e97331c 100644 --- a/python/pyspark/testing/streamingutils.py +++ b/python/pyspark/testing/streamingutils.py @@ -37,7 +37,7 @@ "spark-streaming-kinesis-asl-assembly-", "spark-streaming-kinesis-asl-assembly_") if kinesis_asl_assembly_jar is None: - kinesis_requirement_message = ( + kinesis_requirement_message = ( # type: ignore "Skipping all Kinesis Python tests as the optional Kinesis project was " "not compiled into a JAR. To run these tests, " "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " @@ -47,7 +47,7 @@ existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") jars_args = "--jars %s" % kinesis_asl_assembly_jar os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) - kinesis_requirement_message = None + kinesis_requirement_message = None # type: ignore should_test_kinesis = kinesis_requirement_message is None diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py index 15170b878eb22..3f45bf039d3a9 100644 --- a/python/pyspark/tests/test_appsubmit.py +++ b/python/pyspark/tests/test_appsubmit.py @@ -241,7 +241,7 @@ def test_user_configuration(self): from pyspark.tests.test_appsubmit import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index 543dc98660fde..c35c5a68e4986 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -148,7 +148,7 @@ def random_bytes(n): from pyspark.tests.test_broadcast import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py index 3e80c17f4931c..a8d65b8919777 100644 --- a/python/pyspark/tests/test_conf.py +++ b/python/pyspark/tests/test_conf.py @@ -36,7 +36,7 @@ def test_memory_conf(self): from pyspark.tests.test_conf import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 9f159f7703950..9b6b74a111288 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -93,7 +93,7 @@ def test_add_py_file(self): # this job fails due to `userlibrary` not being on the Python path: # disable logging in log4j temporarily def func(x): - from userlibrary import UserClass + from userlibrary import UserClass # type: ignore return UserClass().hello() with QuietTest(self.sc): self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) @@ -137,7 +137,8 @@ def test_add_egg_file_locally(self): # To ensure that we're actually testing addPyFile's effects, check that # this fails due to `userlibrary` not being on the Python path: def func(): - from userlib import UserClass # noqa: F401 + from userlib import UserClass # type: ignore[import] + UserClass() self.assertRaises(ImportError, func) path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") self.sc.addPyFile(path) @@ -147,11 +148,11 @@ def func(): def test_overwrite_system_module(self): self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) - import SimpleHTTPServer + import SimpleHTTPServer # type: ignore[import] self.assertEqual("My Server", SimpleHTTPServer.__name__) def func(x): - import SimpleHTTPServer + import SimpleHTTPServer # type: ignore[import] return SimpleHTTPServer.__name__ self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) @@ -321,7 +322,7 @@ def tearDown(self): from pyspark.tests.test_context import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py index b1f8c71c77ba9..c3fd89fef72c2 100644 --- a/python/pyspark/tests/test_daemon.py +++ b/python/pyspark/tests/test_daemon.py @@ -76,7 +76,7 @@ def test_termination_sigterm(self): from pyspark.tests.test_daemon import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py index 815c78ef9a8e2..63dd1cfef9a6a 100644 --- a/python/pyspark/tests/test_join.py +++ b/python/pyspark/tests/test_join.py @@ -62,7 +62,7 @@ def test_narrow_dependency_in_join(self): from pyspark.tests.test_join import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_pin_thread.py b/python/pyspark/tests/test_pin_thread.py index efe7d7f6639b1..b612796c963a0 100644 --- a/python/pyspark/tests/test_pin_thread.py +++ b/python/pyspark/tests/test_pin_thread.py @@ -169,7 +169,7 @@ def get_outer_local_prop(): from pyspark.tests.test_pin_thread import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py index ca144cc6e1eb6..de72a547b0844 100644 --- a/python/pyspark/tests/test_profiler.py +++ b/python/pyspark/tests/test_profiler.py @@ -101,7 +101,7 @@ def test_profiler_disabled(self): from pyspark.tests.test_profiler import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index c154bda00d605..47b8f10a5b05e 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -884,7 +884,7 @@ def run_job(job_group, index): from pyspark.tests.test_rdd import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_rddbarrier.py b/python/pyspark/tests/test_rddbarrier.py index f0a05a23cc4e0..ba2c4b9ba84d4 100644 --- a/python/pyspark/tests/test_rddbarrier.py +++ b/python/pyspark/tests/test_rddbarrier.py @@ -43,7 +43,7 @@ def f(index, iterator): from pyspark.tests.test_rddbarrier import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py index adbc343c650a7..145b53a5eaaa1 100644 --- a/python/pyspark/tests/test_readwrite.py +++ b/python/pyspark/tests/test_readwrite.py @@ -307,7 +307,7 @@ def test_malformed_RDD(self): from pyspark.tests.test_readwrite import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py index 8eaa9c7d5a8d2..cc7838e595b8a 100644 --- a/python/pyspark/tests/test_serializers.py +++ b/python/pyspark/tests/test_serializers.py @@ -87,7 +87,7 @@ def __getattr__(self, item): def test_pickling_file_handles(self): # to be corrected with SPARK-11160 try: - import xmlrunner # noqa: F401 + import xmlrunner # type: ignore[import] # noqa: F401 except ImportError: ser = CloudPickleSerializer() out1 = sys.stderr @@ -227,7 +227,7 @@ def test_chunked_stream(self): from pyspark.tests.test_serializers import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py index 061b93f32c56c..6a245a26b4551 100644 --- a/python/pyspark/tests/test_shuffle.py +++ b/python/pyspark/tests/test_shuffle.py @@ -170,7 +170,7 @@ def test_external_sort_in_rdd(self): from pyspark.tests.test_shuffle import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index f5be685643dd5..f0e6672957c13 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -324,7 +324,7 @@ def tearDown(self): from pyspark.tests.test_taskcontext import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index e853bc322c184..a25c41b296944 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -77,7 +77,7 @@ def test_parsing_version_string(self): from pyspark.tests.test_util import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index a855eaafc1927..bfaf3a3186cad 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -202,7 +202,7 @@ def tearDown(self): from pyspark.tests.test_worker import * # noqa: F401 try: - import xmlrunner + import xmlrunner # type: ignore[import] testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) except ImportError: testRunner = None diff --git a/python/pyspark/traceback_utils.pyi b/python/pyspark/traceback_utils.pyi new file mode 100644 index 0000000000000..33b1b7dc3227f --- /dev/null +++ b/python/pyspark/traceback_utils.pyi @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import namedtuple +from typing import Any + +CallSite = namedtuple("CallSite", "function file linenum") + +def first_spark_call(): ... + +class SCCallSiteSync: + def __init__(self, sc: Any) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, type: Any, value: Any, tb: Any) -> None: ... diff --git a/python/pyspark/util.pyi b/python/pyspark/util.pyi new file mode 100644 index 0000000000000..023b409831459 --- /dev/null +++ b/python/pyspark/util.pyi @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Any, Tuple +from pyspark._typing import F + +import threading + +def print_exec(stream: Any) -> None: ... + +class VersionUtils: + @staticmethod + def majorMinorVersion(sparkVersion: str) -> Tuple[int, int]: ... + +def fail_on_stopiteration(f: F) -> F: ... + +class InheritableThread(threading.Thread): + def __init__(self, target: Any, *args: Any, **kwargs: Any): ... + def __del__(self) -> None: ... diff --git a/python/pyspark/version.pyi b/python/pyspark/version.pyi new file mode 100644 index 0000000000000..444dae62f0c09 --- /dev/null +++ b/python/pyspark/version.pyi @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +__version__: str diff --git a/python/pyspark/worker.pyi b/python/pyspark/worker.pyi new file mode 100644 index 0000000000000..cc264823cc867 --- /dev/null +++ b/python/pyspark/worker.pyi @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark import shuffle as shuffle +from pyspark.broadcast import Broadcast as Broadcast +from pyspark.files import SparkFiles as SparkFiles +from pyspark.java_gateway import local_connect_and_auth as local_connect_and_auth +from pyspark.rdd import PythonEvalType as PythonEvalType +from pyspark.resource import ResourceInformation as ResourceInformation +from pyspark.serializers import ( + BatchedSerializer as BatchedSerializer, + PickleSerializer as PickleSerializer, + SpecialLengths as SpecialLengths, + UTF8Deserializer as UTF8Deserializer, + read_bool as read_bool, + read_int as read_int, + read_long as read_long, + write_int as write_int, + write_long as write_long, + write_with_length as write_with_length, +) +from pyspark.sql.pandas.serializers import ( + ArrowStreamPandasUDFSerializer as ArrowStreamPandasUDFSerializer, + CogroupUDFSerializer as CogroupUDFSerializer, +) +from pyspark.sql.pandas.types import to_arrow_type as to_arrow_type +from pyspark.sql.types import StructType as StructType +from pyspark.taskcontext import ( + BarrierTaskContext as BarrierTaskContext, + TaskContext as TaskContext, +) +from pyspark.util import fail_on_stopiteration as fail_on_stopiteration +from typing import Any + +has_resource_module: bool +pickleSer: Any +utf8_deserializer: Any + +def report_times(outfile: Any, boot: Any, init: Any, finish: Any) -> None: ... +def add_path(path: Any) -> None: ... +def read_command(serializer: Any, file: Any): ... +def chain(f: Any, g: Any): ... +def wrap_udf(f: Any, return_type: Any): ... +def wrap_scalar_pandas_udf(f: Any, return_type: Any): ... +def wrap_pandas_iter_udf(f: Any, return_type: Any): ... +def wrap_cogrouped_map_pandas_udf(f: Any, return_type: Any, argspec: Any): ... +def wrap_grouped_map_pandas_udf(f: Any, return_type: Any, argspec: Any): ... +def wrap_grouped_agg_pandas_udf(f: Any, return_type: Any): ... +def wrap_window_agg_pandas_udf( + f: Any, return_type: Any, runner_conf: Any, udf_index: Any +): ... +def wrap_unbounded_window_agg_pandas_udf(f: Any, return_type: Any): ... +def wrap_bounded_window_agg_pandas_udf(f: Any, return_type: Any): ... +def read_single_udf( + pickleSer: Any, infile: Any, eval_type: Any, runner_conf: Any, udf_index: Any +): ... +def read_udfs(pickleSer: Any, infile: Any, eval_type: Any): ... +def main(infile: Any, outfile: Any) -> None: ... diff --git a/python/setup.py b/python/setup.py index 7c12b112acd65..206765389335f 100755 --- a/python/setup.py +++ b/python/setup.py @@ -265,7 +265,8 @@ def run(self): 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy'], + 'Programming Language :: Python :: Implementation :: PyPy', + 'Typing :: Typed'], cmdclass={ 'install': InstallCommand, },