diff --git a/optd-dsl/src/analyzer/semantic.rs b/optd-dsl/src/analyzer/semantic.rs index 91d0fcf..e41df2a 100644 --- a/optd-dsl/src/analyzer/semantic.rs +++ b/optd-dsl/src/analyzer/semantic.rs @@ -441,7 +441,7 @@ mod tests { } #[test] - fn parse_working_file() { + fn test_working_file() { let input = include_str!("../programs/working.optd"); let out = parse_file(input).unwrap(); let mut analyzer = SemanticAnalyzer::new(); diff --git a/optd-dsl/src/gen/operator.rs b/optd-dsl/src/gen/operator.rs index 7789475..e3ffb16 100644 --- a/optd-dsl/src/gen/operator.rs +++ b/optd-dsl/src/gen/operator.rs @@ -1,4 +1,8 @@ -use crate::ast::{Field, LogicalOp, Operator, OperatorKind, ScalarOp, Type}; +use crate::{ + analyzer::semantic::SemanticAnalyzer, + ast::{Field, LogicalOp, Operator, OperatorKind, ScalarOp, Type}, + parser::parse_file, +}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use syn::parse_quote; @@ -63,8 +67,8 @@ impl FieldInfo { } } -fn generate_code(operators: &[Operator]) -> proc_macro2::TokenStream { - let mut generated_code = proc_macro2::TokenStream::new(); +fn generate_code(operators: &[Operator]) -> TokenStream { + let mut generated_code = TokenStream::new(); for operator in operators { let operator_code = match operator { @@ -112,8 +116,39 @@ fn generate_logical_operator(operator: &LogicalOp) -> TokenStream { } } -fn generate_scalar_operator(_operator: &ScalarOp) -> proc_macro2::TokenStream { - unimplemented!() +fn generate_scalar_operator(operator: &ScalarOp) -> TokenStream { + let name = format_ident!("{}", &operator.name); + let fields: Vec = operator.fields.iter().map(FieldInfo::new).collect(); + let struct_fields: Vec<_> = fields.iter().map(|f| f.struct_field()).collect(); + let ctor_params: Vec<_> = fields.iter().map(|f| f.ctor_param()).collect(); + let ctor_inits: Vec<_> = fields.iter().map(|f| f.ctor_init()).collect(); + let field_names: Vec<_> = fields.iter().map(|f| &f.name).collect(); + let fn_name = format_ident!("{}", operator.name.to_lowercase()); + + quote! { + use super::ScalarOperator; + use crate::values::OptdValue; + use serde::Deserialize; + + #[derive(Debug, Clone, PartialEq, Deserialize)] + pub struct #name { + #(#struct_fields,)* + } + + impl #name { + pub fn new(#(#ctor_params,)*) -> Self { + Self { + #(#ctor_inits,)* + } + } + } + + pub fn #fn_name( + #(#ctor_params,)* + ) -> ScalarOperator { + ScalarOperator::#name(#name::new(#(#field_names,)*)) + } + } } #[test] @@ -142,10 +177,24 @@ fn test_generate_logical_operator() { #generated }; let formatted = prettyplease::unparse(&syntax_tree); - println!("Generated code:\n{}", formatted); // Basic validation let code = formatted.to_string(); assert!(code.contains("pub child: Relation")); assert!(code.contains("pub predicate: Scalar")); } + +#[test] +fn test_working_file() { + let input = include_str!("../programs/working.optd"); + let out = parse_file(input).unwrap(); + let mut analyzer = SemanticAnalyzer::new(); + analyzer.validate_file(&out).unwrap(); + + let generated_code = generate_code(&out.operators); + let syntax_tree: syn::File = parse_quote! { + #generated_code + }; + let formatted = prettyplease::unparse(&syntax_tree); + println!("Generated code:\n{}", formatted); +} diff --git a/optd-dsl/src/programs/working.optd b/optd-dsl/src/programs/working.optd index 8330473..2a92256 100644 --- a/optd-dsl/src/programs/working.optd +++ b/optd-dsl/src/programs/working.optd @@ -32,16 +32,6 @@ Logical Join( schema_len = left.schema_len + right.schema_len } -// memo(alexis): tuples should be allowed in here -Logical Sort(child: Logical, keys: [(Scalar, String)]) derive { - schema_len = input.schema_len -} - -Logical Aggregate(child: Logical, group_keys: [Scalar], aggs: [(Scalar, String)]) - derive { - schema_len = group_keys.len + aggs.len -} - // Rules // memo(alexis): support member function for operators like apply_children