Skip to content

feat: instrument spawned tasks with current tracing span when tracing feature is enabled #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ Optional features:
- `backtrace`: include backtrace information in error messages
- `pyarrow`: conversions between PyArrow and DataFusion types
- `serde`: enable arrow-schema's `serde` feature
- `tracing`: propagates the current span across thread boundaries

[apache avro]: https://avro.apache.org/
[apache parquet]: https://parquet.apache.org/
Expand Down
4 changes: 3 additions & 1 deletion datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async-trait = { workspace = true }
bytes = { workspace = true }
dashmap = { workspace = true }
# note only use main datafusion crate for examples
datafusion = { workspace = true, default-features = true, features = ["avro"] }
datafusion = { workspace = true, default-features = true, features = ["avro", "tracing"] }
datafusion-proto = { workspace = true }
env_logger = { workspace = true }
futures = { workspace = true }
Expand All @@ -73,6 +73,8 @@ tempfile = { workspace = true }
test-utils = { path = "../test-utils" }
tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] }
tonic = "0.12.1"
tracing = { version = "0.1" }
tracing-subscriber = { version = "0.3" }
url = { workspace = true }
uuid = "1.7"

Expand Down
127 changes: 127 additions & 0 deletions datafusion-examples/examples/tracing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! This example demonstrates the trace feature in DataFusion’s runtime.
//! When the `trace` feature is enabled, spawned tasks in DataFusion (such as those
//! created during repartitioning or when reading Parquet files) are instrumented
//! with the current tracing span, allowing to propagate any existing tracing context.
//!
//! In this example we create a session configured to use multiple partitions,
//! register a Parquet table (based on the `alltypes_tiny_pages_plain.parquet` file),
//! and run a query that should trigger parallel execution on multiple threads.
//! We wrap the entire query execution within a custom span and log messages.
//! By inspecting the tracing output, we should see that the tasks spawned
//! internally inherit the span context.

use arrow::util::pretty::pretty_format_batches;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::ListingOptions;
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion::test_util::parquet_test_data;
use std::sync::Arc;
use tracing::{info, Level, instrument};

#[tokio::main]
async fn main() -> Result<()> {
// Initialize a tracing subscriber that prints to stdout.
tracing_subscriber::fmt()
.with_thread_ids(true)
.with_thread_names(true)
.with_max_level(Level::DEBUG)
.init();

log::info!("Starting example, this log is not captured by tracing");

// execute the query within a tracing span
let result = run_instrumented_query().await;

info!(
"Finished example. Check the logs above for tracing span details showing \
that tasks were spawned within the 'run_instrumented_query' span on different threads."
);

result
}

#[instrument(level = "info")]
async fn run_instrumented_query() -> Result<()> {
info!("Starting query execution within the custom tracing span");

// The default session will set the number of partitions to `std::thread::available_parallelism()`.
let ctx = SessionContext::new();

// Get the path to the test parquet data.
let test_data = parquet_test_data();
// Build listing options that pick up only the "alltypes_tiny_pages_plain.parquet" file.
let file_format = ParquetFormat::default().with_enable_pruning(true);
let listing_options = ListingOptions::new(Arc::new(file_format))
.with_file_extension("alltypes_tiny_pages_plain.parquet");

info!("Registering Parquet table 'alltypes' from {test_data} in {listing_options:?}");

// Register a listing table using an absolute URL.
let table_path = format!("file://{test_data}/");
ctx.register_listing_table(
"alltypes",
&table_path,
listing_options.clone(),
None,
None,
)
.await
.expect("register_listing_table failed");

info!("Registered Parquet table 'alltypes' from {table_path}");

// Run a query that will trigger parallel execution on multiple threads.
let sql = "SELECT COUNT(*), bool_col, date_string_col, string_col
FROM (
SELECT bool_col, date_string_col, string_col FROM alltypes
UNION ALL
SELECT bool_col, date_string_col, string_col FROM alltypes
) AS t
GROUP BY bool_col, date_string_col, string_col
ORDER BY 1,2,3,4 DESC
LIMIT 5;";
info!(%sql, "Executing SQL query");
let df = ctx.sql(sql).await?;

let results: Vec<RecordBatch> = df.collect().await?;
info!("Query execution complete");

// Print out the results and tracing output.
datafusion::common::assert_batches_eq!(
[
"+----------+----------+-----------------+------------+",
"| count(*) | bool_col | date_string_col | string_col |",
"+----------+----------+-----------------+------------+",
"| 2 | false | 01/01/09 | 9 |",
"| 2 | false | 01/01/09 | 7 |",
"| 2 | false | 01/01/09 | 5 |",
"| 2 | false | 01/01/09 | 3 |",
"| 2 | false | 01/01/09 | 1 |",
"+----------+----------+-----------------+------------+",
],
&results
);

info!("Query results:\n{}", pretty_format_batches(&results)?);

Ok(())
}
5 changes: 5 additions & 0 deletions datafusion/common-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,18 @@ rust-version = { workspace = true }
[lints]
workspace = true

[features]
tracing = ["dep:tracing", "dep:tracing-futures"]

[lib]
name = "datafusion_common_runtime"
path = "src/lib.rs"

[dependencies]
log = { workspace = true }
tokio = { workspace = true }
tracing = { version = "0.1", optional = true }
tracing-futures = { version = "0.2", optional = true }

[dev-dependencies]
tokio = { version = "1.36", features = ["rt", "rt-multi-thread", "time"] }
4 changes: 3 additions & 1 deletion datafusion/common-runtime/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

use std::future::Future;

