|
| 1 | +import pathlib |
| 2 | + |
| 3 | +import sqlglot |
| 4 | +import yaml |
| 5 | + |
| 6 | +from substrait import proto |
| 7 | + |
| 8 | + |
| 9 | +SQL_BINARY_FUNCTIONS = { |
| 10 | + # Arithmetic |
| 11 | + "add": "add", |
| 12 | + "div": "div", |
| 13 | + "mul": "mul", |
| 14 | + "sub": "sub", |
| 15 | + # Comparisons |
| 16 | + "eq": "equal", |
| 17 | +} |
| 18 | + |
| 19 | + |
| 20 | +def parse_sql_extended_expression(catalog, schema, sql): |
| 21 | + select = sqlglot.parse_one(sql) |
| 22 | + if not isinstance(select, sqlglot.expressions.Select): |
| 23 | + raise ValueError("a SELECT statement was expected") |
| 24 | + |
| 25 | + invoked_functions_projection, projections = _substrait_projection_from_sqlglot( |
| 26 | + catalog, schema, select.expressions |
| 27 | + ) |
| 28 | + extension_uris, extensions = catalog.extensions_for_functions( |
| 29 | + invoked_functions_projection |
| 30 | + ) |
| 31 | + projection_extended_expr = proto.ExtendedExpression( |
| 32 | + extension_uris=extension_uris, |
| 33 | + extensions=extensions, |
| 34 | + base_schema=schema, |
| 35 | + referred_expr=projections, |
| 36 | + ) |
| 37 | + |
| 38 | + invoked_functions_filter, filter_expr = _substrait_expression_from_sqlglot( |
| 39 | + catalog, schema, select.find(sqlglot.expressions.Where).this |
| 40 | + ) |
| 41 | + extension_uris, extensions = catalog.extensions_for_functions( |
| 42 | + invoked_functions_filter |
| 43 | + ) |
| 44 | + filter_extended_expr = proto.ExtendedExpression( |
| 45 | + extension_uris=extension_uris, |
| 46 | + extensions=extensions, |
| 47 | + base_schema=schema, |
| 48 | + referred_expr=[proto.ExpressionReference(expression=filter_expr)], |
| 49 | + ) |
| 50 | + |
| 51 | + return projection_extended_expr, filter_extended_expr |
| 52 | + |
| 53 | + |
| 54 | +def _substrait_projection_from_sqlglot(catalog, schema, expressions): |
| 55 | + if not expressions: |
| 56 | + return set(), [] |
| 57 | + |
| 58 | + # My understanding of ExtendedExpressions is that they are meant to directly |
| 59 | + # point to the Expression that ProjectRel would contain, so we don't actually |
| 60 | + # need a ProjectRel at all. |
| 61 | + """ |
| 62 | + projection_sub = proto.ProjectRel( |
| 63 | + input=proto.Rel( |
| 64 | + read=proto.ReadRel( |
| 65 | + named_table=proto.ReadRel.NamedTable(names=["__table__"]), |
| 66 | + base_schema=schema, |
| 67 | + ) |
| 68 | + ), |
| 69 | + expressions=[], |
| 70 | + ) |
| 71 | + """ |
| 72 | + |
| 73 | + substrait_expressions = [] |
| 74 | + invoked_functions = set() |
| 75 | + for sqlexpr in expressions: |
| 76 | + output_names = [] |
| 77 | + if isinstance(sqlexpr, sqlglot.expressions.Alias): |
| 78 | + output_names = [sqlexpr.output_name] |
| 79 | + sqlexpr = sqlexpr.this |
| 80 | + _, substrait_expr = _parse_expression( |
| 81 | + catalog, schema, sqlexpr, invoked_functions |
| 82 | + ) |
| 83 | + substrait_expr_reference = proto.ExpressionReference( |
| 84 | + expression=substrait_expr, output_names=output_names |
| 85 | + ) |
| 86 | + substrait_expressions.append(substrait_expr_reference) |
| 87 | + |
| 88 | + return invoked_functions, substrait_expressions |
| 89 | + |
| 90 | + |
| 91 | +def _substrait_expression_from_sqlglot(catalog, schema, sqlglot_node): |
| 92 | + if not sqlglot_node: |
| 93 | + return set(), None |
| 94 | + |
| 95 | + invoked_functions = set() |
| 96 | + _, substrait_expr = _parse_expression( |
| 97 | + catalog, schema, sqlglot_node, invoked_functions |
| 98 | + ) |
| 99 | + return invoked_functions, substrait_expr |
| 100 | + |
| 101 | + |
| 102 | +def _parse_expression(catalog, schema, expr, invoked_functions): |
| 103 | + # TODO: Propagate up column names (output_names) so that the projections _always_ have an output_name |
| 104 | + if isinstance(expr, sqlglot.expressions.Literal): |
| 105 | + if expr.is_string: |
| 106 | + return proto.Type(string=proto.Type.String()), proto.Expression( |
| 107 | + literal=proto.Expression.Literal(string=expr.text) |
| 108 | + ) |
| 109 | + elif expr.is_int: |
| 110 | + return proto.Type(i32=proto.Type.I32()), proto.Expression( |
| 111 | + literal=proto.Expression.Literal(i32=int(expr.name)) |
| 112 | + ) |
| 113 | + elif sqlglot.helper.is_float(expr.name): |
| 114 | + return proto.Type(fp32=proto.Type.FP32()), proto.Expression( |
| 115 | + literal=proto.Expression.Literal(float=float(expr.name)) |
| 116 | + ) |
| 117 | + else: |
| 118 | + raise ValueError(f"Unsupporter literal: {expr.text}") |
| 119 | + elif isinstance(expr, sqlglot.expressions.Column): |
| 120 | + column_name = expr.output_name |
| 121 | + schema_field = list(schema.names).index(column_name) |
| 122 | + schema_type = schema.struct.types[schema_field] |
| 123 | + return schema_type, proto.Expression( |
| 124 | + selection=proto.Expression.FieldReference( |
| 125 | + direct_reference=proto.Expression.ReferenceSegment( |
| 126 | + struct_field=proto.Expression.ReferenceSegment.StructField( |
| 127 | + field=schema_field |
| 128 | + ) |
| 129 | + ) |
| 130 | + ) |
| 131 | + ) |
| 132 | + elif expr.key in SQL_BINARY_FUNCTIONS: |
| 133 | + left_type, left = _parse_expression( |
| 134 | + catalog, schema, expr.left, invoked_functions |
| 135 | + ) |
| 136 | + right_type, right = _parse_expression( |
| 137 | + catalog, schema, expr.right, invoked_functions |
| 138 | + ) |
| 139 | + function_name = SQL_BINARY_FUNCTIONS[expr.key] |
| 140 | + signature, result_type, function_expression = _parse_function_invokation( |
| 141 | + function_name, left_type, left, right_type, right |
| 142 | + ) |
| 143 | + invoked_functions.add(signature) |
| 144 | + return result_type, function_expression |
| 145 | + else: |
| 146 | + raise ValueError( |
| 147 | + f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" |
| 148 | + ) |
| 149 | + |
| 150 | + |
| 151 | +def _parse_function_invokation(function_name, left_type, left, right_type, right): |
| 152 | + signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" |
| 153 | + try: |
| 154 | + function_anchor = catalog.function_anchor(signature) |
| 155 | + except KeyError: |
| 156 | + # not function found with the exact types, try any1_any1 version |
| 157 | + signature = f"{function_name}:any1_any1" |
| 158 | + function_anchor = catalog.function_anchor(signature) |
| 159 | + return ( |
| 160 | + signature, |
| 161 | + left_type, |
| 162 | + proto.Expression( |
| 163 | + scalar_function=proto.Expression.ScalarFunction( |
| 164 | + function_reference=function_anchor, |
| 165 | + arguments=[ |
| 166 | + proto.FunctionArgument(value=left), |
| 167 | + proto.FunctionArgument(value=right), |
| 168 | + ], |
| 169 | + ) |
| 170 | + ), |
| 171 | + ) |
| 172 | + |
| 173 | + |
| 174 | +class FunctionsCatalog: |
| 175 | + STANDARD_EXTENSIONS = ( |
| 176 | + "/functions_aggregate_approx.yaml", |
| 177 | + "/functions_aggregate_generic.yaml", |
| 178 | + "/functions_arithmetic.yaml", |
| 179 | + "/functions_arithmetic_decimal.yaml", |
| 180 | + "/functions_boolean.yaml", |
| 181 | + "/functions_comparison.yaml", |
| 182 | + "/functions_datetime.yaml", |
| 183 | + "/functions_geometry.yaml", |
| 184 | + "/functions_logarithmic.yaml", |
| 185 | + "/functions_rounding.yaml", |
| 186 | + "/functions_set.yaml", |
| 187 | + "/functions_string.yaml", |
| 188 | + ) |
| 189 | + |
| 190 | + def __init__(self): |
| 191 | + self._declarations = {} |
| 192 | + self._registered_extensions = {} |
| 193 | + self._functions = {} |
| 194 | + |
| 195 | + def load_standard_extensions(self, dirpath): |
| 196 | + for ext in self.STANDARD_EXTENSIONS: |
| 197 | + self.load(dirpath, ext) |
| 198 | + |
| 199 | + def load(self, dirpath, filename): |
| 200 | + with open(pathlib.Path(dirpath) / filename.strip("/")) as f: |
| 201 | + sections = yaml.safe_load(f) |
| 202 | + |
| 203 | + loaded_functions = set() |
| 204 | + for functions in sections.values(): |
| 205 | + for function in functions: |
| 206 | + function_name = function["name"] |
| 207 | + for impl in function.get("impls", []): |
| 208 | + argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] |
| 209 | + if not argtypes: |
| 210 | + signature = function_name |
| 211 | + else: |
| 212 | + signature = f"{function_name}:{'_'.join(argtypes)}" |
| 213 | + self._declarations[signature] = filename |
| 214 | + loaded_functions.add(signature) |
| 215 | + |
| 216 | + self._register_extensions(filename, loaded_functions) |
| 217 | + |
| 218 | + def _register_extensions(self, extension_uri, loaded_functions): |
| 219 | + if extension_uri not in self._registered_extensions: |
| 220 | + ext_anchor_id = len(self._registered_extensions) + 1 |
| 221 | + self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( |
| 222 | + extension_uri_anchor=ext_anchor_id, uri=extension_uri |
| 223 | + ) |
| 224 | + |
| 225 | + for function in loaded_functions: |
| 226 | + if function in self._functions: |
| 227 | + extensions_by_anchor = self.extension_uris_by_anchor |
| 228 | + function = self._functions[function] |
| 229 | + function_extension = extensions_by_anchor[ |
| 230 | + function.extension_uri_reference |
| 231 | + ].uri |
| 232 | + continue |
| 233 | + raise ValueError( |
| 234 | + f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}" |
| 235 | + ) |
| 236 | + extension_anchor = self._registered_extensions[ |
| 237 | + extension_uri |
| 238 | + ].extension_uri_anchor |
| 239 | + function_anchor = len(self._functions) + 1 |
| 240 | + self._functions[function] = ( |
| 241 | + proto.SimpleExtensionDeclaration.ExtensionFunction( |
| 242 | + extension_uri_reference=extension_anchor, |
| 243 | + name=function, |
| 244 | + function_anchor=function_anchor, |
| 245 | + ) |
| 246 | + ) |
| 247 | + |
| 248 | + @property |
| 249 | + def extension_uris_by_anchor(self): |
| 250 | + return { |
| 251 | + ext.extension_uri_anchor: ext |
| 252 | + for ext in self._registered_extensions.values() |
| 253 | + } |
| 254 | + |
| 255 | + @property |
| 256 | + def extension_uris(self): |
| 257 | + return list(self._registered_extensions.values()) |
| 258 | + |
| 259 | + @property |
| 260 | + def extensions(self): |
| 261 | + return list(self._functions.values()) |
| 262 | + |
| 263 | + def function_anchor(self, function): |
| 264 | + return self._functions[function].function_anchor |
| 265 | + |
| 266 | + def extensions_for_functions(self, functions): |
| 267 | + uris_anchors = set() |
| 268 | + extensions = [] |
| 269 | + for f in functions: |
| 270 | + ext = self._functions[f] |
| 271 | + uris_anchors.add(ext.extension_uri_reference) |
| 272 | + extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) |
| 273 | + |
| 274 | + uris_by_anchor = self.extension_uris_by_anchor |
| 275 | + extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] |
| 276 | + return extension_uris, extensions |
| 277 | + |
| 278 | + |
| 279 | +catalog = FunctionsCatalog() |
| 280 | +catalog.load_standard_extensions( |
| 281 | + pathlib.Path(__file__).parent.parent / "third_party" / "substrait" / "extensions", |
| 282 | +) |
| 283 | +schema = proto.NamedStruct( |
| 284 | + names=["first_name", "surname", "age"], |
| 285 | + struct=proto.Type.Struct( |
| 286 | + types=[ |
| 287 | + proto.Type( |
| 288 | + string=proto.Type.String( |
| 289 | + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED |
| 290 | + ) |
| 291 | + ), |
| 292 | + proto.Type( |
| 293 | + string=proto.Type.String( |
| 294 | + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED |
| 295 | + ) |
| 296 | + ), |
| 297 | + proto.Type( |
| 298 | + i32=proto.Type.I32( |
| 299 | + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED |
| 300 | + ) |
| 301 | + ), |
| 302 | + ] |
| 303 | + ), |
| 304 | +) |
| 305 | + |
| 306 | +if __name__ == '__main__': |
| 307 | + sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" |
| 308 | + projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) |
| 309 | + print("---- SQL INPUT ----") |
| 310 | + print(sql) |
| 311 | + print("---- PROJECTION ----") |
| 312 | + print(projection_expr) |
| 313 | + print("---- FILTER ----") |
| 314 | + print(filter_expr) |
| 315 | + # parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)") |
0 commit comments