Skip to content

Commit

Permalink
refactor scalars into unary_op, binary_op, and logic_op
Browse files Browse the repository at this point in the history
Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Feb 13, 2025
1 parent 86a8fa6 commit 05a856a
Show file tree
Hide file tree
Showing 26 changed files with 452 additions and 330 deletions.
88 changes: 52 additions & 36 deletions optd-core/src/cascades/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
project::PhysicalProject, scan::table_scan::TableScan,
},
},
scalar::{add::Add, and::And, equal::Equal, ScalarOperator},
scalar::{binary_op::BinaryOp, logic_op::LogicOp, unary_op::UnaryOp, ScalarOperator},
},
plans::{
logical::{LogicalPlan, PartialLogicalPlan},
Expand Down Expand Up @@ -354,25 +354,30 @@ async fn match_any_partial_scalar_plan(
operator: ScalarOperator::ColumnRef(column_ref.clone()),
}))
}
ScalarExpression::Add(add) => {
let left = match_any_partial_scalar_plan(memo, add.left).await?;
let right = match_any_partial_scalar_plan(memo, add.right).await?;
ScalarExpression::BinaryOp(binary_op) => {
let left = match_any_partial_scalar_plan(memo, binary_op.left).await?;
let right = match_any_partial_scalar_plan(memo, binary_op.right).await?;
Ok(Arc::new(PartialScalarPlan::PartialMaterialized {
operator: ScalarOperator::Add(Add { left, right }),
operator: ScalarOperator::BinaryOp(BinaryOp::new(
binary_op.kind.clone(),
left,
right,
)),
}))
}
ScalarExpression::Equal(equal) => {
let left = match_any_partial_scalar_plan(memo, equal.left).await?;
let right = match_any_partial_scalar_plan(memo, equal.right).await?;
ScalarExpression::UnaryOp(unary_op) => {
let child = match_any_partial_scalar_plan(memo, unary_op.child).await?;
Ok(Arc::new(PartialScalarPlan::PartialMaterialized {
operator: ScalarOperator::Equal(Equal { left, right }),
operator: ScalarOperator::UnaryOp(UnaryOp::new(unary_op.kind.clone(), child)),
}))
}
ScalarExpression::And(and) => {
let left = match_any_partial_scalar_plan(memo, and.left).await?;
let right = match_any_partial_scalar_plan(memo, and.right).await?;
ScalarExpression::LogicOp(logic) => {
let mut children = Vec::with_capacity(logic.children.len());
for child in logic.children.iter() {
children.push(match_any_partial_scalar_plan(memo, *child).await?);
}
Ok(Arc::new(PartialScalarPlan::PartialMaterialized {
operator: ScalarOperator::And(And { left, right }),
operator: ScalarOperator::LogicOp(LogicOp::new(logic.kind.clone(), children)),
}))
}
}
Expand All @@ -392,25 +397,30 @@ async fn match_any_scalar_plan(
ScalarExpression::ColumnRef(column_ref) => Ok(Arc::new(ScalarPlan {
operator: ScalarOperator::ColumnRef(column_ref.clone()),
})),
ScalarExpression::Add(add) => {
let left = match_any_scalar_plan(memo, add.left).await?;
let right = match_any_scalar_plan(memo, add.right).await?;
ScalarExpression::BinaryOp(binary_op) => {
let left = match_any_scalar_plan(memo, binary_op.left).await?;
let right = match_any_scalar_plan(memo, binary_op.right).await?;
Ok(Arc::new(ScalarPlan {
operator: ScalarOperator::Add(Add { left, right }),
operator: ScalarOperator::BinaryOp(BinaryOp::new(
binary_op.kind.clone(),
left,
right,
)),
}))
}
ScalarExpression::Equal(equal) => {
let left = match_any_scalar_plan(memo, equal.left).await?;
let right = match_any_scalar_plan(memo, equal.right).await?;
ScalarExpression::UnaryOp(unary_op) => {
let child = match_any_scalar_plan(memo, unary_op.child).await?;
Ok(Arc::new(ScalarPlan {
operator: ScalarOperator::Equal(Equal { left, right }),
operator: ScalarOperator::UnaryOp(UnaryOp::new(unary_op.kind.clone(), child)),
}))
}
ScalarExpression::And(and) => {
let left = match_any_scalar_plan(memo, and.left).await?;
let right = match_any_scalar_plan(memo, and.right).await?;
ScalarExpression::LogicOp(logic_op) => {
let mut children = Vec::with_capacity(logic_op.children.len());
for child in logic_op.children.iter() {
children.push(match_any_scalar_plan(memo, *child).await?);
}
Ok(Arc::new(ScalarPlan {
operator: ScalarOperator::And(And { left, right }),
operator: ScalarOperator::LogicOp(LogicOp::new(logic_op.kind.clone(), children)),
}))
}
}
Expand Down Expand Up @@ -470,7 +480,7 @@ mod tests {
// select * from t1 where t1.id = 1 and t1.name = 'Memo';
let logical_plan = filter(
scan("t1", boolean(true)),
and(boolean(true), equal(column_ref(2), string("Memo"))),
and(vec![boolean(true), equal(column_ref(2), string("Memo"))]),
);

let group_id = ingest_partial_logical_plan(&memo, &logical_plan).await?;
Expand Down Expand Up @@ -508,8 +518,8 @@ mod tests {

// select * from t1 where t1.#0 = 1 and true;
let logical_plan = filter(
scan("t1", boolean(true)),
and(equal(column_ref(0), int64(1)), boolean(true)),
scan("t1", or(vec![boolean(true), boolean(false)])),
and(vec![equal(column_ref(0), int64(1)), boolean(true)]),
);
let group_id = ingest_partial_logical_plan(&memo, &logical_plan).await?;

Expand All @@ -522,8 +532,8 @@ mod tests {
assert_eq!(
physical_plan,
physical_filter(
table_scan("t1", boolean(true)),
and(equal(column_ref(0), int64(1)), boolean(true))
table_scan("t1", or(vec![boolean(true), boolean(false)])),
and(vec![equal(column_ref(0), int64(1)), boolean(true)])
)
);

Expand All @@ -534,8 +544,8 @@ mod tests {
async fn test_join_e2e() -> anyhow::Result<()> {
let memo = SqliteMemo::new_in_memory().await?;

// select * from t1 where t1.#0 = 1 and true;
let scan_t1 = scan("t1", boolean(true));
// select * from t1 where t1.#0 = 1 and NOT false;
let scan_t1 = scan("t1", not(boolean(false)));
let logical_plan = join(
"inner",
scan_t1.clone(),
Expand All @@ -550,7 +560,7 @@ mod tests {
mock_optimize_relation_group(&memo, group_id).await?;
let physical_plan = match_any_partial_physical_plan(&memo, group_id).await?;

let table_scan_t1 = table_scan("t1", boolean(true));
let table_scan_t1 = table_scan("t1", not(boolean(false)));
assert_eq!(
physical_plan,
nested_loop_join(
Expand All @@ -568,10 +578,13 @@ mod tests {
async fn test_project_e2e() -> anyhow::Result<()> {
let memo = SqliteMemo::new_in_memory().await?;

// select t1.#0, t1.#1 + 1 from t1;
// select t1.#0, (t1.#1 + 1) - (-3) from t1;
let logical_plan = project(
scan("t1", boolean(true)),
vec![column_ref(0), add(column_ref(1), int64(1))],
vec![
column_ref(0),
minus(add(column_ref(1), int64(1)), neg(int64(3))),
],
);
let group_id = ingest_partial_logical_plan(&memo, &logical_plan).await?;

Expand All @@ -585,7 +598,10 @@ mod tests {
physical_plan,
physical_project(
table_scan("t1", boolean(true)),
vec![column_ref(0), add(column_ref(1), int64(1))],
vec![
column_ref(0),
minus(add(column_ref(1), int64(1)), neg(int64(3))),
],
)
);

Expand Down
25 changes: 0 additions & 25 deletions optd-core/src/operators/scalar/add.rs

This file was deleted.

28 changes: 0 additions & 28 deletions optd-core/src/operators/scalar/and.rs

This file was deleted.

43 changes: 43 additions & 0 deletions optd-core/src/operators/scalar/binary_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//! A scalar binary operator.
use crate::{operators::scalar::ScalarOperator, values::OptdValue};
use serde::Deserialize;

/// A scalar operator that performs a binary operation on two values.
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct BinaryOp<Value, Scalar> {
/// The kind of operator.
pub kind: Value,
/// The left operand.
pub left: Scalar,
/// The right operand.
pub right: Scalar,
}

impl<Value, Scalar> BinaryOp<Value, Scalar> {
/// Create a new addition operator.
pub fn new(kind: Value, left: Scalar, right: Scalar) -> Self {
Self { kind, left, right }
}
}

/// Creates an addition scalar operator.
pub fn add<Scalar>(left: Scalar, right: Scalar) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::BinaryOp(BinaryOp::new(OptdValue::String("add".into()), left, right))
}

pub fn minus<Scalar>(left: Scalar, right: Scalar) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::BinaryOp(BinaryOp::new(
OptdValue::String("minus".into()),
left,
right,
))
}

/// Creates an equality scalar operator.
pub fn equal<Scalar>(left: Scalar, right: Scalar) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::BinaryOp(BinaryOp::new(
OptdValue::String("equal".into()),
left,
right,
))
}
16 changes: 13 additions & 3 deletions optd-core/src/operators/scalar/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@ impl Constant<OptdValue> {
}
}

/// Creates a constant scalar operator.
pub fn constant<Scalar>(value: OptdValue) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::Constant(Constant::new(value))
/// Creates a boolean constant scalar operator.
pub fn boolean<Scalar>(value: bool) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::Constant(Constant::new(OptdValue::Bool(value)))
}

/// Creates an `int64` constant scalar operator.
pub fn int64<Scalar>(value: bool) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::Constant(Constant::new(OptdValue::Bool(value)))
}

/// Creates a string constant scalar operator.
pub fn string<Scalar>(value: &str) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::Constant(Constant::new(OptdValue::String(value.into())))
}
25 changes: 0 additions & 25 deletions optd-core/src/operators/scalar/equal.rs

This file was deleted.

33 changes: 33 additions & 0 deletions optd-core/src/operators/scalar/logic_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//! A scalar logic operator.
use serde::Deserialize;

use crate::values::OptdValue;

use super::ScalarOperator;

/// A scalar operator that adds two values.
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct LogicOp<Value, Scalar> {
/// The kind of logic operator.
pub kind: Value,
/// The operands to the logic operator.
pub children: Vec<Scalar>,
}

impl<Value, Scalar> LogicOp<Value, Scalar> {
/// Create a new logic scalar operator.
pub fn new(kind: Value, children: Vec<Scalar>) -> Self {
Self { kind, children }
}
}

/// Creates an `and` logic scalar operator.
pub fn and<Scalar>(children: Vec<Scalar>) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::LogicOp(LogicOp::new(OptdValue::String("and".into()), children))
}

/// Creates an `and` logic scalar operator.
pub fn or<Scalar>(children: Vec<Scalar>) -> ScalarOperator<OptdValue, Scalar> {
ScalarOperator::LogicOp(LogicOp::new(OptdValue::String("or".into()), children))
}
Loading

0 comments on commit 05a856a

Please sign in to comment.