Skip to content

feat: Add tracing regression tests #15673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions datafusion/core/tests/core_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ mod serde;
/// Run all tests that are found in the `catalog` directory
mod catalog;

/// Run all tests that are found in the `tracing` directory
mod tracing;

#[cfg(test)]
#[ctor::ctor]
fn init() {
Expand Down
142 changes: 142 additions & 0 deletions datafusion/core/tests/tracing/asserting_tracer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::collections::VecDeque;
use std::ops::Deref;
use std::sync::{Arc, LazyLock};

use datafusion_common::{HashMap, HashSet};
use datafusion_common_runtime::{set_join_set_tracer, JoinSetTracer};
use futures::future::BoxFuture;
use tokio::sync::{Mutex, MutexGuard};

/// Initializes the global join set tracer with the asserting tracer.
/// Call this function before spawning any tasks that should be traced.
pub fn init_asserting_tracer() {
set_join_set_tracer(ASSERTING_TRACER.deref())
.expect("Failed to initialize asserting tracer");
}

/// Verifies that the current task has a traceable ancestry back to "root".
///
/// The function performs a breadth-first search (BFS) in the global spawn graph:
/// - It starts at the current task and follows parent links.
/// - If it reaches the "root" task, the ancestry is valid.
/// - If a task is missing from the graph, it panics.
///
/// Note: Tokio task IDs are unique only while a task is active.
/// Once a task completes, its ID may be reused.
pub async fn assert_traceability() {
// Acquire the spawn graph lock.
let spawn_graph = acquire_spawn_graph().await;

// Start BFS with the current task.
let mut tasks_to_check = VecDeque::from(vec![current_task()]);

while let Some(task_id) = tasks_to_check.pop_front() {
if task_id == "root" {
// Ancestry reached the root.
continue;
}
// Obtain parent tasks, panicking if the task is not present.
let parents = spawn_graph
.get(&task_id)
.expect("Task ID not found in spawn graph");
// Queue each parent for checking.
for parent in parents {
tasks_to_check.push_back(parent.clone());
}
}
}

/// Tracer that maintains a graph of task ancestry for tracing purposes.
///
/// For each task, it records a set of parent task IDs to ensure that every
/// asynchronous task can be traced back to "root".
struct AssertingTracer {
/// An asynchronous map from task IDs to their parent task IDs.
spawn_graph: Arc<Mutex<HashMap<String, HashSet<String>>>>,
}

/// Lazily initialized global instance of `AssertingTracer`.
static ASSERTING_TRACER: LazyLock<AssertingTracer> = LazyLock::new(AssertingTracer::new);

impl AssertingTracer {
/// Creates a new `AssertingTracer` with an empty spawn graph.
fn new() -> Self {
Self {
spawn_graph: Arc::default(),
}
}
}

/// Returns the current task's ID as a string, or "root" if unavailable.
///
/// Tokio guarantees task IDs are unique only among active tasks,
/// so completed tasks may have their IDs reused.
fn current_task() -> String {
tokio::task::try_id()
.map(|id| format!("{id}"))
.unwrap_or_else(|| "root".to_string())
}

/// Asynchronously locks and returns the spawn graph.
///
/// The returned guard allows inspection or modification of task ancestry.
async fn acquire_spawn_graph<'a>() -> MutexGuard<'a, HashMap<String, HashSet<String>>> {
ASSERTING_TRACER.spawn_graph.lock().await
}

/// Registers the current task as a child of `parent_id` in the spawn graph.
async fn register_task(parent_id: String) {
acquire_spawn_graph()
.await
.entry(current_task())
.or_insert_with(HashSet::new)
.insert(parent_id);
}

impl JoinSetTracer for AssertingTracer {
/// Wraps an asynchronous future to record its parent task before execution.
fn trace_future(
&self,
fut: BoxFuture<'static, Box<dyn Any + Send>>,
) -> BoxFuture<'static, Box<dyn Any + Send>> {
// Capture the parent task ID.
let parent_id = current_task();
Box::pin(async move {
// Register the parent-child relationship.
register_task(parent_id).await;
// Execute the wrapped future.
fut.await
})
}

/// Wraps a blocking closure to record its parent task before execution.
fn trace_block(
&self,
f: Box<dyn FnOnce() -> Box<dyn Any + Send> + Send>,
) -> Box<dyn FnOnce() -> Box<dyn Any + Send> + Send> {
let parent_id = current_task();
Box::new(move || {
// Synchronously record the task relationship.
futures::executor::block_on(register_task(parent_id));
f()
})
}
}
108 changes: 108 additions & 0 deletions datafusion/core/tests/tracing/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! # JoinSetTracer Integration Tests
//!
//! These are smoke tests that verify `JoinSetTracer` can be correctly injected into DataFusion.
//!
//! They run a SQL query that reads Parquet data and performs an aggregation,
//! which causes DataFusion to spawn multiple tasks.
//! The object store is wrapped to assert that every task can be traced back to the root.
//!
//! These tests don't cover all edge cases, but they should fail if changes to
//! DataFusion's task spawning break tracing.

mod asserting_tracer;
mod traceable_object_store;

