Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
prune based on upper bound
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi Z <[email protected]>
  • Loading branch information
skyzh committed Dec 14, 2024
1 parent 81a2e80 commit 3140324
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 21 deletions.
2 changes: 1 addition & 1 deletion optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ impl<T: NodeType> NaiveMemo<T> {
}

fn verify_integrity(&self) {
if cfg!(debug_assertions) {
if false {
let num_of_exprs = self.expr_id_to_expr_node.len();
assert_eq!(num_of_exprs, self.expr_node_to_expr_id.len());
assert_eq!(num_of_exprs, self.expr_id_to_group_id.len());
Expand Down
2 changes: 1 addition & 1 deletion optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> {
trace!(event = "fire_optimize_tasks", root_group_id = %group_id);
self.tasks
.push_back(Box::new(OptimizeGroupTask::new(group_id)));
.push_back(Box::new(OptimizeGroupTask::new(group_id, None)));
// get the task from the stack
self.ctx.budget_used = false;
let plan_space_begin = self.memo.estimated_plan_space();
Expand Down
12 changes: 10 additions & 2 deletions optd-core/src/cascades/tasks/apply_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@ pub struct ApplyRuleTask {
rule_id: RuleId,
expr_id: ExprId,
exploring: bool,
upper_bound: Option<f64>,
}

impl ApplyRuleTask {
pub fn new(rule_id: RuleId, expr_id: ExprId, exploring: bool) -> Self {
pub fn new(
rule_id: RuleId,
expr_id: ExprId,
exploring: bool,
upper_bound: Option<f64>,
) -> Self {
Self {
rule_id,
expr_id,
exploring,
upper_bound,
}
}
}
Expand Down Expand Up @@ -181,13 +188,14 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for ApplyRuleTask {
let typ = expr.unwrap_typ();
if typ.is_logical() {
tasks.push(
Box::new(OptimizeExpressionTask::new(expr_id, self.exploring))
Box::new(OptimizeExpressionTask::new(expr_id, self.exploring, self.upper_bound))
as Box<dyn Task<T, M>>,
);
} else {
tasks.push(Box::new(OptimizeInputsTask::new(
expr_id,
!optimizer.prop.disable_pruning,
self.upper_bound
)) as Box<dyn Task<T, M>>);
}
optimizer.unmark_expr_explored(expr_id);
Expand Down
10 changes: 7 additions & 3 deletions optd-core/src/cascades/tasks/explore_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ use crate::nodes::NodeType;

pub struct ExploreGroupTask {
group_id: GroupId,
upper_bound: Option<f64>,
}

impl ExploreGroupTask {
pub fn new(group_id: GroupId) -> Self {
Self { group_id }
pub fn new(group_id: GroupId, upper_bound: Option<f64>) -> Self {
Self {
group_id,
upper_bound,
}
}
}

Expand All @@ -36,7 +40,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for ExploreGroupTask {
let typ = optimizer.get_expr_memoed(expr).typ.clone();
if typ.is_logical() {
tasks
.push(Box::new(OptimizeExpressionTask::new(expr, true)) as Box<dyn Task<T, M>>);
.push(Box::new(OptimizeExpressionTask::new(expr, true, self.upper_bound)) as Box<dyn Task<T, M>>);
}
}
optimizer.mark_group_explored(self.group_id);
Expand Down
13 changes: 9 additions & 4 deletions optd-core/src/cascades/tasks/optimize_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ use crate::rules::RuleMatcher;
pub struct OptimizeExpressionTask {
expr_id: ExprId,
exploring: bool,
upper_bound: Option<f64>,
}

impl OptimizeExpressionTask {
pub fn new(expr_id: ExprId, exploring: bool) -> Self {
Self { expr_id, exploring }
pub fn new(expr_id: ExprId, exploring: bool, upper_bound: Option<f64>) -> Self {
Self {
expr_id,
exploring,
upper_bound,
}
}
}

Expand Down Expand Up @@ -53,12 +58,12 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeExpressionTask {
}
if top_matches(rule.matcher(), expr.typ.clone()) {
tasks.push(
Box::new(ApplyRuleTask::new(rule_id, self.expr_id, self.exploring))
Box::new(ApplyRuleTask::new(rule_id, self.expr_id, self.exploring, self.upper_bound))
as Box<dyn Task<T, M>>,
);
for &input_group_id in &expr.children {
tasks.push(
Box::new(ExploreGroupTask::new(input_group_id)) as Box<dyn Task<T, M>>
Box::new(ExploreGroupTask::new(input_group_id, self.upper_bound)) as Box<dyn Task<T, M>>
);
}
}
Expand Down
11 changes: 8 additions & 3 deletions optd-core/src/cascades/tasks/optimize_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ use crate::nodes::NodeType;

pub struct OptimizeGroupTask {
group_id: GroupId,
upper_bound: Option<f64>,
}

impl OptimizeGroupTask {
pub fn new(group_id: GroupId) -> Self {
Self { group_id }
pub fn new(group_id: GroupId, upper_bound: Option<f64>) -> Self {
Self {
group_id,
upper_bound,
}
}
}

Expand All @@ -37,7 +41,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeGroupTask {
for &expr in &exprs {
let typ = optimizer.get_expr_memoed(expr).typ.clone();
if typ.is_logical() {
tasks.push(Box::new(OptimizeExpressionTask::new(expr, false)) as Box<dyn Task<T, M>>);
tasks.push(Box::new(OptimizeExpressionTask::new(expr, false, self.upper_bound)) as Box<dyn Task<T, M>>);
}
}
for &expr in &exprs {
Expand All @@ -46,6 +50,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeGroupTask {
tasks.push(Box::new(OptimizeInputsTask::new(
expr,
!optimizer.prop.disable_pruning,
self.upper_bound
)) as Box<dyn Task<T, M>>);
}
}
Expand Down
34 changes: 27 additions & 7 deletions optd-core/src/cascades/tasks/optimize_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,25 @@ pub struct OptimizeInputsTask {
expr_id: ExprId,
continue_from: Option<ContinueTask>,
pruning: bool,
upper_bound: Option<f64>,
}

impl OptimizeInputsTask {
pub fn new(expr_id: ExprId, pruning: bool) -> Self {
pub fn new(expr_id: ExprId, pruning: bool, upper_bound: Option<f64>) -> Self {
Self {
expr_id,
continue_from: None,
pruning,
upper_bound,
}
}

fn continue_from(&self, cont: ContinueTask, pruning: bool) -> Self {
fn continue_from(&self, cont: ContinueTask, pruning: bool, upper_bound: Option<f64>) -> Self {
Self {
expr_id: self.expr_id,
continue_from: Some(cont),
pruning,
upper_bound,
}
}

Expand Down Expand Up @@ -153,6 +156,19 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {

trace!(event = "task_begin", task = "optimize_inputs", expr_id = %self.expr_id, continue_from = %ContinueTaskDisplay(&self.continue_from), total_children = %children_group_ids.len());

let upper_bound = if self.pruning {
if let Some(upper_bound) = self.upper_bound {
Some(upper_bound)
} else if let Some(winner) = optimizer.get_group_info(group_id).winner.as_full_winner()
{
Some(winner.total_weighted_cost)
} else {
None
}
} else {
None
};

if let Some(ContinueTask {
next_group_idx,
return_from_optimize_group,
Expand Down Expand Up @@ -219,9 +235,9 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
winner_weighted_cost = %trace_fmt(&group_info.winner),
current_processing = %next_group_idx,
total_child_groups = %children_group_ids.len());
if let Some(winner) = group_info.winner.as_full_winner() {
if let Some(upper_bound) = upper_bound {
let cost_so_far = cost.weighted_cost(&total_cost);
if winner.total_weighted_cost <= cost_so_far {
if upper_bound <= cost_so_far {
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "pruned");
return Ok(vec![]);
}
Expand All @@ -232,7 +248,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
let child_group_id = children_group_ids[next_group_idx];
let group_idx = next_group_idx;
let child_group_info = optimizer.get_group_info(child_group_id);
if !child_group_info.winner.has_full_winner() {
let Some(child_winner) = child_group_info.winner.as_full_winner() else {
if !return_from_optimize_group {
trace!(event = "task_yield", task = "optimize_inputs", expr_id = %self.expr_id, group_idx = %group_idx, yield_to = "optimize_group", optimize_group_id = %child_group_id);
return Ok(vec![
Expand All @@ -242,22 +258,25 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
return_from_optimize_group: true,
},
self.pruning,
upper_bound,
)) as Box<dyn Task<T, M>>,
Box::new(OptimizeGroupTask::new(child_group_id)) as Box<dyn Task<T, M>>,
Box::new(OptimizeGroupTask::new(child_group_id, upper_bound))
as Box<dyn Task<T, M>>,
]);
} else {
self.update_winner_impossible(optimizer);
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "impossible");
return Ok(vec![]);
}
}
};
trace!(event = "task_yield", task = "optimize_inputs", expr_id = %self.expr_id, group_idx = %group_idx, yield_to = "next_optimize_input");
Ok(vec![Box::new(self.continue_from(
ContinueTask {
next_group_idx: group_idx + 1,
return_from_optimize_group: false,
},
self.pruning,
upper_bound.map(|bound| bound - child_winner.total_weighted_cost),
)) as Box<dyn Task<T, M>>])
} else {
self.update_winner(input_statistics_ref, operation_cost, total_cost, optimizer);
Expand All @@ -272,6 +291,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
return_from_optimize_group: false,
},
self.pruning,
upper_bound,
)) as Box<dyn Task<T, M>>])
}
}
Expand Down

0 comments on commit 3140324

Please sign in to comment.