Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ eval = [
test = [
# go/keep-sorted start
"a2a-sdk>=0.3.0,<0.4.0;python_version>='3.10'",
"aiosqlite>=0.21.0", # For database session service tests
"anthropic>=0.43.0", # For anthropic model tests
"kubernetes>=29.0.0", # For GkeCodeExecutor
"langchain-community>=0.3.17",
"langgraph>=0.2.60, <= 0.4.10", # For LangGraphAgent
"langgraph>=0.2.60, <= 0.4.10", # For LangGraphAgent
"litellm>=1.75.5, <2.0.0", # For LiteLLM tests
"llama-index-readers-file>=0.4.0", # For retrieval tests
"openai>=1.100.2", # For LiteLLM
Expand Down
118 changes: 69 additions & 49 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import asyncio
import copy
from datetime import datetime
from datetime import timezone
Expand All @@ -29,20 +30,21 @@
from sqlalchemy import event
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session as DatabaseSessionFactory
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import MetaData
from sqlalchemy.types import DateTime
from sqlalchemy.types import PickleType
Expand Down Expand Up @@ -390,11 +392,11 @@ def __init__(self, db_url: str, **kwargs: Any):
# 2. Create all tables based on schema
# 3. Initialize all properties
try:
db_engine = create_engine(db_url, **kwargs)
db_engine = create_async_engine(db_url, **kwargs)

if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine, "connect", set_sqlite_pragma)
event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma)

except Exception as e:
if isinstance(e, ArgumentError):
Expand All @@ -413,18 +415,30 @@ def __init__(self, db_url: str, **kwargs: Any):
local_timezone = get_localzone()
logger.info("Local timezone: %s", local_timezone)

self.db_engine: Engine = db_engine
self.db_engine: AsyncEngine = db_engine
self.metadata: MetaData = MetaData()
self.inspector = inspect(self.db_engine)

# DB session factory method
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
sessionmaker(bind=self.db_engine)
)

# Uncomment to recreate DB every time
# Base.metadata.drop_all(self.db_engine)
Base.metadata.create_all(self.db_engine)
self.database_session_factory: async_sessionmaker[
DatabaseSessionFactory
] = async_sessionmaker(bind=self.db_engine)

# Flag to indicate if tables are created
self._tables_created = False
# Lock to ensure thread-safe table creation
self._table_creation_lock = asyncio.Lock()

async def _ensure_tables_created(self):
"""Ensure database tables are created. This is called lazily."""
if self._tables_created:
return

async with self._table_creation_lock:
# Double-check after acquiring the lock
if not self._tables_created:
async with self.db_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
self._tables_created = True

@override
async def create_session(
Expand All @@ -440,12 +454,11 @@ async def create_session(
# 3. Add the object to the table
# 4. Build the session object with generated id
# 5. Return the session

with self.database_session_factory() as sql_session:

await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
# Fetch app and user states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_user_state = sql_session.get(
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(
StorageUserState, (app_name, user_id)
)

Expand Down Expand Up @@ -485,9 +498,9 @@ async def create_session(
state=session_state,
)
sql_session.add(storage_session)
sql_session.commit()
await sql_session.commit()

sql_session.refresh(storage_session)
await sql_session.refresh(storage_session)

# Merge states for response
merged_state = _merge_state(app_state, user_state, session_state)
Expand All @@ -503,11 +516,12 @@ async def get_session(
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
await self._ensure_tables_created()
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
with self.database_session_factory() as sql_session:
storage_session = sql_session.get(
async with self.database_session_factory() as sql_session:
storage_session = await sql_session.get(
StorageSession, (app_name, user_id, session_id)
)
if storage_session is None:
Expand All @@ -519,24 +533,24 @@ async def get_session(
else:
timestamp_filter = True

storage_events = (
sql_session.query(StorageEvent)
stmt = (
select(StorageEvent)
.filter(StorageEvent.app_name == app_name)
.filter(StorageEvent.session_id == storage_session.id)
.filter(StorageEvent.user_id == user_id)
.filter(timestamp_filter)
.order_by(StorageEvent.timestamp.desc())
.limit(
config.num_recent_events
if config and config.num_recent_events
else None
)
.all()
)

if config and config.num_recent_events:
stmt = stmt.limit(config.num_recent_events)

result = await sql_session.execute(stmt)
storage_events = result.scalars().all()

# Fetch states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_user_state = sql_session.get(
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(
StorageUserState, (app_name, user_id)
)

Expand All @@ -556,17 +570,19 @@ async def get_session(
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
with self.database_session_factory() as sql_session:
results = (
sql_session.query(StorageSession)
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
stmt = (
select(StorageSession)
.filter(StorageSession.app_name == app_name)
.filter(StorageSession.user_id == user_id)
.all()
)
result = await sql_session.execute(stmt)
results = result.scalars().all()

# Fetch states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_user_state = sql_session.get(
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(
StorageUserState, (app_name, user_id)
)

Expand All @@ -585,25 +601,27 @@ async def list_sessions(
async def delete_session(
self, app_name: str, user_id: str, session_id: str
) -> None:
with self.database_session_factory() as sql_session:
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
stmt = delete(StorageSession).where(
StorageSession.app_name == app_name,
StorageSession.user_id == user_id,
StorageSession.id == session_id,
)
sql_session.execute(stmt)
sql_session.commit()
await sql_session.execute(stmt)
await sql_session.commit()

@override
async def append_event(self, session: Session, event: Event) -> Event:
await self._ensure_tables_created()
if event.partial:
return event

# 1. Check if timestamp is stale
# 2. Update session attributes based on event config
# 3. Store event to table
with self.database_session_factory() as sql_session:
storage_session = sql_session.get(
async with self.database_session_factory() as sql_session:
storage_session = await sql_session.get(
StorageSession, (session.app_name, session.user_id, session.id)
)

Expand All @@ -617,8 +635,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
)

# Fetch states from storage
storage_app_state = sql_session.get(StorageAppState, (session.app_name))
storage_user_state = sql_session.get(
storage_app_state = await sql_session.get(
StorageAppState, (session.app_name)
)
storage_user_state = await sql_session.get(
StorageUserState, (session.app_name, session.user_id)
)

Expand Down Expand Up @@ -649,8 +669,8 @@ async def append_event(self, session: Session, event: Event) -> Event:

sql_session.add(StorageEvent.from_event(session, event))

sql_session.commit()
sql_session.refresh(storage_session)
await sql_session.commit()
await sql_session.refresh(storage_session)

# Update timestamp with commit time
session.last_update_time = storage_session.update_timestamp_tz
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_session_service(
):
"""Creates a session service for testing."""
if service_type == SessionServiceType.DATABASE:
return DatabaseSessionService('sqlite:///:memory:')
return DatabaseSessionService('sqlite+aiosqlite:///:memory:')
return InMemorySessionService()


Expand Down