Skip to content

Commit 3c9f6eb

Browse files
committed
Refactor FunctionsCatalog and improve functions lookup
1 parent 1a5fcc7 commit 3c9f6eb

File tree

3 files changed

+194
-96
lines changed

3 files changed

+194
-96
lines changed

src/substrait/sql/extended_expression.py

+11-18
Original file line numberDiff line numberDiff line change
@@ -194,30 +194,20 @@ def _parse_function_invokation(
194194
invokation expression itself.
195195
"""
196196
arguments = [argument_parsed_expr] + list(additional_arguments)
197-
signature = self._functions_catalog.signature(
197+
signature = self._functions_catalog.make_signature(
198198
function_name, proto_argtypes=[arg.type for arg in arguments]
199199
)
200200

201-
try:
202-
function_anchor = self._functions_catalog.function_anchor(signature)
203-
except KeyError:
204-
# No function found with the exact types, try any1_any1 version
205-
# TODO: What about cases like i32_any1? What about any instead of any1?
206-
# TODO: What about optional arguments? IE: "i32_i32?"
207-
signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}"
208-
function_anchor = self._functions_catalog.function_anchor(signature)
209-
210-
function_return_type = self._functions_catalog.function_return_type(signature)
211-
if function_return_type is None:
212-
print("No return type for", signature)
213-
# TODO: Is this the right way to handle this?
214-
function_return_type = left_type
201+
registered_function = self._functions_catalog.lookup_function(signature)
202+
if registered_function is None:
203+
raise KeyError(f"Function not found: {signature}")
204+
215205
return (
216-
signature,
217-
function_return_type,
206+
registered_function.signature,
207+
registered_function.return_type,
218208
proto.Expression(
219209
scalar_function=proto.Expression.ScalarFunction(
220-
function_reference=function_anchor,
210+
function_reference=registered_function.function_anchor,
221211
arguments=[
222212
proto.FunctionArgument(value=arg.expression)
223213
for arg in arguments
@@ -255,3 +245,6 @@ def duplicate(
255245
expression or self.expression,
256246
invoked_functions or self.invoked_functions,
257247
)
248+
249+
def __repr__(self):
250+
return f"<ParsedSubstraitExpression {self.output_name} {self.type}>"
+180-77
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,80 @@
1+
import os
12
import pathlib
3+
from collections.abc import Iterable
24

35
import yaml
46

5-
from substrait import proto
7+
from substrait.gen.proto.type_pb2 import Type as SubstraitType
8+
from substrait.gen.proto.extensions.extensions_pb2 import (
9+
SimpleExtensionURI,
10+
SimpleExtensionDeclaration,
11+
)
12+
13+
14+
class RegisteredSubstraitFunction:
15+
"""A Substrait function loaded from an extension file.
16+
17+
The FunctionsCatalog will keep a collection of RegisteredSubstraitFunction
18+
and will use them to generate the necessary extension URIs and extensions.
19+
"""
20+
21+
def __init__(self, signature: str, function_anchor: int | None, impl: dict):
22+
self.signature = signature
23+
self.function_anchor = function_anchor
24+
self.variadic = impl.get("variadic", False)
25+
26+
if "return" in impl:
27+
self.return_type = self._type_from_name(impl["return"])
28+
else:
29+
# We do always need a return type
30+
# to know which type to propagate up to the invoker
31+
_, argtypes = FunctionsCatalog.parse_signature(signature)
32+
# TODO: Is this the right way to handle this?
33+
self.return_type = self._type_from_name(argtypes[0])
34+
35+
@property
36+
def name(self) -> str:
37+
name, _ = FunctionsCatalog.parse_signature(self.signature)
38+
return name
39+
40+
@property
41+
def arguments(self) -> list[str]:
42+
_, argtypes = FunctionsCatalog.parse_signature(self.signature)
43+
return argtypes
44+
45+
@property
46+
def arguments_type(self) -> list[SubstraitType | None]:
47+
return [self._type_from_name(arg) for arg in self.arguments]
48+
49+
def _type_from_name(self, typename: str) -> SubstraitType | None:
50+
nullable = False
51+
if typename.endswith("?"):
52+
nullable = True
53+
54+
typename = typename.strip("?")
55+
if typename in ("any", "any1"):
56+
return None
57+
58+
if typename == "boolean":
59+
# For some reason boolean is an exception to the naming convention
60+
typename = "bool"
61+
62+
try:
63+
type_descriptor = SubstraitType.DESCRIPTOR.fields_by_name[
64+
typename
65+
].message_type
66+
except KeyError:
67+
# TODO: improve resolution of complext type like LIST?<any>
68+
print("Unsupported type", typename)
69+
return None
70+
71+
type_class = getattr(SubstraitType, type_descriptor.name)
72+
nullability = (
73+
SubstraitType.Nullability.NULLABILITY_REQUIRED
74+
if not nullable
75+
else SubstraitType.Nullability.NULLABILITY_NULLABLE
76+
)
77+
return SubstraitType(**{typename: type_class(nullability=nullability)})
678

779

880
class FunctionsCatalog:
@@ -32,20 +104,21 @@ class FunctionsCatalog:
32104
)
33105

34106
def __init__(self):
35-
self._registered_extensions = {}
107+
self._substrait_extension_uris = {}
108+
self._substrait_extension_functions = {}
36109
self._functions = {}
37-
self._functions_return_type = {}
38110

39-
def load_standard_extensions(self, dirpath):
111+
def load_standard_extensions(self, dirpath: str | os.PathLike):
112+
"""Load all standard substrait extensions from the target directory."""
40113
for ext in self.STANDARD_EXTENSIONS:
41114
self.load(dirpath, ext)
42115

43-
def load(self, dirpath, filename):
116+
def load(self, dirpath: str | os.PathLike, filename: str):
117+
"""Load an extension from a YAML file in a target directory."""
44118
with open(pathlib.Path(dirpath) / filename.strip("/")) as f:
45119
sections = yaml.safe_load(f)
46120

47-
loaded_functions = set()
48-
functions_return_type = {}
121+
loaded_functions = {}
49122
for functions in sections.values():
50123
for function in functions:
51124
function_name = function["name"]
@@ -56,100 +129,80 @@ def load(self, dirpath, filename):
56129
t.get("value", "unknown").strip("?")
57130
for t in impl.get("args", [])
58131
]
59-
if impl.get("variadic", False):
60-
# TODO: Variadic functions.
61-
argtypes *= 2
62-
63132
if not argtypes:
64133
signature = function_name
65134
else:
66135
signature = f"{function_name}:{'_'.join(argtypes)}"
67-
loaded_functions.add(signature)
68-
print("Loaded function", signature)
69-
functions_return_type[signature] = self._type_from_name(
70-
impl["return"]
136+
loaded_functions[signature] = RegisteredSubstraitFunction(
137+
signature, None, impl
71138
)
72139

73-
self._register_extensions(filename, loaded_functions, functions_return_type)
140+
self._register_extensions(filename, loaded_functions)
74141

75142
def _register_extensions(
76-
self, extension_uri, loaded_functions, functions_return_type
143+
self,
144+
extension_uri: str,
145+
loaded_functions: dict[str, RegisteredSubstraitFunction],
77146
):
78-
if extension_uri not in self._registered_extensions:
79-
ext_anchor_id = len(self._registered_extensions) + 1
80-
self._registered_extensions[extension_uri] = proto.SimpleExtensionURI(
147+
if extension_uri not in self._substrait_extension_uris:
148+
ext_anchor_id = len(self._substrait_extension_uris) + 1
149+
self._substrait_extension_uris[extension_uri] = SimpleExtensionURI(
81150
extension_uri_anchor=ext_anchor_id, uri=extension_uri
82151
)
83152

84-
for function in loaded_functions:
85-
if function in self._functions:
153+
for signature, registered_function in loaded_functions.items():
154+
if signature in self._substrait_extension_functions:
86155
extensions_by_anchor = self.extension_uris_by_anchor
87-
existing_function = self._functions[function]
156+
existing_function = self._substrait_extension_functions[signature]
88157
function_extension = extensions_by_anchor[
89158
existing_function.extension_uri_reference
90159
].uri
91160
raise ValueError(
92161
f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}"
93162
)
94-
extension_anchor = self._registered_extensions[
163+
extension_anchor = self._substrait_extension_uris[
95164
extension_uri
96165
].extension_uri_anchor
97-
function_anchor = len(self._functions) + 1
98-
self._functions[function] = (
99-
proto.SimpleExtensionDeclaration.ExtensionFunction(
166+
function_anchor = len(self._substrait_extension_functions) + 1
167+
self._substrait_extension_functions[signature] = (
168+
SimpleExtensionDeclaration.ExtensionFunction(
100169
extension_uri_reference=extension_anchor,
101-
name=function,
170+
name=signature,
102171
function_anchor=function_anchor,
103172
)
104173
)
105-
self._functions_return_type[function] = functions_return_type[function]
106-
107-
def _type_from_name(self, typename):
108-
nullable = False
109-
if typename.endswith("?"):
110-
nullable = True
111-
112-
typename = typename.strip("?")
113-
if typename in ("any", "any1"):
114-
return None
115-
116-
if typename == "boolean":
117-
# For some reason boolean is an exception to the naming convention
118-
typename = "bool"
119-
120-
try:
121-
type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[
122-
typename
123-
].message_type
124-
except KeyError:
125-
# TODO: improve resolution of complext type like LIST?<any>
126-
print("Unsupported type", typename)
127-
return None
128-
129-
type_class = getattr(proto.Type, type_descriptor.name)
130-
nullability = (
131-
proto.Type.Nullability.NULLABILITY_REQUIRED
132-
if not nullable
133-
else proto.Type.Nullability.NULLABILITY_NULLABLE
134-
)
135-
return proto.Type(**{typename: type_class(nullability=nullability)})
174+
registered_function.function_anchor = function_anchor
175+
self._functions.setdefault(registered_function.name, []).append(
176+
registered_function
177+
)
136178

137179
@property
138-
def extension_uris_by_anchor(self):
180+
def extension_uris_by_anchor(self) -> dict[int, SimpleExtensionURI]:
139181
return {
140182
ext.extension_uri_anchor: ext
141-
for ext in self._registered_extensions.values()
183+
for ext in self._substrait_extension_uris.values()
142184
}
143185

144186
@property
145-
def extension_uris(self):
146-
return list(self._registered_extensions.values())
187+
def extension_uris(self) -> list[SimpleExtensionURI]:
188+
return list(self._substrait_extension_uris.values())
147189

148190
@property
149-
def extensions(self):
150-
return list(self._functions.values())
191+
def extensions_functions(
192+
self,
193+
) -> list[SimpleExtensionDeclaration.ExtensionFunction]:
194+
return list(self._substrait_extension_functions.values())
195+
196+
@classmethod
197+
def make_signature(
198+
cls, function_name: str, proto_argtypes: Iterable[SubstraitType]
199+
):
200+
"""Create a function signature from a function name and substrait types.
201+
202+
The signature is generated according to Function Signature Compound Names
203+
as described in the Substrait documentation.
204+
"""
151205

152-
def signature(self, function_name, proto_argtypes):
153206
def _normalize_arg_types(argtypes):
154207
for argtype in argtypes:
155208
kind = argtype.WhichOneof("kind")
@@ -160,23 +213,73 @@ def _normalize_arg_types(argtypes):
160213

161214
return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}"
162215

163-
def function_anchor(self, function):
164-
return self._functions[function].function_anchor
216+
@classmethod
217+
def parse_signature(cls, signature: str) -> tuple[str, list[str]]:
218+
"""Parse a function signature and returns name and type names"""
219+
try:
220+
function_name, signature_args = signature.split(":")
221+
except ValueError:
222+
function_name = signature
223+
argtypes = []
224+
else:
225+
argtypes = signature_args.split("_")
226+
return function_name, argtypes
165227

166-
def function_return_type(self, function):
167-
return self._functions_return_type[function]
228+
def extensions_for_functions(
229+
self, function_signatures: Iterable[str]
230+
) -> tuple[list[SimpleExtensionURI], list[SimpleExtensionDeclaration]]:
231+
"""Given a set of function signatures, return the necessary extensions.
168232
169-
def extensions_for_functions(self, functions):
233+
The function will return the URIs of the extensions and the extension
234+
that have to be declared in the plan to use the functions.
235+
"""
170236
uris_anchors = set()
171237
extensions = []
172-
for f in functions:
173-
ext = self._functions[f]
174-
if not ext.extension_uri_reference:
175-
# Built-in function
176-
continue
238+
for f in function_signatures:
239+
ext = self._substrait_extension_functions[f]
177240
uris_anchors.add(ext.extension_uri_reference)
178-
extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext))
241+
extensions.append(SimpleExtensionDeclaration(extension_function=ext))
179242

180243
uris_by_anchor = self.extension_uris_by_anchor
181244
extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors]
182245
return extension_uris, extensions
246+
247+
def lookup_function(self, signature: str) -> RegisteredSubstraitFunction | None:
248+
"""Given the signature of a function invocation, return the matching function."""
249+
function_name, invocation_argtypes = self.parse_signature(signature)
250+
251+
functions = self._functions.get(function_name)
252+
if not functions:
253+
# No function with such a name at all.
254+
return None
255+
256+
is_variadic = functions[0].variadic
257+
if is_variadic:
258+
# If it's variadic we care about only the first parameter.
259+
invocation_argtypes = invocation_argtypes[:1]
260+
261+
found_function = None
262+
for function in functions:
263+
accepted_function_arguments = function.arguments
264+
for argidx, argtype in enumerate(invocation_argtypes):
265+
try:
266+
accepted_argument = accepted_function_arguments[argidx]
267+
except IndexError:
268+
# More arguments than available were provided
269+
break
270+
if accepted_argument != argtype and accepted_argument not in (
271+
"any",
272+
"any1",
273+
):
274+
break
275+
else:
276+
if argidx < len(accepted_function_arguments) - 1:
277+
# Not enough arguments were provided
278+
remainder = accepted_function_arguments[argidx + 1 :]
279+
if all(arg.endswith("?") for arg in remainder):
280+
# All remaining arguments are optional
281+
found_function = function
282+
else:
283+
found_function = function
284+
285+
return found_function

src/substrait/sql/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def __getitem__(self, argument):
3232
if isinstance(argument, dispatch_cls):
3333
return func
3434
else:
35-
raise ValueError(f"Unsupported SQL Node type: {cls}")
35+
raise ValueError(
36+
f"Unsupported SQL Node type: {argument.__class__.__name__} -> {argument}"
37+
)
3638

3739
def __call__(self, obj, dispatch_argument, *args, **kwargs):
3840
return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs)

0 commit comments

Comments
 (0)