Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DSL parser + AST implementation #23

Merged
merged 66 commits into from
Feb 11, 2025
Merged

feat: DSL parser + AST implementation #23

merged 66 commits into from
Feb 11, 2025

Conversation

AlSchlo
Copy link
Collaborator

@AlSchlo AlSchlo commented Feb 9, 2025

Problem

We want a DSL to be able to write rules & operators in a declarative fashion. This makes the code more maintainable, maximizes compatibility, and speeds up the writing of newer rules.

Summary of changes

Wrote a parser of the OPTD-DSL using Pest. The syntax is highly functional and inspired from Scala. Some snippets below:

Next steps

Semantic analysis, type checking, and low-level IR generation.
Also switching to Chumsky as parser, since Pest is not very maintainable.

// Logical Properties
Logical Props(schema_len: Int64)

// Scalar Operators
Scalar ColumnRef(idx: Int64)
Scalar Mult(left: Int64, right: Int64)
Scalar Add(left: Int64, right: Int64)
Scalar And(children: [Scalar])
Scalar Or(children: [Scalar])
Scalar Not(child: Scalar)

// Logical Operators
Logical Scan(table_name: String) derive {
  schema_len = 5 // <call catalog>
}

Logical Filter(child: Logical, cond: Scalar) derive {
  schema_len = input.schema_len
}

Logical Project(child: Logical, exprs: [Scalar]) derive {
  schema_len = exprs.len
}

Logical Join(
    left: Logical,
    right: Logical,
    typ: String,
    cond: Scalar
) derive {
  schema_len = left.schema_len + right.schema_len
}

Logical Sort(child: Logical, keys: [(Scalar, String)]) derive {
  schema_len = input.schema_len
}

Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) 
  derive {
    schema_len = group_keys.len + aggs.len
}

// Rules
def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar =
  match predicate
    case ColumnRef(idx) => ColumnRef(map(idx)),
    case other @ _ => predicate.apply_children(child => rewrite_column_refs(child, map))

@rule(Logical)
def join_commute(expr: Logical): Logical =
  match expr
    case Join("Inner", left, right, cond) =>
      val left_len = left.schema_len;
      val right_len = right.schema_len;
      
      val right_indices = 0..right_len;
      val left_indices = 0..left_len;
      
      val remapping = (left_indices.map(i => (i, i + right_len)) ++
        right_indices.map(i => (left_len + i, i))).to_map();

      Project(
        Join("Inner", right, left, rewrite_column_refs(cond, remapping)),
          left_indices.map(i => ColumnRef(i + right_len)) ++ 
          right_indices.map(i => ColumnRef(i)
        )
      )

def has_refs_in_range(cond: Scalar, from: Int64, to: Int64): Bool =
  match predicate
    case ColumnRef(idx) => from <= idx && idx < to,
    case _ => predicate.children.any(child => has_refs_in_range(child, from, to))
  

@rule(Logical)
def join_associate(expr: Logical): Logical =
  match expr
    case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) =>
      val a_len = a.schema_len;
      if !has_refs_in_range(cond_outer, 0, a_len) then
        val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map();
        Join(
          "Inner", a, 
          Join("Inner", b, c, rewrite_column_refs(cond_outer, remap_inner), 
            cond_inner)
        )
      else fail("")

@rule(Scalar)
def conjunctive_normal_form(expr: Scalar): Scalar = fail("unimplemented")

def with_optional_filter(key: String, old: Scalar, grouped: Map[String, [Scalar]]): Scalar = 
  match grouped(key)
     case Some(conds) => Filter(conds, old),
     case _ => old

@rule(Logical)
def filter_pushdown_join(expr: Logical): Logical =
  match expr
    case op @ Filter(Join(join_type, left, right, join_cond), cond) =>
      val cnf = conjunctive_normal_form(cond);
      val grouped = cnf.children.groupBy(cond => {
        if has_refs_in_range(cond, 0, left.schema_len) && 
           !has_refs_in_range(cond, left.schema_len, op.schema_len) then
          "left"
        else if !has_refs_in_range(cond, 0, left.schema_len) && 
                has_refs_in_range(cond, left.schema_len, op.schema_len) then
          "right"
        else 
          "remain"
      });
 
      with_optional_filter("remain", 
        Join(join_type, 
          with_optional_filter("left", grouped), 
          with_optional_filter("right", grouped), 
          join_cond
        )
      )

@AlSchlo AlSchlo changed the title feat: DSL Parser & AST Implementation feat: DSL parser and AST implementation Feb 9, 2025
@AlSchlo AlSchlo changed the title feat: DSL parser and AST implementation feat: DSL parser + AST implementation Feb 9, 2025
@codecov-commenter
Copy link

codecov-commenter commented Feb 10, 2025

@AlSchlo AlSchlo marked this pull request as ready for review February 10, 2025 01:12
String,
Bool,
Float64,
Array(Box<Type>), // Array types like [T]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remember to change these to rust comments.

Copy link
Member

@skyzh skyzh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems like we can parse the DSL now, and we need some more work to make the interpreter to work with it.

@@ -0,0 +1,258 @@
/*use std::collections::HashSet;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm unused files?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah damn i pushed name analyzer in here.

@@ -0,0 +1,325 @@
use pest::iterators::Pair;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this could be simplified with https://github.com/pest-parser/ast in the future, basically it's a conversion from pest::Pair -> AST

Copy link
Member

@skyzh skyzh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⬇️ more reviews on the rule correctness

schema_len = input.schema_len
}

Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why aggs [Scalar, String]? should be something like [Scalar] or [AggFunction]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string is the aggfunction for now

we don't support UDTs yet

// Rules

// memo(alexis): support member function for operators like apply_children
def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need handle the case where map doesn't contain a mapping of a column referenced, throw an error that it can't rewrite in such cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was planning to do that in the map implementation

@rule(Logical)
def join_commute(expr: Logical): Logical =
match expr
case Join("Inner", left, right, cond) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grammar wise, is the order required here? i.e., would Join("Inner", right, left, cond) produce a correct match?

proposing syntex,

Join { type: "Inner", left, right, cond }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes order is required

val left_indices = 0..left_len;

val remapping = (left_indices.map(i => (i, i + right_len)) ++
right_indices.map(i => (left_len + i, i))).to_map();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right_indices don't seem to be set correctly, should be i -> i - left_len?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this is correct, the pair is reversed


Project(
Join("Inner", right, left, rewrite_column_refs(cond, remapping)),
left_indices.map(i => ColumnRef(i + right_len)) ++
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be right -> right + left_len, left -> left - right_len?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah there seems to be smt wrong here

case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) =>
val a_len = a.schema_len;
if !has_refs_in_range(cond_outer, 0, a_len) then
val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/remap_inner/remap_outer?

@AlSchlo AlSchlo merged commit baa09fc into main Feb 11, 2025
12 checks passed
@AlSchlo AlSchlo deleted the alexis/dsl-compiler branch February 11, 2025 04:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants