Skip to content

Commit de51c99

Browse files
committed
Dynamic dispatch of parsing
1 parent ffbaf59 commit de51c99

File tree

2 files changed

+100
-72
lines changed

2 files changed

+100
-72
lines changed

src/substrait/sql/extended_expression.py

+84-72
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sqlglot
44

55
from substrait import proto
6+
from .utils import DispatchRegistry
7+
68

79
SQL_UNARY_FUNCTIONS = {"not": "not"}
810
SQL_BINARY_FUNCTIONS = {
@@ -83,6 +85,8 @@ def parse_sql_extended_expression(catalog, schema, sql):
8385

8486

8587
class SQLGlotParser:
88+
DISPATCH_REGISTRY = DispatchRegistry()
89+
8690
def __init__(self, functions_catalog, schema):
8791
self._functions_catalog = functions_catalog
8892
self._schema = schema
@@ -99,88 +103,96 @@ def _parse_expression(self, expr):
99103
invoked in a recursive manner to parse the whole
100104
expression tree.
101105
"""
102-
if isinstance(expr, sqlglot.expressions.Literal):
103-
if expr.is_string:
104-
return ParsedSubstraitExpression(
105-
f"literal${next(self._counter)}",
106-
proto.Type(string=proto.Type.String()),
107-
proto.Expression(
108-
literal=proto.Expression.Literal(string=expr.text)
109-
),
110-
)
111-
elif expr.is_int:
112-
return ParsedSubstraitExpression(
113-
f"literal${next(self._counter)}",
114-
proto.Type(i32=proto.Type.I32()),
115-
proto.Expression(
116-
literal=proto.Expression.Literal(i32=int(expr.name))
117-
),
118-
)
119-
elif sqlglot.helper.is_float(expr.name):
120-
return ParsedSubstraitExpression(
121-
f"literal${next(self._counter)}",
122-
proto.Type(fp32=proto.Type.FP32()),
123-
proto.Expression(
124-
literal=proto.Expression.Literal(float=float(expr.name))
125-
),
126-
)
127-
else:
128-
raise ValueError(f"Unsupporter literal: {expr.text}")
129-
elif isinstance(expr, sqlglot.expressions.Column):
130-
column_name = expr.output_name
131-
schema_field = list(self._schema.names).index(column_name)
132-
schema_type = self._schema.struct.types[schema_field]
106+
expr_class = expr.__class__
107+
return self.DISPATCH_REGISTRY[expr_class](self, expr)
108+
109+
@DISPATCH_REGISTRY.register(sqlglot.expressions.Literal)
110+
def _parse_Literal(self, expr):
111+
if expr.is_string:
133112
return ParsedSubstraitExpression(
134-
column_name,
135-
schema_type,
113+
f"literal${next(self._counter)}",
114+
proto.Type(string=proto.Type.String()),
136115
proto.Expression(
137-
selection=proto.Expression.FieldReference(
138-
direct_reference=proto.Expression.ReferenceSegment(
139-
struct_field=proto.Expression.ReferenceSegment.StructField(
140-
field=schema_field
141-
)
142-
)
143-
)
116+
literal=proto.Expression.Literal(string=expr.text)
144117
),
145118
)
146-
elif isinstance(expr, sqlglot.expressions.Alias):
147-
parsed_expression = self._parse_expression(expr.this)
148-
return parsed_expression.duplicate(output_name=expr.output_name)
149-
elif expr.key in SQL_UNARY_FUNCTIONS:
150-
argument_parsed_expr = self._parse_expression(expr.this)
151-
function_name = SQL_UNARY_FUNCTIONS[expr.key]
152-
signature, result_type, function_expression = (
153-
self._parse_function_invokation(function_name, argument_parsed_expr)
154-
)
155-
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
119+
elif expr.is_int:
156120
return ParsedSubstraitExpression(
157-
result_name,
158-
result_type,
159-
function_expression,
160-
argument_parsed_expr.invoked_functions | {signature},
161-
)
162-
elif expr.key in SQL_BINARY_FUNCTIONS:
163-
left_parsed_expr = self._parse_expression(expr.left)
164-
right_parsed_expr = self._parse_expression(expr.right)
165-
function_name = SQL_BINARY_FUNCTIONS[expr.key]
166-
signature, result_type, function_expression = (
167-
self._parse_function_invokation(
168-
function_name, left_parsed_expr, right_parsed_expr
169-
)
121+
f"literal${next(self._counter)}",
122+
proto.Type(i32=proto.Type.I32()),
123+
proto.Expression(
124+
literal=proto.Expression.Literal(i32=int(expr.name))
125+
),
170126
)
171-
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
127+
elif sqlglot.helper.is_float(expr.name):
172128
return ParsedSubstraitExpression(
173-
result_name,
174-
result_type,
175-
function_expression,
176-
left_parsed_expr.invoked_functions
177-
| right_parsed_expr.invoked_functions
178-
| {signature},
129+
f"literal${next(self._counter)}",
130+
proto.Type(fp32=proto.Type.FP32()),
131+
proto.Expression(
132+
literal=proto.Expression.Literal(float=float(expr.name))
133+
),
179134
)
180135
else:
181-
raise ValueError(
182-
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
136+
raise ValueError(f"Unsupporter literal: {expr.text}")
137+
138+
@DISPATCH_REGISTRY.register(sqlglot.expressions.Column)
139+
def _parse_Column(self, expr):
140+
column_name = expr.output_name
141+
schema_field = list(self._schema.names).index(column_name)
142+
schema_type = self._schema.struct.types[schema_field]
143+
return ParsedSubstraitExpression(
144+
column_name,
145+
schema_type,
146+
proto.Expression(
147+
selection=proto.Expression.FieldReference(
148+
direct_reference=proto.Expression.ReferenceSegment(
149+
struct_field=proto.Expression.ReferenceSegment.StructField(
150+
field=schema_field
151+
)
152+
)
153+
)
154+
),
155+
)
156+
157+
@DISPATCH_REGISTRY.register(sqlglot.expressions.Alias)
158+
def _parse_Alias(self, expr):
159+
parsed_expression = self._parse_expression(expr.this)
160+
return parsed_expression.duplicate(output_name=expr.output_name)
161+
162+
@DISPATCH_REGISTRY.register(sqlglot.expressions.Binary)
163+
def _parser_Binary(self, expr):
164+
left_parsed_expr = self._parse_expression(expr.left)
165+
right_parsed_expr = self._parse_expression(expr.right)
166+
function_name = SQL_BINARY_FUNCTIONS[expr.key]
167+
signature, result_type, function_expression = (
168+
self._parse_function_invokation(
169+
function_name, left_parsed_expr, right_parsed_expr
183170
)
171+
)
172+
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
173+
return ParsedSubstraitExpression(
174+
result_name,
175+
result_type,
176+
function_expression,
177+
left_parsed_expr.invoked_functions
178+
| right_parsed_expr.invoked_functions
179+
| {signature},
180+
)
181+
182+
@DISPATCH_REGISTRY.register(sqlglot.expressions.Unary)
183+
def _parse_Unary(self, expr):
184+
argument_parsed_expr = self._parse_expression(expr.this)
185+
function_name = SQL_UNARY_FUNCTIONS[expr.key]
186+
signature, result_type, function_expression = (
187+
self._parse_function_invokation(function_name, argument_parsed_expr)
188+
)
189+
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
190+
return ParsedSubstraitExpression(
191+
result_name,
192+
result_type,
193+
function_expression,
194+
argument_parsed_expr.invoked_functions | {signature},
195+
)
184196

185197
def _parse_function_invokation(
186198
self, function_name, argument_parsed_expr, *additional_arguments

src/substrait/sql/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
class DispatchRegistry:
2+
def __init__(self):
3+
self._registry = {}
4+
5+
def register(self, cls):
6+
def decorator(func):
7+
self._registry[cls] = func
8+
return func
9+
return decorator
10+
11+
def __getitem__(self, cls):
12+
for dispatch_cls, func in self._registry.items():
13+
if issubclass(cls, dispatch_cls):
14+
return func
15+
else:
16+
raise ValueError(f"Unsupported SQL Node type: {cls}")

0 commit comments

Comments
 (0)