Skip to content

feat: basic support for SQL SELECT -> ExtendedExpression #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "[email protected]
license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.8.1"
dependencies = ["protobuf >= 3.20"]
dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0", "PyYAML"]
dynamic = ["version"]

[tool.setuptools_scm]
Expand Down
2 changes: 2 additions & 0 deletions src/substrait/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .extended_expression import parse_sql_extended_expression
from .functions_catalog import FunctionsCatalog
67 changes: 67 additions & 0 deletions src/substrait/sql/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pathlib
import argparse

from substrait import proto
from .functions_catalog import FunctionsCatalog
from .extended_expression import parse_sql_extended_expression


def main():
"""Commandline tool to test the SQL to ExtendedExpression parser.

Run as python -m substrait.sql first_name=String,surname=String,age=I32 "SELECT surname, age + 1 as next_birthday, age + 2 WHERE age = 32"
"""
parser = argparse.ArgumentParser(
description="Convert a SQL SELECT statement to an ExtendedExpression"
)
parser.add_argument("schema", type=str, help="Schema of the input data")
parser.add_argument("sql", type=str, help="SQL SELECT statement")
args = parser.parse_args()

catalog = FunctionsCatalog()
catalog.load_standard_extensions(
pathlib.Path(__file__).parent.parent.parent.parent
/ "third_party"
/ "substrait"
/ "extensions",
)
schema = parse_schema(args.schema)
projection_expr, filter_expr = parse_sql_extended_expression(
catalog, schema, args.sql
)

print("---- SQL INPUT ----")
print(args.sql)
print("---- PROJECTION ----")
print(projection_expr)
print("---- FILTER ----")
print(filter_expr)


def parse_schema(schema_string):
"""Parse Schema from a comma separated string of fieldname=fieldtype pairs.

For example: "first_name=String,surname=String,age=I32"
"""
types = []
names = []

fields = schema_string.split(",")
for field in fields:
fieldname, fieldtype = field.split("=")
proto_type = getattr(proto.Type, fieldtype)
names.append(fieldname)
types.append(
proto.Type(
**{
fieldtype.lower(): proto_type(
nullability=proto.Type.Nullability.NULLABILITY_REQUIRED
)
}
)
)
return proto.NamedStruct(names=names, struct=proto.Type.Struct(types=types))


if __name__ == "__main__":
main()
273 changes: 273 additions & 0 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import itertools

import sqlglot

from substrait import proto
from .utils import DispatchRegistry


SQL_FUNCTIONS = {
# Arithmetic
sqlglot.expressions.Add: "add",
sqlglot.expressions.Div: "div",
sqlglot.expressions.Mul: "mul",
sqlglot.expressions.Sub: "sub",
sqlglot.expressions.Mod: "modulus",
sqlglot.expressions.BitwiseAnd: "bitwise_and",
sqlglot.expressions.BitwiseOr: "bitwise_or",
sqlglot.expressions.BitwiseXor: "bitwise_xor",
sqlglot.expressions.BitwiseNot: "bitwise_not",
# Comparisons
sqlglot.expressions.EQ: "equal",
sqlglot.expressions.NullSafeEQ: "is_not_distinct_from",
sqlglot.expressions.NEQ: "not_equal",
sqlglot.expressions.GT: "gt",
sqlglot.expressions.GTE: "gte",
sqlglot.expressions.LT: "lt",
sqlglot.expressions.LTE: "lte",
sqlglot.expressions.IsNan: "is_nan",
# logical
sqlglot.expressions.And: "and",
sqlglot.expressions.Or: "or",
sqlglot.expressions.Not: "not",
}


def parse_sql_extended_expression(catalog, schema, sql):
"""Parse a SQL SELECT statement into an ExtendedExpression.

Only supports SELECT statements with projections and WHERE clauses.
"""
select = sqlglot.parse_one(sql)
if not isinstance(select, sqlglot.expressions.Select):
raise ValueError("a SELECT statement was expected")

sqlglot_parser = SQLGlotParser(catalog, schema)

# Handle the projections in the SELECT statemenent.
project_expressions = []
projection_invoked_functions = set()
for sqlexpr in select.expressions:
parsed_expr = sqlglot_parser.expression_from_sqlglot(sqlexpr)
projection_invoked_functions.update(parsed_expr.invoked_functions)
project_expressions.append(
proto.ExpressionReference(
expression=parsed_expr.expression,
output_names=[parsed_expr.output_name],
)
)
extension_uris, extensions = catalog.extensions_for_functions(
projection_invoked_functions
)
projection_extended_expr = proto.ExtendedExpression(
extension_uris=extension_uris,
extensions=extensions,
base_schema=schema,
referred_expr=project_expressions,
)

# Handle WHERE clause in the SELECT statement.
filter_parsed_expr = sqlglot_parser.expression_from_sqlglot(
select.find(sqlglot.expressions.Where).this
)
extension_uris, extensions = catalog.extensions_for_functions(
filter_parsed_expr.invoked_functions
)
filter_extended_expr = proto.ExtendedExpression(
extension_uris=extension_uris,
extensions=extensions,
base_schema=schema,
referred_expr=[
proto.ExpressionReference(expression=filter_parsed_expr.expression)
],
)

