Skip to content

Commit b70ef4e

Browse files
committed
Initial Import
1 parent cbac90e commit b70ef4e

File tree

3 files changed

+317
-1
lines changed

3 files changed

+317
-1
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "[email protected]
55
license = {text = "Apache-2.0"}
66
readme = "README.md"
77
requires-python = ">=3.8.1"
8-
dependencies = ["protobuf >= 3.20"]
8+
dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0"]
99
dynamic = ["version"]
1010

1111
[tool.setuptools_scm]

src/substrait/sql/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .extended_expression import parse_sql_extended_expression
+315
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import pathlib
2+
3+
import sqlglot
4+
import yaml
5+
6+
from substrait import proto
7+
8+
9+
SQL_BINARY_FUNCTIONS = {
10+
# Arithmetic
11+
"add": "add",
12+
"div": "div",
13+
"mul": "mul",
14+
"sub": "sub",
15+
# Comparisons
16+
"eq": "equal",
17+
}
18+
19+
20+
def parse_sql_extended_expression(catalog, schema, sql):
21+
select = sqlglot.parse_one(sql)
22+
if not isinstance(select, sqlglot.expressions.Select):
23+
raise ValueError("a SELECT statement was expected")
24+
25+
invoked_functions_projection, projections = _substrait_projection_from_sqlglot(
26+
catalog, schema, select.expressions
27+
)
28+
extension_uris, extensions = catalog.extensions_for_functions(
29+
invoked_functions_projection
30+
)
31+
projection_extended_expr = proto.ExtendedExpression(
32+
extension_uris=extension_uris,
33+
extensions=extensions,
34+
base_schema=schema,
35+
referred_expr=projections,
36+
)
37+
38+
invoked_functions_filter, filter_expr = _substrait_expression_from_sqlglot(
39+
catalog, schema, select.find(sqlglot.expressions.Where).this
40+
)
41+
extension_uris, extensions = catalog.extensions_for_functions(
42+
invoked_functions_filter
43+
)
44+
filter_extended_expr = proto.ExtendedExpression(
45+
extension_uris=extension_uris,
46+
extensions=extensions,
47+
base_schema=schema,
48+
referred_expr=[proto.ExpressionReference(expression=filter_expr)],
49+
)
50+
51+
return projection_extended_expr, filter_extended_expr
52+
53+
54+
def _substrait_projection_from_sqlglot(catalog, schema, expressions):
55+
if not expressions:
56+
return set(), []
57+
58+
# My understanding of ExtendedExpressions is that they are meant to directly
59+
# point to the Expression that ProjectRel would contain, so we don't actually
60+
# need a ProjectRel at all.
61+
"""
62+
projection_sub = proto.ProjectRel(
63+
input=proto.Rel(
64+
read=proto.ReadRel(
65+
named_table=proto.ReadRel.NamedTable(names=["__table__"]),
66+
base_schema=schema,
67+
)
68+
),
69+
expressions=[],
70+
)
71+
"""
72+
73+
substrait_expressions = []
74+
invoked_functions = set()
75+
for sqlexpr in expressions:
76+
output_names = []
77+
if isinstance(sqlexpr, sqlglot.expressions.Alias):
78+
output_names = [sqlexpr.output_name]
79+
sqlexpr = sqlexpr.this
80+
_, substrait_expr = _parse_expression(
81+
catalog, schema, sqlexpr, invoked_functions
82+
)
83+
substrait_expr_reference = proto.ExpressionReference(
84+
expression=substrait_expr, output_names=output_names
85+
)
86+
substrait_expressions.append(substrait_expr_reference)
87+
88+
return invoked_functions, substrait_expressions
89+
90+
91+
def _substrait_expression_from_sqlglot(catalog, schema, sqlglot_node):
92+
if not sqlglot_node:
93+
return set(), None
94+
95+
invoked_functions = set()
96+
_, substrait_expr = _parse_expression(
97+
catalog, schema, sqlglot_node, invoked_functions
98+
)
99+
return invoked_functions, substrait_expr
100+
101+
102+
def _parse_expression(catalog, schema, expr, invoked_functions):
103+
# TODO: Propagate up column names (output_names) so that the projections _always_ have an output_name
104+
if isinstance(expr, sqlglot.expressions.Literal):
105+
if expr.is_string:
106+
return proto.Type(string=proto.Type.String()), proto.Expression(
107+
literal=proto.Expression.Literal(string=expr.text)
108+
)
109+
elif expr.is_int:
110+
return proto.Type(i32=proto.Type.I32()), proto.Expression(
111+
literal=proto.Expression.Literal(i32=int(expr.name))
112+
)
113+
elif sqlglot.helper.is_float(expr.name):
114+
return proto.Type(fp32=proto.Type.FP32()), proto.Expression(
115+
literal=proto.Expression.Literal(float=float(expr.name))
116+
)
117+
else:
118+
raise ValueError(f"Unsupporter literal: {expr.text}")
119+
elif isinstance(expr, sqlglot.expressions.Column):
120+
column_name = expr.output_name
121+
schema_field = list(schema.names).index(column_name)
122+
schema_type = schema.struct.types[schema_field]
123+
return schema_type, proto.Expression(
124+
selection=proto.Expression.FieldReference(
125+
direct_reference=proto.Expression.ReferenceSegment(
126+
struct_field=proto.Expression.ReferenceSegment.StructField(
127+
field=schema_field
128+
)
129+
)
130+
)
131+
)
132+
elif expr.key in SQL_BINARY_FUNCTIONS:
133+
left_type, left = _parse_expression(
134+
catalog, schema, expr.left, invoked_functions
135+
)
136+
right_type, right = _parse_expression(
137+
catalog, schema, expr.right, invoked_functions
138+
)
139+
function_name = SQL_BINARY_FUNCTIONS[expr.key]
140+
signature, result_type, function_expression = _parse_function_invokation(
141+
function_name, left_type, left, right_type, right
142+
)
143+
invoked_functions.add(signature)
144+
return result_type, function_expression
145+
else:
146+
raise ValueError(
147+
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
148+
)
149+
150+
151+
def _parse_function_invokation(function_name, left_type, left, right_type, right):
152+
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
153+
try:
154+
function_anchor = catalog.function_anchor(signature)
155+
except KeyError:
156+
# not function found with the exact types, try any1_any1 version
157+
signature = f"{function_name}:any1_any1"
158+
function_anchor = catalog.function_anchor(signature)
159+
return (
160+
signature,
161+
left_type,
162+
proto.Expression(
163+
scalar_function=proto.Expression.ScalarFunction(
164+
function_reference=function_anchor,
165+
arguments=[
166+
proto.FunctionArgument(value=left),
167+
proto.FunctionArgument(value=right),
168+
],
169+
)
170+
),
171+
)
172+
173+
174+
class FunctionsCatalog:
175+
STANDARD_EXTENSIONS = (
176+
"/functions_aggregate_approx.yaml",
177+
"/functions_aggregate_generic.yaml",
178+
"/functions_arithmetic.yaml",
179+
"/functions_arithmetic_decimal.yaml",
180+
"/functions_boolean.yaml",
181+
"/functions_comparison.yaml",
182+
"/functions_datetime.yaml",
183+
"/functions_geometry.yaml",
184+
"/functions_logarithmic.yaml",
185+
"/functions_rounding.yaml",
186+
"/functions_set.yaml",
187+
"/functions_string.yaml",
188+
)
189+
190+
def __init__(self):
191+
self._declarations = {}
192+
self._registered_extensions = {}
193+
self._functions = {}
194+
195+
def load_standard_extensions(self, dirpath):
196+
for ext in self.STANDARD_EXTENSIONS:
197+
self.load(dirpath, ext)
198+
199+
def load(self, dirpath, filename):
200+
with open(pathlib.Path(dirpath) / filename.strip("/")) as f:
201+
sections = yaml.safe_load(f)
202+
203+
loaded_functions = set()
204+
for functions in sections.values():
205+
for function in functions:
206+
function_name = function["name"]
207+
for impl in function.get("impls", []):
208+
argtypes = [t.get("value", "unknown") for t in impl.get("args", [])]
209+
if not argtypes:
210+
signature = function_name
211+
else:
212+
signature = f"{function_name}:{'_'.join(argtypes)}"
213+
self._declarations[signature] = filename
214+
loaded_functions.add(signature)
215+
216+
self._register_extensions(filename, loaded_functions)
217+
218+
def _register_extensions(self, extension_uri, loaded_functions):
219+
if extension_uri not in self._registered_extensions:
220+
ext_anchor_id = len(self._registered_extensions) + 1
221+
self._registered_extensions[extension_uri] = proto.SimpleExtensionURI(
222+
extension_uri_anchor=ext_anchor_id, uri=extension_uri
223+
)
224+
225+
for function in loaded_functions:
226+
if function in self._functions:
227+
extensions_by_anchor = self.extension_uris_by_anchor
228+
function = self._functions[function]
229+
function_extension = extensions_by_anchor[
230+
function.extension_uri_reference
231+
].uri
232+
continue
233+
raise ValueError(
234+
f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}"
235+
)
236+
extension_anchor = self._registered_extensions[
237+
extension_uri
238+
].extension_uri_anchor
239+
function_anchor = len(self._functions) + 1
240+
self._functions[function] = (
241+
proto.SimpleExtensionDeclaration.ExtensionFunction(
242+
extension_uri_reference=extension_anchor,
243+
name=function,
244+
function_anchor=function_anchor,
245+
)
246+
)
247+
248+
@property
249+
def extension_uris_by_anchor(self):
250+
return {
251+
ext.extension_uri_anchor: ext
252+
for ext in self._registered_extensions.values()
253+
}
254+
255+
@property
256+
def extension_uris(self):
257+
return list(self._registered_extensions.values())
258+
259+
@property
260+
def extensions(self):
261+
return list(self._functions.values())
262+
263+
def function_anchor(self, function):
264+
return self._functions[function].function_anchor
265+
266+
def extensions_for_functions(self, functions):
267+
uris_anchors = set()
268+
extensions = []
269+
for f in functions:
270+
ext = self._functions[f]
271+
uris_anchors.add(ext.extension_uri_reference)
272+
extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext))
273+
274+
uris_by_anchor = self.extension_uris_by_anchor
275+
extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors]
276+
return extension_uris, extensions
277+
278+
279+
catalog = FunctionsCatalog()
280+
catalog.load_standard_extensions(
281+
pathlib.Path(__file__).parent.parent / "third_party" / "substrait" / "extensions",
282+
)
283+
schema = proto.NamedStruct(
284+
names=["first_name", "surname", "age"],
285+
struct=proto.Type.Struct(
286+
types=[
287+
proto.Type(
288+
string=proto.Type.String(
289+
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
290+
)
291+
),
292+
proto.Type(
293+
string=proto.Type.String(
294+
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
295+
)
296+
),
297+
proto.Type(
298+
i32=proto.Type.I32(
299+
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
300+
)
301+
),
302+
]
303+
),
304+
)
305+
306+
if __name__ == '__main__':
307+
sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32"
308+
projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql)
309+
print("---- SQL INPUT ----")
310+
print(sql)
311+
print("---- PROJECTION ----")
312+
print(projection_expr)
313+
print("---- FILTER ----")
314+
print(filter_expr)
315+
# parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)")

0 commit comments

Comments
 (0)