-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial DSL implementation with grammar, semantic analysis, and codegen
- Loading branch information
Showing
132 changed files
with
8,265 additions
and
607 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
[workspace] | ||
members = ["optd-core"] | ||
members = [ "optd-core", "optd-dsl"] | ||
resolver = "2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
//! Types for logical and physical expressions in the optimizer. | ||
use crate::operators::relational::physical::PhysicalOperator; | ||
use crate::operators::scalar::ScalarOperator; | ||
use crate::{operators::relational::logical::LogicalOperator, values::OptdValue}; | ||
use serde::Deserialize; | ||
|
||
use super::groups::{RelationalGroupId, ScalarGroupId}; | ||
|
||
/// A logical expression in the memo table. | ||
pub type LogicalExpression = LogicalOperator<OptdValue, RelationalGroupId, ScalarGroupId>; | ||
|
||
/// A physical expression in the memo table. | ||
pub type PhysicalExpression = PhysicalOperator<OptdValue, RelationalGroupId, ScalarGroupId>; | ||
|
||
/// A scalar expression in the memo table. | ||
pub type ScalarExpression = ScalarOperator<OptdValue, ScalarGroupId>; | ||
|
||
/// A unique identifier for a logical expression in the memo table. | ||
#[repr(transparent)] | ||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Deserialize)] | ||
#[sqlx(transparent)] | ||
pub struct LogicalExpressionId(pub i64); | ||
|
||
/// A unique identifier for a physical expression in the memo table. | ||
#[repr(transparent)] | ||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Deserialize)] | ||
#[sqlx(transparent)] | ||
pub struct PhysicalExpressionId(pub i64); | ||
|
||
/// A unique identifier for a scalar expression in the memo table. | ||
#[repr(transparent)] | ||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Deserialize)] | ||
#[sqlx(transparent)] | ||
pub struct ScalarExpressionId(pub i64); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use serde::Deserialize; | ||
|
||
/// A unique identifier for a group of relational expressions in the memo table. | ||
#[repr(transparent)] | ||
#[derive( | ||
Debug, | ||
Clone, | ||
Copy, | ||
PartialEq, | ||
Eq, | ||
PartialOrd, | ||
Ord, | ||
Hash, | ||
sqlx::Type, | ||
serde::Serialize, | ||
Deserialize, | ||
)] | ||
#[sqlx(transparent)] | ||
pub struct RelationalGroupId(pub i64); | ||
|
||
/// A unique identifier for a group of scalar expressions in the memo table. | ||
#[repr(transparent)] | ||
#[derive( | ||
Debug, | ||
Clone, | ||
Copy, | ||
PartialEq, | ||
Eq, | ||
PartialOrd, | ||
Ord, | ||
Hash, | ||
sqlx::Type, | ||
serde::Serialize, | ||
Deserialize, | ||
)] | ||
#[sqlx(transparent)] | ||
pub struct ScalarGroupId(pub i64); | ||
|
||
/// The exploration status of a group or a logical expression in the memo table. | ||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)] | ||
#[repr(i32)] | ||
pub enum ExplorationStatus { | ||
/// The group or the logical expression has not been explored. | ||
Unexplored, | ||
/// The group or the logical expression is currently being explored. | ||
Exploring, | ||
/// The group or the logical expression has been explored. | ||
Explored, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
//! Memo table interface for query optimization. | ||
//! | ||
//! The memo table is a core data structure that stores expressions and their logical equivalences | ||
//! during query optimization. It serves two main purposes: | ||
//! | ||
//! - Avoiding redundant optimization by memoizing already explored expressions | ||
//! - Grouping logically equivalent expressions together to enable rule-based optimization | ||
//! | ||
use std::sync::Arc; | ||
|
||
use super::{ | ||
expressions::{LogicalExpression, LogicalExpressionId, ScalarExpression, ScalarExpressionId}, | ||
groups::{RelationalGroupId, ScalarGroupId}, | ||
}; | ||
use anyhow::Result; | ||
|
||
#[trait_variant::make(Send)] | ||
pub trait Memoize: Send + Sync + 'static { | ||
/// Gets all logical expressions in a group. | ||
async fn get_all_logical_exprs_in_group( | ||
&self, | ||
group_id: RelationalGroupId, | ||
) -> Result<Vec<(LogicalExpressionId, Arc<LogicalExpression>)>>; | ||
|
||
/// Adds a logical expression to an existing group. | ||
/// Returns the group id of new group if merge happened. | ||
async fn add_logical_expr_to_group( | ||
&self, | ||
logical_expr: &LogicalExpression, | ||
group_id: RelationalGroupId, | ||
) -> Result<RelationalGroupId>; | ||
|
||
/// Adds a logical expression to the memo table. | ||
/// Returns the group id of group if already exists, otherwise creates a new group. | ||
async fn add_logical_expr(&self, logical_expr: &LogicalExpression) | ||
-> Result<RelationalGroupId>; | ||
|
||
/// Gets all scalar expressions in a group. | ||
async fn get_all_scalar_exprs_in_group( | ||
&self, | ||
group_id: ScalarGroupId, | ||
) -> Result<Vec<(ScalarExpressionId, Arc<ScalarExpression>)>>; | ||
|
||
/// Adds a scalar expression to an existing group. | ||
/// Returns the group id of new group if merge happened. | ||
async fn add_scalar_expr_to_group( | ||
&self, | ||
scalar_expr: &ScalarExpression, | ||
group_id: ScalarGroupId, | ||
) -> Result<ScalarGroupId>; | ||
|
||
/// Adds a scalar expression to the memo table. | ||
/// Returns the group id of group if already exists, otherwise creates a new group. | ||
async fn add_scalar_expr(&self, scalar_expr: &ScalarExpression) -> Result<ScalarGroupId>; | ||
|
||
/// Merges two relational groups and returns the new group id. | ||
async fn merge_relation_group( | ||
&self, | ||
from: RelationalGroupId, | ||
to: RelationalGroupId, | ||
) -> Result<RelationalGroupId>; | ||
|
||
/// Merges two scalar groups and returns the new group id. | ||
async fn merge_scalar_group( | ||
&self, | ||
from: ScalarGroupId, | ||
to: ScalarGroupId, | ||
) -> Result<ScalarGroupId>; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
use std::sync::Arc; | ||
|
||
use async_recursion::async_recursion; | ||
use expressions::{LogicalExpression, ScalarExpression}; | ||
use groups::{RelationalGroupId, ScalarGroupId}; | ||
use memo::Memoize; | ||
|
||
use crate::{ | ||
operators::{ | ||
relational::logical::{filter::Filter, join::Join, scan::Scan, LogicalOperator}, | ||
scalar::{add::Add, equal::Equal, ScalarOperator}, | ||
}, | ||
plans::{logical::PartialLogicalPlan, scalar::PartialScalarPlan}, | ||
}; | ||
|
||
pub mod expressions; | ||
pub mod groups; | ||
pub mod memo; | ||
|
||
#[async_recursion] | ||
pub async fn ingest_partial_logical_plan( | ||
memo: &impl Memoize, | ||
partial_logical_plan: &PartialLogicalPlan, | ||
) -> anyhow::Result<RelationalGroupId> { | ||
match partial_logical_plan { | ||
PartialLogicalPlan::PartialMaterialized { operator } => { | ||
let mut children_relations = Vec::new(); | ||
for child in operator.children_relations().iter() { | ||
children_relations.push(ingest_partial_logical_plan(memo, child).await?); | ||
} | ||
|
||
let mut children_scalars = Vec::new(); | ||
for child in operator.children_scalars().iter() { | ||
children_scalars.push(ingest_partial_scalar_plan(memo, child).await?); | ||
} | ||
|
||
memo.add_logical_expr(&operator.into_expr(&children_relations, &children_scalars)) | ||
.await | ||
} | ||
|
||
PartialLogicalPlan::UnMaterialized(group_id) => Ok(*group_id), | ||
} | ||
} | ||
|
||
#[async_recursion] | ||
pub async fn ingest_partial_scalar_plan( | ||
memo: &impl Memoize, | ||
partial_scalar_plan: &PartialScalarPlan, | ||
) -> anyhow::Result<ScalarGroupId> { | ||
match partial_scalar_plan { | ||
PartialScalarPlan::PartialMaterialized { operator } => { | ||
let mut children = Vec::new(); | ||
for child in operator.children_scalars().iter() { | ||
children.push(ingest_partial_scalar_plan(memo, child).await?); | ||
} | ||
|
||
memo.add_scalar_expr(&operator.into_expr(&children)).await | ||
} | ||
|
||
PartialScalarPlan::UnMaterialized(group_id) => { | ||
return Ok(*group_id); | ||
} | ||
} | ||
} | ||
|
||
#[async_recursion] | ||
async fn match_any_partial_logical_plan( | ||
memo: &impl Memoize, | ||
group: RelationalGroupId, | ||
) -> anyhow::Result<Arc<PartialLogicalPlan>> { | ||
let logical_exprs = memo.get_all_logical_exprs_in_group(group).await?; | ||
let last_logical_expr = logical_exprs.last().unwrap().1.clone(); | ||
|
||
match last_logical_expr.as_ref() { | ||
LogicalExpression::Scan(scan) => { | ||
let predicate = match_any_partial_scalar_plan(memo, scan.predicate).await?; | ||
Ok(Arc::new(PartialLogicalPlan::PartialMaterialized { | ||
operator: LogicalOperator::Scan(Scan { | ||
predicate, | ||
table_name: scan.table_name.clone(), | ||
}), | ||
})) | ||
} | ||
LogicalExpression::Filter(filter) => { | ||
let child = match_any_partial_logical_plan(memo, filter.child).await?; | ||
let predicate = match_any_partial_scalar_plan(memo, filter.predicate).await?; | ||
Ok(Arc::new(PartialLogicalPlan::PartialMaterialized { | ||
operator: LogicalOperator::Filter(Filter { child, predicate }), | ||
})) | ||
} | ||
LogicalExpression::Join(join) => { | ||
let left = match_any_partial_logical_plan(memo, join.left).await?; | ||
let right = match_any_partial_logical_plan(memo, join.right).await?; | ||
let condition = match_any_partial_scalar_plan(memo, join.condition).await?; | ||
Ok(Arc::new(PartialLogicalPlan::PartialMaterialized { | ||
operator: LogicalOperator::Join(Join { | ||
left, | ||
right, | ||
condition, | ||
join_type: join.join_type.clone(), | ||
}), | ||
})) | ||
} | ||
} | ||
} | ||
|
||
#[async_recursion] | ||
async fn match_any_partial_scalar_plan( | ||
memo: &impl Memoize, | ||
group: ScalarGroupId, | ||
) -> anyhow::Result<Arc<PartialScalarPlan>> { | ||
let scalar_exprs = memo.get_all_scalar_exprs_in_group(group).await?; | ||
let last_scalar_expr = scalar_exprs.last().unwrap().1.clone(); | ||
match last_scalar_expr.as_ref() { | ||
ScalarExpression::Constant(constant) => { | ||
Ok(Arc::new(PartialScalarPlan::PartialMaterialized { | ||
operator: ScalarOperator::Constant(constant.clone()), | ||
})) | ||
} | ||
ScalarExpression::ColumnRef(column_ref) => { | ||
Ok(Arc::new(PartialScalarPlan::PartialMaterialized { | ||
operator: ScalarOperator::ColumnRef(column_ref.clone()), | ||
})) | ||
} | ||
ScalarExpression::Add(add) => { | ||
let left = match_any_partial_scalar_plan(memo, add.left).await?; | ||
let right = match_any_partial_scalar_plan(memo, add.right).await?; | ||
Ok(Arc::new(PartialScalarPlan::PartialMaterialized { | ||
operator: ScalarOperator::Add(Add { left, right }), | ||
})) | ||
} | ||
ScalarExpression::Equal(equal) => { | ||
let left = match_any_partial_scalar_plan(memo, equal.left).await?; | ||
let right = match_any_partial_scalar_plan(memo, equal.right).await?; | ||
Ok(Arc::new(PartialScalarPlan::PartialMaterialized { | ||
operator: ScalarOperator::Equal(Equal { left, right }), | ||
})) | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::{storage::memo::SqliteMemo, test_utils::*}; | ||
use anyhow::Ok; | ||
|
||
#[tokio::test] | ||
async fn test_ingest_partial_logical_plan() -> anyhow::Result<()> { | ||
let memo = SqliteMemo::new_in_memory().await?; | ||
// select * from t1, t2 where t1.id = t2.id and t2.name = 'Memo' and t2.v1 = 1 + 1 | ||
let partial_logical_plan = filter( | ||
join( | ||
"inner", | ||
scan("t1", boolean(true)), | ||
scan("t2", equal(column_ref(1), add(int64(1), int64(1)))), | ||
equal(column_ref(1), column_ref(2)), | ||
), | ||
equal(column_ref(2), string("Memo")), | ||
); | ||
|
||
let group_id = ingest_partial_logical_plan(&memo, &partial_logical_plan).await?; | ||
let group_id_2 = ingest_partial_logical_plan(&memo, &partial_logical_plan).await?; | ||
assert_eq!(group_id, group_id_2); | ||
|
||
// The plan should be the same, there is only one expression per group. | ||
let result: Arc<PartialLogicalPlan> = | ||
match_any_partial_logical_plan(&memo, group_id).await?; | ||
assert_eq!(result, partial_logical_plan); | ||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.