Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,15 @@ def __deepcopy__(self, memodict={}) -> "SelectableEntity": # noqa: B006
deepcopy(self.entity, memodict), analyzer=self.analyzer
)
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)

if (
self._session.reduce_describe_query_enabled
and self._session.cte_optimization_enabled
):
copied._attributes = (
deepcopy(self._attributes, memodict)
if self._attributes is not None
else None
)
return copied

@property
Expand Down Expand Up @@ -940,6 +948,13 @@ def __copy__(self):
self._merge_projection_complexity_with_subquery
)
new.df_ast_ids = self.df_ast_ids.copy() if self.df_ast_ids is not None else None
if (
self._session.reduce_describe_query_enabled
and self._session.cte_optimization_enabled
):
new._attributes = (
self._attributes.copy() if self._attributes is not None else None
)
return new

def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
Expand All @@ -959,6 +974,15 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
)

_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
if (
self._session.reduce_describe_query_enabled
and self._session.cte_optimization_enabled
):
copied._attributes = (
deepcopy(self._attributes, memodict)
if self._attributes is not None
else None
)
copied._projection_in_str = self._projection_in_str
copied._query_params = deepcopy(self._query_params)
copied._merge_projection_complexity_with_subquery = (
Expand Down Expand Up @@ -1404,7 +1428,11 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
if can_be_flattened:
new = copy(self)
final_projection = []

if (
self._session.reduce_describe_query_enabled
and self._session.cte_optimization_enabled
):
new._attributes = None # reset attributes since projection changed
assert new_column_states is not None
for col, state in new_column_states.items():
if state.change_state in (
Expand Down
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
if TYPE_CHECKING:
from snowflake.snowpark._internal.compiler.utils import TreeNode # pragma: no cover

HASH_LENGTH = 10


def find_duplicate_subtrees(
root: "TreeNode", propagate_complexity_hist: bool = False
Expand Down Expand Up @@ -272,7 +274,7 @@ def stringify(d):
string = f"{string}#{stringify(node.df_aliased_col_name_to_real_col_name)}"

try:
return hashlib.sha256(string.encode()).hexdigest()[:10]
return hashlib.sha256(string.encode()).hexdigest()[:HASH_LENGTH]
except Exception as ex:
logging.warning(f"Encode SnowflakePlan ID failed: {ex}")
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@
LogicalPlan,
WithQueryBlock,
)
from snowflake.snowpark._internal.compiler.cte_utils import find_duplicate_subtrees
from snowflake.snowpark._internal.compiler.cte_utils import (
find_duplicate_subtrees,
HASH_LENGTH,
)
from snowflake.snowpark._internal.compiler.query_generator import QueryGenerator
from snowflake.snowpark._internal.compiler.utils import (
TreeNode,
replace_child,
update_resolvable_node,
)
from snowflake.snowpark._internal.utils import (
TEMP_OBJECT_NAME_PREFIX,
TempObjectType,
random_name_for_temp_object,
)
import snowflake.snowpark.context as context


class RepeatedSubqueryEliminationResult:
Expand Down Expand Up @@ -164,10 +169,17 @@ def _update_parents(
node.encoded_node_id_with_query
]
else:
# create a WithQueryBlock node
with_block = WithQueryBlock(
name=random_name_for_temp_object(TempObjectType.CTE), child=node
)
if (
self._query_generator.session.reduce_describe_query_enabled
and context._is_snowpark_connect_compatible_mode
):
# create a deterministic name using the first 10 chars of encoded_node_id_with_query (SHA256 hash)
# It helps when DataFrame.queries is called multiple times.
# Consistent CTE names returned, reducing the number of describe queries from cached_analyze_attributes calls.
cte_name = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{node.encoded_node_id_with_query[:HASH_LENGTH].upper()}"
else:
cte_name = random_name_for_temp_object(TempObjectType.CTE)
with_block = WithQueryBlock(name=cte_name, child=node) # type: ignore
with_block._is_valid_for_replacement = True

resolved_with_block = self._query_generator.resolve(with_block)
Expand Down
39 changes: 39 additions & 0 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import re
import tracemalloc
from unittest import mock

import pytest

Expand Down Expand Up @@ -32,6 +33,7 @@
StringType,
TimestampType,
)
import snowflake.snowpark.context as context
from tests.integ.scala.test_dataframe_reader_suite import get_reader
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker
from tests.utils import IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils
Expand Down Expand Up @@ -1313,3 +1315,40 @@ def test_table_select_cte(session):
union_count=1,
join_count=0,
)


