Skip to content

Commit 138587e

Browse files
committed
herman
1 parent 638ec7b commit 138587e

File tree

9 files changed

+120
-126
lines changed

9 files changed

+120
-126
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelineExecutionHolder.scala

Lines changed: 0 additions & 88 deletions
This file was deleted.

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,7 @@ import org.apache.spark.sql.connect.service.SessionHolder
3333
import org.apache.spark.sql.pipelines.Language.Python
3434
import org.apache.spark.sql.pipelines.QueryOriginType
3535
import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED}
36-
import org.apache.spark.sql.pipelines.graph.{
37-
FlowAnalysis,
38-
GraphIdentifierManager,
39-
IdentifierHelper,
40-
QueryContext,
41-
QueryOrigin,
42-
Table,
43-
TemporaryView,
44-
UnresolvedFlow
45-
}
36+
import org.apache.spark.sql.pipelines.graph.{FlowAnalysis, GraphIdentifierManager, IdentifierHelper, PipelineUpdateContextImpl, QueryContext, QueryOrigin, Table, TemporaryView, UnresolvedFlow}
4637
import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress}
4738
import org.apache.spark.sql.types.StructType
4839

@@ -341,11 +332,11 @@ private[connect] object PipelinesHandler extends Logging {
341332
)
342333
}
343334
}
344-
PipelineExecutionHolder.executePipeline(
345-
dataflowGraphId,
346-
graphElementRegistry.toDataflowGraph,
347-
eventCallback
348-
)
335+
val pipelineUpdateContext = new PipelineUpdateContextImpl(
336+
graphElementRegistry.toDataflowGraph, eventCallback)
337+
sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext)
338+
pipelineUpdateContext.pipelineExecution.runPipeline()
339+
349340
// Rethrow any exceptions that caused the pipeline run to fail so that the exception is
350341
// propagated back to the SC client / CLI.
351342
runFailureEvent.foreach { event =>

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import org.apache.spark.sql.connect.ml.MLCache
4141
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
4242
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
4343
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
44+
import org.apache.spark.sql.pipelines.graph.PipelineUpdateContext
4445
import org.apache.spark.sql.streaming.StreamingQueryListener
4546
import org.apache.spark.util.{SystemClock, Utils}
4647

@@ -119,6 +120,11 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
119120
private lazy val listenerCache: ConcurrentMap[String, StreamingQueryListener] =
120121
new ConcurrentHashMap()
121122

123+
// Mapping from graphId to the pipeline update context. This is used to manage the lifecycle of
124+
// pipeline executions.
125+
private lazy val pipelineExecutions =
126+
new ConcurrentHashMap[String, PipelineUpdateContext]()
127+
122128
// Handles Python process clean up for streaming queries. Initialized on first use in a query.
123129
private[connect] lazy val streamingForeachBatchRunnerCleanerCache =
124130
new StreamingForeachBatchHelper.CleanerCache(this)
@@ -311,6 +317,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
311317
SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = true)
312318
streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers.
313319
removeAllListeners() // removes all listener and stop python listener processes if necessary.
320+
// Stops all pipeline execution and clears the pipeline execution cache
321+
removeAllPipelineExecutions()
314322

315323
// if there is a server side listener, clean up related resources
316324
if (streamingServersideListenerHolder.isServerSideListenerRegistered) {
@@ -426,6 +434,56 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
426434
listenerCache.keySet().asScala.toSeq
427435
}
428436

