Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 140 additions & 23 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Sequence,
Set,
Union,
Literal,
)

import snowflake.snowpark._internal.utils
Expand Down Expand Up @@ -86,6 +87,7 @@
is_sql_select_statement,
ExprAliasUpdateDict,
)
import snowflake.snowpark.context as context

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
# Python 3.9 can use both
Expand Down Expand Up @@ -1362,9 +1364,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
):
# TODO: Clean up, this entire if case is parameter protection
can_be_flattened = False
elif (self.where or self.order_by or self.limit_) and has_data_generator_exp(
cols
):
elif (
self.where or self.order_by or self.limit_
) and has_data_generator_or_window_function_exp(cols):
can_be_flattened = False
elif self.where and (
(subquery_dependent_columns := derive_dependent_columns(self.where))
Expand All @@ -1375,6 +1377,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
subquery_dependent_columns & new_column_states.active_columns
)
)
or (
# unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery
# reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error
# sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1'
context._is_snowpark_connect_compatible_mode
and new_column_states.dropped_columns
and any(
self.column_states[_col].change_state
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
for _col in (
subquery_dependent_columns & new_column_states.dropped_columns
)
)
)
):
can_be_flattened = False
elif self.order_by and (
Expand All @@ -1387,6 +1403,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
subquery_dependent_columns & new_column_states.active_columns
)
)
or (
# unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery
# reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error
# sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"'
context._is_snowpark_connect_compatible_mode
and new_column_states.dropped_columns
and any(
self.column_states[_col].change_state
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
for _col in (
subquery_dependent_columns & new_column_states.dropped_columns
)
)
)
):
can_be_flattened = False
elif self.distinct_:
Expand Down Expand Up @@ -1450,12 +1480,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
return new

