diff --git a/Cargo.lock b/Cargo.lock index 5f394202..fcdf85c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3248,6 +3248,7 @@ version = "0.1.1" dependencies = [ "anyhow", "arrow-schema", + "async-recursion", "async-trait", "chrono", "itertools 0.13.0", diff --git a/optd-ng-kernel/Cargo.toml b/optd-ng-kernel/Cargo.toml index 460c9a26..6e2f77bc 100644 --- a/optd-ng-kernel/Cargo.toml +++ b/optd-ng-kernel/Cargo.toml @@ -9,6 +9,7 @@ repository.workspace = true [dependencies] anyhow = "1" +async-recursion = "1" async-trait = "0.1" arrow-schema = "47.0.0" tracing = "0.1" diff --git a/optd-ng-kernel/src/cascades/memo.rs b/optd-ng-kernel/src/cascades/memo.rs index 975fdf8a..e9865521 100644 --- a/optd-ng-kernel/src/cascades/memo.rs +++ b/optd-ng-kernel/src/cascades/memo.rs @@ -108,9 +108,6 @@ pub trait Memo: 'static + Send + Sync { /// Get all groups IDs in the memo table. async fn get_all_group_ids(&self) -> Vec; - /// Get a group by ID - async fn get_group(&self, group_id: GroupId) -> &Group; - /// Get a predicate by ID async fn get_pred(&self, pred_id: PredId) -> ArcPredNode; @@ -122,17 +119,5 @@ pub trait Memo: 'static + Send + Sync { // are more efficient way to retrieve the information. /// Get all expressions in the group. - async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec { - let group = self.get_group(group_id).await; - let mut exprs = group.group_exprs.iter().copied().collect_vec(); - // Sort so that we can get a stable processing order for the expressions, therefore making regression test - // yield a stable result across different platforms. - exprs.sort(); - exprs - } - - /// Get group info of a group. - async fn get_group_info(&self, group_id: GroupId) -> &GroupInfo { - &self.get_group(group_id).await.info - } + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec; } diff --git a/optd-ng-kernel/src/cascades/naive_memo.rs b/optd-ng-kernel/src/cascades/naive_memo.rs index 69f446e8..a8271267 100644 --- a/optd-ng-kernel/src/cascades/naive_memo.rs +++ b/optd-ng-kernel/src/cascades/naive_memo.rs @@ -74,8 +74,15 @@ impl Memo for NaiveMemo { self.get_all_group_ids_inner() } - async fn get_group(&self, group_id: GroupId) -> &Group { - self.get_group_inner(group_id) + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec { + let mut expr_ids: Vec = self + .get_group_inner(group_id) + .group_exprs + .iter() + .copied() + .collect(); + expr_ids.sort(); + expr_ids } async fn estimated_plan_space(&self) -> usize { @@ -458,7 +465,7 @@ pub(crate) mod tests { group_id, ) .await; - assert_eq!(memo.get_group(group_id).await.group_exprs.len(), 2); + assert_eq!(memo.get_all_exprs_in_group(group_id).await.len(), 2); } #[tokio::test] diff --git a/optd-ng-kernel/src/cascades/persistent_memo.rs b/optd-ng-kernel/src/cascades/persistent_memo.rs index 91460ca0..8ff1c07f 100644 --- a/optd-ng-kernel/src/cascades/persistent_memo.rs +++ b/optd-ng-kernel/src/cascades/persistent_memo.rs @@ -1,12 +1,13 @@ use std::marker::PhantomData; +use async_recursion::async_recursion; use async_trait::async_trait; use sqlx::{Row, SqlitePool}; use crate::nodes::{ArcPlanNode, ArcPredNode, PersistentNodeType, PlanNodeOrGroup}; use super::{ - memo::{ArcMemoPlanNode, Group, Memo}, + memo::{ArcMemoPlanNode, Group, Memo, MemoPlanNode}, optimizer::{ExprId, GroupId, PredId}, }; @@ -29,7 +30,10 @@ impl PersistentMemo { sqlx::query("CREATE TABLE groups(group_id INTEGER PRIMARY KEY AUTOINCREMENT)") .execute(&self.db_conn) .await?; - sqlx::query("CREATE TABLE group_exprs(group_expr_id INTEGER PRIMARY KEY AUTOINCREMENT, group_id INTEGER, tag TEXT, children JSON DEFAULT('[]'))") + sqlx::query("CREATE TABLE group_merges(from_group_id INTEGER PRIMARY KEY AUTOINCREMENT, to_group_id INTEGER)") + .execute(&self.db_conn) + .await?; + sqlx::query("CREATE TABLE group_exprs(group_expr_id INTEGER PRIMARY KEY AUTOINCREMENT, group_id INTEGER, tag TEXT, children JSON DEFAULT('[]'), predicates JSON DEFAULT('[]'))") .execute(&self.db_conn) .await?; sqlx::query( @@ -46,10 +50,73 @@ pub async fn new_in_memory() -> anyhow::Result PersistentMemo { + #[async_recursion] + async fn add_new_expr_inner(&mut self, rel_node: ArcPlanNode) -> (GroupId, ExprId) { + let mut children_groups = Vec::new(); + for child in rel_node.children.iter() { + let group = match child { + PlanNodeOrGroup::Group(group) => *group, + PlanNodeOrGroup::PlanNode(child) => { + let (group_id, _) = self.add_new_expr_inner(child.clone()).await; + group_id + } + }; + children_groups.push(group.0); + } + let mut predicates = Vec::new(); + for pred in rel_node.predicates.iter() { + let pred_id = self.add_new_pred(pred.clone()).await; + predicates.push(pred_id.0); + } + let tag = T::serialize_plan_tag(rel_node.typ.clone()); + // check if we already have an expr in the database + let row = + sqlx::query("SELECT group_expr_id, group_id FROM group_exprs WHERE tag = ? AND children = ? AND predicates = ?") + .bind(&tag) + .bind(serde_json::to_value(&children_groups).unwrap()) + .bind(serde_json::to_value(&predicates).unwrap()) + .fetch_optional(&self.db_conn) + .await + .unwrap(); + if let Some(row) = row { + let expr_id = row.get::("group_expr_id"); + let group_id = row.get::("group_id"); + (GroupId(group_id as usize), ExprId(expr_id as usize)) + } else { + let group_id = sqlx::query("INSERT INTO groups DEFAULT VALUES") + .execute(&self.db_conn) + .await + .unwrap() + .last_insert_rowid(); + let expr_id = sqlx::query( + "INSERT INTO group_exprs(group_id, tag, children, predicates) VALUES (?, ?, ?, ?)", + ) + .bind(group_id) + .bind(&tag) + .bind(serde_json::to_value(&children_groups).unwrap()) + .bind(serde_json::to_value(&predicates).unwrap()) + .execute(&self.db_conn) + .await + .unwrap() + .last_insert_rowid(); + (GroupId(group_id as usize), ExprId(expr_id as usize)) + } + } + + async fn add_expr_to_group_inner( + &mut self, + rel_node: PlanNodeOrGroup, + group_id: GroupId, + ) -> Option { + unimplemented!() + } +} + #[async_trait] impl Memo for PersistentMemo { async fn add_new_expr(&mut self, rel_node: ArcPlanNode) -> (GroupId, ExprId) { - unimplemented!() + self.add_new_expr_inner(rel_node).await } async fn add_expr_to_group( @@ -90,19 +157,63 @@ impl Memo for PersistentMemo { } async fn get_group_id(&self, expr_id: ExprId) -> GroupId { - unimplemented!() + let group_id = sqlx::query("SELECT group_id FROM group_exprs WHERE group_expr_id = ?") + .bind(expr_id.0 as i64) + .fetch_one(&self.db_conn) + .await + .unwrap() + .get::(0); + GroupId(group_id as usize) } async fn get_expr_memoed(&self, expr_id: ExprId) -> ArcMemoPlanNode { - unimplemented!() + let row = sqlx::query( + "SELECT tag, children, predicates FROM group_exprs WHERE group_expr_id = ?", + ) + .bind(expr_id.0 as i64) + .fetch_one(&self.db_conn) + .await + .unwrap(); + let tag = row.get::(0); + let children = row.get::(1); + let predicates = row.get::(2); + let children: Vec = serde_json::from_value(children).unwrap(); + let children = children.into_iter().map(|x| GroupId(x)).collect(); + let predicates: Vec = serde_json::from_value(predicates).unwrap(); + let predicates = predicates.into_iter().map(|x| PredId(x)).collect(); + MemoPlanNode { + typ: T::deserialize_plan_tag(serde_json::from_str(&tag).unwrap()), + children, + predicates, + } + .into() } async fn get_all_group_ids(&self) -> Vec { - unimplemented!() + let group_ids = sqlx::query("SELECT group_id FROM groups ORDER BY group_id") + .fetch_all(&self.db_conn) + .await + .unwrap(); + let group_ids: Vec = group_ids + .into_iter() + .map(|row| GroupId(row.get::(0) as usize)) + .collect(); + group_ids } - async fn get_group(&self, group_id: GroupId) -> &Group { - unimplemented!() + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec { + let expr_ids = sqlx::query( + "SELECT group_expr_id FROM group_exprs WHERE group_id = ? ORDER BY group_expr_id", + ) + .bind(group_id.0 as i64) + .fetch_all(&self.db_conn) + .await + .unwrap(); + let expr_ids: Vec = expr_ids + .into_iter() + .map(|row| ExprId(row.get::(0) as usize)) + .collect(); + expr_ids } async fn estimated_plan_space(&self) -> usize { @@ -137,4 +248,13 @@ mod tests { let p2 = memo.add_new_pred(pred_node.clone()).await; assert_eq!(p1, p2); } + + #[tokio::test] + async fn add_expr() { + let mut memo = create_db_and_migrate().await; + let scan_node = scan("t1"); + let p1 = memo.add_new_expr(scan_node.clone()).await; + let p2 = memo.add_new_expr(scan_node.clone()).await; + assert_eq!(p1, p2); + } } diff --git a/optd-ng-kernel/src/nodes.rs b/optd-ng-kernel/src/nodes.rs index 17017353..1a30b5e4 100644 --- a/optd-ng-kernel/src/nodes.rs +++ b/optd-ng-kernel/src/nodes.rs @@ -224,7 +224,12 @@ pub trait NodeType: pub trait PersistentNodeType: NodeType { fn serialize_pred(pred: &ArcPredNode) -> serde_json::Value; + fn deserialize_pred(data: serde_json::Value) -> ArcPredNode; + + fn serialize_plan_tag(tag: Self) -> serde_json::Value; + + fn deserialize_plan_tag(data: serde_json::Value) -> Self; } /// A pointer to a plan node diff --git a/optd-ng-kernel/src/tests/common.rs b/optd-ng-kernel/src/tests/common.rs index 36589cdb..0a9f7fd1 100644 --- a/optd-ng-kernel/src/tests/common.rs +++ b/optd-ng-kernel/src/tests/common.rs @@ -114,6 +114,14 @@ impl PersistentNodeType for MemoTestRelTyp { let node: PersistentPredNode = serde_json::from_value(data).unwrap(); Arc::new(node.into()) } + + fn serialize_plan_tag(tag: Self) -> serde_json::Value { + serde_json::to_value(tag).unwrap() + } + + fn deserialize_plan_tag(data: serde_json::Value) -> Self { + serde_json::from_value(data).unwrap() + } } pub(crate) fn join(