Skip to content

Commit 145cb4a

Browse files
committed
Tweak dynamic dispatch and handle variadic and, or etc...
1 parent de51c99 commit 145cb4a

File tree

3 files changed

+43
-28
lines changed

3 files changed

+43
-28
lines changed

src/substrait/sql/extended_expression.py

+12-24
Original file line numberDiff line numberDiff line change
@@ -92,37 +92,25 @@ def __init__(self, functions_catalog, schema):
9292
self._schema = schema
9393
self._counter = itertools.count()
9494

95+
self._parse_expression = self.DISPATCH_REGISTRY.bind(self)
96+
9597
def expression_from_sqlglot(self, sqlglot_node):
9698
"""Parse a SQLGlot expression into a Substrait Expression."""
9799
return self._parse_expression(sqlglot_node)
98100

99-
def _parse_expression(self, expr):
100-
"""Parse a SQLGlot node and return a Substrait expression.
101-
102-
This is the internal implementation, expected to be
103-
invoked in a recursive manner to parse the whole
104-
expression tree.
105-
"""
106-
expr_class = expr.__class__
107-
return self.DISPATCH_REGISTRY[expr_class](self, expr)
108-
109101
@DISPATCH_REGISTRY.register(sqlglot.expressions.Literal)
110102
def _parse_Literal(self, expr):
111103
if expr.is_string:
112104
return ParsedSubstraitExpression(
113105
f"literal${next(self._counter)}",
114106
proto.Type(string=proto.Type.String()),
115-
proto.Expression(
116-
literal=proto.Expression.Literal(string=expr.text)
117-
),
107+
proto.Expression(literal=proto.Expression.Literal(string=expr.text)),
118108
)
119109
elif expr.is_int:
120110
return ParsedSubstraitExpression(
121111
f"literal${next(self._counter)}",
122112
proto.Type(i32=proto.Type.I32()),
123-
proto.Expression(
124-
literal=proto.Expression.Literal(i32=int(expr.name))
125-
),
113+
proto.Expression(literal=proto.Expression.Literal(i32=int(expr.name))),
126114
)
127115
elif sqlglot.helper.is_float(expr.name):
128116
return ParsedSubstraitExpression(
@@ -134,7 +122,7 @@ def _parse_Literal(self, expr):
134122
)
135123
else:
136124
raise ValueError(f"Unsupporter literal: {expr.text}")
137-
125+
138126
@DISPATCH_REGISTRY.register(sqlglot.expressions.Column)
139127
def _parse_Column(self, expr):
140128
column_name = expr.output_name
@@ -164,10 +152,8 @@ def _parser_Binary(self, expr):
164152
left_parsed_expr = self._parse_expression(expr.left)
165153
right_parsed_expr = self._parse_expression(expr.right)
166154
function_name = SQL_BINARY_FUNCTIONS[expr.key]
167-
signature, result_type, function_expression = (
168-
self._parse_function_invokation(
169-
function_name, left_parsed_expr, right_parsed_expr
170-
)
155+
signature, result_type, function_expression = self._parse_function_invokation(
156+
function_name, left_parsed_expr, right_parsed_expr
171157
)
172158
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
173159
return ParsedSubstraitExpression(
@@ -183,10 +169,12 @@ def _parser_Binary(self, expr):
183169
def _parse_Unary(self, expr):
184170
argument_parsed_expr = self._parse_expression(expr.this)
185171
function_name = SQL_UNARY_FUNCTIONS[expr.key]
186-
signature, result_type, function_expression = (
187-
self._parse_function_invokation(function_name, argument_parsed_expr)
172+
signature, result_type, function_expression = self._parse_function_invokation(
173+
function_name, argument_parsed_expr
174+
)
175+
result_name = (
176+
f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
188177
)
189-
result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
190178
return ParsedSubstraitExpression(
191179
result_name,
192180
result_type,

src/substrait/sql/functions_catalog.py

+5
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,16 @@ def load(self, dirpath, filename):
5656
t.get("value", "unknown").strip("?")
5757
for t in impl.get("args", [])
5858
]
59+
if impl.get("variadic", False):
60+
# TODO: Variadic functions.
61+
argtypes *= 2
62+
5963
if not argtypes:
6064
signature = function_name
6165
else:
6266
signature = f"{function_name}:{'_'.join(argtypes)}"
6367
loaded_functions.add(signature)
68+
print("Loaded function", signature)
6469
functions_return_type[signature] = self._type_from_name(
6570
impl["return"]
6671
)

src/substrait/sql/utils.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,38 @@
1+
import types
2+
3+
14
class DispatchRegistry:
5+
"""Dispatch a function based on the class of the argument.
6+
7+
This class allows to register a function to execute for a specific class
8+
and expose this as a method of an object which will be dispatched
9+
based on the argument.
10+
11+
It is similar to functools.singledispatch but it allows more
12+
customization in case the dispatch rules grow in complexity
13+
and works for class methods as well
14+
(singledispatch supports methods only in more recent versions)
15+
"""
16+
217
def __init__(self):
318
self._registry = {}
419

520
def register(self, cls):
621
def decorator(func):
722
self._registry[cls] = func
823
return func
24+
925
return decorator
10-
11-
def __getitem__(self, cls):
26+
27+
def bind(self, obj):
28+
return types.MethodType(self, obj)
29+
30+
def __getitem__(self, argument):
1231
for dispatch_cls, func in self._registry.items():
13-
if issubclass(cls, dispatch_cls):
32+
if isinstance(argument, dispatch_cls):
1433
return func
1534
else:
16-
raise ValueError(f"Unsupported SQL Node type: {cls}")
35+
raise ValueError(f"Unsupported SQL Node type: {cls}")
36+
37+
def __call__(self, obj, dispatch_argument, *args, **kwargs):
38+
return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs)

0 commit comments

Comments
 (0)