Skip to content

Commit

Permalink
Initial DSL implementation with grammar, semantic analysis, and codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
AlSchlo committed Feb 11, 2025
1 parent c41cf41 commit fbdeccf
Show file tree
Hide file tree
Showing 132 changed files with 8,265 additions and 607 deletions.
1,925 changes: 1,912 additions & 13 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[workspace]
members = ["optd-core"]
members = [ "optd-core", "optd-dsl"]
resolver = "2"
4 changes: 2 additions & 2 deletions docs/src/architecture/glossary.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ See the following sections for more information.

A logical expression is a version of a [Relational Expression].

TODO(connor) Add more details.
TODO(connor): Add more details.

Examples of logical expressions include Logical Scan, Logical Join, or Logical Sort expressions
(which can just be shorthanded to Scan, Join, or Sort).
Expand All @@ -135,7 +135,7 @@ Examples of logical expressions include Logical Scan, Logical Join, or Logical S

A physical expression is a version of a [Relational Expression].

TODO(connor) Add more details.
TODO(connor): Add more details.

Examples of physical expressions include Table Scan, Index Scan, Hash Join, or Sort Merge Join.

Expand Down
7 changes: 7 additions & 0 deletions optd-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ version = "0.1.0"
edition = "2021"

[dependencies]
sqlx = { version = "0.8", features = [ "sqlite", "runtime-tokio", "migrate" ] }
trait-variant = "0.1.2"

# Pin more recent versions for `-Zminimal-versions`.
proc-macro2 = "1.0.60" # For a missing feature (https://github.com/rust-lang/rust/issues/113152).
anyhow = "1.0.95"
tokio = { version = "1.43.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1", features = ["raw_value"] }
dotenvy = "0.15"
async-recursion = "1.1.1"
35 changes: 35 additions & 0 deletions optd-core/src/cascades/expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//! Types for logical and physical expressions in the optimizer.
use crate::operators::relational::physical::PhysicalOperator;
use crate::operators::scalar::ScalarOperator;
use crate::{operators::relational::logical::LogicalOperator, values::OptdValue};
use serde::Deserialize;

use super::groups::{RelationalGroupId, ScalarGroupId};

/// A logical expression in the memo table.
pub type LogicalExpression = LogicalOperator<OptdValue, RelationalGroupId, ScalarGroupId>;

/// A physical expression in the memo table.
pub type PhysicalExpression = PhysicalOperator<OptdValue, RelationalGroupId, ScalarGroupId>;

/// A scalar expression in the memo table.
pub type ScalarExpression = ScalarOperator<OptdValue, ScalarGroupId>;

/// A unique identifier for a logical expression in the memo table.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Deserialize)]
#[sqlx(transparent)]
pub struct LogicalExpressionId(pub i64);

/// A unique identifier for a physical expression in the memo table.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Deserialize)]
#[sqlx(transparent)]
pub struct PhysicalExpressionId(pub i64);

/// A unique identifier for a scalar expression in the memo table.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Deserialize)]
#[sqlx(transparent)]
pub struct ScalarExpressionId(pub i64);
49 changes: 49 additions & 0 deletions optd-core/src/cascades/groups.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use serde::Deserialize;

/// A unique identifier for a group of relational expressions in the memo table.
#[repr(transparent)]
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
sqlx::Type,
serde::Serialize,
Deserialize,
)]
#[sqlx(transparent)]
pub struct RelationalGroupId(pub i64);

/// A unique identifier for a group of scalar expressions in the memo table.
#[repr(transparent)]
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
sqlx::Type,
serde::Serialize,
Deserialize,
)]
#[sqlx(transparent)]
pub struct ScalarGroupId(pub i64);

/// The exploration status of a group or a logical expression in the memo table.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type)]
#[repr(i32)]
pub enum ExplorationStatus {
/// The group or the logical expression has not been explored.
Unexplored,
/// The group or the logical expression is currently being explored.
Exploring,
/// The group or the logical expression has been explored.
Explored,
}
70 changes: 70 additions & 0 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//! Memo table interface for query optimization.
//!
//! The memo table is a core data structure that stores expressions and their logical equivalences
//! during query optimization. It serves two main purposes:
//!
//! - Avoiding redundant optimization by memoizing already explored expressions
//! - Grouping logically equivalent expressions together to enable rule-based optimization
//!
use std::sync::Arc;

