Skip to content

Commit 09b93bd

Browse files
committed
[SPARK-51214][ML][PYTHON][CONNECT] Don't eagerly remove the cached models for fit_transform
### What changes were proposed in this pull request? Don't eagerly remove the cached models for `fit_transform`: 1, still keep the `Delete` ml command protobuf, but no longer call it in `__del__` in the python client side; 2, build the ml cache with guava CacheBuilder and soft references, and specify the maximum size and time out. ### Why are the changes needed? a common ml pipeline pattern is `fit_transform`: ``` def fit_transform(df): model = estimator.fit(df) return model.transform(df) df2 = fit_transform(df) df2.count() ``` existing implementation eagerly deletes the intermediate model from the ml cache, right after `fit_transform`, and thus causes NPE ``` pyspark.errors.exceptions.connect.SparkConnectGrpcException: (java.lang.NullPointerException) Cannot invoke "org.apache.spark.ml.Model.copy(org.apache.spark.ml.param.ParamMap)" because "model" is null JVM stacktrace: java.lang.NullPointerException at org.apache.spark.sql.connect.ml.ModelAttributeHelper.transform(MLHandler.scala:68) at org.apache.spark.sql.connect.ml.MLHandler$.transformMLRelation(MLHandler.scala:313) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:231) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477) at scala.Option.getOrElse(Option.scala:201) at org.apache.spark.sql.connect.service.SessionHolder.usePlanCache(SessionHolder.scala:476) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:147) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:133) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelationalGroupedAggregate(SparkConnectPlanner.scala:2318) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformAggregate(SparkConnectPlanner.scala:2299) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:165) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477) ``` ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#49948 from zhengruifeng/ml_connect_del. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent d6ad779 commit 09b93bd

File tree

7 files changed

+116
-16
lines changed

7 files changed

+116
-16
lines changed

common/utils/src/main/resources/error/error-conditions.json

