Skip to content

Make ExecutionPlan::execute Sync #2434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions ballista/rust/core/src/execution_plans/distributed_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ use datafusion::physical_plan::{

use crate::serde::protobuf::execute_query_params::OptionalSessionId;
use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec};
use async_trait::async_trait;
use datafusion::arrow::error::{ArrowError, Result as ArrowResult};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::execution::context::TaskContext;
Expand Down Expand Up @@ -122,7 +121,6 @@ impl<T: 'static + AsLogicalPlan> DistributedQueryExec<T> {
}
}

#[async_trait]
impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -162,7 +160,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
}))
}

async fn execute(
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
Expand Down
4 changes: 1 addition & 3 deletions ballista/rust/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::sync::Arc;
use crate::client::BallistaClient;
use crate::serde::scheduler::{PartitionLocation, PartitionStats};

use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;

use datafusion::error::{DataFusionError, Result};
Expand Down Expand Up @@ -64,7 +63,6 @@ impl ShuffleReaderExec {
}
}

#[async_trait]
impl ExecutionPlan for ShuffleReaderExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -101,7 +99,7 @@ impl ExecutionPlan for ShuffleReaderExec {
))
}

async fn execute(
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
Expand Down
10 changes: 4 additions & 6 deletions ballista/rust/core/src/execution_plans/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use crate::utils;

use crate::serde::protobuf::ShuffleWritePartition;
use crate::serde::scheduler::PartitionStats;
use async_trait::async_trait;
use datafusion::arrow::array::{
ArrayBuilder, ArrayRef, StringBuilder, StructBuilder, UInt32Builder, UInt64Builder,
};
Expand Down Expand Up @@ -155,7 +154,7 @@ impl ShuffleWriterExec {

async move {
let now = Instant::now();
let mut stream = plan.execute(input_partition, context).await?;
let mut stream = plan.execute(input_partition, context)?;

match output_partitioning {
None => {
Expand Down Expand Up @@ -293,7 +292,6 @@ impl ShuffleWriterExec {
}
}

#[async_trait]
impl ExecutionPlan for ShuffleWriterExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -336,7 +334,7 @@ impl ExecutionPlan for ShuffleWriterExec {
)?))
}

async fn execute(
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
Expand Down Expand Up @@ -459,7 +457,7 @@ mod tests {
work_dir.into_path().to_str().unwrap().to_owned(),
Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)),
)?;
let mut stream = query_stage.execute(0, task_ctx).await?;
let mut stream = query_stage.execute(0, task_ctx)?;
let batches = utils::collect_stream(&mut stream)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
Expand Down Expand Up @@ -516,7 +514,7 @@ mod tests {
work_dir.into_path().to_str().unwrap().to_owned(),
Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)),
)?;
let mut stream = query_stage.execute(0, task_ctx).await?;
let mut stream = query_stage.execute(0, task_ctx)?;
let batches = utils::collect_stream(&mut stream)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
Expand Down
4 changes: 1 addition & 3 deletions ballista/rust/core/src/execution_plans/unresolved_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use std::any::Any;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
Expand Down Expand Up @@ -63,7 +62,6 @@ impl UnresolvedShuffleExec {
}
}

#[async_trait]
impl ExecutionPlan for UnresolvedShuffleExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -101,7 +99,7 @@ impl ExecutionPlan for UnresolvedShuffleExec {
))
}

async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand Down
3 changes: 1 addition & 2 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ mod tests {
}
}

#[async_trait]
impl ExecutionPlan for TopKExec {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -515,7 +514,7 @@ mod tests {
}

/// Execute one partition and return an iterator over RecordBatch
async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand Down
10 changes: 3 additions & 7 deletions ballista/rust/executor/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, pin::Pin};

use async_trait::async_trait;
use datafusion::arrow::{
datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch,
};
Expand All @@ -49,7 +48,6 @@ impl CollectExec {
}
}

#[async_trait]
impl ExecutionPlan for CollectExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -78,18 +76,16 @@ impl ExecutionPlan for CollectExec {
unimplemented!()
}