use super::{
expressions::{LogicalExpression, LogicalExpressionId, ScalarExpression, ScalarExpressionId},
groups::{RelationalGroupId, ScalarGroupId},
};
use anyhow::Result;

#[trait_variant::make(Send)]
pub trait Memoize: Send + Sync + 'static {
/// Gets all logical expressions in a group.
async fn get_all_logical_exprs_in_group(
&self,
group_id: RelationalGroupId,
) -> Result<Vec<(LogicalExpressionId, Arc<LogicalExpression>)>>;

/// Adds a logical expression to an existing group.
/// Returns the group id of new group if merge happened.
async fn add_logical_expr_to_group(
&self,
logical_expr: &LogicalExpression,
group_id: RelationalGroupId,
) -> Result<RelationalGroupId>;

/// Adds a logical expression to the memo table.
/// Returns the group id of group if already exists, otherwise creates a new group.
async fn add_logical_expr(&self, logical_expr: &LogicalExpression)
-> Result<RelationalGroupId>;

/// Gets all scalar expressions in a group.
async fn get_all_scalar_exprs_in_group(
&self,
group_id: ScalarGroupId,
) -> Result<Vec<(ScalarExpressionId, Arc<ScalarExpression>)>>;

/// Adds a scalar expression to an existing group.
/// Returns the group id of new group if merge happened.
async fn add_scalar_expr_to_group(
&self,
scalar_expr: &ScalarExpression,
group_id: ScalarGroupId,
) -> Result<ScalarGroupId>;

/// Adds a scalar expression to the memo table.
/// Returns the group id of group if already exists, otherwise creates a new group.
async fn add_scalar_expr(&self, scalar_expr: &ScalarExpression) -> Result<ScalarGroupId>;

/// Merges two relational groups and returns the new group id.
async fn merge_relation_group(
&self,
from: RelationalGroupId,
to: RelationalGroupId,
) -> Result<RelationalGroupId>;

/// Merges two scalar groups and returns the new group id.
async fn merge_scalar_group(
&self,
from: ScalarGroupId,
to: ScalarGroupId,
) -> Result<ScalarGroupId>;
}
172 changes: 172 additions & 0 deletions optd-core/src/cascades/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use std::sync::Arc;

use async_recursion::async_recursion;
use expressions::{LogicalExpression, ScalarExpression};
use groups::{RelationalGroupId, ScalarGroupId};
use memo::Memoize;

use crate::{
operators::{
relational::logical::{filter::Filter, join::Join, scan::Scan, LogicalOperator},
scalar::{add::Add, equal::Equal, ScalarOperator},
},
plans::{logical::PartialLogicalPlan, scalar::PartialScalarPlan},
};

pub mod expressions;
pub mod groups;
pub mod memo;

#[async_recursion]
pub async fn ingest_partial_logical_plan(
memo: &impl Memoize,
partial_logical_plan: &PartialLogicalPlan,
) -> anyhow::Result<RelationalGroupId> {
match partial_logical_plan {
PartialLogicalPlan::PartialMaterialized { operator } => {
let mut children_relations = Vec::new();
for child in operator.children_relations().iter() {
children_relations.push(ingest_partial_logical_plan(memo, child).await?);
}

let mut children_scalars = Vec::new();
for child in operator.children_scalars().iter() {
children_scalars.push(ingest_partial_scalar_plan(memo, child).await?);
}

memo.add_logical_expr(&operator.into_expr(&children_relations, &children_scalars))
.await
}

PartialLogicalPlan::UnMaterialized(group_id) => Ok(*group_id),
}
}

#[async_recursion]
pub async fn ingest_partial_scalar_plan(
memo: &impl Memoize,
partial_scalar_plan: &PartialScalarPlan,
) -> anyhow::Result<ScalarGroupId> {
match partial_scalar_plan {
PartialScalarPlan::PartialMaterialized { operator } => {
let mut children = Vec::new();
for child in operator.children_scalars().iter() {
children.push(ingest_partial_scalar_plan(memo, child).await?);
}

memo.add_scalar_expr(&operator.into_expr(&children)).await
}

PartialScalarPlan::UnMaterialized(group_id) => {
return Ok(*group_id);
}
}
}

