From 409607394e1ec7883b659941af803709879dd153 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Tue, 29 Oct 2024 23:57:04 -0400 Subject: [PATCH] feat(df-repr): add back join order enumeration (#204) ref https://github.com/cmu-db/optd/issues/194 after the memo table refactor, adding back a more efficient join order enumeration implementation. --------- Signed-off-by: Alex Chi --- Cargo.lock | 1 + optd-core/src/cascades.rs | 4 +- optd-core/src/cascades/memo.rs | 19 +- optd-core/src/cascades/optimizer.rs | 4 + optd-datafusion-bridge/Cargo.toml | 1 + optd-datafusion-bridge/src/lib.rs | 152 ++------------ optd-datafusion-repr/src/lib.rs | 4 +- optd-datafusion-repr/src/memo_ext.rs | 186 ++++++++++++++++++ optd-datafusion-repr/src/plan_nodes.rs | 6 +- optd-sqlplannertest/src/lib.rs | 11 -- .../joins/join_enumerate.planner.sql | 0 .../joins/join_enumerate.yml | 10 +- .../tests/joins/self-join.planner.sql | 2 + optd-sqlplannertest/tests/joins/self-join.yml | 2 +- 14 files changed, 241 insertions(+), 161 deletions(-) create mode 100644 optd-datafusion-repr/src/memo_ext.rs rename optd-sqlplannertest/{disabled_tests => tests}/joins/join_enumerate.planner.sql (100%) rename optd-sqlplannertest/{disabled_tests => tests}/joins/join_enumerate.yml (75%) diff --git a/Cargo.lock b/Cargo.lock index 1ce2208c..15f2e142 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2767,6 +2767,7 @@ dependencies = [ "datafusion-expr", "futures-lite", "futures-util", + "itertools", "optd-core", "optd-datafusion-repr", ] diff --git a/optd-core/src/cascades.rs b/optd-core/src/cascades.rs index 1d415fd1..cbeda5fe 100644 --- a/optd-core/src/cascades.rs +++ b/optd-core/src/cascades.rs @@ -4,6 +4,6 @@ mod memo; mod optimizer; mod tasks; -use memo::Memo; -pub use optimizer::{CascadesOptimizer, GroupId, OptimizerProperties, RelNodeContext}; +pub use memo::Memo; +pub use optimizer::{CascadesOptimizer, ExprId, GroupId, OptimizerProperties, RelNodeContext}; use tasks::Task; diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index 3cd14947..364bab6b 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -26,6 +26,20 @@ pub struct RelMemoNode { pub data: Option, } +impl RelMemoNode { + pub fn into_rel_node(self) -> RelNode { + RelNode { + typ: self.typ, + children: self + .children + .into_iter() + .map(|x| Arc::new(RelNode::new_group(x))) + .collect(), + data: self.data, + } + } +} + impl std::fmt::Display for RelMemoNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "({}", self.typ)?; @@ -401,7 +415,7 @@ impl Memo { } /// Get the memoized representation of a node, only for debugging purpose - pub(crate) fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef { + pub fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef { while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) { expr_id = *new_expr_id; } @@ -411,7 +425,8 @@ impl Memo { .clone() } - pub(crate) fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec { + pub fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec { + let group_id = self.reduce_group(group_id); let group = self.groups.get(&group_id).expect("group not found"); let mut exprs = group.group_exprs.iter().copied().collect_vec(); exprs.sort(); diff --git a/optd-core/src/cascades/optimizer.rs b/optd-core/src/cascades/optimizer.rs index 9b0d88ab..9f7233ee 100644 --- a/optd-core/src/cascades/optimizer.rs +++ b/optd-core/src/cascades/optimizer.rs @@ -369,6 +369,10 @@ impl CascadesOptimizer { .map(|x| x.cost.0[0]) .unwrap_or(0.0) } + + pub fn memo(&self) -> &Memo { + &self.memo + } } impl Optimizer for CascadesOptimizer { diff --git a/optd-datafusion-bridge/Cargo.toml b/optd-datafusion-bridge/Cargo.toml index 2779891e..9c4edf46 100644 --- a/optd-datafusion-bridge/Cargo.toml +++ b/optd-datafusion-bridge/Cargo.toml @@ -15,3 +15,4 @@ anyhow = "1" async-recursion = "1" futures-lite = "2" futures-util = "0.3" +itertools = "0.11" diff --git a/optd-datafusion-bridge/src/lib.rs b/optd-datafusion-bridge/src/lib.rs index 6f7f30a1..1f4d3724 100644 --- a/optd-datafusion-bridge/src/lib.rs +++ b/optd-datafusion-bridge/src/lib.rs @@ -16,13 +16,11 @@ use datafusion::{ physical_plan::{displayable, explain::ExplainExec, ExecutionPlan}, physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, }; +use itertools::Itertools; use optd_datafusion_repr::{ - plan_nodes::{ - ConstantType, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PhysicalHashJoin, - PhysicalNestedLoopJoin, PlanNode, - }, + plan_nodes::{ConstantType, OptRelNode, PlanNode}, properties::schema::Catalog, - DatafusionOptimizer, + DatafusionOptimizer, MemoExt, }; use std::{ collections::HashMap, @@ -89,93 +87,6 @@ pub struct OptdQueryPlanner { pub optimizer: Arc>>>, } -#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] -enum JoinOrder { - Table(String), - HashJoin(Box, Box), - NestedLoopJoin(Box, Box), -} - -#[allow(dead_code)] -impl JoinOrder { - pub fn conv_into_logical_join_order(&self) -> LogicalJoinOrder { - match self { - JoinOrder::Table(name) => LogicalJoinOrder::Table(name.clone()), - JoinOrder::HashJoin(left, right) => LogicalJoinOrder::Join( - Box::new(left.conv_into_logical_join_order()), - Box::new(right.conv_into_logical_join_order()), - ), - JoinOrder::NestedLoopJoin(left, right) => LogicalJoinOrder::Join( - Box::new(left.conv_into_logical_join_order()), - Box::new(right.conv_into_logical_join_order()), - ), - } - } -} - -#[allow(unused)] -#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] -enum LogicalJoinOrder { - Table(String), - Join(Box, Box), -} - -#[allow(dead_code)] -fn get_join_order(rel_node: OptRelNodeRef) -> Option { - match rel_node.typ { - OptRelNodeTyp::PhysicalHashJoin(_) => { - let join = PhysicalHashJoin::from_rel_node(rel_node.clone()).unwrap(); - let left = get_join_order(join.left().into_rel_node())?; - let right = get_join_order(join.right().into_rel_node())?; - Some(JoinOrder::HashJoin(Box::new(left), Box::new(right))) - } - OptRelNodeTyp::PhysicalNestedLoopJoin(_) => { - let join = PhysicalNestedLoopJoin::from_rel_node(rel_node.clone()).unwrap(); - let left = get_join_order(join.left().into_rel_node())?; - let right = get_join_order(join.right().into_rel_node())?; - Some(JoinOrder::NestedLoopJoin(Box::new(left), Box::new(right))) - } - OptRelNodeTyp::PhysicalScan => { - let scan = - optd_datafusion_repr::plan_nodes::PhysicalScan::from_rel_node(rel_node).unwrap(); - Some(JoinOrder::Table(scan.table().to_string())) - } - _ => { - for child in &rel_node.children { - if let Some(res) = get_join_order(child.clone()) { - return Some(res); - } - } - None - } - } -} - -impl std::fmt::Display for LogicalJoinOrder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - LogicalJoinOrder::Table(name) => write!(f, "{}", name), - LogicalJoinOrder::Join(left, right) => { - write!(f, "(Join {} {})", left, right) - } - } - } -} - -impl std::fmt::Display for JoinOrder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - JoinOrder::Table(name) => write!(f, "{}", name), - JoinOrder::HashJoin(left, right) => { - write!(f, "(HashJoin {} {})", left, right) - } - JoinOrder::NestedLoopJoin(left, right) => { - write!(f, "(NLJ {} {})", left, right) - } - } - } -} - impl OptdQueryPlanner { pub fn enable_adaptive(&self) { self.optimizer @@ -247,7 +158,7 @@ impl OptdQueryPlanner { } } - let (_, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?; + let (group_id, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?; if let Some(explains) = &mut explains { explains.push(StringifiedPlan::new( @@ -258,52 +169,17 @@ impl OptdQueryPlanner { .unwrap() .explain_to_string(if verbose { Some(&meta) } else { None }), )); - - // const ENABLE_JOIN_ORDER: bool = false; - - // if ENABLE_JOIN_ORDER { - // let join_order = get_join_order(optimized_rel.clone()); - // explains.push(StringifiedPlan::new( - // PlanType::OptimizedPhysicalPlan { - // optimizer_name: "optd-join-order".to_string(), - // }, - // if let Some(join_order) = join_order { - // join_order.to_string() - // } else { - // "None".to_string() - // }, - // )); - // let bindings = optimizer - // .optd_cascades_optimizer() - // .get_all_group_bindings(group_id, true); - // let mut join_orders = BTreeSet::new(); - // let mut logical_join_orders = BTreeSet::new(); - // for binding in bindings { - // if let Some(join_order) = get_join_order(binding) { - // logical_join_orders.insert(join_order.conv_into_logical_join_order()); - // join_orders.insert(join_order); - // } - // } - // explains.push(StringifiedPlan::new( - // PlanType::OptimizedPhysicalPlan { - // optimizer_name: "optd-all-join-orders".to_string(), - // }, - // join_orders.iter().map(|x| x.to_string()).join("\n"), - // )); - // explains.push(StringifiedPlan::new( - // PlanType::OptimizedPhysicalPlan { - // optimizer_name: "optd-all-logical-join-orders".to_string(), - // }, - // logical_join_orders.iter().map(|x| x.to_string()).join("\n"), - // )); - // } + let join_orders = optimizer + .optd_cascades_optimizer() + .memo() + .enumerate_join_order(group_id); + explains.push(StringifiedPlan::new( + PlanType::OptimizedPhysicalPlan { + optimizer_name: "optd-all-logical-join-orders".to_string(), + }, + join_orders.iter().map(|x| x.to_string()).join("\n"), + )); } - // println!( - // "{} cost={}", - // get_join_order(optimized_rel.clone()).unwrap(), - // optimizer.optd_optimizer().get_cost_of(group_id) - // ); - // optimizer.dump(Some(group_id)); ctx.optimizer = Some(&optimizer); let physical_plan = ctx.conv_from_optd(optimized_rel, meta).await?; if let Some(explains) = &mut explains { diff --git a/optd-datafusion-repr/src/lib.rs b/optd-datafusion-repr/src/lib.rs index e98cfd2f..e7634ac1 100644 --- a/optd-datafusion-repr/src/lib.rs +++ b/optd-datafusion-repr/src/lib.rs @@ -34,14 +34,16 @@ use crate::rules::{ DepInitialDistinct, DepJoinEliminateAtScan, DepJoinPastAgg, DepJoinPastFilter, DepJoinPastProj, }; +pub use memo_ext::{LogicalJoinOrder, MemoExt}; + pub mod cost; mod explain; +mod memo_ext; pub mod plan_nodes; pub mod properties; pub mod rules; #[cfg(test)] mod testing; -// mod expand; pub struct DatafusionOptimizer { heuristic_optimizer: HeuristicsOptimizer, diff --git a/optd-datafusion-repr/src/memo_ext.rs b/optd-datafusion-repr/src/memo_ext.rs new file mode 100644 index 00000000..0075fda1 --- /dev/null +++ b/optd-datafusion-repr/src/memo_ext.rs @@ -0,0 +1,186 @@ +//! Memo table extensions + +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +use itertools::Itertools; +use optd_core::{ + cascades::{ExprId, GroupId, Memo}, + rel_node::RelNodeTyp, +}; + +use crate::plan_nodes::{LogicalScan, OptRelNode, OptRelNodeTyp}; + +#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub enum LogicalJoinOrder { + Table(Arc), + Join(Box, Box), +} + +impl std::fmt::Display for LogicalJoinOrder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogicalJoinOrder::Table(name) => write!(f, "{}", name), + LogicalJoinOrder::Join(left, right) => { + write!(f, "(Join {} {})", left, right) + } + } + } +} + +pub trait MemoExt { + fn enumerate_join_order(&self, entry: GroupId) -> Vec; +} + +fn enumerate_join_order_expr_inner( + memo: &Memo, + current: ExprId, + visited: &mut HashMap>, +) -> Vec { + let expr = memo + .get_expr_memoed(current) + .as_ref() + .clone() + .into_rel_node(); + match expr.typ { + OptRelNodeTyp::Scan => { + let scan = LogicalScan::from_rel_node(Arc::new(expr)).unwrap(); + vec![LogicalJoinOrder::Table(scan.table())] + } + OptRelNodeTyp::Join(_) | OptRelNodeTyp::DepJoin(_) | OptRelNodeTyp::RawDepJoin(_) => { + // Assume child 0 == left, child 1 == right + let left = expr.children[0].typ.extract_group().unwrap(); + let right = expr.children[1].typ.extract_group().unwrap(); + let left_join_orders = enumerate_join_order_group_inner(memo, left, visited); + let right_join_orders = enumerate_join_order_group_inner(memo, right, visited); + let mut join_orders = BTreeSet::new(); + for left_join_order in left_join_orders { + for right_join_order in right_join_orders.iter() { + join_orders.insert(LogicalJoinOrder::Join( + Box::new(left_join_order.clone()), + Box::new(right_join_order.clone()), + )); + } + } + join_orders.into_iter().collect() + } + typ if typ.is_logical() => { + let mut join_orders = BTreeSet::new(); + for (idx, child) in expr.children.iter().enumerate() { + let child_join_orders = enumerate_join_order_group_inner( + memo, + child.typ.extract_group().unwrap(), + visited, + ); + if idx == 0 { + for child_join_order in child_join_orders { + join_orders.insert(child_join_order); + } + } else { + assert!( + child_join_orders.is_empty(), + "missing join node? found a node with join orders on multiple children" + ); + } + } + join_orders.into_iter().collect() + } + _ => Vec::new(), + } +} + +fn enumerate_join_order_group_inner( + memo: &Memo, + current: GroupId, + visited: &mut HashMap>, +) -> Vec { + if let Some(result) = visited.get(¤t) { + return result.clone(); + } + // If the current node is processed again before the result gets populated, simply return an empty list, as another + // search path will eventually return a correct for it, and then get combined with this empty list. + visited.insert(current, Vec::new()); + let group_exprs = memo.get_all_exprs_in_group(current); + let mut join_orders = BTreeSet::new(); + for expr_id in group_exprs { + let expr_join_orders = enumerate_join_order_expr_inner(memo, expr_id, visited); + for expr_join_order in expr_join_orders { + join_orders.insert(expr_join_order); + } + } + let res = join_orders.into_iter().collect_vec(); + visited.insert(current, res.clone()); + res +} + +impl MemoExt for Memo { + fn enumerate_join_order(&self, entry: GroupId) -> Vec { + let mut visited = HashMap::new(); + enumerate_join_order_group_inner(self, entry, &mut visited) + } +} + +#[cfg(test)] +mod tests { + use optd_core::rel_node::{RelNode, Value}; + + use crate::plan_nodes::{ + ConstantExpr, ExprList, JoinType, LogicalJoin, LogicalProjection, PlanNode, + }; + + use super::*; + + #[test] + fn enumerate_join_orders() { + let mut memo = Memo::::new(Arc::new([])); + let (group, _) = memo.add_new_expr( + LogicalJoin::new( + LogicalScan::new("t1".to_string()).into_plan_node(), + LogicalScan::new("t2".to_string()).into_plan_node(), + ConstantExpr::new(Value::Bool(true)).into_expr(), + JoinType::Inner, + ) + .into_rel_node(), + ); + // Add an alternative join order + memo.add_expr_to_group( + LogicalProjection::new( + LogicalJoin::new( + LogicalScan::new("t2".to_string()).into_plan_node(), + LogicalScan::new("t1".to_string()).into_plan_node(), + ConstantExpr::new(Value::Bool(true)).into_expr(), + JoinType::Inner, + ) + .into_plan_node(), + ExprList::new(Vec::new()), + ) + .into_rel_node(), + group, + ); + // Self-reference group + memo.add_expr_to_group( + LogicalProjection::new( + PlanNode::from_group(Arc::new(RelNode::new_group(group))), + ExprList::new(Vec::new()), + ) + .into_rel_node(), + group, + ); + let orders = memo.enumerate_join_order(group); + assert_eq!( + orders, + vec![ + LogicalJoinOrder::Join( + Box::new(LogicalJoinOrder::Table("t1".into())), + Box::new(LogicalJoinOrder::Table("t2".into())), + ), + LogicalJoinOrder::Join( + Box::new(LogicalJoinOrder::Table("t2".into())), + Box::new(LogicalJoinOrder::Table("t1".into())), + ) + ] + ); + } +} diff --git a/optd-datafusion-repr/src/plan_nodes.rs b/optd-datafusion-repr/src/plan_nodes.rs index 581c503d..782d1541 100644 --- a/optd-datafusion-repr/src/plan_nodes.rs +++ b/optd-datafusion-repr/src/plan_nodes.rs @@ -136,7 +136,11 @@ impl OptRelNodeTyp { impl std::fmt::Display for OptRelNodeTyp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + if let Self::Placeholder(group_id) = self { + write!(f, "{}", group_id) + } else { + write!(f, "{:?}", self) + } } } diff --git a/optd-sqlplannertest/src/lib.rs b/optd-sqlplannertest/src/lib.rs index aeec7287..2a3a7af5 100644 --- a/optd-sqlplannertest/src/lib.rs +++ b/optd-sqlplannertest/src/lib.rs @@ -243,17 +243,6 @@ impl DatafusionDBMS { .map(|x| &x[1]) .unwrap() )?; - } else if subtask == "join_orders" { - writeln!( - r, - "{}", - result - .iter() - .find(|x| x[0] == "physical_plan after optd-all-join-orders") - .map(|x| &x[1]) - .unwrap() - )?; - writeln!(r)?; } else if subtask == "logical_join_orders" { writeln!( r, diff --git a/optd-sqlplannertest/disabled_tests/joins/join_enumerate.planner.sql b/optd-sqlplannertest/tests/joins/join_enumerate.planner.sql similarity index 100% rename from optd-sqlplannertest/disabled_tests/joins/join_enumerate.planner.sql rename to optd-sqlplannertest/tests/joins/join_enumerate.planner.sql diff --git a/optd-sqlplannertest/disabled_tests/joins/join_enumerate.yml b/optd-sqlplannertest/tests/joins/join_enumerate.yml similarity index 75% rename from optd-sqlplannertest/disabled_tests/joins/join_enumerate.yml rename to optd-sqlplannertest/tests/joins/join_enumerate.yml index 74bcd736..6040d312 100644 --- a/optd-sqlplannertest/disabled_tests/joins/join_enumerate.yml +++ b/optd-sqlplannertest/tests/joins/join_enumerate.yml @@ -6,16 +6,16 @@ insert into t2 values (0, 200), (1, 201), (2, 202); insert into t3 values (0, 300), (1, 301), (2, 302); tasks: - - execute[use_df_logical] + - execute - sql: | select * from t2, t1, t3 where t1v1 = t2v1 and t1v2 = t3v2; desc: Test whether the optimizer enumerates all join orders. tasks: - - explain[use_df_logical]:logical_join_orders - - execute[use_df_logical] + - explain:logical_join_orders + - execute - sql: | select * from t1, t2, t3 where t1v1 = t2v1 and t1v2 = t3v2; desc: Test whether the optimizer enumerates all join orders. tasks: - - explain[use_df_logical]:logical_join_orders - - execute[use_df_logical] + - explain:logical_join_orders + - execute diff --git a/optd-sqlplannertest/tests/joins/self-join.planner.sql b/optd-sqlplannertest/tests/joins/self-join.planner.sql index c8e615cb..77e6a559 100644 --- a/optd-sqlplannertest/tests/joins/self-join.planner.sql +++ b/optd-sqlplannertest/tests/joins/self-join.planner.sql @@ -13,6 +13,8 @@ insert into t2 values (0, 200), (1, 201), (2, 202); select * from t1 as a, t1 as b where a.t1v1 = b.t1v1; /* +(Join t1 t1) + LogicalProjection { exprs: [ #0, #1, #2, #3 ] } └── LogicalFilter ├── cond:Eq diff --git a/optd-sqlplannertest/tests/joins/self-join.yml b/optd-sqlplannertest/tests/joins/self-join.yml index 627986eb..5e0ba928 100644 --- a/optd-sqlplannertest/tests/joins/self-join.yml +++ b/optd-sqlplannertest/tests/joins/self-join.yml @@ -9,5 +9,5 @@ select * from t1 as a, t1 as b where a.t1v1 = b.t1v1; desc: test self join tasks: - - explain:logical_optd,physical_optd + - explain:logical_join_orders,logical_optd,physical_optd - execute