diff --git a/Cargo.lock b/Cargo.lock index 8107755..c7d57fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -800,6 +800,8 @@ dependencies = [ "anyhow", "async-recursion", "dotenvy", + "pest", + "pest_derive", "proc-macro2", "serde", "serde_json", @@ -852,6 +854,51 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pest" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "816518421cfc6887a0d62bf441b6ffb4536fcc926395a69e1a85852d4363f57e" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d1396fd3a870fc7838768d171b4616d5c91f6cc25e377b673d714567d99377b" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pest_meta" +version = "2.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e58089ea25d717bfd31fb534e4f3afcc2cc569c70de3e239778991ea3b7dea" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -1526,6 +1573,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "unicode-bidi" version = "0.3.18" diff --git a/optd-core/Cargo.toml b/optd-core/Cargo.toml index f464e87..94efe51 100644 --- a/optd-core/Cargo.toml +++ b/optd-core/Cargo.toml @@ -15,3 +15,5 @@ serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1", features = ["raw_value"] } dotenvy = "0.15" async-recursion = "1.1.1" +pest = "2.7.15" +pest_derive = "2.7.15" diff --git a/optd-core/src/cascades/mod.rs b/optd-core/src/cascades/mod.rs index 984d487..72ed19a 100644 --- a/optd-core/src/cascades/mod.rs +++ b/optd-core/src/cascades/mod.rs @@ -147,7 +147,7 @@ mod tests { #[tokio::test] async fn test_ingest_partial_logical_plan() -> anyhow::Result<()> { - let memo = SqliteMemo::new("sqlite://memo.db").await?; + let memo = SqliteMemo::new_in_memory().await?; // select * from t1, t2 where t1.id = t2.id and t2.name = 'Memo' and t2.v1 = 1 + 1 let partial_logical_plan = filter( join( diff --git a/optd-core/src/dsl/mod.rs b/optd-core/src/dsl/mod.rs new file mode 100644 index 0000000..67c567f --- /dev/null +++ b/optd-core/src/dsl/mod.rs @@ -0,0 +1 @@ +pub mod parser; diff --git a/optd-core/src/dsl/parser/ast.rs b/optd-core/src/dsl/parser/ast.rs new file mode 100644 index 0000000..609f2ad --- /dev/null +++ b/optd-core/src/dsl/parser/ast.rs @@ -0,0 +1,151 @@ +use std::collections::HashMap; + +/// Types supported by the language +#[derive(Debug, Clone, PartialEq)] +pub enum Type { + Int64, + String, + Bool, + Float64, + Array(Box), // Array types like [T] + Map(Box, Box), // Map types like map[K->V] + Tuple(Vec), // Tuple types like (T1, T2) + Function(Box, Box), // Function types like (T1)->T2 + Operator(OperatorKind), // Operator types (scalar/logical) +} + +/// Kinds of operators supported in the language +#[derive(Debug, Clone, PartialEq)] +pub enum OperatorKind { + Scalar, // Scalar operators + Logical, // Logical operators with derivable properties +} + +/// A field in an operator or properties block +#[derive(Debug, Clone)] +pub struct Field { + pub name: String, + pub ty: Type, +} + +/// Logical properties block that must appear exactly once per file +#[derive(Debug, Clone)] +pub struct Properties { + pub fields: Vec, +} + +/// Top-level operator definition +#[derive(Debug, Clone)] +pub enum Operator { + Scalar(ScalarOp), + Logical(LogicalOp), +} + +/// Scalar operator definition +#[derive(Debug, Clone)] +pub struct ScalarOp { + pub name: String, + pub fields: Vec, +} + +/// Logical operator definition with derived properties +#[derive(Debug, Clone)] +pub struct LogicalOp { + pub name: String, + pub fields: Vec, + pub derived_props: HashMap, // Maps property names to their derivation expressions +} + +/// Patterns used in match expressions +#[derive(Debug, Clone)] +pub enum Pattern { + Bind(String, Box), // Binding patterns like x@p or x:p + Constructor( + String, // Constructor name + Vec, // Subpatterns, can be named (x:p) or positional + ), + Literal(Literal), // Literal patterns like 42 or "hello" + Wildcard, // Wildcard pattern _ + Var(String), // Variable binding pattern +} + +/// Literal values +#[derive(Debug, Clone)] +pub enum Literal { + Int64(i64), + String(String), + Bool(bool), + Float64(f64), + Array(Vec), // Array literals [e1, e2, ...] + Tuple(Vec), // Tuple literals (e1, e2, ...) +} + +/// Expressions - the core of the language +#[derive(Debug, Clone)] +pub enum Expr { + Match(Box, Vec), // Pattern matching + If(Box, Box, Box), // If-then-else + Val(String, Box, Box), // Local binding (val x = e1; e2) + Constructor(String, Vec), // Constructor application (currently only operators) + Binary(Box, BinOp, Box), // Binary operations + Unary(UnaryOp, Box), // Unary operations + Call(Box, Vec), // Function application + Member(Box, String), // Field access (e.f) + MemberCall(Box, String, Vec), // Method call (e.f(args)) + ArrayIndex(Box, Box), // Array indexing (e[i]) + Var(String), // Variable reference + Literal(Literal), // Literal values + Fail(String), // Failure with message + Closure(Vec, Box), // Anonymous functions v = (x, y) => x + y; +} + +/// A case in a match expression +#[derive(Debug, Clone)] +pub struct MatchArm { + pub pattern: Pattern, + pub expr: Expr, +} + +/// Binary operators with fixed precedence +#[derive(Debug, Clone)] +pub enum BinOp { + Add, // + + Sub, // - + Mul, // * + Div, // / + Concat, // ++ + Eq, // == + Neq, // != + Gt, // > + Lt, // < + Ge, // >= + Le, // <= + And, // && + Or, // || + Range, // .. +} + +/// Unary operators +#[derive(Debug, Clone)] +pub enum UnaryOp { + Neg, // - + Not, // ! +} + +/// Function definition +#[derive(Debug, Clone)] +pub struct Function { + pub name: String, + pub params: Vec<(String, Type)>, // Parameter name and type pairs + pub return_type: Type, + pub body: Expr, + pub rule_type: Option, // Some if this is a rule, indicating what kind +} + +/// A complete source file +#[derive(Debug, Clone)] +pub struct File { + pub properties: Properties, // The single logical properties block + pub operators: Vec, // All operator definitions + pub functions: Vec, // All function definitions +} diff --git a/optd-core/src/dsl/parser/expr.rs b/optd-core/src/dsl/parser/expr.rs new file mode 100644 index 0000000..e15907d --- /dev/null +++ b/optd-core/src/dsl/parser/expr.rs @@ -0,0 +1,325 @@ +use pest::iterators::Pair; + +use super::{ + ast::{BinOp, Expr, Literal, MatchArm, UnaryOp}, + patterns::parse_pattern, + Rule, +}; + +/// Parse a complete expression from a pest Pair +/// +/// # Arguments +/// * `pair` - The pest Pair containing the expression +/// +/// # Returns +/// * `Expr` - The parsed expression AST node +pub fn parse_expr(pair: Pair<'_, Rule>) -> Expr { + match pair.as_rule() { + Rule::expr => parse_expr(pair.into_inner().next().unwrap()), + Rule::closure => parse_closure(pair), + Rule::logical_or + | Rule::logical_and + | Rule::comparison + | Rule::concatenation + | Rule::additive + | Rule::range + | Rule::multiplicative => parse_binary_operation(pair), + Rule::postfix => parse_postfix(pair), + Rule::prefix => parse_prefix(pair), + Rule::match_expr => parse_match_expr(pair), + Rule::if_expr => parse_if_expr(pair), + Rule::val_expr => parse_val_expr(pair), + Rule::array_literal => parse_array_literal(pair), + Rule::tuple_literal => parse_tuple_literal(pair), + Rule::constructor_expr => parse_constructor(pair), + Rule::number => parse_number(pair), + Rule::string => parse_string(pair), + Rule::boolean => parse_boolean(pair), + Rule::identifier => Expr::Var(pair.as_str().to_string()), + Rule::fail_expr => parse_fail_expr(pair), + _ => unreachable!("Unexpected expression rule: {:?}", pair.as_rule()), + } +} + +/// Parse a closure expression (e.g., "(x, y) => x + y") +fn parse_closure(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let params_pair = pairs.next().unwrap(); + + let params = if params_pair.as_rule() == Rule::identifier { + vec![params_pair.as_str().to_string()] + } else { + params_pair + .into_inner() + .map(|p| p.as_str().to_string()) + .collect() + }; + + let body = parse_expr(pairs.next().unwrap()); + Expr::Closure(params, Box::new(body)) +} + +/// Parse a binary operation with proper operator precedence +fn parse_binary_operation(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let mut expr = parse_expr(pairs.next().unwrap()); + + while let Some(op_pair) = pairs.next() { + let op = parse_binary_operator(op_pair); + let rhs = parse_expr(pairs.next().unwrap()); + expr = Expr::Binary(Box::new(expr), op, Box::new(rhs)); + } + + expr +} + +/// Parse a postfix expression (function calls, member access, array indexing) +fn parse_postfix(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let mut expr = parse_expr(pairs.next().unwrap()); + + for postfix_pair in pairs { + match postfix_pair.as_rule() { + Rule::call => { + let args = postfix_pair.into_inner().map(parse_expr).collect(); + expr = Expr::Call(Box::new(expr), args); + } + Rule::member_access => { + let member = postfix_pair.as_str().trim_start_matches('.').to_string(); + expr = Expr::Member(Box::new(expr), member); + } + Rule::array_index => { + let index = parse_expr(postfix_pair.into_inner().next().unwrap()); + expr = Expr::ArrayIndex(Box::new(expr), Box::new(index)); + } + Rule::member_call => { + let mut pairs = postfix_pair.into_inner(); + let member = pairs + .next() + .unwrap() + .as_str() + .trim_start_matches('.') + .to_string(); + let args = pairs.map(parse_expr).collect(); + expr = Expr::MemberCall(Box::new(expr), member, args); + } + _ => unreachable!("Unexpected postfix rule: {:?}", postfix_pair.as_rule()), + } + } + expr +} + +/// Parse a prefix expression (unary operators) +fn parse_prefix(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let first = pairs.next().unwrap(); + + if first.as_str() == "!" || first.as_str() == "-" { + let op = match first.as_str() { + "-" => UnaryOp::Neg, + "!" => UnaryOp::Not, + _ => unreachable!("Unexpected prefix operator: {}", first.as_str()), + }; + let expr = parse_expr(pairs.next().unwrap()); + Expr::Unary(op, Box::new(expr)) + } else { + parse_expr(first) + } +} + +/// Parse a match expression +fn parse_match_expr(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let expr = parse_expr(pairs.next().unwrap()); + let mut arms = Vec::new(); + + for arm_pair in pairs { + if arm_pair.as_rule() == Rule::match_arm { + arms.push(parse_match_arm(arm_pair)); + } + } + + Expr::Match(Box::new(expr), arms) +} + +/// Parse a match arm within a match expression +fn parse_match_arm(pair: Pair<'_, Rule>) -> MatchArm { + let mut pairs = pair.into_inner(); + let pattern = parse_pattern(pairs.next().unwrap()); + let expr = parse_expr(pairs.next().unwrap()); + MatchArm { pattern, expr } +} + +/// Parse an if expression +fn parse_if_expr(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let condition = parse_expr(pairs.next().unwrap()); + let then_branch = parse_expr(pairs.next().unwrap()); + let else_branch = parse_expr(pairs.next().unwrap()); + + Expr::If( + Box::new(condition), + Box::new(then_branch), + Box::new(else_branch), + ) +} + +/// Parse a val expression (local binding) +fn parse_val_expr(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let name = pairs.next().unwrap().as_str().to_string(); + let value = parse_expr(pairs.next().unwrap()); + let body = parse_expr(pairs.next().unwrap()); + + Expr::Val(name, Box::new(value), Box::new(body)) +} + +/// Parse an array literal +fn parse_array_literal(pair: Pair<'_, Rule>) -> Expr { + let exprs = pair.into_inner().map(parse_expr).collect(); + Expr::Literal(Literal::Array(exprs)) +} + +/// Parse a tuple literal +fn parse_tuple_literal(pair: Pair<'_, Rule>) -> Expr { + let exprs = pair.into_inner().map(parse_expr).collect(); + Expr::Literal(Literal::Tuple(exprs)) +} + +/// Parse a constructor expression +fn parse_constructor(pair: Pair<'_, Rule>) -> Expr { + let mut pairs = pair.into_inner(); + let name = pairs.next().unwrap().as_str().to_string(); + let args = pairs.map(parse_expr).collect(); + + // Check if first character is lowercase + // TODO(alexis): small hack until I rewrite the grammar using Chumsky + if name.chars().next().is_some_and(|c| c.is_ascii_lowercase()) { + Expr::Call(Box::new(Expr::Var(name)), args) + } else { + Expr::Constructor(name, args) + } +} + +/// Parse a numeric literal +fn parse_number(pair: Pair<'_, Rule>) -> Expr { + let num = pair.as_str().parse().unwrap(); + Expr::Literal(Literal::Int64(num)) +} + +/// Parse a string literal +fn parse_string(pair: Pair<'_, Rule>) -> Expr { + let s = pair.as_str().to_string(); + Expr::Literal(Literal::String(s)) +} + +/// Parse a boolean literal +fn parse_boolean(pair: Pair<'_, Rule>) -> Expr { + let b = pair.as_str().parse().unwrap(); + Expr::Literal(Literal::Bool(b)) +} + +/// Parse a fail expression +fn parse_fail_expr(pair: Pair<'_, Rule>) -> Expr { + let msg = pair.into_inner().next().unwrap().as_str().to_string(); + Expr::Fail(msg) +} + +/// Parse a binary operator +fn parse_binary_operator(pair: Pair<'_, Rule>) -> BinOp { + let op_str = match pair.as_rule() { + Rule::add_op + | Rule::mult_op + | Rule::or_op + | Rule::and_op + | Rule::compare_op + | Rule::concat_op + | Rule::range_op => pair.as_str(), + _ => pair.as_str(), + }; + + match op_str { + "+" => BinOp::Add, + "-" => BinOp::Sub, + "*" => BinOp::Mul, + "/" => BinOp::Div, + "++" => BinOp::Concat, + "==" => BinOp::Eq, + "!=" => BinOp::Neq, + ">" => BinOp::Gt, + "<" => BinOp::Lt, + ">=" => BinOp::Ge, + "<=" => BinOp::Le, + "&&" => BinOp::And, + "||" => BinOp::Or, + ".." => BinOp::Range, + _ => unreachable!("Unexpected binary operator: {}", op_str), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dsl::parser::DslParser; + use pest::Parser; + + fn parse_expr_from_str(input: &str) -> Expr { + let pair = DslParser::parse(Rule::expr, input).unwrap().next().unwrap(); + parse_expr(pair) + } + + #[test] + fn test_parse_binary_operations() { + let expr = parse_expr_from_str("1 + 2 * 3"); + match expr { + Expr::Binary(left, BinOp::Add, right) => { + assert!(matches!(*left, Expr::Literal(Literal::Int64(1)))); + match *right { + Expr::Binary(l, BinOp::Mul, r) => { + assert!(matches!(*l, Expr::Literal(Literal::Int64(2)))); + assert!(matches!(*r, Expr::Literal(Literal::Int64(3)))); + } + _ => panic!("Expected multiplication"), + } + } + _ => panic!("Expected addition"), + } + } + + #[test] + fn test_parse_if_expression() { + let expr = parse_expr_from_str("if x > 0 then 1 else 2"); + match expr { + Expr::If(cond, then_branch, else_branch) => { + match *cond { + Expr::Binary(left, BinOp::Gt, right) => { + assert!(matches!(*left, Expr::Var(v) if v == "x")); + assert!(matches!(*right, Expr::Literal(Literal::Int64(0)))); + } + _ => panic!("Expected comparison"), + } + assert!(matches!(*then_branch, Expr::Literal(Literal::Int64(1)))); + assert!(matches!(*else_branch, Expr::Literal(Literal::Int64(2)))); + } + _ => panic!("Expected if expression"), + } + } + + #[test] + fn test_parse_closure() { + let expr = parse_expr_from_str("(x, y) => x + y"); + match expr { + Expr::Closure(params, body) => { + assert_eq!(params, vec!["x", "y"]); + match *body { + Expr::Binary(left, BinOp::Add, right) => { + assert!(matches!(*left, Expr::Var(v) if v == "x")); + assert!(matches!(*right, Expr::Var(v) if v == "y")); + } + _ => panic!("Expected addition in closure body"), + } + } + _ => panic!("Expected closure"), + } + } +} diff --git a/optd-core/src/dsl/parser/functions.rs b/optd-core/src/dsl/parser/functions.rs new file mode 100644 index 0000000..3807cd7 --- /dev/null +++ b/optd-core/src/dsl/parser/functions.rs @@ -0,0 +1,227 @@ +use pest::iterators::Pair; + +use super::ast::{Function, OperatorKind, Type}; +use super::expr::parse_expr; +use super::types::parse_type_expr; +use super::Rule; + +/// Parse a function definition from a pest Pair +/// +/// # Arguments +/// * `pair` - The pest Pair containing the function definition +/// +/// # Returns +/// * `Function` - The parsed function AST node +pub fn parse_function_def(pair: Pair<'_, Rule>) -> Function { + let mut name = None; + let mut params = Vec::new(); + let mut return_type = None; + let mut body = None; + let mut rule_type = None; + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::identifier => { + name = Some(inner_pair.as_str().to_string()); + } + Rule::params => { + params = parse_params(inner_pair); + } + Rule::type_expr => { + return_type = Some(parse_type_expr(inner_pair)); + } + Rule::expr => { + body = Some(parse_expr(inner_pair)); + } + Rule::rule_annot => { + rule_type = Some(parse_rule_annotation(inner_pair)); + } + _ => unreachable!( + "Unexpected function definition rule: {:?}", + inner_pair.as_rule() + ), + } + } + + Function { + name: name.unwrap(), + params, + return_type: return_type.unwrap(), + body: body.unwrap(), + rule_type, + } +} + +/// Parse function parameters +fn parse_params(pair: Pair<'_, Rule>) -> Vec<(String, Type)> { + let mut params = Vec::new(); + + for param_pair in pair.into_inner() { + if param_pair.as_rule() == Rule::param { + let (name, ty) = parse_param(param_pair); + params.push((name, ty)); + } + } + + params +} + +/// Parse a single parameter +fn parse_param(pair: Pair<'_, Rule>) -> (String, Type) { + let mut param_name = None; + let mut param_type = None; + + for param_inner_pair in pair.into_inner() { + match param_inner_pair.as_rule() { + Rule::identifier => { + param_name = Some(param_inner_pair.as_str().to_string()); + } + Rule::type_expr => { + param_type = Some(parse_type_expr(param_inner_pair)); + } + _ => unreachable!( + "Unexpected parameter rule: {:?}", + param_inner_pair.as_rule() + ), + } + } + + (param_name.unwrap(), param_type.unwrap()) +} + +/// Parse a rule annotation (@rule(Scalar) or @rule(Logical)) +fn parse_rule_annotation(pair: Pair<'_, Rule>) -> OperatorKind { + let annot_inner_pair = pair.into_inner().next().unwrap(); + match annot_inner_pair.as_str() { + "Scalar" => OperatorKind::Scalar, + "Logical" => OperatorKind::Logical, + _ => unreachable!("Unexpected rule type: {}", annot_inner_pair.as_str()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dsl::parser::ast::{BinOp, Expr, Literal, Type}; + use crate::dsl::parser::{DslParser, Rule}; + use pest::Parser; + + fn parse_function_from_str(input: &str) -> Function { + let pair = DslParser::parse(Rule::function_def, input) + .unwrap() + .next() + .unwrap(); + parse_function_def(pair) + } + + #[test] + fn test_parse_simple_function() { + let input = "def add(x: Int64, y: Int64): Int64 = x + y"; + + let function = parse_function_from_str(input); + assert_eq!(function.name, "add"); + assert_eq!(function.params.len(), 2); + + // Check parameters + assert_eq!(function.params[0].0, "x"); + assert!(matches!(function.params[0].1, Type::Int64)); + assert_eq!(function.params[1].0, "y"); + assert!(matches!(function.params[1].1, Type::Int64)); + + // Check return type + assert!(matches!(function.return_type, Type::Int64)); + + // Check body + match function.body { + Expr::Binary(left, BinOp::Add, right) => { + assert!(matches!(*left, Expr::Var(v) if v == "x")); + assert!(matches!(*right, Expr::Var(v) if v == "y")); + } + _ => panic!("Expected binary addition expression"), + } + + assert!(function.rule_type.is_none()); + } + + #[test] + fn test_parse_function_with_rule_annotation() { + let input = r#"@rule(Scalar) + def increment(x: Int64): Int64 = x + 1 + "#; + + let function = parse_function_from_str(input); + assert_eq!(function.name, "increment"); + assert_eq!(function.params.len(), 1); + assert!(matches!(function.rule_type, Some(OperatorKind::Scalar))); + } + + #[test] + fn test_parse_function_with_block_body() { + let input = r#"def max(x: Int64, y: Int64): Int64 = { + if x > y then x else y + } + "#; + + let function = parse_function_from_str(input); + assert_eq!(function.name, "max"); + assert_eq!(function.params.len(), 2); + + match function.body { + Expr::If(cond, then_branch, else_branch) => { + match *cond { + Expr::Binary(left, BinOp::Gt, right) => { + assert!(matches!(*left, Expr::Var(v) if v == "x")); + assert!(matches!(*right, Expr::Var(v) if v == "y")); + } + _ => panic!("Expected comparison expression"), + } + assert!(matches!(*then_branch, Expr::Var(v) if v == "x")); + assert!(matches!(*else_branch, Expr::Var(v) if v == "y")); + } + _ => panic!("Expected if expression"), + } + } + + #[test] + fn test_parse_function_with_array_parameter() { + let input = "def sum(values: [Int64]): Int64 = 0"; + + let function = parse_function_from_str(input); + assert_eq!(function.name, "sum"); + assert_eq!(function.params.len(), 1); + + match &function.params[0].1 { + Type::Array(inner) => assert!(matches!(**inner, Type::Int64)), + _ => panic!("Expected array type parameter"), + } + } + + #[test] + fn test_parse_function_with_no_parameters() { + let input = "def get_zero(): Int64 = 0"; + + let function = parse_function_from_str(input); + assert_eq!(function.name, "get_zero"); + assert_eq!(function.params.len(), 0); + assert!(matches!(function.return_type, Type::Int64)); + match function.body { + Expr::Literal(Literal::Int64(n)) => assert_eq!(n, 0), + _ => panic!("Expected integer literal"), + } + } + + #[test] + fn test_parse_function_with_complex_return_type() { + let input = "def make_pair(x: Int64): (Int64, Int64) = (x, x)"; + + let function = parse_function_from_str(input); + match function.return_type { + Type::Tuple(types) => { + assert_eq!(types.len(), 2); + assert!(matches!(types[0], Type::Int64)); + assert!(matches!(types[1], Type::Int64)); + } + _ => panic!("Expected tuple return type"), + } + } +} diff --git a/optd-core/src/dsl/parser/grammar.pest b/optd-core/src/dsl/parser/grammar.pest new file mode 100644 index 0000000..4a415a9 --- /dev/null +++ b/optd-core/src/dsl/parser/grammar.pest @@ -0,0 +1,239 @@ +KEYWORD = { + "Bool" | + "Float64" | + "Int64" | + "Logical" | + "Map" | + "Props" | + "Scalar" | + "String" | + "case" | + "def" | + "derive" | + "else" | + "fail" | + "false" | + "if" | + "match" | + "then" | + "true" | + "val" +} + +// Whitespace and comments +WHITESPACE = _{ " " | "\t" | "\r" | "\n" } +COMMENT = _{ "//" ~ (!"\n" ~ ANY)* } + +// Basic identifiers and literals +identifier = @{ + !(KEYWORD ~ !(ASCII_ALPHANUMERIC | "_")) ~ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* +} +number = @{ "-"? ~ ASCII_DIGIT+ } +string = @{ "\"" ~ (!"\"" ~ ANY)* ~ "\"" } +boolean = { "true" | "false" } + +// Types +operator_type = { "Scalar" | "Logical" } +base_type = { "Int64" | "String" | "Bool" | "Float64" } +array_type = { "[" ~ type_expr ~ "]" } +map_type = { "Map" ~ "[" ~ type_expr ~ "," ~ type_expr ~ "]" } +tuple_type = { "(" ~ type_expr ~ "," ~ type_expr ~ ("," ~ type_expr)* ~ ")" } +function_type = { "(" ~ type_expr ~ ")" ~ "->" ~ type_expr } // Function composition type +type_expr = { + function_type + | operator_type + | base_type + | array_type + | map_type + | tuple_type +} + +// Annotations +props_annot = { "Logical" ~ "Props" } +rule_annot = { "@rule" ~ "(" ~ operator_type ~ ")" } + +// Operators +operator_def = { + operator_type ~ identifier ~ NEWLINE? + ~ "(" ~ NEWLINE? + ~ (field_def_list ~ NEWLINE?)? + ~ ")" ~ NEWLINE? + ~ (derive_props_block ~ NEWLINE?)? +} + +field_def_list = { + (field_def ~ "," ~ NEWLINE?)* ~ field_def +} + +field_def = { identifier ~ ":" ~ type_expr } + +// Property blocks +props_block = { + props_annot ~ NEWLINE? + ~ "(" ~ NEWLINE? + ~ (field_def ~ "," ~ NEWLINE?)* ~ field_def + ~ ")" +} + +derive_props_block = { + "derive" ~ NEWLINE? + ~ "{" ~ NEWLINE? + ~ (prop_derivation ~ ("," ~ NEWLINE? ~ prop_derivation)*)? ~ NEWLINE? + ~ "}" +} + +prop_derivation = { identifier ~ "=" ~ expr } + +// Functions +function_def = { normal_function | rule_def } +normal_function = _{ + "def" ~ identifier ~ "(" ~ params? ~ ")" ~ ":" ~ type_expr ~ "=" ~ NEWLINE? + ~ (expr | "{" ~ NEWLINE? ~ expr ~ NEWLINE? ~ "}") // Braces optional +} + +rule_def = _{ + rule_annot ~ NEWLINE? + ~ normal_function +} + +params = { param ~ ("," ~ param)* } +param = { identifier ~ ":" ~ type_expr } + +// Pattern matching +pattern = { + (identifier ~ "@" ~ top_level_pattern) + | top_level_pattern +} + +top_level_pattern = _{ + constructor_pattern + | literal_pattern + | wildcard_pattern +} + +constructor_pattern = { + identifier ~ "(" ~ constructor_pattern_fields? ~ ")" +} + +constructor_pattern_fields = { + pattern_field ~ ("," ~ pattern_field)* +} + +// Inside a constructor +pattern_field = { + (identifier ~ ":" ~ top_level_pattern) // Bind & match + | top_level_pattern // Recurse match + | identifier // Simple bind +} + +literal_pattern = { + number + | string + | boolean + | array_literal + | tuple_literal +} + +wildcard_pattern = { "_" } + +// Expressions with precedence (lowest to highest) +expr = { + "{" ~ NEWLINE? ~ (closure | logical_or) ~ NEWLINE? ~ "}" + | (closure | logical_or) +} + +// Operators as separate rules +or_op = { "||" } +and_op = { "&&" } +compare_op = { "==" | "!=" | ">=" | "<=" | ">" | "<" } +concat_op = { "++" } +add_op = { "+" | "-" } +range_op = { ".." } +mult_op = { "*" | "/" } + +// Closure is lowest precedence +closure = { closure_params ~ "=>" ~ expr } +closure_params = { + "(" ~ (identifier ~ ("," ~ identifier)*)? ~ ")" + | identifier // Single param case +} + +logical_or = { logical_and ~ (or_op ~ logical_and)* } +logical_and = { comparison ~ (and_op ~ comparison)* } +comparison = { concatenation ~ (compare_op ~ concatenation)* } +concatenation = { additive ~ (concat_op ~ additive)* } +additive = { range ~ (add_op ~ range)* } +range = { multiplicative ~ (range_op ~ multiplicative)* } +multiplicative = { postfix ~ (mult_op ~ postfix)* } +postfix = { prefix ~ (member_call | member_access | call | array_index)* } +prefix = { ("!" | "-")? ~ primary } + +primary = _{ + match_expr + | if_expr + | val_expr + | array_literal + | tuple_literal + | constructor_expr + | term +} + +term = _{ + number + | string + | boolean + | identifier + | "(" ~ expr ~ ")" + | fail_expr +} + +// Constructor syntax +constructor_expr = { + identifier ~ "(" ~ constructor_fields? ~ ")" +} + +constructor_fields = _{ + expr ~ ("," ~ expr)* +} + +// Method calls and operations +call = { "(" ~ (expr ~ ("," ~ expr)*)? ~ ")" } +member_access = { "." ~ identifier } +member_call = { "." ~ identifier ~ "(" ~ (expr ~ ("," ~ expr)*)? ~ ")"} +array_index = { "[" ~ expr ~ "]" } + +// Array and map literals +array_literal = { "[" ~ (expr ~ ("," ~ expr)*)? ~ "]" } +tuple_literal = { "(" ~ (expr ~ ("," ~ expr)*)? ~ ")" } + +// Control expressions +fail_expr = { "fail" ~ "(" ~ string ~ ")" } + +// Braces optional +match_expr = { + "match" ~ expr ~ NEWLINE? + ~ (("{" ~ NEWLINE? + ~ (match_arm ~ ("," ~ NEWLINE? ~ match_arm)*)? + ~ "}") + | (match_arm ~ ("," ~ NEWLINE? ~ match_arm)*)) + ~ NEWLINE? +} +match_arm = { "case" ~ pattern ~ "=>" ~ expr } + +// Braces optional +if_expr = { + "if" ~ expr ~ NEWLINE? + ~ (("then" ~ expr) | "{" ~ NEWLINE? ~ expr ~ NEWLINE? ~ "}") + ~ "else" ~ NEWLINE? + ~ (expr | "{" ~ NEWLINE? ~ expr ~ NEWLINE? ~ "}") +} + +val_expr = { "val" ~ identifier ~ "=" ~ expr ~ ";" ~ expr } + +// Full file +file = { + SOI ~ + props_block ~ + (operator_def | function_def)* ~ + EOI +} \ No newline at end of file diff --git a/optd-core/src/dsl/parser/mod.rs b/optd-core/src/dsl/parser/mod.rs new file mode 100644 index 0000000..95da158 --- /dev/null +++ b/optd-core/src/dsl/parser/mod.rs @@ -0,0 +1,214 @@ +use pest::{iterators::Pair, Parser}; +use pest_derive::Parser; + +pub mod ast; +pub mod expr; +pub mod functions; +pub mod operators; +pub mod patterns; +pub mod types; + +use ast::*; +use functions::parse_function_def; +use operators::parse_operator_def; +use types::parse_type_expr; + +use pest::error::Error; + +/// The main parser for the DSL, derived using pest +#[derive(Parser)] +#[grammar = "dsl/parser/grammar.pest"] +pub struct DslParser; + +/// Parses a complete DSL file into the AST +/// +/// # Arguments +/// * `input` - The input string containing the DSL code +/// +/// # Returns +/// * `Result>` - The parsed AST or a parsing error +pub fn parse_file(input: &str) -> Result>> { + let pairs = DslParser::parse(Rule::file, input)? + .next() + .unwrap() + .into_inner(); + + let mut file = File { + properties: Properties { fields: Vec::new() }, + operators: Vec::new(), + functions: Vec::new(), + }; + + for pair in pairs { + match pair.as_rule() { + Rule::props_block => { + file.properties = parse_properties_block(pair); + } + Rule::operator_def => { + file.operators.push(parse_operator_def(pair)); + } + Rule::function_def => { + file.functions.push(parse_function_def(pair)); + } + _ => {} + } + } + + Ok(file) +} + +/// Parses a properties block in the DSL +/// +/// # Arguments +/// * `pair` - The pest Pair containing the properties block +/// +/// # Returns +/// * `Properties` - The parsed properties structure +fn parse_properties_block(pair: Pair) -> Properties { + let mut properties = Properties { fields: Vec::new() }; + + for field_pair in pair.into_inner() { + if field_pair.as_rule() == Rule::field_def { + properties.fields.push(parse_field_def(field_pair)); + } + } + + properties +} + +/// Parses a field definition +/// +/// # Arguments +/// * `pair` - The pest Pair containing the field definition +/// +/// # Returns +/// * `Field` - The parsed field structure +pub(crate) fn parse_field_def(pair: Pair) -> Field { + let mut name = None; + let mut ty = None; + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::identifier => { + name = Some(inner_pair.as_str().to_string()); + } + Rule::type_expr => { + ty = Some(parse_type_expr(inner_pair)); + } + _ => {} + } + } + + Field { + name: name.unwrap(), + ty: ty.unwrap(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pest::Parser; + + fn parse_props_from_str(input: &str) -> Properties { + let pair = DslParser::parse(Rule::props_block, input) + .unwrap() + .next() + .unwrap(); + parse_properties_block(pair) + } + + #[test] + fn parse_example_files() { + let input = include_str!("../parser/programs/example.optd"); + parse_file(input).unwrap(); + + let input = include_str!("../parser/programs/working.optd"); + let out = parse_file(input).unwrap(); + println!("{:#?}", out); + } + + #[test] + fn test_parse_single_property() { + let input = r#"Logical Props ( + name: String + )"#; + + let props = parse_props_from_str(input); + assert_eq!(props.fields.len(), 1); + assert_eq!(props.fields[0].name, "name"); + assert!(matches!(props.fields[0].ty, Type::String)); + } + + #[test] + fn test_parse_multiple_properties() { + let input = r#"Logical Props( + name: String, + value: Int64 + )"#; + + let props = parse_props_from_str(input); + assert_eq!(props.fields.len(), 2); + assert_eq!(props.fields[0].name, "name"); + assert!(matches!(props.fields[0].ty, Type::String)); + assert_eq!(props.fields[1].name, "value"); + assert!(matches!(props.fields[1].ty, Type::Int64)); + } + + #[test] + fn test_parse_array_property() { + let input = r#"Logical Props ( + values: [Int64] + )"#; + + let props = parse_props_from_str(input); + assert_eq!(props.fields.len(), 1); + assert_eq!(props.fields[0].name, "values"); + match &props.fields[0].ty { + Type::Array(inner) => assert!(matches!(**inner, Type::Int64)), + _ => panic!("Expected array type"), + } + } + + #[test] + fn test_parse_complex_property_types() { + let input = r#"Logical Props ( + pairs: Map[String, Int64], + coords: (Int64, Int64) + )"#; + + let props = parse_props_from_str(input); + assert_eq!(props.fields.len(), 2); + + // Check Map type + assert_eq!(props.fields[0].name, "pairs"); + match &props.fields[0].ty { + Type::Map(key, value) => { + assert!(matches!(**key, Type::String)); + assert!(matches!(**value, Type::Int64)); + } + _ => panic!("Expected map type"), + } + + // Check Tuple type + assert_eq!(props.fields[1].name, "coords"); + match &props.fields[1].ty { + Type::Tuple(types) => { + assert_eq!(types.len(), 2); + assert!(matches!(types[0], Type::Int64)); + assert!(matches!(types[1], Type::Int64)); + } + _ => panic!("Expected tuple type"), + } + } + + #[test] + #[should_panic] + fn test_parse_invalid_property() { + let input = r#"Logical Props ( + invalid + )"#; + + parse_props_from_str(input); + } +} diff --git a/optd-core/src/dsl/parser/operators.rs b/optd-core/src/dsl/parser/operators.rs new file mode 100644 index 0000000..fa96019 --- /dev/null +++ b/optd-core/src/dsl/parser/operators.rs @@ -0,0 +1,243 @@ +use pest::iterators::Pair; +use std::collections::HashMap; + +use super::{ + ast::{Expr, Field, LogicalOp, Operator, OperatorKind, ScalarOp}, + expr::parse_expr, + parse_field_def, Rule, +}; + +/// Parse an operator definition from a pest Pair +/// +/// # Arguments +/// * `pair` - The pest Pair containing the operator definition +/// +/// # Returns +/// * `Operator` - The parsed operator AST node +pub fn parse_operator_def(pair: Pair<'_, Rule>) -> Operator { + let mut operator_type = None; + let mut name = None; + let mut fields = Vec::new(); + let mut derived_props = HashMap::new(); + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::operator_type => { + operator_type = Some(parse_operator_type(inner_pair)); + } + Rule::identifier => { + name = Some(inner_pair.as_str().to_string()); + } + Rule::field_def_list => { + fields = parse_field_def_list(inner_pair); + } + Rule::derive_props_block => { + derived_props = parse_derive_props_block(inner_pair); + } + _ => unreachable!( + "Unexpected operator definition rule: {:?}", + inner_pair.as_rule() + ), + } + } + + create_operator(operator_type.unwrap(), name.unwrap(), fields, derived_props) +} + +/// Parse an operator type (Scalar or Logical) +fn parse_operator_type(pair: Pair<'_, Rule>) -> OperatorKind { + match pair.as_str() { + "Scalar" => OperatorKind::Scalar, + "Logical" => OperatorKind::Logical, + _ => unreachable!("Unexpected operator type: {}", pair.as_str()), + } +} + +/// Parse a list of field definitions +fn parse_field_def_list(pair: Pair<'_, Rule>) -> Vec { + pair.into_inner().map(parse_field_def).collect() +} + +/// Parse a derive properties block +fn parse_derive_props_block(pair: Pair<'_, Rule>) -> HashMap { + let mut props = HashMap::new(); + + for prop_pair in pair.into_inner() { + if prop_pair.as_rule() == Rule::prop_derivation { + let (name, expr) = parse_prop_derivation(prop_pair); + props.insert(name, expr); + } + } + + props +} + +/// Parse a single property derivation +fn parse_prop_derivation(pair: Pair<'_, Rule>) -> (String, Expr) { + let mut prop_name = None; + let mut expr = None; + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::identifier => { + prop_name = Some(inner_pair.as_str().to_string()); + } + Rule::expr => { + expr = Some(parse_expr(inner_pair)); + } + _ => unreachable!( + "Unexpected property derivation rule: {:?}", + inner_pair.as_rule() + ), + } + } + + (prop_name.unwrap(), expr.unwrap()) +} + +/// Create an operator based on its type and components +fn create_operator( + operator_type: OperatorKind, + name: String, + fields: Vec, + derived_props: HashMap, +) -> Operator { + match operator_type { + OperatorKind::Scalar => Operator::Scalar(ScalarOp { name, fields }), + OperatorKind::Logical => Operator::Logical(LogicalOp { + name, + fields, + derived_props, + }), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dsl::parser::{ + ast::{BinOp, Type}, + DslParser, + }; + use pest::Parser; + + fn parse_operator_from_str(input: &str) -> Operator { + let pair = DslParser::parse(Rule::operator_def, input) + .unwrap() + .next() + .unwrap(); + parse_operator_def(pair) + } + + #[test] + fn test_parse_scalar_operator() { + let input = r#"Scalar Add( + x: Int64, + y: Int64 + ) + "#; + + match parse_operator_from_str(input) { + Operator::Scalar(op) => { + assert_eq!(op.name, "Add"); + assert_eq!(op.fields.len(), 2); + assert_eq!(op.fields[0].name, "x"); + assert!(matches!(op.fields[0].ty, Type::Int64)); + assert_eq!(op.fields[1].name, "y"); + assert!(matches!(op.fields[1].ty, Type::Int64)); + } + _ => panic!("Expected scalar operator"), + } + } + + #[test] + fn test_parse_logical_operator() { + let input = r#"Logical And( + left: Bool, + right: Bool + ) derive { + result = left && right + } + "#; + + match parse_operator_from_str(input) { + Operator::Logical(op) => { + assert_eq!(op.name, "And"); + assert_eq!(op.fields.len(), 2); + assert_eq!(op.fields[0].name, "left"); + assert!(matches!(op.fields[0].ty, Type::Bool)); + assert_eq!(op.fields[1].name, "right"); + assert!(matches!(op.fields[1].ty, Type::Bool)); + + assert_eq!(op.derived_props.len(), 1); + assert!(op.derived_props.contains_key("result")); + + match op.derived_props.get("result").unwrap() { + Expr::Binary(left, BinOp::And, right) => { + assert!(matches!(**left, Expr::Var(ref v) if v == "left")); + assert!(matches!(**right, Expr::Var(ref v) if v == "right")); + } + _ => panic!("Expected binary AND expression"), + } + } + _ => panic!("Expected logical operator"), + } + } + + #[test] + fn test_parse_operator_with_multiple_derived_props() { + let input = r#"Logical Or( + left: Bool, + right: Bool + ) derive { + result = left || right, + description = "Logical OR" + } + "#; + + match parse_operator_from_str(input) { + Operator::Logical(op) => { + assert_eq!(op.name, "Or"); + assert_eq!(op.derived_props.len(), 2); + assert!(op.derived_props.contains_key("result")); + assert!(op.derived_props.contains_key("description")); + } + _ => panic!("Expected logical operator"), + } + } + + #[test] + fn test_parse_operator_without_fields() { + let input = r#"Scalar True() + "#; + + match parse_operator_from_str(input) { + Operator::Scalar(op) => { + assert_eq!(op.name, "True"); + assert_eq!(op.fields.len(), 0); + } + _ => panic!("Expected scalar operator"), + } + } + + #[test] + fn test_parse_operator_with_array_field() { + let input = r#"Scalar Sum( + values: [Int64] + ) + "#; + + match parse_operator_from_str(input) { + Operator::Scalar(op) => { + assert_eq!(op.name, "Sum"); + assert_eq!(op.fields.len(), 1); + assert_eq!(op.fields[0].name, "values"); + match &op.fields[0].ty { + Type::Array(inner) => assert!(matches!(**inner, Type::Int64)), + _ => panic!("Expected array type"), + } + } + _ => panic!("Expected scalar operator"), + } + } +} diff --git a/optd-core/src/dsl/parser/patterns.rs b/optd-core/src/dsl/parser/patterns.rs new file mode 100644 index 0000000..6065d82 --- /dev/null +++ b/optd-core/src/dsl/parser/patterns.rs @@ -0,0 +1,205 @@ +use pest::iterators::Pair; + +use super::{ + ast::{Literal, Pattern}, + expr::parse_expr, + Rule, +}; + +/// Parse a pattern from a pest Pair +/// +/// # Arguments +/// * `pair` - The pest Pair containing the pattern +/// +/// # Returns +/// * `Pattern` - The parsed pattern AST node +pub fn parse_pattern(pair: Pair<'_, Rule>) -> Pattern { + match pair.as_rule() { + Rule::pattern => { + let mut pairs = pair.into_inner(); + let first = pairs.next().unwrap(); + + if first.as_rule() == Rule::identifier { + if let Some(at_pattern) = pairs.next() { + // This is a binding pattern (x @ pattern) + let name = first.as_str().to_string(); + let pattern = parse_pattern(at_pattern); + Pattern::Bind(name, Box::new(pattern)) + } else { + // This is a simple variable pattern + Pattern::Var(first.as_str().to_string()) + } + } else { + parse_pattern(first) + } + } + Rule::constructor_pattern => parse_constructor_pattern(pair), + Rule::literal_pattern => parse_literal_pattern(pair), + Rule::wildcard_pattern => Pattern::Wildcard, + _ => unreachable!("Unexpected pattern rule: {:?}", pair.as_rule()), + } +} + +/// Parse a constructor pattern (e.g., Some(x), Node(left, right)) +fn parse_constructor_pattern(pair: Pair<'_, Rule>) -> Pattern { + let mut pairs = pair.into_inner(); + let name = pairs.next().unwrap().as_str().to_string(); + let mut subpatterns = Vec::new(); + + if let Some(fields) = pairs.next() { + for field in fields.into_inner() { + if field.as_rule() == Rule::pattern_field { + let mut field_pairs = field.into_inner(); + let first = field_pairs.next().unwrap(); + + if first.as_rule() == Rule::identifier { + // Check for field binding (identifier : pattern) + if let Some(colon_pattern) = field_pairs.next() { + let field_name = first.as_str().to_string(); + let pattern = parse_pattern(colon_pattern); + subpatterns.push(Pattern::Bind(field_name, Box::new(pattern))); + } else { + // Simple identifier pattern + subpatterns.push(Pattern::Var(first.as_str().to_string())); + } + } else { + // Regular pattern + subpatterns.push(parse_pattern(first)); + } + } + } + } + + Pattern::Constructor(name, subpatterns) +} + +/// Parse a literal pattern (numbers, strings, booleans, arrays, tuples) +fn parse_literal_pattern(pair: Pair<'_, Rule>) -> Pattern { + let literal = parse_literal(pair.into_inner().next().unwrap()); + Pattern::Literal(literal) +} + +/// Parse a literal value +fn parse_literal(pair: Pair<'_, Rule>) -> Literal { + match pair.as_rule() { + Rule::number => Literal::Int64(pair.as_str().parse().unwrap()), + Rule::string => Literal::String(pair.as_str().to_string()), + Rule::boolean => Literal::Bool(pair.as_str().parse().unwrap()), + Rule::array_literal => { + let exprs = pair.into_inner().map(parse_expr).collect(); + Literal::Array(exprs) + } + Rule::tuple_literal => { + let exprs = pair.into_inner().map(parse_expr).collect(); + Literal::Tuple(exprs) + } + _ => unreachable!("Unexpected literal rule: {:?}", pair.as_rule()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dsl::parser::DslParser; + use pest::Parser; + + fn parse_pattern_from_str(input: &str) -> Pattern { + let pair = DslParser::parse(Rule::pattern, input) + .unwrap() + .next() + .unwrap(); + parse_pattern(pair) + } + + #[test] + fn test_parse_constructor_pattern() { + match parse_pattern_from_str("Some(x)") { + Pattern::Constructor(name, patterns) => { + assert_eq!(name, "Some"); + assert_eq!(patterns.len(), 1); + match &patterns[0] { + Pattern::Var(var_name) => assert_eq!(var_name, "x"), + _ => panic!("Expected variable pattern inside constructor"), + } + } + _ => panic!("Expected constructor pattern"), + } + } + + #[test] + fn test_parse_literal_pattern() { + match parse_pattern_from_str("42") { + Pattern::Literal(Literal::Int64(n)) => assert_eq!(n, 42), + _ => panic!("Expected integer literal pattern"), + } + + match parse_pattern_from_str("\"hello\"") { + Pattern::Literal(Literal::String(s)) => assert_eq!(s, "\"hello\""), + _ => panic!("Expected string literal pattern"), + } + + match parse_pattern_from_str("true") { + Pattern::Literal(Literal::Bool(b)) => assert!(b), + _ => panic!("Expected boolean literal pattern"), + } + } + + #[test] + fn test_parse_wildcard_pattern() { + assert!(matches!(parse_pattern_from_str("_"), Pattern::Wildcard)); + } + + #[test] + fn test_parse_binding_pattern() { + match parse_pattern_from_str("x @ Some(y)") { + Pattern::Bind(name, pattern) => { + assert_eq!(name, "x"); + match *pattern { + Pattern::Constructor(cname, patterns) => { + assert_eq!(cname, "Some"); + assert_eq!(patterns.len(), 1); + match &patterns[0] { + Pattern::Var(var_name) => assert_eq!(var_name, "y"), + _ => panic!("Expected variable pattern inside constructor"), + } + } + _ => panic!("Expected constructor pattern after binding"), + } + } + _ => panic!("Expected binding pattern"), + } + } + + #[test] + fn test_parse_complex_constructor_pattern() { + match parse_pattern_from_str("Node(_, right: Some(x))") { + Pattern::Constructor(name, patterns) => { + assert_eq!(name, "Node"); + assert_eq!(patterns.len(), 2); + + // Check left pattern + match &patterns[0] { + Pattern::Wildcard => {} + _ => panic!("Expected wildcard pattern for left field"), + } + + // Check right pattern + match &patterns[1] { + Pattern::Bind(name, pattern) => { + assert_eq!(name, "right"); + match **pattern { + Pattern::Constructor(ref cname, ref inner_patterns) => { + assert_eq!(cname, "Some"); + assert_eq!(inner_patterns.len(), 1); + assert!(matches!(&inner_patterns[0], Pattern::Var(n) if n == "x")); + } + _ => panic!("Expected Some constructor in right pattern"), + } + } + _ => panic!("Expected binding pattern for right field"), + } + } + _ => panic!("Expected complex constructor pattern"), + } + } +} diff --git a/optd-core/src/dsl/parser/programs/example.optd b/optd-core/src/dsl/parser/programs/example.optd new file mode 100644 index 0000000..a07e4fe --- /dev/null +++ b/optd-core/src/dsl/parser/programs/example.optd @@ -0,0 +1,148 @@ +// Shared properties for Logical operators +Logical Props(vale: [Int64]) + +Scalar ColumnRef (idx: (Int64) -> Map[(Int64, (Int64) -> String), Int64]) +Scalar Const(asd: Int64) +Scalar Eq(left: Map[Int64, Int64]) +Scalar Divide(left: Scalar, right: Scalar) +Scalar Not(input: Scalar) + +// Logical Operators +Logical Filter(input: Logical, cond: Scalar) derive { + schema_len = input.schema_len +} + +Logical Project (input: Logical, exprs: [Scalar]) derive { + schema_len = exprs.len, + other = expr2 +} + +Logical Join( + left: Logical, + right: Logical, + type: String, + cond: Scalar +) derive { + schema_len = left.schema_len + right.schema_len +} + +Logical Sort(input: Logical, keys: [Scalar]) derive { + schema_len = input.schema_len +} + +Logical Aggregate(input: Logical, group_keys: [Scalar], aggs: [Scalar]) derive { + schema_len = group_keys.len + aggs.len +} + + @rule(Scalar) +def increment(x: Int64): Int64 = x + 1 + +// Rules demonstrating all language features +@rule(Scalar) +def constant_fold(expr: Scalar): Scalar = + match expr + case op @ Add(left: _, Const(bla)) => Const(x + y), + case op @ Multiply(left: Const(x), right: Const(y)) => { + Const(x * y) + }, + case And(ms) => { + val folded = ms.map(ms); + val test = (5 + 5) * 3 ++ 87..bla; + And(folded) + }, + case _ => expr + +def rewrite_column_refs(expr: Scalar, index_map: Map[Int64, Int64]): Scalar = + match expr + case ColumnRef(i) => ColumnRef(index_map.get(i)), + case _ => expr.with_children(child => rewrite_column_refs(child, index_map)) + +@rule(Scalar) +def has_refs_in_range(expr: Scalar, start: Int64, end: Int64): Bool = + match expr { + case ColumnRef(i) => i >= start && i < end, + case _ => expr.children().any(child => has_refs_in_range(child, start, end)) + } + +@rule(Logical) +def join_commute(expr: Logical): Logical = + match expr + case Join("Inner", left, right, cond) => { + val left_len = l.schema_len; + val right_len = r.schema_len; + + val refs_remap = + ((0..left_len).map(i => (i, i + right_len)) ++ + (0..right_len).map(i => (left_len + i, i))) + .to_map(); + + Join("Inner", r, l, rewrite_column_refs(c, refs_remap)) + } + +@rule(Logical) +def join_associate(expr: Logical): Logical = + match expr + case Join("Inner", l: Join("Inner", a, b, c1), c, c2) => + val a_len = a.schema_len; + val b_len = b.schema_len; + val c_len = c.schema_len; + if !has_refs_in_range(c2, 0, a_len) then + val inner_map = (a_len..(a_len + b_len + c_len)) + .map(i => (i, i - a_len)) + .to_map(); + Join("Inner", a, Join("Inner", b, c, rewrite_column_refs(c2, inner_map)), c1) + else + fail("Cannot rewrite: outer join condition references left relation") + +@rule(Logical) +def complex_project(expr: Logical): Logical = + match expr + case project @ Project(rel, es) => { + if es.len == 0 then + fail("Empty projection list") + else { + val mapped = es.map(e => constant_fold(e)); + val filtered = mapped.filter(e => + match e + case Const(_) => false, + case _ => true + ); + if filtered.len == 0 then + fail("All expressions folded to constants") + else + Project(rel, filtered) + } + } + +@rule(Scalar) +def simplify_not(expr: Scalar): Scalar = + match expr + case Not(Not(x)) => x, + case Not(Eq(l, r)) => { + val new_left = simplify_not(l); + val new_right = simplify_not(r); + Not(Eq(new_left, new_right)) + }, + case _ => expr + +@rule(Logical) +def push_filter_down(expr: Logical): Logical = + match expr + case Filter(Project(rel, es), cond) => { + val new_cond = rewrite_column_refs(cond, bla); + Project(Filter(rel, new_cond), es) + }, + case _ => expr + +def optimize_sort_keys(expr: Logical): Logical = + match expr + case Sort(rel, ks) => { + val simplified = ks.map(k => constant_fold(k)); + val filtered = simplified.filter(k => + match k + case Const(_) => false, + case _ => true + ); + Sort(rel, filtered) + }, + case _ => expr \ No newline at end of file diff --git a/optd-core/src/dsl/parser/programs/working.optd b/optd-core/src/dsl/parser/programs/working.optd new file mode 100644 index 0000000..d20fd8a --- /dev/null +++ b/optd-core/src/dsl/parser/programs/working.optd @@ -0,0 +1,123 @@ +// Logical Properties +Logical Props(schema_len: Int64) + +// Scalar Operators +Scalar ColumnRef(idx: Int64) +Scalar Mult(left: Int64, right: Int64) +Scalar Add(left: Int64, right: Int64) +Scalar And(children: [Scalar]) +Scalar Or(children: [Scalar]) +Scalar Not(child: Scalar) + +// Logical Operators +Logical Scan(table_name: String) derive { + schema_len = 5 // +} + +Logical Filter(child: Logical, cond: Scalar) derive { + schema_len = input.schema_len +} + + +Logical Project(child: Logical, exprs: [Scalar]) derive { + schema_len = exprs.len +} + +Logical Join( + left: Logical, + right: Logical, + typ: String, + cond: Scalar +) derive { + schema_len = left.schema_len + right.schema_len +} + +// memo(alexis): tuples should be allowed in here +Logical Sort(child: Logical, keys: [(Scalar, String)]) derive { + schema_len = input.schema_len +} + +Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) + derive { + schema_len = group_keys.len + aggs.len +} + +// Rules + +// memo(alexis): support member function for operators like apply_children +def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar = + match predicate + case ColumnRef(idx) => ColumnRef(map(idx)), + case other @ _ => predicate.apply_children(child => rewrite_column_refs(child, map)) + +@rule(Logical) +def join_commute(expr: Logical): Logical = + match expr + case Join("Inner", left, right, cond) => + val left_len = left.schema_len; + val right_len = right.schema_len; + + val right_indices = 0..right_len; + val left_indices = 0..left_len; + + val remapping = (left_indices.map(i => (i, i + right_len)) ++ + right_indices.map(i => (left_len + i, i))).to_map(); + + Project( + Join("Inner", right, left, rewrite_column_refs(cond, remapping)), + left_indices.map(i => ColumnRef(i + right_len)) ++ + right_indices.map(i => ColumnRef(i) + ) + ) + +def has_refs_in_range(cond: Scalar, from: Int64, to: Int64): Bool = + match predicate + case ColumnRef(idx) => from <= idx && idx < to, + case _ => predicate.children.any(child => has_refs_in_range(child, from, to)) + + +@rule(Logical) +def join_associate(expr: Logical): Logical = + match expr + case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) => + val a_len = a.schema_len; + if !has_refs_in_range(cond_outer, 0, a_len) then + val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map(); + Join( + "Inner", a, + Join("Inner", b, c, rewrite_column_refs(cond_outer, remap_inner), + cond_inner) + ) + else fail("") + +def conjunctive_normal_form(expr: Scalar): Bool = fail("unimplemented") + +def with_optional_filter(key: String, old: Scalar, grouped: Map[String, [Scalar]]): Scalar = match grouped(key) + case Some(conds) => Filter(conds, old), + case _ => old + +@rule(Logical) +def filter_pushdown_join(expr: Logical): Logical = + match expr + case op @ Filter(Join(join_type, left, right, join_cond), cond) => + val conditions_with_refs = conjunctive_normal_form(cond) + .map(c => (c, extract_column_refs(c))); + + val grouped = conditions_with_refs.groupBy((cond, refs) => { + if has_refs_in_range(refs, 0, left.schema_len) && + !has_refs_in_range(refs, left.schema_len, op.schema_len) then + "left" + else if !has_refs_in_range(refs, 0, left.schema_len) && + has_refs_in_range(refs, left.schema_len, op.schema_len) then + "right" + else + "remain" + }); + + with_optional_filter("remain", + Join(join_type, + with_optional_filter("left", grouped), + with_optional_filter("right", grouped), + join_cond + ) + ) diff --git a/optd-core/src/dsl/parser/types.rs b/optd-core/src/dsl/parser/types.rs new file mode 100644 index 0000000..87ef1e8 --- /dev/null +++ b/optd-core/src/dsl/parser/types.rs @@ -0,0 +1,167 @@ +use pest::iterators::Pair; + +use super::{ + ast::{OperatorKind, Type}, + Rule, +}; + +/// Parse a type expression from a pest Pair +/// +/// # Arguments +/// * `pair` - The pest Pair containing the type expression +/// +/// # Returns +/// * `Type` - The parsed type AST node +pub fn parse_type_expr(pair: Pair<'_, Rule>) -> Type { + match pair.as_rule() { + Rule::type_expr => parse_type_expr(pair.into_inner().next().unwrap()), + Rule::base_type => parse_base_type(pair), + Rule::array_type => parse_array_type(pair), + Rule::map_type => parse_map_type(pair), + Rule::tuple_type => parse_tuple_type(pair), + Rule::function_type => parse_function_type(pair), + Rule::operator_type => parse_operator_type(pair), + _ => unreachable!("Unexpected type rule: {:?}", pair.as_rule()), + } +} + +/// Parse a base type (Int64, String, Bool, Float64) +fn parse_base_type(pair: Pair<'_, Rule>) -> Type { + match pair.as_str() { + "Int64" => Type::Int64, + "String" => Type::String, + "Bool" => Type::Bool, + "Float64" => Type::Float64, + _ => unreachable!("Unexpected base type: {}", pair.as_str()), + } +} + +/// Parse an array type (e.g., [Int64]) +fn parse_array_type(pair: Pair<'_, Rule>) -> Type { + let inner_type = pair.into_inner().next().unwrap(); + Type::Array(Box::new(parse_type_expr(inner_type))) +} + +/// Parse a map type (e.g., Map[String, Int64]) +fn parse_map_type(pair: Pair<'_, Rule>) -> Type { + let mut inner_types = pair.into_inner(); + let key_type = parse_type_expr(inner_types.next().unwrap()); + let value_type = parse_type_expr(inner_types.next().unwrap()); + Type::Map(Box::new(key_type), Box::new(value_type)) +} + +/// Parse a tuple type (e.g., (Int64, String)) +fn parse_tuple_type(pair: Pair<'_, Rule>) -> Type { + let types = pair.into_inner().map(parse_type_expr).collect(); + Type::Tuple(types) +} + +/// Parse a function type (e.g., (Int64) -> String) +fn parse_function_type(pair: Pair<'_, Rule>) -> Type { + let mut inner_types = pair.into_inner(); + let input_type = parse_type_expr(inner_types.next().unwrap()); + let output_type = parse_type_expr(inner_types.next().unwrap()); + Type::Function(Box::new(input_type), Box::new(output_type)) +} + +/// Parse an operator type (Scalar or Logical) +fn parse_operator_type(pair: Pair<'_, Rule>) -> Type { + match pair.as_str() { + "Scalar" => Type::Operator(OperatorKind::Scalar), + "Logical" => Type::Operator(OperatorKind::Logical), + _ => unreachable!("Unexpected operator type: {}", pair.as_str()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dsl::parser::DslParser; + use pest::Parser; + + fn parse_type_from_str(input: &str) -> Type { + let pair = DslParser::parse(Rule::type_expr, input) + .unwrap() + .next() + .unwrap(); + parse_type_expr(pair) + } + + #[test] + fn test_parse_base_types() { + assert!(matches!(parse_type_from_str("Int64"), Type::Int64)); + assert!(matches!(parse_type_from_str("String"), Type::String)); + assert!(matches!(parse_type_from_str("Bool"), Type::Bool)); + assert!(matches!(parse_type_from_str("Float64"), Type::Float64)); + } + + #[test] + fn test_parse_array_type() { + match parse_type_from_str("[Int64]") { + Type::Array(inner) => assert!(matches!(*inner, Type::Int64)), + _ => panic!("Expected array type"), + } + } + + #[test] + fn test_parse_map_type() { + match parse_type_from_str("Map[String, Int64]") { + Type::Map(key, value) => { + assert!(matches!(*key, Type::String)); + assert!(matches!(*value, Type::Int64)); + } + _ => panic!("Expected map type"), + } + } + + #[test] + fn test_parse_tuple_type() { + match parse_type_from_str("(Int64, String, Bool)") { + Type::Tuple(types) => { + assert_eq!(types.len(), 3); + assert!(matches!(types[0], Type::Int64)); + assert!(matches!(types[1], Type::String)); + assert!(matches!(types[2], Type::Bool)); + } + _ => panic!("Expected tuple type"), + } + } + + #[test] + fn test_parse_function_type() { + match parse_type_from_str("(Int64) -> String") { + Type::Function(input, output) => { + assert!(matches!(*input, Type::Int64)); + assert!(matches!(*output, Type::String)); + } + _ => panic!("Expected function type"), + } + } + + #[test] + fn test_parse_operator_type() { + match parse_type_from_str("Scalar") { + Type::Operator(kind) => assert!(matches!(kind, OperatorKind::Scalar)), + _ => panic!("Expected scalar operator type"), + } + + match parse_type_from_str("Logical") { + Type::Operator(kind) => assert!(matches!(kind, OperatorKind::Logical)), + _ => panic!("Expected logical operator type"), + } + } + + #[test] + fn test_parse_nested_types() { + match parse_type_from_str("Map[String, [Int64]]") { + Type::Map(key, value) => { + assert!(matches!(*key, Type::String)); + match *value { + Type::Array(inner) => assert!(matches!(*inner, Type::Int64)), + _ => panic!("Expected array type"), + } + } + _ => panic!("Expected map type"), + } + } +} diff --git a/optd-core/src/lib.rs b/optd-core/src/lib.rs index 494ff29..975eaba 100644 --- a/optd-core/src/lib.rs +++ b/optd-core/src/lib.rs @@ -1,5 +1,6 @@ #[allow(dead_code)] pub mod cascades; +pub mod dsl; pub mod engine; pub mod operators; pub mod plans;