Skip to content

Commit

Permalink
Merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
AlSchlo committed Feb 10, 2025
2 parents 3173aed + a94829a commit 3d87b42
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 68 deletions.
18 changes: 12 additions & 6 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,51 @@ use anyhow::Result;

#[trait_variant::make(Send)]
pub trait Memoize: Send + Sync + 'static {
/// Gets all logical expressions in a group.
async fn get_all_logical_exprs_in_group(
&self,
group_id: RelationalGroupId,
) -> Result<Vec<(LogicalExpressionId, Arc<LogicalExpression>)>>;

// Returns the group id of new group if merge happened.
/// Adds a logical expression to an existing group.
/// Returns the group id of new group if merge happened.
async fn add_logical_expr_to_group(
&self,
logical_expr: &LogicalExpression,
group_id: RelationalGroupId,
) -> Result<RelationalGroupId>;

// Returns the group id of group if already exists, otherwise creates a new group.
/// Adds a logical expression to the memo table.
/// Returns the group id of group if already exists, otherwise creates a new group.
async fn add_logical_expr(&self, logical_expr: &LogicalExpression)
-> Result<RelationalGroupId>;

/// Gets all scalar expressions in a group.
async fn get_all_scalar_exprs_in_group(
&self,
group_id: ScalarGroupId,
) -> Result<Vec<(ScalarExpressionId, Arc<ScalarExpression>)>>;

// Returns the group id of new group if merge happened.
/// Adds a scalar expression to an existing group.
/// Returns the group id of new group if merge happened.
async fn add_scalar_expr_to_group(
&self,
scalar_expr: &ScalarExpression,
group_id: ScalarGroupId,
) -> Result<ScalarGroupId>;

// Returns the group id of group if already exists, otherwise creates a new group.
/// Adds a scalar expression to the memo table.
/// Returns the group id of group if already exists, otherwise creates a new group.
async fn add_scalar_expr(&self, scalar_expr: &ScalarExpression) -> Result<ScalarGroupId>;

// Merges two relational groups and returns the new group id.
/// Merges two relational groups and returns the new group id.
async fn merge_relation_group(
&self,
from: RelationalGroupId,
to: RelationalGroupId,
) -> Result<RelationalGroupId>;

// Merges two scalar groups and returns the new group id.
/// Merges two scalar groups and returns the new group id.
async fn merge_scalar_group(
&self,
from: ScalarGroupId,
Expand Down
83 changes: 49 additions & 34 deletions optd-core/src/storage/memo.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! An implementation of the memo table using SQLite.
use std::{str::FromStr, sync::Arc, time::Duration};

use super::transaction::Transaction;
Expand Down Expand Up @@ -64,7 +66,7 @@ impl SqliteMemo {
/// Begin a new transaction.
pub(super) async fn begin(&self) -> anyhow::Result<Transaction<'_>> {
let txn = self.db.begin().await?;
Ok(Transaction::new(txn).await?)
Transaction::new(txn).await
}
}

Expand All @@ -80,9 +82,7 @@ impl Memoize for SqliteMemo {
}

let mut txn = self.begin().await?;
let representative_group_id = self
.get_representative_group_id(&mut *txn, group_id)
.await?;
let representative_group_id = self.get_representative_group_id(&mut txn, group_id).await?;
let logical_exprs: Vec<LogicalExprRecord> =
sqlx::query_as(&self.get_all_logical_exprs_in_group_query)
.bind(representative_group_id)
Expand Down Expand Up @@ -129,7 +129,7 @@ impl Memoize for SqliteMemo {

let mut txn = self.begin().await?;
let representative_group_id = self
.get_representative_scalar_group_id(&mut *txn, group_id)
.get_representative_scalar_group_id(&mut txn, group_id)
.await?;
let scalar_exprs: Vec<ScalarExprRecord> =
sqlx::query_as(&self.get_all_scalar_exprs_in_group_query)
Expand Down Expand Up @@ -168,8 +168,7 @@ impl Memoize for SqliteMemo {
to: RelationalGroupId,
) -> Result<RelationalGroupId> {
let mut txn = self.begin().await?;
self.set_representative_group_id(&mut *txn, from, to)
.await?;
self.set_representative_group_id(&mut txn, from, to).await?;
txn.commit().await?;
Ok(to)
}
Expand All @@ -180,15 +179,16 @@ impl Memoize for SqliteMemo {
to: ScalarGroupId,
) -> Result<ScalarGroupId> {
let mut txn = self.begin().await?;
self.set_representative_scalar_group_id(&mut *txn, from, to)
self.set_representative_scalar_group_id(&mut txn, from, to)
.await?;
txn.commit().await?;
Ok(to)
}
}

