(feat) making prompt caching optional instead of enabled default (#3689)

* (feat) making prompt caching optional instead of enabled default

At present, only the Claude models support prompt caching as a experimental feature, therefore, this feature should be implemented as an optional setting rather than being enabled by default.

Signed-off-by: Yi Lin <teroincn@gmail.com>

* handle the conflict

* fix unittest mock return value

* fix lint error in whitespace

---------

Signed-off-by: Yi Lin <teroincn@gmail.com>
This commit is contained in:
niliy01 2024-09-06 00:52:26 +08:00 committed by GitHub
parent 5b7ab28511
commit 82a154f7e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 37 additions and 18 deletions

View File

@ -201,6 +201,12 @@ class CodeActAgent(Agent):
],
'temperature': 0.0,
}
if self.llm.is_caching_prompt_active():
params['extra_headers'] = {
'anthropic-beta': 'prompt-caching-2024-07-31',
}
try:
response = self.llm.completion(**params)
except Exception:
@ -217,7 +223,7 @@ class CodeActAgent(Agent):
content=[
TextContent(
text=self.prompt_manager.system_message,
cache_prompt=self.llm.supports_prompt_caching,
cache_prompt=self.llm.is_caching_prompt_active(), # Cache system prompt
)
],
),
@ -226,7 +232,7 @@ class CodeActAgent(Agent):
content=[
TextContent(
text=self.prompt_manager.initial_user_message,
cache_prompt=self.llm.supports_prompt_caching,
cache_prompt=self.llm.is_caching_prompt_active(), # if the user asks the same query,
)
],
),
@ -252,14 +258,14 @@ class CodeActAgent(Agent):
messages.append(message)
# Add caching to the last 2 user messages
if self.llm.supports_prompt_caching:
user_messages = list(
islice((m for m in reversed(messages) if m.role == 'user'), 2)
)
for message in user_messages:
message.content[
-1
].cache_prompt = True # Last item inside the message content
if self.llm.is_caching_prompt_active():
user_turns_processed = 0
for message in reversed(messages):
if message.role == 'user' and user_turns_processed < 2:
message.content[
-1
].cache_prompt = True # Last item inside the message content
user_turns_processed += 1
# The latest user message is important:
# we want to remind the agent of the environment constraints

View File

@ -141,6 +141,9 @@ model = "gpt-4o"
# Drop any unmapped (unsupported) params without causing an exception
#drop_params = false
# Using the prompt caching feature provided by the LLM
#caching_prompt = false
# Base URL for the OLLAMA API
#ollama_base_url = ""

View File

@ -44,6 +44,7 @@ The following environment variables might be necessary for some LLMs/providers:
* `LLM_EMBEDDING_DEPLOYMENT_NAME`
* `LLM_DROP_PARAMS`
* `LLM_DISABLE_VISION`
* `LLM_CACHING_PROMPT`
We have a few guides for running OpenHands with specific model providers:

View File

@ -52,6 +52,7 @@ class LLMConfig:
ollama_base_url: The base URL for the OLLAMA API.
drop_params: Drop any unmapped (unsupported) params without causing an exception.
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
caching_prompt: Using the prompt caching feature provided by the LLM.
"""
model: str = 'gpt-4o'
@ -80,6 +81,7 @@ class LLMConfig:
ollama_base_url: str | None = None
drop_params: bool | None = None
disable_vision: bool | None = None
caching_prompt: bool = False
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""

View File

@ -21,6 +21,7 @@ class ConfigType(str, Enum):
LLM_API_KEY = 'LLM_API_KEY'
LLM_API_VERSION = 'LLM_API_VERSION'
LLM_BASE_URL = 'LLM_BASE_URL'
LLM_CACHING_PROMPT = 'LLM_CACHING_PROMPT'
LLM_CUSTOM_LLM_PROVIDER = 'LLM_CUSTOM_LLM_PROVIDER'
LLM_DROP_PARAMS = 'LLM_DROP_PARAMS'
LLM_EMBEDDING_BASE_URL = 'LLM_EMBEDDING_BASE_URL'

View File

@ -70,11 +70,6 @@ class LLM:
# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)
self.supports_prompt_caching = (
self.vision_is_active()
and self.config.model in cache_prompting_supported_models
)
# litellm actually uses base Exception here for unknown model
self.model_info = None
try:
@ -190,7 +185,7 @@ class LLM:
if debug_str:
debug_message += message_separator + debug_str
if self.supports_prompt_caching:
if self.is_caching_prompt_active():
# Anthropic-specific prompt caching
if 'claude-3' in self.config.model:
kwargs['extra_headers'] = {
@ -467,6 +462,17 @@ class LLM:
except Exception:
return False
def is_caching_prompt_active(self) -> bool:
"""Check if prompt caching is enabled and supported for current model.
Returns:
boolean: True if prompt caching is active for the given model.
"""
return (
self.config.caching_prompt is True
and self.config.model in cache_prompting_supported_models
)
def _post_completion(self, response) -> None:
"""Post-process the completion response."""
try:

View File

@ -14,8 +14,8 @@ from openhands.storage import get_file_store
@pytest.fixture
def mock_llm():
llm = Mock(spec=LLM)
llm.config = LLMConfig(model='claude-3-5-sonnet-20240620')
llm.supports_prompt_caching = True
llm.config = LLMConfig(model='claude-3-5-sonnet-20240620', caching_prompt=True)
llm.is_caching_prompt_active.return_value = True
return llm