Skip to content

Commit 3d87b42

Browse files
committed
Merge with main
2 parents 3173aed + a94829a commit 3d87b42

File tree

3 files changed

+95
-68
lines changed

3 files changed

+95
-68
lines changed

optd-core/src/cascades/memo.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,45 +17,51 @@ use anyhow::Result;
1717

1818
#[trait_variant::make(Send)]
1919
pub trait Memoize: Send + Sync + 'static {
20+
/// Gets all logical expressions in a group.
2021
async fn get_all_logical_exprs_in_group(
2122
&self,
2223
group_id: RelationalGroupId,
2324
) -> Result<Vec<(LogicalExpressionId, Arc<LogicalExpression>)>>;
2425

25-
// Returns the group id of new group if merge happened.
26+
/// Adds a logical expression to an existing group.
27+
/// Returns the group id of new group if merge happened.
2628
async fn add_logical_expr_to_group(
2729
&self,
2830
logical_expr: &LogicalExpression,
2931
group_id: RelationalGroupId,
3032
) -> Result<RelationalGroupId>;
3133

32-
// Returns the group id of group if already exists, otherwise creates a new group.
34+
/// Adds a logical expression to the memo table.
35+
/// Returns the group id of group if already exists, otherwise creates a new group.
3336
async fn add_logical_expr(&self, logical_expr: &LogicalExpression)
3437
-> Result<RelationalGroupId>;
3538

39+
/// Gets all scalar expressions in a group.
3640
async fn get_all_scalar_exprs_in_group(
3741
&self,
3842
group_id: ScalarGroupId,
3943
) -> Result<Vec<(ScalarExpressionId, Arc<ScalarExpression>)>>;
4044

41-
// Returns the group id of new group if merge happened.
45+
/// Adds a scalar expression to an existing group.
46+
/// Returns the group id of new group if merge happened.
4247
async fn add_scalar_expr_to_group(
4348
&self,
4449
scalar_expr: &ScalarExpression,
4550
group_id: ScalarGroupId,
4651
) -> Result<ScalarGroupId>;
4752

48-
// Returns the group id of group if already exists, otherwise creates a new group.
53+
/// Adds a scalar expression to the memo table.
54+
/// Returns the group id of group if already exists, otherwise creates a new group.
4955
async fn add_scalar_expr(&self, scalar_expr: &ScalarExpression) -> Result<ScalarGroupId>;
5056

51-
// Merges two relational groups and returns the new group id.
57+
/// Merges two relational groups and returns the new group id.
5258
async fn merge_relation_group(
5359
&self,
5460
from: RelationalGroupId,
5561
to: RelationalGroupId,
5662
) -> Result<RelationalGroupId>;
5763

58-
// Merges two scalar groups and returns the new group id.
64+
/// Merges two scalar groups and returns the new group id.
5965
async fn merge_scalar_group(
6066
&self,
6167
from: ScalarGroupId,

optd-core/src/storage/memo.rs

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! An implementation of the memo table using SQLite.
2+
13
use std::{str::FromStr, sync::Arc, time::Duration};
24

35
use super::transaction::Transaction;
@@ -64,7 +66,7 @@ impl SqliteMemo {
6466
/// Begin a new transaction.
6567
pub(super) async fn begin(&self) -> anyhow::Result<Transaction<'_>> {
6668
let txn = self.db.begin().await?;
67-
Ok(Transaction::new(txn).await?)
69+
Transaction::new(txn).await
6870
}
6971
}
7072

@@ -80,9 +82,7 @@ impl Memoize for SqliteMemo {
8082
}
8183

8284
let mut txn = self.begin().await?;
83-
let representative_group_id = self
84-
.get_representative_group_id(&mut *txn, group_id)
85-
.await?;
85+
let representative_group_id = self.get_representative_group_id(&mut txn, group_id).await?;
8686
let logical_exprs: Vec<LogicalExprRecord> =
8787
sqlx::query_as(&self.get_all_logical_exprs_in_group_query)
8888
.bind(representative_group_id)
@@ -129,7 +129,7 @@ impl Memoize for SqliteMemo {
129129

130130
let mut txn = self.begin().await?;
131131
let representative_group_id = self
132-
.get_representative_scalar_group_id(&mut *txn, group_id)
132+
.get_representative_scalar_group_id(&mut txn, group_id)
133133
.await?;
134134
let scalar_exprs: Vec<ScalarExprRecord> =
135135
sqlx::query_as(&self.get_all_scalar_exprs_in_group_query)
@@ -168,8 +168,7 @@ impl Memoize for SqliteMemo {
168168
to: RelationalGroupId,
169169
) -> Result<RelationalGroupId> {
170170
let mut txn = self.begin().await?;
171-
self.set_representative_group_id(&mut *txn, from, to)
172-
.await?;
171+
self.set_representative_group_id(&mut txn, from, to).await?;
173172
txn.commit().await?;
174173
Ok(to)
175174
}
@@ -180,15 +179,16 @@ impl Memoize for SqliteMemo {
180179
to: ScalarGroupId,
181180
) -> Result<ScalarGroupId> {
182181
let mut txn = self.begin().await?;
183-
self.set_representative_scalar_group_id(&mut *txn, from, to)
182+
self.set_representative_scalar_group_id(&mut txn, from, to)
184183
.await?;
185184
txn.commit().await?;
186185
Ok(to)
187186
}
188187
}
189188

