diff --git a/infra/src/lib.rs b/infra/src/lib.rs index 2d05bb7..d78ae85 100644 --- a/infra/src/lib.rs +++ b/infra/src/lib.rs @@ -17,7 +17,7 @@ use datafusion::execution::SessionStateBuilder; use datafusion::logical_expr::{Explain, LogicalPlan, PlanType, TableSource, ToStringifiedPlan}; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; -use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion::prelude::{Expr, SessionConfig, SessionContext}; /// TODO make distinction between relational groups and scalar groups. #[repr(transparent)] @@ -29,7 +29,11 @@ pub struct GroupId(u64); pub struct ExprId(u64); mod types; -use types::plan::logical_plan::LogicalPlan as OptDLogicalPlan; +use types::operator::logical::{ + LogicalFilterOperator, LogicalJoinOperator, LogicalOperator, LogicalScanOperator, +}; +use types::operator::ScalarOperator; +use types::plan::logical_plan::{LogicalLink, LogicalPlan as OptDLogicalPlan, ScalarLink}; struct OptdOptimizer {} @@ -38,14 +42,47 @@ pub struct OptdQueryPlanner { } impl OptdQueryPlanner { - fn convert_into_optd_logical(plan_node: LogicalPlan) -> OptDLogicalPlan { - match plan_node { - LogicalPlan::Filter(filter) => todo!(), - LogicalPlan::Join(join) => todo!(), - LogicalPlan::TableScan(table_scan) => todo!(), + fn convert_into_optd_scalar(predicate_expr: Expr) -> Arc> { + // TODO: Implement the conversion logic here + Arc::new(ScalarOperator::new()) + } + + fn convert_into_optd_logical(plan_node: Arc) -> Arc> { + match &*plan_node { + LogicalPlan::Filter(filter) => { + Arc::new(LogicalOperator::Filter(LogicalFilterOperator { + child: LogicalLink::LogicalNode(Self::convert_into_optd_logical( + filter.input.clone(), + )), + predicate: LogicalLink::ScalarNode(Self::convert_into_optd_scalar( + filter.predicate.clone(), + )), + })) + } + + LogicalPlan::Join(join) => Arc::new(LogicalOperator::Join( + (LogicalJoinOperator { + join_type: (), + left: LogicalLink::LogicalNode(Self::convert_into_optd_logical( + join.left.clone(), + )), + right: LogicalLink::LogicalNode(Self::convert_into_optd_logical( + join.right.clone(), + )), + condition: LogicalLink::ScalarNode(Arc::new(todo!())), + }), + )), + + LogicalPlan::TableScan(table_scan) => Arc::new(LogicalOperator::Scan( + (LogicalScanOperator { + table_name: table_scan.table_name.to_quoted_string(), + predicate: None, // TODO fix this: there are multiple predicates in the scan but our IR only accepts one + }), + )), _ => panic!("OptD does not support this type of query yet"), } } + async fn create_physical_plan_inner( &self, logical_plan: &LogicalPlan, diff --git a/infra/src/main.rs b/infra/src/main.rs index a5de83f..f20f6b0 100644 --- a/infra/src/main.rs +++ b/infra/src/main.rs @@ -6,8 +6,8 @@ use datafusion::physical_plan::ExecutionPlanProperties; use datafusion::physical_plan::Partitioning; use datafusion::prelude::SessionConfig; use futures::StreamExt; -use std::{io, time::SystemTime}; use infra::create_df_context; +use std::{io, time::SystemTime}; #[tokio::main] async fn main() -> Result<()> { @@ -16,13 +16,9 @@ async fn main() -> Result<()> { let session_config = SessionConfig::from_env()?.with_information_schema(true); - let ctx = crate::create_df_context( - Some(session_config.clone()), - Some(rt_config.clone()), - None - ) - .await - .unwrap(); + let ctx = crate::create_df_context(Some(session_config.clone()), Some(rt_config.clone()), None) + .await + .unwrap(); // Create a DataFrame with the input query let queries = io::read_to_string(io::stdin())?; diff --git a/infra/src/types/memo/rule.rs b/infra/src/types/memo/rule.rs index 3b5da3f..7b75e09 100644 --- a/infra/src/types/memo/rule.rs +++ b/infra/src/types/memo/rule.rs @@ -1,7 +1,9 @@ use super::Memo; use crate::{ types::expression::{relational::logical::LogicalExpr, Expr}, - types::plan::{partial_logical_plan::PartialLogicalPlan, partial_physical_plan::PartialPhysicalPlan}, + types::plan::{ + partial_logical_plan::PartialLogicalPlan, partial_physical_plan::PartialPhysicalPlan, + }, }; #[trait_variant::make(Send)] diff --git a/infra/src/types/operator/logical.rs b/infra/src/types/operator/logical.rs index 3d27b4a..e42e685 100644 --- a/infra/src/types/operator/logical.rs +++ b/infra/src/types/operator/logical.rs @@ -25,7 +25,7 @@ pub enum LogicalOperator { /// TODO Add docs. pub struct LogicalScanOperator { pub table_name: String, - pub predicate: Link, + pub predicate: Option, } /// TODO Add docs. diff --git a/infra/src/types/operator/mod.rs b/infra/src/types/operator/mod.rs index 056495f..7f6fbe5 100644 --- a/infra/src/types/operator/mod.rs +++ b/infra/src/types/operator/mod.rs @@ -1,5 +1,7 @@ use std::{marker::PhantomData, sync::Arc}; +use super::plan::logical_plan::ScalarLink; + pub mod logical; pub mod physical; @@ -10,3 +12,12 @@ pub mod physical; pub struct ScalarOperator { _phantom: PhantomData, } + +impl ScalarOperator { + // Add a public constructor + pub fn new() -> Self { + ScalarOperator { + _phantom: std::marker::PhantomData, + } + } +} diff --git a/infra/src/types/plan/partial_physical_plan.rs b/infra/src/types/plan/partial_physical_plan.rs index d10f95e..109021d 100644 --- a/infra/src/types/plan/partial_physical_plan.rs +++ b/infra/src/types/plan/partial_physical_plan.rs @@ -1,4 +1,6 @@ -use crate::types::operator::{logical::LogicalOperator, physical::PhysicalOperator, ScalarOperator}; +use crate::types::operator::{ + logical::LogicalOperator, physical::PhysicalOperator, ScalarOperator, +}; use crate::GroupId; use std::sync::Arc; diff --git a/infra/src/types/plan/physical_plan.rs b/infra/src/types/plan/physical_plan.rs index eff7724..dd75537 100644 --- a/infra/src/types/plan/physical_plan.rs +++ b/infra/src/types/plan/physical_plan.rs @@ -4,12 +4,15 @@ use crate::types::operator::{ ScalarOperator, }; use datafusion::{ - common::{arrow::datatypes::Schema, JoinType}, datasource::physical_plan::{CsvExecBuilder, FileScanConfig}, execution::object_store::ObjectStoreUrl, physical_plan::{ + common::{arrow::datatypes::Schema, JoinType}, + datasource::physical_plan::{CsvExecBuilder, FileScanConfig}, + execution::object_store::ObjectStoreUrl, + physical_plan::{ expressions::NoOp, filter::FilterExec, joins::{HashJoinExec, PartitionMode}, ExecutionPlan, - } + }, }; use std::sync::Arc;