-
Notifications
You must be signed in to change notification settings - Fork 1.9k
fix: Handle async generators in session state to prevent pickle errors #2958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright 2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utilities for safely copying session objects that may contain non-serializable objects.""" | ||
|
||
from __future__ import annotations | ||
|
||
import copy | ||
import inspect | ||
import logging | ||
from typing import Any | ||
|
||
logger = logging.getLogger('google_adk.' + __name__) | ||
|
||
|
||
def _is_async_generator(obj: Any) -> bool: | ||
"""Check if an object is an async generator.""" | ||
return inspect.isasyncgen(obj) | ||
|
||
|
||
def _filter_non_serializable_objects(obj: Any, path: str = "root") -> Any: | ||
"""Recursively filter out non-serializable objects from a data structure. | ||
Args: | ||
obj: The object to filter | ||
path: The current path in the object tree (for logging) | ||
Returns: | ||
A copy of the object with non-serializable objects removed | ||
""" | ||
if _is_async_generator(obj): | ||
logger.warning( | ||
f"Removing async generator from session state at {path}. " | ||
"Async generators cannot be persisted in session state." | ||
) | ||
return None | ||
|
||
if isinstance(obj, dict): | ||
filtered_dict = {} | ||
for key, value in obj.items(): | ||
filtered_value = _filter_non_serializable_objects(value, f"{path}.{key}") | ||
if filtered_value is not None: | ||
filtered_dict[key] = filtered_value | ||
return filtered_dict | ||
|
||
elif isinstance(obj, (list, tuple)): | ||
filtered_items = [] | ||
for i, item in enumerate(obj): | ||
filtered_item = _filter_non_serializable_objects(item, f"{path}[{i}]") | ||
if filtered_item is not None: | ||
filtered_items.append(filtered_item) | ||
return type(obj)(filtered_items) | ||
|
||
# For other types, assume they're serializable | ||
return obj | ||
|
||
|
||
def safe_deepcopy_session(session): | ||
"""Safely deepcopy a session object, filtering out non-serializable objects. | ||
This function creates a deep copy of a session while filtering out objects | ||
that cannot be pickled, such as async generators. | ||
Args: | ||
session: The session object to copy | ||
Returns: | ||
A deep copy of the session with non-serializable objects filtered out | ||
""" | ||
# Create a shallow copy first | ||
session_copy = copy.copy(session) | ||
|
||
# Deep copy the state while filtering non-serializable objects | ||
if hasattr(session_copy, 'state') and session_copy.state: | ||
session_copy.state = _filter_non_serializable_objects(session_copy.state, "state") | ||
# Now we can safely deepcopy the filtered state | ||
session_copy.state = copy.deepcopy(session_copy.state) | ||
|
||
# Deep copy other attributes that should be safe | ||
if hasattr(session_copy, 'events'): | ||
session_copy.events = copy.deepcopy(session.events) | ||
|
||
return session_copy | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,8 @@ | |
from .base_session_service import ListSessionsResponse | ||
from .session import Session | ||
from .state import State | ||
from ._session_copy_utils import safe_deepcopy_session | ||
from ._session_copy_utils import _filter_non_serializable_objects | ||
|
||
logger = logging.getLogger('google_adk.' + __name__) | ||
|
||
|
@@ -93,11 +95,15 @@ def _create_session_impl( | |
if session_id and session_id.strip() | ||
else str(uuid.uuid4()) | ||
) | ||
|
||
# Filter out non-serializable objects from the state before creating the session | ||
filtered_state = _filter_non_serializable_objects(state or {}, "initial_state") | ||
|
||
session = Session( | ||
app_name=app_name, | ||
user_id=user_id, | ||
id=session_id, | ||
state=state or {}, | ||
state=filtered_state, | ||
last_update_time=time.time(), | ||
) | ||
Comment on lines
+99
to
108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The filtering of Additionally, this change introduces an import of a private function ( I suggest removing this filtering step and the associated import. This will centralize the filtering logic within session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=state or {},
last_update_time=time.time(),
) |
||
|
||
|
@@ -107,7 +113,7 @@ def _create_session_impl( | |
self.sessions[app_name][user_id] = {} | ||
self.sessions[app_name][user_id][session_id] = session | ||
|
||
copied_session = copy.deepcopy(session) | ||
copied_session = safe_deepcopy_session(session) | ||
return self._merge_state(app_name, user_id, copied_session) | ||
|
||
@override | ||
|
@@ -158,7 +164,7 @@ def _get_session_impl( | |
return None | ||
|
||
session = self.sessions[app_name][user_id].get(session_id) | ||
copied_session = copy.deepcopy(session) | ||
copied_session = safe_deepcopy_session(session) | ||
|
||
if config: | ||
if config.num_recent_events: | ||
|
@@ -222,7 +228,7 @@ def _list_sessions_impl( | |
|
||
sessions_without_events = [] | ||
for session in self.sessions[app_name][user_id].values(): | ||
copied_session = copy.deepcopy(session) | ||
copied_session = safe_deepcopy_session(session) | ||
copied_session.events = [] | ||
copied_session = self._merge_state(app_name, user_id, copied_session) | ||
sessions_without_events.append(copied_session) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright 2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for async generator handling in session services. | ||
This module tests the fix for issue #1862 where async generators in session | ||
state would cause pickle errors during deepcopy operations. | ||
""" | ||
|
||
import asyncio | ||
import pytest | ||
from typing import AsyncGenerator | ||
|
||
from google.adk.sessions.in_memory_session_service import InMemorySessionService | ||
from google.adk.sessions._session_copy_utils import ( | ||
_filter_non_serializable_objects, | ||
_is_async_generator, | ||
safe_deepcopy_session, | ||
) | ||
|
||
|
||
async def test_async_generator() -> AsyncGenerator[str, None]: | ||
"""A test async generator function.""" | ||
yield "test_message_1" | ||
yield "test_message_2" | ||
|
||
|
||
class TestAsyncGeneratorSessionHandling: | ||
"""Test class for async generator handling in sessions.""" | ||
|
||
def test_is_async_generator_detection(self): | ||
"""Test that async generators are correctly detected.""" | ||
async def regular_async_func(): | ||
return "not a generator" | ||
|
||
def regular_func(): | ||
return "regular function" | ||
|
||
# Test with actual async generator | ||
async_gen = test_async_generator() | ||
assert _is_async_generator(async_gen) is True | ||
|
||
# Test with non-generators | ||
assert _is_async_generator(regular_func) is False | ||
assert _is_async_generator("string") is False | ||
assert _is_async_generator(123) is False | ||
assert _is_async_generator([1, 2, 3]) is False | ||
assert _is_async_generator({"key": "value"}) is False | ||
|
||
# Clean up | ||
asyncio.run(async_gen.aclose()) | ||
|
||
def test_filter_non_serializable_objects(self): | ||
"""Test filtering of async generators from nested data structures.""" | ||
async_gen = test_async_generator() | ||
|
||
# Test simple case | ||
state = {"async_tool": async_gen, "normal_data": "test_value"} | ||
filtered = _filter_non_serializable_objects(state) | ||
|
||
assert "normal_data" in filtered | ||
assert filtered["normal_data"] == "test_value" | ||
assert "async_tool" not in filtered | ||
|
||
# Test nested structure | ||
nested_state = { | ||
"level1": { | ||
"level2": { | ||
"async_gen": async_gen, | ||
"normal": "value" | ||
}, | ||
"other": "data" | ||
}, | ||
"top_level": "value" | ||
} | ||
|
||
filtered_nested = _filter_non_serializable_objects(nested_state) | ||
assert filtered_nested["level1"]["level2"]["normal"] == "value" | ||
assert "async_gen" not in filtered_nested["level1"]["level2"] | ||
assert filtered_nested["level1"]["other"] == "data" | ||
assert filtered_nested["top_level"] == "value" | ||
|
||
# Test list with async generator | ||
list_state = {"tools": [async_gen, "normal_tool"]} | ||
filtered_list = _filter_non_serializable_objects(list_state) | ||
assert len(filtered_list["tools"]) == 1 | ||
assert filtered_list["tools"][0] == "normal_tool" | ||
|
||
# Clean up | ||
asyncio.run(async_gen.aclose()) | ||
|
||
@pytest.mark.asyncio | ||
async def test_session_creation_with_async_generator(self): | ||
"""Test that session creation works with async generators in state.""" | ||
session_service = InMemorySessionService() | ||
async_gen = test_async_generator() | ||
|
||
# This should not raise an exception | ||
session = await session_service.create_session( | ||
app_name="test_app", | ||
user_id="test_user", | ||
state={ | ||
"streaming_tool": async_gen, | ||
"normal_data": "test_value" | ||
} | ||
) | ||
|
||
# The async generator should be filtered out | ||
assert "streaming_tool" not in session.state | ||
assert "normal_data" in session.state | ||
assert session.state["normal_data"] == "test_value" | ||
|
||
# Clean up | ||
await async_gen.aclose() | ||
|
||
@pytest.mark.asyncio | ||
async def test_session_operations_with_filtered_state(self): | ||
"""Test that all session operations work after filtering.""" | ||
session_service = InMemorySessionService() | ||
|
||
# Create session with normal state | ||
session = await session_service.create_session( | ||
app_name="test_app", | ||
user_id="test_user", | ||
state={"normal_data": "test_value"} | ||
) | ||
|
||
# Test get_session | ||
retrieved_session = await session_service.get_session( | ||
app_name="test_app", | ||
user_id="test_user", | ||
session_id=session.id | ||
) | ||
assert retrieved_session is not None | ||
assert retrieved_session.state["normal_data"] == "test_value" | ||
|
||
# Test list_sessions | ||
sessions_response = await session_service.list_sessions( | ||
app_name="test_app", | ||
user_id="test_user" | ||
) | ||
assert len(sessions_response.sessions) == 1 | ||
assert sessions_response.sessions[0].id == session.id | ||
|
||
def test_safe_deepcopy_session(self): | ||
"""Test the safe_deepcopy_session function.""" | ||
# This test would require creating a mock session object | ||
# For now, we test that the function exists and can be imported | ||
assert callable(safe_deepcopy_session) | ||
Comment on lines
+156
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test for Please consider implementing a proper test case that creates a def test_safe_deepcopy_session(self):
"""Test the safe_deepcopy_session function."""
from google.adk.sessions.session import Session
async_gen = test_async_generator()
session = Session(
id="test_session",
app_name="test_app",
user_id="test_user",
state={"async_tool": async_gen, "normal_data": "value"},
events=[],
)
copied_session = safe_deepcopy_session(session)
# Original session should be untouched
assert "async_tool" in session.state
# Copied session should have generator removed
assert "async_tool" not in copied_session.state
assert "normal_data" in copied_session.state
assert copied_session.state["normal_data"] == "value"
# Ensure other attributes are copied
assert copied_session.id == session.id
# Clean up
asyncio.run(async_gen.aclose()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of
safe_deepcopy_session
can be improved for clarity and robustness.session
parameter and the function's return value should be type-hinted withSession
for better static analysis and readability. This will require importingSession
from.session
.hasattr
checks forstate
andevents
are redundant becauseSession
is a Pydantic model and these fields are guaranteed to exist. You can simply check for truthiness if you want to skip operations on empty collections.copy.deepcopy(session_copy.events)
should be used instead ofcopy.deepcopy(session.events)
. While it works here due to the shallow copy, it's less confusing to operate on thesession_copy
object throughout.