Skip to content

perf: Use a global tokio runtime #1614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions common/src/main/java/org/apache/comet/parquet/Native.java
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,7 @@ public static native long initRecordBatchReader(
byte[] requiredSchema,
byte[] dataSchema,
String sessionTimezone,
int batchSize,
int workerThreads,
int blockingThreads);
int batchSize);

// arrow native version of read batch
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,16 +357,6 @@ public void init() throws URISyntaxException, IOException {
conf.getInt(
CometConf.COMET_BATCH_SIZE().key(),
(Integer) CometConf.COMET_BATCH_SIZE().defaultValue().get());
int workerThreads =
conf.getInt(
CometConf.COMET_WORKER_THREADS().key(),
(Integer) CometConf.COMET_WORKER_THREADS().defaultValue().get());
;
int blockingThreads =
conf.getInt(
CometConf.COMET_BLOCKING_THREADS().key(),
(Integer) CometConf.COMET_BLOCKING_THREADS().defaultValue().get());
;
this.handle =
Native.initRecordBatchReader(
filePath,
Expand All @@ -377,9 +367,7 @@ public void init() throws URISyntaxException, IOException {
serializedRequestedArrowSchema,
serializedDataArrowSchema,
timeZoneId,
batchSize,
workerThreads,
blockingThreads);
batchSize);
isInitialized = true;
}

Expand Down
16 changes: 0 additions & 16 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -457,22 +457,6 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val COMET_WORKER_THREADS: ConfigEntry[Int] =
conf("spark.comet.workerThreads")
.internal()
.doc("The number of worker threads used for Comet native execution. " +
"By default, this config is 4.")
.intConf
.createWithDefault(4)

val COMET_BLOCKING_THREADS: ConfigEntry[Int] =
conf("spark.comet.blockingThreads")
.internal()
.doc("The number of blocking threads used for Comet native execution. " +
"By default, this config is 10.")
.intConf
.createWithDefault(10)

val COMET_BATCH_SIZE: ConfigEntry[Int] = conf("spark.comet.batchSize")
.doc("The columnar batch size, i.e., the maximum number of rows that a batch can contain.")
.intConf
Expand Down
9 changes: 9 additions & 0 deletions docs/source/user-guide/tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ under the License.

Comet provides some tuning options to help you get the best performance from your queries.

## Configuring Tokio Runtime

Comet uses a global tokio runtime per executor process using tokio's defaults of one worker thread per core and a
maximum of 512 blocking threads. These values can be overridden using the environment variables `COMET_WORKER_THREADS`
and `COMET_MAX_BLOCKING_THREADS`.

DataFusion currently has a known issue when merging spill files in sort operators where the process can deadlock if
there are more spill files than `COMET_MAX_BLOCKING_THREADS` ([tracking issue](https://github.com/apache/datafusion/issues/15323)).

## Memory Tuning

It is necessary to specify how much memory Comet can use in addition to memory already allocated to Spark. In some
Expand Down
37 changes: 24 additions & 13 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ use crate::execution::spark_plan::SparkPlan;
use log::info;
use once_cell::sync::{Lazy, OnceCell};

static TOKIO_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
let mut builder = tokio::runtime::Builder::new_multi_thread();
if let Some(n) = parse_usize_env_var("COMET_WORKER_THREADS") {
builder.worker_threads(n);
}
if let Some(n) = parse_usize_env_var("COMET_MAX_BLOCKING_THREADS") {
builder.max_blocking_threads(n);
}
builder
.enable_all()
.build()
.expect("Failed to create Tokio runtime")
});

fn parse_usize_env_var(name: &str) -> Option<usize> {
std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::<usize>().ok()))
}

/// Function to get a handle to the global Tokio runtime
pub fn get_runtime() -> &'static Runtime {
&TOKIO_RUNTIME
}

/// Comet native execution context. Kept alive across JNI calls.
struct ExecutionContext {
/// The id of the execution context.
Expand All @@ -89,8 +112,6 @@ struct ExecutionContext {
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// The Tokio runtime used for async.
pub runtime: Runtime,
/// Native metrics
pub metrics: Arc<GlobalRef>,
// The interval in milliseconds to update metrics
Expand Down Expand Up @@ -177,8 +198,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
task_attempt_id: jlong,
debug_native: jboolean,
explain_native: jboolean,
worker_threads: jint,
blocking_threads: jint,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
Expand All @@ -192,13 +211,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// Deserialize query plan
let spark_plan = serde::deserialize_op(bytes.as_slice())?;

// Use multi-threaded tokio runtime to prevent blocking spawned tasks if any
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_threads as usize)
.max_blocking_threads(blocking_threads as usize)
.enable_all()
.build()?;

let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);

