Skip to content

Commit

Permalink
Add missing files
Browse files Browse the repository at this point in the history
  • Loading branch information
AlSchlo committed Feb 5, 2025
1 parent 270d61a commit f24c089
Show file tree
Hide file tree
Showing 6 changed files with 451 additions and 0 deletions.
101 changes: 101 additions & 0 deletions optd-core/src/engine/interpreter/analyzers/interpreter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// PartialLogicalPlan + Transformation IR => PartialLogicalPlan

use std::collections::HashMap;

use crate::{
engine::patterns::{scalar::ScalarPattern, value::ValuePattern},
plans::scalar::PartialScalarPlan,
values::OptdValue,
};

use super::scalar::ScalarAnalyzer;

// TODO(Alexis): it is totally fair for analyzers to have transformer compostions actually. just their return type should differ.
// This is much more powerful. No reason not to do it.
pub struct Context {
pub value_bindings: HashMap<String, OptdValue>,
pub scalar_bindings: HashMap<String, PartialScalarPlan>,
}

pub fn scalar_analyze(
plan: PartialScalarPlan,
transformer: &ScalarAnalyzer,
) -> anyhow::Result<Option<OptdValue>> {
for matcher in transformer.matches.iter() {
let mut context = Context {
value_bindings: HashMap::new(),
scalar_bindings: HashMap::new(),
};
if match_scalar(&plan, &matcher.pattern, &mut context)? {
for (name, comp) in matcher.composition.iter() {
let value = scalar_analyze(context.scalar_bindings[name].clone(), &comp.borrow())?;
let Some(value) = value else {
return Ok(None);
};
context.value_bindings.insert(name.clone(), value);
}

return Ok(Some(matcher.output.evaluate(&context.value_bindings)));
}
}
Ok(None)
}

fn match_scalar(
plan: &PartialScalarPlan,
pattern: &ScalarPattern,
context: &mut Context,
) -> anyhow::Result<bool> {
match pattern {
ScalarPattern::Any => Ok(true),
ScalarPattern::Not(scalar_pattern) => {
let x = match_scalar(plan, scalar_pattern, context)?;
Ok(!x)
}
ScalarPattern::Bind(name, scalar_pattern) => {
context.scalar_bindings.insert(name.clone(), plan.clone());
match_scalar(plan, scalar_pattern, context)
}
ScalarPattern::Operator {
op_type,
content,
scalar_children,
} => {
let PartialScalarPlan::PartialMaterialized { operator } = plan else {
return Ok(false); //TODO: Call memo!!
};

if operator.operator_kind() != *op_type {
return Ok(false);
}

for (subpattern, subplan) in scalar_children
.iter()
.zip(operator.children_scalars().iter())
{
if !match_scalar(subplan, subpattern, context)? {
return Ok(false);
}
}

for (subpattern, value) in content.iter().zip(operator.values().iter()) {
if !match_value(value, subpattern, context) {
return Ok(false);
}
}

Ok(true)
}
}
}

fn match_value(value: &OptdValue, pattern: &ValuePattern, context: &mut Context) -> bool {
match pattern {
ValuePattern::Any => true,
ValuePattern::Bind(name, optd_expr) => {
context.value_bindings.insert(name.clone(), value.clone());
match_value(value, optd_expr, context)
}
ValuePattern::Match { expr } => expr.evaluate(&context.value_bindings) == *value,
}
}
2 changes: 2 additions & 0 deletions optd-core/src/engine/interpreter/expressions/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod plans;
pub mod values;
30 changes: 30 additions & 0 deletions optd-core/src/engine/interpreter/expressions/plans.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use std::collections::HashMap;

use crate::{plans::PartialPlanExpr, values::OptdValue};

/// Evaluates a PartialPlanExpr to an PartialPlan using provided bindings.
impl<Plan: Clone> PartialPlanExpr<Plan> {
pub fn interpret(
&self,
plan_bindings: &HashMap<String, Plan>,
value_bindings: &HashMap<String, OptdValue>,
) -> Plan {
match self {
PartialPlanExpr::Plan(plan) => plan.clone(),

PartialPlanExpr::Ref(name) => plan_bindings.get(name).cloned().unwrap_or_else(|| {
panic!("Undefined reference: {}", name);
}),

PartialPlanExpr::IfThenElse {
cond,
then,
otherwise,
} => match cond.interpret(value_bindings) {
OptdValue::Bool(true) => then.interpret(plan_bindings, value_bindings),
OptdValue::Bool(false) => otherwise.interpret(plan_bindings, value_bindings),
_ => panic!("IfThenElse condition must be boolean"),
},
}
}
}
197 changes: 197 additions & 0 deletions optd-core/src/engine/interpreter/expressions/values.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
//! Interpreter for OPTD-DSL expressions.
//! Evaluates OptdExpr with given bindings to produce OptdValue values.
use std::collections::HashMap;

use crate::values::{OptdExpr, OptdValue};

