Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.concurrent.Awaitable
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, SparkThrowable}

private[spark] object SparkThreadUtils {
// scalastyle:off awaitresult
Expand All @@ -41,13 +41,30 @@ private[spark] object SparkThreadUtils {
*/
@throws(classOf[SparkException])
def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
awaitResult(awaitable, atMost, preserveSparkThrowable = false)
}

@throws(classOf[SparkException])
def awaitResult[T](
awaitable: Awaitable[T],
atMost: Duration,
preserveSparkThrowable: Boolean): T = {
try {
awaitResultNoSparkExceptionConversion(awaitable, atMost)
} catch {
case e: SparkFatalException =>
throw e.throwable
// TimeoutException is thrown in the current thread, so not need to warp
// the exception.
// Re-throw exceptions that already carry a structured condition (SparkThrowable)
// to avoid wrapping them in a generic SparkException and losing the SQL state.
case st: Exception with SparkThrowable
if preserveSparkThrowable
&& !st.isInstanceOf[TimeoutException] && st.getCondition != null =>
// Attach the caller's stack trace so it's not lost when re-throwing from a worker thread.
st.addSuppressed(
new SparkException("Exception thrown in awaitResult", cause = null))
throw st
case NonFatal(t)
if !t.isInstanceOf[TimeoutException] =>
throw new SparkException("Exception thrown in awaitResult: ", t)
Expand Down
43 changes: 40 additions & 3 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.control.NonFatal

import com.google.common.util.concurrent.ThreadFactoryBuilder

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, SparkThrowable}

private[spark] object ThreadUtils {

Expand Down Expand Up @@ -358,10 +358,26 @@ private[spark] object ThreadUtils {
def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
SparkThreadUtils.awaitResult(awaitable, atMost)
}

@throws(classOf[SparkException])
def awaitResult[T](
awaitable: Awaitable[T],
atMost: Duration,
preserveSparkThrowable: Boolean): T = {
SparkThreadUtils.awaitResult(awaitable, atMost, preserveSparkThrowable)
}
// scalastyle:on awaitresult

@throws(classOf[SparkException])
def awaitResult[T](future: JFuture[T], atMost: Duration): T = {
awaitResult(future, atMost, preserveSparkThrowable = false)
}

@throws(classOf[SparkException])
def awaitResult[T](
future: JFuture[T],
atMost: Duration,
preserveSparkThrowable: Boolean): T = {
try {
atMost match {
case Duration.Inf => future.get()
Expand All @@ -370,6 +386,16 @@ private[spark] object ThreadUtils {
} catch {
case e: SparkFatalException =>
throw e.throwable
// JFuture.get() wraps exceptions in ExecutionException. Unwrap and check if the
// cause carries a structured condition (SparkThrowable) to preserve the SQL state.
case e: ExecutionException
if preserveSparkThrowable
&& e.getCause.isInstanceOf[SparkThrowable]
&& e.getCause.asInstanceOf[SparkThrowable].getCondition != null =>
// Attach the caller's stack trace so it's not lost when re-throwing from a worker thread.
e.getCause.addSuppressed(
new SparkException("Exception thrown in awaitResult", cause = null))
throw e.getCause
case NonFatal(t)
if !t.isInstanceOf[TimeoutException] =>
throw new SparkException("Exception thrown in awaitResult: ", t)
Expand Down Expand Up @@ -407,6 +433,11 @@ private[spark] object ThreadUtils {
}
}

/** See the overloaded [[parmap]] for full documentation. */
def parmap[I, O](in: Seq[I], prefix: String, maxThreads: Int)(f: I => O): Seq[O] = {
parmap(in, prefix, maxThreads, preserveSparkThrowable = false)(f)
}

/**
* Transforms input collection by applying the given function to each element in parallel fashion.
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
Expand All @@ -419,21 +450,27 @@ private[spark] object ThreadUtils {
* @param in - the input collection which should be transformed in parallel.
* @param prefix - the prefix assigned to the underlying thread pool.
* @param maxThreads - maximum number of thread can be created during execution.
* @param preserveSparkThrowable if true, re-throw exceptions that already carry a structured
* error class (SparkThrowable) instead of wrapping them in a generic SparkException.
* @param f - the lambda function will be applied to each element of `in`.
* @tparam I - the type of elements in the input collection.
* @tparam O - the type of elements in resulted collection.
* @return new collection in which each element was given from the input collection `in` by
* applying the lambda function `f`.
*/
def parmap[I, O](in: Seq[I], prefix: String, maxThreads: Int)(f: I => O): Seq[O] = {
def parmap[I, O](
in: Seq[I],
prefix: String,
maxThreads: Int,
preserveSparkThrowable: Boolean)(f: I => O): Seq[O] = {
val pool = newForkJoinPool(prefix, maxThreads)
try {
implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(pool)

val futures = in.map(x => Future(f(x)))
val futureSeq = Future.sequence(futures)

awaitResult(futureSeq, Duration.Inf)
awaitResult(futureSeq, Duration.Inf, preserveSparkThrowable)
} finally {
pool.shutdownNow()
}
Expand Down
88 changes: 87 additions & 1 deletion core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.util.Random

import org.scalatest.concurrent.Eventually._

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite, SparkThrowable}

class ThreadUtilsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -229,4 +229,90 @@ class ThreadUtilsSuite extends SparkFunSuite {
assert(!t.isAlive)
}
}

test("awaitResult preserves SparkThrowable when flag is true") {
import java.io.IOException

val sparkThrowableEx = new RuntimeException("structured error") with SparkThrowable {
override def getCondition: String = "TEST_ERROR_CLASS"
override def getMessageParameters: java.util.Map[String, String] =
java.util.Collections.emptyMap()
}

// With preserveSparkThrowable=true, SparkThrowable is re-thrown directly.
val f1 = Future {
throw sparkThrowableEx
}(ThreadUtils.sameThread)
val caught1 = intercept[RuntimeException] {
ThreadUtils.awaitResult(f1, 1.seconds, preserveSparkThrowable = true)
}
assert(caught1.isInstanceOf[SparkThrowable])
assert(caught1.asInstanceOf[SparkThrowable].getCondition == "TEST_ERROR_CLASS")
assert(caught1.getSuppressed.nonEmpty)

// With preserveSparkThrowable=false (default), SparkThrowable is wrapped in SparkException.
val f2 = Future {
throw sparkThrowableEx
}(ThreadUtils.sameThread)
val caught2 = intercept[SparkException] {
ThreadUtils.awaitResult(f2, 1.seconds)
}
assert(caught2.getCause.isInstanceOf[SparkThrowable])

// Plain exceptions are always wrapped regardless of the flag.
val plainEx = new IOException("plain error")
val f3 = Future {
throw plainEx
}(ThreadUtils.sameThread)
val caught3 = intercept[SparkException] {
ThreadUtils.awaitResult(f3, 1.seconds, preserveSparkThrowable = true)
}
assert(caught3.getCause eq plainEx)
}

test("awaitResult (JFuture) preserves SparkThrowable when flag is true") {
val sparkThrowableEx = new RuntimeException("structured error") with SparkThrowable {
override def getCondition: String = "TEST_ERROR_CLASS"
override def getMessageParameters: java.util.Map[String, String] =
java.util.Collections.emptyMap()
}

// scalastyle:off sparkThreadPools
val jfuture = new java.util.concurrent.CompletableFuture[String]()
// scalastyle:on sparkThreadPools
jfuture.completeExceptionally(sparkThrowableEx)

val caught = intercept[RuntimeException] {
ThreadUtils.awaitResult(jfuture, 10.seconds, preserveSparkThrowable = true)
}
assert(caught.isInstanceOf[SparkThrowable])
assert(caught.asInstanceOf[SparkThrowable].getCondition == "TEST_ERROR_CLASS")
assert(caught.getSuppressed.nonEmpty)
}

test("parmap preserves SparkThrowable when flag is true") {
val sparkThrowableEx = new RuntimeException("structured error") with SparkThrowable {
override def getCondition: String = "TEST_ERROR_CLASS"
override def getMessageParameters: java.util.Map[String, String] =
java.util.Collections.emptyMap()
}

// With preserveSparkThrowable=true, the original SparkThrowable is re-thrown.
val caught1 = intercept[RuntimeException] {
ThreadUtils.parmap(Seq(1), "test", 1, preserveSparkThrowable = true) { _ =>
throw sparkThrowableEx
}
}
assert(caught1.isInstanceOf[SparkThrowable])
assert(caught1.asInstanceOf[SparkThrowable].getCondition == "TEST_ERROR_CLASS")
assert(caught1.getSuppressed.nonEmpty)

// With preserveSparkThrowable=false, it is wrapped in SparkException.
val caught2 = intercept[SparkException] {
ThreadUtils.parmap(Seq(1), "test", 1, preserveSparkThrowable = false) { _ =>
throw sparkThrowableEx
}
}
assert(caught2.getCause.isInstanceOf[SparkThrowable])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ object ParquetFileFormat extends Logging {
partFiles: Seq[FileStatus],
ignoreCorruptFiles: Boolean,
ignoreMissingFiles: Boolean = false): Seq[Footer] = {
ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile =>
ThreadUtils.parmap(partFiles, "readingParquetFooters", 8,
preserveSparkThrowable = true) { currentFile =>
try {
// Skips row group information since we only need the schema.
// ParquetFileReader.readFooter throws RuntimeException, instead of IOException,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ abstract class ParquetFileFormatSuite

private def checkCannotReadFooterError(body: => Unit): Unit = {
checkErrorMatchPVals(
exception = intercept[SparkException] { body }.getCause.asInstanceOf[SparkException],
exception = intercept[SparkException] { body },
condition = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER",
parameters = Map("path" -> "file:.*")
)
Expand Down Expand Up @@ -97,10 +97,12 @@ abstract class ParquetFileFormatSuite
}

testReadFooters(true)
// With preserveSparkThrowable=true, the structured error class is thrown directly
// without being wrapped in a generic SparkException by awaitResult.
checkErrorMatchPVals(
exception = intercept[SparkException] {
testReadFooters(false)
}.getCause.asInstanceOf[SparkException],
},
condition = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER",
parameters = Map("path" -> "file:.*")
)
Expand Down Expand Up @@ -143,7 +145,7 @@ abstract class ParquetFileFormatSuite
exception = intercept[SparkException] {
ParquetFileFormat.readParquetFootersInParallel(
conf, Seq(fakeStatus), ignoreCorruptFiles = false, ignoreMissingFiles = false)
}.getCause.asInstanceOf[SparkException],
},
condition = "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER",
parameters = Map("path" -> s"${WrappingFNFLocalFileSystem.scheme}:.*")
)
Expand Down