diff --git a/optd-core/src/dsl/parser/ast.rs b/optd-core/src/dsl/parser/ast.rs index ef48fb8..a34bfea 100644 --- a/optd-core/src/dsl/parser/ast.rs +++ b/optd-core/src/dsl/parser/ast.rs @@ -1,112 +1,153 @@ use std::collections::HashMap; +/// Types supported by the language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int64, String, Bool, Float64, - Array(Box), - Map(Box, Box), - Tuple(Vec), - Function(Box, Box), // Function type: (input) -> output + Array(Box), // Array types like [T] + Map(Box, Box), // Map types like map[K->V] + Tuple(Vec), // Tuple types like (T1, T2) + Function(Box, Box), // Function types like (T1)->T2 + Operator(OperatorKind), // Operator types (scalar/logical) } +/// Kinds of operators supported in the language +#[derive(Debug, Clone, PartialEq)] +pub enum OperatorKind { + Scalar, // Scalar operators + Logical, // Logical operators with derivable properties +} + +/// A field in an operator or properties block #[derive(Debug, Clone)] pub struct Field { pub name: String, pub ty: Type, } +/// Logical properties block that must appear exactly once per file +#[derive(Debug, Clone)] +pub struct Properties { + pub fields: Vec, +} + +/// Top-level operator definition +#[derive(Debug, Clone)] +pub enum Operator { + Scalar(ScalarOp), + Logical(LogicalOp), +} + +/// Scalar operator definition #[derive(Debug, Clone)] pub struct ScalarOp { pub name: String, pub fields: Vec, } +/// Logical operator definition with derived properties #[derive(Debug, Clone)] pub struct LogicalOp { pub name: String, pub fields: Vec, - pub derived_props: HashMap, + pub derived_props: HashMap, // Maps property names to their derivation expressions } +/// Patterns used in match expressions #[derive(Debug, Clone)] pub enum Pattern { - Bind(String, Box), - Constructor(String, Vec<(String, Pattern)>), - Enum(String, Option>), - Literal(Literal), - Wildcard, - Var(String), + Bind(String, Box), // Binding patterns like x@p or x:p + Constructor( + String, // Constructor name + Vec, // Subpatterns, can be named (x:p) or positional + ), + Literal(Literal), // Literal patterns like 42 or "hello" + Wildcard, // Wildcard pattern _ + Var(String), // Variable binding pattern } +/// Literal values #[derive(Debug, Clone)] pub enum Literal { - Number(i64), - StringLit(String), - Boolean(bool), + Int64(i64), + String(String), + Bool(bool), + Float64(f64), + Array(Vec), // Array literals [e1, e2, ...] + Tuple(Vec), // Tuple literals (e1, e2, ...) } +/// Expressions - the core of the language #[derive(Debug, Clone)] pub enum Expr { - Match(Box, Vec<(Pattern, Block)>), - If(Box, Box, Option>), - Val(String, Box, Box), - Array(Vec), - Map(Vec<(Expr, Expr)>), - Range(Box, Box), - Binary(Box, BinOp, Box), - Call(Box, Vec), - Member(Box, String), - MemberCall(Box, String, Vec), - Var(String), - Literal(Literal), - Constructor(String, Vec<(String, Expr)>), - Fail(String), - Closure(Vec, Box), + Match(Box, Vec), // Pattern matching + If(Box, Box, Box), // If-then-else + Val(String, Box, Box), // Local binding (val x = e1; e2) + Array(Vec), // Array literals + Tuple(Vec), // Tuple literals + Constructor(String, Vec), // Constructor application (currently only operators) + Binary(Box, BinOp, Box), // Binary operations + Unary(UnaryOp, Box), // Unary operations + Call(Box, Vec), // Function application + Member(Box, String), // Field access (e.f) + MemberCall(Box, String, Vec), // Method call (e.f(args)) + ArrayIndex(Box, Box), // Array indexing (e[i]) + Var(String), // Variable reference + Literal(Literal), // Literal values + Fail(String), // Failure with message + Closure(Vec, Box), // Anonymous functions +} + +/// A case in a match expression +#[derive(Debug, Clone)] +pub struct MatchArm { + pub pattern: Pattern, + pub expr: Expr, } +/// Binary operators with fixed precedence #[derive(Debug, Clone)] pub enum BinOp { - Add, - Sub, - Mul, - Div, - Concat, - Eq, - Neq, - Gt, - Lt, - Ge, - Le, - And, - Or, + Add, // + + Sub, // - + Mul, // * + Div, // / + Concat, // ++ + Eq, // == + Neq, // != + Gt, // > + Lt, // < + Ge, // >= + Le, // <= + And, // && + Or, // || + Range, // .. } +/// Unary operators #[derive(Debug, Clone)] -pub struct Block { - pub exprs: Vec, +pub enum UnaryOp { + Neg, // - + Not, // ! } +/// Function definition #[derive(Debug, Clone)] pub struct Function { pub name: String, - pub params: Vec<(String, Type)>, + pub params: Vec<(String, Type)>, // Parameter name and type pairs pub return_type: Type, - pub body: Block, - pub is_rule: bool, - pub is_operator: bool, + pub body: Expr, + pub rule_type: Option, // Some if this is a rule, indicating what kind } +/// A complete source file #[derive(Debug, Clone)] pub struct File { - pub operators: Vec, - pub functions: Vec, + pub properties: Properties, // The single logical properties block + pub operators: Vec, // All operator definitions + pub functions: Vec, // All function definitions } - -#[derive(Debug, Clone)] -pub enum Operator { - Scalar(ScalarOp), - Logical(LogicalOp), -} \ No newline at end of file diff --git a/optd-core/src/dsl/parser/grammar.pest b/optd-core/src/dsl/parser/grammar.pest index 5de18f9..ce1f536 100644 --- a/optd-core/src/dsl/parser/grammar.pest +++ b/optd-core/src/dsl/parser/grammar.pest @@ -1,25 +1,56 @@ +KEYWORD = { + "Bool" | + "Float64" | + "Int64" | + "Logical" | + "Map" | + "Props" | + "Scalar" | + "String" | + "case" | + "def" | + "derive" | + "else" | + "fail" | + "false" | + "if" | + "match" | + "then" | + "true" | + "val" +} + // Whitespace and comments WHITESPACE = _{ " " | "\t" | "\r" | "\n" } COMMENT = _{ "//" ~ (!"\n" ~ ANY)* } // Basic identifiers and literals -identifier = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* } +identifier = @{ + !KEYWORD ~ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* +} number = @{ "-"? ~ ASCII_DIGIT+ } string = @{ "\"" ~ (!"\"" ~ ANY)* ~ "\"" } boolean = { "true" | "false" } // Types -operator_type = { "scalar" | "logical" } -base_type = { "i64" | "String" | "Bool" | "f64" } +operator_type = { "Scalar" | "Logical" } +base_type = { "Int64" | "String" | "Bool" | "Float64" } array_type = { "[" ~ type_expr ~ "]" } -map_type = { "map" ~ "[" ~ type_expr ~ "->" ~ type_expr ~ "]" } -tuple_type = { "[" ~ type_expr ~ ("," ~ type_expr)* ~ "]" } +map_type = { "Map" ~ "[" ~ type_expr ~ "," ~ type_expr ~ "]" } +tuple_type = { "(" ~ type_expr ~ "," ~ type_expr ~ ("," ~ type_expr)* ~ ")" } function_type = { "(" ~ type_expr ~ ")" ~ "->" ~ type_expr } // Function composition type -type_expr = { operator_type | base_type | array_type | map_type | tuple_type | function_type } +type_expr = { + function_type + | operator_type + | base_type + | array_type + | map_type + | tuple_type +} // Annotations -props_annot = { "logical" ~ "props" } -rule_annot = { "@rule" ~ "(" ~ ("scalar" | "logical") ~ ")" } +props_annot = { "Logical" ~ "Props" } +rule_annot = { "@rule" ~ "(" ~ ("Scalar" | "Logical") ~ ")" } // Operators operator_def = { @@ -55,12 +86,12 @@ prop_derivation = { identifier ~ "=" ~ expr } // Functions function_def = { normal_function | rule_def } -normal_function = { +normal_function = _{ "def" ~ identifier ~ "(" ~ params? ~ ")" ~ ":" ~ type_expr ~ "=" ~ NEWLINE? ~ (expr | "{" ~ NEWLINE? ~ expr ~ NEWLINE? ~ "}") // Braces optional } -rule_def = { +rule_def = _{ rule_annot ~ NEWLINE? ~ normal_function } @@ -74,7 +105,7 @@ pattern = { | top_level_pattern } -top_level_pattern = { +top_level_pattern = _{ constructor_pattern | literal_pattern | wildcard_pattern @@ -111,6 +142,15 @@ expr = { | (closure | logical_or) } +// Operators as separate rules +or_op = { "||" } +and_op = { "&&" } +compare_op = { "==" | "!=" | ">=" | "<=" | ">" | "<" } +concat_op = { "++" } +add_op = { "+" | "-" } +range_op = { ".." } +mult_op = { "*" | "/" } + // Closure is lowest precedence closure = { closure_params ~ "=>" ~ expr } closure_params = { @@ -118,14 +158,14 @@ closure_params = { | identifier // Single param case } -logical_or = { logical_and ~ ("||" ~ logical_and)* } -logical_and = { comparison ~ ("&&" ~ comparison)* } -comparison = { concatenation ~ (("==" | "!=" | ">=" | "<=" | ">" | "<") ~ concatenation)* } -concatenation = { additive ~ ("++" ~ additive)* } -additive = { range ~ (("+" | "-") ~ range)* } -range = { multiplicative ~ (".." ~ multiplicative)* } -multiplicative = { postfix ~ (("*" | "/") ~ postfix)* } -postfix = { prefix ~ (call | member_access | array_index | member_call)* } +logical_or = { logical_and ~ (or_op ~ logical_and)* } +logical_and = { comparison ~ (and_op ~ comparison)* } +comparison = { concatenation ~ (compare_op ~ concatenation)* } +concatenation = { additive ~ (concat_op ~ additive)* } +additive = { range ~ (add_op ~ range)* } +range = { multiplicative ~ (range_op ~ multiplicative)* } +multiplicative = { postfix ~ (mult_op ~ postfix)* } +postfix = { prefix ~ (member_call | member_access | call | array_index)* } prefix = { ("!" | "-")? ~ primary } primary = _{ @@ -152,7 +192,7 @@ constructor_expr = { identifier ~ "(" ~ constructor_fields? ~ ")" } -constructor_fields = { +constructor_fields = _{ expr ~ ("," ~ expr)* } @@ -193,6 +233,7 @@ val_expr = { "val" ~ identifier ~ "=" ~ expr ~ ";" ~ expr } // Full file file = { SOI ~ - (operator_def | props_block | function_def)* ~ + props_block ~ + (operator_def | function_def)* ~ EOI } \ No newline at end of file diff --git a/optd-core/src/dsl/parser/mod.rs b/optd-core/src/dsl/parser/mod.rs index 1b4b24c..c575bd3 100644 --- a/optd-core/src/dsl/parser/mod.rs +++ b/optd-core/src/dsl/parser/mod.rs @@ -1,55 +1,38 @@ -pub mod ast; +use pest::{iterators::Pair, Parser}; +use pest_derive::Parser; use std::collections::HashMap; -use ast::*; -use pest::iterators::Pair; -use pest::Parser; -use pest_derive::Parser; +pub mod ast; #[derive(Parser)] #[grammar = "dsl/parser/grammar.pest"] pub struct DslParser; -#[cfg(test)] -mod tests { - use std::fs; - - use pest::Parser; - - use super::*; - - #[test] - fn test_parse_operator() { - let input = fs::read_to_string("/home/alexis/optd/optd-core/src/dsl/parser/test.optd") - .expect("Failed to read file"); - - let pairs = DslParser::parse(Rule::file, &input) - .map_err(|e| e.to_string()) - .unwrap(); - print!("{:?}", pairs); - - // assert_eq!(file.operators.len(), 1); - // Add more assertions... - } -} +use ast::*; +use pest::error::Error; -/*pub fn parse_file(input: &str) -> Result { - let pairs = DslParser::parse(Rule::file, input).map_err(|e| e.to_string())?; +pub fn parse_file(input: &str) -> Result> { + let pairs = DslParser::parse(Rule::file, input)? + .next() + .unwrap() + .into_inner(); let mut file = File { + properties: Properties { fields: Vec::new() }, operators: Vec::new(), functions: Vec::new(), }; for pair in pairs { match pair.as_rule() { + Rule::props_block => { + file.properties = parse_properties_block(pair); + } Rule::operator_def => { - let operator = parse_operator(pair)?; - file.operators.push(operator); + file.operators.push(parse_operator_def(pair)); } Rule::function_def => { - let function = parse_function(pair)?; - file.functions.push(function); + file.functions.push(parse_function_def(pair)); } _ => {} } @@ -58,344 +41,508 @@ mod tests { Ok(file) } -fn parse_operator(pair: Pair) -> Result { - let mut inner_pairs = pair.into_inner(); - let operator_type = inner_pairs.next().unwrap().as_str(); - let name = inner_pairs.next().unwrap().as_str().to_string(); +fn parse_properties_block(pair: Pair) -> Properties { + let mut properties = Properties { fields: Vec::new() }; + + for field_pair in pair.into_inner() { + if field_pair.as_rule() == Rule::field_def { + properties.fields.push(parse_field_def(field_pair)); + } + } + + properties +} + +fn parse_operator_def(pair: Pair) -> Operator { + let mut operator_type = None; + let mut name = None; + let mut fields = Vec::new(); + let mut derived_props = HashMap::new(); + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::operator_type => { + operator_type = Some(match inner_pair.as_str() { + "Scalar" => OperatorKind::Scalar, + "Logical" => OperatorKind::Logical, + _ => unreachable!(), + }); + } + Rule::identifier => { + name = Some(inner_pair.as_str().to_string()); + } + Rule::field_def_list => { + for field_pair in inner_pair.into_inner() { + fields.push(parse_field_def(field_pair)); + } + } + Rule::derive_props_block => { + for prop_pair in inner_pair.into_inner() { + if prop_pair.as_rule() == Rule::prop_derivation { + let mut prop_name = None; + let mut expr = None; + + for prop_inner_pair in prop_pair.into_inner() { + match prop_inner_pair.as_rule() { + Rule::identifier => { + prop_name = Some(prop_inner_pair.as_str().to_string()); + } + Rule::expr => { + expr = Some(parse_expr(prop_inner_pair)); + } + _ => unreachable!(), + } + } - let fields = parse_field_def_list(inner_pairs.next().unwrap())?; + if let (Some(name), Some(expr)) = (prop_name, expr) { + derived_props.insert(name, expr); + } + } + } + } + _ => unreachable!(), + } + } match operator_type { - "scalar" => Ok(Operator::Scalar(ScalarOp { name, fields })), - "logical" => { - let derived_props = parse_derive_props_block(inner_pairs.next().unwrap())?; - Ok(Operator::Logical(LogicalOp { - name, - fields, - derived_props, - })) + Some(OperatorKind::Scalar) => Operator::Scalar(ScalarOp { + name: name.unwrap(), + fields, + }), + Some(OperatorKind::Logical) => Operator::Logical(LogicalOp { + name: name.unwrap(), + fields, + derived_props, + }), + _ => unreachable!(), + } +} + +fn parse_function_def(pair: Pair) -> Function { + let mut name = None; + let mut params = Vec::new(); + let mut return_type = None; + let mut body = None; + let mut rule_type = None; + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::identifier => { + name = Some(inner_pair.as_str().to_string()); + } + Rule::params => { + for param_pair in inner_pair.into_inner() { + if param_pair.as_rule() == Rule::param { + let mut param_name = None; + let mut param_type = None; + + for param_inner_pair in param_pair.into_inner() { + match param_inner_pair.as_rule() { + Rule::identifier => { + param_name = Some(param_inner_pair.as_str().to_string()); + } + Rule::type_expr => { + param_type = Some(parse_type_expr(param_inner_pair)); + } + _ => {} + } + } + + if let (Some(name), Some(ty)) = (param_name, param_type) { + params.push((name, ty)); + } + } + } + } + Rule::type_expr => { + return_type = Some(parse_type_expr(inner_pair)); + } + Rule::expr => { + body = Some(parse_expr(inner_pair)); + } + Rule::rule_annot => { + for annot_inner_pair in inner_pair.into_inner() { + if annot_inner_pair.as_rule() == Rule::operator_type { + rule_type = Some(match annot_inner_pair.as_str() { + "scalar" => OperatorKind::Scalar, + "logical" => OperatorKind::Logical, + _ => unreachable!(), + }); + } + } + } + _ => unreachable!(), } - _ => Err("Unknown operator type".to_string()), + } + + Function { + name: name.unwrap(), + params, + return_type: return_type.unwrap(), + body: body.unwrap(), + rule_type, } } -fn parse_field_def_list(pair: Pair) -> Result, String> { - let mut fields = Vec::new(); - for field_pair in pair.into_inner() { - let mut field_pairs = field_pair.into_inner(); - let name = field_pairs.next().unwrap().as_str().to_string(); - let ty = parse_type_expr(field_pairs.next().unwrap())?; - fields.push(Field { name, ty }); +fn parse_field_def(pair: Pair) -> Field { + let mut name = None; + let mut ty = None; + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::identifier => { + name = Some(inner_pair.as_str().to_string()); + } + Rule::type_expr => { + ty = Some(parse_type_expr(inner_pair)); + } + _ => {} + } + } + + Field { + name: name.unwrap(), + ty: ty.unwrap(), } - Ok(fields) } -fn parse_type_expr(pair: Pair) -> Result { +fn parse_type_expr(pair: Pair) -> Type { match pair.as_rule() { Rule::base_type => match pair.as_str() { - "i64" => Ok(Type::Int64), - "String" => Ok(Type::String), - "Bool" => Ok(Type::Bool), - "f64" => Ok(Type::Float64), - _ => Err("Unknown base type".to_string()), + "Int64" => Type::Int64, + "String" => Type::String, + "Bool" => Type::Bool, + "Float64" => Type::Float64, + _ => unreachable!(), }, Rule::array_type => { - let inner_type = parse_type_expr(pair.into_inner().next().unwrap())?; - Ok(Type::Array(Box::new(inner_type))) + let inner_type = pair.into_inner().next().unwrap(); + Type::Array(Box::new(parse_type_expr(inner_type))) } Rule::map_type => { - let mut inner_pairs = pair.into_inner(); - let key_type = parse_type_expr(inner_pairs.next().unwrap())?; - let value_type = parse_type_expr(inner_pairs.next().unwrap())?; - Ok(Type::Map(Box::new(key_type), Box::new(value_type))) + let mut inner_types = pair.into_inner(); + let key_type = parse_type_expr(inner_types.next().unwrap()); + let value_type = parse_type_expr(inner_types.next().unwrap()); + Type::Map(Box::new(key_type), Box::new(value_type)) } Rule::tuple_type => { let mut types = Vec::new(); - for type_pair in pair.into_inner() { - types.push(parse_type_expr(type_pair)?); + for inner_pair in pair.into_inner() { + types.push(parse_type_expr(inner_pair)); } - Ok(Type::Tuple(types)) + Type::Tuple(types) } Rule::function_type => { - let mut inner_pairs = pair.into_inner(); - let input_type = parse_type_expr(inner_pairs.next().unwrap())?; - let output_type = parse_type_expr(inner_pairs.next().unwrap())?; - Ok(Type::Function(Box::new(input_type), Box::new(output_type))) + let mut inner_types = pair.into_inner(); + let input_type = parse_type_expr(inner_types.next().unwrap()); + let output_type = parse_type_expr(inner_types.next().unwrap()); + Type::Function(Box::new(input_type), Box::new(output_type)) } - _ => Err("Unknown type expression".to_string()), + Rule::operator_type => match pair.as_str() { + "Scalar" => Type::Operator(OperatorKind::Scalar), + "Logical" => Type::Operator(OperatorKind::Logical), + _ => unreachable!(), + }, + Rule::type_expr => parse_type_expr(pair.into_inner().next().unwrap()), + _ => unreachable!(), } } -fn parse_derive_props_block(pair: Pair) -> Result, String> { - let mut derived_props = HashMap::new(); - for prop_pair in pair.into_inner() { - let mut prop_pairs = prop_pair.into_inner(); - let name = prop_pairs.next().unwrap().as_str().to_string(); - let expr = parse_expr(prop_pairs.next().unwrap())?; - derived_props.insert(name, expr); - } - Ok(derived_props) -} +fn parse_operator(pair: Pair) -> Expr { + let mut pairs = pair.clone().into_inner(); + let mut expr = parse_expr(pairs.next().unwrap()); -fn parse_function(pair: Pair) -> Result { - let mut inner_pairs = pair.into_inner(); - let name = inner_pairs.next().unwrap().as_str().to_string(); - let params = parse_params(inner_pairs.next().unwrap())?; - let return_type = parse_type_expr(inner_pairs.next().unwrap())?; - let body = parse_block(inner_pairs.next().unwrap())?; + while let Some(op_pair) = pairs.next() { + let op = parse_bin_op(op_pair); + let rhs = parse_expr(pairs.next().unwrap()); + expr = Expr::Binary(Box::new(expr), op, Box::new(rhs)); + } - Ok(Function { - name, - params, - return_type, - body, - is_rule: false, - is_operator: false, - }) + expr } -fn parse_params(pair: Pair) -> Result, String> { - let mut params = Vec::new(); - for param_pair in pair.into_inner() { - let mut param_pairs = param_pair.into_inner(); - let name = param_pairs.next().unwrap().as_str().to_string(); - let ty = parse_type_expr(param_pairs.next().unwrap())?; - params.push((name, ty)); +fn parse_postfix(pair: Pair) -> Expr { + let mut pairs = pair.into_inner(); + let mut expr = parse_expr(pairs.next().unwrap()); + for postfix_pair in pairs { + match postfix_pair.as_rule() { + Rule::call => { + let mut args = Vec::new(); + for arg_pair in postfix_pair.into_inner() { + args.push(parse_expr(arg_pair)); + } + expr = Expr::Call(Box::new(expr), args); + } + Rule::member_access => { + let member = postfix_pair.as_str().trim_start_matches('.').to_string(); + expr = Expr::Member(Box::new(expr), member); + } + Rule::array_index => { + let index = parse_expr(postfix_pair.into_inner().next().unwrap()); + expr = Expr::ArrayIndex(Box::new(expr), Box::new(index)); + } + Rule::member_call => { + let mut pairs = postfix_pair.into_inner(); + let member = pairs + .next() + .unwrap() + .as_str() + .trim_start_matches('.') + .to_string(); + let mut args = Vec::new(); + for arg_pair in pairs { + args.push(parse_expr(arg_pair)); + } + expr = Expr::MemberCall(Box::new(expr), member, args); + } + _ => unreachable!(), + } } - Ok(params) + expr } -fn parse_block(pair: Pair) -> Result { - let mut exprs = Vec::new(); - for expr_pair in pair.into_inner() { - exprs.push(parse_expr(expr_pair)?); +fn parse_prefix(pair: Pair) -> Expr { + let mut pairs = pair.into_inner(); + let first = pairs.next().unwrap(); + + if first.as_str() == "!" || first.as_str() == "-" { + let op = match first.as_str() { + "-" => UnaryOp::Neg, + "!" => UnaryOp::Not, + _ => unreachable!(), + }; + let expr = parse_expr(pairs.next().unwrap()); + Expr::Unary(op, Box::new(expr)) + } else { + parse_expr(first) } - Ok(Block { exprs }) } -fn parse_expr(pair: Pair) -> Result { +fn parse_expr(pair: Pair) -> Expr { match pair.as_rule() { + Rule::closure => { + let mut pairs = pair.into_inner(); + let params = pairs + .next() + .unwrap() + .as_str() + .split(',') + .map(|s| s.trim().to_string()) + .collect(); + let body = parse_expr(pairs.next().unwrap()); + Expr::Closure(params, Box::new(body)) + } + Rule::logical_or + | Rule::logical_and + | Rule::comparison + | Rule::concatenation + | Rule::additive + | Rule::range + | Rule::multiplicative => parse_operator(pair), + Rule::postfix => parse_postfix(pair), + Rule::prefix => parse_prefix(pair), Rule::match_expr => { - let mut inner_pairs = pair.into_inner(); - let expr = parse_expr(inner_pairs.next().unwrap())?; - let arms = parse_match_arms(inner_pairs.next().unwrap())?; - Ok(Expr::Match(Box::new(expr), arms)) + let mut pairs = pair.into_inner(); + let expr = parse_expr(pairs.next().unwrap()); + let mut arms = Vec::new(); + for arm_pair in pairs { + if arm_pair.as_rule() == Rule::match_arm { + arms.push(parse_match_arm(arm_pair)); + } + } + Expr::Match(Box::new(expr), arms) } Rule::if_expr => { - let mut inner_pairs = pair.into_inner(); - let cond = parse_expr(inner_pairs.next().unwrap())?; - let then_block = parse_block(inner_pairs.next().unwrap())?; - let else_block = if inner_pairs.peek().is_some() { - Some(Box::new(parse_block(inner_pairs.next().unwrap())?)) - } else { - None - }; - Ok(Expr::If(Box::new(cond), Box::new(then_block), else_block)) + let mut pairs = pair.into_inner(); + let cond = parse_expr(pairs.next().unwrap()); + let then_branch = parse_expr(pairs.next().unwrap()); + let else_branch = parse_expr(pairs.next().unwrap()); + Expr::If(Box::new(cond), Box::new(then_branch), Box::new(else_branch)) } Rule::val_expr => { - let mut inner_pairs = pair.into_inner(); - let name = inner_pairs.next().unwrap().as_str().to_string(); - let value = parse_expr(inner_pairs.next().unwrap())?; - let body = parse_expr(inner_pairs.next().unwrap())?; - Ok(Expr::Val(name, Box::new(value), Box::new(body))) + let mut pairs = pair.into_inner(); + let name = pairs.next().unwrap().as_str().to_string(); + let value = parse_expr(pairs.next().unwrap()); + let body = parse_expr(pairs.next().unwrap()); + Expr::Val(name, Box::new(value), Box::new(body)) } Rule::array_literal => { let mut exprs = Vec::new(); - for expr_pair in pair.into_inner() { - exprs.push(parse_expr(expr_pair)?); + for inner_pair in pair.into_inner() { + exprs.push(parse_expr(inner_pair)); } - Ok(Expr::Array(exprs)) + Expr::Array(exprs) } - Rule::map_literal => { - let mut entries = Vec::new(); - for entry_pair in pair.into_inner() { - let mut entry_pairs = entry_pair.into_inner(); - let key = parse_expr(entry_pairs.next().unwrap())?; - let value = parse_expr(entry_pairs.next().unwrap())?; - entries.push((key, value)); + Rule::tuple_literal => { + let mut exprs = Vec::new(); + for inner_pair in pair.into_inner() { + exprs.push(parse_expr(inner_pair)); } - Ok(Expr::Map(entries)) - } - Rule::range_expr => { - let mut inner_pairs = pair.into_inner(); - let start = parse_expr(inner_pairs.next().unwrap())?; - let end = parse_expr(inner_pairs.next().unwrap())?; - Ok(Expr::Range(Box::new(start), Box::new(end))) + Expr::Tuple(exprs) } - Rule::binary_expr => { - let mut inner_pairs = pair.into_inner(); - let left = parse_expr(inner_pairs.next().unwrap())?; - let op = parse_bin_op(inner_pairs.next().unwrap())?; - let right = parse_expr(inner_pairs.next().unwrap())?; - Ok(Expr::Binary(Box::new(left), op, Box::new(right))) + Rule::constructor_expr => { + let mut pairs = pair.into_inner(); + let name = pairs.next().unwrap().as_str().to_string(); + let mut args = Vec::new(); + for arg_pair in pairs { + args.push(parse_expr(arg_pair)); + } + Expr::Constructor(name, args) } - Rule::call_expr => { - let mut inner_pairs = pair.into_inner(); - let callee = parse_expr(inner_pairs.next().unwrap())?; - let args = parse_expr_list(inner_pairs.next().unwrap())?; - Ok(Expr::Call(Box::new(callee), args)) + Rule::number => { + let num = pair.as_str().parse().unwrap(); + Expr::Literal(Literal::Int64(num)) } - Rule::member_access => { - let mut inner_pairs = pair.into_inner(); - let object = parse_expr(inner_pairs.next().unwrap())?; - let member = inner_pairs.next().unwrap().as_str().to_string(); - Ok(Expr::Member(Box::new(object), member)) + Rule::string => { + let s = pair.as_str().to_string(); + Expr::Literal(Literal::String(s)) } - Rule::member_call => { - let mut inner_pairs = pair.into_inner(); - let object = parse_expr(inner_pairs.next().unwrap())?; - let member = inner_pairs.next().unwrap().as_str().to_string(); - let args = parse_expr_list(inner_pairs.next().unwrap())?; - Ok(Expr::MemberCall(Box::new(object), member, args)) + Rule::boolean => { + let b = pair.as_str().parse().unwrap(); + Expr::Literal(Literal::Bool(b)) } - Rule::var_expr => { + Rule::identifier => { let name = pair.as_str().to_string(); - Ok(Expr::Var(name)) - } - Rule::literal => { - let literal = parse_literal(pair.into_inner().next().unwrap())?; - Ok(Expr::Literal(literal)) - } - Rule::constructor_expr => { - let mut inner_pairs = pair.into_inner(); - let name = inner_pairs.next().unwrap().as_str().to_string(); - let fields = parse_constructor_fields(inner_pairs.next().unwrap())?; - Ok(Expr::Constructor(name, fields)) + Expr::Var(name) } Rule::fail_expr => { - let message = pair.into_inner().next().unwrap().as_str().to_string(); - Ok(Expr::Fail(message)) + let msg = pair.into_inner().next().unwrap().as_str().to_string(); + Expr::Fail(msg) } - Rule::closure => { - let mut inner_pairs = pair.into_inner(); - let params = parse_closure_params(inner_pairs.next().unwrap())?; - let body = parse_expr(inner_pairs.next().unwrap())?; - Ok(Expr::Closure(params, Box::new(body))) - } - _ => Err("Unknown expression".to_string()), + Rule::expr => parse_expr(pair.into_inner().next().unwrap()), + _ => unreachable!(), } } -fn parse_match_arms(pair: Pair) -> Result, String> { - let mut arms = Vec::new(); - for arm_pair in pair.into_inner() { - let mut arm_pairs = arm_pair.into_inner(); - let pattern = parse_pattern(arm_pairs.next().unwrap())?; - let block = parse_block(arm_pairs.next().unwrap())?; - arms.push((pattern, block)); - } - Ok(arms) +fn parse_match_arm(pair: Pair) -> MatchArm { + let mut pairs = pair.into_inner(); + let pattern = parse_pattern(pairs.next().unwrap()); + let expr = parse_expr(pairs.next().unwrap()); + MatchArm { pattern, expr } } -fn parse_pattern(pair: Pair) -> Result { +fn parse_pattern(pair: Pair) -> Pattern { match pair.as_rule() { - Rule::bind_pattern => { - let mut inner_pairs = pair.into_inner(); - let name = inner_pairs.next().unwrap().as_str().to_string(); - let pattern = parse_pattern(inner_pairs.next().unwrap())?; - Ok(Pattern::Bind(name, Box::new(pattern))) + Rule::pattern => { + let mut pairs = pair.into_inner(); + let first = pairs.next().unwrap(); + if first.as_rule() == Rule::identifier { + if let Some(at_pattern) = pairs.next() { + let name = first.as_str().to_string(); + let pattern = parse_pattern(at_pattern); + Pattern::Bind(name, Box::new(pattern)) + } else { + Pattern::Var(first.as_str().to_string()) + } + } else { + parse_pattern(first) + } } Rule::constructor_pattern => { - let mut inner_pairs = pair.into_inner(); - let name = inner_pairs.next().unwrap().as_str().to_string(); - let fields = parse_constructor_pattern_fields(inner_pairs.next().unwrap())?; - Ok(Pattern::Constructor(name, fields)) - } - Rule::enum_pattern => { - let mut inner_pairs = pair.into_inner(); - let name = inner_pairs.next().unwrap().as_str().to_string(); - let pattern = if inner_pairs.peek().is_some() { - Some(Box::new(parse_pattern(inner_pairs.next().unwrap())?)) - } else { - None - }; - Ok(Pattern::Enum(name, pattern)) + let mut pairs = pair.into_inner(); + let name = pairs.next().unwrap().as_str().to_string(); + let mut subpatterns = Vec::new(); + + if let Some(fields) = pairs.next() { + for field in fields.into_inner() { + if field.as_rule() == Rule::pattern_field { + let mut field_pairs = field.into_inner(); + let first = field_pairs.next().unwrap(); + + if first.as_rule() == Rule::identifier { + // Check if we have a field binding (identifier : pattern) + if let Some(colon_pattern) = field_pairs.next() { + let field_name = first.as_str().to_string(); + let pattern = parse_pattern(colon_pattern); + subpatterns.push(Pattern::Bind(field_name, Box::new(pattern))); + } else { + // Just an identifier + subpatterns.push(Pattern::Var(first.as_str().to_string())); + } + } else { + // Regular pattern + subpatterns.push(parse_pattern(first)); + } + } + } + } + + Pattern::Constructor(name, subpatterns) } Rule::literal_pattern => { - let literal = parse_literal(pair.into_inner().next().unwrap())?; - Ok(Pattern::Literal(literal)) + let lit = parse_literal(pair.into_inner().next().unwrap()); + Pattern::Literal(lit) } - Rule::wildcard_pattern => Ok(Pattern::Wildcard), - Rule::var_pattern => { - let name = pair.as_str().to_string(); - Ok(Pattern::Var(name)) - } - _ => Err("Unknown pattern".to_string()), + Rule::wildcard_pattern => Pattern::Wildcard, + _ => unreachable!(), } } -fn parse_constructor_pattern_fields(pair: Pair) -> Result, String> { - let mut fields = Vec::new(); - for field_pair in pair.into_inner() { - let mut field_pairs = field_pair.into_inner(); - let name = field_pairs.next().unwrap().as_str().to_string(); - let pattern = parse_pattern(field_pairs.next().unwrap())?; - fields.push((name, pattern)); - } - Ok(fields) -} - -fn parse_literal(pair: Pair) -> Result { +fn parse_literal(pair: Pair) -> Literal { match pair.as_rule() { - Rule::number => { - let num = pair.as_str().parse::().map_err(|e| e.to_string())?; - Ok(Literal::Number(num)) - } - Rule::string => { - let s = pair.as_str().to_string(); - Ok(Literal::StringLit(s)) + Rule::number => Literal::Int64(pair.as_str().parse().unwrap()), + Rule::string => Literal::String(pair.as_str().to_string()), + Rule::boolean => Literal::Bool(pair.as_str().parse().unwrap()), + Rule::array_literal => { + let mut exprs = Vec::new(); + for inner_pair in pair.into_inner() { + exprs.push(parse_expr(inner_pair)); + } + Literal::Array(exprs) } - Rule::boolean => { - let b = pair.as_str().parse::().map_err(|e| e.to_string())?; - Ok(Literal::Boolean(b)) + Rule::tuple_literal => { + let mut exprs = Vec::new(); + for inner_pair in pair.into_inner() { + exprs.push(parse_expr(inner_pair)); + } + Literal::Tuple(exprs) } - _ => Err("Unknown literal".to_string()), + _ => unreachable!(), } } -fn parse_bin_op(pair: Pair) -> Result { - match pair.as_str() { - "+" => Ok(BinOp::Add), - "-" => Ok(BinOp::Sub), - "*" => Ok(BinOp::Mul), - "/" => Ok(BinOp::Div), - "++" => Ok(BinOp::Concat), - "==" => Ok(BinOp::Eq), - "!=" => Ok(BinOp::Neq), - ">" => Ok(BinOp::Gt), - "<" => Ok(BinOp::Lt), - ">=" => Ok(BinOp::Ge), - "<=" => Ok(BinOp::Le), - "&&" => Ok(BinOp::And), - "||" => Ok(BinOp::Or), - _ => Err("Unknown binary operator".to_string()), - } -} +fn parse_bin_op(pair: Pair) -> BinOp { + let op_str = match pair.as_rule() { + Rule::add_op + | Rule::mult_op + | Rule::or_op + | Rule::and_op + | Rule::compare_op + | Rule::concat_op + | Rule::range_op => pair.as_str(), + _ => pair.as_str(), + }; -fn parse_expr_list(pair: Pair) -> Result, String> { - let mut exprs = Vec::new(); - for expr_pair in pair.into_inner() { - exprs.push(parse_expr(expr_pair)?); + match op_str { + "+" => BinOp::Add, + "-" => BinOp::Sub, + "*" => BinOp::Mul, + "/" => BinOp::Div, + "++" => BinOp::Concat, + "==" => BinOp::Eq, + "!=" => BinOp::Neq, + ">" => BinOp::Gt, + "<" => BinOp::Lt, + ">=" => BinOp::Ge, + "<=" => BinOp::Le, + "&&" => BinOp::And, + "||" => BinOp::Or, + ".." => BinOp::Range, + _ => unreachable!(), } - Ok(exprs) } -fn parse_constructor_fields(pair: Pair) -> Result, String> { - let mut fields = Vec::new(); - for field_pair in pair.into_inner() { - let mut field_pairs = field_pair.into_inner(); - let name = field_pairs.next().unwrap().as_str().to_string(); - let expr = parse_expr(field_pairs.next().unwrap())?; - fields.push((name, expr)); - } - Ok(fields) -} +#[cfg(test)] +mod tests { + use super::*; + use std::fs; -fn parse_closure_params(pair: Pair) -> Result, String> { - let mut params = Vec::new(); - for param_pair in pair.into_inner() { - params.push(param_pair.as_str().to_string()); + #[test] + fn test_parse_operator() { + let input = fs::read_to_string("/home/alexis/optd/optd-core/src/dsl/parser/test.optd") + .expect("Failed to read file"); + + let file = parse_file(&input).unwrap(); + assert!(file.properties.fields.len() == 1); + println!("{:#?}", file); } - Ok(params) } -*/ \ No newline at end of file diff --git a/optd-core/src/dsl/parser/test.optd b/optd-core/src/dsl/parser/test.optd index 5e405de..066c530 100644 --- a/optd-core/src/dsl/parser/test.optd +++ b/optd-core/src/dsl/parser/test.optd @@ -1,42 +1,42 @@ -// Shared properties for logical operators -logical props(schema_len: i64) +// Shared properties for Logical operators +Logical Props(schema_len: Map[(Int64, (Int64) -> String), Int64]) -scalar ColumnRef(idx: i64) -scalar Const(val: i64) -scalar Eq(left: scalar, right: scalar) -scalar Divide(left: scalar, right: scalar) -scalar Not(input: scalar) +Scalar ColumnRef(idx: (Int64) -> Map[(Int64, (Int64) -> String), Int64]) +Scalar Const(asd: Int64) +Scalar Eq(left: Scalar, right: Scalar) +Scalar Divide(left: Scalar, right: Scalar) +Scalar Not(input: Scalar) -// logical Operators -logical Filter(input: logical, cond: scalar) derive { +// Logical Operators +Logical Filter(input: Logical, cond: Scalar) derive { schema_len = input.schema_len } -logical Project (input: logical, exprs: [scalar]) derive { +Logical Project (input: Logical, exprs: [Scalar]) derive { schema_len = exprs.len, other = expr2 } -logical Join( - left: logical, - right: logical, +Logical Join( + left: Logical, + right: Logical, type: String, - cond: scalar + cond: Scalar ) derive { schema_len = left.schema_len + right.schema_len } -logical Sort(input: logical, keys: [scalar]) derive { +Logical Sort(input: Logical, keys: [Scalar]) derive { schema_len = input.schema_len } -logical Aggregate(input: logical, group_keys: [scalar], aggs: [scalar]) derive { +Logical Aggregate(input: Logical, group_keys: [Scalar], aggs: [Scalar]) derive { schema_len = group_keys.len + aggs.len } // Rules demonstrating all language features -@rule(scalar) -def constant_fold(expr: scalar): scalar = +@rule(Scalar) +def constant_fold(expr: Scalar): Scalar = match expr case op @ Add(left: _, Const(bla)) => Const(x + y), case op @ Multiply(left: Const(x), right: Const(y)) => { @@ -44,24 +44,25 @@ def constant_fold(expr: scalar): scalar = }, case And(ms) => { val folded = ms.map(ms); + val test = (5 + 5) * 3 ++ 87..bla; And(folded) }, case _ => expr -def rewrite_column_refs(expr: scalar, index_map: map[i64 -> i64]): scalar = +def rewrite_column_refs(expr: Scalar, index_map: Map[Int64, Int64]): Scalar = match expr case ColumnRef(i) => ColumnRef(index_map.get(i)), case _ => expr.with_children(child => rewrite_column_refs(child, index_map)) -@rule(scalar) -def has_refs_in_range(expr: scalar, start: i64, end: i64): Bool = +@rule(Scalar) +def has_refs_in_range(expr: Scalar, start: Int64, end: Int64): Bool = match expr { case ColumnRef(i) => i >= start && i < end, case _ => expr.children().any(child => has_refs_in_range(child, start, end)) } -@rule(logical) -def join_commute(expr: logical): logical = +@rule(Logical) +def join_commute(expr: Logical): Logical = match expr case Join("Inner", left, right, cond) => { val left_len = l.schema_len; @@ -75,10 +76,10 @@ def join_commute(expr: logical): logical = Join("Inner", r, l, rewrite_column_refs(c, refs_remap)) } -@rule(logical) -def join_associate(expr: logical): logical = +@rule(Logical) +def join_associate(expr: Logical): Logical = match expr - case Join("Inner", Join("Inner", a, b, c1), c, c2) => + case Join("Inner", l: Join("Inner", a, b, c1), c, c2) => val a_len = a.schema_len; val b_len = b.schema_len; val c_len = c.schema_len; @@ -90,8 +91,8 @@ def join_associate(expr: logical): logical = else fail("Cannot rewrite: outer join condition references left relation") -@rule(logical) -def complex_project(expr: logical): logical = +@rule(Logical) +def complex_project(expr: Logical): Logical = match expr case project @ Project(rel, es) => { if es.len == 0 then @@ -110,8 +111,8 @@ def complex_project(expr: logical): logical = } } -@rule(scalar) -def simplify_not(expr: scalar): scalar = +@rule(Scalar) +def simplify_not(expr: Scalar): Scalar = match expr case Not(Not(x)) => x, case Not(Eq(l, r)) => { @@ -121,8 +122,8 @@ def simplify_not(expr: scalar): scalar = }, case _ => expr -@rule(logical) -def push_filter_down(expr: logical): logical = +@rule(Logical) +def push_filter_down(expr: Logical): Logical = match expr case Filter(Project(rel, es), cond) => { val new_cond = rewrite_column_refs(cond, bla); @@ -130,7 +131,7 @@ def push_filter_down(expr: logical): logical = }, case _ => expr -def optimize_sort_keys(expr: logical): logical = +def optimize_sort_keys(expr: Logical): Logical = match expr case Sort(rel, ks) => { val simplified = ks.map(k => constant_fold(k)); diff --git a/optd-core/src/storage/memo.rs b/optd-core/src/storage/memo.rs index cde77cc..692f564 100644 --- a/optd-core/src/storage/memo.rs +++ b/optd-core/src/storage/memo.rs @@ -307,9 +307,9 @@ impl SqliteMemo { ScalarOperatorKind::Add, ) .await?; - println!("add: {:?}", add); - println!("scalar_expr_id: {:?}", scalar_expr_id); - println!("group_id: {:?}", group_id); + // println!("add: {:?}", add); + // println!("scalar_expr_id: {:?}", scalar_expr_id); + // println!("group_id: {:?}", group_id); let group_id = sqlx::query_scalar("INSERT INTO scalar_adds (scalar_expression_id, group_id, left_group_id, right_group_id) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET group_id = group_id RETURNING group_id") .bind(scalar_expr_id) .bind(group_id)