/// Evaluates an OptdExpr to an OptdValue value using provided bindings.
impl OptdExpr {
pub fn interpret(&self, bindings: &HashMap<String, OptdValue>) -> OptdValue {
match self {
OptdExpr::Value(val) => val.clone(),

OptdExpr::Ref(name) => bindings.get(name).cloned().unwrap_or_else(|| {
panic!("Undefined reference: {}", name);
}),

OptdExpr::IfThenElse {
cond,
then,
otherwise,
} => match cond.interpret(bindings) {
OptdValue::Bool(true) => then.interpret(bindings),
OptdValue::Bool(false) => otherwise.interpret(bindings),
_ => panic!("IfThenElse condition must be boolean"),
},

OptdExpr::Eq { left, right } => {
OptdValue::Bool(left.interpret(bindings) == right.interpret(bindings))
}

OptdExpr::Lt { left, right } => {
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Int64(l), OptdValue::Int64(r)) => OptdValue::Bool(l < r),
_ => panic!("Lt requires integer operands"),
}
}

OptdExpr::Gt { left, right } => {
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Int64(l), OptdValue::Int64(r)) => OptdValue::Bool(l > r),
_ => panic!("Gt requires integer operands"),
}
}

OptdExpr::Add { left, right } => {
// TODO(alexis): overflow checks
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Int64(l), OptdValue::Int64(r)) => OptdValue::Int64(l + r),
_ => panic!("Add requires integer operands"),
}
}

OptdExpr::Sub { left, right } => {
// TODO(alexis): underflow checks
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Int64(l), OptdValue::Int64(r)) => OptdValue::Int64(l - r),
_ => panic!("Sub requires integer operands"),
}
}

OptdExpr::Mul { left, right } => {
// TODO(alexis): overflow checks
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Int64(l), OptdValue::Int64(r)) => OptdValue::Int64(l * r),
_ => panic!("Mul requires integer operands"),
}
}

OptdExpr::Div { left, right } => {
// TODO(alexis): div by 0 checks
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Int64(l), OptdValue::Int64(r)) => {
if r == 0 {
panic!("Division by zero");
}
OptdValue::Int64(l / r)
}
_ => panic!("Div requires integer operands"),
}
}

OptdExpr::And { left, right } => {
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Bool(l), OptdValue::Bool(r)) => OptdValue::Bool(l && r),
_ => panic!("And requires boolean operands"),
}
}

OptdExpr::Or { left, right } => {
match (left.interpret(bindings), right.interpret(bindings)) {
(OptdValue::Bool(l), OptdValue::Bool(r)) => OptdValue::Bool(l || r),
_ => panic!("Or requires boolean operands"),
}
}

OptdExpr::Not(expr) => match expr.interpret(bindings) {
OptdValue::Bool(b) => OptdValue::Bool(!b),
_ => panic!("Not requires boolean operand"),
},
}
}
}

#[cfg(test)]
mod tests {
use super::*;

// Helper to create test bindings
fn test_bindings() -> HashMap<String, OptdValue> {
let mut map = HashMap::new();
map.insert("x".to_string(), OptdValue::Int64(5));
map.insert("y".to_string(), OptdValue::Int64(3));
map.insert("flag".to_string(), OptdValue::Bool(true));
map
}

#[test]
fn test_basic_values() {
let bindings = test_bindings();
assert_eq!(
OptdExpr::Value(OptdValue::Int64(42)).interpret(&bindings),
OptdValue::Int64(42)
);
}

#[test]
fn test_references() {
let bindings = test_bindings();
assert_eq!(
OptdExpr::Ref("x".to_string()).interpret(&bindings),
OptdValue::Int64(5)
);
}

#[test]
fn test_arithmetic() {
let bindings = test_bindings();

// Addition
assert_eq!(
OptdExpr::Add {
left: Box::new(OptdExpr::Ref("x".to_string())),
right: Box::new(OptdExpr::Ref("y".to_string()))
}
.interpret(&bindings),
OptdValue::Int64(8)
);

// Multiplication
assert_eq!(
OptdExpr::Mul {
left: Box::new(OptdExpr::Ref("x".to_string())),
right: Box::new(OptdExpr::Ref("y".to_string()))
}
.interpret(&bindings),
OptdValue::Int64(15)
);
}

#[test]
fn test_boolean_logic() {
let bindings = test_bindings();

// AND
assert_eq!(
OptdExpr::And {
left: Box::new(OptdExpr::Ref("flag".to_string())),
right: Box::new(OptdExpr::Value(OptdValue::Bool(true)))
}
.interpret(&bindings),
OptdValue::Bool(true)
);

// NOT
assert_eq!(
OptdExpr::Not(Box::new(OptdExpr::Ref("flag".to_string()))).interpret(&bindings),
OptdValue::Bool(false)
);
}

#[test]
fn test_conditionals() {
let bindings = test_bindings();

let expr = OptdExpr::IfThenElse {
cond: Box::new(OptdExpr::Gt {
left: Box::new(OptdExpr::Ref("x".to_string())),
right: Box::new(OptdExpr::Ref("y".to_string())),
}),
then: Box::new(OptdExpr::Value(OptdValue::Int64(1))),
otherwise: Box::new(OptdExpr::Value(OptdValue::Int64(0))),
};

assert_eq!(expr.interpret(&bindings), OptdValue::Int64(1));
}
}
1 change: 1 addition & 0 deletions optd-core/src/engine/interpreter/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod expressions;
Loading

0 comments on commit f24c089

Please sign in to comment.