Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions examples/pydantic_ai_examples/anthropic_prompt_caching.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a more basic example to the Anthropic docs, and drop this?

Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python3
"""Example demonstrating Anthropic prompt caching.

This example shows how to use CachePoint to reduce costs by caching:
- Long system prompts
- Large context (like documentation)
- Tool definitions

Run with: uv run -m pydantic_ai_examples.anthropic_prompt_caching
"""

from pydantic_ai import Agent, CachePoint

# Sample long context to demonstrate caching
# Need at least 1024 tokens - repeating 10x to be safe
LONG_CONTEXT = (
"""
# Product Documentation

## Overview
Our API provides comprehensive data access with the following features:

### Authentication
All requests require a Bearer token in the Authorization header.
Rate limits: 1000 requests/hour for standard tier.

### Endpoints

#### GET /api/users
Returns a list of users with pagination support.
Parameters:
- page: Page number (default: 1)
- limit: Items per page (default: 20, max: 100)
- filter: Optional filter expression

#### GET /api/products
Returns product catalog with detailed specifications.
Parameters:
- category: Filter by category
- in_stock: Boolean, filter available items
- sort: Sort order (price_asc, price_desc, name)

#### POST /api/orders
Create a new order. Requires authentication.
Request body:
- user_id: Integer, required
- items: Array of {product_id, quantity}
- shipping_address: Object with address details

#### Error Handling
Standard HTTP status codes are used:
- 200: Success
- 400: Bad request
- 401: Unauthorized
- 404: Not found
- 500: Server error

## Best Practices
1. Always handle rate limiting with exponential backoff
2. Cache responses where appropriate
3. Use pagination for large datasets
4. Validate input before submission
5. Monitor API usage through dashboard

## Code Examples
See detailed examples in our GitHub repository.
"""
* 10
) # Repeat 10x to ensure we exceed Anthropic's minimum cache size (1024 tokens)


async def main() -> None:
"""Demonstrate prompt caching with Anthropic."""
print('=== Anthropic Prompt Caching Demo ===\n')

agent = Agent(
'anthropic:claude-sonnet-4-5',
system_prompt='You are a helpful API documentation assistant.',
)

# First request with cache point - this will write to cache
print('First request (will cache context)...')
result1 = await agent.run(
[
LONG_CONTEXT,
CachePoint(), # Everything before this will be cached
'What authentication method does the API use?',
]
)

print(f'Response: {result1.output}\n')
usage1 = result1.usage()
print(f'Usage: {usage1}')
if usage1.cache_write_tokens:
print(
f' Cache write tokens: {usage1.cache_write_tokens} (tokens written to cache)'
)
print()

# Second request with same cached context - should use cache
print('Second request (should read from cache)...')
result2 = await agent.run(
[
LONG_CONTEXT,
CachePoint(), # Same content, should hit cache
'What are the available API endpoints?',
]
)

print(f'Response: {result2.output}\n')
usage2 = result2.usage()
print(f'Usage: {usage2}')
if usage2.cache_read_tokens:
print(
f' Cache read tokens: {usage2.cache_read_tokens} (tokens read from cache)'
)
print(
f' Cache savings: ~{usage2.cache_read_tokens * 0.9:.0f} token-equivalents (90% discount)'
)
print()

# Third request with different question, same cache
print('Third request (should also read from cache)...')
result3 = await agent.run(
[
LONG_CONTEXT,
CachePoint(),
'How should I handle rate limiting?',
]
)

print(f'Response: {result3.output}\n')
usage3 = result3.usage()
print(f'Usage: {usage3}')
if usage3.cache_read_tokens:
print(f' Cache read tokens: {usage3.cache_read_tokens}')
print()

print('=== Summary ===')
total_usage = usage1 + usage2 + usage3
print(f'Total input tokens: {total_usage.input_tokens}')
print(f'Total cache write: {total_usage.cache_write_tokens}')
print(f'Total cache read: {total_usage.cache_read_tokens}')
if total_usage.cache_read_tokens:
savings = total_usage.cache_read_tokens * 0.9
print(f'Estimated savings: ~{savings:.0f} token-equivalents')