190-
// Memoize helpers
189+
// Helper functions for implementing the `Memoize` trait.
191190
impl SqliteMemo {
191+
/// Gets the representative group id of a relational group.
192192
async fn get_representative_group_id(
193193
&self,
194194
db: &mut SqliteConnection,
@@ -202,6 +202,7 @@ impl SqliteMemo {
202202
Ok(representative_group_id)
203203
}
204204

205+
/// Sets the representative group id of a relational group.
205206
async fn set_representative_group_id(
206207
&self,
207208
db: &mut SqliteConnection,
@@ -216,6 +217,7 @@ impl SqliteMemo {
216217
Ok(())
217218
}
218219

220+
/// Gets the representative group id of a scalar group.
219221
async fn get_representative_scalar_group_id(
220222
&self,
221223
db: &mut SqliteConnection,
@@ -229,6 +231,7 @@ impl SqliteMemo {
229231
Ok(representative_group_id)
230232
}
231233

234+
/// Sets the representative group id of a scalar group.
232235
async fn set_representative_scalar_group_id(
233236
&self,
234237
db: &mut SqliteConnection,
@@ -243,6 +246,10 @@ impl SqliteMemo {
243246
Ok(())
244247
}
245248

249+
/// Inserts a scalar expression into the database. If the `add_to_group_id` is `Some`,
250+
/// we will attempt to add the scalar expression to the specified group.
251+
/// If the scalar expression already exists in the database, the existing group id will be returned.
252+
/// Otherwise, a new group id will be created.
246253
async fn add_scalar_expr_to_group_inner(
247254
&self,
248255
scalar_expr: &ScalarExpression,
@@ -275,13 +282,13 @@ impl SqliteMemo {
275282
ScalarOperatorKind::Constant,
276283
)
277284
.await?;
278-
let group_id = sqlx::query_scalar("INSERT INTO scalar_constants (scalar_expression_id, group_id, value) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
285+
286+
sqlx::query_scalar("INSERT INTO scalar_constants (scalar_expression_id, group_id, value) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
279287
.bind(scalar_expr_id)
280288
.bind(group_id)
281289
.bind(serde_json::to_string(&constant)?)
282290
.fetch_one(&mut *txn)
283-
.await?;
284-
group_id
291+
.await?
285292
}
286293
ScalarExpression::ColumnRef(column_ref) => {
287294
Self::insert_into_scalar_expressions(
@@ -291,13 +298,13 @@ impl SqliteMemo {
291298
ScalarOperatorKind::ColumnRef,
292299
)
293300
.await?;
294-
let group_id = sqlx::query_scalar("INSERT INTO scalar_column_refs (scalar_expression_id, group_id, column_index) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
301+
302+
sqlx::query_scalar("INSERT INTO scalar_column_refs (scalar_expression_id, group_id, column_index) VALUES ($1, $2, $3) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
295303
.bind(scalar_expr_id)
296304
.bind(group_id)
297305
.bind(serde_json::to_string(&column_ref.column_index)?)
298306
.fetch_one(&mut *txn)
299-
.await?;
300-
group_id
307+
.await?
301308
}
302309
ScalarExpression::Add(add) => {
303310
Self::insert_into_scalar_expressions(
@@ -307,17 +314,14 @@ impl SqliteMemo {
307314
ScalarOperatorKind::Add,
308315
)
309316
.await?;
310-
// println!("add: {:?}", add);
311-
// println!("scalar_expr_id: {:?}", scalar_expr_id);
312-
// println!("group_id: {:?}", group_id);
313-
let group_id = sqlx::query_scalar("INSERT INTO scalar_adds (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
317+
318+
sqlx::query_scalar("INSERT INTO scalar_adds (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
314319
.bind(scalar_expr_id)
315320
.bind(group_id)
316321
.bind(add.left)
317322
.bind(add.right)
318323
.fetch_one(&mut *txn)
319-
.await?;
320-
group_id
324+
.await?
321325
}
322326
ScalarExpression::Equal(equal) => {
323327
Self::insert_into_scalar_expressions(
@@ -327,14 +331,14 @@ impl SqliteMemo {
327331
ScalarOperatorKind::Equal,
328332
)
329333
.await?;
330-
let group_id = sqlx::query_scalar("INSERT INTO scalar_equals (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
334+
335+
sqlx::query_scalar("INSERT INTO scalar_equals (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
331336
.bind(scalar_expr_id)
332337
.bind(group_id)
333338
.bind(equal.left)
334339
.bind(equal.right)
335340
.fetch_one(&mut *txn)
336-
.await?;
337-
group_id
341+
.await?
338342
}
339343
};
340344

@@ -354,7 +358,7 @@ impl SqliteMemo {
354358
Ok(inserted_group_id)
355359
}
356360

357-
/// Inserts a scalar expression into the database.
361+
/// Inserts an entry into the `scalar_expressions` table.
358362
async fn insert_into_scalar_expressions(
359363
db: &mut SqliteConnection,
360364
scalar_expr_id: ScalarExpressionId,
@@ -371,6 +375,7 @@ impl SqliteMemo {
371375
Ok(())
372376
}
373377

378+
/// Removes a dangling scalar expression from the `scalar_expressions` table.
374379
async fn remove_dangling_scalar_expr(
375380
&self,
376381
db: &mut SqliteConnection,
@@ -383,6 +388,10 @@ impl SqliteMemo {
383388
Ok(())
384389
}
385390

391+
/// Inserts a logical expression into the memo table. If the `add_to_group_id` is `Some`,
392+
/// we will attempt to add the logical expression to the specified group.
393+
/// If the logical expression already exists in the database, the existing group id will be returned.
394+
/// Otherwise, a new group id will be created.
386395
async fn add_logical_expr_to_group_inner(
387396
&self,
388397
logical_expr: &LogicalExpression,
@@ -417,14 +426,14 @@ impl SqliteMemo {
417426
LogicalOperatorKind::Scan,
418427
)
419428
.await?;
420-
let group_id= sqlx::query_scalar("INSERT INTO scans (logical_expression_id, group_id, table_name, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
429+
430+
sqlx::query_scalar("INSERT INTO scans (logical_expression_id, group_id, table_name, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
421431
.bind(logical_expr_id)
422432
.bind(group_id)
423433
.bind(serde_json::to_string(&scan.table_name)?)
424434
.bind(scan.predicate)
425435
.fetch_one(&mut *txn)
426-
.await?;
427-
group_id
436+
.await?
428437
}
429438
LogicalExpression::Filter(filter) => {
430439
Self::insert_into_logical_expressions(
@@ -434,14 +443,14 @@ impl SqliteMemo {
434443
LogicalOperatorKind::Filter,
435444
)
436445
.await?;
437-
let group_id = sqlx::query_scalar("INSERT INTO filters (logical_expression_id, group_id, child_group_id, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
446+
447+
sqlx::query_scalar("INSERT INTO filters (logical_expression_id, group_id, child_group_id, predicate_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
438448
.bind(logical_expr_id)
439449
.bind(group_id)
440450
.bind(filter.child)
441451
.bind(filter.predicate)
442452
.fetch_one(&mut *txn)
443-
.await?;
444-
group_id
453+
.await?
445454
}
446455
LogicalExpression::Join(join) => {
447456
Self::insert_into_logical_expressions(
@@ -451,16 +460,16 @@ impl SqliteMemo {
451460
LogicalOperatorKind::Join,
452461
)
453462
.await?;
454-
let group_id = sqlx::query_scalar("INSERT INTO joins (logical_expression_id, group_id, join_type, left_group_id, right_group_id, condition_group_id) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
463+
464+
sqlx::query_scalar("INSERT INTO joins (logical_expression_id, group_id, join_type, left_group_id, right_group_id, condition_group_id) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id")
455465
.bind(logical_expr_id)
456466
.bind(group_id)
457467
.bind(serde_json::to_string(&join.join_type)?)
458468
.bind(join.left)
459469
.bind(join.right)
460470
.bind(join.condition)
461471
.fetch_one(&mut *txn)
462-
.await?;
463-
group_id
472+
.await?
464473
}
465474
};
466475

@@ -480,6 +489,7 @@ impl SqliteMemo {
480489
Ok(inserted_group_id)
481490
}
482491

492+
/// Inserts an entry into the `logical_expressions` table.
483493
async fn insert_into_logical_expressions(
484494
txn: &mut SqliteConnection,
485495
logical_expr_id: LogicalExpressionId,
@@ -496,6 +506,7 @@ impl SqliteMemo {
496506
Ok(())
497507
}
498508

509+
/// Removes a dangling logical expression from the `logical_expressions` table.
499510
async fn remove_dangling_logical_expr(
500511
&self,
501512
db: &mut SqliteConnection,
@@ -510,6 +521,8 @@ impl SqliteMemo {
510521
}
511522

512523
/// The SQL query to get all logical expressions in a group.
524+
/// For each of the operators, the logical_expression_id is selected,
525+
/// as well as the data fields in json form.
513526
const fn get_all_logical_exprs_in_group_query() -> &'static str {
514527
concat!(
515528
"SELECT logical_expression_id, json_object('Scan', json_object('table_name', json(table_name), 'predicate', predicate_group_id)) as data FROM scans WHERE group_id = $1",
@@ -521,6 +534,8 @@ const fn get_all_logical_exprs_in_group_query() -> &'static str {
521534
}
522535

523536
/// The SQL query to get all scalar expressions in a group.
537+
/// For each of the operators, the scalar_expression_id is selected,
538+
/// as well as the data fields in json form.
524539
const fn get_all_scalar_exprs_in_group_query() -> &'static str {
525540
concat!(
526541
"SELECT scalar_expression_id, json_object('Constant', json(value)) as data FROM scalar_constants WHERE group_id = $1",

0 commit comments

Comments
 (0)