Skip to content

Add background tasks with Docket #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,35 @@ python -m pytest
3. Commit your changes
4. Push to the branch
5. Create a Pull Request

## Running the Background Task Worker

The Redis Memory Server uses Docket for background task management. There are two ways to run the worker:

### 1. Using the Docket CLI

After installing the package, you can run the worker using the Docket CLI command:

```bash
docket worker --tasks agent_memory_server.docket_tasks:task_collection
```

You can customize the concurrency and redelivery timeout:

```bash
docket worker --tasks agent_memory_server.docket_tasks:task_collection --concurrency 5 --redelivery-timeout 60
```

### 2. Using Python Code

Alternatively, you can run the worker directly in Python:

```bash
python -m agent_memory_server.worker
```

With customization options:

```bash
python -m agent_memory_server.worker --concurrency 5 --redelivery-timeout 60
```
17 changes: 9 additions & 8 deletions agent_memory_server/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Literal

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException

from agent_memory_server import long_term_memory, messages
from agent_memory_server.config import settings
from agent_memory_server.dependencies import get_background_tasks
from agent_memory_server.llms import get_model_config
from agent_memory_server.logging import get_logger
from agent_memory_server.models import (
Expand Down Expand Up @@ -130,14 +131,15 @@ async def get_session_memory(
async def put_session_memory(
session_id: str,
memory: SessionMemory,
background_tasks: BackgroundTasks,
background_tasks=Depends(get_background_tasks),
):
"""
Set session memory. Replaces existing session memory.

Args:
session_id: The session ID
memory: Messages and context to save
background_tasks: DocketBackgroundTasks instance (injected automatically)

Returns:
Acknowledgement response
Expand Down Expand Up @@ -179,26 +181,25 @@ async def delete_session_memory(

@router.post("/long-term-memory", response_model=AckResponse)
async def create_long_term_memory(
payload: CreateLongTermMemoryPayload, background_tasks: BackgroundTasks
payload: CreateLongTermMemoryPayload,
background_tasks=Depends(get_background_tasks),
):
"""
Create a long-term memory

Args:
payload: Long-term memory payload
background_tasks: DocketBackgroundTasks instance (injected automatically)

Returns:
Acknowledgement response
"""
redis = get_redis_conn()

if not settings.long_term_memory:
raise HTTPException(status_code=400, detail="Long-term memory is disabled")

await long_term_memory.index_long_term_memories(
redis=redis,
await background_tasks.add_task(
long_term_memory.index_long_term_memories,
memories=payload.memories,
background_tasks=background_tasks,
)
return AckResponse(status="ok")

Expand Down
4 changes: 4 additions & 0 deletions agent_memory_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,9 @@ class Settings(BaseSettings):
redisvl_index_name: str = "memory"
redisvl_index_prefix: str = "memory"

# Docket settings
docket_name: str = "memory-server"
use_docket: bool = True


settings = Settings()
35 changes: 35 additions & 0 deletions agent_memory_server/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from collections.abc import Callable
from typing import Any

from fastapi import BackgroundTasks

from agent_memory_server.config import settings


class DocketBackgroundTasks(BackgroundTasks):
"""A BackgroundTasks implementation that uses Docket."""

async def add_task(
self, func: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
"""Run tasks either directly or through Docket"""
from docket import Docket

if settings.use_docket:
async with Docket(
name=settings.docket_name,
url=settings.redis_url,
) as docket:
# Schedule task through Docket
await docket.add(func)(*args, **kwargs)
else:
await func(*args, **kwargs)


def get_background_tasks() -> DocketBackgroundTasks:
"""
Dependency function that returns a DocketBackgroundTasks instance.

This is used by API endpoints to inject a consistent background tasks object.
"""
return DocketBackgroundTasks()
43 changes: 43 additions & 0 deletions agent_memory_server/docket_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Background task management using Docket.
"""

import logging

from docket import Docket

from agent_memory_server.config import settings
from agent_memory_server.long_term_memory import (
extract_memory_structure,
index_long_term_memories,
)
from agent_memory_server.summarization import summarize_session


logger = logging.getLogger(__name__)


# Register functions in the task collection for the CLI worker
task_collection = [
extract_memory_structure,
summarize_session,
index_long_term_memories,
]


async def register_tasks() -> None:
"""Register all task functions with Docket."""
if not settings.use_docket:
logger.info("Docket is disabled, skipping task registration")
return

# Initialize Docket client
async with Docket(
name=settings.docket_name,
url=settings.redis_url,
) as docket:
# Register all tasks
for task in task_collection:
docket.register(task)

logger.info(f"Registered {len(task_collection)} background tasks with Docket")
19 changes: 9 additions & 10 deletions agent_memory_server/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from functools import reduce

import nanoid
from fastapi import BackgroundTasks
from redis.asyncio import Redis
from redisvl.query import VectorQuery, VectorRangeQuery
from redisvl.utils.vectorize import OpenAITextVectorizer

from agent_memory_server.dependencies import get_background_tasks
from agent_memory_server.extraction import handle_extraction
from agent_memory_server.filters import (
CreatedAt,
Expand All @@ -25,6 +25,7 @@
)
from agent_memory_server.utils import (
Keys,
get_redis_conn,
get_search_index,
safe_get,
)
Expand All @@ -33,11 +34,10 @@
logger = logging.getLogger(__name__)


async def extract_memory_structure(
redis: Redis, _id: str, text: str, namespace: str | None
):
async def extract_memory_structure(_id: str, text: str, namespace: str | None):
redis = get_redis_conn()

# Process messages for topic/entity extraction
# TODO: Move into background task.
topics, entities = await handle_extraction(text)

# Convert lists to comma-separated strings for TAG fields
Expand Down Expand Up @@ -65,14 +65,13 @@ async def compact_long_term_memories(redis: Redis) -> None:


async def index_long_term_memories(
redis: Redis,
memories: list[LongTermMemory],
background_tasks: BackgroundTasks,
) -> None:
"""
Index long-term memories in Redis for search
"""

redis = get_redis_conn()
background_tasks = get_background_tasks()
vectorizer = OpenAITextVectorizer()
embeddings = await vectorizer.aembed_many(
[memory.text for memory in memories],
Expand Down Expand Up @@ -100,8 +99,8 @@ async def index_long_term_memories(
},
)

background_tasks.add_task(
extract_memory_structure, redis, id_, memory.text, memory.namespace
await background_tasks.add_task(
extract_memory_structure, id_, memory.text, memory.namespace
)

await pipe.execute()
Expand Down
45 changes: 43 additions & 2 deletions agent_memory_server/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
from contextlib import asynccontextmanager

import uvicorn
Expand All @@ -7,6 +8,7 @@
from agent_memory_server import utils
from agent_memory_server.api import router as memory_router
from agent_memory_server.config import settings
from agent_memory_server.docket_tasks import register_tasks
from agent_memory_server.healthcheck import router as health_router
from agent_memory_server.llms import MODEL_CONFIGS, ModelProvider
from agent_memory_server.logging import configure_logging, get_logger
Expand Down Expand Up @@ -87,6 +89,20 @@ async def lifespan(app: FastAPI):
logger.error(f"Failed to ensure RediSearch index: {e}")
raise

# Initialize Docket for background tasks if enabled
if settings.use_docket:
try:
await register_tasks()
logger.info("Initialized Docket for background tasks")
logger.info("To run the worker, use one of these methods:")
logger.info(
"1. CLI: docket worker --tasks agent_memory_server.docket_tasks:task_collection"
)
logger.info("2. Python: python -m agent_memory_server.worker")
except Exception as e:
logger.error(f"Failed to initialize Docket: {e}")
raise

# Show available models
openai_models = [
model
Expand Down Expand Up @@ -138,6 +154,31 @@ def on_start_logger(port: int):

# Run the application
if __name__ == "__main__":
port = int(os.environ.get("PORT", "8000"))
# Parse command line arguments for port
port = settings.port

# Check if --port argument is provided
if "--port" in sys.argv:
try:
port_index = sys.argv.index("--port") + 1
if port_index < len(sys.argv):
port = int(sys.argv[port_index])
print(f"Using port from command line: {port}")
except (ValueError, IndexError):
# If conversion fails or index out of bounds, use default
print(f"Invalid port argument, using default: {port}")
else:
print(f"No port argument provided, using default: {port}")

# Explicitly unset the PORT environment variable if it exists
if "PORT" in os.environ:
port_val = os.environ.pop("PORT")
print(f"Removed environment variable PORT={port_val}")

on_start_logger(port)
uvicorn.run("agent_memory_server.main:app", host="0.0.0.0", port=port, reload=False)
uvicorn.run(
app, # Using the app instance directly
host="0.0.0.0",
port=port,
reload=False,
)
5 changes: 3 additions & 2 deletions agent_memory_server/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import sys

from fastapi import BackgroundTasks, HTTPException
from fastapi import HTTPException
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.prompts import base
from mcp.types import TextContent
Expand All @@ -13,6 +13,7 @@
search_long_term_memory as core_search_long_term_memory,
)
from agent_memory_server.config import settings
from agent_memory_server.dependencies import get_background_tasks
from agent_memory_server.models import (
AckResponse,
CreateLongTermMemoryPayload,
Expand Down Expand Up @@ -75,7 +76,7 @@ async def create_long_term_memories(
An acknowledgement response indicating success
"""
return await core_create_long_term_memory(
payload, background_tasks=BackgroundTasks()
payload, background_tasks=get_background_tasks()
)


Expand Down
Loading
Loading