Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 155 additions & 118 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ impl fmt::Display for Array {
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Interval {
/// The interval value expression (commonly a string literal).
pub value: Box<Expr>,
pub value: Expr,
/// Optional leading time unit (e.g., `HOUR`, `MINUTE`).
pub leading_field: Option<DateTimeField>,
/// Optional leading precision for the leading field.
Expand All @@ -475,7 +475,7 @@ pub struct Interval {

impl fmt::Display for Interval {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let value = self.value.as_ref();
let value = &self.value;
match (
&self.leading_field,
self.leading_precision,
Expand Down Expand Up @@ -1025,42 +1025,9 @@ pub enum Expr {
expr: Box<Expr>,
},
/// CONVERT a value to a different data type or character encoding. e.g. `CONVERT(foo USING utf8mb4)`
Convert {
/// CONVERT (false) or TRY_CONVERT (true)
/// <https://learn.microsoft.com/en-us/sql/t-sql/functions/try-convert-transact-sql?view=sql-server-ver16>
is_try: bool,
/// The expression to convert.
expr: Box<Expr>,
/// The target data type, if provided.
data_type: Option<DataType>,
/// Optional target character encoding (e.g., `utf8mb4`).
charset: Option<ObjectName>,
/// `true` when target precedes the value (MSSQL syntax).
target_before_value: bool,
/// How to translate the expression.
///
/// [MSSQL]: https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#style
styles: Vec<Expr>,
},
Convert(Box<ConvertExpr>),
/// `CAST` an expression to a different data type e.g. `CAST(foo AS VARCHAR(123))`
Cast {
/// The cast kind (e.g., `CAST`, `TRY_CAST`).
kind: CastKind,
/// Expression being cast.
expr: Box<Expr>,
/// Target data type.
data_type: DataType,
/// [MySQL] allows CAST(... AS type ARRAY) in functional index definitions for InnoDB
/// multi-valued indices. It's not really a datatype, and is only allowed in `CAST` in key
/// specifications, so it's a flag here.
///
/// [MySQL]: https://dev.mysql.com/doc/refman/8.4/en/cast-functions.html#function_cast
array: bool,
/// Optional CAST(string_expression AS type FORMAT format_string_expression) as used by [BigQuery]
///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#formatting_syntax
format: Option<CastFormat>,
},
Cast(Box<CastExpr>),
/// AT a timestamp to a different timezone e.g. `FROM_UNIXTIME(0) AT TIME ZONE 'UTC-06:00'`
AtTimeZone {
/// Timestamp expression to shift.
Expand Down Expand Up @@ -1192,26 +1159,15 @@ pub enum Expr {
/// A constant of form `<data_type> 'value'`.
/// This can represent ANSI SQL `DATE`, `TIME`, and `TIMESTAMP` literals (such as `DATE '2020-01-01'`),
/// as well as constants of other types (a non-standard PostgreSQL extension).
TypedString(TypedString),
TypedString(Box<TypedString>),
/// Scalar function call e.g. `LEFT(foo, 5)`
Function(Function),
Function(Box<Function>),
/// `CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END`
///
/// Note we only recognize a complete single expression as `<condition>`,
/// not `< 0` nor `1, 2, 3` as allowed in a `<simple when clause>` per
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
Case {
/// The attached `CASE` token (keeps original spacing/comments).
case_token: AttachedToken,
/// The attached `END` token (keeps original spacing/comments).
end_token: AttachedToken,
/// Optional operand expression after `CASE` (for simple CASE).
operand: Option<Box<Expr>>,
/// The `WHEN ... THEN` conditions and results.
conditions: Vec<CaseWhen>,
/// Optional `ELSE` result expression.
else_result: Option<Box<Expr>>,
},
Case(Box<CaseExpr>),
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
/// `WHERE [ NOT ] EXISTS (SELECT ...)`.
Exists {
Expand Down Expand Up @@ -1277,7 +1233,7 @@ pub enum Expr {
/// An array expression e.g. `ARRAY[1, 2]`
Array(Array),
/// An interval expression e.g. `INTERVAL '1' YEAR`
Interval(Interval),
Interval(Box<Interval>),
/// `MySQL` specific text search function [(1)].
///
/// Syntax:
Expand Down Expand Up @@ -1328,7 +1284,7 @@ pub enum Expr {
/// [ClickHouse](https://clickhouse.com/docs/en/sql-reference/functions#higher-order-functions---operator-and-lambdaparams-expr-function)
/// [Databricks](https://docs.databricks.com/en/sql/language-manual/sql-ref-lambda-functions.html)
/// [DuckDB](https://duckdb.org/docs/stable/sql/functions/lambda)
Lambda(LambdaFunction),
Lambda(Box<LambdaFunction>),
/// Checks membership of a value in a JSON array
MemberOf(MemberOf),
}
Expand All @@ -1338,6 +1294,78 @@ impl Expr {
pub fn value(value: impl Into<ValueWithSpan>) -> Self {
Expr::Value(value.into())
}

/// Convenience method to retrieve `Expr::Function`'s value if `self` is a
/// function expression.
pub fn as_function(&self) -> Option<&Function> {
if let Expr::Function(f) = self {
Some(&**f)
} else {
None
}
}
}

/// A [`CONVERT` expression](Expr::Convert)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ConvertExpr {
/// CONVERT (false) or TRY_CONVERT (true)
/// <https://learn.microsoft.com/en-us/sql/t-sql/functions/try-convert-transact-sql?view=sql-server-ver16>
pub is_try: bool,
/// The expression to convert.
pub expr: Expr,
/// The target data type, if provided.
pub data_type: Option<DataType>,
/// Optional target character encoding (e.g., `utf8mb4`).
pub charset: Option<ObjectName>,
/// `true` when target precedes the value (MSSQL syntax).
pub target_before_value: bool,
/// How to translate the expression.
///
/// [MSSQL]: https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#style
pub styles: Vec<Expr>,
}

/// A [`CAST` expression](Expr::Cast)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CastExpr {
/// The cast kind (e.g., `CAST`, `TRY_CAST`).
pub kind: CastKind,
/// Expression being cast.
pub expr: Expr,
/// Target data type.
pub data_type: DataType,
/// [MySQL] allows CAST(... AS type ARRAY) in functional index definitions for InnoDB
/// multi-valued indices. It's not really a datatype, and is only allowed in `CAST` in key
/// specifications, so it's a flag here.
///
/// [MySQL]: https://dev.mysql.com/doc/refman/8.4/en/cast-functions.html#function_cast
pub array: bool,
/// Optional CAST(string_expression AS type FORMAT format_string_expression) as used by [BigQuery]
///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#formatting_syntax
pub format: Option<CastFormat>,
}

/// A [`CASE` expression](Expr::Case)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseExpr {
/// The attached `CASE` token (keeps original spacing/comments).
pub case_token: AttachedToken,
/// The attached `END` token (keeps original spacing/comments).
pub end_token: AttachedToken,
/// Optional operand expression after `CASE` (for simple CASE).
pub operand: Option<Box<Expr>>,
/// The `WHEN ... THEN` conditions and results.
pub conditions: Vec<CaseWhen>,
/// Optional `ELSE` result expression.
pub else_result: Option<Box<Expr>>,
}

/// The contents inside the `[` and `]` in a subscript expression.
Expand Down Expand Up @@ -1437,7 +1465,7 @@ pub struct LambdaFunction {
/// The parameters to the lambda function.
pub params: OneOrManyWithParens<LambdaFunctionParameter>,
/// The body of the lambda function.
pub body: Box<Expr>,
pub body: Expr,
/// The syntax style used to write the lambda function.
pub syntax: LambdaSyntax,
}
Expand Down Expand Up @@ -1892,14 +1920,15 @@ impl fmt::Display for Expr {
write!(f, "{op}{expr}")
}
}
Expr::Convert {
is_try,
expr,
target_before_value,
data_type,
charset,
styles,
} => {
Expr::Convert(convert) => {
let ConvertExpr {
is_try,
expr,
target_before_value,
data_type,
charset,
styles,
} = &**convert;
write!(f, "{}CONVERT(", if *is_try { "TRY_" } else { "" })?;
if let Some(data_type) = data_type {
if let Some(charset) = charset {
Expand All @@ -1919,41 +1948,44 @@ impl fmt::Display for Expr {
}
write!(f, ")")
}
Expr::Cast {
kind,
expr,
data_type,
array,
format,
} => match kind {
CastKind::Cast => {
write!(f, "CAST({expr} AS {data_type}")?;
if *array {
write!(f, " ARRAY")?;
Expr::Cast(cast) => {
let CastExpr {
kind,
expr,
data_type,
array,
format,
} = &**cast;
match kind {
CastKind::Cast => {
write!(f, "CAST({expr} AS {data_type}")?;
if *array {
write!(f, " ARRAY")?;
}
if let Some(format) = format {
write!(f, " FORMAT {format}")?;
}
write!(f, ")")
}
if let Some(format) = format {
write!(f, " FORMAT {format}")?;
CastKind::TryCast => {
if let Some(format) = format {
write!(f, "TRY_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "TRY_CAST({expr} AS {data_type})")
}
}
write!(f, ")")
}
CastKind::TryCast => {
if let Some(format) = format {
write!(f, "TRY_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "TRY_CAST({expr} AS {data_type})")
CastKind::SafeCast => {
if let Some(format) = format {
write!(f, "SAFE_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "SAFE_CAST({expr} AS {data_type})")
}
}
}
CastKind::SafeCast => {
if let Some(format) = format {
write!(f, "SAFE_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "SAFE_CAST({expr} AS {data_type})")
CastKind::DoubleColon => {
write!(f, "{expr}::{data_type}")
}
}
CastKind::DoubleColon => {
write!(f, "{expr}::{data_type}")
}
},
}
Expr::Extract {
field,
syntax,
Expand Down Expand Up @@ -1983,13 +2015,14 @@ impl fmt::Display for Expr {
Expr::Prefixed { prefix, value } => write!(f, "{prefix} {value}"),
Expr::TypedString(ts) => ts.fmt(f),
Expr::Function(fun) => fun.fmt(f),
Expr::Case {
case_token: _,
end_token: _,
operand,
conditions,
else_result,
} => {
Expr::Case(case) => {
let CaseExpr {
case_token: _,
end_token: _,
operand,
conditions,
else_result,
} = &**case;
f.write_str("CASE")?;
if let Some(operand) = operand {
f.write_str(" ")?;
Expand Down Expand Up @@ -10887,7 +10920,7 @@ pub enum TableObject {
/// INSERT INTO TABLE FUNCTION remote('localhost', default.simple_table)
/// ```
/// [Clickhouse](https://clickhouse.com/docs/en/sql-reference/table-functions)
TableFunction(Function),
TableFunction(Box<Function>),
}

impl fmt::Display for TableObject {
Expand Down Expand Up @@ -12325,29 +12358,33 @@ mod tests {

#[test]
fn test_interval_display() {
let interval = Expr::Interval(Interval {
value: Box::new(Expr::Value(
Value::SingleQuotedString(String::from("123:45.67")).with_empty_span(),
)),
leading_field: Some(DateTimeField::Minute),
leading_precision: Some(10),
last_field: Some(DateTimeField::Second),
fractional_seconds_precision: Some(9),
});
let interval = Expr::Interval(
Interval {
value: Expr::Value(
Value::SingleQuotedString(String::from("123:45.67")).with_empty_span(),
),
leading_field: Some(DateTimeField::Minute),
leading_precision: Some(10),
last_field: Some(DateTimeField::Second),
fractional_seconds_precision: Some(9),
}
.into(),
);
assert_eq!(
"INTERVAL '123:45.67' MINUTE (10) TO SECOND (9)",
format!("{interval}"),
);

let interval = Expr::Interval(Interval {
value: Box::new(Expr::Value(
Value::SingleQuotedString(String::from("5")).with_empty_span(),
)),
leading_field: Some(DateTimeField::Second),
leading_precision: Some(1),
last_field: None,
fractional_seconds_precision: Some(3),
});
let interval = Expr::Interval(
Interval {
value: Expr::Value(Value::SingleQuotedString(String::from("5")).with_empty_span()),
leading_field: Some(DateTimeField::Second),
leading_precision: Some(1),
last_field: None,
fractional_seconds_precision: Some(3),
}
.into(),
);
assert_eq!("INTERVAL '5' SECOND (1, 3)", format!("{interval}"));
}

Expand Down
Loading
Loading