if __name__ == '__main__':
import asyncio

asyncio.run(main())
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
BinaryImage,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentFormat,
DocumentMediaType,
DocumentUrl,
Expand Down Expand Up @@ -141,6 +142,7 @@
'BinaryContent',
'BuiltinToolCallPart',
'BuiltinToolReturnPart',
'CachePoint',
'DocumentFormat',
'DocumentMediaType',
'DocumentUrl',
Expand Down
17 changes: 16 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,20 @@ def __init__(
raise ValueError('`BinaryImage` must be have a media type that starts with "image/"') # pragma: no cover


@dataclass
class CachePoint:
"""A cache point marker for prompt caching.

Can be inserted into UserPromptPart.content to mark cache boundaries.
Models that don't support caching will filter these out.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Models that don't support caching will filter these out.
Supported by:
- Anthropic

"""

kind: Literal['cache-point'] = 'cache-point'
"""Type identifier, this is available on all parts as a discriminator."""


MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
UserContent: TypeAlias = str | MultiModalContent
UserContent: TypeAlias = str | MultiModalContent | CachePoint


@dataclass(repr=False)
Expand Down Expand Up @@ -730,6 +742,9 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
if settings.include_content and settings.include_binary_content:
converted_part['content'] = base64.b64encode(part.data).decode()
parts.append(converted_part)
elif isinstance(part, CachePoint):
# CachePoint is a marker, not actual content - skip it for otel
pass
else:
parts.append({'type': part.kind}) # pragma: no cover
return parts
Expand Down
30 changes: 28 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
FilePart,
FinishReason,
Expand Down Expand Up @@ -58,6 +59,7 @@
from anthropic.types.beta import (
BetaBase64PDFBlockParam,
BetaBase64PDFSourceParam,
BetaCacheControlEphemeralParam,
BetaCitationsDelta,
BetaCodeExecutionTool20250522Param,
BetaCodeExecutionToolResultBlock,
Expand Down Expand Up @@ -477,7 +479,10 @@ async def _map_message( # noqa: C901
system_prompt_parts.append(request_part.content)
elif isinstance(request_part, UserPromptPart):
async for content in self._map_user_prompt(request_part):
user_content_params.append(content)
if isinstance(content, CachePoint):
self._add_cache_control_to_last_param(user_content_params)
else:
user_content_params.append(content)
elif isinstance(request_part, ToolReturnPart):
tool_result_block_param = BetaToolResultBlockParam(
tool_use_id=_guard_tool_call_id(t=request_part),
Expand Down Expand Up @@ -639,10 +644,27 @@ async def _map_message( # noqa: C901
system_prompt = '\n\n'.join(system_prompt_parts)
return system_prompt, anthropic_messages

@staticmethod
def _add_cache_control_to_last_param(params: list[BetaContentBlockParam]) -> None:
"""Add cache control to the last content block param."""
if not params:
raise UserError(
'CachePoint cannot be the first content in a user message - there must be previous content to attach the CachePoint to.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying in context from https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#what-can-be-cached:

Tools: Tool definitions in the tools array
System messages: Content blocks in the system array
Text messages: Content blocks in the messages.content array, for both user and assistant turns
Images & Documents: Content blocks in the messages.content array, in user turns
Tool use and tool results: Content blocks in the messages.content array, in both user and assistant turns

I think we should support inserting a cache point after tool defs and system messages as well.

In the original PR I suggested doing this by supporting CachePoint as the first content in a user message (by adding it to whatever came before it: the system message, tool definition, or the last message of the assistant output), but that doesn't really feel natural from a code perspective.

What do you think about adding anthropic_cache_tools and anthropic_cache_instructions fields to AnthropicModelSettings, and setting cache_control on the relevant parts when set?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable, I'll look into it!

)

# Only certain types support cache_control
cacheable_types = {'text', 'tool_use', 'server_tool_use', 'image', 'tool_result'}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please link to the doc this came from?

last_param = cast(dict[str, Any], params[-1]) # Cast to dict for mutation
if last_param['type'] not in cacheable_types:
raise UserError(f'Cache control not supported for param type: {last_param["type"]}')

# Add cache_control to the last param
last_param['cache_control'] = BetaCacheControlEphemeralParam(type='ephemeral')

@staticmethod
async def _map_user_prompt(
part: UserPromptPart,
) -> AsyncGenerator[BetaContentBlockParam]:
) -> AsyncGenerator[BetaContentBlockParam | CachePoint]:
if isinstance(part.content, str):
if part.content: # Only yield non-empty text
yield BetaTextBlockParam(text=part.content, type='text')
Expand All @@ -651,6 +673,8 @@ async def _map_user_prompt(
if isinstance(item, str):
if item: # Only yield non-empty text
yield BetaTextBlockParam(text=item, type='text')
elif isinstance(item, CachePoint):
yield item
elif isinstance(item, BinaryContent):
if item.is_image:
yield BetaImageBlockParam(
Expand Down Expand Up @@ -717,6 +741,8 @@ def _map_usage(
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
}

# Note: genai-prices already extracts cache_creation_input_tokens and cache_read_input_tokens
# from the Anthropic response and maps them to cache_write_tokens and cache_read_tokens
return usage.RequestUsage.extract(
dict(model=model, usage=details),
provider=provider,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
FinishReason,
ImageUrl,
Expand Down Expand Up @@ -624,6 +625,9 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int])
content.append({'video': video})
elif isinstance(item, AudioUrl): # pragma: no cover
raise NotImplementedError('Audio is not supported yet.')
elif isinstance(item, CachePoint):
# Bedrock doesn't support prompt caching via CachePoint in this implementation
pass
else:
assert_never(item)
return [{'role': 'user', 'content': content}]
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
FilePart,
FileUrl,
ModelMessage,
Expand Down Expand Up @@ -391,6 +392,9 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]
else: # pragma: lax no cover
file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
content.append(file_data)
elif isinstance(item, CachePoint):
# Gemini doesn't support prompt caching via CachePoint
pass
else:
assert_never(item) # pragma: lax no cover
return content
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
FilePart,
FileUrl,
FinishReason,
Expand Down Expand Up @@ -602,6 +603,9 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
else:
file_data_dict: FileDataDict = {'file_uri': item.url, 'mime_type': item.media_type}
content.append({'file_data': file_data_dict}) # pragma: lax no cover
elif isinstance(item, CachePoint):
# Google Gemini doesn't support prompt caching via CachePoint
pass
else:
assert_never(item)
return content
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
FilePart,
FinishReason,
Expand Down Expand Up @@ -447,6 +448,9 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
elif isinstance(item, VideoUrl):
raise NotImplementedError('VideoUrl is not supported for Hugging Face')
elif isinstance(item, CachePoint):
# Hugging Face doesn't support prompt caching via CachePoint
pass
else:
assert_never(item)
return ChatCompletionInputMessage(role='user', content=content) # type: ignore
Expand Down
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
BinaryImage,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
FilePart,
FinishReason,
Expand Down Expand Up @@ -860,6 +861,9 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse
)
elif isinstance(item, VideoUrl): # pragma: no cover
raise NotImplementedError('VideoUrl is not supported for OpenAI')
elif isinstance(item, CachePoint):
# OpenAI doesn't support prompt caching via CachePoint, so we filter it out
pass
else:
assert_never(item)
return chat.ChatCompletionUserMessageParam(role='user', content=content)
Expand Down Expand Up @@ -1598,7 +1602,7 @@ def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseForma
return response_format_param

@staticmethod
async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam:
async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam: # noqa: C901
content: str | list[responses.ResponseInputContentParam]
if isinstance(part.content, str):
content = part.content
Expand Down Expand Up @@ -1673,6 +1677,9 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
)
elif isinstance(item, VideoUrl): # pragma: no cover
raise NotImplementedError('VideoUrl is not supported for OpenAI.')
elif isinstance(item, CachePoint):
# OpenAI doesn't support prompt caching via CachePoint, so we filter it out
pass
else:
assert_never(item)
return responses.EasyInputMessageParam(role='user', content=content)
Expand Down
Loading
Loading