Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions src/google/adk/sessions/_session_copy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for safely copying session objects that may contain non-serializable objects."""

from __future__ import annotations

import copy
import inspect
import logging
from typing import Any

logger = logging.getLogger('google_adk.' + __name__)


def _is_async_generator(obj: Any) -> bool:
"""Check if an object is an async generator."""
return inspect.isasyncgen(obj)


def _filter_non_serializable_objects(obj: Any, path: str = "root") -> Any:
"""Recursively filter out non-serializable objects from a data structure.
Args:
obj: The object to filter
path: The current path in the object tree (for logging)
Returns:
A copy of the object with non-serializable objects removed
"""
if _is_async_generator(obj):
logger.warning(
f"Removing async generator from session state at {path}. "
"Async generators cannot be persisted in session state."
)
return None

if isinstance(obj, dict):
filtered_dict = {}
for key, value in obj.items():
filtered_value = _filter_non_serializable_objects(value, f"{path}.{key}")
if filtered_value is not None:
filtered_dict[key] = filtered_value
return filtered_dict

elif isinstance(obj, (list, tuple)):
filtered_items = []
for i, item in enumerate(obj):
filtered_item = _filter_non_serializable_objects(item, f"{path}[{i}]")
if filtered_item is not None:
filtered_items.append(filtered_item)
return type(obj)(filtered_items)

# For other types, assume they're serializable
return obj


def safe_deepcopy_session(session):
"""Safely deepcopy a session object, filtering out non-serializable objects.
This function creates a deep copy of a session while filtering out objects
that cannot be pickled, such as async generators.
Args:
session: The session object to copy
Returns:
A deep copy of the session with non-serializable objects filtered out
"""
# Create a shallow copy first
session_copy = copy.copy(session)

# Deep copy the state while filtering non-serializable objects
if hasattr(session_copy, 'state') and session_copy.state:
session_copy.state = _filter_non_serializable_objects(session_copy.state, "state")
# Now we can safely deepcopy the filtered state
session_copy.state = copy.deepcopy(session_copy.state)

# Deep copy other attributes that should be safe
if hasattr(session_copy, 'events'):
session_copy.events = copy.deepcopy(session.events)

return session_copy
Comment on lines +81 to +94

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of safe_deepcopy_session can be improved for clarity and robustness.

  • The session parameter and the function's return value should be type-hinted with Session for better static analysis and readability. This will require importing Session from .session.
  • The hasattr checks for state and events are redundant because Session is a Pydantic model and these fields are guaranteed to exist. You can simply check for truthiness if you want to skip operations on empty collections.
  • For consistency, copy.deepcopy(session_copy.events) should be used instead of copy.deepcopy(session.events). While it works here due to the shallow copy, it's less confusing to operate on the session_copy object throughout.
Suggested change
# Create a shallow copy first
session_copy = copy.copy(session)
# Deep copy the state while filtering non-serializable objects
if hasattr(session_copy, 'state') and session_copy.state:
session_copy.state = _filter_non_serializable_objects(session_copy.state, "state")
# Now we can safely deepcopy the filtered state
session_copy.state = copy.deepcopy(session_copy.state)
# Deep copy other attributes that should be safe
if hasattr(session_copy, 'events'):
session_copy.events = copy.deepcopy(session.events)
return session_copy
# Create a shallow copy first
session_copy = copy.copy(session)
# Deep copy the state while filtering non-serializable objects
if session_copy.state:
session_copy.state = _filter_non_serializable_objects(session_copy.state, "state")
# Now we can safely deepcopy the filtered state
session_copy.state = copy.deepcopy(session_copy.state)
# Deep copy other attributes that should be safe
if session_copy.events:
session_copy.events = copy.deepcopy(session_copy.events)
return session_copy

14 changes: 10 additions & 4 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
from ._session_copy_utils import safe_deepcopy_session
from ._session_copy_utils import _filter_non_serializable_objects

logger = logging.getLogger('google_adk.' + __name__)

Expand Down Expand Up @@ -93,11 +95,15 @@ def _create_session_impl(
if session_id and session_id.strip()
else str(uuid.uuid4())
)

# Filter out non-serializable objects from the state before creating the session
filtered_state = _filter_non_serializable_objects(state or {}, "initial_state")

session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=state or {},
state=filtered_state,
last_update_time=time.time(),
)
Comment on lines +99 to 108

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The filtering of state here in _create_session_impl is redundant. The session object is passed to safe_deepcopy_session a few lines later, which already performs this filtering before returning a copy to the caller.

Additionally, this change introduces an import of a private function (_filter_non_serializable_objects on line 32), which is generally discouraged.

I suggest removing this filtering step and the associated import. This will centralize the filtering logic within safe_deepcopy_session, simplify this function, and avoid importing a private member. The session object stored in memory can safely contain the non-serializable objects until it's copied for external use.

    session = Session(
        app_name=app_name,
        user_id=user_id,
        id=session_id,
        state=state or {},
        last_update_time=time.time(),
    )


Expand All @@ -107,7 +113,7 @@ def _create_session_impl(
self.sessions[app_name][user_id] = {}
self.sessions[app_name][user_id][session_id] = session

copied_session = copy.deepcopy(session)
copied_session = safe_deepcopy_session(session)
return self._merge_state(app_name, user_id, copied_session)

@override
Expand Down Expand Up @@ -158,7 +164,7 @@ def _get_session_impl(
return None

session = self.sessions[app_name][user_id].get(session_id)
copied_session = copy.deepcopy(session)
copied_session = safe_deepcopy_session(session)

if config:
if config.num_recent_events:
Expand Down Expand Up @@ -222,7 +228,7 @@ def _list_sessions_impl(

sessions_without_events = []
for session in self.sessions[app_name][user_id].values():
copied_session = copy.deepcopy(session)
copied_session = safe_deepcopy_session(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
Expand Down
160 changes: 160 additions & 0 deletions tests/unittests/sessions/test_async_generator_session_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for async generator handling in session services.
This module tests the fix for issue #1862 where async generators in session
state would cause pickle errors during deepcopy operations.
"""