437+
/**
438+
* Caches the pipeline execution context for a given graph ID.
439+
* @param graphId The id of the graph being executed.
440+
* @param pipelineUpdateContext The context for the pipeline execution.
441+
*/
442+
private[connect] def cachePipelineExecution(
443+
graphId: String,
444+
pipelineUpdateContext: PipelineUpdateContext): Unit = {
445+
pipelineExecutions.compute(
446+
graphId,
447+
(_, existing) => {
448+
if (Option(existing).isDefined) {
449+
throw new IllegalStateException(
450+
s"Pipeline execution for graph ID $graphId already exists. " +
451+
s"Stop the existing execution before starting a new one."
452+
)
453+
}
454+
455+
pipelineUpdateContext
456+
}
457+
)
458+
}
459+
460+
/** Stops the pipeline execution and removes it from the cache. */
461+
private def removeCachedPipelineExecution(graphId: String): Unit = {
462+
pipelineExecutions.compute(graphId, (_, context) => {
463+
if (context.pipelineExecution.executionStarted) {
464+
context.pipelineExecution.stopPipeline()
465+
}
466+
// Remove the execution.
467+
null
468+
})
469+
}
470+
471+
/** Stops all pipeline executions and clears the pipeline execution cache. */
472+
def removeAllPipelineExecutions(): Unit = {
473+
pipelineExecutions.forEach((graphId, _) => {
474+
removeCachedPipelineExecution(graphId)
475+
})
476+
pipelineExecutions.clear()
477+
}
478+
479+
/**
480+
* Returns [[PipelineUpdateContext]] cached for the given graphId. If it is not found, return
481+
* None.
482+
*/
483+
private[connect] def getPipelineExecution(graphId: String): Option[PipelineUpdateContext] = {
484+
Option(pipelineExecutions.get(graphId))
485+
}
486+
429487
/**
430488
* An accumulator for Python executors.
431489
*

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class PythonPipelineSuite
5959
with EventVerificationTestHelpers {
6060

6161
def buildGraph(pythonText: String): DataflowGraph = {
62-
val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n")
62+
val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n")
6363
val pythonCode =
6464
s"""
6565
|from pyspark.sql import SparkSession
@@ -72,24 +72,24 @@ class PythonPipelineSuite
7272
| graph_element_registration_context,
7373
|)
7474
|
75-
|try:
76-
| spark = SparkSession.builder \\
77-
| .remote("sc://localhost:$serverPort") \\
78-
| .config("spark.connect.grpc.channel.timeout", "5s") \\
79-
| .create()
75+
|spark = SparkSession.builder \\
76+
| .remote("sc://localhost:$serverPort") \\
77+
| .config("spark.connect.grpc.channel.timeout", "5s") \\
78+
| .create()
8079
|
81-
| dataflow_graph_id = create_dataflow_graph(
82-
| spark,
83-
| default_catalog=None,
84-
| default_database=None,
85-
| sql_conf={},
86-
| )
80+
|dataflow_graph_id = create_dataflow_graph(
81+
| spark,
82+
| default_catalog=None,
83+
| default_database=None,
84+
| sql_conf={},
85+
|)
8786
|
88-
| registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
89-
| with graph_element_registration_context(registry):
90-
| $indentedPythonText
87+
|registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
88+
|with graph_element_registration_context(registry):
89+
|$indentedPythonText
9190
|""".stripMargin
9291

92+
logInfo(s"Running code: $pythonCode")
9393
val (exitCode, output) = executePythonCode(pythonCode)
9494

9595
if (exitCode != 0) {

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.connect.pipelines
1919

20-
import scala.concurrent.duration.DurationInt
21-
2220
import org.apache.spark.connect.proto
2321
import org.apache.spark.connect.proto.{DatasetType, Expression, PipelineCommand, Relation, UnresolvedTableValuedFunction}
2422
import org.apache.spark.connect.proto.PipelineCommand.{DefineDataset, DefineFlow}

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ package org.apache.spark.sql.connect.pipelines
2020
import org.apache.spark.connect.{proto => sc}
2121
import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkConnectTestUtils}
2222
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
23+
import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService}
2324
import org.apache.spark.sql.pipelines.utils.PipelineTest
2425

2526
class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest {
2627

2728
override def afterEach(): Unit = {
28-
PipelineExecutionHolder.stopAllPipelineExecutions()
29+
SparkConnectService.sessionManager.getIsolatedSessionIfPresent(
30+
SessionKey(defaultUserId, defaultSessionId)).foreach(_.removeAllPipelineExecutions())
2931
DataflowGraphRegistry.dropAllDataflowGraphs()
3032
PipelineTest.cleanupMetastore(spark)
3133
super.afterEach()

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput
3737
import org.apache.spark.sql.connect.config.Connect
3838
import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, SparkConnectPlanner, StreamingForeachBatchHelper}
3939
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner
40+
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
41+
import org.apache.spark.sql.pipelines.logging.PipelineEvent
4042
import org.apache.spark.sql.test.SharedSparkSession
4143
import org.apache.spark.util.ArrayImplicits._
4244

@@ -422,4 +424,19 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
422424
}
423425
assert(ex.getMessage.contains("already exists"))
424426
}
427+
428+
test("Pipeline execution cache") {
429+
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
430+
val graphId = "test_graph"
431+
val pipelineUpdateContext = new PipelineUpdateContextImpl(
432+
new DataflowGraph(Seq(), Seq(), Seq()),
433+
(_: PipelineEvent) => None
434+
)
435+
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
436+
assert(
437+
sessionHolder.getPipelineExecution(graphId).nonEmpty, "pipeline execution was not cached")
438+
sessionHolder.removeAllPipelineExecutions()
439+
assert(
440+
sessionHolder.getPipelineExecution(graphId).isEmpty, "pipeline execution was not removed")
441+
}
425442
}

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import org.scalatest.BeforeAndAfterEach
2323
import org.scalatest.time.SpanSugar._
2424

2525
import org.apache.spark.SparkSQLException
26+
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
27+
import org.apache.spark.sql.pipelines.logging.PipelineEvent
2628
import org.apache.spark.sql.test.SharedSparkSession
2729

2830
class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndAfterEach {
@@ -136,7 +138,6 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
136138
test("SessionHolder is recorded with status closed after close") {
137139
val key = SessionKey("user", UUID.randomUUID().toString)
138140
val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
139-
140141
val activeSessionInfo = SparkConnectService.sessionManager.listActiveSessions.find(
141142
_.sessionId == sessionHolder.sessionId)
142143
assert(activeSessionInfo.isDefined)
@@ -152,4 +153,21 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
152153
assert(closedSessionInfo.get.status == SessionStatus.Closed)
153154
assert(closedSessionInfo.get.closedTimeMs.isDefined)
154155
}
156+
157+
158+
test("Pipeline execution cache is cleared when the session holder is closed") {
159+
val key = SessionKey("user", UUID.randomUUID().toString)
160+
val sessionHolder = SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
161+
val graphId = "test_graph"
162+
val pipelineUpdateContext = new PipelineUpdateContextImpl(
163+
new DataflowGraph(Seq(), Seq(), Seq()),
164+
(_: PipelineEvent) => None
165+
)
166+
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
167+
assert(
168+
sessionHolder.getPipelineExecution(graphId).nonEmpty, "pipeline execution was not cached")
169+
sessionHolder.close()
170+
assert(
171+
sessionHolder.getPipelineExecution(graphId).isEmpty, "pipeline execution was not removed")
172+
}
155173
}

sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
606606

607607
val graph = pipelineDef.toDataflowGraph
608608
val updateContext = TestPipelineUpdateContext(spark, graph)
609-
updateContext.pipelineExecution.runPipeline()
609+
updateContext.pipelineExecution.startPipeline()
610610

611611
val graphExecution = updateContext.pipelineExecution.graphExecution.get
612612

@@ -1033,9 +1033,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
10331033
}.toDataflowGraph
10341034

10351035
val updateContext = TestPipelineUpdateContext(spark = spark, unresolvedGraph = graph)
1036-
intercept[UnresolvedPipelineException] {
1037-
updateContext.pipelineExecution.runPipeline()
1038-
}
1036+
updateContext.pipelineExecution.runPipeline()
10391037

10401038
assertFlowProgressEvent(
10411039
updateContext.eventBuffer,

0 commit comments

Comments
 (0)