Skip to content

Commit

Permalink
Fix fmt again
Browse files Browse the repository at this point in the history
  • Loading branch information
AlSchlo committed Feb 11, 2025
1 parent bf92829 commit e5ed19b
Showing 1 changed file with 151 additions and 0 deletions.
151 changes: 151 additions & 0 deletions optd-dsl/src/gen/operator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use crate::ast::{Field, LogicalOp, Operator, OperatorKind, ScalarOp, Type};
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use syn::parse_quote;

/// Converts an AST type to its corresponding Rust type representation
fn type_to_tokens(ty: &Type, is_param: bool) -> TokenStream {
match ty {
Type::String => {
if is_param {
quote! { &str }
} else {
quote! { String }
}
}
Type::Bool => quote! { bool },
Type::Int64 => quote! { i64 },
Type::Float64 => quote! { f64 },
Type::Operator(OperatorKind::Logical) => quote! { Relation },
Type::Operator(OperatorKind::Scalar) => quote! { Scalar },
Type::Array(inner) => {
let inner_type = type_to_tokens(inner, false);
quote! { Vec<#inner_type> }
}
_ => panic!("Unexpected type: {:?}", ty),
}
}

/// Helper struct to hold field information for code generation
struct FieldInfo {
name: Ident,
ty: Type,
}

impl FieldInfo {
fn new(field: &Field) -> Self {
Self {
name: format_ident!("{}", field.name),
ty: field.ty.clone(),
}
}

fn struct_field(&self) -> TokenStream {
let name = &self.name;
let ty = type_to_tokens(&self.ty, false);
quote! {
pub #name: #ty
}
}

fn ctor_param(&self) -> TokenStream {
let name = &self.name;
let ty = type_to_tokens(&self.ty, true);
quote! { #name: #ty }
}

fn ctor_init(&self) -> TokenStream {
let name = &self.name;
match &self.ty {
Type::String => quote! { #name: #name.into() },
_ => quote! { #name },
}
}
}

fn generate_code(operators: &[Operator]) -> proc_macro2::TokenStream {
let mut generated_code = proc_macro2::TokenStream::new();

for operator in operators {
let operator_code = match operator {
Operator::Logical(op) => generate_logical_operator(op),
Operator::Scalar(op) => generate_scalar_operator(op),
};
generated_code.extend(operator_code);
}

generated_code
}

fn generate_logical_operator(operator: &LogicalOp) -> TokenStream {
let name = format_ident!("{}", &operator.name);
let fields: Vec<FieldInfo> = operator.fields.iter().map(FieldInfo::new).collect();
let struct_fields: Vec<_> = fields.iter().map(|f| f.struct_field()).collect();
let ctor_params: Vec<_> = fields.iter().map(|f| f.ctor_param()).collect();
let ctor_inits: Vec<_> = fields.iter().map(|f| f.ctor_init()).collect();
let field_names: Vec<_> = fields.iter().map(|f| &f.name).collect();
let fn_name = format_ident!("{}", operator.name.to_lowercase());

quote! {
use super::LogicalOperator;
use crate::values::OptdValue;
use serde::Deserialize;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct #name<Value, Relation, Scalar> {
#(#struct_fields,)*
}

impl<Relation, Scalar> #name<OptdValue, Relation, Scalar> {
pub fn new(#(#ctor_params,)*) -> Self {
Self {
#(#ctor_inits,)*
}
}
}

pub fn #fn_name<Relation, Scalar>(
#(#ctor_params,)*
) -> LogicalOperator<OptdValue, Relation, Scalar> {
LogicalOperator::#name(#name::new(#(#field_names,)*))
}
}
}

fn generate_scalar_operator(_operator: &ScalarOp) -> proc_macro2::TokenStream {
unimplemented!()
}

#[test]
fn test_generate_logical_operator() {
use crate::ast::{Field, LogicalOp, OperatorKind, Type};
use std::collections::HashMap;

// Test with both Logical and Scalar operator types
let filter_op = LogicalOp {
name: "Filter".to_string(),
fields: vec![
Field {
name: "child".to_string(),
ty: Type::Operator(OperatorKind::Logical),
},
Field {
name: "predicate".to_string(),
ty: Type::Operator(OperatorKind::Scalar),
},
],
derived_props: HashMap::new(),
};

let generated = generate_logical_operator(&filter_op);
let syntax_tree: syn::File = parse_quote! {
#generated
};
let formatted = prettyplease::unparse(&syntax_tree);
println!("Generated code:\n{}", formatted);

// Basic validation
let code = formatted.to_string();
assert!(code.contains("pub child: Relation"));
assert!(code.contains("pub predicate: Scalar"));
}

0 comments on commit e5ed19b

Please sign in to comment.