return projection_extended_expr, filter_extended_expr


class SQLGlotParser:
DISPATCH_REGISTRY = DispatchRegistry()

def __init__(self, functions_catalog, schema):
self._functions_catalog = functions_catalog
self._schema = schema
self._counter = itertools.count()

self._parse_expression = self.DISPATCH_REGISTRY.bind(self)

def expression_from_sqlglot(self, sqlglot_node):
"""Parse a SQLGlot expression into a Substrait Expression."""
return self._parse_expression(sqlglot_node)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Literal)
def _parse_Literal(self, expr):
if expr.is_string:
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(string=proto.Type.String()),
proto.Expression(literal=proto.Expression.Literal(string=expr.name)),
)
elif expr.is_int:
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(i32=proto.Type.I32()),
proto.Expression(literal=proto.Expression.Literal(i32=int(expr.name))),
)
elif sqlglot.helper.is_float(expr.name):
return ParsedSubstraitExpression(
f"literal${next(self._counter)}",
proto.Type(fp32=proto.Type.FP32()),
proto.Expression(
literal=proto.Expression.Literal(float=float(expr.name))
),
)
else:
raise ValueError(f"Unsupporter literal: {expr.text}")

@DISPATCH_REGISTRY.register(sqlglot.expressions.Column)
def _parse_Column(self, expr):
column_name = expr.output_name
schema_field = list(self._schema.names).index(column_name)
schema_type = self._schema.struct.types[schema_field]
return ParsedSubstraitExpression(
column_name,
schema_type,
proto.Expression(
selection=proto.Expression.FieldReference(
direct_reference=proto.Expression.ReferenceSegment(
struct_field=proto.Expression.ReferenceSegment.StructField(
field=schema_field
)
)
)
),
)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Alias)
def _parse_Alias(self, expr):
parsed_expression = self._parse_expression(expr.this)
return parsed_expression.duplicate(output_name=expr.output_name)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Is)
def _parse_IS(self, expr):
# IS NULL is a special case because in SQLGlot is a binary expression with argument
# while in Substrait there are only the is_null and is_not_null unary functions
argument_parsed_expr = self._parse_expression(expr.left)
if isinstance(expr.right, sqlglot.expressions.Null):
function_name = "is_null"
else:
raise ValueError(f"Unsupported IS expression: {expr}")
signature, result_type, function_expression = self._parse_function_invokation(
function_name, argument_parsed_expr
)
result_name = (
f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
)
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
argument_parsed_expr.invoked_functions | {signature},
)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Binary)
def _parser_Binary(self, expr):
left_parsed_expr = self._parse_expression(expr.left)
right_parsed_expr = self._parse_expression(expr.right)
function_name = SQL_FUNCTIONS[type(expr)]
signature, result_type, function_expression = self._parse_function_invokation(
function_name, left_parsed_expr, right_parsed_expr
)
result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}"
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
left_parsed_expr.invoked_functions
| right_parsed_expr.invoked_functions
| {signature},
)

@DISPATCH_REGISTRY.register(sqlglot.expressions.Unary)
def _parse_Unary(self, expr):
argument_parsed_expr = self._parse_expression(expr.this)
function_name = SQL_FUNCTIONS[type(expr)]
signature, result_type, function_expression = self._parse_function_invokation(
function_name, argument_parsed_expr
)
result_name = (
f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}"
)
return ParsedSubstraitExpression(
result_name,
result_type,
function_expression,
argument_parsed_expr.invoked_functions | {signature},
)

def _parse_function_invokation(
self, function_name, argument_parsed_expr, *additional_arguments
):
"""Generates a Substrait function invokation expression.

The function invocation will be generated from the function name
and the arguments as ParsedSubstraitExpression.

Returns the function signature, the return type and the
invokation expression itself.
"""
arguments = [argument_parsed_expr] + list(additional_arguments)
signature = self._functions_catalog.make_signature(
function_name, proto_argtypes=[arg.type for arg in arguments]
)

registered_function = self._functions_catalog.lookup_function(signature)
if registered_function is None:
raise KeyError(f"Function not found: {signature}")

return (
registered_function.signature,
registered_function.return_type,
proto.Expression(
scalar_function=proto.Expression.ScalarFunction(
function_reference=registered_function.function_anchor,
arguments=[
proto.FunctionArgument(value=arg.expression)
for arg in arguments
],
)
),
)


class ParsedSubstraitExpression:
"""A Substrait expression that was parsed from a SQLGlot node.

This stores the expression itself, with an associated output name
in case it is required to emit projections.

It also stores the type of the expression (i64, string, boolean, etc...)
and the functions that the expression in going to invoke.
"""

def __init__(self, output_name, type, expression, invoked_functions=None):
self.expression = expression
self.output_name = output_name
self.type = type

if invoked_functions is None:
invoked_functions = set()
self.invoked_functions = invoked_functions

def duplicate(
self, output_name=None, type=None, expression=None, invoked_functions=None
):
return ParsedSubstraitExpression(
output_name or self.output_name,
type or self.type,
expression or self.expression,
invoked_functions or self.invoked_functions,
)

def __repr__(self):
return f"<ParsedSubstraitExpression {self.output_name} {self.type}>"
Loading
Loading