Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

feat(df-repr): add back join order enumeration #204

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions optd-core/src/cascades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ mod memo;
mod optimizer;
mod tasks;

use memo::Memo;
pub use optimizer::{CascadesOptimizer, GroupId, OptimizerProperties, RelNodeContext};
pub use memo::Memo;
pub use optimizer::{CascadesOptimizer, ExprId, GroupId, OptimizerProperties, RelNodeContext};
use tasks::Task;
19 changes: 17 additions & 2 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ pub struct RelMemoNode<T: RelNodeTyp> {
pub data: Option<Value>,
}

impl<T: RelNodeTyp> RelMemoNode<T> {
pub fn into_rel_node(self) -> RelNode<T> {
RelNode {
typ: self.typ,
children: self
.children
.into_iter()
.map(|x| Arc::new(RelNode::new_group(x)))
.collect(),
data: self.data,
}
}
}

impl<T: RelNodeTyp> std::fmt::Display for RelMemoNode<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}", self.typ)?;
Expand Down Expand Up @@ -401,7 +415,7 @@ impl<T: RelNodeTyp> Memo<T> {
}

/// Get the memoized representation of a node, only for debugging purpose
pub(crate) fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef<T> {
pub fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef<T> {
while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) {
expr_id = *new_expr_id;
}
Expand All @@ -411,7 +425,8 @@ impl<T: RelNodeTyp> Memo<T> {
.clone()
}

pub(crate) fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
pub fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
let group_id = self.reduce_group(group_id);
let group = self.groups.get(&group_id).expect("group not found");
let mut exprs = group.group_exprs.iter().copied().collect_vec();
exprs.sort();
Expand Down
4 changes: 4 additions & 0 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
.map(|x| x.cost.0[0])
.unwrap_or(0.0)
}

pub fn memo(&self) -> &Memo<T> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exposing the memo table publicly strikes me as a bit scary. I am worried users of the library might try to manipulate the memo table manually.

I'm hoping access is read-only (looks like it is?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it's read only 🤪

&self.memo
}
}

impl<T: RelNodeTyp> Optimizer<T> for CascadesOptimizer<T> {
Expand Down
1 change: 1 addition & 0 deletions optd-datafusion-bridge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ anyhow = "1"
async-recursion = "1"
futures-lite = "2"
futures-util = "0.3"
itertools = "0.11"
152 changes: 14 additions & 138 deletions optd-datafusion-bridge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ use datafusion::{
physical_plan::{displayable, explain::ExplainExec, ExecutionPlan},
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
};
use itertools::Itertools;
use optd_datafusion_repr::{
plan_nodes::{
ConstantType, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PhysicalHashJoin,
PhysicalNestedLoopJoin, PlanNode,
},
plan_nodes::{ConstantType, OptRelNode, PlanNode},
properties::schema::Catalog,
DatafusionOptimizer,
DatafusionOptimizer, MemoExt,
};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -89,93 +87,6 @@ pub struct OptdQueryPlanner {
pub optimizer: Arc<Mutex<Option<Box<DatafusionOptimizer>>>>,
}

#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
enum JoinOrder {
Table(String),
HashJoin(Box<Self>, Box<Self>),
NestedLoopJoin(Box<Self>, Box<Self>),
}

#[allow(dead_code)]
impl JoinOrder {
pub fn conv_into_logical_join_order(&self) -> LogicalJoinOrder {
match self {
JoinOrder::Table(name) => LogicalJoinOrder::Table(name.clone()),
JoinOrder::HashJoin(left, right) => LogicalJoinOrder::Join(
Box::new(left.conv_into_logical_join_order()),
Box::new(right.conv_into_logical_join_order()),
),
JoinOrder::NestedLoopJoin(left, right) => LogicalJoinOrder::Join(
Box::new(left.conv_into_logical_join_order()),
Box::new(right.conv_into_logical_join_order()),
),
}
}
}

#[allow(unused)]
#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
enum LogicalJoinOrder {
Table(String),
Join(Box<Self>, Box<Self>),
}

#[allow(dead_code)]
fn get_join_order(rel_node: OptRelNodeRef) -> Option<JoinOrder> {
match rel_node.typ {
OptRelNodeTyp::PhysicalHashJoin(_) => {
let join = PhysicalHashJoin::from_rel_node(rel_node.clone()).unwrap();
let left = get_join_order(join.left().into_rel_node())?;
let right = get_join_order(join.right().into_rel_node())?;
Some(JoinOrder::HashJoin(Box::new(left), Box::new(right)))
}
OptRelNodeTyp::PhysicalNestedLoopJoin(_) => {
let join = PhysicalNestedLoopJoin::from_rel_node(rel_node.clone()).unwrap();
let left = get_join_order(join.left().into_rel_node())?;
let right = get_join_order(join.right().into_rel_node())?;
Some(JoinOrder::NestedLoopJoin(Box::new(left), Box::new(right)))
}
OptRelNodeTyp::PhysicalScan => {
let scan =
optd_datafusion_repr::plan_nodes::PhysicalScan::from_rel_node(rel_node).unwrap();
Some(JoinOrder::Table(scan.table().to_string()))
}
_ => {
for child in &rel_node.children {
if let Some(res) = get_join_order(child.clone()) {
return Some(res);
}
}
None
}
}
}

impl std::fmt::Display for LogicalJoinOrder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LogicalJoinOrder::Table(name) => write!(f, "{}", name),
LogicalJoinOrder::Join(left, right) => {
write!(f, "(Join {} {})", left, right)
}
}
}
}

