|
18 | 18 | import tempfile
|
19 | 19 | import unittest
|
20 | 20 |
|
| 21 | +from pyspark.sql import Row |
21 | 22 | from pyspark.ml.pipeline import Pipeline, PipelineModel
|
22 | 23 | from pyspark.ml.feature import (
|
23 | 24 | VectorAssembler,
|
|
26 | 27 | MinMaxScaler,
|
27 | 28 | MinMaxScalerModel,
|
28 | 29 | )
|
| 30 | +from pyspark.ml.linalg import Vectors |
29 | 31 | from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel
|
30 | 32 | from pyspark.ml.clustering import KMeans, KMeansModel
|
31 | 33 | from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer
|
@@ -172,6 +174,24 @@ def test_clustering_pipeline(self):
|
172 | 174 | self.assertEqual(str(model), str(model2))
|
173 | 175 | self.assertEqual(str(model.stages), str(model2.stages))
|
174 | 176 |
|
| 177 | + def test_model_gc(self): |
| 178 | + spark = self.spark |
| 179 | + df = spark.createDataFrame( |
| 180 | + [ |
| 181 | + Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])), |
| 182 | + Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])), |
| 183 | + Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])), |
| 184 | + ] |
| 185 | + ) |
| 186 | + |
| 187 | + def fit_transform(df): |
| 188 | + lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight") |
| 189 | + model = lr.fit(df) |
| 190 | + return model.transform(df) |
| 191 | + |
| 192 | + output = fit_transform(df) |
| 193 | + self.assertEqual(output.count(), 3) |
| 194 | + |
175 | 195 |
|
176 | 196 | class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
|
177 | 197 | pass
|
|
0 commit comments