Skip to content

Add derive based AST visitor #765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# will have compiled files and executables
/target/
/sqlparser_bench/target/
/derive/target/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note to self

I tested using

cargo test --all --features=visitor

We need to add this feature to the CI tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that explains the poor coverage I guess

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock
Expand Down
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.28.0"
authors = ["Andy Grove <[email protected]>"]
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser/"
keywords = [ "ansi", "sql", "lexer", "parser" ]
keywords = ["ansi", "sql", "lexer", "parser"]
repository = "https://github.com/sqlparser-rs/sqlparser-rs"
license = "Apache-2.0"
include = [
Expand All @@ -23,6 +23,7 @@ default = ["std"]
std = []
# Enable JSON output in the `cli` example:
json_example = ["serde_json", "serde"]
visitor = ["sqlparser_derive"]

[dependencies]
bigdecimal = { version = "0.3", features = ["serde"], optional = true }
Expand All @@ -32,6 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
# of dev-dependencies because of
# https://github.com/rust-lang/cargo/issues/1596
serde_json = { version = "1.0", optional = true }
sqlparser_derive = { version = "0.1", path = "derive", optional = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself: if we merge this PR I think we should document this feature flag -- it seems documentation is missing for the existing feature flags

Copy link
Contributor

@alamb alamb Dec 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs in #774 #776


[dev-dependencies]
simple_logger = "4.0"
Expand Down
23 changes: 23 additions & 0 deletions derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "sqlparser_derive"
description = "proc macro for sqlparser"
version = "0.1.0"
authors = ["Andy Grove <[email protected]>"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 maybe this should be something different. I will think about a more appropriate author -- maybe sqlparser-rs contributors 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I'm not sure as well. Theoretically I'd say that the owner/creator should be the author, but since this is an expansion, not sure whether put both of you or just the @tustvold

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok as is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to "sqlparser authors" in #775

homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser/"
keywords = ["ansi", "sql", "lexer", "parser"]
repository = "https://github.com/sqlparser-rs/sqlparser-rs"
license = "Apache-2.0"
include = [
"src/**/*.rs",
"Cargo.toml",
]
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "1.0"
proc-macro2 = "1.0"
quote = "1.0"
79 changes: 79 additions & 0 deletions derive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SQL Parser Derive Macro

## Visit

This crate contains a procedural macro that can automatically derive implementations of the `Visit` trait
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self -- this should have a doc link back to the sqlparser crate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in #779


```rust
#[derive(Visit)]
struct Foo {
boolean: bool,
bar: Bar,
}

#[derive(Visit)]
enum Bar {
A(),
B(String, bool),
C { named: i32 },
}
```

Will generate code akin to

```rust
impl Visit for Foo {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
self.boolean.visit(visitor)?;
self.bar.visit(visitor)?;
ControlFlow::Continue(())
}
}

impl Visit for Bar {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
match self {
Self::A() => {}
Self::B(_1, _2) => {
_1.visit(visitor)?;
_2.visit(visitor)?;
}
Self::C { named } => {
named.visit(visitor)?;
}
}
ControlFlow::Continue(())
}
}
```

Additionally certain types may wish to call a corresponding method on visitor before recursing

```rust
#[derive(Visit)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
}
```

Will generate

```rust
impl Visit for Bar {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.visit_expr(self)?;
match self {
Self::A() => {}
Self::B(_1, _2, _3) => {
_1.visit(visitor)?;
visitor.visit_relation(_3)?;
_2.visit(visitor)?;
_3.visit(visitor)?;
}
}
ControlFlow::Continue(())
}
}
```
184 changes: 184 additions & 0 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
};

/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(Visit, attributes(visit))]
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;

let attributes = Attributes::parse(&input.attrs);
// Add a bound `T: HeapSize` to every type parameter T.
let generics = add_trait_bounds(input.generics);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let (pre_visit, post_visit) = attributes.visit(quote!(self));
let children = visit_children(&input.data);

let expanded = quote! {
// The generated impl.
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
#pre_visit
#children
#post_visit
::std::ops::ControlFlow::Continue(())
}
}
};

proc_macro::TokenStream::from(expanded)
}

/// Parses attributes that can be provided to this macro
///
/// `#[visit(leaf, with = "visit_expr")]`
#[derive(Default)]
struct Attributes {
/// Content for the `with` attribute
with: Option<Ident>,
}

