Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
dedup expressions when adding
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Nov 27, 2024
1 parent e0aaac5 commit b290f3f
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 27 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions optd-ng-kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ repository.workspace = true

[dependencies]
anyhow = "1"
async-recursion = "1"
async-trait = "0.1"
arrow-schema = "47.0.0"
tracing = "0.1"
Expand Down
17 changes: 1 addition & 16 deletions optd-ng-kernel/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ pub trait Memo<T: NodeType>: 'static + Send + Sync {
/// Get all groups IDs in the memo table.
async fn get_all_group_ids(&self) -> Vec<GroupId>;

/// Get a group by ID
async fn get_group(&self, group_id: GroupId) -> &Group;

/// Get a predicate by ID
async fn get_pred(&self, pred_id: PredId) -> ArcPredNode<T>;

Expand All @@ -122,17 +119,5 @@ pub trait Memo<T: NodeType>: 'static + Send + Sync {
// are more efficient way to retrieve the information.

/// Get all expressions in the group.
async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
let group = self.get_group(group_id).await;
let mut exprs = group.group_exprs.iter().copied().collect_vec();
// Sort so that we can get a stable processing order for the expressions, therefore making regression test
// yield a stable result across different platforms.
exprs.sort();
exprs
}

/// Get group info of a group.
async fn get_group_info(&self, group_id: GroupId) -> &GroupInfo {
&self.get_group(group_id).await.info
}
async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId>;
}
13 changes: 10 additions & 3 deletions optd-ng-kernel/src/cascades/naive_memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,15 @@ impl<T: NodeType> Memo<T> for NaiveMemo<T> {
self.get_all_group_ids_inner()
}

async fn get_group(&self, group_id: GroupId) -> &Group {
self.get_group_inner(group_id)
async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
let mut expr_ids: Vec<ExprId> = self
.get_group_inner(group_id)
.group_exprs
.iter()
.copied()
.collect();
expr_ids.sort();
expr_ids
}

async fn estimated_plan_space(&self) -> usize {
Expand Down Expand Up @@ -458,7 +465,7 @@ pub(crate) mod tests {
group_id,
)
.await;
assert_eq!(memo.get_group(group_id).await.group_exprs.len(), 2);
assert_eq!(memo.get_all_exprs_in_group(group_id).await.len(), 2);
}

#[tokio::test]
Expand Down
136 changes: 128 additions & 8 deletions optd-ng-kernel/src/cascades/persistent_memo.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::marker::PhantomData;

use async_recursion::async_recursion;
use async_trait::async_trait;
use sqlx::{Row, SqlitePool};

use crate::nodes::{ArcPlanNode, ArcPredNode, PersistentNodeType, PlanNodeOrGroup};

use super::{
memo::{ArcMemoPlanNode, Group, Memo},
memo::{ArcMemoPlanNode, Group, Memo, MemoPlanNode},
optimizer::{ExprId, GroupId, PredId},
};