impl std::fmt::Display for JoinOrder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JoinOrder::Table(name) => write!(f, "{}", name),
JoinOrder::HashJoin(left, right) => {
write!(f, "(HashJoin {} {})", left, right)
}
JoinOrder::NestedLoopJoin(left, right) => {
write!(f, "(NLJ {} {})", left, right)
}
}
}
}

impl OptdQueryPlanner {
pub fn enable_adaptive(&self) {
self.optimizer
Expand Down Expand Up @@ -247,7 +158,7 @@ impl OptdQueryPlanner {
}
}

let (_, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;
let (group_id, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;

if let Some(explains) = &mut explains {
explains.push(StringifiedPlan::new(
Expand All @@ -258,52 +169,17 @@ impl OptdQueryPlanner {
.unwrap()
.explain_to_string(if verbose { Some(&meta) } else { None }),
));

// const ENABLE_JOIN_ORDER: bool = false;

// if ENABLE_JOIN_ORDER {
// let join_order = get_join_order(optimized_rel.clone());
// explains.push(StringifiedPlan::new(
// PlanType::OptimizedPhysicalPlan {
// optimizer_name: "optd-join-order".to_string(),
// },
// if let Some(join_order) = join_order {
// join_order.to_string()
// } else {
// "None".to_string()
// },
// ));
// let bindings = optimizer
// .optd_cascades_optimizer()
// .get_all_group_bindings(group_id, true);
// let mut join_orders = BTreeSet::new();
// let mut logical_join_orders = BTreeSet::new();
// for binding in bindings {
// if let Some(join_order) = get_join_order(binding) {
// logical_join_orders.insert(join_order.conv_into_logical_join_order());
// join_orders.insert(join_order);
// }
// }
// explains.push(StringifiedPlan::new(
// PlanType::OptimizedPhysicalPlan {
// optimizer_name: "optd-all-join-orders".to_string(),
// },
// join_orders.iter().map(|x| x.to_string()).join("\n"),
// ));
// explains.push(StringifiedPlan::new(
// PlanType::OptimizedPhysicalPlan {
// optimizer_name: "optd-all-logical-join-orders".to_string(),
// },
// logical_join_orders.iter().map(|x| x.to_string()).join("\n"),
// ));
// }
let join_orders = optimizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this optional or always calculated?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always calculated, as before, unless we can find a better way of passing options through datafusion SQL...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's why I didn't close the issue, this is mentioned in the issue

.optd_cascades_optimizer()
.memo()
.enumerate_join_order(group_id);
explains.push(StringifiedPlan::new(
PlanType::OptimizedPhysicalPlan {
optimizer_name: "optd-all-logical-join-orders".to_string(),
},
join_orders.iter().map(|x| x.to_string()).join("\n"),
));
}
// println!(
// "{} cost={}",
// get_join_order(optimized_rel.clone()).unwrap(),
// optimizer.optd_optimizer().get_cost_of(group_id)
// );
// optimizer.dump(Some(group_id));
ctx.optimizer = Some(&optimizer);
let physical_plan = ctx.conv_from_optd(optimized_rel, meta).await?;
if let Some(explains) = &mut explains {
Expand Down
4 changes: 3 additions & 1 deletion optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ use crate::rules::{
DepInitialDistinct, DepJoinEliminateAtScan, DepJoinPastAgg, DepJoinPastFilter, DepJoinPastProj,
};

pub use memo_ext::{LogicalJoinOrder, MemoExt};

pub mod cost;
mod explain;
mod memo_ext;
pub mod plan_nodes;
pub mod properties;
pub mod rules;
#[cfg(test)]
mod testing;
// mod expand;

pub struct DatafusionOptimizer {
heuristic_optimizer: HeuristicsOptimizer<OptRelNodeTyp>,
Expand Down
Loading
Loading