// Memoize helpers
// Helper functions for implementing the `Memoize` trait.
impl SqliteMemo {
/// Gets the representative group id of a relational group.
async fn get_representative_group_id(
&self,
db: &mut SqliteConnection,
Expand All @@ -202,6 +202,7 @@ impl SqliteMemo {
Ok(representative_group_id)
}

/// Sets the representative group id of a relational group.
async fn set_representative_group_id(
&self,
db: &mut SqliteConnection,
Expand All @@ -216,6 +217,7 @@ impl SqliteMemo {
Ok(())
}

/// Gets the representative group id of a scalar group.
async fn get_representative_scalar_group_id(
&self,
db: &mut SqliteConnection,
Expand All @@ -229,6 +231,7 @@ impl SqliteMemo {
Ok(representative_group_id)
}

/// Sets the representative group id of a scalar group.
async fn set_representative_scalar_group_id(
&self,
db: &mut SqliteConnection,
Expand All @@ -243,6 +246,10 @@ impl SqliteMemo {
Ok(())
}

/// Inserts a scalar expression into the database. If the `add_to_group_id` is `Some`,
/// we will attempt to add the scalar expression to the specified group.
/// If the scalar expression already exists in the database, the existing group id will be returned.
/// Otherwise, a new group id will be created.
async fn add_scalar_expr_to_group_inner(
&self,
scalar_expr: &ScalarExpression,
Expand Down Expand Up @@ -275,13 +282,13 @@ impl SqliteMemo {
ScalarOperatorKind::Constant,
)
.await?;
let group_id = sqlx::query_scalar("INSERT INTO scalar_constants (scalar_expression_id, group_id, value) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")

sqlx::query_scalar("INSERT INTO scalar_constants (scalar_expression_id, group_id, value) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
.bind(scalar_expr_id)
.bind(group_id)
.bind(serde_json::to_string(&constant)?)
.fetch_one(&mut *txn)
.await?;
group_id
.await?
}
ScalarExpression::ColumnRef(column_ref) => {
Self::insert_into_scalar_expressions(
Expand All @@ -291,13 +298,13 @@ impl SqliteMemo {
ScalarOperatorKind::ColumnRef,
)
.await?;
let group_id = sqlx::query_scalar("INSERT INTO scalar_column_refs (scalar_expression_id, group_id, column_index) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")

sqlx::query_scalar("INSERT INTO scalar_column_refs (scalar_expression_id, group_id, column_index) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
.bind(scalar_expr_id)
.bind(group_id)
.bind(serde_json::to_string(&column_ref.column_index)?)
.fetch_one(&mut *txn)
.await?;
group_id
.await?
}
ScalarExpression::Add(add) => {
Self::insert_into_scalar_expressions(
Expand All @@ -307,17 +314,14 @@ impl SqliteMemo {
ScalarOperatorKind::Add,
)
.await?;
// println!("add: {:?}", add);
// println!("scalar_expr_id: {:?}", scalar_expr_id);
// println!("group_id: {:?}", group_id);
let group_id = sqlx::query_scalar("INSERT INTO scalar_adds (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")

sqlx::query_scalar("INSERT INTO scalar_adds (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
.bind(scalar_expr_id)
.bind(group_id)
.bind(add.left)
.bind(add.right)
.fetch_one(&mut *txn)
.await?;
group_id
.await?
}
ScalarExpression::Equal(equal) => {
Self::insert_into_scalar_expressions(
Expand All @@ -327,14 +331,14 @@ impl SqliteMemo {
ScalarOperatorKind::Equal,
)
.await?;
let group_id = sqlx::query_scalar("INSERT INTO scalar_equals (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")

sqlx::query_scalar("INSERT INTO scalar_equals (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
.bind(scalar_expr_id)
.bind(group_id)
.bind(equal.left)
.bind(equal.right)
.fetch_one(&mut *txn)
.await?;
group_id
.await?
}
};

Expand All @@ -354,7 +358,7 @@ impl SqliteMemo {
Ok(inserted_group_id)
}

/// Inserts a scalar expression into the database.
/// Inserts an entry into the `scalar_expressions` table.
async fn insert_into_scalar_expressions(
db: &mut SqliteConnection,
scalar_expr_id: ScalarExpressionId,
Expand All @@ -371,6 +375,7 @@ impl SqliteMemo {
Ok(())
}

/// Removes a dangling scalar expression from the `scalar_expressions` table.
async fn remove_dangling_scalar_expr(
&self,
db: &mut SqliteConnection,
Expand All @@ -383,6 +388,10 @@ impl SqliteMemo {
Ok(())
}

/// Inserts a logical expression into the memo table. If the `add_to_group_id` is `Some`,
/// we will attempt to add the logical expression to the specified group.
/// If the logical expression already exists in the database, the existing group id will be returned.
/// Otherwise, a new group id will be created.
async fn add_logical_expr_to_group_inner(
&self,
logical_expr: &LogicalExpression,
Expand Down Expand Up @@ -417,14 +426,14 @@ impl SqliteMemo {
LogicalOperatorKind::Scan,
)
.await?;
let group_id= sqlx::query_scalar("INSERT INTO scans (logical_expression_id, group_id, table_name, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")

sqlx::query_scalar("INSERT INTO scans (logical_expression_id, group_id, table_name, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
.bind(logical_expr_id)
.bind(group_id)
.bind(serde_json::to_string(&scan.table_name)?)
.bind(scan.predicate)
.fetch_one(&mut *txn)
.await?;
group_id
.await?
}
LogicalExpression::Filter(filter) => {
Self::insert_into_logical_expressions(
Expand All @@ -434,14 +443,14 @@ impl SqliteMemo {
LogicalOperatorKind::Filter,
)
.await?;
let group_id = sqlx::query_scalar("INSERT INTO filters (logical_expression_id, group_id, child_group_id, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")

sqlx::query_scalar("INSERT INTO filters (logical_expression_id, group_id, child_group_id, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
.bind(logical_expr_id)
.bind(group_id)
.bind(filter.child)
.bind(filter.predicate)
.fetch_one(&mut *txn)
.await?;
group_id
.await?
}
LogicalExpression::Join(join) => {
Self::insert_into_logical_expressions(
Expand All @@ -451,16 +460,16 @@ impl SqliteMemo {
LogicalOperatorKind::Join,
)
.await?;
let group_id = sqlx::query_scalar("INSERT INTO joins (logical_expression_id, group_id, join_type, left_group_id, right_group_id, condition_group_id) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")

sqlx::query_scalar("INSERT INTO joins (logical_expression_id, group_id, join_type, left_group_id, right_group_id, condition_group_id) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
.bind(logical_expr_id)
.bind(group_id)
.bind(serde_json::to_string(&join.join_type)?)
.bind(join.left)
.bind(join.right)
.bind(join.condition)
.fetch_one(&mut *txn)
.await?;
group_id
.await?
}
};

Expand All @@ -480,6 +489,7 @@ impl SqliteMemo {
Ok(inserted_group_id)
}

/// Inserts an entry into the `logical_expressions` table.
async fn insert_into_logical_expressions(
txn: &mut SqliteConnection,
logical_expr_id: LogicalExpressionId,
Expand All @@ -496,6 +506,7 @@ impl SqliteMemo {
Ok(())
}

/// Removes a dangling logical expression from the `logical_expressions` table.
async fn remove_dangling_logical_expr(
&self,
db: &mut SqliteConnection,
Expand All @@ -510,6 +521,8 @@ impl SqliteMemo {
}

/// The SQL query to get all logical expressions in a group.
/// For each of the operators, the logical_expression_id is selected,
/// as well as the data fields in json form.
const fn get_all_logical_exprs_in_group_query() -> &'static str {
concat!(
"SELECT logical_expression_id, json_object('Scan', json_object('table_name', json(table_name), 'predicate', predicate_group_id)) as data FROM scans WHERE group_id = $1",
Expand All @@ -521,6 +534,8 @@ const fn get_all_logical_exprs_in_group_query() -> &'static str {
}

/// The SQL query to get all scalar expressions in a group.
/// For each of the operators, the scalar_expression_id is selected,
/// as well as the data fields in json form.
const fn get_all_scalar_exprs_in_group_query() -> &'static str {
concat!(
"SELECT scalar_expression_id, json_object('Constant', json(value)) as data FROM scalar_constants WHERE group_id = $1",
Expand Down
Loading

0 comments on commit 3d87b42

Please sign in to comment.