Skip to content

Improve docs and add examples for Visitor #778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 159 additions & 5 deletions src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.

use crate::ast::{Expr, ObjectName, Statement};
use core::ops::ControlFlow;

/// A type that can be visited by a `visitor`
/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
/// recursively visiting parsed SQL statements.
///
/// # Note
///
/// This trait should be automatically derived for sqlparser AST nodes
/// using the [Visit](sqlparser_derive::Visit) proc macro.
///
/// ```text
/// #[cfg_attr(feature = "visitor", derive(Visit))]
/// ```
pub trait Visit {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
}
Expand Down Expand Up @@ -57,8 +69,70 @@ visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
#[cfg(feature = "bigdecimal")]
visit_noop!(bigdecimal::BigDecimal);

/// A visitor that can be used to walk an AST tree
/// A visitor that can be used to walk an AST tree.
///
/// `previst_` methods are invoked before visiting all children of the
/// node and `postvisit_` methods are invoked after visiting all
/// children of the node.
///
/// # See also
///
/// These methods provide a more concise way of visiting nodes of a certain type:
/// * [visit_relations]
/// * [visit_expressions]
/// * [visit_statements]
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
/// # use core::ops::ControlFlow;
/// // A structure that records statements and relations
/// #[derive(Default)]
/// struct V {
/// visited: Vec<String>,
/// }
///
/// // Visit relations and exprs before children are visited (depth first walk)
/// // Note you can also visit statements and visit exprs after children have been visitoed
/// impl Visitor for V {
/// type Break = ();
///
/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
/// self.visited.push(format!("PRE: RELATION: {}", relation));
/// ControlFlow::Continue(())
/// }
///
/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
/// self.visited.push(format!("PRE: EXPR: {}", expr));
/// ControlFlow::Continue(())
/// }
/// }
///
/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
/// .unwrap();
///
/// // Drive the visitor through the AST
/// let mut visitor = V::default();
/// statements.visit(&mut visitor);
///
/// // The visitor has visited statements and expressions in pre-traversal order
/// let expected : Vec<_> = [
/// "PRE: EXPR: a",
/// "PRE: RELATION: foo",
/// "PRE: EXPR: x IN (SELECT y FROM bar)",
/// "PRE: EXPR: x",
/// "PRE: EXPR: y",
/// "PRE: RELATION: bar",
/// ]
/// .into_iter().map(|s| s.to_string()).collect();
///
/// assert_eq!(visitor.visited, expected);
/// ```
pub trait Visitor {
/// Type returned when the recursion returns early.
type Break;

/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
Expand Down Expand Up @@ -102,7 +176,33 @@ impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F>
}
}

/// Invokes the provided closure on all relations present in v
/// Invokes the provided closure on all relations (e.g. table names) present in `v`
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{visit_relations};
/// # use core::ops::ControlFlow;
/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
/// .unwrap();
///
/// // visit statements, capturing relations (table names)
/// let mut visited = vec![];
/// visit_relations(&statements, |relation| {
/// visited.push(format!("RELATION: {}", relation));
/// ControlFlow::<()>::Continue(())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised you need to type hint this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, me too -- I think it is because the return value is not used

/// });
///
/// let expected : Vec<_> = [
/// "RELATION: foo",
/// "RELATION: bar",
/// ]
/// .into_iter().map(|s| s.to_string()).collect();
///
/// assert_eq!(visited, expected);
/// ```
pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
where
V: Visit,
Expand All @@ -123,7 +223,35 @@ impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
}
}

/// Invokes the provided closure on all expressions present in v
/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{visit_expressions};
/// # use core::ops::ControlFlow;
/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
/// .unwrap();
///
/// // visit all expressions
/// let mut visited = vec![];
/// visit_expressions(&statements, |expr| {
/// visited.push(format!("EXPR: {}", expr));
/// ControlFlow::<()>::Continue(())
/// });
///
/// let expected : Vec<_> = [
/// "EXPR: a",
/// "EXPR: x IN (SELECT y FROM bar)",
/// "EXPR: x",
/// "EXPR: y",
/// ]
/// .into_iter().map(|s| s.to_string()).collect();
///
/// assert_eq!(visited, expected);
/// ```
pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
where
V: Visit,
Expand All @@ -144,7 +272,33 @@ impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F>
}
}

/// Invokes the provided closure on all statements present in v
/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
///
/// # Example
/// ```
/// # use sqlparser::parser::Parser;
/// # use sqlparser::dialect::GenericDialect;
/// # use sqlparser::ast::{visit_statements};
/// # use core::ops::ControlFlow;
/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
/// .unwrap();
///
/// // visit all statements
/// let mut visited = vec![];
/// visit_statements(&statements, |stmt| {
/// visited.push(format!("STATEMENT: {}", stmt));
/// ControlFlow::<()>::Continue(())
/// });
///
/// let expected : Vec<_> = [
/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
/// "STATEMENT: CREATE TABLE baz (q INT)"
/// ]
/// .into_iter().map(|s| s.to_string()).collect();
///
/// assert_eq!(visited, expected);
/// ```
pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
where
V: Visit,
Expand Down