impl Attributes {
fn parse(attrs: &[Attribute]) -> Self {
let mut out = Self::default();
for attr in attrs.iter().filter(|a| a.path.is_ident("visit")) {
let meta = attr.parse_meta().expect("visit attribute");
match meta {
Meta::List(l) => {
for nested in &l.nested {
match nested {
NestedMeta::Meta(Meta::NameValue(v)) => out.parse_name_value(v),
_ => panic!("Expected #[visit(key = \"value\")]"),
}
}
}
_ => panic!("Expected #[visit(...)]"),
}
}
out
}

/// Updates self with a name value attribute
fn parse_name_value(&mut self, v: &MetaNameValue) {
if v.path.is_ident("with") {
match &v.lit {
Lit::Str(s) => self.with = Some(format_ident!("{}", s.value(), span = s.span())),
_ => panic!("Expected a string value, got {}", v.lit.to_token_stream()),
}
return;
}
panic!("Unrecognised kv attribute {}", v.path.to_token_stream())
}

/// Returns the pre and post visit token streams
fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
let pre_visit = self.with.as_ref().map(|m| {
let m = format_ident!("pre_{}", m);
quote!(visitor.#m(#s)?;)
});
let post_visit = self.with.as_ref().map(|m| {
let m = format_ident!("post_{}", m);
quote!(visitor.#m(#s)?;)
});
(pre_visit, post_visit)
}
}

// Add a bound `T: Visit` to every type parameter T.
fn add_trait_bounds(mut generics: Generics) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
}
}
generics
}

// Generate the body of the visit implementation for the given type
fn visit_children(data: &Data) -> TokenStream {
match data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
}
}
Fields::Unnamed(fields) => {
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
}
}
Fields::Unit => {
quote!()
}
},
Data::Enum(data) => {
let statements = data.variants.iter().map(|v| {
let name = &v.ident;
match &v.fields {
Fields::Named(fields) => {
let names = fields.named.iter().map(|f| &f.ident);
let visit = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
});

quote!(
Self::#name { #(#names),* } => {
#(#visit)*
}
)
}
Fields::Unnamed(fields) => {
let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span()));
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
});

quote! {
Self::#name ( #(#names),*) => {
#(#visit)*
}
}
}
Fields::Unit => {
quote! {
Self::#name => {}
}
}
}
});

quote! {
match self {
#(#statements),*
}
}
}
Data::Union(_) => unimplemented!(),
}
}
8 changes: 8 additions & 0 deletions src/ast/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ use core::fmt;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;

use crate::ast::ObjectName;

use super::value::escape_single_quote_string;

/// SQL data types
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub enum DataType {
/// Fixed-length character type e.g. CHARACTER(10)
Character(Option<CharacterLength>),
Expand Down Expand Up @@ -337,6 +341,7 @@ fn format_datetime_precision_and_tz(
/// guarantee compatibility with the input query we must maintain its exact information.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a big fan of this method as a way to minimize maintenance burden on new contributions

I picked a random struct and removed the #[cfg_attr(feature = "visitor", derive(Visit))] to see what would happen if it was forgotten

I think the error is fairly clear 👍

   Compiling sqlparser v0.28.0 (/Users/alamb/Software/sqlparser-rs)
error[E0277]: the trait bound `ast::Password: visitor::Visit` is not satisfied
    --> src/ast/mod.rs:1262:9
     |
1262 |         password: Option<Password>,
     |         ^^^^^^^^ the trait `visitor::Visit` is not implemented for `ast::Password`
     |
     = help: the following other types implement trait `visitor::Visit`:
               Box<T>
               CommentObject
               CreateTableBuilder
               Keyword
               Option<T>
               String
               Vec<T>
               ast::Action
             and 120 others
note: required for `Option<ast::Password>` to implement `visitor::Visit`
    --> src/ast/visitor.rs:21:16
     |
21   | impl<T: Visit> Visit for Option<T> {
     |                ^^^^^     ^^^^^^^^^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also making it easier to keep up with the merge conflicts 😆

pub enum TimezoneInfo {
/// No information about time zone. E.g., TIMESTAMP
None,
Expand Down Expand Up @@ -384,6 +389,7 @@ impl fmt::Display for TimezoneInfo {
/// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub enum ExactNumberInfo {
/// No additional information e.g. `DECIMAL`
None,
Expand Down Expand Up @@ -414,6 +420,7 @@ impl fmt::Display for ExactNumberInfo {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub struct CharacterLength {
/// Default (if VARYING) or maximum (if not VARYING) length
pub length: u64,
Expand All @@ -436,6 +443,7 @@ impl fmt::Display for CharacterLength {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
pub enum CharLengthUnits {
/// CHARACTERS unit
Characters,
Expand Down
Loading