diff --git a/Cargo.lock b/Cargo.lock index c7d57fc..96b3926 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -810,6 +810,18 @@ dependencies = [ "trait-variant", ] +[[package]] +name = "optd-dsl" +version = "0.1.0" +dependencies = [ + "pest", + "pest_derive", + "prettyplease", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "parking" version = "2.2.1" @@ -947,6 +959,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.93" diff --git a/optd-dsl/src/gen/operator.rs b/optd-dsl/src/gen/operator.rs index 238ffdf..772fff4 100644 --- a/optd-dsl/src/gen/operator.rs +++ b/optd-dsl/src/gen/operator.rs @@ -1,4 +1,8 @@ -use crate::ast::{Field, LogicalOp, Operator, OperatorKind, ScalarOp, Type}; +use crate::{ + analyzer::semantic::SemanticAnalyzer, + ast::{Field, LogicalOp, Operator, OperatorKind, ScalarOp, Type}, + parser::parse_file, +}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use syn::parse_quote; @@ -63,8 +67,99 @@ impl FieldInfo { } } -fn generate_code(operators: &[Operator]) -> proc_macro2::TokenStream { - let mut generated_code = proc_macro2::TokenStream::new(); +fn generate_logical_enum(operators: &[LogicalOp]) -> TokenStream { + let variants: Vec<_> = operators + .iter() + .map(|op| { + let name = format_ident!("{}", &op.name); + quote! { + #name(#name) + } + }) + .collect(); + + let variant_names: Vec<_> = operators + .iter() + .map(|op| format_ident!("{}", &op.name)) + .collect(); + + quote! { + #[derive(Debug, Clone, PartialEq, Deserialize)] + pub enum LogicalOperator { + #(#variants,)* + } + + #[derive(Debug, Clone, PartialEq, sqlx::Type)] + pub enum LogicalOperatorKind { + #(#variant_names,)* + } + + impl LogicalOperator + where + Relation: Clone, + Scalar: Clone, + { + pub fn operator_kind(&self) -> LogicalOperatorKind { + match self { + #(LogicalOperator::#variant_names => LogicalOperatorKind::#variant_names,)* + } + } + } + } +} + +fn generate_scalar_enum(operators: &[ScalarOp]) -> TokenStream { + let variants: Vec<_> = operators + .iter() + .map(|op| { + let name = format_ident!("{}", &op.name); + quote! { + #name(#name) + } + }) + .collect(); + + let variant_names: Vec<_> = operators + .iter() + .map(|op| format_ident!("{}", &op.name)) + .collect(); + + quote! { + #[derive(Debug, Clone, PartialEq, Deserialize)] + pub enum ScalarOperator { + #(#variants,)* + } + + #[derive(Debug, Clone, PartialEq, sqlx::Type)] + pub enum ScalarOperatorKind { + #(#variant_names,)* + } + + impl ScalarOperator + where + Scalar: Clone, + { + pub fn operator_kind(&self) -> ScalarOperatorKind { + match self { + #(ScalarOperator::#variant_names => ScalarOperatorKind::#variant_names,)* + } + } + } + } +} + +fn generate_code(operators: &[Operator]) -> TokenStream { + let mut logical_ops = Vec::new(); + let mut scalar_ops = Vec::new(); + + for operator in operators.into_iter().cloned() { + match operator { + Operator::Logical(op) => logical_ops.push(op), + Operator::Scalar(op) => scalar_ops.push(op), + } + } + + let mut generated_code = TokenStream::new(); // Generate enums first let logical_enum = generate_logical_enum(&logical_ops); @@ -118,8 +213,39 @@ fn generate_logical_operator(operator: &LogicalOp) -> TokenStream { } } -fn generate_scalar_operator(_operator: &ScalarOp) -> proc_macro2::TokenStream { - unimplemented!() +fn generate_scalar_operator(operator: &ScalarOp) -> TokenStream { + let name = format_ident!("{}", &operator.name); + let fields: Vec = operator.fields.iter().map(FieldInfo::new).collect(); + let struct_fields: Vec<_> = fields.iter().map(|f| f.struct_field()).collect(); + let ctor_params: Vec<_> = fields.iter().map(|f| f.ctor_param()).collect(); + let ctor_inits: Vec<_> = fields.iter().map(|f| f.ctor_init()).collect(); + let field_names: Vec<_> = fields.iter().map(|f| &f.name).collect(); + let fn_name = format_ident!("{}", operator.name.to_lowercase()); + + quote! { + use super::ScalarOperator; + use crate::values::OptdValue; + use serde::Deserialize; + + #[derive(Debug, Clone, PartialEq, Deserialize)] + pub struct #name { + #(#struct_fields,)* + } + + impl #name { + pub fn new(#(#ctor_params,)*) -> Self { + Self { + #(#ctor_inits,)* + } + } + } + + pub fn #fn_name( + #(#ctor_params,)* + ) -> ScalarOperator { + ScalarOperator::#name(#name::new(#(#field_names,)*)) + } + } } #[test] @@ -148,10 +274,24 @@ fn test_generate_logical_operator() { #generated }; let formatted = prettyplease::unparse(&syntax_tree); - println!("Generated code:\n{}", formatted); // Basic validation let code = formatted.to_string(); assert!(code.contains("pub child: Relation")); assert!(code.contains("pub predicate: Scalar")); } + +#[test] +fn test_working_file() { + let input = include_str!("../programs/working.optd"); + let out = parse_file(input).unwrap(); + let mut analyzer = SemanticAnalyzer::new(); + analyzer.validate_file(&out).unwrap(); + + let generated_code = generate_code(&out.operators); + let syntax_tree: syn::File = parse_quote! { + #generated_code + }; + let formatted = prettyplease::unparse(&syntax_tree); + println!("Generated code:\n{}", formatted); +}