Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ class SparkEnv (
* @param daemonModule The daemon module name to reuse the worker, e.g., "pyspark.daemon".
* @param envVars The environment variables for the worker.
*/
private case class PythonWorkersKey(
case class PythonWorkersKey(
pythonExec: String, workerModule: String, daemonModule: String, envVars: Map[String, String])
private val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]()
val pythonWorkers: mutable.Map[PythonWorkersKey, PythonWorkerFactory] =
mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]()

// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
// allow the user to set the batch size for the BatchedSerializer on UDFs
envVars.put("PYTHON_UDF_BATCH_SIZE", batchSizeForPythonUDF.toString)
envVars.put("PYSPARK_RUNTIME_PROFILE", true.toString)

envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ import org.apache.spark.internal.config.Python.PYTHON_FACTORY_IDLE_WORKER_MAX_PO
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}

case class PythonWorker(channel: SocketChannel) {
case class PythonWorker(
channel: SocketChannel,
extraChannel: Option[SocketChannel] = None) {

private[this] var selectorOpt: Option[Selector] = None
private[this] var selectionKeyOpt: Option[SelectionKey] = None
Expand Down Expand Up @@ -68,6 +70,7 @@ case class PythonWorker(channel: SocketChannel) {
def stop(): Unit = synchronized {
closeSelector()
Option(channel).foreach(_.close())
extraChannel.foreach(_.close())
}
}

Expand Down Expand Up @@ -129,6 +132,10 @@ private[spark] class PythonWorkerFactory(
envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))

def getAllDaemonWorkers: Seq[(PythonWorker, ProcessHandle)] = self.synchronized {
daemonWorkers.filter { case (_, handle) => handle.isAlive}.toSeq
}

def create(): (PythonWorker, Option[ProcessHandle]) = {
if (useDaemon) {
self.synchronized {
Expand Down Expand Up @@ -163,22 +170,36 @@ private[spark] class PythonWorkerFactory(
private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = {

def createWorker(): (PythonWorker, Option[ProcessHandle]) = {
val socketChannel = if (isUnixDomainSock) {
val mainChannel = if (isUnixDomainSock) {
SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath))
} else {
SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
}

val extraChannel = if (envVars.getOrElse("PYSPARK_RUNTIME_PROFILE", "false").toBoolean) {
if (isUnixDomainSock) {
Some(SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath)))
} else {
Some(SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)))
}
} else {
None
}

// These calls are blocking.
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
val pid = new DataInputStream(Channels.newInputStream(mainChannel)).readInt()
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
}
val processHandle = ProcessHandle.of(pid).orElseThrow(
() => new IllegalStateException("Python daemon failed to launch worker.")
)
authHelper.authToServer(socketChannel)
socketChannel.configureBlocking(false)
val worker = PythonWorker(socketChannel)

authHelper.authToServer(mainChannel)
mainChannel.configureBlocking(false)
extraChannel.foreach(_.configureBlocking(true))

val worker = PythonWorker(mainChannel, extraChannel)
daemonWorkers.put(worker, processHandle)
(worker.refresh(), Some(processHandle))
}
Expand Down Expand Up @@ -271,7 +292,7 @@ private[spark] class PythonWorkerFactory(
if (!blockingMode) {
socketChannel.configureBlocking(false)
}
val worker = PythonWorker(socketChannel)
val worker = PythonWorker(socketChannel, None)
self.synchronized {
simpleWorkers.put(worker, workerProcess)
}
Expand Down
33 changes: 30 additions & 3 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def compute_real_exit_code(exit_code):
return 1


def worker(sock, authenticated):
def worker(sock, sock2, authenticated):
"""
Called by a worker process after the fork().
"""
Expand All @@ -64,6 +64,9 @@ def worker(sock, authenticated):
buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size)
outfile2 = None
if sock2 is not None:
outfile2 = os.fdopen(os.dup(sock2.fileno()), "wb", buffer_size)

if not authenticated:
client_secret = UTF8Deserializer().loads(infile)
Expand All @@ -74,11 +77,16 @@ def worker(sock, authenticated):
write_with_length("err".encode("utf-8"), outfile)
outfile.flush()
sock.close()
if sock2 is not None:
sock2.close()
return 1

exit_code = 0
try:
worker_main(infile, outfile)
if sock2 is not None:
worker_main(infile, (outfile, outfile2))
else:
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = compute_real_exit_code(exc.code)
finally:
Expand All @@ -94,6 +102,7 @@ def manager():
os.setpgid(0, 0)

is_unix_domain_sock = os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED", "false").lower() == "true"
is_python_runtime_profile = os.environ.get("PYSPARK_RUNTIME_PROFILE", "false").lower() == "true"
socket_path = None

# Create a listening socket on the loopback interface
Expand Down Expand Up @@ -173,6 +182,15 @@ def handle_sigterm(*args):
continue
raise

sock2 = None
if is_python_runtime_profile:
try:
sock2, _ = listen_sock.accept()
except OSError as e:
if e.errno == EINTR:
continue
raise

