|
| 1 | +use pest::iterators::Pair; |
| 2 | + |
| 3 | +use super::{ |
| 4 | + ast::{BinOp, Expr, Literal, MatchArm, UnaryOp}, |
| 5 | + patterns::parse_pattern, |
| 6 | + Rule, |
| 7 | +}; |
| 8 | + |
| 9 | +/// Parse a complete expression from a pest Pair |
| 10 | +/// |
| 11 | +/// # Arguments |
| 12 | +/// * `pair` - The pest Pair containing the expression |
| 13 | +/// |
| 14 | +/// # Returns |
| 15 | +/// * `Expr` - The parsed expression AST node |
| 16 | +pub fn parse_expr(pair: Pair<'_, Rule>) -> Expr { |
| 17 | + match pair.as_rule() { |
| 18 | + Rule::expr => parse_expr(pair.into_inner().next().unwrap()), |
| 19 | + Rule::closure => parse_closure(pair), |
| 20 | + Rule::logical_or |
| 21 | + | Rule::logical_and |
| 22 | + | Rule::comparison |
| 23 | + | Rule::concatenation |
| 24 | + | Rule::additive |
| 25 | + | Rule::range |
| 26 | + | Rule::multiplicative => parse_binary_operation(pair), |
| 27 | + Rule::postfix => parse_postfix(pair), |
| 28 | + Rule::prefix => parse_prefix(pair), |
| 29 | + Rule::match_expr => parse_match_expr(pair), |
| 30 | + Rule::if_expr => parse_if_expr(pair), |
| 31 | + Rule::val_expr => parse_val_expr(pair), |
| 32 | + Rule::array_literal => parse_array_literal(pair), |
| 33 | + Rule::tuple_literal => parse_tuple_literal(pair), |
| 34 | + Rule::constructor_expr => parse_constructor(pair), |
| 35 | + Rule::number => parse_number(pair), |
| 36 | + Rule::string => parse_string(pair), |
| 37 | + Rule::boolean => parse_boolean(pair), |
| 38 | + Rule::identifier => Expr::Var(pair.as_str().to_string()), |
| 39 | + Rule::fail_expr => parse_fail_expr(pair), |
| 40 | + _ => unreachable!("Unexpected expression rule: {:?}", pair.as_rule()), |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +/// Parse a closure expression (e.g., "(x, y) => x + y") |
| 45 | +fn parse_closure(pair: Pair<'_, Rule>) -> Expr { |
| 46 | + let mut pairs = pair.into_inner(); |
| 47 | + let params_pair = pairs.next().unwrap(); |
| 48 | + |
| 49 | + let params = if params_pair.as_rule() == Rule::identifier { |
| 50 | + vec![params_pair.as_str().to_string()] |
| 51 | + } else { |
| 52 | + params_pair |
| 53 | + .into_inner() |
| 54 | + .map(|p| p.as_str().to_string()) |
| 55 | + .collect() |
| 56 | + }; |
| 57 | + |
| 58 | + let body = parse_expr(pairs.next().unwrap()); |
| 59 | + Expr::Closure(params, Box::new(body)) |
| 60 | +} |
| 61 | + |
| 62 | +/// Parse a binary operation with proper operator precedence |
| 63 | +fn parse_binary_operation(pair: Pair<'_, Rule>) -> Expr { |
| 64 | + let mut pairs = pair.into_inner(); |
| 65 | + let mut expr = parse_expr(pairs.next().unwrap()); |
| 66 | + |
| 67 | + while let Some(op_pair) = pairs.next() { |
| 68 | + let op = parse_binary_operator(op_pair); |
| 69 | + let rhs = parse_expr(pairs.next().unwrap()); |
| 70 | + expr = Expr::Binary(Box::new(expr), op, Box::new(rhs)); |
| 71 | + } |
| 72 | + |
| 73 | + expr |
| 74 | +} |
| 75 | + |
| 76 | +/// Parse a postfix expression (function calls, member access, array indexing) |
| 77 | +fn parse_postfix(pair: Pair<'_, Rule>) -> Expr { |
| 78 | + let mut pairs = pair.into_inner(); |
| 79 | + let mut expr = parse_expr(pairs.next().unwrap()); |
| 80 | + |
| 81 | + for postfix_pair in pairs { |
| 82 | + match postfix_pair.as_rule() { |
| 83 | + Rule::call => { |
| 84 | + let args = postfix_pair.into_inner().map(parse_expr).collect(); |
| 85 | + expr = Expr::Call(Box::new(expr), args); |
| 86 | + } |
| 87 | + Rule::member_access => { |
| 88 | + let member = postfix_pair.as_str().trim_start_matches('.').to_string(); |
| 89 | + expr = Expr::Member(Box::new(expr), member); |
| 90 | + } |
| 91 | + Rule::array_index => { |
| 92 | + let index = parse_expr(postfix_pair.into_inner().next().unwrap()); |
| 93 | + expr = Expr::ArrayIndex(Box::new(expr), Box::new(index)); |
| 94 | + } |
| 95 | + Rule::member_call => { |
| 96 | + let mut pairs = postfix_pair.into_inner(); |
| 97 | + let member = pairs |
| 98 | + .next() |
| 99 | + .unwrap() |
| 100 | + .as_str() |
| 101 | + .trim_start_matches('.') |
| 102 | + .to_string(); |
| 103 | + let args = pairs.map(parse_expr).collect(); |
| 104 | + expr = Expr::MemberCall(Box::new(expr), member, args); |
| 105 | + } |
| 106 | + _ => unreachable!("Unexpected postfix rule: {:?}", postfix_pair.as_rule()), |
| 107 | + } |
| 108 | + } |
| 109 | + expr |
| 110 | +} |
| 111 | + |
| 112 | +/// Parse a prefix expression (unary operators) |
| 113 | +fn parse_prefix(pair: Pair<'_, Rule>) -> Expr { |
| 114 | + let mut pairs = pair.into_inner(); |
| 115 | + let first = pairs.next().unwrap(); |
| 116 | + |
| 117 | + if first.as_str() == "!" || first.as_str() == "-" { |
| 118 | + let op = match first.as_str() { |
| 119 | + "-" => UnaryOp::Neg, |
| 120 | + "!" => UnaryOp::Not, |
| 121 | + _ => unreachable!("Unexpected prefix operator: {}", first.as_str()), |
| 122 | + }; |
| 123 | + let expr = parse_expr(pairs.next().unwrap()); |
| 124 | + Expr::Unary(op, Box::new(expr)) |
| 125 | + } else { |
| 126 | + parse_expr(first) |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +/// Parse a match expression |
| 131 | +fn parse_match_expr(pair: Pair<'_, Rule>) -> Expr { |
| 132 | + let mut pairs = pair.into_inner(); |
| 133 | + let expr = parse_expr(pairs.next().unwrap()); |
| 134 | + let mut arms = Vec::new(); |
| 135 | + |
| 136 | + for arm_pair in pairs { |
| 137 | + if arm_pair.as_rule() == Rule::match_arm { |
| 138 | + arms.push(parse_match_arm(arm_pair)); |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + Expr::Match(Box::new(expr), arms) |
| 143 | +} |
| 144 | + |
| 145 | +/// Parse a match arm within a match expression |
| 146 | +fn parse_match_arm(pair: Pair<'_, Rule>) -> MatchArm { |
| 147 | + let mut pairs = pair.into_inner(); |
| 148 | + let pattern = parse_pattern(pairs.next().unwrap()); |
| 149 | + let expr = parse_expr(pairs.next().unwrap()); |
| 150 | + MatchArm { pattern, expr } |
| 151 | +} |
| 152 | + |
| 153 | +/// Parse an if expression |
| 154 | +fn parse_if_expr(pair: Pair<'_, Rule>) -> Expr { |
| 155 | + let mut pairs = pair.into_inner(); |
| 156 | + let condition = parse_expr(pairs.next().unwrap()); |
| 157 | + let then_branch = parse_expr(pairs.next().unwrap()); |
| 158 | + let else_branch = parse_expr(pairs.next().unwrap()); |
| 159 | + |
| 160 | + Expr::If( |
| 161 | + Box::new(condition), |
| 162 | + Box::new(then_branch), |
| 163 | + Box::new(else_branch), |
| 164 | + ) |
| 165 | +} |
| 166 | + |
| 167 | +/// Parse a val expression (local binding) |
| 168 | +fn parse_val_expr(pair: Pair<'_, Rule>) -> Expr { |
| 169 | + let mut pairs = pair.into_inner(); |
| 170 | + let name = pairs.next().unwrap().as_str().to_string(); |
| 171 | + let value = parse_expr(pairs.next().unwrap()); |
| 172 | + let body = parse_expr(pairs.next().unwrap()); |
| 173 | + |
| 174 | + Expr::Val(name, Box::new(value), Box::new(body)) |
| 175 | +} |
| 176 | + |
| 177 | +/// Parse an array literal |
| 178 | +fn parse_array_literal(pair: Pair<'_, Rule>) -> Expr { |
| 179 | + let exprs = pair.into_inner().map(parse_expr).collect(); |
| 180 | + Expr::Array(exprs) |
| 181 | +} |
| 182 | + |
| 183 | +/// Parse a tuple literal |
| 184 | +fn parse_tuple_literal(pair: Pair<'_, Rule>) -> Expr { |
| 185 | + let exprs = pair.into_inner().map(parse_expr).collect(); |
| 186 | + Expr::Tuple(exprs) |
| 187 | +} |
| 188 | + |
| 189 | +/// Parse a constructor expression |
| 190 | +fn parse_constructor(pair: Pair<'_, Rule>) -> Expr { |
| 191 | + let mut pairs = pair.into_inner(); |
| 192 | + let name = pairs.next().unwrap().as_str().to_string(); |
| 193 | + let args = pairs.map(parse_expr).collect(); |
| 194 | + Expr::Constructor(name, args) |
| 195 | +} |
| 196 | + |
| 197 | +/// Parse a numeric literal |
| 198 | +fn parse_number(pair: Pair<'_, Rule>) -> Expr { |
| 199 | + let num = pair.as_str().parse().unwrap(); |
| 200 | + Expr::Literal(Literal::Int64(num)) |
| 201 | +} |
| 202 | + |
| 203 | +/// Parse a string literal |
| 204 | +fn parse_string(pair: Pair<'_, Rule>) -> Expr { |
| 205 | + let s = pair.as_str().to_string(); |
| 206 | + Expr::Literal(Literal::String(s)) |
| 207 | +} |
| 208 | + |
| 209 | +/// Parse a boolean literal |
| 210 | +fn parse_boolean(pair: Pair<'_, Rule>) -> Expr { |
| 211 | + let b = pair.as_str().parse().unwrap(); |
| 212 | + Expr::Literal(Literal::Bool(b)) |
| 213 | +} |
| 214 | + |
| 215 | +/// Parse a fail expression |
| 216 | +fn parse_fail_expr(pair: Pair<'_, Rule>) -> Expr { |
| 217 | + let msg = pair.into_inner().next().unwrap().as_str().to_string(); |
| 218 | + Expr::Fail(msg) |
| 219 | +} |
| 220 | + |
| 221 | +/// Parse a binary operator |
| 222 | +fn parse_binary_operator(pair: Pair<'_, Rule>) -> BinOp { |
| 223 | + let op_str = match pair.as_rule() { |
| 224 | + Rule::add_op |
| 225 | + | Rule::mult_op |
| 226 | + | Rule::or_op |
| 227 | + | Rule::and_op |
| 228 | + | Rule::compare_op |
| 229 | + | Rule::concat_op |
| 230 | + | Rule::range_op => pair.as_str(), |
| 231 | + _ => pair.as_str(), |
| 232 | + }; |
| 233 | + |
| 234 | + match op_str { |
| 235 | + "+" => BinOp::Add, |
| 236 | + "-" => BinOp::Sub, |
| 237 | + "*" => BinOp::Mul, |
| 238 | + "/" => BinOp::Div, |
| 239 | + "++" => BinOp::Concat, |
| 240 | + "==" => BinOp::Eq, |
| 241 | + "!=" => BinOp::Neq, |
| 242 | + ">" => BinOp::Gt, |
| 243 | + "<" => BinOp::Lt, |
| 244 | + ">=" => BinOp::Ge, |
| 245 | + "<=" => BinOp::Le, |
| 246 | + "&&" => BinOp::And, |
| 247 | + "||" => BinOp::Or, |
| 248 | + ".." => BinOp::Range, |
| 249 | + _ => unreachable!("Unexpected binary operator: {}", op_str), |
| 250 | + } |
| 251 | +} |
| 252 | + |
| 253 | +#[cfg(test)] |
| 254 | +mod tests { |
| 255 | + use super::*; |
| 256 | + use crate::dsl::parser::DslParser; |
| 257 | + use pest::Parser; |
| 258 | + |
| 259 | + fn parse_expr_from_str(input: &str) -> Expr { |
| 260 | + let pair = DslParser::parse(Rule::expr, input).unwrap().next().unwrap(); |
| 261 | + parse_expr(pair) |
| 262 | + } |
| 263 | + |
| 264 | + #[test] |
| 265 | + fn test_parse_binary_operations() { |
| 266 | + let expr = parse_expr_from_str("1 + 2 * 3"); |
| 267 | + match expr { |
| 268 | + Expr::Binary(left, BinOp::Add, right) => { |
| 269 | + assert!(matches!(*left, Expr::Literal(Literal::Int64(1)))); |
| 270 | + match *right { |
| 271 | + Expr::Binary(l, BinOp::Mul, r) => { |
| 272 | + assert!(matches!(*l, Expr::Literal(Literal::Int64(2)))); |
| 273 | + assert!(matches!(*r, Expr::Literal(Literal::Int64(3)))); |
| 274 | + } |
| 275 | + _ => panic!("Expected multiplication"), |
| 276 | + } |
| 277 | + } |
| 278 | + _ => panic!("Expected addition"), |
| 279 | + } |
| 280 | + } |
| 281 | + |
| 282 | + #[test] |
| 283 | + fn test_parse_if_expression() { |
| 284 | + let expr = parse_expr_from_str("if x > 0 then 1 else 2"); |
| 285 | + match expr { |
| 286 | + Expr::If(cond, then_branch, else_branch) => { |
| 287 | + match *cond { |
| 288 | + Expr::Binary(left, BinOp::Gt, right) => { |
| 289 | + assert!(matches!(*left, Expr::Var(v) if v == "x")); |
| 290 | + assert!(matches!(*right, Expr::Literal(Literal::Int64(0)))); |
| 291 | + } |
| 292 | + _ => panic!("Expected comparison"), |
| 293 | + } |
| 294 | + assert!(matches!(*then_branch, Expr::Literal(Literal::Int64(1)))); |
| 295 | + assert!(matches!(*else_branch, Expr::Literal(Literal::Int64(2)))); |
| 296 | + } |
| 297 | + _ => panic!("Expected if expression"), |
| 298 | + } |
| 299 | + } |
| 300 | + |
| 301 | + #[test] |
| 302 | + fn test_parse_closure() { |
| 303 | + let expr = parse_expr_from_str("(x, y) => x + y"); |
| 304 | + println!("{:?}", expr); |
| 305 | + match expr { |
| 306 | + Expr::Closure(params, body) => { |
| 307 | + assert_eq!(params, vec!["x", "y"]); |
| 308 | + match *body { |
| 309 | + Expr::Binary(left, BinOp::Add, right) => { |
| 310 | + assert!(matches!(*left, Expr::Var(v) if v == "x")); |
| 311 | + assert!(matches!(*right, Expr::Var(v) if v == "y")); |
| 312 | + } |
| 313 | + _ => panic!("Expected addition in closure body"), |
| 314 | + } |
| 315 | + } |
| 316 | + _ => panic!("Expected closure"), |
| 317 | + } |
| 318 | + } |
| 319 | +} |
0 commit comments