diff --git a/Cargo.lock b/Cargo.lock index 96b3926..a39a078 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -800,8 +800,6 @@ dependencies = [ "anyhow", "async-recursion", "dotenvy", - "pest", - "pest_derive", "proc-macro2", "serde", "serde_json", diff --git a/optd-core/Cargo.toml b/optd-core/Cargo.toml index 94efe51..65fd14b 100644 --- a/optd-core/Cargo.toml +++ b/optd-core/Cargo.toml @@ -14,6 +14,4 @@ tokio = { version = "1.43.0", features = ["full"] } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1", features = ["raw_value"] } dotenvy = "0.15" -async-recursion = "1.1.1" -pest = "2.7.15" -pest_derive = "2.7.15" +async-recursion = "1.1.1" \ No newline at end of file diff --git a/optd-core/src/dsl/analyzer/mod.rs b/optd-core/src/dsl/analyzer/mod.rs deleted file mode 100644 index 515416e..0000000 --- a/optd-core/src/dsl/analyzer/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod semantic; diff --git a/optd-core/src/dsl/analyzer/semantic.rs b/optd-core/src/dsl/analyzer/semantic.rs deleted file mode 100644 index 698b16e..0000000 --- a/optd-core/src/dsl/analyzer/semantic.rs +++ /dev/null @@ -1,575 +0,0 @@ -use std::collections::HashSet; - -use crate::dsl::ast::upper_layer::{ - Expr, File, Function, Operator, OperatorKind, Pattern, Properties, Type, -}; - -/// SemanticAnalyzer performs static analysis on the DSL code to ensure semantic correctness. -/// It validates properties, operators, functions, and expressions while maintaining scope information. -/// -/// # Fields -/// * `logical_properties` - Set of valid logical property names defined in the current context -/// * `operators` - Set of operator names to prevent duplicates -/// * `identifiers` - Stack of identifier sets representing different scopes -#[derive(Debug)] -pub struct SemanticAnalyzer { - logical_properties: HashSet, - operators: HashSet, - identifiers: Vec>, -} - -impl Default for SemanticAnalyzer { - /// Default implementation for SemanticAnalyzer. - fn default() -> Self { - Self::new() - } -} - -impl SemanticAnalyzer { - /// Creates a new SemanticAnalyzer instance with empty sets and a single global scope. - pub fn new() -> Self { - SemanticAnalyzer { - logical_properties: HashSet::new(), - operators: HashSet::new(), - identifiers: vec![HashSet::new()], // Initialize with global scope - } - } - - /// Creates a new scope for local variables. - /// Used when entering functions, closures, or blocks. - fn enter_scope(&mut self) { - self.identifiers.push(HashSet::new()); - } - - /// Removes the current scope when exiting a block. - /// Any variables defined in this scope become inaccessible. - fn exit_scope(&mut self) { - self.identifiers.pop(); - } - - /// Adds a new identifier to the current scope. - /// Returns an error if the identifier is already defined in the current scope. - /// - /// # Arguments - /// * `name` - The identifier name to add - /// - /// # Returns - /// * `Ok(())` if the identifier was added successfully - /// * `Err(String)` if the identifier already exists in the current scope - fn add_identifier(&mut self, name: String) -> Result<(), String> { - if let Some(scope) = self.identifiers.last_mut() { - if scope.contains(&name) { - return Err(format!("Duplicate identifier name: {}", name)); - } - scope.insert(name); - } - Ok(()) - } - - /// Checks if an identifier is defined in any accessible scope. - /// Searches from innermost to outermost scope. - /// - /// # Arguments - /// * `name` - The identifier name to look up - /// - /// # Returns - /// * `true` if the identifier is found in any accessible scope - /// * `false` otherwise - fn lookup_identifier(&self, name: &str) -> bool { - self.identifiers - .iter() - .rev() - .any(|scope| scope.contains(name)) - } - - /// Validates that a type is valid for scalar operators. - /// Scalar types include basic types and arrays/tuples of scalar types. - /// - /// # Arguments - /// * `ty` - The type to validate - /// - /// # Returns - /// * `true` if the type is valid for scalar operators - /// * `false` otherwise - fn is_valid_scalar_type(ty: &Type) -> bool { - match ty { - Type::Array(inner) => Self::is_valid_scalar_type(inner), - Type::Tuple(fields) => fields.iter().all(Self::is_valid_scalar_type), - Type::Int64 | Type::String | Type::Bool | Type::Float64 => true, - Type::Operator(OperatorKind::Scalar) => true, - _ => false, - } - } - - /// Validates that a type is valid for logical operators. - /// Logical types include basic types, operators, and arrays/tuples of logical types. - /// - /// # Arguments - /// * `ty` - The type to validate - /// - /// # Returns - /// * `true` if the type is valid for logical operators - /// * `false` otherwise - fn is_valid_logical_type(ty: &Type) -> bool { - match ty { - Type::Array(inner) => Self::is_valid_logical_type(inner), - Type::Tuple(fields) => fields.iter().all(Self::is_valid_logical_type), - Type::Int64 | Type::String | Type::Bool | Type::Float64 => true, - Type::Operator(_) => true, - _ => false, - } - } - - /// Validates that a type is valid for properties. - /// Property types include basic types, maps, and arrays/tuples of property types. - /// - /// # Arguments - /// * `ty` - The type to validate - /// - /// # Returns - /// * `true` if the type is valid for properties - /// * `false` otherwise - fn is_valid_property_type(ty: &Type) -> bool { - match ty { - Type::Array(inner) => Self::is_valid_property_type(inner), - Type::Tuple(fields) => fields.iter().all(Self::is_valid_property_type), - Type::Map(a, b) => Self::is_valid_property_type(a) && Self::is_valid_property_type(b), - Type::Int64 | Type::String | Type::Bool | Type::Float64 => true, - Type::Function(_, _) | Type::Operator(_) => false, - } - } - - /// Validates property definitions and updates the logical_properties set. - /// - /// # Arguments - /// * `properties` - The properties to validate - /// - /// # Returns - /// * `Ok(())` if all properties are valid - /// * `Err(String)` if any property has an invalid type - fn validate_properties(&mut self, properties: &Properties) -> Result<(), String> { - // Validate all property types - for field in &properties.fields { - if !Self::is_valid_property_type(&field.ty) { - return Err(format!("Invalid type in properties: {:?}", field.ty)); - } - } - - // Update logical properties set - self.logical_properties = properties - .fields - .iter() - .map(|field| field.name.clone()) - .collect(); - - Ok(()) - } - - /// Validates operator definitions, including field types and derived properties. - /// - /// # Arguments - /// * `operator` - The operator to validate - /// - /// # Returns - /// * `Ok(())` if the operator is valid - /// * `Err(String)` containing the validation error - fn validate_operator(&mut self, operator: &Operator) -> Result<(), String> { - let (name, fields, is_logical) = match operator { - Operator::Scalar(op) => (&op.name, &op.fields, false), - Operator::Logical(op) => (&op.name, &op.fields, true), - }; - - // Check for duplicate operator names - if self.operators.contains(name) { - return Err(format!("Duplicate operator name: {}", name)); - } - self.operators.insert(name.clone()); - - // Validate field types based on operator kind - if let Some(field) = fields.iter().find(|f| { - if is_logical { - !Self::is_valid_logical_type(&f.ty) - } else { - !Self::is_valid_scalar_type(&f.ty) - } - }) { - return Err(format!( - "Invalid type in {} operator: {:?}", - if is_logical { "logical" } else { "scalar" }, - field.ty - )); - } - - // Additional validation for logical operators - if let Operator::Logical(op) = operator { - // Validate derived properties exist in logical properties - if let Some(prop) = op - .derived_props - .keys() - .find(|&p| !self.logical_properties.contains(p)) - { - return Err(format!( - "Derived property not found in logical properties: {}", - prop - )); - } - - // Ensure all logical properties have derived implementations - if let Some(field) = self - .logical_properties - .iter() - .find(|&f| !op.derived_props.contains_key(f)) - { - return Err(format!( - "Logical property field '{}' is missing a derived property", - field - )); - } - } - - Ok(()) - } - - /// Validates function definitions, including parameters and body. - /// - /// # Arguments - /// * `function` - The function to validate - /// - /// # Returns - /// * `Ok(())` if the function is valid - /// * `Err(String)` containing the validation error - fn validate_function(&mut self, function: &Function) -> Result<(), String> { - // Add function name to current scope - self.add_identifier(function.name.clone())?; - - // Create new scope for function parameters and body - self.enter_scope(); - - // Add parameters to function scope - for (param_name, _) in &function.params { - self.add_identifier(param_name.clone())?; - } - - // Validate function body - self.validate_expr(&function.body)?; - - // Exit function scope - self.exit_scope(); - - Ok(()) - } - - /// Validates expressions recursively, ensuring all variables are defined - /// and sub-expressions are valid. - /// - /// # Arguments - /// * `expr` - The expression to validate - /// - /// # Returns - /// * `Ok(())` if the expression is valid - /// * `Err(String)` containing the validation error - fn validate_expr(&mut self, expr: &Expr) -> Result<(), String> { - match expr { - Expr::Var(name) => { - if !self.lookup_identifier(name) { - return Err(format!("Undefined identifier: {}", name)); - } - } - Expr::Val(name, expr1, expr2) => { - self.validate_expr(expr1)?; - self.add_identifier(name.clone())?; - self.validate_expr(expr2)?; - } - Expr::Match(expr, arms) => { - self.validate_expr(expr)?; - for arm in arms { - self.validate_pattern(&arm.pattern)?; - self.validate_expr(&arm.expr)?; - } - } - Expr::If(cond, then_expr, else_expr) => { - self.validate_expr(cond)?; - self.validate_expr(then_expr)?; - self.validate_expr(else_expr)?; - } - Expr::Binary(left, _, right) => { - self.validate_expr(left)?; - self.validate_expr(right)?; - } - Expr::Unary(_, expr) => { - self.validate_expr(expr)?; - } - Expr::Call(func, args) => { - self.validate_expr(func)?; - for arg in args { - self.validate_expr(arg)?; - } - } - Expr::Member(expr, _) => { - self.validate_expr(expr)?; - } - Expr::MemberCall(expr, _, args) => { - self.validate_expr(expr)?; - for arg in args { - self.validate_expr(arg)?; - } - } - Expr::ArrayIndex(array, index) => { - self.validate_expr(array)?; - self.validate_expr(index)?; - } - Expr::Literal(_) => {} - Expr::Fail(_) => {} - Expr::Closure(params, body) => { - self.enter_scope(); - for param in params { - self.add_identifier(param.clone())?; - } - self.validate_expr(body)?; - self.exit_scope(); - } - _ => {} - } - Ok(()) - } - - /// Validates pattern matching constructs, ensuring all bindings are valid. - /// - /// # Arguments - /// * `pattern` - The pattern to validate - /// - /// # Returns - /// * `Ok(())` if the pattern is valid - /// * `Err(String)` containing the validation error - fn validate_pattern(&mut self, pattern: &Pattern) -> Result<(), String> { - match pattern { - Pattern::Bind(name, pat) => { - self.add_identifier(name.clone())?; - self.validate_pattern(pat)?; - } - Pattern::Constructor(_, pats) => { - for pat in pats { - self.validate_pattern(pat)?; - } - } - Pattern::Literal(_) | Pattern::Wildcard => {} - Pattern::Var(name) => { - self.add_identifier(name.clone())?; - } - } - Ok(()) - } - - /// Main entry point for validating a complete DSL file. - /// Validates properties, operators, and functions in order. - /// - /// # Arguments - /// * `file` - The File AST node to validate - /// - /// # Returns - /// * `Ok(())` if the file is semantically valid - /// * `Err(String)` containing the first validation error encountered - pub fn validate_file(&mut self, file: &File) -> Result<(), String> { - self.validate_properties(&file.properties)?; - - for operator in &file.operators { - self.validate_operator(operator)?; - } - - for function in &file.functions { - self.validate_function(function)?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dsl::{ - ast::upper_layer::{ - Expr, Field, File, Function, Literal, LogicalOp, Operator, Properties, ScalarOp, Type, - }, - parser::parse_file, - }; - use std::collections::HashMap; - - fn create_test_file() -> File { - File { - properties: Properties { - fields: vec![ - Field { - name: "prop1".to_string(), - ty: Type::Int64, - }, - Field { - name: "prop2".to_string(), - ty: Type::String, - }, - ], - }, - operators: vec![ - Operator::Scalar(ScalarOp { - name: "op1".to_string(), - fields: vec![Field { - name: "field1".to_string(), - ty: Type::Int64, - }], - }), - Operator::Logical(LogicalOp { - name: "op2".to_string(), - fields: vec![Field { - name: "field2".to_string(), - ty: Type::String, - }], - derived_props: HashMap::from([ - ("prop1".to_string(), Expr::Literal(Literal::Int64(42))), - ( - "prop2".to_string(), - Expr::Literal(Literal::String("test".to_string())), - ), - ]), - }), - ], - functions: vec![Function { - name: "func1".to_string(), - params: vec![("x".to_string(), Type::Int64)], - return_type: Type::Int64, - body: Expr::Literal(Literal::Int64(42)), - rule_type: None, - }], - } - } - - #[test] - fn parse_working_file() { - let input = include_str!("../programs/working.optd"); - let out = parse_file(input).unwrap(); - let mut analyzer = SemanticAnalyzer::new(); - analyzer.validate_file(&out).unwrap(); - } - - #[test] - fn test_valid_file() { - let file = create_test_file(); - let mut analyzer = SemanticAnalyzer::new(); - assert!(analyzer.validate_file(&file).is_ok()); - } - - #[test] - fn test_duplicate_operator_name() { - let mut file = create_test_file(); - file.operators.push(Operator::Scalar(ScalarOp { - name: "op1".to_string(), // Duplicate name - fields: vec![Field { - name: "field3".to_string(), - ty: Type::Int64, - }], - })); - - let mut analyzer = SemanticAnalyzer::new(); - let result = analyzer.validate_file(&file); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Duplicate operator name: op1"); - } - - #[test] - fn test_duplicate_function_name() { - let mut file = create_test_file(); - file.functions.push(Function { - name: "func1".to_string(), // Duplicate name - params: vec![("y".to_string(), Type::Int64)], - return_type: Type::Int64, - body: Expr::Literal(Literal::Int64(42)), - rule_type: None, - }); - - let mut analyzer = SemanticAnalyzer::new(); - let result = analyzer.validate_file(&file); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Duplicate identifier name: func1"); - } - - #[test] - fn test_undefined_variable() { - let mut file = create_test_file(); - file.functions[0].body = Expr::Var("undefined_var".to_string()); // Undefined variable - - let mut analyzer = SemanticAnalyzer::new(); - let result = analyzer.validate_file(&file); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Undefined identifier: undefined_var"); - } - - #[test] - fn test_invalid_scalar_operator_type() { - let mut file = create_test_file(); - file.operators[0] = Operator::Scalar(ScalarOp { - name: "op1".to_string(), - fields: vec![Field { - name: "field1".to_string(), - ty: Type::Function(Box::new(Type::Int64), Box::new(Type::Int64)), - }], - }); - - let mut analyzer = SemanticAnalyzer::new(); - let result = analyzer.validate_file(&file); - assert!(result.is_err()); - assert_eq!( - result.unwrap_err(), - "Invalid type in scalar operator: Function(Int64, Int64)" - ); - } - - #[test] - fn test_invalid_logical_operator_type() { - let mut file = create_test_file(); - file.operators[1] = Operator::Logical(LogicalOp { - name: "op2".to_string(), - fields: vec![Field { - name: "field2".to_string(), - ty: Type::Function(Box::new(Type::Int64), Box::new(Type::Int64)), - }], - derived_props: HashMap::new(), - }); - - let mut analyzer = SemanticAnalyzer::new(); - let result = analyzer.validate_file(&file); - assert!(result.is_err()); - assert_eq!( - result.unwrap_err(), - "Invalid type in logical operator: Function(Int64, Int64)" - ); - } - - #[test] - fn test_missing_derived_property() { - let mut file = create_test_file(); - if let Operator::Logical(op) = &mut file.operators[1] { - op.derived_props.remove("prop2"); // Missing derived property - } - - let mut analyzer = SemanticAnalyzer::new(); - let result = analyzer.validate_file(&file); - assert!(result.is_err()); - assert_eq!( - result.unwrap_err(), - "Logical property field 'prop2' is missing a derived property" - ); - } - - #[test] - fn test_invalid_property_type() { - let mut file = create_test_file(); - file.properties.fields[0].ty = Type::Function(Box::new(Type::Int64), Box::new(Type::Int64)); - - let mut analyzer = SemanticAnalyzer::new(); - let result = analyzer.validate_file(&file); - assert!(result.is_err()); - assert_eq!( - result.unwrap_err(), - "Invalid type in properties: Function(Int64, Int64)" - ); - } -} diff --git a/optd-core/src/dsl/ast/lower_layer.rs b/optd-core/src/dsl/ast/lower_layer.rs deleted file mode 100644 index 8b13789..0000000 --- a/optd-core/src/dsl/ast/lower_layer.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/optd-core/src/dsl/ast/mod.rs b/optd-core/src/dsl/ast/mod.rs deleted file mode 100644 index 9c6d62b..0000000 --- a/optd-core/src/dsl/ast/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod lower_layer; -pub mod upper_layer; diff --git a/optd-core/src/dsl/ast/upper_layer.rs b/optd-core/src/dsl/ast/upper_layer.rs deleted file mode 100644 index 609f2ad..0000000 --- a/optd-core/src/dsl/ast/upper_layer.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::collections::HashMap; - -/// Types supported by the language -#[derive(Debug, Clone, PartialEq)] -pub enum Type { - Int64, - String, - Bool, - Float64, - Array(Box), // Array types like [T] - Map(Box, Box), // Map types like map[K->V] - Tuple(Vec), // Tuple types like (T1, T2) - Function(Box, Box), // Function types like (T1)->T2 - Operator(OperatorKind), // Operator types (scalar/logical) -} - -/// Kinds of operators supported in the language -#[derive(Debug, Clone, PartialEq)] -pub enum OperatorKind { - Scalar, // Scalar operators - Logical, // Logical operators with derivable properties -} - -/// A field in an operator or properties block -#[derive(Debug, Clone)] -pub struct Field { - pub name: String, - pub ty: Type, -} - -/// Logical properties block that must appear exactly once per file -#[derive(Debug, Clone)] -pub struct Properties { - pub fields: Vec, -} - -/// Top-level operator definition -#[derive(Debug, Clone)] -pub enum Operator { - Scalar(ScalarOp), - Logical(LogicalOp), -} - -/// Scalar operator definition -#[derive(Debug, Clone)] -pub struct ScalarOp { - pub name: String, - pub fields: Vec, -} - -/// Logical operator definition with derived properties -#[derive(Debug, Clone)] -pub struct LogicalOp { - pub name: String, - pub fields: Vec, - pub derived_props: HashMap, // Maps property names to their derivation expressions -} - -/// Patterns used in match expressions -#[derive(Debug, Clone)] -pub enum Pattern { - Bind(String, Box), // Binding patterns like x@p or x:p - Constructor( - String, // Constructor name - Vec, // Subpatterns, can be named (x:p) or positional - ), - Literal(Literal), // Literal patterns like 42 or "hello" - Wildcard, // Wildcard pattern _ - Var(String), // Variable binding pattern -} - -/// Literal values -#[derive(Debug, Clone)] -pub enum Literal { - Int64(i64), - String(String), - Bool(bool), - Float64(f64), - Array(Vec), // Array literals [e1, e2, ...] - Tuple(Vec), // Tuple literals (e1, e2, ...) -} - -/// Expressions - the core of the language -#[derive(Debug, Clone)] -pub enum Expr { - Match(Box, Vec), // Pattern matching - If(Box, Box, Box), // If-then-else - Val(String, Box, Box), // Local binding (val x = e1; e2) - Constructor(String, Vec), // Constructor application (currently only operators) - Binary(Box, BinOp, Box), // Binary operations - Unary(UnaryOp, Box), // Unary operations - Call(Box, Vec), // Function application - Member(Box, String), // Field access (e.f) - MemberCall(Box, String, Vec), // Method call (e.f(args)) - ArrayIndex(Box, Box), // Array indexing (e[i]) - Var(String), // Variable reference - Literal(Literal), // Literal values - Fail(String), // Failure with message - Closure(Vec, Box), // Anonymous functions v = (x, y) => x + y; -} - -/// A case in a match expression -#[derive(Debug, Clone)] -pub struct MatchArm { - pub pattern: Pattern, - pub expr: Expr, -} - -/// Binary operators with fixed precedence -#[derive(Debug, Clone)] -pub enum BinOp { - Add, // + - Sub, // - - Mul, // * - Div, // / - Concat, // ++ - Eq, // == - Neq, // != - Gt, // > - Lt, // < - Ge, // >= - Le, // <= - And, // && - Or, // || - Range, // .. -} - -/// Unary operators -#[derive(Debug, Clone)] -pub enum UnaryOp { - Neg, // - - Not, // ! -} - -/// Function definition -#[derive(Debug, Clone)] -pub struct Function { - pub name: String, - pub params: Vec<(String, Type)>, // Parameter name and type pairs - pub return_type: Type, - pub body: Expr, - pub rule_type: Option, // Some if this is a rule, indicating what kind -} - -/// A complete source file -#[derive(Debug, Clone)] -pub struct File { - pub properties: Properties, // The single logical properties block - pub operators: Vec, // All operator definitions - pub functions: Vec, // All function definitions -} diff --git a/optd-core/src/dsl/mod.rs b/optd-core/src/dsl/mod.rs deleted file mode 100644 index 9e6ab75..0000000 --- a/optd-core/src/dsl/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod analyzer; -pub mod ast; -pub mod parser; diff --git a/optd-core/src/dsl/parser/ast.rs b/optd-core/src/dsl/parser/ast.rs deleted file mode 100644 index 609f2ad..0000000 --- a/optd-core/src/dsl/parser/ast.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::collections::HashMap; - -/// Types supported by the language -#[derive(Debug, Clone, PartialEq)] -pub enum Type { - Int64, - String, - Bool, - Float64, - Array(Box), // Array types like [T] - Map(Box, Box), // Map types like map[K->V] - Tuple(Vec), // Tuple types like (T1, T2) - Function(Box, Box), // Function types like (T1)->T2 - Operator(OperatorKind), // Operator types (scalar/logical) -} - -/// Kinds of operators supported in the language -#[derive(Debug, Clone, PartialEq)] -pub enum OperatorKind { - Scalar, // Scalar operators - Logical, // Logical operators with derivable properties -} - -/// A field in an operator or properties block -#[derive(Debug, Clone)] -pub struct Field { - pub name: String, - pub ty: Type, -} - -/// Logical properties block that must appear exactly once per file -#[derive(Debug, Clone)] -pub struct Properties { - pub fields: Vec, -} - -/// Top-level operator definition -#[derive(Debug, Clone)] -pub enum Operator { - Scalar(ScalarOp), - Logical(LogicalOp), -} - -/// Scalar operator definition -#[derive(Debug, Clone)] -pub struct ScalarOp { - pub name: String, - pub fields: Vec, -} - -/// Logical operator definition with derived properties -#[derive(Debug, Clone)] -pub struct LogicalOp { - pub name: String, - pub fields: Vec, - pub derived_props: HashMap, // Maps property names to their derivation expressions -} - -/// Patterns used in match expressions -#[derive(Debug, Clone)] -pub enum Pattern { - Bind(String, Box), // Binding patterns like x@p or x:p - Constructor( - String, // Constructor name - Vec, // Subpatterns, can be named (x:p) or positional - ), - Literal(Literal), // Literal patterns like 42 or "hello" - Wildcard, // Wildcard pattern _ - Var(String), // Variable binding pattern -} - -/// Literal values -#[derive(Debug, Clone)] -pub enum Literal { - Int64(i64), - String(String), - Bool(bool), - Float64(f64), - Array(Vec), // Array literals [e1, e2, ...] - Tuple(Vec), // Tuple literals (e1, e2, ...) -} - -/// Expressions - the core of the language -#[derive(Debug, Clone)] -pub enum Expr { - Match(Box, Vec), // Pattern matching - If(Box, Box, Box), // If-then-else - Val(String, Box, Box), // Local binding (val x = e1; e2) - Constructor(String, Vec), // Constructor application (currently only operators) - Binary(Box, BinOp, Box), // Binary operations - Unary(UnaryOp, Box), // Unary operations - Call(Box, Vec), // Function application - Member(Box, String), // Field access (e.f) - MemberCall(Box, String, Vec), // Method call (e.f(args)) - ArrayIndex(Box, Box), // Array indexing (e[i]) - Var(String), // Variable reference - Literal(Literal), // Literal values - Fail(String), // Failure with message - Closure(Vec, Box), // Anonymous functions v = (x, y) => x + y; -} - -/// A case in a match expression -#[derive(Debug, Clone)] -pub struct MatchArm { - pub pattern: Pattern, - pub expr: Expr, -} - -/// Binary operators with fixed precedence -#[derive(Debug, Clone)] -pub enum BinOp { - Add, // + - Sub, // - - Mul, // * - Div, // / - Concat, // ++ - Eq, // == - Neq, // != - Gt, // > - Lt, // < - Ge, // >= - Le, // <= - And, // && - Or, // || - Range, // .. -} - -/// Unary operators -#[derive(Debug, Clone)] -pub enum UnaryOp { - Neg, // - - Not, // ! -} - -/// Function definition -#[derive(Debug, Clone)] -pub struct Function { - pub name: String, - pub params: Vec<(String, Type)>, // Parameter name and type pairs - pub return_type: Type, - pub body: Expr, - pub rule_type: Option, // Some if this is a rule, indicating what kind -} - -/// A complete source file -#[derive(Debug, Clone)] -pub struct File { - pub properties: Properties, // The single logical properties block - pub operators: Vec, // All operator definitions - pub functions: Vec, // All function definitions -} diff --git a/optd-core/src/dsl/parser/expr.rs b/optd-core/src/dsl/parser/expr.rs deleted file mode 100644 index 8a346cd..0000000 --- a/optd-core/src/dsl/parser/expr.rs +++ /dev/null @@ -1,323 +0,0 @@ -use pest::iterators::Pair; - -use crate::dsl::ast::upper_layer::{BinOp, Expr, Literal, MatchArm, UnaryOp}; - -use super::{patterns::parse_pattern, Rule}; - -/// Parse a complete expression from a pest Pair -/// -/// # Arguments -/// * `pair` - The pest Pair containing the expression -/// -/// # Returns -/// * `Expr` - The parsed expression AST node -pub fn parse_expr(pair: Pair<'_, Rule>) -> Expr { - match pair.as_rule() { - Rule::expr => parse_expr(pair.into_inner().next().unwrap()), - Rule::closure => parse_closure(pair), - Rule::logical_or - | Rule::logical_and - | Rule::comparison - | Rule::concatenation - | Rule::additive - | Rule::range - | Rule::multiplicative => parse_binary_operation(pair), - Rule::postfix => parse_postfix(pair), - Rule::prefix => parse_prefix(pair), - Rule::match_expr => parse_match_expr(pair), - Rule::if_expr => parse_if_expr(pair), - Rule::val_expr => parse_val_expr(pair), - Rule::array_literal => parse_array_literal(pair), - Rule::tuple_literal => parse_tuple_literal(pair), - Rule::constructor_expr => parse_constructor(pair), - Rule::number => parse_number(pair), - Rule::string => parse_string(pair), - Rule::boolean => parse_boolean(pair), - Rule::identifier => Expr::Var(pair.as_str().to_string()), - Rule::fail_expr => parse_fail_expr(pair), - _ => unreachable!("Unexpected expression rule: {:?}", pair.as_rule()), - } -} - -/// Parse a closure expression (e.g., "(x, y) => x + y") -fn parse_closure(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let params_pair = pairs.next().unwrap(); - - let params = if params_pair.as_rule() == Rule::identifier { - vec![params_pair.as_str().to_string()] - } else { - params_pair - .into_inner() - .map(|p| p.as_str().to_string()) - .collect() - }; - - let body = parse_expr(pairs.next().unwrap()); - Expr::Closure(params, Box::new(body)) -} - -/// Parse a binary operation with proper operator precedence -fn parse_binary_operation(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let mut expr = parse_expr(pairs.next().unwrap()); - - while let Some(op_pair) = pairs.next() { - let op = parse_binary_operator(op_pair); - let rhs = parse_expr(pairs.next().unwrap()); - expr = Expr::Binary(Box::new(expr), op, Box::new(rhs)); - } - - expr -} - -/// Parse a postfix expression (function calls, member access, array indexing) -fn parse_postfix(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let mut expr = parse_expr(pairs.next().unwrap()); - - for postfix_pair in pairs { - match postfix_pair.as_rule() { - Rule::call => { - let args = postfix_pair.into_inner().map(parse_expr).collect(); - expr = Expr::Call(Box::new(expr), args); - } - Rule::member_access => { - let member = postfix_pair.as_str().trim_start_matches('.').to_string(); - expr = Expr::Member(Box::new(expr), member); - } - Rule::array_index => { - let index = parse_expr(postfix_pair.into_inner().next().unwrap()); - expr = Expr::ArrayIndex(Box::new(expr), Box::new(index)); - } - Rule::member_call => { - let mut pairs = postfix_pair.into_inner(); - let member = pairs - .next() - .unwrap() - .as_str() - .trim_start_matches('.') - .to_string(); - let args = pairs.map(parse_expr).collect(); - expr = Expr::MemberCall(Box::new(expr), member, args); - } - _ => unreachable!("Unexpected postfix rule: {:?}", postfix_pair.as_rule()), - } - } - expr -} - -/// Parse a prefix expression (unary operators) -fn parse_prefix(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let first = pairs.next().unwrap(); - - if first.as_str() == "!" || first.as_str() == "-" { - let op = match first.as_str() { - "-" => UnaryOp::Neg, - "!" => UnaryOp::Not, - _ => unreachable!("Unexpected prefix operator: {}", first.as_str()), - }; - let expr = parse_expr(pairs.next().unwrap()); - Expr::Unary(op, Box::new(expr)) - } else { - parse_expr(first) - } -} - -/// Parse a match expression -fn parse_match_expr(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let expr = parse_expr(pairs.next().unwrap()); - let mut arms = Vec::new(); - - for arm_pair in pairs { - if arm_pair.as_rule() == Rule::match_arm { - arms.push(parse_match_arm(arm_pair)); - } - } - - Expr::Match(Box::new(expr), arms) -} - -/// Parse a match arm within a match expression -fn parse_match_arm(pair: Pair<'_, Rule>) -> MatchArm { - let mut pairs = pair.into_inner(); - let pattern = parse_pattern(pairs.next().unwrap()); - let expr = parse_expr(pairs.next().unwrap()); - MatchArm { pattern, expr } -} - -/// Parse an if expression -fn parse_if_expr(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let condition = parse_expr(pairs.next().unwrap()); - let then_branch = parse_expr(pairs.next().unwrap()); - let else_branch = parse_expr(pairs.next().unwrap()); - - Expr::If( - Box::new(condition), - Box::new(then_branch), - Box::new(else_branch), - ) -} - -/// Parse a val expression (local binding) -fn parse_val_expr(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let name = pairs.next().unwrap().as_str().to_string(); - let value = parse_expr(pairs.next().unwrap()); - let body = parse_expr(pairs.next().unwrap()); - - Expr::Val(name, Box::new(value), Box::new(body)) -} - -/// Parse an array literal -fn parse_array_literal(pair: Pair<'_, Rule>) -> Expr { - let exprs = pair.into_inner().map(parse_expr).collect(); - Expr::Literal(Literal::Array(exprs)) -} - -/// Parse a tuple literal -fn parse_tuple_literal(pair: Pair<'_, Rule>) -> Expr { - let exprs = pair.into_inner().map(parse_expr).collect(); - Expr::Literal(Literal::Tuple(exprs)) -} - -/// Parse a constructor expression -fn parse_constructor(pair: Pair<'_, Rule>) -> Expr { - let mut pairs = pair.into_inner(); - let name = pairs.next().unwrap().as_str().to_string(); - let args = pairs.map(parse_expr).collect(); - - // Check if first character is lowercase - // TODO(alexis): small hack until I rewrite the grammar using Chumsky - if name.chars().next().is_some_and(|c| c.is_ascii_lowercase()) { - Expr::Call(Box::new(Expr::Var(name)), args) - } else { - Expr::Constructor(name, args) - } -} - -/// Parse a numeric literal -fn parse_number(pair: Pair<'_, Rule>) -> Expr { - let num = pair.as_str().parse().unwrap(); - Expr::Literal(Literal::Int64(num)) -} - -/// Parse a string literal -fn parse_string(pair: Pair<'_, Rule>) -> Expr { - let s = pair.as_str().to_string(); - Expr::Literal(Literal::String(s)) -} - -/// Parse a boolean literal -fn parse_boolean(pair: Pair<'_, Rule>) -> Expr { - let b = pair.as_str().parse().unwrap(); - Expr::Literal(Literal::Bool(b)) -} - -/// Parse a fail expression -fn parse_fail_expr(pair: Pair<'_, Rule>) -> Expr { - let msg = pair.into_inner().next().unwrap().as_str().to_string(); - Expr::Fail(msg) -} - -/// Parse a binary operator -fn parse_binary_operator(pair: Pair<'_, Rule>) -> BinOp { - let op_str = match pair.as_rule() { - Rule::add_op - | Rule::mult_op - | Rule::or_op - | Rule::and_op - | Rule::compare_op - | Rule::concat_op - | Rule::range_op => pair.as_str(), - _ => pair.as_str(), - }; - - match op_str { - "+" => BinOp::Add, - "-" => BinOp::Sub, - "*" => BinOp::Mul, - "/" => BinOp::Div, - "++" => BinOp::Concat, - "==" => BinOp::Eq, - "!=" => BinOp::Neq, - ">" => BinOp::Gt, - "<" => BinOp::Lt, - ">=" => BinOp::Ge, - "<=" => BinOp::Le, - "&&" => BinOp::And, - "||" => BinOp::Or, - ".." => BinOp::Range, - _ => unreachable!("Unexpected binary operator: {}", op_str), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dsl::parser::DslParser; - use pest::Parser; - - fn parse_expr_from_str(input: &str) -> Expr { - let pair = DslParser::parse(Rule::expr, input).unwrap().next().unwrap(); - parse_expr(pair) - } - - #[test] - fn test_parse_binary_operations() { - let expr = parse_expr_from_str("1 + 2 * 3"); - match expr { - Expr::Binary(left, BinOp::Add, right) => { - assert!(matches!(*left, Expr::Literal(Literal::Int64(1)))); - match *right { - Expr::Binary(l, BinOp::Mul, r) => { - assert!(matches!(*l, Expr::Literal(Literal::Int64(2)))); - assert!(matches!(*r, Expr::Literal(Literal::Int64(3)))); - } - _ => panic!("Expected multiplication"), - } - } - _ => panic!("Expected addition"), - } - } - - #[test] - fn test_parse_if_expression() { - let expr = parse_expr_from_str("if x > 0 then 1 else 2"); - match expr { - Expr::If(cond, then_branch, else_branch) => { - match *cond { - Expr::Binary(left, BinOp::Gt, right) => { - assert!(matches!(*left, Expr::Var(v) if v == "x")); - assert!(matches!(*right, Expr::Literal(Literal::Int64(0)))); - } - _ => panic!("Expected comparison"), - } - assert!(matches!(*then_branch, Expr::Literal(Literal::Int64(1)))); - assert!(matches!(*else_branch, Expr::Literal(Literal::Int64(2)))); - } - _ => panic!("Expected if expression"), - } - } - - #[test] - fn test_parse_closure() { - let expr = parse_expr_from_str("(x, y) => x + y"); - match expr { - Expr::Closure(params, body) => { - assert_eq!(params, vec!["x", "y"]); - match *body { - Expr::Binary(left, BinOp::Add, right) => { - assert!(matches!(*left, Expr::Var(v) if v == "x")); - assert!(matches!(*right, Expr::Var(v) if v == "y")); - } - _ => panic!("Expected addition in closure body"), - } - } - _ => panic!("Expected closure"), - } - } -} diff --git a/optd-core/src/dsl/parser/functions.rs b/optd-core/src/dsl/parser/functions.rs deleted file mode 100644 index e8db8de..0000000 --- a/optd-core/src/dsl/parser/functions.rs +++ /dev/null @@ -1,228 +0,0 @@ -use pest::iterators::Pair; - -use crate::dsl::ast::upper_layer::{Function, OperatorKind, Type}; - -use super::expr::parse_expr; -use super::types::parse_type_expr; -use super::Rule; - -/// Parse a function definition from a pest Pair -/// -/// # Arguments -/// * `pair` - The pest Pair containing the function definition -/// -/// # Returns -/// * `Function` - The parsed function AST node -pub fn parse_function_def(pair: Pair<'_, Rule>) -> Function { - let mut name = None; - let mut params = Vec::new(); - let mut return_type = None; - let mut body = None; - let mut rule_type = None; - - for inner_pair in pair.into_inner() { - match inner_pair.as_rule() { - Rule::identifier => { - name = Some(inner_pair.as_str().to_string()); - } - Rule::params => { - params = parse_params(inner_pair); - } - Rule::type_expr => { - return_type = Some(parse_type_expr(inner_pair)); - } - Rule::expr => { - body = Some(parse_expr(inner_pair)); - } - Rule::rule_annot => { - rule_type = Some(parse_rule_annotation(inner_pair)); - } - _ => unreachable!( - "Unexpected function definition rule: {:?}", - inner_pair.as_rule() - ), - } - } - - Function { - name: name.unwrap(), - params, - return_type: return_type.unwrap(), - body: body.unwrap(), - rule_type, - } -} - -/// Parse function parameters -fn parse_params(pair: Pair<'_, Rule>) -> Vec<(String, Type)> { - let mut params = Vec::new(); - - for param_pair in pair.into_inner() { - if param_pair.as_rule() == Rule::param { - let (name, ty) = parse_param(param_pair); - params.push((name, ty)); - } - } - - params -} - -/// Parse a single parameter -fn parse_param(pair: Pair<'_, Rule>) -> (String, Type) { - let mut param_name = None; - let mut param_type = None; - - for param_inner_pair in pair.into_inner() { - match param_inner_pair.as_rule() { - Rule::identifier => { - param_name = Some(param_inner_pair.as_str().to_string()); - } - Rule::type_expr => { - param_type = Some(parse_type_expr(param_inner_pair)); - } - _ => unreachable!( - "Unexpected parameter rule: {:?}", - param_inner_pair.as_rule() - ), - } - } - - (param_name.unwrap(), param_type.unwrap()) -} - -/// Parse a rule annotation (@rule(Scalar) or @rule(Logical)) -fn parse_rule_annotation(pair: Pair<'_, Rule>) -> OperatorKind { - let annot_inner_pair = pair.into_inner().next().unwrap(); - match annot_inner_pair.as_str() { - "Scalar" => OperatorKind::Scalar, - "Logical" => OperatorKind::Logical, - _ => unreachable!("Unexpected rule type: {}", annot_inner_pair.as_str()), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dsl::ast::upper_layer::{BinOp, Expr, Literal, Type}; - use crate::dsl::parser::{DslParser, Rule}; - use pest::Parser; - - fn parse_function_from_str(input: &str) -> Function { - let pair = DslParser::parse(Rule::function_def, input) - .unwrap() - .next() - .unwrap(); - parse_function_def(pair) - } - - #[test] - fn test_parse_simple_function() { - let input = "def add(x: Int64, y: Int64): Int64 = x + y"; - - let function = parse_function_from_str(input); - assert_eq!(function.name, "add"); - assert_eq!(function.params.len(), 2); - - // Check parameters - assert_eq!(function.params[0].0, "x"); - assert!(matches!(function.params[0].1, Type::Int64)); - assert_eq!(function.params[1].0, "y"); - assert!(matches!(function.params[1].1, Type::Int64)); - - // Check return type - assert!(matches!(function.return_type, Type::Int64)); - - // Check body - match function.body { - Expr::Binary(left, BinOp::Add, right) => { - assert!(matches!(*left, Expr::Var(v) if v == "x")); - assert!(matches!(*right, Expr::Var(v) if v == "y")); - } - _ => panic!("Expected binary addition expression"), - } - - assert!(function.rule_type.is_none()); - } - - #[test] - fn test_parse_function_with_rule_annotation() { - let input = r#"@rule(Scalar) - def increment(x: Int64): Int64 = x + 1 - "#; - - let function = parse_function_from_str(input); - assert_eq!(function.name, "increment"); - assert_eq!(function.params.len(), 1); - assert!(matches!(function.rule_type, Some(OperatorKind::Scalar))); - } - - #[test] - fn test_parse_function_with_block_body() { - let input = r#"def max(x: Int64, y: Int64): Int64 = { - if x > y then x else y - } - "#; - - let function = parse_function_from_str(input); - assert_eq!(function.name, "max"); - assert_eq!(function.params.len(), 2); - - match function.body { - Expr::If(cond, then_branch, else_branch) => { - match *cond { - Expr::Binary(left, BinOp::Gt, right) => { - assert!(matches!(*left, Expr::Var(v) if v == "x")); - assert!(matches!(*right, Expr::Var(v) if v == "y")); - } - _ => panic!("Expected comparison expression"), - } - assert!(matches!(*then_branch, Expr::Var(v) if v == "x")); - assert!(matches!(*else_branch, Expr::Var(v) if v == "y")); - } - _ => panic!("Expected if expression"), - } - } - - #[test] - fn test_parse_function_with_array_parameter() { - let input = "def sum(values: [Int64]): Int64 = 0"; - - let function = parse_function_from_str(input); - assert_eq!(function.name, "sum"); - assert_eq!(function.params.len(), 1); - - match &function.params[0].1 { - Type::Array(inner) => assert!(matches!(**inner, Type::Int64)), - _ => panic!("Expected array type parameter"), - } - } - - #[test] - fn test_parse_function_with_no_parameters() { - let input = "def get_zero(): Int64 = 0"; - - let function = parse_function_from_str(input); - assert_eq!(function.name, "get_zero"); - assert_eq!(function.params.len(), 0); - assert!(matches!(function.return_type, Type::Int64)); - match function.body { - Expr::Literal(Literal::Int64(n)) => assert_eq!(n, 0), - _ => panic!("Expected integer literal"), - } - } - - #[test] - fn test_parse_function_with_complex_return_type() { - let input = "def make_pair(x: Int64): (Int64, Int64) = (x, x)"; - - let function = parse_function_from_str(input); - match function.return_type { - Type::Tuple(types) => { - assert_eq!(types.len(), 2); - assert!(matches!(types[0], Type::Int64)); - assert!(matches!(types[1], Type::Int64)); - } - _ => panic!("Expected tuple return type"), - } - } -} diff --git a/optd-core/src/dsl/parser/grammar.pest b/optd-core/src/dsl/parser/grammar.pest deleted file mode 100644 index 4a415a9..0000000 --- a/optd-core/src/dsl/parser/grammar.pest +++ /dev/null @@ -1,239 +0,0 @@ -KEYWORD = { - "Bool" | - "Float64" | - "Int64" | - "Logical" | - "Map" | - "Props" | - "Scalar" | - "String" | - "case" | - "def" | - "derive" | - "else" | - "fail" | - "false" | - "if" | - "match" | - "then" | - "true" | - "val" -} - -// Whitespace and comments -WHITESPACE = _{ " " | "\t" | "\r" | "\n" } -COMMENT = _{ "//" ~ (!"\n" ~ ANY)* } - -// Basic identifiers and literals -identifier = @{ - !(KEYWORD ~ !(ASCII_ALPHANUMERIC | "_")) ~ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* -} -number = @{ "-"? ~ ASCII_DIGIT+ } -string = @{ "\"" ~ (!"\"" ~ ANY)* ~ "\"" } -boolean = { "true" | "false" } - -// Types -operator_type = { "Scalar" | "Logical" } -base_type = { "Int64" | "String" | "Bool" | "Float64" } -array_type = { "[" ~ type_expr ~ "]" } -map_type = { "Map" ~ "[" ~ type_expr ~ "," ~ type_expr ~ "]" } -tuple_type = { "(" ~ type_expr ~ "," ~ type_expr ~ ("," ~ type_expr)* ~ ")" } -function_type = { "(" ~ type_expr ~ ")" ~ "->" ~ type_expr } // Function composition type -type_expr = { - function_type - | operator_type - | base_type - | array_type - | map_type - | tuple_type -} - -// Annotations -props_annot = { "Logical" ~ "Props" } -rule_annot = { "@rule" ~ "(" ~ operator_type ~ ")" } - -// Operators -operator_def = { - operator_type ~ identifier ~ NEWLINE? - ~ "(" ~ NEWLINE? - ~ (field_def_list ~ NEWLINE?)? - ~ ")" ~ NEWLINE? - ~ (derive_props_block ~ NEWLINE?)? -} - -field_def_list = { - (field_def ~ "," ~ NEWLINE?)* ~ field_def -} - -field_def = { identifier ~ ":" ~ type_expr } - -// Property blocks -props_block = { - props_annot ~ NEWLINE? - ~ "(" ~ NEWLINE? - ~ (field_def ~ "," ~ NEWLINE?)* ~ field_def - ~ ")" -} - -derive_props_block = { - "derive" ~ NEWLINE? - ~ "{" ~ NEWLINE? - ~ (prop_derivation ~ ("," ~ NEWLINE? ~ prop_derivation)*)? ~ NEWLINE? - ~ "}" -} - -prop_derivation = { identifier ~ "=" ~ expr } - -// Functions -function_def = { normal_function | rule_def } -normal_function = _{ - "def" ~ identifier ~ "(" ~ params? ~ ")" ~ ":" ~ type_expr ~ "=" ~ NEWLINE? - ~ (expr | "{" ~ NEWLINE? ~ expr ~ NEWLINE? ~ "}") // Braces optional -} - -rule_def = _{ - rule_annot ~ NEWLINE? - ~ normal_function -} - -params = { param ~ ("," ~ param)* } -param = { identifier ~ ":" ~ type_expr } - -// Pattern matching -pattern = { - (identifier ~ "@" ~ top_level_pattern) - | top_level_pattern -} - -top_level_pattern = _{ - constructor_pattern - | literal_pattern - | wildcard_pattern -} - -constructor_pattern = { - identifier ~ "(" ~ constructor_pattern_fields? ~ ")" -} - -constructor_pattern_fields = { - pattern_field ~ ("," ~ pattern_field)* -} - -// Inside a constructor -pattern_field = { - (identifier ~ ":" ~ top_level_pattern) // Bind & match - | top_level_pattern // Recurse match - | identifier // Simple bind -} - -literal_pattern = { - number - | string - | boolean - | array_literal - | tuple_literal -} - -wildcard_pattern = { "_" } - -// Expressions with precedence (lowest to highest) -expr = { - "{" ~ NEWLINE? ~ (closure | logical_or) ~ NEWLINE? ~ "}" - | (closure | logical_or) -} - -// Operators as separate rules -or_op = { "||" } -and_op = { "&&" } -compare_op = { "==" | "!=" | ">=" | "<=" | ">" | "<" } -concat_op = { "++" } -add_op = { "+" | "-" } -range_op = { ".." } -mult_op = { "*" | "/" } - -// Closure is lowest precedence -closure = { closure_params ~ "=>" ~ expr } -closure_params = { - "(" ~ (identifier ~ ("," ~ identifier)*)? ~ ")" - | identifier // Single param case -} - -logical_or = { logical_and ~ (or_op ~ logical_and)* } -logical_and = { comparison ~ (and_op ~ comparison)* } -comparison = { concatenation ~ (compare_op ~ concatenation)* } -concatenation = { additive ~ (concat_op ~ additive)* } -additive = { range ~ (add_op ~ range)* } -range = { multiplicative ~ (range_op ~ multiplicative)* } -multiplicative = { postfix ~ (mult_op ~ postfix)* } -postfix = { prefix ~ (member_call | member_access | call | array_index)* } -prefix = { ("!" | "-")? ~ primary } - -primary = _{ - match_expr - | if_expr - | val_expr - | array_literal - | tuple_literal - | constructor_expr - | term -} - -term = _{ - number - | string - | boolean - | identifier - | "(" ~ expr ~ ")" - | fail_expr -} - -// Constructor syntax -constructor_expr = { - identifier ~ "(" ~ constructor_fields? ~ ")" -} - -constructor_fields = _{ - expr ~ ("," ~ expr)* -} - -// Method calls and operations -call = { "(" ~ (expr ~ ("," ~ expr)*)? ~ ")" } -member_access = { "." ~ identifier } -member_call = { "." ~ identifier ~ "(" ~ (expr ~ ("," ~ expr)*)? ~ ")"} -array_index = { "[" ~ expr ~ "]" } - -// Array and map literals -array_literal = { "[" ~ (expr ~ ("," ~ expr)*)? ~ "]" } -tuple_literal = { "(" ~ (expr ~ ("," ~ expr)*)? ~ ")" } - -// Control expressions -fail_expr = { "fail" ~ "(" ~ string ~ ")" } - -// Braces optional -match_expr = { - "match" ~ expr ~ NEWLINE? - ~ (("{" ~ NEWLINE? - ~ (match_arm ~ ("," ~ NEWLINE? ~ match_arm)*)? - ~ "}") - | (match_arm ~ ("," ~ NEWLINE? ~ match_arm)*)) - ~ NEWLINE? -} -match_arm = { "case" ~ pattern ~ "=>" ~ expr } - -// Braces optional -if_expr = { - "if" ~ expr ~ NEWLINE? - ~ (("then" ~ expr) | "{" ~ NEWLINE? ~ expr ~ NEWLINE? ~ "}") - ~ "else" ~ NEWLINE? - ~ (expr | "{" ~ NEWLINE? ~ expr ~ NEWLINE? ~ "}") -} - -val_expr = { "val" ~ identifier ~ "=" ~ expr ~ ";" ~ expr } - -// Full file -file = { - SOI ~ - props_block ~ - (operator_def | function_def)* ~ - EOI -} \ No newline at end of file diff --git a/optd-core/src/dsl/parser/mod.rs b/optd-core/src/dsl/parser/mod.rs deleted file mode 100644 index 2dc23fd..0000000 --- a/optd-core/src/dsl/parser/mod.rs +++ /dev/null @@ -1,216 +0,0 @@ -use pest::{iterators::Pair, Parser}; -use pest_derive::Parser; - -pub mod expr; -pub mod functions; -pub mod operators; -pub mod patterns; -pub mod types; - -use functions::parse_function_def; -use operators::parse_operator_def; -use types::parse_type_expr; - -use pest::error::Error; - -use super::ast::upper_layer::{Field, File, Properties}; - -/// The main parser for the DSL, derived using pest -#[derive(Parser)] -#[grammar = "dsl/parser/grammar.pest"] -pub struct DslParser; - -/// Parses a complete DSL file into the AST -/// -/// # Arguments -/// * `input` - The input string containing the DSL code -/// -/// # Returns -/// * `Result>` - The parsed AST or a parsing error -pub fn parse_file(input: &str) -> Result>> { - let pairs = DslParser::parse(Rule::file, input)? - .next() - .unwrap() - .into_inner(); - - let mut file = File { - properties: Properties { fields: Vec::new() }, - operators: Vec::new(), - functions: Vec::new(), - }; - - for pair in pairs { - match pair.as_rule() { - Rule::props_block => { - file.properties = parse_properties_block(pair); - } - Rule::operator_def => { - file.operators.push(parse_operator_def(pair)); - } - Rule::function_def => { - file.functions.push(parse_function_def(pair)); - } - _ => {} - } - } - - Ok(file) -} - -/// Parses a properties block in the DSL -/// -/// # Arguments -/// * `pair` - The pest Pair containing the properties block -/// -/// # Returns -/// * `Properties` - The parsed properties structure -fn parse_properties_block(pair: Pair) -> Properties { - let mut properties = Properties { fields: Vec::new() }; - - for field_pair in pair.into_inner() { - if field_pair.as_rule() == Rule::field_def { - properties.fields.push(parse_field_def(field_pair)); - } - } - - properties -} - -/// Parses a field definition -/// -/// # Arguments -/// * `pair` - The pest Pair containing the field definition -/// -/// # Returns -/// * `Field` - The parsed field structure -pub(crate) fn parse_field_def(pair: Pair) -> Field { - let mut name = None; - let mut ty = None; - - for inner_pair in pair.into_inner() { - match inner_pair.as_rule() { - Rule::identifier => { - name = Some(inner_pair.as_str().to_string()); - } - Rule::type_expr => { - ty = Some(parse_type_expr(inner_pair)); - } - _ => {} - } - } - - Field { - name: name.unwrap(), - ty: ty.unwrap(), - } -} - -#[cfg(test)] -mod tests { - use crate::dsl::ast::upper_layer::Type; - - use super::*; - use pest::Parser; - - fn parse_props_from_str(input: &str) -> Properties { - let pair = DslParser::parse(Rule::props_block, input) - .unwrap() - .next() - .unwrap(); - parse_properties_block(pair) - } - - #[test] - fn parse_example_files() { - let input = include_str!("../programs/example.optd"); - parse_file(input).unwrap(); - - let input = include_str!("../programs/working.optd"); - let out = parse_file(input).unwrap(); - println!("{:#?}", out); - } - - #[test] - fn test_parse_single_property() { - let input = r#"Logical Props ( - name: String - )"#; - - let props = parse_props_from_str(input); - assert_eq!(props.fields.len(), 1); - assert_eq!(props.fields[0].name, "name"); - assert!(matches!(props.fields[0].ty, Type::String)); - } - - #[test] - fn test_parse_multiple_properties() { - let input = r#"Logical Props( - name: String, - value: Int64 - )"#; - - let props = parse_props_from_str(input); - assert_eq!(props.fields.len(), 2); - assert_eq!(props.fields[0].name, "name"); - assert!(matches!(props.fields[0].ty, Type::String)); - assert_eq!(props.fields[1].name, "value"); - assert!(matches!(props.fields[1].ty, Type::Int64)); - } - - #[test] - fn test_parse_array_property() { - let input = r#"Logical Props ( - values: [Int64] - )"#; - - let props = parse_props_from_str(input); - assert_eq!(props.fields.len(), 1); - assert_eq!(props.fields[0].name, "values"); - match &props.fields[0].ty { - Type::Array(inner) => assert!(matches!(**inner, Type::Int64)), - _ => panic!("Expected array type"), - } - } - - #[test] - fn test_parse_complex_property_types() { - let input = r#"Logical Props ( - pairs: Map[String, Int64], - coords: (Int64, Int64) - )"#; - - let props = parse_props_from_str(input); - assert_eq!(props.fields.len(), 2); - - // Check Map type - assert_eq!(props.fields[0].name, "pairs"); - match &props.fields[0].ty { - Type::Map(key, value) => { - assert!(matches!(**key, Type::String)); - assert!(matches!(**value, Type::Int64)); - } - _ => panic!("Expected map type"), - } - - // Check Tuple type - assert_eq!(props.fields[1].name, "coords"); - match &props.fields[1].ty { - Type::Tuple(types) => { - assert_eq!(types.len(), 2); - assert!(matches!(types[0], Type::Int64)); - assert!(matches!(types[1], Type::Int64)); - } - _ => panic!("Expected tuple type"), - } - } - - #[test] - #[should_panic] - fn test_parse_invalid_property() { - let input = r#"Logical Props ( - invalid - )"#; - - parse_props_from_str(input); - } -} diff --git a/optd-core/src/dsl/parser/operators.rs b/optd-core/src/dsl/parser/operators.rs deleted file mode 100644 index f8d2ee8..0000000 --- a/optd-core/src/dsl/parser/operators.rs +++ /dev/null @@ -1,238 +0,0 @@ -use pest::iterators::Pair; -use std::collections::HashMap; - -use super::{expr::parse_expr, parse_field_def, Rule}; -use crate::dsl::ast::upper_layer::{Expr, Field, LogicalOp, Operator, OperatorKind, ScalarOp}; - -/// Parse an operator definition from a pest Pair -/// -/// # Arguments -/// * `pair` - The pest Pair containing the operator definition -/// -/// # Returns -/// * `Operator` - The parsed operator AST node -pub fn parse_operator_def(pair: Pair<'_, Rule>) -> Operator { - let mut operator_type = None; - let mut name = None; - let mut fields = Vec::new(); - let mut derived_props = HashMap::new(); - - for inner_pair in pair.into_inner() { - match inner_pair.as_rule() { - Rule::operator_type => { - operator_type = Some(parse_operator_type(inner_pair)); - } - Rule::identifier => { - name = Some(inner_pair.as_str().to_string()); - } - Rule::field_def_list => { - fields = parse_field_def_list(inner_pair); - } - Rule::derive_props_block => { - derived_props = parse_derive_props_block(inner_pair); - } - _ => unreachable!( - "Unexpected operator definition rule: {:?}", - inner_pair.as_rule() - ), - } - } - - create_operator(operator_type.unwrap(), name.unwrap(), fields, derived_props) -} - -/// Parse an operator type (Scalar or Logical) -fn parse_operator_type(pair: Pair<'_, Rule>) -> OperatorKind { - match pair.as_str() { - "Scalar" => OperatorKind::Scalar, - "Logical" => OperatorKind::Logical, - _ => unreachable!("Unexpected operator type: {}", pair.as_str()), - } -} - -/// Parse a list of field definitions -fn parse_field_def_list(pair: Pair<'_, Rule>) -> Vec { - pair.into_inner().map(parse_field_def).collect() -} - -/// Parse a derive properties block -fn parse_derive_props_block(pair: Pair<'_, Rule>) -> HashMap { - let mut props = HashMap::new(); - - for prop_pair in pair.into_inner() { - if prop_pair.as_rule() == Rule::prop_derivation { - let (name, expr) = parse_prop_derivation(prop_pair); - props.insert(name, expr); - } - } - - props -} - -/// Parse a single property derivation -fn parse_prop_derivation(pair: Pair<'_, Rule>) -> (String, Expr) { - let mut prop_name = None; - let mut expr = None; - - for inner_pair in pair.into_inner() { - match inner_pair.as_rule() { - Rule::identifier => { - prop_name = Some(inner_pair.as_str().to_string()); - } - Rule::expr => { - expr = Some(parse_expr(inner_pair)); - } - _ => unreachable!( - "Unexpected property derivation rule: {:?}", - inner_pair.as_rule() - ), - } - } - - (prop_name.unwrap(), expr.unwrap()) -} - -/// Create an operator based on its type and components -fn create_operator( - operator_type: OperatorKind, - name: String, - fields: Vec, - derived_props: HashMap, -) -> Operator { - match operator_type { - OperatorKind::Scalar => Operator::Scalar(ScalarOp { name, fields }), - OperatorKind::Logical => Operator::Logical(LogicalOp { - name, - fields, - derived_props, - }), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dsl::ast::upper_layer::{BinOp, Type}; - use crate::dsl::parser::DslParser; - use pest::Parser; - - fn parse_operator_from_str(input: &str) -> Operator { - let pair = DslParser::parse(Rule::operator_def, input) - .unwrap() - .next() - .unwrap(); - parse_operator_def(pair) - } - - #[test] - fn test_parse_scalar_operator() { - let input = r#"Scalar Add( - x: Int64, - y: Int64 - ) - "#; - - match parse_operator_from_str(input) { - Operator::Scalar(op) => { - assert_eq!(op.name, "Add"); - assert_eq!(op.fields.len(), 2); - assert_eq!(op.fields[0].name, "x"); - assert!(matches!(op.fields[0].ty, Type::Int64)); - assert_eq!(op.fields[1].name, "y"); - assert!(matches!(op.fields[1].ty, Type::Int64)); - } - _ => panic!("Expected scalar operator"), - } - } - - #[test] - fn test_parse_logical_operator() { - let input = r#"Logical And( - left: Bool, - right: Bool - ) derive { - result = left && right - } - "#; - - match parse_operator_from_str(input) { - Operator::Logical(op) => { - assert_eq!(op.name, "And"); - assert_eq!(op.fields.len(), 2); - assert_eq!(op.fields[0].name, "left"); - assert!(matches!(op.fields[0].ty, Type::Bool)); - assert_eq!(op.fields[1].name, "right"); - assert!(matches!(op.fields[1].ty, Type::Bool)); - - assert_eq!(op.derived_props.len(), 1); - assert!(op.derived_props.contains_key("result")); - - match op.derived_props.get("result").unwrap() { - Expr::Binary(left, BinOp::And, right) => { - assert!(matches!(**left, Expr::Var(ref v) if v == "left")); - assert!(matches!(**right, Expr::Var(ref v) if v == "right")); - } - _ => panic!("Expected binary AND expression"), - } - } - _ => panic!("Expected logical operator"), - } - } - - #[test] - fn test_parse_operator_with_multiple_derived_props() { - let input = r#"Logical Or( - left: Bool, - right: Bool - ) derive { - result = left || right, - description = "Logical OR" - } - "#; - - match parse_operator_from_str(input) { - Operator::Logical(op) => { - assert_eq!(op.name, "Or"); - assert_eq!(op.derived_props.len(), 2); - assert!(op.derived_props.contains_key("result")); - assert!(op.derived_props.contains_key("description")); - } - _ => panic!("Expected logical operator"), - } - } - - #[test] - fn test_parse_operator_without_fields() { - let input = r#"Scalar True() - "#; - - match parse_operator_from_str(input) { - Operator::Scalar(op) => { - assert_eq!(op.name, "True"); - assert_eq!(op.fields.len(), 0); - } - _ => panic!("Expected scalar operator"), - } - } - - #[test] - fn test_parse_operator_with_array_field() { - let input = r#"Scalar Sum( - values: [Int64] - ) - "#; - - match parse_operator_from_str(input) { - Operator::Scalar(op) => { - assert_eq!(op.name, "Sum"); - assert_eq!(op.fields.len(), 1); - assert_eq!(op.fields[0].name, "values"); - match &op.fields[0].ty { - Type::Array(inner) => assert!(matches!(**inner, Type::Int64)), - _ => panic!("Expected array type"), - } - } - _ => panic!("Expected scalar operator"), - } - } -} diff --git a/optd-core/src/dsl/parser/patterns.rs b/optd-core/src/dsl/parser/patterns.rs deleted file mode 100644 index 857093a..0000000 --- a/optd-core/src/dsl/parser/patterns.rs +++ /dev/null @@ -1,202 +0,0 @@ -use pest::iterators::Pair; - -use super::{expr::parse_expr, Rule}; -use crate::dsl::ast::upper_layer::{Literal, Pattern}; - -/// Parse a pattern from a pest Pair -/// -/// # Arguments -/// * `pair` - The pest Pair containing the pattern -/// -/// # Returns -/// * `Pattern` - The parsed pattern AST node -pub fn parse_pattern(pair: Pair<'_, Rule>) -> Pattern { - match pair.as_rule() { - Rule::pattern => { - let mut pairs = pair.into_inner(); - let first = pairs.next().unwrap(); - - if first.as_rule() == Rule::identifier { - if let Some(at_pattern) = pairs.next() { - // This is a binding pattern (x @ pattern) - let name = first.as_str().to_string(); - let pattern = parse_pattern(at_pattern); - Pattern::Bind(name, Box::new(pattern)) - } else { - // This is a simple variable pattern - Pattern::Var(first.as_str().to_string()) - } - } else { - parse_pattern(first) - } - } - Rule::constructor_pattern => parse_constructor_pattern(pair), - Rule::literal_pattern => parse_literal_pattern(pair), - Rule::wildcard_pattern => Pattern::Wildcard, - _ => unreachable!("Unexpected pattern rule: {:?}", pair.as_rule()), - } -} - -/// Parse a constructor pattern (e.g., Some(x), Node(left, right)) -fn parse_constructor_pattern(pair: Pair<'_, Rule>) -> Pattern { - let mut pairs = pair.into_inner(); - let name = pairs.next().unwrap().as_str().to_string(); - let mut subpatterns = Vec::new(); - - if let Some(fields) = pairs.next() { - for field in fields.into_inner() { - if field.as_rule() == Rule::pattern_field { - let mut field_pairs = field.into_inner(); - let first = field_pairs.next().unwrap(); - - if first.as_rule() == Rule::identifier { - // Check for field binding (identifier : pattern) - if let Some(colon_pattern) = field_pairs.next() { - let field_name = first.as_str().to_string(); - let pattern = parse_pattern(colon_pattern); - subpatterns.push(Pattern::Bind(field_name, Box::new(pattern))); - } else { - // Simple identifier pattern - subpatterns.push(Pattern::Var(first.as_str().to_string())); - } - } else { - // Regular pattern - subpatterns.push(parse_pattern(first)); - } - } - } - } - - Pattern::Constructor(name, subpatterns) -} - -/// Parse a literal pattern (numbers, strings, booleans, arrays, tuples) -fn parse_literal_pattern(pair: Pair<'_, Rule>) -> Pattern { - let literal = parse_literal(pair.into_inner().next().unwrap()); - Pattern::Literal(literal) -} - -/// Parse a literal value -fn parse_literal(pair: Pair<'_, Rule>) -> Literal { - match pair.as_rule() { - Rule::number => Literal::Int64(pair.as_str().parse().unwrap()), - Rule::string => Literal::String(pair.as_str().to_string()), - Rule::boolean => Literal::Bool(pair.as_str().parse().unwrap()), - Rule::array_literal => { - let exprs = pair.into_inner().map(parse_expr).collect(); - Literal::Array(exprs) - } - Rule::tuple_literal => { - let exprs = pair.into_inner().map(parse_expr).collect(); - Literal::Tuple(exprs) - } - _ => unreachable!("Unexpected literal rule: {:?}", pair.as_rule()), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dsl::parser::DslParser; - use pest::Parser; - - fn parse_pattern_from_str(input: &str) -> Pattern { - let pair = DslParser::parse(Rule::pattern, input) - .unwrap() - .next() - .unwrap(); - parse_pattern(pair) - } - - #[test] - fn test_parse_constructor_pattern() { - match parse_pattern_from_str("Some(x)") { - Pattern::Constructor(name, patterns) => { - assert_eq!(name, "Some"); - assert_eq!(patterns.len(), 1); - match &patterns[0] { - Pattern::Var(var_name) => assert_eq!(var_name, "x"), - _ => panic!("Expected variable pattern inside constructor"), - } - } - _ => panic!("Expected constructor pattern"), - } - } - - #[test] - fn test_parse_literal_pattern() { - match parse_pattern_from_str("42") { - Pattern::Literal(Literal::Int64(n)) => assert_eq!(n, 42), - _ => panic!("Expected integer literal pattern"), - } - - match parse_pattern_from_str("\"hello\"") { - Pattern::Literal(Literal::String(s)) => assert_eq!(s, "\"hello\""), - _ => panic!("Expected string literal pattern"), - } - - match parse_pattern_from_str("true") { - Pattern::Literal(Literal::Bool(b)) => assert!(b), - _ => panic!("Expected boolean literal pattern"), - } - } - - #[test] - fn test_parse_wildcard_pattern() { - assert!(matches!(parse_pattern_from_str("_"), Pattern::Wildcard)); - } - - #[test] - fn test_parse_binding_pattern() { - match parse_pattern_from_str("x @ Some(y)") { - Pattern::Bind(name, pattern) => { - assert_eq!(name, "x"); - match *pattern { - Pattern::Constructor(cname, patterns) => { - assert_eq!(cname, "Some"); - assert_eq!(patterns.len(), 1); - match &patterns[0] { - Pattern::Var(var_name) => assert_eq!(var_name, "y"), - _ => panic!("Expected variable pattern inside constructor"), - } - } - _ => panic!("Expected constructor pattern after binding"), - } - } - _ => panic!("Expected binding pattern"), - } - } - - #[test] - fn test_parse_complex_constructor_pattern() { - match parse_pattern_from_str("Node(_, right: Some(x))") { - Pattern::Constructor(name, patterns) => { - assert_eq!(name, "Node"); - assert_eq!(patterns.len(), 2); - - // Check left pattern - match &patterns[0] { - Pattern::Wildcard => {} - _ => panic!("Expected wildcard pattern for left field"), - } - - // Check right pattern - match &patterns[1] { - Pattern::Bind(name, pattern) => { - assert_eq!(name, "right"); - match **pattern { - Pattern::Constructor(ref cname, ref inner_patterns) => { - assert_eq!(cname, "Some"); - assert_eq!(inner_patterns.len(), 1); - assert!(matches!(&inner_patterns[0], Pattern::Var(n) if n == "x")); - } - _ => panic!("Expected Some constructor in right pattern"), - } - } - _ => panic!("Expected binding pattern for right field"), - } - } - _ => panic!("Expected complex constructor pattern"), - } - } -} diff --git a/optd-core/src/dsl/parser/types.rs b/optd-core/src/dsl/parser/types.rs deleted file mode 100644 index 6b85d2f..0000000 --- a/optd-core/src/dsl/parser/types.rs +++ /dev/null @@ -1,166 +0,0 @@ -use pest::iterators::Pair; - -use crate::dsl::ast::upper_layer::{OperatorKind, Type}; - -use super::Rule; - -/// Parse a type expression from a pest Pair -/// -/// # Arguments -/// * `pair` - The pest Pair containing the type expression -/// -/// # Returns -/// * `Type` - The parsed type AST node -pub fn parse_type_expr(pair: Pair<'_, Rule>) -> Type { - match pair.as_rule() { - Rule::type_expr => parse_type_expr(pair.into_inner().next().unwrap()), - Rule::base_type => parse_base_type(pair), - Rule::array_type => parse_array_type(pair), - Rule::map_type => parse_map_type(pair), - Rule::tuple_type => parse_tuple_type(pair), - Rule::function_type => parse_function_type(pair), - Rule::operator_type => parse_operator_type(pair), - _ => unreachable!("Unexpected type rule: {:?}", pair.as_rule()), - } -} - -/// Parse a base type (Int64, String, Bool, Float64) -fn parse_base_type(pair: Pair<'_, Rule>) -> Type { - match pair.as_str() { - "Int64" => Type::Int64, - "String" => Type::String, - "Bool" => Type::Bool, - "Float64" => Type::Float64, - _ => unreachable!("Unexpected base type: {}", pair.as_str()), - } -} - -/// Parse an array type (e.g., [Int64]) -fn parse_array_type(pair: Pair<'_, Rule>) -> Type { - let inner_type = pair.into_inner().next().unwrap(); - Type::Array(Box::new(parse_type_expr(inner_type))) -} - -/// Parse a map type (e.g., Map[String, Int64]) -fn parse_map_type(pair: Pair<'_, Rule>) -> Type { - let mut inner_types = pair.into_inner(); - let key_type = parse_type_expr(inner_types.next().unwrap()); - let value_type = parse_type_expr(inner_types.next().unwrap()); - Type::Map(Box::new(key_type), Box::new(value_type)) -} - -/// Parse a tuple type (e.g., (Int64, String)) -fn parse_tuple_type(pair: Pair<'_, Rule>) -> Type { - let types = pair.into_inner().map(parse_type_expr).collect(); - Type::Tuple(types) -} - -/// Parse a function type (e.g., (Int64) -> String) -fn parse_function_type(pair: Pair<'_, Rule>) -> Type { - let mut inner_types = pair.into_inner(); - let input_type = parse_type_expr(inner_types.next().unwrap()); - let output_type = parse_type_expr(inner_types.next().unwrap()); - Type::Function(Box::new(input_type), Box::new(output_type)) -} - -/// Parse an operator type (Scalar or Logical) -fn parse_operator_type(pair: Pair<'_, Rule>) -> Type { - match pair.as_str() { - "Scalar" => Type::Operator(OperatorKind::Scalar), - "Logical" => Type::Operator(OperatorKind::Logical), - _ => unreachable!("Unexpected operator type: {}", pair.as_str()), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dsl::parser::DslParser; - use pest::Parser; - - fn parse_type_from_str(input: &str) -> Type { - let pair = DslParser::parse(Rule::type_expr, input) - .unwrap() - .next() - .unwrap(); - parse_type_expr(pair) - } - - #[test] - fn test_parse_base_types() { - assert!(matches!(parse_type_from_str("Int64"), Type::Int64)); - assert!(matches!(parse_type_from_str("String"), Type::String)); - assert!(matches!(parse_type_from_str("Bool"), Type::Bool)); - assert!(matches!(parse_type_from_str("Float64"), Type::Float64)); - } - - #[test] - fn test_parse_array_type() { - match parse_type_from_str("[Int64]") { - Type::Array(inner) => assert!(matches!(*inner, Type::Int64)), - _ => panic!("Expected array type"), - } - } - - #[test] - fn test_parse_map_type() { - match parse_type_from_str("Map[String, Int64]") { - Type::Map(key, value) => { - assert!(matches!(*key, Type::String)); - assert!(matches!(*value, Type::Int64)); - } - _ => panic!("Expected map type"), - } - } - - #[test] - fn test_parse_tuple_type() { - match parse_type_from_str("(Int64, String, Bool)") { - Type::Tuple(types) => { - assert_eq!(types.len(), 3); - assert!(matches!(types[0], Type::Int64)); - assert!(matches!(types[1], Type::String)); - assert!(matches!(types[2], Type::Bool)); - } - _ => panic!("Expected tuple type"), - } - } - - #[test] - fn test_parse_function_type() { - match parse_type_from_str("(Int64) -> String") { - Type::Function(input, output) => { - assert!(matches!(*input, Type::Int64)); - assert!(matches!(*output, Type::String)); - } - _ => panic!("Expected function type"), - } - } - - #[test] - fn test_parse_operator_type() { - match parse_type_from_str("Scalar") { - Type::Operator(kind) => assert!(matches!(kind, OperatorKind::Scalar)), - _ => panic!("Expected scalar operator type"), - } - - match parse_type_from_str("Logical") { - Type::Operator(kind) => assert!(matches!(kind, OperatorKind::Logical)), - _ => panic!("Expected logical operator type"), - } - } - - #[test] - fn test_parse_nested_types() { - match parse_type_from_str("Map[String, [Int64]]") { - Type::Map(key, value) => { - assert!(matches!(*key, Type::String)); - match *value { - Type::Array(inner) => assert!(matches!(*inner, Type::Int64)), - _ => panic!("Expected array type"), - } - } - _ => panic!("Expected map type"), - } - } -} diff --git a/optd-core/src/dsl/programs/example.optd b/optd-core/src/dsl/programs/example.optd deleted file mode 100644 index a07e4fe..0000000 --- a/optd-core/src/dsl/programs/example.optd +++ /dev/null @@ -1,148 +0,0 @@ -// Shared properties for Logical operators -Logical Props(vale: [Int64]) - -Scalar ColumnRef (idx: (Int64) -> Map[(Int64, (Int64) -> String), Int64]) -Scalar Const(asd: Int64) -Scalar Eq(left: Map[Int64, Int64]) -Scalar Divide(left: Scalar, right: Scalar) -Scalar Not(input: Scalar) - -// Logical Operators -Logical Filter(input: Logical, cond: Scalar) derive { - schema_len = input.schema_len -} - -Logical Project (input: Logical, exprs: [Scalar]) derive { - schema_len = exprs.len, - other = expr2 -} - -Logical Join( - left: Logical, - right: Logical, - type: String, - cond: Scalar -) derive { - schema_len = left.schema_len + right.schema_len -} - -Logical Sort(input: Logical, keys: [Scalar]) derive { - schema_len = input.schema_len -} - -Logical Aggregate(input: Logical, group_keys: [Scalar], aggs: [Scalar]) derive { - schema_len = group_keys.len + aggs.len -} - - @rule(Scalar) -def increment(x: Int64): Int64 = x + 1 - -// Rules demonstrating all language features -@rule(Scalar) -def constant_fold(expr: Scalar): Scalar = - match expr - case op @ Add(left: _, Const(bla)) => Const(x + y), - case op @ Multiply(left: Const(x), right: Const(y)) => { - Const(x * y) - }, - case And(ms) => { - val folded = ms.map(ms); - val test = (5 + 5) * 3 ++ 87..bla; - And(folded) - }, - case _ => expr - -def rewrite_column_refs(expr: Scalar, index_map: Map[Int64, Int64]): Scalar = - match expr - case ColumnRef(i) => ColumnRef(index_map.get(i)), - case _ => expr.with_children(child => rewrite_column_refs(child, index_map)) - -@rule(Scalar) -def has_refs_in_range(expr: Scalar, start: Int64, end: Int64): Bool = - match expr { - case ColumnRef(i) => i >= start && i < end, - case _ => expr.children().any(child => has_refs_in_range(child, start, end)) - } - -@rule(Logical) -def join_commute(expr: Logical): Logical = - match expr - case Join("Inner", left, right, cond) => { - val left_len = l.schema_len; - val right_len = r.schema_len; - - val refs_remap = - ((0..left_len).map(i => (i, i + right_len)) ++ - (0..right_len).map(i => (left_len + i, i))) - .to_map(); - - Join("Inner", r, l, rewrite_column_refs(c, refs_remap)) - } - -@rule(Logical) -def join_associate(expr: Logical): Logical = - match expr - case Join("Inner", l: Join("Inner", a, b, c1), c, c2) => - val a_len = a.schema_len; - val b_len = b.schema_len; - val c_len = c.schema_len; - if !has_refs_in_range(c2, 0, a_len) then - val inner_map = (a_len..(a_len + b_len + c_len)) - .map(i => (i, i - a_len)) - .to_map(); - Join("Inner", a, Join("Inner", b, c, rewrite_column_refs(c2, inner_map)), c1) - else - fail("Cannot rewrite: outer join condition references left relation") - -@rule(Logical) -def complex_project(expr: Logical): Logical = - match expr - case project @ Project(rel, es) => { - if es.len == 0 then - fail("Empty projection list") - else { - val mapped = es.map(e => constant_fold(e)); - val filtered = mapped.filter(e => - match e - case Const(_) => false, - case _ => true - ); - if filtered.len == 0 then - fail("All expressions folded to constants") - else - Project(rel, filtered) - } - } - -@rule(Scalar) -def simplify_not(expr: Scalar): Scalar = - match expr - case Not(Not(x)) => x, - case Not(Eq(l, r)) => { - val new_left = simplify_not(l); - val new_right = simplify_not(r); - Not(Eq(new_left, new_right)) - }, - case _ => expr - -@rule(Logical) -def push_filter_down(expr: Logical): Logical = - match expr - case Filter(Project(rel, es), cond) => { - val new_cond = rewrite_column_refs(cond, bla); - Project(Filter(rel, new_cond), es) - }, - case _ => expr - -def optimize_sort_keys(expr: Logical): Logical = - match expr - case Sort(rel, ks) => { - val simplified = ks.map(k => constant_fold(k)); - val filtered = simplified.filter(k => - match k - case Const(_) => false, - case _ => true - ); - Sort(rel, filtered) - }, - case _ => expr \ No newline at end of file diff --git a/optd-core/src/dsl/programs/working.optd b/optd-core/src/dsl/programs/working.optd deleted file mode 100644 index 56b421c..0000000 --- a/optd-core/src/dsl/programs/working.optd +++ /dev/null @@ -1,119 +0,0 @@ -// Logical Properties -Logical Props(schema_len: Int64) - -// Scalar Operators -Scalar ColumnRef(idx: Int64) -Scalar Mult(left: Int64, right: Int64) -Scalar Add(left: Int64, right: Int64) -Scalar And(children: [Scalar]) -Scalar Or(children: [Scalar]) -Scalar Not(child: Scalar) - -// Logical Operators -Logical Scan(table_name: String) derive { - schema_len = 5 // -} - -Logical Filter(child: Logical, cond: Scalar) derive { - schema_len = input.schema_len -} - -Logical Project(child: Logical, exprs: [Scalar]) derive { - schema_len = exprs.len -} - -Logical Join( - left: Logical, - right: Logical, - typ: String, - cond: Scalar -) derive { - schema_len = left.schema_len + right.schema_len -} - -Logical Sort(child: Logical, keys: [(Scalar, String)]) derive { - schema_len = input.schema_len -} - -Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) - derive { - schema_len = group_keys.len + aggs.len -} - -// Rules -def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar = - match predicate - case ColumnRef(idx) => ColumnRef(map(idx)), - case other @ _ => predicate.apply_children(child => rewrite_column_refs(child, map)) - -@rule(Logical) -def join_commute(expr: Logical): Logical = - match expr - case Join("Inner", left, right, cond) => - val left_len = left.schema_len; - val right_len = right.schema_len; - - val right_indices = 0..right_len; - val left_indices = 0..left_len; - - val remapping = (left_indices.map(i => (i, i + right_len)) ++ - right_indices.map(i => (left_len + i, i))).to_map(); - - Project( - Join("Inner", right, left, rewrite_column_refs(cond, remapping)), - left_indices.map(i => ColumnRef(i + right_len)) ++ - right_indices.map(i => ColumnRef(i) - ) - ) - -def has_refs_in_range(predicate: Scalar, from: Int64, to: Int64): Bool = - match predicate - case ColumnRef(idx) => from <= idx && idx < to, - case _ => predicate.children.any(child => has_refs_in_range(child, from, to)) - - -@rule(Logical) -def join_associate(expr: Logical): Logical = - match expr - case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) => - val a_len = a.schema_len; - if !has_refs_in_range(cond_outer, 0, a_len) then - val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map(); - Join( - "Inner", a, - Join("Inner", b, c, rewrite_column_refs(cond_outer, remap_inner), - cond_inner) - ) - else fail("") - -@rule(Scalar) -def conjunctive_normal_form(expr: Scalar): Scalar = fail("unimplemented") - -def with_optional_filter(key: String, old: Scalar, grouped: Map[String, [Scalar]]): Scalar = - match grouped(key) - case Some(conds) => Filter(conds, old), - case _ => old - -@rule(Logical) -def filter_pushdown_join(expr: Logical): Logical = - match expr - case op @ Filter(Join(join_type, left, right, join_cond), cond) => - val cnf = conjunctive_normal_form(cond); - val grouped = cnf.children.groupBy(cond => { - if has_refs_in_range(cond, 0, left.schema_len) && - !has_refs_in_range(cond, left.schema_len, op.schema_len) then - "left" - else if !has_refs_in_range(cond, 0, left.schema_len) && - has_refs_in_range(cond, left.schema_len, op.schema_len) then - "right" - else - "remain" - }); - - with_optional_filter("remain", - Join(join_type, - with_optional_filter("left", grouped), - with_optional_filter("right", grouped), - join_cond - ) - ) diff --git a/optd-core/src/lib.rs b/optd-core/src/lib.rs index 975eaba..494ff29 100644 --- a/optd-core/src/lib.rs +++ b/optd-core/src/lib.rs @@ -1,6 +1,5 @@ #[allow(dead_code)] pub mod cascades; -pub mod dsl; pub mod engine; pub mod operators; pub mod plans;