Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,14 @@ def configure_development_features(
_debug_eager_schema_validation = enable_eager_schema_validation

if enable_dataframe_trace_on_error or enable_trace_sql_errors_to_dataframe:
_enable_dataframe_trace_on_error = enable_dataframe_trace_on_error
_enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe
sessions = snowflake.snowpark.session._get_active_sessions(
require_at_least_one=False
)
try:
session = get_active_session()
if session is None:
_logger.warning(
"No active session found. Please create a session first and call "
"`configure_development_features()` after creating the session.",
)
return
_enable_dataframe_trace_on_error = enable_dataframe_trace_on_error
_enable_trace_sql_errors_to_dataframe = enable_trace_sql_errors_to_dataframe
session.ast_enabled = True
for active_session in sessions:
active_session._set_ast_enabled_internal(True)
except Exception as e:
_logger.warning(
f"Cannot enable AST collection in the session due to {str(e)}. Some development features may not work as expected.",
Expand Down
25 changes: 20 additions & 5 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,16 @@ def _get_active_session() -> "Session":
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()


def _get_active_sessions() -> Set["Session"]:
def _get_active_sessions(require_at_least_one: bool = True) -> Set["Session"]:
with _session_management_lock:
if len(_active_sessions) >= 1:
# TODO: This function is allowing unsafe access to a mutex protected data
# structure, we should ONLY use it in tests
return _active_sessions
else:
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
if require_at_least_one:
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
return []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type mismatch: Returns an empty list [] but the function signature declares return type as Set["Session"]. This should return an empty set instead.

return set()

While iteration works with both lists and sets, this type inconsistency could cause runtime errors if callers use Set-specific methods like .add(), .union(), etc. or type checkers will flag this as an error.

Suggested change
return []
return set()

Spotted by Graphite Agent

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.



def _add_session(session: "Session") -> None:
Expand Down Expand Up @@ -736,6 +738,16 @@ def __init__(
ast_enabled = False

set_ast_state(AstFlagSource.SERVER, ast_enabled)

# development features require AST to be enabled
from snowflake.snowpark.context import (
_enable_trace_sql_errors_to_dataframe,
_enable_dataframe_trace_on_error,
)

if _enable_trace_sql_errors_to_dataframe or _enable_dataframe_trace_on_error:
self._set_ast_enabled_internal(True)

# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
# in Snowflake. This is the limit where we start seeing compilation errors.
self._large_query_breakdown_complexity_bounds: Tuple[int, int] = (
Expand Down Expand Up @@ -973,9 +985,7 @@ def ast_enabled(self) -> bool:
"""
return is_ast_enabled()

@ast_enabled.setter
@experimental_parameter(version="1.33.0")
def ast_enabled(self, value: bool) -> None:
def _set_ast_enabled_internal(self, value: bool) -> None:
# TODO: we could send here explicit telemetry if a user changes the behavior.
# In addition, we could introduce a server-side parameter to enable AST capture or not.
# self._conn._telemetry_client.send_ast_enabled_telemetry(
Expand All @@ -998,6 +1008,11 @@ def ast_enabled(self, value: bool) -> None:
self._auto_clean_up_temp_table_enabled = False
set_ast_state(AstFlagSource.USER, value)

@ast_enabled.setter
@experimental_parameter(version="1.33.0")
def ast_enabled(self, value: bool) -> None:
self._set_ast_enabled_internal(value)

@property
def cte_optimization_enabled(self) -> bool:
"""Set to ``True`` to enable the CTE optimization (defaults to ``False``).
Expand Down
30 changes: 14 additions & 16 deletions tests/integ/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from snowflake.snowpark.context import get_active_session
import snowflake.snowpark.context as context
import snowflake.snowpark.session as session
from unittest import mock


Expand All @@ -14,35 +15,32 @@ def test_get_active_session(session):

def test_context_configure_development_features():
try:
# Test when get_active_session() returns None
with mock.patch.object(context, "get_active_session", return_value=None):
# Test when _get_active_sessions() returns None
with mock.patch.object(session, "_get_active_sessions", return_value=None):
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=True
)
assert context._enable_trace_sql_errors_to_dataframe is False
assert context._enable_trace_sql_errors_to_dataframe is True
assert context._enable_dataframe_trace_on_error is False
assert context._debug_eager_schema_validation is False

# Test when get_active_session() throws an exception
with mock.patch.object(
context, "get_active_session", side_effect=RuntimeError("test")
):
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=True
)
assert context._enable_trace_sql_errors_to_dataframe is False
assert context._enable_dataframe_trace_on_error is False
# Test when _get_active_sessions() returns a valid session
mock_session1 = mock.MagicMock()
mock_session1._set_ast_enabled_internal = mock.MagicMock()
mock_session2 = mock.MagicMock()
mock_session2._set_ast_enabled_internal = mock.MagicMock()

# Test when get_active_session() returns a valid session
mock_session = mock.MagicMock()
with mock.patch.object(
context, "get_active_session", return_value=mock_session
session, "_get_active_sessions", return_value=[mock_session1, mock_session2]
):
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=True
)
assert context._enable_trace_sql_errors_to_dataframe is True
assert context._enable_dataframe_trace_on_error is False
assert mock_session.ast_enabled is True
mock_session1._set_ast_enabled_internal.assert_called_once_with(True)
mock_session2._set_ast_enabled_internal.assert_called_once_with(True)

finally:
context.configure_development_features(
enable_trace_sql_errors_to_dataframe=False
Expand Down
Loading