// Get the global references of input sources
Expand Down Expand Up @@ -258,7 +270,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
scans: vec![],
input_sources,
stream: None,
runtime,
metrics,
metrics_update_interval,
metrics_last_update_time: Instant::now(),
Expand Down Expand Up @@ -559,7 +570,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
loop {
// Polling the stream.
let next_item = exec_context.stream.as_mut().unwrap().next();
let poll_output = exec_context.runtime.block_on(async { poll!(next_item) });
let poll_output = get_runtime().block_on(async { poll!(next_item) });

// update metrics at interval
if let Some(interval) = exec_context.metrics_update_interval {
Expand Down
13 changes: 2 additions & 11 deletions native/core/src/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use jni::{
};

use self::util::jni::TypePromotionInfo;
use crate::execution::jni_api::get_runtime;
use crate::execution::operators::ExecutionError;
use crate::execution::planner::PhysicalPlanner;
use crate::execution::serde;
Expand Down Expand Up @@ -606,7 +607,6 @@ enum ParquetReaderState {
}
/// Parquet read context maintained across multiple JNI calls.
struct BatchContext {
runtime: tokio::runtime::Runtime,
batch_stream: Option<SendableRecordBatchStream>,
current_batch: Option<RecordBatch>,
reader_state: ParquetReaderState,
Expand Down Expand Up @@ -652,8 +652,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
data_schema: jbyteArray,
session_timezone: jstring,
batch_size: jint,
worker_threads: jint,
blocking_threads: jint,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| unsafe {
let session_config = SessionConfig::new().with_batch_size(batch_size as usize);
Expand All @@ -666,12 +664,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
.unwrap()
.into();

let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_threads as usize)
.max_blocking_threads(blocking_threads as usize)
.enable_all()
.build()?;

let (object_store_url, object_store_path) =
prepare_object_store(session_ctx.runtime_env(), path.clone())?;

Expand Down Expand Up @@ -718,7 +710,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
let batch_stream = Some(scan.execute(partition_index, session_ctx.task_ctx())?);

let ctx = BatchContext {
runtime,
batch_stream,
current_batch: None,
reader_state: ParquetReaderState::Init,
Expand All @@ -738,7 +729,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_readNextRecordBatch(
let context = get_batch_context(handle)?;
let mut rows_read: i32 = 0;
let batch_stream = context.batch_stream.as_mut().unwrap();
let runtime = &context.runtime;
let runtime = get_runtime();

loop {
let next_item = batch_stream.next();
Expand Down
6 changes: 2 additions & 4 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.comet.CometMetricNode
import org.apache.spark.sql.vectorized._

import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_METRICS_UPDATE_INTERVAL, COMET_WORKER_THREADS}
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_METRICS_UPDATE_INTERVAL}
import org.apache.comet.vector.NativeUtil

/**
Expand Down Expand Up @@ -92,9 +92,7 @@ class CometExecIterator(
memoryLimitPerTask = getMemoryLimitPerTask(conf),
taskAttemptId = TaskContext.get().taskAttemptId,
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
explain = COMET_EXPLAIN_NATIVE_ENABLED.get())
}

private var nextBatch: Option[ColumnarBatch] = None
Expand Down
4 changes: 1 addition & 3 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ class Native extends NativeBase {
memoryLimitPerTask: Long,
taskAttemptId: Long,
debug: Boolean,
explain: Boolean,
workerThreads: Int,
blockingThreads: Int): Long
explain: Boolean): Long
// scalastyle:on

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,6 @@ object CometParquetFileFormat extends Logging with ShimSQLConf {
CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.key,
CometConf.COMET_EXCEPTION_ON_LEGACY_DATE_TIMESTAMP.get())
hadoopConf.setInt(CometConf.COMET_BATCH_SIZE.key, CometConf.COMET_BATCH_SIZE.get())
hadoopConf.setInt(CometConf.COMET_WORKER_THREADS.key, CometConf.COMET_WORKER_THREADS.get())
hadoopConf.setInt(
CometConf.COMET_BLOCKING_THREADS.key,
CometConf.COMET_BLOCKING_THREADS.get())
}

def getDatetimeRebaseSpec(
Expand Down
Loading