@@ -9,9 +9,9 @@ use async_trait::async_trait;
9
9
use sqlparser:: ast:: helpers:: attached_token:: AttachedToken ;
10
10
use sqlparser:: ast:: {
11
11
BinaryOperator , CastKind , CharacterLength , DataType , Expr , Function , FunctionArg ,
12
- FunctionArgExpr , FunctionArgumentList , FunctionArguments , Ident , ObjectName ,
13
- OneOrManyWithParens , SelectItem , SetExpr , Spanned , Statement , Value , Visit , VisitMut , Visitor ,
14
- VisitorMut ,
12
+ FunctionArgExpr , FunctionArgumentList , FunctionArguments , Ident , ObjectName , ObjectNamePart ,
13
+ OneOrManyWithParens , SelectFlavor , SelectItem , SetExpr , Spanned , Statement , Value ,
14
+ ValueWithSpan , Visit , VisitMut , Visitor , VisitorMut ,
15
15
} ;
16
16
use sqlparser:: dialect:: { Dialect , MsSqlDialect , MySqlDialect , PostgreSqlDialect , SQLiteDialect } ;
17
17
use sqlparser:: parser:: { Parser , ParserError } ;
@@ -329,7 +329,7 @@ fn extract_toplevel_functions(stmt: &mut Statement) -> Vec<DelayedFunctionCall>
329
329
let argument_col_name = format ! ( "_sqlpage_f{func_idx}_a{arg_idx}" ) ;
330
330
argument_col_names. push ( argument_col_name. clone ( ) ) ;
331
331
let expr_to_insert = SelectItem :: ExprWithAlias {
332
- expr : std:: mem:: replace ( expr, Expr :: Value ( Value :: Null ) ) ,
332
+ expr : std:: mem:: replace ( expr, Expr :: value ( Value :: Null ) ) ,
333
333
alias : Ident :: new ( argument_col_name) ,
334
334
} ;
335
335
select_items_to_add. push ( SelectItemToAdd {
@@ -417,10 +417,21 @@ fn extract_static_simple_select(
417
417
return None ;
418
418
} ;
419
419
let value = match expr {
420
- Expr :: Value ( Value :: Boolean ( b) ) => Static ( Bool ( * b) ) ,
421
- Expr :: Value ( Value :: Number ( n, _) ) => Static ( Number ( n. parse ( ) . ok ( ) ?) ) ,
422
- Expr :: Value ( Value :: SingleQuotedString ( s) ) => Static ( String ( s. clone ( ) ) ) ,
423
- Expr :: Value ( Value :: Null ) => Static ( Null ) ,
420
+ Expr :: Value ( ValueWithSpan {
421
+ value : Value :: Boolean ( b) ,
422
+ ..
423
+ } ) => Static ( Bool ( * b) ) ,
424
+ Expr :: Value ( ValueWithSpan {
425
+ value : Value :: Number ( n, _) ,
426
+ ..
427
+ } ) => Static ( Number ( n. parse ( ) . ok ( ) ?) ) ,
428
+ Expr :: Value ( ValueWithSpan {
429
+ value : Value :: SingleQuotedString ( s) ,
430
+ ..
431
+ } ) => Static ( String ( s. clone ( ) ) ) ,
432
+ Expr :: Value ( ValueWithSpan {
433
+ value : Value :: Null , ..
434
+ } ) => Static ( Null ) ,
424
435
e if is_simple_select_placeholder ( e) => {
425
436
if let Some ( p) = params_iter. next ( ) {
426
437
Dynamic ( p)
@@ -446,7 +457,10 @@ fn extract_static_simple_select(
446
457
447
458
fn is_simple_select_placeholder ( e : & Expr ) -> bool {
448
459
match e {
449
- Expr :: Value ( Value :: Placeholder ( _) ) => true ,
460
+ Expr :: Value ( ValueWithSpan {
461
+ value : Value :: Placeholder ( _) ,
462
+ ..
463
+ } ) => true ,
450
464
Expr :: Cast {
451
465
expr,
452
466
data_type : DataType :: Text | DataType :: Varchar ( _) | DataType :: Char ( _) ,
@@ -469,13 +483,15 @@ fn extract_set_variable(
469
483
hivevar : false ,
470
484
} = stmt
471
485
{
472
- if let ( [ ident] , [ value] ) = ( name. as_mut_slice ( ) , value. as_mut_slice ( ) ) {
486
+ if let ( [ ObjectNamePart :: Identifier ( ident) ] , [ value] ) =
487
+ ( name. as_mut_slice ( ) , value. as_mut_slice ( ) )
488
+ {
473
489
let variable = if let Some ( variable) = extract_ident_param ( ident) {
474
490
variable
475
491
} else {
476
492
StmtParam :: PostOrGet ( std:: mem:: take ( & mut ident. value ) )
477
493
} ;
478
- let owned_expr = std:: mem:: replace ( value, Expr :: Value ( Value :: Null ) ) ;
494
+ let owned_expr = std:: mem:: replace ( value, Expr :: value ( Value :: Null ) ) ;
479
495
let mut select_stmt: Statement = expr_to_statement ( owned_expr) ;
480
496
let delayed_functions = extract_toplevel_functions ( & mut select_stmt) ;
481
497
if let Err ( err) = validate_function_calls ( & select_stmt) {
@@ -576,7 +592,7 @@ impl ParameterExtractor {
576
592
AnyKind :: Mssql => DataType :: Varchar ( Some ( CharacterLength :: Max ) ) ,
577
593
_ => DataType :: Text ,
578
594
} ;
579
- let value = Expr :: Value ( Value :: Placeholder ( name) ) ;
595
+ let value = Expr :: value ( Value :: Placeholder ( name) ) ;
580
596
Expr :: Cast {
581
597
expr : Box :: new ( value) ,
582
598
data_type,
@@ -693,9 +709,10 @@ pub(super) fn function_args_to_stmt_params(
693
709
694
710
fn expr_to_stmt_param ( arg : & mut Expr ) -> Option < StmtParam > {
695
711
match arg {
696
- Expr :: Value ( Value :: Placeholder ( placeholder) ) => {
697
- Some ( map_param ( std:: mem:: take ( placeholder) ) )
698
- }
712
+ Expr :: Value ( ValueWithSpan {
713
+ value : Value :: Placeholder ( placeholder) ,
714
+ ..
715
+ } ) => Some ( map_param ( std:: mem:: take ( placeholder) ) ) ,
699
716
Expr :: Identifier ( ident) => extract_ident_param ( ident) ,
700
717
Expr :: Function ( Function {
701
718
name : ObjectName ( func_name_parts) ,
@@ -710,13 +727,17 @@ fn expr_to_stmt_param(arg: &mut Expr) -> Option<StmtParam> {
710
727
sqlpage_func_name ( func_name_parts) ,
711
728
args. as_mut_slice ( ) ,
712
729
) ) ,
713
- Expr :: Value ( Value :: SingleQuotedString ( param_value) ) => {
714
- Some ( StmtParam :: Literal ( std:: mem:: take ( param_value) ) )
715
- }
716
- Expr :: Value ( Value :: Number ( param_value, _is_long) ) => {
717
- Some ( StmtParam :: Literal ( param_value. clone ( ) ) )
718
- }
719
- Expr :: Value ( Value :: Null ) => Some ( StmtParam :: Null ) ,
730
+ Expr :: Value ( ValueWithSpan {
731
+ value : Value :: SingleQuotedString ( param_value) ,
732
+ ..
733
+ } ) => Some ( StmtParam :: Literal ( std:: mem:: take ( param_value) ) ) ,
734
+ Expr :: Value ( ValueWithSpan {
735
+ value : Value :: Number ( param_value, _is_long) ,
736
+ ..
737
+ } ) => Some ( StmtParam :: Literal ( param_value. clone ( ) ) ) ,
738
+ Expr :: Value ( ValueWithSpan {
739
+ value : Value :: Null , ..
740
+ } ) => Some ( StmtParam :: Null ) ,
720
741
Expr :: BinaryOp {
721
742
// 'str1' || 'str2'
722
743
left,
@@ -741,7 +762,10 @@ fn expr_to_stmt_param(arg: &mut Expr) -> Option<StmtParam> {
741
762
} ) ,
742
763
..
743
764
} ) if func_name_parts. len ( ) == 1 => {
744
- let func_name = func_name_parts[ 0 ] . value . as_str ( ) ;
765
+ let func_name = func_name_parts[ 0 ]
766
+ . as_ident ( )
767
+ . map ( |ident| ident. value . as_str ( ) )
768
+ . unwrap_or_default ( ) ;
745
769
if func_name. eq_ignore_ascii_case ( "concat" ) {
746
770
let mut concat_args = Vec :: with_capacity ( args. len ( ) ) ;
747
771
for arg in args {
@@ -829,7 +853,10 @@ impl VisitorMut for ParameterExtractor {
829
853
self . replace_with_placeholder ( value, param) ;
830
854
}
831
855
}
832
- Expr :: Value ( Value :: Placeholder ( param) ) if !self . is_own_placeholder ( param) =>
856
+ Expr :: Value ( ValueWithSpan {
857
+ value : Value :: Placeholder ( param) ,
858
+ ..
859
+ } ) if !self . is_own_placeholder ( param) =>
833
860
// this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
834
861
{
835
862
let name = std:: mem:: take ( param) ;
@@ -860,10 +887,10 @@ impl VisitorMut for ParameterExtractor {
860
887
op : BinaryOperator :: StringConcat ,
861
888
right,
862
889
} if self . db_kind == AnyKind :: Mssql => {
863
- let left = std:: mem:: replace ( left. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
864
- let right = std:: mem:: replace ( right. as_mut ( ) , Expr :: Value ( Value :: Null ) ) ;
890
+ let left = std:: mem:: replace ( left. as_mut ( ) , Expr :: value ( Value :: Null ) ) ;
891
+ let right = std:: mem:: replace ( right. as_mut ( ) , Expr :: value ( Value :: Null ) ) ;
865
892
* value = Expr :: Function ( Function {
866
- name : ObjectName ( vec ! [ Ident :: new( "CONCAT" ) ] ) ,
893
+ name : ObjectName ( vec ! [ ObjectNamePart :: Identifier ( Ident :: new( "CONCAT" ) ) ] ) ,
867
894
args : FunctionArguments :: List ( FunctionArgumentList {
868
895
args : vec ! [
869
896
FunctionArg :: Unnamed ( FunctionArgExpr :: Expr ( left) ) ,
@@ -896,18 +923,22 @@ impl VisitorMut for ParameterExtractor {
896
923
897
924
const SQLPAGE_FUNCTION_NAMESPACE : & str = "sqlpage" ;
898
925
899
- fn is_sqlpage_func ( func_name_parts : & [ Ident ] ) -> bool {
900
- if let [ Ident { value, .. } , Ident { .. } ] = func_name_parts {
926
+ fn is_sqlpage_func ( func_name_parts : & [ ObjectNamePart ] ) -> bool {
927
+ if let [ ObjectNamePart :: Identifier ( Ident { value, .. } ) , ObjectNamePart :: Identifier ( Ident { .. } ) ] =
928
+ func_name_parts
929
+ {
901
930
value == SQLPAGE_FUNCTION_NAMESPACE
902
931
} else {
903
932
false
904
933
}
905
934
}
906
935
907
- fn extract_sqlpage_function_name ( func_name_parts : & [ Ident ] ) -> Option < SqlPageFunctionName > {
908
- if let [ Ident {
936
+ fn extract_sqlpage_function_name (
937
+ func_name_parts : & [ ObjectNamePart ] ,
938
+ ) -> Option < SqlPageFunctionName > {
939
+ if let [ ObjectNamePart :: Identifier ( Ident {
909
940
value : namespace, ..
910
- } , Ident { value, .. } ] = func_name_parts
941
+ } ) , ObjectNamePart :: Identifier ( Ident { value, .. } ) ] = func_name_parts
911
942
{
912
943
if namespace == SQLPAGE_FUNCTION_NAMESPACE {
913
944
return SqlPageFunctionName :: from_str ( value) . ok ( ) ;
@@ -916,8 +947,10 @@ fn extract_sqlpage_function_name(func_name_parts: &[Ident]) -> Option<SqlPageFun
916
947
None
917
948
}
918
949
919
- fn sqlpage_func_name ( func_name_parts : & [ Ident ] ) -> & str {
920
- if let [ Ident { .. } , Ident { value, .. } ] = func_name_parts {
950
+ fn sqlpage_func_name ( func_name_parts : & [ ObjectNamePart ] ) -> & str {
951
+ if let [ ObjectNamePart :: Identifier ( Ident { .. } ) , ObjectNamePart :: Identifier ( Ident { value, .. } ) ] =
952
+ func_name_parts
953
+ {
921
954
value
922
955
} else {
923
956
debug_assert ! (
@@ -955,7 +988,7 @@ fn extract_json_columns(stmt: &Statement, db_kind: AnyKind) -> Vec<String> {
955
988
fn is_json_function ( expr : & Expr ) -> bool {
956
989
match expr {
957
990
Expr :: Function ( function) => {
958
- if let [ Ident { value, .. } ] = function. name . 0 . as_slice ( ) {
991
+ if let [ ObjectNamePart :: Identifier ( Ident { value, .. } ) ] = function. name . 0 . as_slice ( ) {
959
992
[
960
993
"json_object" ,
961
994
"json_array" ,
@@ -979,10 +1012,17 @@ fn is_json_function(expr: &Expr) -> bool {
979
1012
}
980
1013
}
981
1014
Expr :: Cast { data_type, .. } => {
982
- matches ! ( data_type, DataType :: JSON | DataType :: JSONB )
983
- || ( matches ! ( data_type, DataType :: Custom ( ObjectName ( parts) , _) if
984
- ( parts. len( ) == 1 )
985
- && ( parts[ 0 ] . value. eq_ignore_ascii_case( "json" ) ) ) )
1015
+ if matches ! ( data_type, DataType :: JSON | DataType :: JSONB ) {
1016
+ true
1017
+ } else if let DataType :: Custom ( ObjectName ( parts) , _) = data_type {
1018
+ if let [ ObjectNamePart :: Identifier ( ident) ] = parts. as_slice ( ) {
1019
+ ident. value . eq_ignore_ascii_case ( "json" )
1020
+ } else {
1021
+ false
1022
+ }
1023
+ } else {
1024
+ false
1025
+ }
986
1026
}
987
1027
_ => false ,
988
1028
}
@@ -1019,6 +1059,7 @@ fn expr_to_statement(expr: Expr) -> Statement {
1019
1059
window_before_qualify : false ,
1020
1060
value_table_mode : None ,
1021
1061
connect_by : None ,
1062
+ flavor : SelectFlavor :: Standard ,
1022
1063
} ,
1023
1064
) ) ) ,
1024
1065
order_by : None ,
0 commit comments