Skip to content

Commit

Permalink
(feat) making prompt caching optional instead of enabled default (All…
Browse files Browse the repository at this point in the history
…-Hands-AI#3689)

* (feat) making prompt caching optional instead of enabled default

At present, only the Claude models support prompt caching as a experimental feature, therefore, this feature should be implemented as an optional setting rather than being enabled by default.

Signed-off-by: Yi Lin <[email protected]>

* handle the conflict

* fix unittest mock return value

* fix lint error in whitespace

---------

Signed-off-by: Yi Lin <[email protected]>
  • Loading branch information
WannaTen authored Sep 5, 2024
1 parent 5b7ab28 commit 82a154f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 18 deletions.
26 changes: 16 additions & 10 deletions agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def step(self, state: State) -> Action:
],
'temperature': 0.0,
}

if self.llm.is_caching_prompt_active():
params['extra_headers'] = {
'anthropic-beta': 'prompt-caching-2024-07-31',
}

try:
response = self.llm.completion(**params)
except Exception:
Expand All @@ -217,7 +223,7 @@ def _get_messages(self, state: State) -> list[Message]:
content=[
TextContent(
text=self.prompt_manager.system_message,
cache_prompt=self.llm.supports_prompt_caching,
cache_prompt=self.llm.is_caching_prompt_active(), # Cache system prompt
)
],
),
Expand All @@ -226,7 +232,7 @@ def _get_messages(self, state: State) -> list[Message]:
content=[
TextContent(
text=self.prompt_manager.initial_user_message,
cache_prompt=self.llm.supports_prompt_caching,
cache_prompt=self.llm.is_caching_prompt_active(), # if the user asks the same query,
)
],
),
Expand All @@ -252,14 +258,14 @@ def _get_messages(self, state: State) -> list[Message]:
messages.append(message)

# Add caching to the last 2 user messages
if self.llm.supports_prompt_caching:
user_messages = list(
islice((m for m in reversed(messages) if m.role == 'user'), 2)
)
for message in user_messages:
message.content[
-1
].cache_prompt = True # Last item inside the message content
if self.llm.is_caching_prompt_active():
user_turns_processed = 0
for message in reversed(messages):
if message.role == 'user' and user_turns_processed < 2:
message.content[
-1
].cache_prompt = True # Last item inside the message content
user_turns_processed += 1

# The latest user message is important:
# we want to remind the agent of the environment constraints
Expand Down
3 changes: 3 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ model = "gpt-4o"
# Drop any unmapped (unsupported) params without causing an exception
#drop_params = false

# Using the prompt caching feature provided by the LLM
#caching_prompt = false

# Base URL for the OLLAMA API
#ollama_base_url = ""

Expand Down
1 change: 1 addition & 0 deletions docs/modules/usage/llms/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The following environment variables might be necessary for some LLMs/providers:
* `LLM_EMBEDDING_DEPLOYMENT_NAME`
* `LLM_DROP_PARAMS`
* `LLM_DISABLE_VISION`
* `LLM_CACHING_PROMPT`

We have a few guides for running OpenHands with specific model providers:

Expand Down
2 changes: 2 additions & 0 deletions openhands/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LLMConfig:
ollama_base_url: The base URL for the OLLAMA API.
drop_params: Drop any unmapped (unsupported) params without causing an exception.
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
caching_prompt: Using the prompt caching feature provided by the LLM.
"""

model: str = 'gpt-4o'
Expand Down Expand Up @@ -80,6 +81,7 @@ class LLMConfig:
ollama_base_url: str | None = None
drop_params: bool | None = None
disable_vision: bool | None = None
caching_prompt: bool = False

def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
Expand Down
1 change: 1 addition & 0 deletions openhands/core/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ConfigType(str, Enum):
LLM_API_KEY = 'LLM_API_KEY'
LLM_API_VERSION = 'LLM_API_VERSION'
LLM_BASE_URL = 'LLM_BASE_URL'
LLM_CACHING_PROMPT = 'LLM_CACHING_PROMPT'
LLM_CUSTOM_LLM_PROVIDER = 'LLM_CUSTOM_LLM_PROVIDER'
LLM_DROP_PARAMS = 'LLM_DROP_PARAMS'
LLM_EMBEDDING_BASE_URL = 'LLM_EMBEDDING_BASE_URL'
Expand Down
18 changes: 12 additions & 6 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ def __init__(
# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)

self.supports_prompt_caching = (
self.vision_is_active()
and self.config.model in cache_prompting_supported_models
)

# litellm actually uses base Exception here for unknown model
self.model_info = None
try:
Expand Down Expand Up @@ -190,7 +185,7 @@ def wrapper(*args, **kwargs):
if debug_str:
debug_message += message_separator + debug_str

if self.supports_prompt_caching:
if self.is_caching_prompt_active():
# Anthropic-specific prompt caching
if 'claude-3' in self.config.model:
kwargs['extra_headers'] = {
Expand Down Expand Up @@ -467,6 +462,17 @@ def _supports_vision(self):
except Exception:
return False

def is_caching_prompt_active(self) -> bool:
"""Check if prompt caching is enabled and supported for current model.
Returns:
boolean: True if prompt caching is active for the given model.
"""
return (
self.config.caching_prompt is True
and self.config.model in cache_prompting_supported_models
)

def _post_completion(self, response) -> None:
"""Post-process the completion response."""
try:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_prompt_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
@pytest.fixture
def mock_llm():
llm = Mock(spec=LLM)
llm.config = LLMConfig(model='claude-3-5-sonnet-20240620')
llm.supports_prompt_caching = True
llm.config = LLMConfig(model='claude-3-5-sonnet-20240620', caching_prompt=True)
llm.is_caching_prompt_active.return_value = True
return llm


Expand Down

0 comments on commit 82a154f

Please sign in to comment.