@pytest.mark.parametrize(
"reduce_describe_enabled,expected_describe_counts",
[
(True, [1, 0]), # With caching: first call misses, second call hits cache
(False, [1, 1]), # Without caching: both calls issue describe queries
],
)
def test_dataframe_queries_with_cte_reuses_schema_cache(
session, reduce_describe_enabled, expected_describe_counts
):
"""Test that calling dataframe.queries (not same dataframe but same operation) multiple times with CTE optimization
does not issue extra DESCRIBE queries when reduce_describe_query_enabled is True.

This tests the deterministic CTE naming feature: when CTE optimization is enabled
and reduce_describe_query is enabled, repeated calls to df.queries should produce
identical SQL (with same CTE names), allowing the schema cache to hit.
"""

def create_cte_dataframe():
"""Create a DataFrame that triggers CTE optimization (same df used twice)."""
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
return df.union_all(df)

def access_queries_and_schema(df):
"""Access both queries and schema properties."""
_ = df.queries
_ = df.schema

with mock.patch.object(
session, "_reduce_describe_query_enabled", reduce_describe_enabled
), mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True):
for expected_describe_count in expected_describe_counts:
df_union = create_cte_dataframe()
with SqlCounter(query_count=0, describe_count=expected_describe_count):
access_queries_and_schema(df_union)
68 changes: 68 additions & 0 deletions tests/integ/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
from typing import Callable, List, Optional
from unittest import mock

import pytest

Expand Down Expand Up @@ -415,3 +416,70 @@ def traverse_plan(plan, plan_id_map):
traverse_plan(child, plan_id_map)

traverse_plan(copied_plan, {})


def test_selectable_entity_deepcopy_attributes_with_flags_enabled(session):
"""Verify _attributes is deepcopied in SelectableEntity when both flags are True."""
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).write.save_as_table(
temp_table_name, table_type="temp"
)

with mock.patch.object(
session, "_reduce_describe_query_enabled", True
), mock.patch.object(session, "_cte_optimization_enabled", True):
# Apply a filter to get a SelectStatement with SelectableEntity as child
df = session.table(temp_table_name).filter(col("a") == 1)
# Access the SelectableEntity from the plan's children
if session.sql_simplifier_enabled:
assert len(df._plan.children_plan_nodes) == 1
assert isinstance(df._plan.children_plan_nodes[0], SelectableEntity)
selectable = df._plan.children_plan_nodes[0]
# Set _attributes to simulate cached attributes
selectable._attributes = selectable.snowflake_plan.attributes

copied_selectable = copy.deepcopy(selectable)

# Verify attributes were deepcopied
assert copied_selectable._attributes is not None
assert copied_selectable._attributes is not selectable._attributes
for copied_attr, original_attr in zip(
copied_selectable._attributes, selectable._attributes
):
assert copied_attr is not original_attr
assert copied_attr.name == original_attr.name


@pytest.mark.parametrize(
"reduce_describe_enabled,cte_enabled",
[
(False, False),
(True, False),
(False, True),
],
)
def test_selectable_entity_deepcopy_attributes_with_flags_disabled(
session, reduce_describe_enabled, cte_enabled
):
"""Verify _attributes is NOT copied when either flag is False."""
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).write.save_as_table(
temp_table_name, table_type="temp"
)

with mock.patch.object(
session, "_reduce_describe_query_enabled", reduce_describe_enabled
), mock.patch.object(session, "_cte_optimization_enabled", cte_enabled):
# Apply a filter to get a SelectStatement with SelectableEntity as child
df = session.table(temp_table_name).filter(col("a") == 1)
if session.sql_simplifier_enabled:
assert len(df._plan.children_plan_nodes) == 1
assert isinstance(df._plan.children_plan_nodes[0], SelectableEntity)
selectable = df._plan.children_plan_nodes[0]
# Set _attributes to simulate cached attributes
selectable._attributes = selectable.snowflake_plan.attributes

copied_selectable = copy.deepcopy(selectable)

# Verify attributes were NOT copied
assert copied_selectable._attributes is None
Loading
Loading