use tokio::task::{JoinError, JoinSet};
use tokio::task::{JoinError};
use crate::JoinSet;

/// Helper that provides a simple API to spawn a single task and join it.
/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
Expand All @@ -36,6 +37,7 @@ impl<R: 'static> SpawnedTask<R> {
R: Send,
{
let mut inner = JoinSet::new();

inner.spawn(task);
Self { inner }
}
Expand Down
207 changes: 207 additions & 0 deletions datafusion/common-runtime/src/join_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::future::Future;
use std::task::{Context, Poll};
use tokio::runtime::Handle;
use tokio::task::JoinSet as TokioJoinSet;
use tokio::task::{AbortHandle, Id, JoinError, LocalSet};

#[cfg(feature = "tracing")]
mod trace_utils {
use std::future::Future;
use tracing_futures::Instrument;

/// Instruments the provided future with the current tracing span.
pub fn trace_future<F, T>(future: F) -> impl Future<Output = T> + Send
where
F: Future<Output = T> + Send + 'static,
T: Send,
{
future.in_current_span()
}

/// Wraps the provided blocking function to execute within the current tracing span.
pub fn trace_block<F, T>(f: F) -> impl FnOnce() -> T + Send + 'static
where
F: FnOnce() -> T + Send + 'static,
T: Send,
{
let span = tracing::Span::current();
move || span.in_scope(f)
}
}

/// A wrapper around Tokio's [`JoinSet`](tokio::task::JoinSet) that transparently forwards all public API calls
/// while optionally instrumenting spawned tasks and blocking closures with the current tracing span
/// when the `trace` feature is enabled.
#[derive(Debug)]
pub struct JoinSet<T> {
inner: TokioJoinSet<T>,
}

impl<T> Default for JoinSet<T> {
fn default() -> Self {
Self::new()
}
}

impl<T> JoinSet<T> {
/// [JoinSet::new](tokio::task::JoinSet::new) - Create a new JoinSet.
pub fn new() -> Self {
Self {
inner: TokioJoinSet::new(),
}
}

/// [JoinSet::len](tokio::task::JoinSet::len) - Return the number of tasks.
pub fn len(&self) -> usize {
self.inner.len()
}

/// [JoinSet::is_empty](tokio::task::JoinSet::is_empty) - Check if the JoinSet is empty.
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<T: 'static> JoinSet<T> {
/// [JoinSet::spawn](tokio::task::JoinSet::spawn) - Spawn a new task.
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
#[cfg(feature = "tracing")]
let task = trace_utils::trace_future(task);

self.inner.spawn(task)
}

/// [JoinSet::spawn_on](tokio::task::JoinSet::spawn_on) - Spawn a task on a provided runtime.
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
#[cfg(feature = "tracing")]
let task = trace_utils::trace_future(task);

self.inner.spawn_on(task, handle)
}

/// [JoinSet::spawn_local](tokio::task::JoinSet::spawn_local) - Spawn a local task.
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.inner.spawn_local(task)
}

/// [JoinSet::spawn_local_on](tokio::task::JoinSet::spawn_local_on) - Spawn a local task on a provided LocalSet.
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.inner.spawn_local_on(task, local_set)
}

/// [JoinSet::spawn_blocking](tokio::task::JoinSet::spawn_blocking) - Spawn a blocking task.
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send,
{
#[cfg(feature = "tracing")]
let f = trace_utils::trace_block(f);

self.inner.spawn_blocking(f)
}

/// [JoinSet::spawn_blocking_on](tokio::task::JoinSet::spawn_blocking_on) - Spawn a blocking task on a provided runtime.
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send,
{
#[cfg(feature = "tracing")]
let f = trace_utils::trace_block(f);

self.inner.spawn_blocking_on(f, handle)
}

/// [JoinSet::join_next](tokio::task::JoinSet::join_next) - Await the next completed task.
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
self.inner.join_next().await
}

/// [JoinSet::try_join_next](tokio::task::JoinSet::try_join_next) - Try to join the next completed task.
pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
self.inner.try_join_next()
}

/// [JoinSet::abort_all](tokio::task::JoinSet::abort_all) - Abort all tasks.
pub fn abort_all(&mut self) {
self.inner.abort_all()
}

/// [JoinSet::detach_all](tokio::task::JoinSet::detach_all) - Detach all tasks.
pub fn detach_all(&mut self) {
self.inner.detach_all()
}

/// [JoinSet::poll_join_next](tokio::task::JoinSet::poll_join_next) - Poll for the next completed task.
pub fn poll_join_next(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<T, JoinError>>> {
self.inner.poll_join_next(cx)
}

/// [JoinSet::join_next_with_id](tokio::task::JoinSet::join_next_with_id) - Await the next completed task with its ID.
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
self.inner.join_next_with_id().await
}

/// [JoinSet::try_join_next_with_id](tokio::task::JoinSet::try_join_next_with_id) - Try to join the next completed task with its ID.
pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
self.inner.try_join_next_with_id()
}

/// [JoinSet::poll_join_next_with_id](tokio::task::JoinSet::poll_join_next_with_id) - Poll for the next completed task with its ID.
pub fn poll_join_next_with_id(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Id, T), JoinError>>> {
self.inner.poll_join_next_with_id(cx)
}

/// [JoinSet::shutdown](tokio::task::JoinSet::shutdown) - Abort all tasks and wait for shutdown.
pub async fn shutdown(&mut self) {
self.inner.shutdown().await
}

/// [JoinSet::join_all](tokio::task::JoinSet::join_all) - Await all tasks.
pub async fn join_all(self) -> Vec<T> {
self.inner.join_all().await
}
}
Loading