Skip to content

Commit

Permalink
feat: DSL parser + AST implementation (#23)
Browse files Browse the repository at this point in the history
## Problem

We want a DSL to be able to write rules & operators in a declarative
fashion. This makes the code more maintainable, maximizes compatibility,
and speeds up the writing of newer rules.

## Summary of changes

Wrote a parser of the OPTD-DSL using Pest. The syntax is highly
functional and inspired from Scala. Some snippets below:

## Next steps

Semantic analysis, type checking, and low-level IR generation.
Also switching to Chumsky as parser, since Pest is not very
maintainable.
 
```scala
// Logical Properties
Logical Props(schema_len: Int64)

// Scalar Operators
Scalar ColumnRef(idx: Int64)
Scalar Mult(left: Int64, right: Int64)
Scalar Add(left: Int64, right: Int64)
Scalar And(children: [Scalar])
Scalar Or(children: [Scalar])
Scalar Not(child: Scalar)

// Logical Operators
Logical Scan(table_name: String) derive {
  schema_len = 5 // <call catalog>
}

Logical Filter(child: Logical, cond: Scalar) derive {
  schema_len = input.schema_len
}

Logical Project(child: Logical, exprs: [Scalar]) derive {
  schema_len = exprs.len
}

Logical Join(
    left: Logical,
    right: Logical,
    typ: String,
    cond: Scalar
) derive {
  schema_len = left.schema_len + right.schema_len
}

Logical Sort(child: Logical, keys: [(Scalar, String)]) derive {
  schema_len = input.schema_len
}

Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) 
  derive {
    schema_len = group_keys.len + aggs.len
}

// Rules
def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar =
  match predicate
    case ColumnRef(idx) => ColumnRef(map(idx)),
    case other @ _ => predicate.apply_children(child => rewrite_column_refs(child, map))

@rule(Logical)
def join_commute(expr: Logical): Logical =
  match expr
    case Join("Inner", left, right, cond) =>
      val left_len = left.schema_len;
      val right_len = right.schema_len;
      
      val right_indices = 0..right_len;
      val left_indices = 0..left_len;
      
      val remapping = (left_indices.map(i => (i, i + right_len)) ++
        right_indices.map(i => (left_len + i, i))).to_map();

      Project(
        Join("Inner", right, left, rewrite_column_refs(cond, remapping)),
          left_indices.map(i => ColumnRef(i + right_len)) ++ 
          right_indices.map(i => ColumnRef(i)
        )
      )

def has_refs_in_range(cond: Scalar, from: Int64, to: Int64): Bool =
  match predicate
    case ColumnRef(idx) => from <= idx && idx < to,
    case _ => predicate.children.any(child => has_refs_in_range(child, from, to))
  

@rule(Logical)
def join_associate(expr: Logical): Logical =
  match expr
    case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) =>
      val a_len = a.schema_len;
      if !has_refs_in_range(cond_outer, 0, a_len) then
        val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map();
        Join(
          "Inner", a, 
          Join("Inner", b, c, rewrite_column_refs(cond_outer, remap_inner), 
            cond_inner)
        )
      else fail("")

@rule(Scalar)
def conjunctive_normal_form(expr: Scalar): Scalar = fail("unimplemented")

def with_optional_filter(key: String, old: Scalar, grouped: Map[String, [Scalar]]): Scalar = 
  match grouped(key)
     case Some(conds) => Filter(conds, old),
     case _ => old

@rule(Logical)
def filter_pushdown_join(expr: Logical): Logical =
  match expr
    case op @ Filter(Join(join_type, left, right, join_cond), cond) =>
      val cnf = conjunctive_normal_form(cond);
      val grouped = cnf.children.groupBy(cond => {
        if has_refs_in_range(cond, 0, left.schema_len) && 
           !has_refs_in_range(cond, left.schema_len, op.schema_len) then
          "left"
        else if !has_refs_in_range(cond, 0, left.schema_len) && 
                has_refs_in_range(cond, left.schema_len, op.schema_len) then
          "right"
        else 
          "remain"
      });
 
      with_optional_filter("remain", 
        Join(join_type, 
          with_optional_filter("left", grouped), 
          with_optional_filter("right", grouped), 
          join_cond
        )
      )
```

---------

Signed-off-by: Yuchen Liang <[email protected]>
Co-authored-by: Yuchen Liang <[email protected]>
Co-authored-by: Yuchen Liang <[email protected]>
  • Loading branch information
3 people authored Feb 11, 2025
1 parent a94829a commit baa09fc
Show file tree
Hide file tree
Showing 15 changed files with 2,100 additions and 1 deletion.
53 changes: 53 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions optd-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1", features = ["raw_value"] }
dotenvy = "0.15"
async-recursion = "1.1.1"
pest = "2.7.15"
pest_derive = "2.7.15"
2 changes: 1 addition & 1 deletion optd-core/src/cascades/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ mod tests {

#[tokio::test]
async fn test_ingest_partial_logical_plan() -> anyhow::Result<()> {
let memo = SqliteMemo::new("sqlite://memo.db").await?;
let memo = SqliteMemo::new_in_memory().await?;
// select * from t1, t2 where t1.id = t2.id and t2.name = 'Memo' and t2.v1 = 1 + 1
let partial_logical_plan = filter(
join(
Expand Down
1 change: 1 addition & 0 deletions optd-core/src/dsl/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod parser;
151 changes: 151 additions & 0 deletions optd-core/src/dsl/parser/ast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::collections::HashMap;

/// Types supported by the language
#[derive(Debug, Clone, PartialEq)]
pub enum Type {
Int64,
String,
Bool,
Float64,
Array(Box<Type>), // Array types like [T]
Map(Box<Type>, Box<Type>), // Map types like map[K->V]
Tuple(Vec<Type>), // Tuple types like (T1, T2)
Function(Box<Type>, Box<Type>), // Function types like (T1)->T2
Operator(OperatorKind), // Operator types (scalar/logical)
}

/// Kinds of operators supported in the language
#[derive(Debug, Clone, PartialEq)]
pub enum OperatorKind {
Scalar, // Scalar operators
Logical, // Logical operators with derivable properties
}

/// A field in an operator or properties block
#[derive(Debug, Clone)]
pub struct Field {
pub name: String,
pub ty: Type,
}

/// Logical properties block that must appear exactly once per file
#[derive(Debug, Clone)]
pub struct Properties {
pub fields: Vec<Field>,
}

/// Top-level operator definition
#[derive(Debug, Clone)]
pub enum Operator {
Scalar(ScalarOp),
Logical(LogicalOp),
}

/// Scalar operator definition
#[derive(Debug, Clone)]
pub struct ScalarOp {
pub name: String,
pub fields: Vec<Field>,
}

/// Logical operator definition with derived properties
#[derive(Debug, Clone)]
pub struct LogicalOp {
pub name: String,
pub fields: Vec<Field>,
pub derived_props: HashMap<String, Expr>, // Maps property names to their derivation expressions
}

/// Patterns used in match expressions
#[derive(Debug, Clone)]
pub enum Pattern {
Bind(String, Box<Pattern>), // Binding patterns like x@p or x:p
Constructor(
String, // Constructor name
Vec<Pattern>, // Subpatterns, can be named (x:p) or positional
),
Literal(Literal), // Literal patterns like 42 or "hello"
Wildcard, // Wildcard pattern _
Var(String), // Variable binding pattern
}

/// Literal values
#[derive(Debug, Clone)]
pub enum Literal {
Int64(i64),
String(String),
Bool(bool),
Float64(f64),
Array(Vec<Expr>), // Array literals [e1, e2, ...]
Tuple(Vec<Expr>), // Tuple literals (e1, e2, ...)
}

/// Expressions - the core of the language
#[derive(Debug, Clone)]
pub enum Expr {
Match(Box<Expr>, Vec<MatchArm>), // Pattern matching
If(Box<Expr>, Box<Expr>, Box<Expr>), // If-then-else
Val(String, Box<Expr>, Box<Expr>), // Local binding (val x = e1; e2)
Constructor(String, Vec<Expr>), // Constructor application (currently only operators)
Binary(Box<Expr>, BinOp, Box<Expr>), // Binary operations
Unary(UnaryOp, Box<Expr>), // Unary operations
Call(Box<Expr>, Vec<Expr>), // Function application
Member(Box<Expr>, String), // Field access (e.f)
MemberCall(Box<Expr>, String, Vec<Expr>), // Method call (e.f(args))
ArrayIndex(Box<Expr>, Box<Expr>), // Array indexing (e[i])
Var(String), // Variable reference
Literal(Literal), // Literal values
Fail(String), // Failure with message
Closure(Vec<String>, Box<Expr>), // Anonymous functions v = (x, y) => x + y;
}

/// A case in a match expression
#[derive(Debug, Clone)]
pub struct MatchArm {
pub pattern: Pattern,
pub expr: Expr,
}

/// Binary operators with fixed precedence
#[derive(Debug, Clone)]
pub enum BinOp {
Add, // +
Sub, // -
Mul, // *
Div, // /
Concat, // ++
Eq, // ==
Neq, // !=
Gt, // >
Lt, // <
Ge, // >=
Le, // <=
And, // &&
Or, // ||
Range, // ..
}

/// Unary operators
#[derive(Debug, Clone)]
pub enum UnaryOp {
Neg, // -
Not, // !
}

/// Function definition
#[derive(Debug, Clone)]
pub struct Function {
pub name: String,
pub params: Vec<(String, Type)>, // Parameter name and type pairs
pub return_type: Type,
pub body: Expr,
pub rule_type: Option<OperatorKind>, // Some if this is a rule, indicating what kind
}

/// A complete source file
#[derive(Debug, Clone)]
pub struct File {
pub properties: Properties, // The single logical properties block
pub operators: Vec<Operator>, // All operator definitions
pub functions: Vec<Function>, // All function definitions
}
Loading

0 comments on commit baa09fc

Please sign in to comment.