async fn execute(
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq!(0, partition);
let num_partitions = self.plan.output_partitioning().partition_count();

let futures = (0..num_partitions).map(|i| self.plan.execute(i, context.clone()));
let streams = futures::future::join_all(futures)
.await
.into_iter()
let streams = (0..num_partitions)
.map(|i| self.plan.execute(i, context.clone()))
.collect::<Result<Vec<_>>>()
.map_err(|e| DataFusionError::Execution(format!("BallistaError: {:?}", e)))?;

Expand Down
7 changes: 3 additions & 4 deletions datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ impl CustomExec {
}
}

#[async_trait]
impl ExecutionPlan for CustomExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -225,7 +224,7 @@ impl ExecutionPlan for CustomExec {
Ok(self)
}

async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand All @@ -243,7 +242,7 @@ impl ExecutionPlan for CustomExec {
account_array.append_value(user.bank_account)?;
}

return Ok(Box::pin(MemoryStream::try_new(
Ok(Box::pin(MemoryStream::try_new(
vec![RecordBatch::try_new(
self.projected_schema.clone(),
vec![
Expand All @@ -253,7 +252,7 @@ impl ExecutionPlan for CustomExec {
)?],
self.schema(),
None,
)?));
)?))
}

fn statistics(&self) -> Statistics {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ mod tests {
let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]);
let exec = get_exec("aggregate_test_100.csv", &projection, None).await?;
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
.map(|batch| {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod tests {
let projection = None;
let exec = get_exec(&projection, None).await?;
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
.map(|batch| {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ mod tests {
let projection = None;
let exec = get_exec("alltypes_plain.parquet", &projection, None).await?;
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches = stream
.map(|batch| {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl MemTable {
let context1 = context.clone();
let exec = exec.clone();
tokio::spawn(async move {
let stream = exec.execute(part_i, context1.clone()).await?;
let stream = exec.execute(part_i, context1.clone())?;
common::collect(stream).await
})
})
Expand All @@ -103,7 +103,7 @@ impl MemTable {
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
// execute this *output* partition and collect all batches
let mut stream = exec.execute(i, context.clone()).await?;
let mut stream = exec.execute(i, context.clone())?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
Expand Down Expand Up @@ -177,7 +177,7 @@ mod tests {

// scan with projection
let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?;
let mut it = exec.execute(0, task_ctx).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
assert_eq!("c", batch2.schema().field(0).name());
Expand Down Expand Up @@ -209,7 +209,7 @@ mod tests {
let provider = MemTable::try_new(schema, vec![vec![batch]])?;

let exec = provider.scan(&None, &[], None).await?;
let mut it = exec.execute(0, task_ctx).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Expand Down Expand Up @@ -365,7 +365,7 @@ mod tests {
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;

let exec = provider.scan(&None, &[], None).await?;
let mut it = exec.execute(0, task_ctx).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ mod tests {

// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
let result = common::collect(optimized.execute(0, task_ctx).await?).await?;
let result = common::collect(optimized.execute(0, task_ctx)?).await?;
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
assert_eq!(
result[0]
Expand Down
15 changes: 5 additions & 10 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use crate::physical_plan::{
};
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::Result;
use datafusion_expr::Accumulator;
use datafusion_physical_expr::expressions::Column;
Expand Down Expand Up @@ -145,7 +144,6 @@ impl AggregateExec {
}
}

#[async_trait]
impl ExecutionPlan for AggregateExec {
/// Return a reference to Any that can be used for down-casting
fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -196,12 +194,12 @@ impl ExecutionPlan for AggregateExec {
)?))
}

async fn execute(
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let input = self.input.execute(partition, context).await?;
let input = self.input.execute(partition, context)?;
let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect();

let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Expand Down Expand Up @@ -417,7 +415,6 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_common::{DataFusionError, Result};
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
use futures::{FutureExt, Stream};
Expand Down Expand Up @@ -489,8 +486,7 @@ mod tests {
)?);

let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone()).await?)
.await?;
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;

let expected = vec![
"+---+---------------+-------------+",
Expand Down Expand Up @@ -522,7 +518,7 @@ mod tests {
)?);

let result =
common::collect(merged_aggregate.execute(0, task_ctx.clone()).await?).await?;
common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?;
assert_eq!(result.len(), 1);

let batch = &result[0];
Expand Down Expand Up @@ -556,7 +552,6 @@ mod tests {
pub yield_first: bool,
}

#[async_trait]
impl ExecutionPlan for TestYieldingExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -587,7 +582,7 @@ mod tests {
)))
}

async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand Down
Loading