# Launch a worker process
try:
pid = os.fork()
Expand All @@ -186,6 +204,13 @@ def handle_sigterm(*args):
outfile.flush()
outfile.close()
sock.close()

if sock2 is not None:
outfile = sock2.makefile(mode="wb")
write_int(e.errno, outfile) # Signal that the fork failed
outfile.flush()
outfile.close()
sock2.close()
continue

if pid == 0:
Expand Down Expand Up @@ -217,14 +242,16 @@ def handle_sigterm(*args):
or False
)
while True:
code = worker(sock, authenticated)
code = worker(sock, sock2, authenticated)
if code == 0:
authenticated = True
if not reuse or code:
# wait for closing
try:
while sock.recv(1024):
pass
while sock2 is not None and sock2.recv(1024):
pass
except Exception:
pass
break
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""
Worker that receives input from Piped RDD.
"""
import pickle
import threading
import itertools
import os
import sys
Expand Down Expand Up @@ -45,6 +47,7 @@
read_bool,
write_long,
read_int,
write_with_length,
SpecialLengths,
CPickleSerializer,
BatchedSerializer,
Expand Down Expand Up @@ -3167,7 +3170,28 @@ def func(_, it):
return func, None, ser, ser


def write_profile(outfile):
import yappi

while True:
stats = []
for thread in yappi.get_thread_stats():
data = list(yappi.get_func_stats(ctx_id=thread.id))
stats.extend([{str(k): str(v) for k, v in d.items()} for d in data])
pickled = pickle.dumps(stats)
write_with_length(pickled, outfile)
outfile.flush()
time.sleep(1)


def main(infile, outfile):
if isinstance(outfile, tuple):
import yappi

outfile, outfile2 = outfile
yappi.start()
threading.Thread(target=write_profile, args=(outfile2,), daemon=True).start()

faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
org.apache.spark.sql.execution.streaming.sources.PythonProfileSourceProvider
org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat
org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataSource
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources

import java.io.DataInputStream
import java.nio.channels.Channels
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

import net.razorvine.pickle.Unpickler

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
import org.apache.spark.sql.execution.streaming.runtime.LongOffset


class PythonProfileMicroBatchStream
extends MicroBatchStream with Logging {

@GuardedBy("this")
private var readThread: Thread = null

@GuardedBy("this")
private val batches = new ListBuffer[java.util.List[java.util.Map[String, String]]]

@GuardedBy("this")
private var currentOffset: LongOffset = LongOffset(-1L)

@GuardedBy("this")
private var lastOffsetCommitted: LongOffset = LongOffset(-1L)

private val initialized: AtomicBoolean = new AtomicBoolean(false)

private def initialize(): Unit = synchronized {
readThread = new Thread(s"PythonProfileMicroBatchStream") {
setDaemon(true)

override def run(): Unit = {
val unpickler = new Unpickler
val extraChannel = SparkEnv.get.pythonWorkers.values
.head.getAllDaemonWorkers.map(_._1.extraChannel).head
extraChannel.foreach { s =>
val inputStream = new DataInputStream(Channels.newInputStream(s))
while (true) {
val len = inputStream.readInt()
val buf = new Array[Byte](len)
var totalRead = 0
while (totalRead < len) {
val readNow = inputStream.read(buf, totalRead, len - totalRead)
assert(readNow != -1)
totalRead += readNow
}
currentOffset += 1
batches.append(
unpickler.loads(buf).asInstanceOf[java.util.List[java.util.Map[String, String]]])
}
}
}
}
readThread.start()
}

override def initialOffset(): Offset = LongOffset(-1L)

override def latestOffset(): Offset = currentOffset

override def deserializeOffset(json: String): Offset = {
LongOffset(json.toLong)
}

override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = {
val startOrdinal = start.asInstanceOf[LongOffset].offset.toInt + 1
val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1

val rawList = synchronized {
if (initialized.compareAndSet(false, true)) {
initialize()
}

val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
batches.slice(sliceStart, sliceEnd)
}

Array(PythonProfileInputPartition(rawList))
}

override def createReaderFactory(): PartitionReaderFactory =
(partition: InputPartition) => {
val stats = partition.asInstanceOf[PythonProfileInputPartition].stats
new PartitionReader[InternalRow] {
private var currentIdx = -1

override def next(): Boolean = {
currentIdx += 1
currentIdx < stats.size
}

override def get(): InternalRow = {
InternalRow.fromSeq(
CatalystTypeConverters.convertToCatalyst(
stats(currentIdx).asScala.toSeq.map(_.asScala)) :: Nil)
}

override def close(): Unit = {}
}
}

override def commit(end: Offset): Unit = synchronized {
val newOffset = end.asInstanceOf[LongOffset]

val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt

if (offsetDiff < 0) {
throw new IllegalStateException(
s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
}

batches.dropInPlace(offsetDiff)
lastOffsetCommitted = newOffset
}

override def toString: String = s"PythonProfile"

override def stop(): Unit = { }
}

case class PythonProfileInputPartition(
stats: ListBuffer[java.util.List[java.util.Map[String, String]]]) extends InputPartition
Loading