Skip to content

Commit 84ae3e8

Browse files
added coalesceBatchesExec, FilterExec, RepartitionExec
1 parent 3565bac commit 84ae3e8

File tree

1 file changed

+102
-13
lines changed

1 file changed

+102
-13
lines changed

src/query/mod.rs

Lines changed: 102 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,18 @@ use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, Tr
3333
// use datafusion::config::ConfigFileType;
3434
use datafusion::error::{DataFusionError, Result};
3535
use datafusion::execution::disk_manager::DiskManagerConfig;
36+
use datafusion::execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder};
3637
// use datafusion::execution::runtime_env::RuntimeEnvBuilder;
3738
use datafusion::execution::SessionStateBuilder;
3839
use datafusion::logical_expr::expr::Alias;
3940
use datafusion::logical_expr::{
4041
Aggregate, Explain, Filter, LogicalPlan, PlanType, Projection, ToStringifiedPlan,
4142
};
43+
use datafusion::physical_expr::create_physical_expr;
44+
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
45+
use datafusion::physical_plan::repartition::RepartitionExec;
46+
use datafusion::physical_plan::{collect as PhysicalPlanCollect, ExecutionPlan, Partitioning};
47+
use datafusion::physical_plan::filter::FilterExec;
4248
// use datafusion::physical_plan::execution_plan::EmissionType;
4349
// use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties};
4450
use datafusion::prelude::*;
@@ -759,50 +765,133 @@ struct BenchmarkResult {
759765
elapsed_seconds: f64,
760766
}
761767

762-
#[derive(Debug, Serialize)]
763-
struct BenchmarkResponse {
764-
results: Vec<BenchmarkResult>,
765-
}
766768

767769
pub async fn run_benchmark() {
768770
const TRIES: usize = 1;
769771

770772
let mut results = Vec::new();
771773
let mut query_num = 1;
774+
775+
// 1. Configure Runtime Environment with parallelism
776+
let runtime_config = RuntimeEnvBuilder::new() // Number of partitions for parallel processing
777+
.with_disk_manager(DiskManagerConfig::NewOs);
778+
779+
let runtime = RuntimeEnv::new(runtime_config).unwrap();
780+
781+
772782
// Create session context
773-
let ctx = SessionContext::new();
783+
let mut config = SessionConfig::new().with_coalesce_batches(true)
784+
.with_target_partitions(8)
785+
.with_batch_size(50000);
786+
config.options_mut().execution.parquet.binary_as_string = true;
787+
config.options_mut().execution.use_row_number_estimates_to_optimize_partitioning = true;
788+
config.options_mut().execution.parquet.pushdown_filters = true;
789+
config.options_mut().execution.parquet.enable_page_index = true;
790+
config.options_mut().execution.parquet.pruning = true;
791+
config.options_mut().execution.parquet.reorder_filters = true;
792+
let state = SessionStateBuilder::new()
793+
.with_default_features()
794+
.with_config(config)
795+
.with_runtime_env(Arc::new(runtime))
796+
.build();
797+
let ctx = SessionContext::new_with_state(state);
774798
let sql = "CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '/home/ubuntu/clickbench/hits.parquet' OPTIONS ('binary_as_string' 'true')";
775799
let _ = ctx.sql(&sql).await.unwrap().collect().await.unwrap();
776800
// Read queries from file
777801
let queries = fs::read_to_string("/home/ubuntu/queries.sql").unwrap();
778802

779803

780804
for query in queries.lines() {
781-
// Write current query to temporary file
782805
fs::write("/tmp/query.sql", &query).unwrap();
783-
806+
784807
for iteration in 1..=TRIES {
785-
// Clear caches
786808
clear_caches().unwrap();
809+
810+
811+
// Create the query plan
812+
let df = ctx.sql(&query).await.unwrap();
813+
let logical_plan = df.logical_plan().clone();
814+
let physical_plan = df.create_physical_plan().await.unwrap();
815+
816+
// Add coalesce
817+
let mut exec_plan: Arc<dyn ExecutionPlan> = Arc::new(CoalesceBatchesExec::new(physical_plan, 50000));
818+
819+
// Check if plan contains filter and add FilterExec
820+
fn has_filter(plan: &LogicalPlan) -> bool {
821+
match plan {
822+
LogicalPlan::Filter(_) => true,
823+
LogicalPlan::Projection(proj) => has_filter(proj.input.as_ref()),
824+
LogicalPlan::Aggregate(agg) => has_filter(agg.input.as_ref()),
825+
LogicalPlan::Join(join) => {
826+
has_filter(join.left.as_ref()) || has_filter(join.right.as_ref())
827+
},
828+
LogicalPlan::Window(window) => has_filter(window.input.as_ref()),
829+
LogicalPlan::Sort(sort) => has_filter(sort.input.as_ref()),
830+
LogicalPlan::Limit(limit) => has_filter(limit.input.as_ref()),
831+
_ => false,
832+
}
833+
}
834+
835+
// Extract filter expressions from logical plan
836+
fn extract_filters(plan: &LogicalPlan) -> Vec<Expr> {
837+
match plan {
838+
LogicalPlan::Filter(filter) => vec![filter.predicate.clone()],
839+
LogicalPlan::Projection(proj) => extract_filters(proj.input.as_ref()),
840+
LogicalPlan::Aggregate(agg) => extract_filters(agg.input.as_ref()),
841+
LogicalPlan::Join(join) => {
842+
let mut filters = extract_filters(join.left.as_ref());
843+
filters.extend(extract_filters(join.right.as_ref()));
844+
filters
845+
},
846+
_ => vec![],
847+
}
848+
}
849+
850+
if has_filter(&logical_plan) {
851+
let filters = extract_filters(&logical_plan);
852+
for filter_expr in filters {
853+
854+
855+
if let Ok(physical_filter_expr) = create_physical_expr(
856+
&filter_expr,
857+
&logical_plan.schema(),
858+
&ctx.state().execution_props(),
859+
860+
) {
861+
exec_plan = Arc::new(FilterExec::try_new(
862+
physical_filter_expr,
863+
exec_plan,
864+
).unwrap());
865+
}
787866

788-
// Execute and time the query
867+
868+
}
869+
}
870+
871+
// Execute the plan
872+
let task_ctx = ctx.task_ctx();
789873
let start = Instant::now();
790-
ctx.sql(&query).await.unwrap().collect().await.unwrap();
874+
875+
//let _ = execute_parallel(exec_plan.clone(), ctx.task_ctx()).await.unwrap();
876+
// Add repartitioning for better parallelism
877+
let repartitioned = Arc::new(RepartitionExec::try_new(
878+
exec_plan,
879+
Partitioning::RoundRobinBatch(8),
880+
).unwrap());
881+
let _ = PhysicalPlanCollect(repartitioned, task_ctx).await.unwrap();
882+
791883
let elapsed = start.elapsed().as_secs_f64();
792884
let benchmark_result = BenchmarkResult {
793885
query_num,
794886
iteration,
795887
elapsed_seconds: elapsed,
796888
};
797889
println!("Query {query_num} iteration {iteration} took {elapsed} seconds");
798-
// Record result
799890
results.push(benchmark_result);
800-
801891
}
802892
query_num += 1;
803893
}
804894

805-
println!("{}", serde_json::to_string(&BenchmarkResponse { results }).unwrap());
806895
}
807896

808897
fn clear_caches() -> io::Result<()> {

0 commit comments

Comments
 (0)