3
3
import sqlglot
4
4
5
5
from substrait import proto
6
+ from .utils import DispatchRegistry
7
+
6
8
7
9
SQL_UNARY_FUNCTIONS = {"not" : "not" }
8
10
SQL_BINARY_FUNCTIONS = {
@@ -83,6 +85,8 @@ def parse_sql_extended_expression(catalog, schema, sql):
83
85
84
86
85
87
class SQLGlotParser :
88
+ DISPATCH_REGISTRY = DispatchRegistry ()
89
+
86
90
def __init__ (self , functions_catalog , schema ):
87
91
self ._functions_catalog = functions_catalog
88
92
self ._schema = schema
@@ -99,88 +103,96 @@ def _parse_expression(self, expr):
99
103
invoked in a recursive manner to parse the whole
100
104
expression tree.
101
105
"""
102
- if isinstance (expr , sqlglot .expressions .Literal ):
103
- if expr .is_string :
104
- return ParsedSubstraitExpression (
105
- f"literal${ next (self ._counter )} " ,
106
- proto .Type (string = proto .Type .String ()),
107
- proto .Expression (
108
- literal = proto .Expression .Literal (string = expr .text )
109
- ),
110
- )
111
- elif expr .is_int :
112
- return ParsedSubstraitExpression (
113
- f"literal${ next (self ._counter )} " ,
114
- proto .Type (i32 = proto .Type .I32 ()),
115
- proto .Expression (
116
- literal = proto .Expression .Literal (i32 = int (expr .name ))
117
- ),
118
- )
119
- elif sqlglot .helper .is_float (expr .name ):
120
- return ParsedSubstraitExpression (
121
- f"literal${ next (self ._counter )} " ,
122
- proto .Type (fp32 = proto .Type .FP32 ()),
123
- proto .Expression (
124
- literal = proto .Expression .Literal (float = float (expr .name ))
125
- ),
126
- )
127
- else :
128
- raise ValueError (f"Unsupporter literal: { expr .text } " )
129
- elif isinstance (expr , sqlglot .expressions .Column ):
130
- column_name = expr .output_name
131
- schema_field = list (self ._schema .names ).index (column_name )
132
- schema_type = self ._schema .struct .types [schema_field ]
106
+ expr_class = expr .__class__
107
+ return self .DISPATCH_REGISTRY [expr_class ](self , expr )
108
+
109
+ @DISPATCH_REGISTRY .register (sqlglot .expressions .Literal )
110
+ def _parse_Literal (self , expr ):
111
+ if expr .is_string :
133
112
return ParsedSubstraitExpression (
134
- column_name ,
135
- schema_type ,
113
+ f"literal$ { next ( self . _counter ) } " ,
114
+ proto . Type ( string = proto . Type . String ()) ,
136
115
proto .Expression (
137
- selection = proto .Expression .FieldReference (
138
- direct_reference = proto .Expression .ReferenceSegment (
139
- struct_field = proto .Expression .ReferenceSegment .StructField (
140
- field = schema_field
141
- )
142
- )
143
- )
116
+ literal = proto .Expression .Literal (string = expr .text )
144
117
),
145
118
)
146
- elif isinstance (expr , sqlglot .expressions .Alias ):
147
- parsed_expression = self ._parse_expression (expr .this )
148
- return parsed_expression .duplicate (output_name = expr .output_name )
149
- elif expr .key in SQL_UNARY_FUNCTIONS :
150
- argument_parsed_expr = self ._parse_expression (expr .this )
151
- function_name = SQL_UNARY_FUNCTIONS [expr .key ]
152
- signature , result_type , function_expression = (
153
- self ._parse_function_invokation (function_name , argument_parsed_expr )
154
- )
155
- result_name = f"{ function_name } _{ argument_parsed_expr .output_name } _{ next (self ._counter )} "
119
+ elif expr .is_int :
156
120
return ParsedSubstraitExpression (
157
- result_name ,
158
- result_type ,
159
- function_expression ,
160
- argument_parsed_expr .invoked_functions | {signature },
161
- )
162
- elif expr .key in SQL_BINARY_FUNCTIONS :
163
- left_parsed_expr = self ._parse_expression (expr .left )
164
- right_parsed_expr = self ._parse_expression (expr .right )
165
- function_name = SQL_BINARY_FUNCTIONS [expr .key ]
166
- signature , result_type , function_expression = (
167
- self ._parse_function_invokation (
168
- function_name , left_parsed_expr , right_parsed_expr
169
- )
121
+ f"literal${ next (self ._counter )} " ,
122
+ proto .Type (i32 = proto .Type .I32 ()),
123
+ proto .Expression (
124
+ literal = proto .Expression .Literal (i32 = int (expr .name ))
125
+ ),
170
126
)
171
- result_name = f" { function_name } _ { left_parsed_expr . output_name } _ { right_parsed_expr . output_name } _ { next ( self . _counter ) } "
127
+ elif sqlglot . helper . is_float ( expr . name ):
172
128
return ParsedSubstraitExpression (
173
- result_name ,
174
- result_type ,
175
- function_expression ,
176
- left_parsed_expr .invoked_functions
177
- | right_parsed_expr .invoked_functions
178
- | {signature },
129
+ f"literal${ next (self ._counter )} " ,
130
+ proto .Type (fp32 = proto .Type .FP32 ()),
131
+ proto .Expression (
132
+ literal = proto .Expression .Literal (float = float (expr .name ))
133
+ ),
179
134
)
180
135
else :
181
- raise ValueError (
182
- f"Unsupported expression in ExtendedExpression: '{ expr .key } ' -> { expr } "
136
+ raise ValueError (f"Unsupporter literal: { expr .text } " )
137
+
138
+ @DISPATCH_REGISTRY .register (sqlglot .expressions .Column )
139
+ def _parse_Column (self , expr ):
140
+ column_name = expr .output_name
141
+ schema_field = list (self ._schema .names ).index (column_name )
142
+ schema_type = self ._schema .struct .types [schema_field ]
143
+ return ParsedSubstraitExpression (
144
+ column_name ,
145
+ schema_type ,
146
+ proto .Expression (
147
+ selection = proto .Expression .FieldReference (
148
+ direct_reference = proto .Expression .ReferenceSegment (
149
+ struct_field = proto .Expression .ReferenceSegment .StructField (
150
+ field = schema_field
151
+ )
152
+ )
153
+ )
154
+ ),
155
+ )
156
+
157
+ @DISPATCH_REGISTRY .register (sqlglot .expressions .Alias )
158
+ def _parse_Alias (self , expr ):
159
+ parsed_expression = self ._parse_expression (expr .this )
160
+ return parsed_expression .duplicate (output_name = expr .output_name )
161
+
162
+ @DISPATCH_REGISTRY .register (sqlglot .expressions .Binary )
163
+ def _parser_Binary (self , expr ):
164
+ left_parsed_expr = self ._parse_expression (expr .left )
165
+ right_parsed_expr = self ._parse_expression (expr .right )
166
+ function_name = SQL_BINARY_FUNCTIONS [expr .key ]
167
+ signature , result_type , function_expression = (
168
+ self ._parse_function_invokation (
169
+ function_name , left_parsed_expr , right_parsed_expr
183
170
)
171
+ )
172
+ result_name = f"{ function_name } _{ left_parsed_expr .output_name } _{ right_parsed_expr .output_name } _{ next (self ._counter )} "
173
+ return ParsedSubstraitExpression (
174
+ result_name ,
175
+ result_type ,
176
+ function_expression ,
177
+ left_parsed_expr .invoked_functions
178
+ | right_parsed_expr .invoked_functions
179
+ | {signature },
180
+ )
181
+
182
+ @DISPATCH_REGISTRY .register (sqlglot .expressions .Unary )
183
+ def _parse_Unary (self , expr ):
184
+ argument_parsed_expr = self ._parse_expression (expr .this )
185
+ function_name = SQL_UNARY_FUNCTIONS [expr .key ]
186
+ signature , result_type , function_expression = (
187
+ self ._parse_function_invokation (function_name , argument_parsed_expr )
188
+ )
189
+ result_name = f"{ function_name } _{ argument_parsed_expr .output_name } _{ next (self ._counter )} "
190
+ return ParsedSubstraitExpression (
191
+ result_name ,
192
+ result_type ,
193
+ function_expression ,
194
+ argument_parsed_expr .invoked_functions | {signature },
195
+ )
184
196
185
197
def _parse_function_invokation (
186
198
self , function_name , argument_parsed_expr , * additional_arguments
0 commit comments