Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: DSL parser + AST implementation (#23)
## Problem We want a DSL to be able to write rules & operators in a declarative fashion. This makes the code more maintainable, maximizes compatibility, and speeds up the writing of newer rules. ## Summary of changes Wrote a parser of the OPTD-DSL using Pest. The syntax is highly functional and inspired from Scala. Some snippets below: ## Next steps Semantic analysis, type checking, and low-level IR generation. Also switching to Chumsky as parser, since Pest is not very maintainable. ```scala // Logical Properties Logical Props(schema_len: Int64) // Scalar Operators Scalar ColumnRef(idx: Int64) Scalar Mult(left: Int64, right: Int64) Scalar Add(left: Int64, right: Int64) Scalar And(children: [Scalar]) Scalar Or(children: [Scalar]) Scalar Not(child: Scalar) // Logical Operators Logical Scan(table_name: String) derive { schema_len = 5 // <call catalog> } Logical Filter(child: Logical, cond: Scalar) derive { schema_len = input.schema_len } Logical Project(child: Logical, exprs: [Scalar]) derive { schema_len = exprs.len } Logical Join( left: Logical, right: Logical, typ: String, cond: Scalar ) derive { schema_len = left.schema_len + right.schema_len } Logical Sort(child: Logical, keys: [(Scalar, String)]) derive { schema_len = input.schema_len } Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) derive { schema_len = group_keys.len + aggs.len } // Rules def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar = match predicate case ColumnRef(idx) => ColumnRef(map(idx)), case other @ _ => predicate.apply_children(child => rewrite_column_refs(child, map)) @rule(Logical) def join_commute(expr: Logical): Logical = match expr case Join("Inner", left, right, cond) => val left_len = left.schema_len; val right_len = right.schema_len; val right_indices = 0..right_len; val left_indices = 0..left_len; val remapping = (left_indices.map(i => (i, i + right_len)) ++ right_indices.map(i => (left_len + i, i))).to_map(); Project( Join("Inner", right, left, rewrite_column_refs(cond, remapping)), left_indices.map(i => ColumnRef(i + right_len)) ++ right_indices.map(i => ColumnRef(i) ) ) def has_refs_in_range(cond: Scalar, from: Int64, to: Int64): Bool = match predicate case ColumnRef(idx) => from <= idx && idx < to, case _ => predicate.children.any(child => has_refs_in_range(child, from, to)) @rule(Logical) def join_associate(expr: Logical): Logical = match expr case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) => val a_len = a.schema_len; if !has_refs_in_range(cond_outer, 0, a_len) then val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map(); Join( "Inner", a, Join("Inner", b, c, rewrite_column_refs(cond_outer, remap_inner), cond_inner) ) else fail("") @rule(Scalar) def conjunctive_normal_form(expr: Scalar): Scalar = fail("unimplemented") def with_optional_filter(key: String, old: Scalar, grouped: Map[String, [Scalar]]): Scalar = match grouped(key) case Some(conds) => Filter(conds, old), case _ => old @rule(Logical) def filter_pushdown_join(expr: Logical): Logical = match expr case op @ Filter(Join(join_type, left, right, join_cond), cond) => val cnf = conjunctive_normal_form(cond); val grouped = cnf.children.groupBy(cond => { if has_refs_in_range(cond, 0, left.schema_len) && !has_refs_in_range(cond, left.schema_len, op.schema_len) then "left" else if !has_refs_in_range(cond, 0, left.schema_len) && has_refs_in_range(cond, left.schema_len, op.schema_len) then "right" else "remain" }); with_optional_filter("remain", Join(join_type, with_optional_filter("left", grouped), with_optional_filter("right", grouped), join_cond ) ) ``` --------- Signed-off-by: Yuchen Liang <[email protected]> Co-authored-by: Yuchen Liang <[email protected]> Co-authored-by: Yuchen Liang <[email protected]>
- Loading branch information