Skip to content

Commit 038da09

Browse files
authored
Fix(snowflake)!: treat TABLE(...) as a UDTF (#4875)
1 parent 1268605 commit 038da09

File tree

6 files changed

+115
-40
lines changed

6 files changed

+115
-40
lines changed

sqlglot/dialects/snowflake.py

+28
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ class Parser(parser.Parser):
450450
"REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll),
451451
"RLIKE": exp.RegexpLike.from_arg_list,
452452
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
453+
"TABLE": lambda args: exp.TableFromRows(this=seq_get(args, 0)),
453454
"TIMEADD": _build_date_time_add(exp.TimeAdd),
454455
"TIMEDIFF": _build_datediff,
455456
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
@@ -747,6 +748,33 @@ def _parse_table_parts(
747748

748749
return table
749750

751+
def _parse_table(
752+
self,
753+
schema: bool = False,
754+
joins: bool = False,
755+
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
756+
parse_bracket: bool = False,
757+
is_db_reference: bool = False,
758+
parse_partition: bool = False,
759+
) -> t.Optional[exp.Expression]:
760+
table = super()._parse_table(
761+
schema=schema,
762+
joins=joins,
763+
alias_tokens=alias_tokens,
764+
parse_bracket=parse_bracket,
765+
is_db_reference=is_db_reference,
766+
parse_partition=parse_partition,
767+
)
768+
if isinstance(table, exp.Table) and isinstance(table.this, exp.TableFromRows):
769+
table_from_rows = table.this
770+
for arg in exp.TableFromRows.arg_types:
771+
if arg != "this":
772+
table_from_rows.set(arg, table.args.get(arg))
773+
774+
table = table_from_rows
775+
776+
return table
777+
750778
def _parse_id_var(
751779
self,
752780
any_token: bool = True,

sqlglot/expressions.py

+12
Original file line numberDiff line numberDiff line change
@@ -2587,6 +2587,18 @@ class Lateral(UDTF):
25872587
}
25882588

25892589

2590+
# https://docs.snowflake.com/sql-reference/literals-table
2591+
# https://docs.snowflake.com/en/sql-reference/functions-table#using-a-table-function
2592+
class TableFromRows(UDTF):
2593+
arg_types = {
2594+
"this": True,
2595+
"alias": False,
2596+
"joins": False,
2597+
"pivots": False,
2598+
"sample": False,
2599+
}
2600+
2601+
25902602
class MatchRecognizeMeasure(Expression):
25912603
arg_types = {
25922604
"this": True,

sqlglot/generator.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ class Generator(metaclass=_Generator):
191191
exp.StreamingTableProperty: lambda *_: "STREAMING",
192192
exp.StrictProperty: lambda *_: "STRICT",
193193
exp.SwapTable: lambda self, e: f"SWAP WITH {self.sql(e, 'this')}",
194-
exp.TemporaryProperty: lambda *_: "TEMPORARY",
195194
exp.Tags: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
195+
exp.TemporaryProperty: lambda *_: "TEMPORARY",
196196
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
197197
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
198198
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
@@ -1999,6 +1999,17 @@ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
19991999

20002000
return f"{only}{table}{changes}{partition}{version}{file_format}{sample_pre_alias}{alias}{hints}{pivots}{sample_post_alias}{joins}{laterals}{ordinality}"
20012001

2002+
def tablefromrows_sql(self, expression: exp.TableFromRows) -> str:
2003+
table = self.func("TABLE", expression.this)
2004+
alias = self.sql(expression, "alias")
2005+
alias = f" AS {alias}" if alias else ""
2006+
sample = self.sql(expression, "sample")
2007+
pivots = self.expressions(expression, key="pivots", sep="", flat=True)
2008+
joins = self.indent(
2009+
self.expressions(expression, key="joins", sep="", flat=True), skip_first=True
2010+
)
2011+
return f"{table}{alias}{pivots}{sample}{joins}"
2012+
20022013
def tablesample_sql(
20032014
self,
20042015
expression: exp.TableSample,

sqlglot/parser.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,8 @@ class Parser(metaclass=_Parser):
13871387

13881388
RECURSIVE_CTE_SEARCH_KIND = {"BREADTH", "DEPTH", "CYCLE"}
13891389

1390+
MODIFIABLES = (exp.Query, exp.Table, exp.TableFromRows)
1391+
13901392
STRICT_CAST = True
13911393

13921394
PREFIXED_PIVOT_COLUMNS = False
@@ -3351,7 +3353,7 @@ def _implicit_unnests_to_explicit(self, this: E) -> E:
33513353
def _parse_query_modifiers(
33523354
self, this: t.Optional[exp.Expression]
33533355
) -> t.Optional[exp.Expression]:
3354-
if isinstance(this, (exp.Query, exp.Table)):
3356+
if isinstance(this, self.MODIFIABLES):
33553357
for join in self._parse_joins():
33563358
this.append("joins", join)
33573359
for lateral in iter(self._parse_lateral, None):

tests/dialects/test_snowflake.py

+15-38
Original file line numberDiff line numberDiff line change
@@ -1627,44 +1627,21 @@ def test_stored_procedures(self):
16271627
"CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
16281628
)
16291629

1630-
def test_table_literal(self):
1631-
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
1632-
self.validate_all(
1633-
r"""SELECT * FROM TABLE('MYTABLE')""",
1634-
write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
1635-
)
1636-
1637-
self.validate_all(
1638-
r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')""",
1639-
write={"snowflake": r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')"""},
1640-
)
1641-
1642-
# Per Snowflake documentation at https://docs.snowflake.com/en/sql-reference/literals-table.html
1643-
# one can use either a " ' " or " $$ " to enclose the object identifier.
1644-
# Capturing the single tokens seems like lot of work. Hence adjusting tests to use these interchangeably,
1645-
self.validate_all(
1646-
r"""SELECT * FROM TABLE($$MYDB. "MYSCHEMA"."MYTABLE"$$)""",
1647-
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
1648-
)
1649-
1650-
self.validate_all(
1651-
r"""SELECT * FROM TABLE($MYVAR)""",
1652-
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
1653-
)
1654-
1655-
self.validate_all(
1656-
r"""SELECT * FROM TABLE(?)""",
1657-
write={"snowflake": r"""SELECT * FROM TABLE(?)"""},
1658-
)
1659-
1660-
self.validate_all(
1661-
r"""SELECT * FROM TABLE(:BINDING)""",
1662-
write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
1663-
)
1664-
1665-
self.validate_all(
1666-
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
1667-
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
1630+
def test_table_function(self):
1631+
self.validate_identity("SELECT * FROM TABLE('MYTABLE')")
1632+
self.validate_identity("SELECT * FROM TABLE($MYVAR)")
1633+
self.validate_identity("SELECT * FROM TABLE(?)")
1634+
self.validate_identity("SELECT * FROM TABLE(:BINDING)")
1635+
self.validate_identity("SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10")
1636+
self.validate_identity("SELECT * FROM TABLE('t1') AS f")
1637+
self.validate_identity("SELECT * FROM (TABLE('t1') CROSS JOIN TABLE('t2'))")
1638+
self.validate_identity("SELECT * FROM TABLE('t1'), LATERAL (SELECT * FROM t2)")
1639+
self.validate_identity("SELECT * FROM TABLE('t1') UNION ALL SELECT * FROM TABLE('t2')")
1640+
self.validate_identity("SELECT * FROM TABLE('t1') TABLESAMPLE BERNOULLI (20.3)")
1641+
self.validate_identity("""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')""")
1642+
self.validate_identity(
1643+
'SELECT * FROM TABLE($$MYDB. "MYSCHEMA"."MYTABLE"$$)',
1644+
"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')""",
16681645
)
16691646

16701647
def test_flatten(self):

tests/test_lineage.py

+45
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,48 @@ def test_pivot_with_implicit_column_of_pivoted_source_and_cte(self) -> None:
576576
self.assertEqual(node.downstream[0].name, "t.empid")
577577
self.assertEqual(node.downstream[0].reference_node_name, "t")
578578
self.assertEqual(node.downstream[0].downstream[0].name, "quarterly_sales.empid")
579+
580+
def test_table_udtf_snowflake(self) -> None:
581+
lateral_flatten = """
582+
SELECT f.value:external_id::string AS external_id
583+
FROM database_name.schema_name.table_name AS raw,
584+
LATERAL FLATTEN(events) AS f
585+
"""
586+
table_flatten = """
587+
SELECT f.value:external_id::string AS external_id
588+
FROM database_name.schema_name.table_name AS raw
589+
JOIN TABLE(FLATTEN(events)) AS f
590+
"""
591+
592+
lateral_node = lineage("external_id", lateral_flatten, dialect="snowflake")
593+
table_node = lineage("external_id", table_flatten, dialect="snowflake")
594+
595+
self.assertEqual(lateral_node.name, "EXTERNAL_ID")
596+
self.assertEqual(table_node.name, "EXTERNAL_ID")
597+
598+
lateral_node = lateral_node.downstream[0]
599+
table_node = table_node.downstream[0]
600+
601+
self.assertEqual(lateral_node.name, "F.VALUE")
602+
self.assertEqual(
603+
lateral_node.source.sql("snowflake"),
604+
"LATERAL FLATTEN(RAW.EVENTS) AS F(SEQ, KEY, PATH, INDEX, VALUE, THIS)",
605+
)
606+
607+
self.assertEqual(table_node.name, "F.VALUE")
608+
self.assertEqual(table_node.source.sql("snowflake"), "TABLE(FLATTEN(RAW.EVENTS)) AS F")
609+
610+
lateral_node = lateral_node.downstream[0]
611+
table_node = table_node.downstream[0]
612+
613+
self.assertEqual(lateral_node.name, "RAW.EVENTS")
614+
self.assertEqual(
615+
lateral_node.source.sql("snowflake"),
616+
"DATABASE_NAME.SCHEMA_NAME.TABLE_NAME AS RAW",
617+
)
618+
619+
self.assertEqual(table_node.name, "RAW.EVENTS")
620+
self.assertEqual(
621+
table_node.source.sql("snowflake"),
622+
"DATABASE_NAME.SCHEMA_NAME.TABLE_NAME AS RAW",
623+
)

0 commit comments

Comments
 (0)