Expand All @@ -29,7 +30,10 @@ impl<T: PersistentNodeType> PersistentMemo<T> {
sqlx::query("CREATE TABLE groups(group_id INTEGER PRIMARY KEY AUTOINCREMENT)")
.execute(&self.db_conn)
.await?;
sqlx::query("CREATE TABLE group_exprs(group_expr_id INTEGER PRIMARY KEY AUTOINCREMENT, group_id INTEGER, tag TEXT, children JSON DEFAULT('[]'))")
sqlx::query("CREATE TABLE group_merges(from_group_id INTEGER PRIMARY KEY AUTOINCREMENT, to_group_id INTEGER)")
.execute(&self.db_conn)
.await?;
sqlx::query("CREATE TABLE group_exprs(group_expr_id INTEGER PRIMARY KEY AUTOINCREMENT, group_id INTEGER, tag TEXT, children JSON DEFAULT('[]'), predicates JSON DEFAULT('[]'))")
.execute(&self.db_conn)
.await?;
sqlx::query(
Expand All @@ -46,10 +50,73 @@ pub async fn new_in_memory<T: PersistentNodeType>() -> anyhow::Result<Persistent
Ok(PersistentMemo::new(db_conn).await)
}

impl<T: PersistentNodeType> PersistentMemo<T> {
#[async_recursion]
async fn add_new_expr_inner(&mut self, rel_node: ArcPlanNode<T>) -> (GroupId, ExprId) {
let mut children_groups = Vec::new();
for child in rel_node.children.iter() {
let group = match child {
PlanNodeOrGroup::Group(group) => *group,
PlanNodeOrGroup::PlanNode(child) => {
let (group_id, _) = self.add_new_expr_inner(child.clone()).await;
group_id
}
};
children_groups.push(group.0);
}
let mut predicates = Vec::new();
for pred in rel_node.predicates.iter() {
let pred_id = self.add_new_pred(pred.clone()).await;
predicates.push(pred_id.0);
}
let tag = T::serialize_plan_tag(rel_node.typ.clone());
// check if we already have an expr in the database
let row =
sqlx::query("SELECT group_expr_id, group_id FROM group_exprs WHERE tag = ? AND children = ? AND predicates = ?")
.bind(&tag)
.bind(serde_json::to_value(&children_groups).unwrap())
.bind(serde_json::to_value(&predicates).unwrap())
.fetch_optional(&self.db_conn)
.await
.unwrap();
if let Some(row) = row {
let expr_id = row.get::<i64, _>("group_expr_id");
let group_id = row.get::<i64, _>("group_id");
(GroupId(group_id as usize), ExprId(expr_id as usize))
} else {
let group_id = sqlx::query("INSERT INTO groups DEFAULT VALUES")
.execute(&self.db_conn)
.await
.unwrap()
.last_insert_rowid();
let expr_id = sqlx::query(
"INSERT INTO group_exprs(group_id, tag, children, predicates) VALUES (?, ?, ?, ?)",
)
.bind(group_id)
.bind(&tag)
.bind(serde_json::to_value(&children_groups).unwrap())
.bind(serde_json::to_value(&predicates).unwrap())
.execute(&self.db_conn)
.await
.unwrap()
.last_insert_rowid();
(GroupId(group_id as usize), ExprId(expr_id as usize))
}
}

async fn add_expr_to_group_inner(
&mut self,
rel_node: PlanNodeOrGroup<T>,
group_id: GroupId,
) -> Option<ExprId> {
unimplemented!()
}
}

#[async_trait]
impl<T: PersistentNodeType> Memo<T> for PersistentMemo<T> {
async fn add_new_expr(&mut self, rel_node: ArcPlanNode<T>) -> (GroupId, ExprId) {
unimplemented!()
self.add_new_expr_inner(rel_node).await
}

async fn add_expr_to_group(
Expand Down Expand Up @@ -90,19 +157,63 @@ impl<T: PersistentNodeType> Memo<T> for PersistentMemo<T> {
}

async fn get_group_id(&self, expr_id: ExprId) -> GroupId {
unimplemented!()
let group_id = sqlx::query("SELECT group_id FROM group_exprs WHERE group_expr_id = ?")
.bind(expr_id.0 as i64)
.fetch_one(&self.db_conn)
.await
.unwrap()
.get::<i64, _>(0);
GroupId(group_id as usize)
}

async fn get_expr_memoed(&self, expr_id: ExprId) -> ArcMemoPlanNode<T> {
unimplemented!()
let row = sqlx::query(
"SELECT tag, children, predicates FROM group_exprs WHERE group_expr_id = ?",
)
.bind(expr_id.0 as i64)
.fetch_one(&self.db_conn)
.await
.unwrap();
let tag = row.get::<String, _>(0);
let children = row.get::<serde_json::Value, _>(1);
let predicates = row.get::<serde_json::Value, _>(2);
let children: Vec<usize> = serde_json::from_value(children).unwrap();
let children = children.into_iter().map(|x| GroupId(x)).collect();
let predicates: Vec<usize> = serde_json::from_value(predicates).unwrap();
let predicates = predicates.into_iter().map(|x| PredId(x)).collect();
MemoPlanNode {
typ: T::deserialize_plan_tag(serde_json::from_str(&tag).unwrap()),
children,
predicates,
}
.into()
}

async fn get_all_group_ids(&self) -> Vec<GroupId> {
unimplemented!()
let group_ids = sqlx::query("SELECT group_id FROM groups ORDER BY group_id")
.fetch_all(&self.db_conn)
.await
.unwrap();
let group_ids: Vec<GroupId> = group_ids
.into_iter()
.map(|row| GroupId(row.get::<i64, _>(0) as usize))
.collect();
group_ids
}

async fn get_group(&self, group_id: GroupId) -> &Group {
unimplemented!()
async fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
let expr_ids = sqlx::query(
"SELECT group_expr_id FROM group_exprs WHERE group_id = ? ORDER BY group_expr_id",
)
.bind(group_id.0 as i64)
.fetch_all(&self.db_conn)
.await
.unwrap();
let expr_ids: Vec<ExprId> = expr_ids
.into_iter()
.map(|row| ExprId(row.get::<i64, _>(0) as usize))
.collect();
expr_ids
}

async fn estimated_plan_space(&self) -> usize {
Expand Down Expand Up @@ -137,4 +248,13 @@ mod tests {
let p2 = memo.add_new_pred(pred_node.clone()).await;
assert_eq!(p1, p2);
}

#[tokio::test]
async fn add_expr() {
let mut memo = create_db_and_migrate().await;
let scan_node = scan("t1");
let p1 = memo.add_new_expr(scan_node.clone()).await;
let p2 = memo.add_new_expr(scan_node.clone()).await;
assert_eq!(p1, p2);
}
}
5 changes: 5 additions & 0 deletions optd-ng-kernel/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,12 @@ pub trait NodeType:

pub trait PersistentNodeType: NodeType {
fn serialize_pred(pred: &ArcPredNode<Self>) -> serde_json::Value;

fn deserialize_pred(data: serde_json::Value) -> ArcPredNode<Self>;

fn serialize_plan_tag(tag: Self) -> serde_json::Value;

fn deserialize_plan_tag(data: serde_json::Value) -> Self;
}

/// A pointer to a plan node
Expand Down
8 changes: 8 additions & 0 deletions optd-ng-kernel/src/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ impl PersistentNodeType for MemoTestRelTyp {
let node: PersistentPredNode<MemoTestRelTyp> = serde_json::from_value(data).unwrap();
Arc::new(node.into())
}

fn serialize_plan_tag(tag: Self) -> serde_json::Value {
serde_json::to_value(tag).unwrap()
}

fn deserialize_plan_tag(data: serde_json::Value) -> Self {
serde_json::from_value(data).unwrap()
}
}

pub(crate) fn join(
Expand Down

0 comments on commit b290f3f

Please sign in to comment.