1
+ import os
1
2
import pathlib
3
+ from collections .abc import Iterable
2
4
3
5
import yaml
4
6
5
- from substrait import proto
7
+ from substrait .gen .proto .type_pb2 import Type as SubstraitType
8
+ from substrait .gen .proto .extensions .extensions_pb2 import (
9
+ SimpleExtensionURI ,
10
+ SimpleExtensionDeclaration ,
11
+ )
12
+
13
+
14
+ class RegisteredSubstraitFunction :
15
+ """A Substrait function loaded from an extension file.
16
+
17
+ The FunctionsCatalog will keep a collection of RegisteredSubstraitFunction
18
+ and will use them to generate the necessary extension URIs and extensions.
19
+ """
20
+
21
+ def __init__ (self , signature : str , function_anchor : int | None , impl : dict ):
22
+ self .signature = signature
23
+ self .function_anchor = function_anchor
24
+ self .variadic = impl .get ("variadic" , False )
25
+
26
+ if "return" in impl :
27
+ self .return_type = self ._type_from_name (impl ["return" ])
28
+ else :
29
+ # We do always need a return type
30
+ # to know which type to propagate up to the invoker
31
+ _ , argtypes = FunctionsCatalog .parse_signature (signature )
32
+ # TODO: Is this the right way to handle this?
33
+ self .return_type = self ._type_from_name (argtypes [0 ])
34
+
35
+ @property
36
+ def name (self ) -> str :
37
+ name , _ = FunctionsCatalog .parse_signature (self .signature )
38
+ return name
39
+
40
+ @property
41
+ def arguments (self ) -> list [str ]:
42
+ _ , argtypes = FunctionsCatalog .parse_signature (self .signature )
43
+ return argtypes
44
+
45
+ @property
46
+ def arguments_type (self ) -> list [SubstraitType | None ]:
47
+ return [self ._type_from_name (arg ) for arg in self .arguments ]
48
+
49
+ def _type_from_name (self , typename : str ) -> SubstraitType | None :
50
+ nullable = False
51
+ if typename .endswith ("?" ):
52
+ nullable = True
53
+
54
+ typename = typename .strip ("?" )
55
+ if typename in ("any" , "any1" ):
56
+ return None
57
+
58
+ if typename == "boolean" :
59
+ # For some reason boolean is an exception to the naming convention
60
+ typename = "bool"
61
+
62
+ try :
63
+ type_descriptor = SubstraitType .DESCRIPTOR .fields_by_name [
64
+ typename
65
+ ].message_type
66
+ except KeyError :
67
+ # TODO: improve resolution of complext type like LIST?<any>
68
+ print ("Unsupported type" , typename )
69
+ return None
70
+
71
+ type_class = getattr (SubstraitType , type_descriptor .name )
72
+ nullability = (
73
+ SubstraitType .Nullability .NULLABILITY_REQUIRED
74
+ if not nullable
75
+ else SubstraitType .Nullability .NULLABILITY_NULLABLE
76
+ )
77
+ return SubstraitType (** {typename : type_class (nullability = nullability )})
6
78
7
79
8
80
class FunctionsCatalog :
@@ -32,20 +104,21 @@ class FunctionsCatalog:
32
104
)
33
105
34
106
def __init__ (self ):
35
- self ._registered_extensions = {}
107
+ self ._substrait_extension_uris = {}
108
+ self ._substrait_extension_functions = {}
36
109
self ._functions = {}
37
- self ._functions_return_type = {}
38
110
39
- def load_standard_extensions (self , dirpath ):
111
+ def load_standard_extensions (self , dirpath : str | os .PathLike ):
112
+ """Load all standard substrait extensions from the target directory."""
40
113
for ext in self .STANDARD_EXTENSIONS :
41
114
self .load (dirpath , ext )
42
115
43
- def load (self , dirpath , filename ):
116
+ def load (self , dirpath : str | os .PathLike , filename : str ):
117
+ """Load an extension from a YAML file in a target directory."""
44
118
with open (pathlib .Path (dirpath ) / filename .strip ("/" )) as f :
45
119
sections = yaml .safe_load (f )
46
120
47
- loaded_functions = set ()
48
- functions_return_type = {}
121
+ loaded_functions = {}
49
122
for functions in sections .values ():
50
123
for function in functions :
51
124
function_name = function ["name" ]
@@ -56,100 +129,80 @@ def load(self, dirpath, filename):
56
129
t .get ("value" , "unknown" ).strip ("?" )
57
130
for t in impl .get ("args" , [])
58
131
]
59
- if impl .get ("variadic" , False ):
60
- # TODO: Variadic functions.
61
- argtypes *= 2
62
-
63
132
if not argtypes :
64
133
signature = function_name
65
134
else :
66
135
signature = f"{ function_name } :{ '_' .join (argtypes )} "
67
- loaded_functions .add (signature )
68
- print ("Loaded function" , signature )
69
- functions_return_type [signature ] = self ._type_from_name (
70
- impl ["return" ]
136
+ loaded_functions [signature ] = RegisteredSubstraitFunction (
137
+ signature , None , impl
71
138
)
72
139
73
- self ._register_extensions (filename , loaded_functions , functions_return_type )
140
+ self ._register_extensions (filename , loaded_functions )
74
141
75
142
def _register_extensions (
76
- self , extension_uri , loaded_functions , functions_return_type
143
+ self ,
144
+ extension_uri : str ,
145
+ loaded_functions : dict [str , RegisteredSubstraitFunction ],
77
146
):
78
- if extension_uri not in self ._registered_extensions :
79
- ext_anchor_id = len (self ._registered_extensions ) + 1
80
- self ._registered_extensions [extension_uri ] = proto . SimpleExtensionURI (
147
+ if extension_uri not in self ._substrait_extension_uris :
148
+ ext_anchor_id = len (self ._substrait_extension_uris ) + 1
149
+ self ._substrait_extension_uris [extension_uri ] = SimpleExtensionURI (
81
150
extension_uri_anchor = ext_anchor_id , uri = extension_uri
82
151
)
83
152
84
- for function in loaded_functions :
85
- if function in self ._functions :
153
+ for signature , registered_function in loaded_functions . items () :
154
+ if signature in self ._substrait_extension_functions :
86
155
extensions_by_anchor = self .extension_uris_by_anchor
87
- existing_function = self ._functions [ function ]
156
+ existing_function = self ._substrait_extension_functions [ signature ]
88
157
function_extension = extensions_by_anchor [
89
158
existing_function .extension_uri_reference
90
159
].uri
91
160
raise ValueError (
92
161
f"Duplicate function definition: { existing_function .name } from { extension_uri } , already loaded from { function_extension } "
93
162
)
94
- extension_anchor = self ._registered_extensions [
163
+ extension_anchor = self ._substrait_extension_uris [
95
164
extension_uri
96
165
].extension_uri_anchor
97
- function_anchor = len (self ._functions ) + 1
98
- self ._functions [ function ] = (
99
- proto . SimpleExtensionDeclaration .ExtensionFunction (
166
+ function_anchor = len (self ._substrait_extension_functions ) + 1
167
+ self ._substrait_extension_functions [ signature ] = (
168
+ SimpleExtensionDeclaration .ExtensionFunction (
100
169
extension_uri_reference = extension_anchor ,
101
- name = function ,
170
+ name = signature ,
102
171
function_anchor = function_anchor ,
103
172
)
104
173
)
105
- self ._functions_return_type [function ] = functions_return_type [function ]
106
-
107
- def _type_from_name (self , typename ):
108
- nullable = False
109
- if typename .endswith ("?" ):
110
- nullable = True
111
-
112
- typename = typename .strip ("?" )
113
- if typename in ("any" , "any1" ):
114
- return None
115
-
116
- if typename == "boolean" :
117
- # For some reason boolean is an exception to the naming convention
118
- typename = "bool"
119
-
120
- try :
121
- type_descriptor = proto .Type .DESCRIPTOR .fields_by_name [
122
- typename
123
- ].message_type
124
- except KeyError :
125
- # TODO: improve resolution of complext type like LIST?<any>
126
- print ("Unsupported type" , typename )
127
- return None
128
-
129
- type_class = getattr (proto .Type , type_descriptor .name )
130
- nullability = (
131
- proto .Type .Nullability .NULLABILITY_REQUIRED
132
- if not nullable
133
- else proto .Type .Nullability .NULLABILITY_NULLABLE
134
- )
135
- return proto .Type (** {typename : type_class (nullability = nullability )})
174
+ registered_function .function_anchor = function_anchor
175
+ self ._functions .setdefault (registered_function .name , []).append (
176
+ registered_function
177
+ )
136
178
137
179
@property
138
- def extension_uris_by_anchor (self ):
180
+ def extension_uris_by_anchor (self ) -> dict [ int , SimpleExtensionURI ] :
139
181
return {
140
182
ext .extension_uri_anchor : ext
141
- for ext in self ._registered_extensions .values ()
183
+ for ext in self ._substrait_extension_uris .values ()
142
184
}
143
185
144
186
@property
145
- def extension_uris (self ):
146
- return list (self ._registered_extensions .values ())
187
+ def extension_uris (self ) -> list [ SimpleExtensionURI ] :
188
+ return list (self ._substrait_extension_uris .values ())
147
189
148
190
@property
149
- def extensions (self ):
150
- return list (self ._functions .values ())
191
+ def extensions_functions (
192
+ self ,
193
+ ) -> list [SimpleExtensionDeclaration .ExtensionFunction ]:
194
+ return list (self ._substrait_extension_functions .values ())
195
+
196
+ @classmethod
197
+ def make_signature (
198
+ cls , function_name : str , proto_argtypes : Iterable [SubstraitType ]
199
+ ):
200
+ """Create a function signature from a function name and substrait types.
201
+
202
+ The signature is generated according to Function Signature Compound Names
203
+ as described in the Substrait documentation.
204
+ """
151
205
152
- def signature (self , function_name , proto_argtypes ):
153
206
def _normalize_arg_types (argtypes ):
154
207
for argtype in argtypes :
155
208
kind = argtype .WhichOneof ("kind" )
@@ -160,23 +213,73 @@ def _normalize_arg_types(argtypes):
160
213
161
214
return f"{ function_name } :{ '_' .join (_normalize_arg_types (proto_argtypes ))} "
162
215
163
- def function_anchor (self , function ):
164
- return self ._functions [function ].function_anchor
216
+ @classmethod
217
+ def parse_signature (cls , signature : str ) -> tuple [str , list [str ]]:
218
+ """Parse a function signature and returns name and type names"""
219
+ try :
220
+ function_name , signature_args = signature .split (":" )
221
+ except ValueError :
222
+ function_name = signature
223
+ argtypes = []
224
+ else :
225
+ argtypes = signature_args .split ("_" )
226
+ return function_name , argtypes
165
227
166
- def function_return_type (self , function ):
167
- return self ._functions_return_type [function ]
228
+ def extensions_for_functions (
229
+ self , function_signatures : Iterable [str ]
230
+ ) -> tuple [list [SimpleExtensionURI ], list [SimpleExtensionDeclaration ]]:
231
+ """Given a set of function signatures, return the necessary extensions.
168
232
169
- def extensions_for_functions (self , functions ):
233
+ The function will return the URIs of the extensions and the extension
234
+ that have to be declared in the plan to use the functions.
235
+ """
170
236
uris_anchors = set ()
171
237
extensions = []
172
- for f in functions :
173
- ext = self ._functions [f ]
174
- if not ext .extension_uri_reference :
175
- # Built-in function
176
- continue
238
+ for f in function_signatures :
239
+ ext = self ._substrait_extension_functions [f ]
177
240
uris_anchors .add (ext .extension_uri_reference )
178
- extensions .append (proto . SimpleExtensionDeclaration (extension_function = ext ))
241
+ extensions .append (SimpleExtensionDeclaration (extension_function = ext ))
179
242
180
243
uris_by_anchor = self .extension_uris_by_anchor
181
244
extension_uris = [uris_by_anchor [uri_anchor ] for uri_anchor in uris_anchors ]
182
245
return extension_uris , extensions
246
+
247
+ def lookup_function (self , signature : str ) -> RegisteredSubstraitFunction | None :
248
+ """Given the signature of a function invocation, return the matching function."""
249
+ function_name , invocation_argtypes = self .parse_signature (signature )
250
+
251
+ functions = self ._functions .get (function_name )
252
+ if not functions :
253
+ # No function with such a name at all.
254
+ return None
255
+
256
+ is_variadic = functions [0 ].variadic
257
+ if is_variadic :
258
+ # If it's variadic we care about only the first parameter.
259
+ invocation_argtypes = invocation_argtypes [:1 ]
260
+
261
+ found_function = None
262
+ for function in functions :
263
+ accepted_function_arguments = function .arguments
264
+ for argidx , argtype in enumerate (invocation_argtypes ):
265
+ try :
266
+ accepted_argument = accepted_function_arguments [argidx ]
267
+ except IndexError :
268
+ # More arguments than available were provided
269
+ break
270
+ if accepted_argument != argtype and accepted_argument not in (
271
+ "any" ,
272
+ "any1" ,
273
+ ):
274
+ break
275
+ else :
276
+ if argidx < len (accepted_function_arguments ) - 1 :
277
+ # Not enough arguments were provided
278
+ remainder = accepted_function_arguments [argidx + 1 :]
279
+ if all (arg .endswith ("?" ) for arg in remainder ):
280
+ # All remaining arguments are optional
281
+ found_function = function
282
+ else :
283
+ found_function = function
284
+
285
+ return found_function
0 commit comments