diff --git a/optd-core/src/dsl/analyzer/mod.rs b/optd-core/src/dsl/analyzer/mod.rs new file mode 100644 index 0000000..26f160d --- /dev/null +++ b/optd-core/src/dsl/analyzer/mod.rs @@ -0,0 +1 @@ +pub mod semantic; \ No newline at end of file diff --git a/optd-core/src/dsl/analyzer/semantic.rs b/optd-core/src/dsl/analyzer/semantic.rs new file mode 100644 index 0000000..f237cb2 --- /dev/null +++ b/optd-core/src/dsl/analyzer/semantic.rs @@ -0,0 +1,258 @@ +/*use std::collections::HashSet; + +use crate::dsl::parser::ast::{Expr, File, Function, Operator, Pattern, Properties, Type}; + +#[derive(Debug)] +pub struct SemanticAnalyzer { + logical_properties: HashSet, + operators: HashSet, + identifiers: Vec>, +} + +impl SemanticAnalyzer { + pub fn new() -> Self { + SemanticAnalyzer { + logical_properties: HashSet::new(), + operators: HashSet::new(), + identifiers: Vec::new(), + } + } + + fn enter_scope(&mut self) { + self.identifiers.push(HashSet::new()); + } + + fn exit_scope(&mut self) { + self.identifiers.pop(); + } + + fn add_identifier(&mut self, name: String) -> Result<(), String> { + if let Some(scope) = self.identifiers.last_mut() { + if scope.contains(&name) { + return Err(format!("Duplicate identifier name: {}", name)); + } + scope.insert(name); + } + Ok(()) + } + + fn lookup_identifier(&self, name: &str) -> bool { + self.identifiers + .iter() + .rev() + .any(|scope| scope.contains(name)) + } + + fn is_valid_scalar_type(&self, ty: &Type) -> bool { + match ty { + Type::Array(inner) => self.is_valid_scalar_type(inner), + Type::Int64 | Type::String | Type::Bool | Type::Float64 => true, + _ => false, + } + } + + fn is_valid_logical_type(&self, ty: &Type) -> bool { + match ty { + Type::Array(inner) => self.is_valid_logical_type(inner), + Type::Int64 | Type::String | Type::Bool | Type::Float64 => true, + _ => false, + } + } + + fn is_valid_property_type(&self, ty: &Type) -> bool { + match ty { + Type::Array(inner) => self.is_valid_property_type(inner), + Type::Tuple(fields) => fields.iter().all(|f| self.is_valid_property_type(f)), + Type::Map(a, b) => self.is_valid_property_type(a) && self.is_valid_property_type(b), + Type::Int64 | Type::String | Type::Bool | Type::Float64 => true, + Type::Function(_, _) => false, + Type::Operator(_) => false, + } + } + + fn validate_properties(&mut self, properties: &Properties) -> Result<(), String> { + for field in &properties.fields { + if !self.is_valid_property_type(&field.ty) { + return Err(format!("Invalid type in properties: {:?}", field.ty)); + } + } + + self.logical_properties = properties + .fields + .iter() + .map(|field| field.name.clone()) + .collect(); + + Ok(()) + } + + fn validate_operator(&mut self, operator: &Operator) -> Result<(), String> { + match operator { + Operator::Scalar(scalar_op) => { + if self.operators.contains(&scalar_op.name) { + return Err(format!("Duplicate operator name: {}", scalar_op.name)); + } + self.operators.insert(scalar_op.name.clone()); + + for field in &scalar_op.fields { + if !self.is_valid_scalar_type(&field.ty) { + return Err(format!("Invalid type in scalar operator: {:?}", field.ty)); + } + } + } + Operator::Logical(logical_op) => { + if self.operators.contains(&logical_op.name) { + return Err(format!("Duplicate operator name: {}", logical_op.name)); + } + self.operators.insert(logical_op.name.clone()); + + for field in &logical_op.fields { + if !self.is_valid_logical_type(&field.ty) { + return Err(format!("Invalid type in logical operator: {:?}", field.ty)); + } + } + + // Check that derived properties match the logical properties fields + for (prop_name, _) in &logical_op.derived_props { + if !self.operators.iter().any(|f| f == prop_name) { + return Err(format!( + "Derived property not found in logical properties: {}", + prop_name + )); + } + } + + // Check that all logical properties fields have corresponding derived properties + for field in &self.logical_properties { + if !logical_op.derived_props.contains_key(field) { + return Err(format!( + "Logical property field '{}' is missing a derived property", + field + )); + } + } + } + } + Ok(()) + } + + // Validate a function definition + fn validate_function(&mut self, function: &Function) -> Result<(), String> { + if self.function_names.contains(&function.name) { + return Err(format!("Duplicate function name: {}", function.name)); + } + self.function_names.insert(function.name.clone()); + + self.enter_scope(); + for (param_name, _) in &function.params { + self.add_identifier(param_name.clone())?; + } + self.validate_expr(&function.body)?; + self.exit_scope(); + + Ok(()) + } + + // Validate an expression + fn validate_expr(&mut self, expr: &Expr) -> Result<(), String> { + match expr { + Expr::Var(name) => { + if self.lookup_identifier(name) { + return Err(format!("Undefined identifier: {}", name)); + } + } + Expr::Val(name, expr1, expr2) => { + self.validate_expr(expr1)?; + self.add_identifier(name.clone())?; + self.validate_expr(expr2)?; + } + Expr::Match(expr, arms) => { + self.validate_expr(expr)?; + for arm in arms { + self.validate_pattern(&arm.pattern)?; + self.validate_expr(&arm.expr)?; + } + } + Expr::If(cond, then_expr, else_expr) => { + self.validate_expr(cond)?; + self.validate_expr(then_expr)?; + self.validate_expr(else_expr)?; + } + Expr::Binary(left, _, right) => { + self.validate_expr(left)?; + self.validate_expr(right)?; + } + Expr::Unary(_, expr) => { + self.validate_expr(expr)?; + } + Expr::Call(func, args) => { + self.validate_expr(func)?; + for arg in args { + self.validate_expr(arg)?; + } + } + Expr::Member(expr, _) => { + self.validate_expr(expr)?; + } + Expr::MemberCall(expr, _, args) => { + self.validate_expr(expr)?; + for arg in args { + self.validate_expr(arg)?; + } + } + Expr::ArrayIndex(array, index) => { + self.validate_expr(array)?; + self.validate_expr(index)?; + } + Expr::Literal(_) => {} + Expr::Fail(_) => {} + Expr::Closure(params, body) => { + self.enter_scope(); + for param in params { + self.add_identifier(param.clone())?; + } + self.validate_expr(body)?; + self.exit_scope(); + } + _ => {} + } + Ok(()) + } + + // Validate a pattern + fn validate_pattern(&mut self, pattern: &Pattern) -> Result<(), String> { + match pattern { + Pattern::Bind(name, pat) => { + self.add_identifier(name.clone())?; + self.validate_pattern(pat)?; + } + Pattern::Constructor(_, pats) => { + for pat in pats { + self.validate_pattern(pat)?; + } + } + Pattern::Literal(_) => {} + Pattern::Wildcard => {} + Pattern::Var(name) => { + self.add_identifier(name.clone())?; + } + } + Ok(()) + } + + // Validate a complete file + pub fn validate_file(&mut self, file: &File) -> Result<(), String> { + self.validate_properties(&file.properties)?; + + for operator in &file.operators { + self.validate_operator(operator)?; + } + + for function in &file.functions { + self.validate_function(function)?; + } + + Ok(()) + } +} +*/ \ No newline at end of file diff --git a/optd-core/src/dsl/mod.rs b/optd-core/src/dsl/mod.rs index 67c567f..451496b 100644 --- a/optd-core/src/dsl/mod.rs +++ b/optd-core/src/dsl/mod.rs @@ -1 +1,2 @@ +pub mod analyzer; pub mod parser; diff --git a/optd-core/src/dsl/parser/expr.rs b/optd-core/src/dsl/parser/expr.rs index 9a3af4f..4c85766 100644 --- a/optd-core/src/dsl/parser/expr.rs +++ b/optd-core/src/dsl/parser/expr.rs @@ -191,7 +191,18 @@ fn parse_constructor(pair: Pair<'_, Rule>) -> Expr { let mut pairs = pair.into_inner(); let name = pairs.next().unwrap().as_str().to_string(); let args = pairs.map(parse_expr).collect(); - Expr::Constructor(name, args) + + // Check if first character is lowercase + // TODO(alexis): small hack until I rewrite the grammar using Chumsky + if name + .chars() + .next() + .map_or(false, |c| c.is_ascii_lowercase()) + { + Expr::Call(Box::new(Expr::Var(name)), args) + } else { + Expr::Constructor(name, args) + } } /// Parse a numeric literal diff --git a/optd-core/src/dsl/parser/mod.rs b/optd-core/src/dsl/parser/mod.rs index 05b91c8..95da158 100644 --- a/optd-core/src/dsl/parser/mod.rs +++ b/optd-core/src/dsl/parser/mod.rs @@ -119,9 +119,13 @@ mod tests { } #[test] - fn parse_example_file() { + fn parse_example_files() { let input = include_str!("../parser/programs/example.optd"); parse_file(input).unwrap(); + + let input = include_str!("../parser/programs/working.optd"); + let out = parse_file(input).unwrap(); + println!("{:#?}", out); } #[test] diff --git a/optd-core/src/dsl/parser/programs/example.optd b/optd-core/src/dsl/parser/programs/example.optd index 71828e9..a07e4fe 100644 --- a/optd-core/src/dsl/parser/programs/example.optd +++ b/optd-core/src/dsl/parser/programs/example.optd @@ -1,8 +1,9 @@ // Shared properties for Logical operators Logical Props(vale: [Int64]) + Scalar ColumnRef (idx: (Int64) -> Map[(Int64, (Int64) -> String), Int64]) Scalar Const(asd: Int64) -Scalar Eq(left: [Int64]) +Scalar Eq(left: Map[Int64, Int64]) Scalar Divide(left: Scalar, right: Scalar) Scalar Not(input: Scalar) diff --git a/optd-core/src/dsl/parser/programs/working.optd b/optd-core/src/dsl/parser/programs/working.optd new file mode 100644 index 0000000..d20fd8a --- /dev/null +++ b/optd-core/src/dsl/parser/programs/working.optd @@ -0,0 +1,123 @@ +// Logical Properties +Logical Props(schema_len: Int64) + +// Scalar Operators +Scalar ColumnRef(idx: Int64) +Scalar Mult(left: Int64, right: Int64) +Scalar Add(left: Int64, right: Int64) +Scalar And(children: [Scalar]) +Scalar Or(children: [Scalar]) +Scalar Not(child: Scalar) + +// Logical Operators +Logical Scan(table_name: String) derive { + schema_len = 5 // +} + +Logical Filter(child: Logical, cond: Scalar) derive { + schema_len = input.schema_len +} + + +Logical Project(child: Logical, exprs: [Scalar]) derive { + schema_len = exprs.len +} + +Logical Join( + left: Logical, + right: Logical, + typ: String, + cond: Scalar +) derive { + schema_len = left.schema_len + right.schema_len +} + +// memo(alexis): tuples should be allowed in here +Logical Sort(child: Logical, keys: [(Scalar, String)]) derive { + schema_len = input.schema_len +} + +Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) + derive { + schema_len = group_keys.len + aggs.len +} + +// Rules + +// memo(alexis): support member function for operators like apply_children +def rewrite_column_refs(predicate: Scalar, map: Map[Int64, Int64]): Scalar = + match predicate + case ColumnRef(idx) => ColumnRef(map(idx)), + case other @ _ => predicate.apply_children(child => rewrite_column_refs(child, map)) + +@rule(Logical) +def join_commute(expr: Logical): Logical = + match expr + case Join("Inner", left, right, cond) => + val left_len = left.schema_len; + val right_len = right.schema_len; + + val right_indices = 0..right_len; + val left_indices = 0..left_len; + + val remapping = (left_indices.map(i => (i, i + right_len)) ++ + right_indices.map(i => (left_len + i, i))).to_map(); + + Project( + Join("Inner", right, left, rewrite_column_refs(cond, remapping)), + left_indices.map(i => ColumnRef(i + right_len)) ++ + right_indices.map(i => ColumnRef(i) + ) + ) + +def has_refs_in_range(cond: Scalar, from: Int64, to: Int64): Bool = + match predicate + case ColumnRef(idx) => from <= idx && idx < to, + case _ => predicate.children.any(child => has_refs_in_range(child, from, to)) + + +@rule(Logical) +def join_associate(expr: Logical): Logical = + match expr + case op @ Join("Inner", Join("Inner", a, b, cond_inner), c, cond_outer) => + val a_len = a.schema_len; + if !has_refs_in_range(cond_outer, 0, a_len) then + val remap_inner = (a.schema_len..op.schema_len).map(i => (i, i - a_len)).to_map(); + Join( + "Inner", a, + Join("Inner", b, c, rewrite_column_refs(cond_outer, remap_inner), + cond_inner) + ) + else fail("") + +def conjunctive_normal_form(expr: Scalar): Bool = fail("unimplemented") + +def with_optional_filter(key: String, old: Scalar, grouped: Map[String, [Scalar]]): Scalar = match grouped(key) + case Some(conds) => Filter(conds, old), + case _ => old + +@rule(Logical) +def filter_pushdown_join(expr: Logical): Logical = + match expr + case op @ Filter(Join(join_type, left, right, join_cond), cond) => + val conditions_with_refs = conjunctive_normal_form(cond) + .map(c => (c, extract_column_refs(c))); + + val grouped = conditions_with_refs.groupBy((cond, refs) => { + if has_refs_in_range(refs, 0, left.schema_len) && + !has_refs_in_range(refs, left.schema_len, op.schema_len) then + "left" + else if !has_refs_in_range(refs, 0, left.schema_len) && + has_refs_in_range(refs, left.schema_len, op.schema_len) then + "right" + else + "remain" + }); + + with_optional_filter("remain", + Join(join_type, + with_optional_filter("left", grouped), + with_optional_filter("right", grouped), + join_cond + ) + )