#[async_recursion]
async fn match_any_partial_logical_plan(
memo: &impl Memoize,
group: RelationalGroupId,
) -> anyhow::Result<Arc<PartialLogicalPlan>> {
let logical_exprs = memo.get_all_logical_exprs_in_group(group).await?;
let last_logical_expr = logical_exprs.last().unwrap().1.clone();

match last_logical_expr.as_ref() {
LogicalExpression::Scan(scan) => {
let predicate = match_any_partial_scalar_plan(memo, scan.predicate).await?;
Ok(Arc::new(PartialLogicalPlan::PartialMaterialized {
operator: LogicalOperator::Scan(Scan {
predicate,
table_name: scan.table_name.clone(),
}),
}))
}
LogicalExpression::Filter(filter) => {
let child = match_any_partial_logical_plan(memo, filter.child).await?;
let predicate = match_any_partial_scalar_plan(memo, filter.predicate).await?;
Ok(Arc::new(PartialLogicalPlan::PartialMaterialized {
operator: LogicalOperator::Filter(Filter { child, predicate }),
}))
}
LogicalExpression::Join(join) => {
let left = match_any_partial_logical_plan(memo, join.left).await?;
let right = match_any_partial_logical_plan(memo, join.right).await?;
let condition = match_any_partial_scalar_plan(memo, join.condition).await?;
Ok(Arc::new(PartialLogicalPlan::PartialMaterialized {
operator: LogicalOperator::Join(Join {
left,
right,
condition,
join_type: join.join_type.clone(),
}),
}))
}
}
}

#[async_recursion]
async fn match_any_partial_scalar_plan(
memo: &impl Memoize,
group: ScalarGroupId,
) -> anyhow::Result<Arc<PartialScalarPlan>> {
let scalar_exprs = memo.get_all_scalar_exprs_in_group(group).await?;
let last_scalar_expr = scalar_exprs.last().unwrap().1.clone();
match last_scalar_expr.as_ref() {
ScalarExpression::Constant(constant) => {
Ok(Arc::new(PartialScalarPlan::PartialMaterialized {
operator: ScalarOperator::Constant(constant.clone()),
}))
}
ScalarExpression::ColumnRef(column_ref) => {
Ok(Arc::new(PartialScalarPlan::PartialMaterialized {
operator: ScalarOperator::ColumnRef(column_ref.clone()),
}))
}
ScalarExpression::Add(add) => {
let left = match_any_partial_scalar_plan(memo, add.left).await?;
let right = match_any_partial_scalar_plan(memo, add.right).await?;
Ok(Arc::new(PartialScalarPlan::PartialMaterialized {
operator: ScalarOperator::Add(Add { left, right }),
}))
}
ScalarExpression::Equal(equal) => {
let left = match_any_partial_scalar_plan(memo, equal.left).await?;
let right = match_any_partial_scalar_plan(memo, equal.right).await?;
Ok(Arc::new(PartialScalarPlan::PartialMaterialized {
operator: ScalarOperator::Equal(Equal { left, right }),
}))
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{storage::memo::SqliteMemo, test_utils::*};
use anyhow::Ok;

#[tokio::test]
async fn test_ingest_partial_logical_plan() -> anyhow::Result<()> {
let memo = SqliteMemo::new_in_memory().await?;
// select * from t1, t2 where t1.id = t2.id and t2.name = 'Memo' and t2.v1 = 1 + 1
let partial_logical_plan = filter(
join(
"inner",
scan("t1", boolean(true)),
scan("t2", equal(column_ref(1), add(int64(1), int64(1)))),
equal(column_ref(1), column_ref(2)),
),
equal(column_ref(2), string("Memo")),
);

let group_id = ingest_partial_logical_plan(&memo, &partial_logical_plan).await?;
let group_id_2 = ingest_partial_logical_plan(&memo, &partial_logical_plan).await?;
assert_eq!(group_id, group_id_2);

// The plan should be the same, there is only one expression per group.
let result: Arc<PartialLogicalPlan> =
match_any_partial_logical_plan(&memo, group_id).await?;
assert_eq!(result, partial_logical_plan);
Ok(())
}
}
Loading

0 comments on commit fbdeccf

Please sign in to comment.