def filter(self, col: Expression) -> "SelectStatement":
self._session._retrieve_aggregation_function_list()
can_be_flattened = (
(not self.flatten_disabled)
and can_clause_dependent_columns_flatten(
derive_dependent_columns(col), self.column_states
derive_dependent_columns(col), self.column_states, "filter"
)
and not has_data_generator_exp(self.projection)
and not has_data_generator_or_window_function_exp(self.projection)
and not (
context._is_snowpark_connect_compatible_mode
and has_aggregation_function_exp(self.projection)
) # sum(col) as new_col, new_col can not be flattened in where clause
and not (self.order_by and self.limit_ is not None)
)
if can_be_flattened:
Expand Down Expand Up @@ -1490,9 +1525,12 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
and (not self.limit_)
and (not self.offset)
and can_clause_dependent_columns_flatten(
derive_dependent_columns(*cols), self.column_states
derive_dependent_columns(*cols), self.column_states, "sort"
)
and not has_data_generator_exp(self.projection)
# we do not check aggregation function here like filter
# in the case when aggregation function is in the projection
# order by is evaluated after aggregation, row info are not taken in the calculation
)
if can_be_flattened:
new = copy(self)
Expand Down Expand Up @@ -1529,7 +1567,7 @@ def distinct(self) -> "SelectStatement":
# .order_by(col1).select(col2).distinct() cannot be flattened because
# SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
and (not (self.order_by and self.has_projection))
and not has_data_generator_exp(self.projection)
and not has_data_generator_or_window_function_exp(self.projection)
)
if can_be_flattened:
new = copy(self)
Expand Down Expand Up @@ -2020,7 +2058,12 @@ def can_projection_dependent_columns_be_flattened(
def can_clause_dependent_columns_flatten(
dependent_columns: Optional[AbstractSet[str]],
subquery_column_states: ColumnStateDict,
clause: Literal["filter", "sort"],
) -> bool:
assert clause in (
"filter",
"sort",
), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}"
if dependent_columns == COLUMN_DEPENDENCY_DOLLAR:
return False
elif (
Expand All @@ -2035,15 +2078,18 @@ def can_clause_dependent_columns_flatten(
dc_state = subquery_column_states.get(dc)
if dc_state:
if dc_state.change_state == ColumnChangeState.CHANGED_EXP:
return False
if (
clause == "filter"
): # where can not be flattened because 'where' is evaluated before projection, flattening leads to wrong result
# df.select((col('a') + 1).alias('a')).filter(col('a') > 5) -- this should be applied to the new 'a', flattening will use the old 'a' to evaluated
return False
else: # clause == 'sort'
# df.select((col('a') + 1).alias('a')).sort(col('a')) -- this is valid to flatten because 'order by' is evaluated after projection
# however, if the order by is a data generator, it should not be flattened because generator is evaluated dynamically according to the order.
return context._is_snowpark_connect_compatible_mode
elif dc_state.change_state == ColumnChangeState.NEW:
# Most of the time this can be flattened. But if a new column uses window function and this column
# is used in a clause, the sql doesn't work in Snowflake.
# For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work.
# But `select a, b as d from test_table where d = 1` works
# We can inspect whether the referenced new column uses window function. Here we are being
# conservative for now to not flatten the SQL.
return False
return context._is_snowpark_connect_compatible_mode

return True


Expand Down Expand Up @@ -2260,18 +2306,89 @@ def derive_column_states_from_subquery(
return column_states


def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
def _check_expressions_for_types(
expressions: Optional[List["Expression"]],
check_data_gen: bool = False,
check_window: bool = False,
check_aggregation: bool = False,
) -> bool:
"""Efficiently check if expressions contain specific types in a single pass.

Args:
expressions: List of expressions to check
check_data_gen: Check for data generator functions
check_window: Check for window functions
check_aggregation: Check for aggregation functions

Returns:
True if any requested type is found
"""
if expressions is None:
return False

for exp in expressions:
if isinstance(exp, WindowExpression):
if exp is None:
continue

# Check window functions
if check_window and isinstance(exp, WindowExpression):
return True
if isinstance(exp, FunctionExpression) and (
exp.is_data_generator
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION

# Check data generators (including window in non-connect mode)
if check_data_gen:
# In non-connect mode, windows are treated as data generators
if not context._is_snowpark_connect_compatible_mode and isinstance(
exp, WindowExpression
):
return True
# Check actual data generator functions
if isinstance(exp, FunctionExpression) and (
exp.is_data_generator
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION
):
# https://docs.snowflake.com/en/sql-reference/functions-data-generation
return True

# Check aggregation functions
if check_aggregation and isinstance(exp, FunctionExpression):
if exp.name.lower() in context._aggregation_function_set:
return True

# Recursively check children
if _check_expressions_for_types(
exp.children, check_data_gen, check_window, check_aggregation
):
# https://docs.snowflake.com/en/sql-reference/functions-data-generation
return True
if exp is not None and has_data_generator_exp(exp.children):
return True

return False


def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
"""Check if expressions contain data generator functions.

Note:
In non-connect mode, check_data_gen check both data generator and window expressions for backward compatibility.
In connect mode, check_data_gen only checks data generator expressions.
"""
return _check_expressions_for_types(expressions, check_data_gen=True)


def has_data_generator_or_window_function_exp(
expressions: Optional[List["Expression"]],
) -> bool:
"""Check if expressions contain data generators or window functions.

Optimized to do a single pass checking both types simultaneously.
"""
if not context._is_snowpark_connect_compatible_mode:
# In non-connect mode, windows are already treated as data generators
return _check_expressions_for_types(expressions, check_data_gen=True)
# In connect mode, check both in a single pass
return _check_expressions_for_types(
expressions, check_data_gen=True, check_window=True
)


def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool:
"""Check if expressions contain aggregation functions."""
return _check_expressions_for_types(expressions, check_aggregation=True)
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@

# This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect
_is_snowpark_connect_compatible_mode = False
_aggregation_function_set = (
set()
) # lower cased names of aggregation functions, used in sql simplification
_aggregation_function_set_lock = threading.RLock()

# Following are internal-only global flags, used to enable development features.
_enable_dataframe_trace_on_error = False
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/mock/_select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def filter(self, col: Expression) -> "MockSelectStatement":
else:
dependent_columns = derive_dependent_columns(col)
can_be_flattened = can_clause_dependent_columns_flatten(
dependent_columns, self.column_states
dependent_columns, self.column_states, "filter"
)
if can_be_flattened:
new = copy(self)
Expand All @@ -433,7 +433,7 @@ def sort(self, cols: List[Expression]) -> "MockSelectStatement":
else:
dependent_columns = derive_dependent_columns(*cols)
can_be_flattened = can_clause_dependent_columns_flatten(
dependent_columns, self.column_states
dependent_columns, self.column_states, "sort"
)
if can_be_flattened:
new = copy(self)
Expand Down
29 changes: 29 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4925,6 +4925,35 @@ def _execute_sproc_internal(
# Note the collect is implicit within the stored procedure call, so should not emit_ast here.
return df.collect(statement_params=statement_params, _emit_ast=False)[0][0]

def _retrieve_aggregation_function_list(self) -> None:
"""Retrieve the list of aggregation functions which will later be used in sql simplifier."""
if (
not context._is_snowpark_connect_compatible_mode
or context._aggregation_function_set
):
return

retrieved_set = set()

for sql in [
"""select function_name from information_schema.functions where is_aggregate = 'YES'""",
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
]:
try:
retrieved_set.update({r[0].lower() for r in self.sql(sql).collect()})
except BaseException as e:
_logger.debug(
"Unable to get aggregation functions from the database: %s",
e,
)
# we raise error here as a pessimistic tactics
# the reason is that if we fail to retrieve the aggregation function list, we have empty set
# the simplifier will flatten the query which contains aggregation functions leading to incorrect results
raise

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)

def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
"""
Returns a DataFrame representing the results of a directory table query on the specified stage.
Expand Down
26 changes: 23 additions & 3 deletions tests/integ/test_query_line_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def generate_test_data(session, sql_simplifier_enabled):
}


@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False])
@pytest.mark.parametrize(
"op,sql_simplifier,line_to_expected_sql",
"op,sql_simplifier,line_to_expected_sql,snowpark_connect_compatible_mode_sql",
[
(
lambda data: data["df1"].union(data["df2"]),
Expand All @@ -68,12 +69,16 @@ def generate_test_data(session, sql_simplifier_enabled):
6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)',
10: 'SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (3 :: INT, \'C\' :: STRING, 300 :: INT), (4 :: INT, \'D\' :: STRING, 400 :: INT) )',
},
None,
),
(
lambda data: data["df1"].filter(data["df1"].value > 150),
True,
{
8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)',
8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)'
},
{
8: """SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM (SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, 'A' :: STRING, 100 :: INT), (2 :: INT, 'B' :: STRING, 200 :: INT)) WHERE ("VALUE" > 150)""",
},
),
(
Expand All @@ -83,6 +88,7 @@ def generate_test_data(session, sql_simplifier_enabled):
1: 'SELECT "_1" AS "ID", "_2" AS "NAME" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) )',
4: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)',
},
None,
),
(
lambda data: data["df1"].pivot(F.col("name")).sum(F.col("value")),
Expand All @@ -92,12 +98,26 @@ def generate_test_data(session, sql_simplifier_enabled):
6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)',
9: 'SELECT * FROM ( SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) ) ) PIVOT ( sum("VALUE") FOR "NAME" IN ( ANY ) )',
},
None,
),
],
)
def test_get_plan_from_line_numbers_sql_content(
session, op, sql_simplifier, line_to_expected_sql
session,
op,
sql_simplifier,
line_to_expected_sql,
snowpark_connect_compatible_mode_sql,
snowpark_connect_compatible_mode,
monkeypatch,
):
if snowpark_connect_compatible_mode:
import snowflake.snowpark.context as context

monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
line_to_expected_sql = (
snowpark_connect_compatible_mode_sql or line_to_expected_sql
)
session.sql_simplifier_enabled = sql_simplifier
df = op(generate_test_data(session, sql_simplifier))

Expand Down
Loading
Loading