From cbfbb88ea2e8c98b6fe31c1f8b091d14d8d38386 Mon Sep 17 00:00:00 2001 From: SarveshOO7 Date: Tue, 28 Jan 2025 12:22:40 -0500 Subject: [PATCH] Change join condition to match datafusion and fix type errors from prev commit --- infra/src/lib.rs | 81 ++++++++++++---------------- infra/src/types/operator/logical.rs | 4 +- infra/src/types/operator/physical.rs | 4 +- 3 files changed, 40 insertions(+), 49 deletions(-) diff --git a/infra/src/lib.rs b/infra/src/lib.rs index bf20de2..c77f651 100644 --- a/infra/src/lib.rs +++ b/infra/src/lib.rs @@ -36,10 +36,9 @@ use types::operator::logical::{ use types::operator::physical::{ HashJoinOperator, PhysicalFilterOperator, PhysicalOperator, TableScanOperator, }; -use types::operator::ScalarOperator; -use types::plan::logical_plan::{LogicalLink, LogicalPlan as OptDLogicalPlan, ScalarLink}; -use types::plan::partial_physical_plan::PhysicalLink; -use types::plan::physical_plan::PhysicalPlan; +use types::operator::Scalar; +use types::plan::logical_plan::{LogicalLink, LogicalPlan as OptDLogicalPlan}; +use types::plan::physical_plan::{PhysicalLink, PhysicalPlan}; struct OptdOptimizer {} @@ -49,44 +48,27 @@ impl OptdOptimizer { ) -> Arc> { match &*logical_node { LogicalOperator::Scan(logical_scan_operator) => { - Arc::new(PhysicalOperator::TableScan(TableScanOperator::< - PhysicalLink, - > { + Arc::new(PhysicalOperator::TableScan(TableScanOperator { table_name: logical_scan_operator.table_name.clone(), predicate: None, })) } LogicalOperator::Filter(logical_filter_operator) => { - let LogicalLink::LogicalNode(ref child) = logical_filter_operator.child else { - panic!("The child of filter is not a logical node") - }; - - let LogicalLink::ScalarNode(ref predicate) = logical_filter_operator.predicate - else { - panic!("The predicate of filter is not a scalar node") - }; + let LogicalLink::LogicalNode(ref child) = logical_filter_operator.child; + let predicate = logical_filter_operator.predicate.clone(); Arc::new(PhysicalOperator::Filter(PhysicalFilterOperator::< PhysicalLink, > { child: PhysicalLink::PhysicalNode(Self::conv_logical_to_physical( child.clone(), )), - predicate: PhysicalLink::ScalarNode(todo!()), + predicate: predicate, })) } LogicalOperator::Join(logical_join_operator) => { - let LogicalLink::LogicalNode(ref left_join) = logical_join_operator.left else { - panic!("The left child of join is not a logical node") - }; - - let LogicalLink::LogicalNode(ref right_join) = logical_join_operator.right else { - panic!("The right child of join is not a logical node") - }; - - let LogicalLink::ScalarNode(ref condition) = logical_join_operator.condition else { - panic!("The condition child of join is not a Scalar Node") - }; - + let LogicalLink::LogicalNode(ref left_join) = logical_join_operator.left; + let LogicalLink::LogicalNode(ref right_join) = logical_join_operator.right; + let condition = logical_join_operator.condition.clone(); Arc::new(PhysicalOperator::HashJoin( HashJoinOperator:: { join_type: (), @@ -96,7 +78,7 @@ impl OptdOptimizer { right: PhysicalLink::PhysicalNode(Self::conv_logical_to_physical( right_join.clone(), )), - condition: PhysicalLink::ScalarNode(todo!()), + condition: condition, }, )) } @@ -112,9 +94,9 @@ pub struct OptdQueryPlanner { } impl OptdQueryPlanner { - fn convert_into_optd_scalar(predicate_expr: Expr) -> Arc> { + fn convert_into_optd_scalar(predicate_expr: Expr) -> Scalar { // TODO: Implement the conversion logic here - Arc::new(ScalarOperator::new()) + Scalar {} } fn convert_into_optd_logical(plan_node: &LogicalPlan) -> Arc> { @@ -122,27 +104,32 @@ impl OptdQueryPlanner { LogicalPlan::Filter(filter) => { Arc::new(LogicalOperator::Filter(LogicalFilterOperator { child: LogicalLink::LogicalNode(Self::convert_into_optd_logical(&filter.input)), - predicate: LogicalLink::ScalarNode(Self::convert_into_optd_scalar( - filter.predicate.clone(), - )), + predicate: Self::convert_into_optd_scalar(filter.predicate.clone()), })) } - LogicalPlan::Join(join) => Arc::new(LogicalOperator::Join( - (LogicalJoinOperator { - join_type: (), - left: LogicalLink::LogicalNode(Self::convert_into_optd_logical(&join.left)), - right: LogicalLink::LogicalNode(Self::convert_into_optd_logical(&join.right)), - condition: LogicalLink::ScalarNode(Arc::new(todo!())), - }), - )), - - LogicalPlan::TableScan(table_scan) => Arc::new(LogicalOperator::Scan( - (LogicalScanOperator { + LogicalPlan::Join(join) => Arc::new(LogicalOperator::Join(LogicalJoinOperator { + join_type: (), + left: LogicalLink::LogicalNode(Self::convert_into_optd_logical(&join.left)), + right: LogicalLink::LogicalNode(Self::convert_into_optd_logical(&join.right)), + condition: Arc::new( + join.on + .iter() + .map(|(left, right)| { + let left_scalar = Self::convert_into_optd_scalar(left.clone()); + let right_scalar = Self::convert_into_optd_scalar(right.clone()); + (left_scalar, right_scalar) + }) + .collect(), + ), + })), + + LogicalPlan::TableScan(table_scan) => { + Arc::new(LogicalOperator::Scan(LogicalScanOperator { table_name: table_scan.table_name.to_quoted_string(), predicate: None, // TODO fix this: there are multiple predicates in the scan but our IR only accepts one - }), - )), + })) + } _ => panic!("OptD does not support this type of query yet"), } } diff --git a/infra/src/types/operator/logical.rs b/infra/src/types/operator/logical.rs index 800e5ac..9aa6af3 100644 --- a/infra/src/types/operator/logical.rs +++ b/infra/src/types/operator/logical.rs @@ -1,5 +1,7 @@ //! Type representations of logical operators in (materialized) query plans. +use std::sync::Arc; + use crate::types::operator::Scalar; /// A type representing a logical operator in an input logical query plan. @@ -40,5 +42,5 @@ pub struct LogicalJoinOperator { pub join_type: (), pub left: Link, pub right: Link, - pub condition: Scalar, + pub condition: Arc>, } diff --git a/infra/src/types/operator/physical.rs b/infra/src/types/operator/physical.rs index 9eeda34..ef76577 100644 --- a/infra/src/types/operator/physical.rs +++ b/infra/src/types/operator/physical.rs @@ -1,5 +1,7 @@ //! Type representations of physical operators in (materialized) query plans. +use std::sync::Arc; + use crate::types::operator::Scalar; /// A type representing a physical operator in an output physical query execution plan. @@ -41,5 +43,5 @@ pub struct HashJoinOperator { pub join_type: (), pub left: Link, pub right: Link, - pub condition: Scalar, + pub condition: Arc>, }