use asserting_tracer::init_asserting_tracer;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::ListingOptions;
use datafusion::prelude::*;
use datafusion::test_util::parquet_test_data;
use datafusion_common::assert_contains;
use datafusion_common_runtime::SpawnedTask;
use log::info;
use object_store::local::LocalFileSystem;
use std::sync::Arc;
use traceable_object_store::traceable_object_store;
use url::Url;

/// Combined test that first verifies the query panics when no tracer is registered,
/// then initializes the tracer and confirms the query runs successfully.
///
/// Using a single test function prevents global tracer leakage between tests.
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_tracer_injection() {
// Without initializing the tracer, run the query.
// Spawn the query in a separate task so we can catch its panic.
info!("Running query without tracer");
// The absence of the tracer should cause the task to panic inside the `TraceableObjectStore`.
let untraced_result = SpawnedTask::spawn(run_query()).join().await;
if let Err(e) = untraced_result {
// Check if the error message contains the expected error.
assert!(e.is_panic(), "Expected a panic, but got: {:?}", e);
assert_contains!(e.to_string(), "Task ID not found in spawn graph");
info!("Caught expected panic: {}", e);
} else {
panic!("Expected the task to panic, but it completed successfully");
};

// Initialize the asserting tracer and run the query.
info!("Initializing tracer and re-running query");
init_asserting_tracer();
SpawnedTask::spawn(run_query()).join().await.unwrap(); // Should complete without panics or errors.
}

/// Executes a sample task-spawning SQL query using a traceable object store.
async fn run_query() {
info!("Starting query execution");

// Create a new session context
let ctx = SessionContext::new();

// Get the test data directory
let test_data = parquet_test_data();

// Define a Parquet file format with pruning enabled
let file_format = ParquetFormat::default().with_enable_pruning(true);

// Set listing options for the parquet file with a specific extension
let listing_options = ListingOptions::new(Arc::new(file_format))
.with_file_extension("alltypes_tiny_pages_plain.parquet");

// Wrap the local file system in a traceable object store to verify task traceability.
let local_fs = Arc::new(LocalFileSystem::new());
let traceable_store = traceable_object_store(local_fs);

// Register the traceable object store with a test URL.
let url = Url::parse("test://").unwrap();
ctx.register_object_store(&url, traceable_store.clone());

// Register a listing table from the test data directory.
let table_path = format!("test://{}/", test_data);
ctx.register_listing_table("alltypes", &table_path, listing_options, None, None)
.await
.expect("Failed to register table");

// Define and execute an SQL query against the registered table, which should
// spawn multiple tasks due to the aggregation and parquet file read.
let sql = "SELECT COUNT(*), string_col FROM alltypes GROUP BY string_col";
let result_batches = ctx.sql(sql).await.unwrap().collect().await.unwrap();

info!("Query complete: {} batches returned", result_batches.len());
}
125 changes: 125 additions & 0 deletions datafusion/core/tests/tracing/traceable_object_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Object store implementation used for testing

use crate::tracing::asserting_tracer::assert_traceability;
use futures::stream::BoxStream;
use object_store::{
path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta,
ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult,
};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;

/// Returns an `ObjectStore` that asserts it can trace its calls back to the root tokio task.
pub fn traceable_object_store(
object_store: Arc<dyn ObjectStore>,
) -> Arc<dyn ObjectStore> {
Arc::new(TraceableObjectStore::new(object_store))
}

/// An object store that asserts it can trace all its calls back to the root tokio task.
#[derive(Debug)]
struct TraceableObjectStore {
inner: Arc<dyn ObjectStore>,
}

impl TraceableObjectStore {
fn new(inner: Arc<dyn ObjectStore>) -> Self {
Self { inner }
}
}

impl Display for TraceableObjectStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.inner, f)
}
}

/// All trait methods are forwarded to the inner object store,
/// after asserting they can trace their calls back to the root tokio task.
#[async_trait::async_trait]
impl ObjectStore for TraceableObjectStore {
async fn put_opts(
&self,
location: &Path,
payload: PutPayload,
opts: PutOptions,
) -> object_store::Result<PutResult> {
assert_traceability().await;
self.inner.put_opts(location, payload, opts).await
}

async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOpts,
) -> object_store::Result<Box<dyn MultipartUpload>> {
assert_traceability().await;
self.inner.put_multipart_opts(location, opts).await
}

async fn get_opts(
&self,
location: &Path,
options: GetOptions,
) -> object_store::Result<GetResult> {
assert_traceability().await;
self.inner.get_opts(location, options).await
}

async fn head(&self, location: &Path) -> object_store::Result<ObjectMeta> {
assert_traceability().await;
self.inner.head(location).await
}

async fn delete(&self, location: &Path) -> object_store::Result<()> {
assert_traceability().await;
self.inner.delete(location).await
}

fn list(
&self,
prefix: Option<&Path>,
) -> BoxStream<'_, object_store::Result<ObjectMeta>> {
futures::executor::block_on(assert_traceability());
self.inner.list(prefix)
}

async fn list_with_delimiter(
&self,
prefix: Option<&Path>,
) -> object_store::Result<ListResult> {
assert_traceability().await;
self.inner.list_with_delimiter(prefix).await
}

async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> {
assert_traceability().await;
self.inner.copy(from, to).await
}

async fn copy_if_not_exists(
&self,
from: &Path,
to: &Path,
) -> object_store::Result<()> {
assert_traceability().await;
self.inner.copy_if_not_exists(from, to).await
}
}
Loading