Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.kyuubi.SparkDatasetHelper._

import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_OPERATION_INCREMENTAL_COLLECT_CANCEL_JOB_GROUP, OPERATION_RESULT_MAX_ROWS, OPERATION_RESULT_SAVE_TO_FILE, OPERATION_RESULT_SAVE_TO_FILE_MIN_ROWS, OPERATION_RESULT_SAVE_TO_FILE_MINSIZE}
import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_OPERATION_INCREMENTAL_COLLECT_CANCEL_JOB_GROUP, OPERATION_RESULT_MAX_ROWS, OPERATION_RESULT_PREFETCH, OPERATION_RESULT_SAVE_TO_FILE, OPERATION_RESULT_SAVE_TO_FILE_MIN_ROWS, OPERATION_RESULT_SAVE_TO_FILE_MINSIZE}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
import org.apache.kyuubi.engine.spark.session.{SparkSessionImpl, SparkSQLSessionManager}
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationHandle, OperationState}
Expand Down Expand Up @@ -195,7 +195,12 @@ class ExecuteStatement(
}
info(s"Save result to ${saveFilePath.get}")
fetchOrcStatement = Some(new FetchOrcStatement(spark))
return fetchOrcStatement.get.getIterator(saveFilePath.get.toString, resultSchema)
asyncFetchHdfsResultMode = getSessionConf(OPERATION_RESULT_PREFETCH, spark)
return fetchOrcStatement.get.getIterator(
saveFilePath.get.toString,
resultSchema,
getProtocolVersion,
asyncFetchHdfsResultMode)
}
val internalArray = if (resultMaxRows <= 0) {
info("Execute in full collect mode")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.kyuubi.engine.spark.operation

import java.util.concurrent.{Executors, ExecutorService}

import scala.Array._
import scala.collection.mutable.ListBuffer

Expand All @@ -35,12 +37,23 @@ import org.apache.spark.sql.execution.datasources.RecordReaderIterator
import org.apache.spark.sql.execution.datasources.orc.OrcDeserializer
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.Logging
import org.apache.kyuubi.config.KyuubiConf.{OPERATION_RESULT_PREFETCH_QUEUE_SIZE, OPERATION_RESULT_PREFETCH_TIMEOUT}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.getSessionConf
import org.apache.kyuubi.operation.{FetchIterator, IterableFetchIterator}
import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TProtocolVersion
import org.apache.kyuubi.util.NamedThreadFactory

class FetchOrcStatement(spark: SparkSession) {
class FetchOrcStatement(spark: SparkSession) extends Logging {

var orcIter: OrcFileIterator = _
def getIterator(path: String, orcSchema: StructType): FetchIterator[Row] = {
var fetchThreadPool: ExecutorService = _

def getIterator(
path: String,
orcSchema: StructType,
protocolVersion: TProtocolVersion,
asyncFetchEnabled: Boolean): FetchIterator[Row] = {
val conf = spark.sparkContext.hadoopConfiguration
val savePath = new Path(path)
val fsIterator = savePath.getFileSystem(conf).listFiles(savePath, false)
Expand All @@ -64,13 +77,38 @@ class FetchOrcStatement(spark: SparkSession) {
val iterRow = orcIter.map(value =>
unsafeProjection(deserializer.deserialize(value)))
.map(value => toRowConverter(value))
if (asyncFetchEnabled) {
info(f"Creating thread pool for result prefetching")
fetchThreadPool =
Executors.newFixedThreadPool(1, new NamedThreadFactory("Result-Prefetch-Pool", false))
val asyncFetchTimeout: Long = getSessionConf(OPERATION_RESULT_PREFETCH_TIMEOUT, spark)
val rowsetsQueueSize: Int = getSessionConf(OPERATION_RESULT_PREFETCH_QUEUE_SIZE, spark)
new IterableAsyncFetchIterator[Row](
new Iterable[Row] {
override def iterator: Iterator[Row] = iterRow
},
fetchThreadPool,
orcSchema,
protocolVersion,
asyncFetchTimeout,
rowsetsQueueSize)
} else {
new IterableFetchIterator[Row](new Iterable[Row] {
override def iterator: Iterator[Row] = iterRow
})
}

new IterableFetchIterator[Row](new Iterable[Row] {
override def iterator: Iterator[Row] = iterRow
})
}

def close(): Unit = {
orcIter.close()
if (fetchThreadPool != null) {
fetchThreadPool.shutdown()
info(f"Closing thread pool of result prefetching")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.spark.operation

import java.util.concurrent._

import scala.reflect.ClassTag

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.Logging
import org.apache.kyuubi.engine.spark.schema.SparkTRowSetGenerator
import org.apache.kyuubi.operation.FetchIterator
import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TProtocolVersion, TRowSet};

class IterableAsyncFetchIterator[A: ClassTag](
iterable: Iterable[A],
prefetchExecutor: ExecutorService,
resultSchema: StructType,
protocolVersion: TProtocolVersion,
asyncFetchTimeout: Long,
rowsetsQueueSize: Int) extends FetchIterator[A] with Logging {
val results = new LinkedBlockingQueue[(TRowSet, Int)](rowsetsQueueSize)
private var future: Future[Boolean] = _
private val tRowSetGenerator = new SparkTRowSetGenerator()
var rowSetSize: Int = -1

private var iter: Iterator[A] = iterable.iterator
private var iterEx: Iterator[Row] = iter.asInstanceOf[Iterator[Row]]

private var fetchStart: Long = 0
private var position: Long = 0

private def startPrefetchThread(): Unit = {
val task = new Callable[Boolean] {
override def call(): Boolean = {
var succeeded = true
try {
var isEmptyIter = false
while (!isEmptyIter) {
val taken = iterEx.take(rowSetSize)
val rows = taken.toArray
val rowSet = tRowSetGenerator.toTRowSet(
rows.toSeq,
resultSchema,
protocolVersion)
results.put(rowSet, rows.length)
isEmptyIter = (rows.length == 0)
}
} catch {
case e: Throwable =>
error(s"An exception occurred in the prefetch thread. message: ${e.getMessage}")
e.printStackTrace(System.err)
results.clear()
succeeded = false
}
succeeded
}
}
}

def takeRowSet(numRows: Int): TRowSet = {
if (rowSetSize < 0) {
rowSetSize = numRows
startPrefetchThread()
} else {
if (numRows != rowSetSize) {
throw new Exception("The current size of row set is different from " +
"that of the initial call.")
}
}
var result: (TRowSet, Int) = results.poll(3000, TimeUnit.MILLISECONDS)
if (result == null && !future.isDone) {
warn("Queue of prefetched results is empty, the prefetch thread my be stuck," +
s"retry to retrieve result with a timeout of ${asyncFetchTimeout} ms")
result = results.poll(asyncFetchTimeout, TimeUnit.MILLISECONDS)
}
if (result == null) {
if (future.isDone) {
val prefetchSucceeded = future.get()
if (!prefetchSucceeded) {
throw new Exception("An exception occurred in the prefetch thread.")
} else {
result = results.poll()
if (result == null) {
error("Thrift client fetched more than one empty rowSet!")
result = (
tRowSetGenerator.toTRowSet(
Seq.empty[Row],
resultSchema,
protocolVersion),
0)
}
}
} else {
future.cancel(true)
throw new Exception("the prefetch thread is stuck.")
}
}
position += result._2
result._1
}

/**
* Begin a fetch block, forward from the current position.
* Resets the fetch start offset.
*/
override def fetchNext(): Unit = fetchStart = position

/**
* Begin a fetch block, moving the iterator to the given position.
* Resets the fetch start offset.
*
* @param pos index to move a position of iterator.
*/
override def fetchAbsolute(pos: Long): Unit = {
if (future != null) {
future.cancel(true)
future.get(5, TimeUnit.SECONDS)
if (!future.isDone) {
throw new Exception("Cancel the prefetch thread failed")
}
future = null
results.clear()
}
val newPos = pos max 0
resetPosition()
while (position < newPos && hasNextInternal) {
nextInternal()
}
rowSetSize = -1
}

override def getFetchStart: Long = fetchStart

override def getPosition: Long = position

override def hasNext: Boolean = {
throw new Exception("Unsupported function: IterableAsyncFetchIterator.hasNext")
}
def hasNextInternal: Boolean = iter.hasNext
override def next(): A = {
throw new Exception("Unsupported function: IterableAsyncFetchIterator.next")
}
def nextInternal(): A = {
position += 1
iter.next()
}

private def resetPosition(): Unit = {
if (position != 0) {
iter = iterable.iterator
iterEx = iter.asInstanceOf[Iterator[Row]]
position = 0
fetchStart = 0
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ abstract class SparkOperation(session: Session)
}

private val progressEnable: Boolean = getSessionConf(SESSION_PROGRESS_ENABLE, spark)

var asyncFetchHdfsResultMode: Boolean = false
protected def supportProgress: Boolean = false

protected def outputMode: EngineSparkOutputMode.EngineSparkOutputMode =
Expand Down Expand Up @@ -261,13 +261,19 @@ abstract class SparkOperation(session: Session)
try {
withLocalProperties {
validateDefaultFetchOrientation(order)
if (asyncFetchHdfsResultMode && order != FETCH_NEXT) {
throw KyuubiSQLException(s"The fetch type ${order} is not support for this ResultSet.")
}
assertState(OperationState.FINISHED)
setHasResultSet(true)
order match {
case FETCH_NEXT => iter.fetchNext()
case FETCH_PRIOR => iter.fetchPrior(rowSetSize);
case FETCH_FIRST => iter.fetchAbsolute(0);
}
if (iter.getPosition <= 0) {
info(s"fetching rowSet firstly, order: ${order.toString},rowSetSize: ${rowSetSize}")
}
resultRowSet =
if (isArrowBasedOperation) {
if (iter.hasNext) {
Expand All @@ -280,11 +286,16 @@ abstract class SparkOperation(session: Session)
ThriftUtils.newEmptyRowSet
}
} else {
val taken = iter.take(rowSetSize)
new SparkTRowSetGenerator().toTRowSet(
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion)
if (asyncFetchHdfsResultMode) {
val rowSet = iter.asInstanceOf[IterableAsyncFetchIterator[Row]].takeRowSet(rowSetSize)
rowSet
} else {
val taken = iter.take(rowSetSize)
new SparkTRowSetGenerator().toTRowSet(
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion)
}
}
resultRowSet.setStartRowOffset(iter.getPosition)
}
Expand Down
Loading
Loading