import asyncio
import pytest
from typing import AsyncGenerator

from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions._session_copy_utils import (
_filter_non_serializable_objects,
_is_async_generator,
safe_deepcopy_session,
)


async def test_async_generator() -> AsyncGenerator[str, None]:
"""A test async generator function."""
yield "test_message_1"
yield "test_message_2"


class TestAsyncGeneratorSessionHandling:
"""Test class for async generator handling in sessions."""

def test_is_async_generator_detection(self):
"""Test that async generators are correctly detected."""
async def regular_async_func():
return "not a generator"

def regular_func():
return "regular function"

# Test with actual async generator
async_gen = test_async_generator()
assert _is_async_generator(async_gen) is True

# Test with non-generators
assert _is_async_generator(regular_func) is False
assert _is_async_generator("string") is False
assert _is_async_generator(123) is False
assert _is_async_generator([1, 2, 3]) is False
assert _is_async_generator({"key": "value"}) is False

# Clean up
asyncio.run(async_gen.aclose())

def test_filter_non_serializable_objects(self):
"""Test filtering of async generators from nested data structures."""
async_gen = test_async_generator()

# Test simple case
state = {"async_tool": async_gen, "normal_data": "test_value"}
filtered = _filter_non_serializable_objects(state)

assert "normal_data" in filtered
assert filtered["normal_data"] == "test_value"
assert "async_tool" not in filtered

# Test nested structure
nested_state = {
"level1": {
"level2": {
"async_gen": async_gen,
"normal": "value"
},
"other": "data"
},
"top_level": "value"
}

filtered_nested = _filter_non_serializable_objects(nested_state)
assert filtered_nested["level1"]["level2"]["normal"] == "value"
assert "async_gen" not in filtered_nested["level1"]["level2"]
assert filtered_nested["level1"]["other"] == "data"
assert filtered_nested["top_level"] == "value"

# Test list with async generator
list_state = {"tools": [async_gen, "normal_tool"]}
filtered_list = _filter_non_serializable_objects(list_state)
assert len(filtered_list["tools"]) == 1
assert filtered_list["tools"][0] == "normal_tool"

# Clean up
asyncio.run(async_gen.aclose())

@pytest.mark.asyncio
async def test_session_creation_with_async_generator(self):
"""Test that session creation works with async generators in state."""
session_service = InMemorySessionService()
async_gen = test_async_generator()

# This should not raise an exception
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
state={
"streaming_tool": async_gen,
"normal_data": "test_value"
}
)

# The async generator should be filtered out
assert "streaming_tool" not in session.state
assert "normal_data" in session.state
assert session.state["normal_data"] == "test_value"

# Clean up
await async_gen.aclose()

@pytest.mark.asyncio
async def test_session_operations_with_filtered_state(self):
"""Test that all session operations work after filtering."""
session_service = InMemorySessionService()

# Create session with normal state
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
state={"normal_data": "test_value"}
)

# Test get_session
retrieved_session = await session_service.get_session(
app_name="test_app",
user_id="test_user",
session_id=session.id
)
assert retrieved_session is not None
assert retrieved_session.state["normal_data"] == "test_value"

# Test list_sessions
sessions_response = await session_service.list_sessions(
app_name="test_app",
user_id="test_user"
)
assert len(sessions_response.sessions) == 1
assert sessions_response.sessions[0].id == session.id

def test_safe_deepcopy_session(self):
"""Test the safe_deepcopy_session function."""
# This test would require creating a mock session object
# For now, we test that the function exists and can be imported
assert callable(safe_deepcopy_session)
Comment on lines +156 to +160

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test for safe_deepcopy_session is currently a placeholder that only checks if the function is callable. It would be much more valuable to have a direct unit test that verifies its core logic: filtering non-serializable objects during a deep copy.

Please consider implementing a proper test case that creates a Session object with an async generator in its state, calls safe_deepcopy_session, and asserts that the generator is removed from the copied session's state.

    def test_safe_deepcopy_session(self):
        """Test the safe_deepcopy_session function."""
        from google.adk.sessions.session import Session

        async_gen = test_async_generator()
        session = Session(
            id="test_session",
            app_name="test_app",
            user_id="test_user",
            state={"async_tool": async_gen, "normal_data": "value"},
            events=[],
        )

        copied_session = safe_deepcopy_session(session)

        # Original session should be untouched
        assert "async_tool" in session.state
        
        # Copied session should have generator removed
        assert "async_tool" not in copied_session.state
        assert "normal_data" in copied_session.state
        assert copied_session.state["normal_data"] == "value"

        # Ensure other attributes are copied
        assert copied_session.id == session.id

        # Clean up
        asyncio.run(async_gen.aclose())

Loading