diff --git a/agenthub/codeact_agent/codeact_agent.py b/agenthub/codeact_agent/codeact_agent.py index 6f397baa3fc3..e1bde044c758 100644 --- a/agenthub/codeact_agent/codeact_agent.py +++ b/agenthub/codeact_agent/codeact_agent.py @@ -201,6 +201,12 @@ def step(self, state: State) -> Action: ], 'temperature': 0.0, } + + if self.llm.is_caching_prompt_active(): + params['extra_headers'] = { + 'anthropic-beta': 'prompt-caching-2024-07-31', + } + try: response = self.llm.completion(**params) except Exception: @@ -217,7 +223,7 @@ def _get_messages(self, state: State) -> list[Message]: content=[ TextContent( text=self.prompt_manager.system_message, - cache_prompt=self.llm.supports_prompt_caching, + cache_prompt=self.llm.is_caching_prompt_active(), # Cache system prompt ) ], ), @@ -226,7 +232,7 @@ def _get_messages(self, state: State) -> list[Message]: content=[ TextContent( text=self.prompt_manager.initial_user_message, - cache_prompt=self.llm.supports_prompt_caching, + cache_prompt=self.llm.is_caching_prompt_active(), # if the user asks the same query, ) ], ), @@ -252,14 +258,14 @@ def _get_messages(self, state: State) -> list[Message]: messages.append(message) # Add caching to the last 2 user messages - if self.llm.supports_prompt_caching: - user_messages = list( - islice((m for m in reversed(messages) if m.role == 'user'), 2) - ) - for message in user_messages: - message.content[ - -1 - ].cache_prompt = True # Last item inside the message content + if self.llm.is_caching_prompt_active(): + user_turns_processed = 0 + for message in reversed(messages): + if message.role == 'user' and user_turns_processed < 2: + message.content[ + -1 + ].cache_prompt = True # Last item inside the message content + user_turns_processed += 1 # The latest user message is important: # we want to remind the agent of the environment constraints diff --git a/config.template.toml b/config.template.toml index af26af7a49c2..9f65614b9af5 100644 --- a/config.template.toml +++ b/config.template.toml @@ -141,6 +141,9 @@ model = "gpt-4o" # Drop any unmapped (unsupported) params without causing an exception #drop_params = false +# Using the prompt caching feature provided by the LLM +#caching_prompt = false + # Base URL for the OLLAMA API #ollama_base_url = "" diff --git a/docs/modules/usage/llms/llms.md b/docs/modules/usage/llms/llms.md index 9e0b05465c3c..8311c127a2fe 100644 --- a/docs/modules/usage/llms/llms.md +++ b/docs/modules/usage/llms/llms.md @@ -44,6 +44,7 @@ The following environment variables might be necessary for some LLMs/providers: * `LLM_EMBEDDING_DEPLOYMENT_NAME` * `LLM_DROP_PARAMS` * `LLM_DISABLE_VISION` +* `LLM_CACHING_PROMPT` We have a few guides for running OpenHands with specific model providers: diff --git a/openhands/core/config.py b/openhands/core/config.py index 6fb7bbb95d19..6253cc5ac1de 100644 --- a/openhands/core/config.py +++ b/openhands/core/config.py @@ -52,6 +52,7 @@ class LLMConfig: ollama_base_url: The base URL for the OLLAMA API. drop_params: Drop any unmapped (unsupported) params without causing an exception. disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction). + caching_prompt: Using the prompt caching feature provided by the LLM. """ model: str = 'gpt-4o' @@ -80,6 +81,7 @@ class LLMConfig: ollama_base_url: str | None = None drop_params: bool | None = None disable_vision: bool | None = None + caching_prompt: bool = False def defaults_to_dict(self) -> dict: """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" diff --git a/openhands/core/schema/config.py b/openhands/core/schema/config.py index 6c34cebe34d3..1272ebe655a5 100644 --- a/openhands/core/schema/config.py +++ b/openhands/core/schema/config.py @@ -21,6 +21,7 @@ class ConfigType(str, Enum): LLM_API_KEY = 'LLM_API_KEY' LLM_API_VERSION = 'LLM_API_VERSION' LLM_BASE_URL = 'LLM_BASE_URL' + LLM_CACHING_PROMPT = 'LLM_CACHING_PROMPT' LLM_CUSTOM_LLM_PROVIDER = 'LLM_CUSTOM_LLM_PROVIDER' LLM_DROP_PARAMS = 'LLM_DROP_PARAMS' LLM_EMBEDDING_BASE_URL = 'LLM_EMBEDDING_BASE_URL' diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 708321d005a5..dc6daba81370 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -70,11 +70,6 @@ def __init__( # Set up config attributes with default values to prevent AttributeError LLMConfig.set_missing_attributes(self.config) - self.supports_prompt_caching = ( - self.vision_is_active() - and self.config.model in cache_prompting_supported_models - ) - # litellm actually uses base Exception here for unknown model self.model_info = None try: @@ -190,7 +185,7 @@ def wrapper(*args, **kwargs): if debug_str: debug_message += message_separator + debug_str - if self.supports_prompt_caching: + if self.is_caching_prompt_active(): # Anthropic-specific prompt caching if 'claude-3' in self.config.model: kwargs['extra_headers'] = { @@ -467,6 +462,17 @@ def _supports_vision(self): except Exception: return False + def is_caching_prompt_active(self) -> bool: + """Check if prompt caching is enabled and supported for current model. + + Returns: + boolean: True if prompt caching is active for the given model. + """ + return ( + self.config.caching_prompt is True + and self.config.model in cache_prompting_supported_models + ) + def _post_completion(self, response) -> None: """Post-process the completion response.""" try: diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 85a386311c43..7acf01413096 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -14,8 +14,8 @@ @pytest.fixture def mock_llm(): llm = Mock(spec=LLM) - llm.config = LLMConfig(model='claude-3-5-sonnet-20240620') - llm.supports_prompt_caching = True + llm.config = LLMConfig(model='claude-3-5-sonnet-20240620', caching_prompt=True) + llm.is_caching_prompt_active.return_value = True return llm