+5
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,11 @@
775775
"<attribute> in <className> is not allowed to be accessed."
776776
]
777777
},
778+
"CACHE_INVALID" : {
779+
"message" : [
780+
"Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted."
781+
]
782+
},
778783
"UNSUPPORTED_EXCEPTION" : {
779784
"message" : [
780785
"<message>"

python/pyspark/ml/tests/test_pipeline.py

+20
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tempfile
1919
import unittest
2020

21+
from pyspark.sql import Row
2122
from pyspark.ml.pipeline import Pipeline, PipelineModel
2223
from pyspark.ml.feature import (
2324
VectorAssembler,
@@ -26,6 +27,7 @@
2627
MinMaxScaler,
2728
MinMaxScalerModel,
2829
)
30+
from pyspark.ml.linalg import Vectors
2931
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel
3032
from pyspark.ml.clustering import KMeans, KMeansModel
3133
from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer
@@ -172,6 +174,24 @@ def test_clustering_pipeline(self):
172174
self.assertEqual(str(model), str(model2))
173175
self.assertEqual(str(model.stages), str(model2.stages))
174176

177+
def test_model_gc(self):
178+
spark = self.spark
179+
df = spark.createDataFrame(
180+
[
181+
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
182+
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
183+
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
184+
]
185+
)
186+
187+
def fit_transform(df):
188+
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
189+
model = lr.fit(df)
190+
return model.transform(df)
191+
192+
output = fit_transform(df)
193+
self.assertEqual(output.count(), 3)
194+
175195

176196
class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
177197
pass

python/pyspark/ml/util.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,21 @@ def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any:
249249
return cast(FuncT, wrapped)
250250

251251

252+
# delete the object from the ml cache eagerly
253+
def del_remote_cache(ref_id: str) -> None:
254+
if ref_id is not None and "." not in ref_id:
255+
try:
256+
from pyspark.sql.connect.session import SparkSession
257+
258+
session = SparkSession.getActiveSession()
259+
if session is not None:
260+
session.client.remove_ml_cache(ref_id)
261+
return
262+
except Exception:
263+
# SparkSession's down.
264+
return
265+
266+
252267
def try_remote_del(f: FuncT) -> FuncT:
253268
"""Mark the function/property to delete a model on the server side."""
254269

@@ -261,18 +276,19 @@ def wrapped(self: "JavaWrapper") -> Any:
261276

262277
if in_remote:
263278
# Delete the model if possible
264-
model_id = self._java_obj
265-
if model_id is not None and "." not in model_id:
266-
try:
267-
from pyspark.sql.connect.session import SparkSession
268-
269-
session = SparkSession.getActiveSession()
270-
if session is not None:
271-
session.client.remove_ml_cache(model_id)
272-
return
273-
except Exception:
274-
# SparkSession's down.
275-
return
279+
# model_id = self._java_obj
280+
# del_remote_cache(model_id)
281+
#
282+
# Above codes delete the model from the ml cache eagerly, and may cause
283+
# NPE in the server side in the case of 'fit_transform':
284+
#
285+
# def fit_transform(df):
286+
# model = estimator.fit(df)
287+
# return model.transform(df)
288+
#
289+
# output = fit_transform(df)
290+
# output.show()
291+
return
276292
else:
277293
return f(self)
278294

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala

+18-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.apache.spark.sql.connect.ml
1818

1919
import java.util.UUID
20-
import java.util.concurrent.ConcurrentHashMap
20+
import java.util.concurrent.{ConcurrentMap, TimeUnit}
21+
22+
import com.google.common.cache.CacheBuilder
2123

2224
import org.apache.spark.internal.Logging
2325
import org.apache.spark.ml.util.ConnectHelper
@@ -29,8 +31,13 @@ private[connect] class MLCache extends Logging {
2931
private val helper = new ConnectHelper()
3032
private val helperID = "______ML_CONNECT_HELPER______"
3133

32-
private val cachedModel: ConcurrentHashMap[String, Object] =
33-
new ConcurrentHashMap[String, Object]()
34+
private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder
35+
.newBuilder()
36+
.softValues()
37+
.maximumSize(MLCache.MAX_CACHED_ITEMS)
38+
.expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES)
39+
.build[String, Object]()
40+
.asMap()
3441

3542
/**
3643
* Cache an object into a map of MLCache, and return its key
@@ -76,3 +83,11 @@ private[connect] class MLCache extends Logging {
7683
cachedModel.clear()
7784
}
7885
}
86+
87+
private[connect] object MLCache {
88+
// The maximum number of distinct items in the cache.
89+
private val MAX_CACHED_ITEMS = 100
90+
91+
// The maximum time for an item to stay in the cache.
92+
private val CACHE_TIMEOUT_MINUTE = 60
93+
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala

+6
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ private[spark] case class MLAttributeNotAllowedException(className: String, attr
3030
errorClass = "CONNECT_ML.ATTRIBUTE_NOT_ALLOWED",
3131
messageParameters = Map("className" -> className, "attribute" -> attribute),
3232
cause = null)
33+
34+
private[spark] case class MLCacheInvalidException(objectName: String)
35+
extends SparkException(
36+
errorClass = "CONNECT_ML.CACHE_INVALID",
37+
messageParameters = Map("objectName" -> objectName),
38+
cause = null)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala

+10-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ private class AttributeHelper(
4141
val sessionHolder: SessionHolder,
4242
val objRef: String,
4343
val methods: Array[Method]) {
44-
protected lazy val instance = sessionHolder.mlCache.get(objRef)
44+
protected lazy val instance = {
45+
val obj = sessionHolder.mlCache.get(objRef)
46+
if (obj == null) {
47+
throw MLCacheInvalidException(s"object $objRef")
48+
}
49+
obj
50+
}
4551
// Get the attribute by reflection
4652
def getAttribute: Any = {
4753
assert(methods.length >= 1)
@@ -181,6 +187,9 @@ private[connect] object MLHandler extends Logging {
181187
case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model
182188
val objId = mlCommand.getWrite.getObjRef.getId
183189
val model = mlCache.get(objId).asInstanceOf[Model[_]]
190+
if (model == null) {
191+
throw MLCacheInvalidException(s"model $objId")
192+
}
184193
val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
185194
MLUtils.setInstanceParams(copiedModel, mlCommand.getWrite.getParams)
186195

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala

+29
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,35 @@ class MLSuite extends MLHelper {
256256
}
257257
}
258258

259+
test("Exception: cannot retrieve object") {
260+
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
261+
val modelId = trainLogisticRegressionModel(sessionHolder)
262+
263+
// Fetch summary attribute
264+
val accuracyCommand = proto.MlCommand
265+
.newBuilder()
266+
.setFetch(
267+
proto.Fetch
268+
.newBuilder()
269+
.setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
270+
.addMethods(proto.Fetch.Method.newBuilder().setMethod("summary"))
271+
.addMethods(proto.Fetch.Method.newBuilder().setMethod("accuracy")))
272+
.build()
273+
274+
// Successfully fetch summary.accuracy from the cached model
275+
MLHandler.handleMlCommand(sessionHolder, accuracyCommand)
276+
277+
// Remove the model from cache
278+
sessionHolder.mlCache.clear()
279+
280+
// No longer able to retrieve the model from cache
281+
val e = intercept[MLCacheInvalidException] {
282+
MLHandler.handleMlCommand(sessionHolder, accuracyCommand)
283+
}
284+
val msg = e.getMessage
285+
assert(msg.contains(s"$modelId from the ML cache"))
286+
}
287+
259288
test("access the attribute which is not in allowed list") {
260289
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
261290
val modelId = trainLogisticRegressionModel(sessionHolder)

0 commit comments

Comments
 (0)