mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add prompt caching (Sonnet, Haiku only) (#3411)
* Add prompt caching * remove anthropic-version from extra_headers * change supports_prompt_caching method to attribute * change caching strat and log cache statistics * add reminder as a new message to fix caching * fix unit test * append reminder to the end of the last message content * move token logs to post completion function * fix unit test failure * fix reminder and prompt caching * unit tests for prompt caching * add test * clean up tests * separate reminder, use latest two messages * fix tests --------- Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com> Co-authored-by: Xingyao Wang <xingyao6@illinois.edu> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
e72dc96d13
commit
5bb931e4d6
@ -172,26 +172,44 @@ class CodeActAgent(Agent):
|
||||
# prepare what we want to send to the LLM
|
||||
messages = self._get_messages(state)
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[message.model_dump() for message in messages],
|
||||
stop=[
|
||||
params = {
|
||||
'messages': [message.model_dump() for message in messages],
|
||||
'stop': [
|
||||
'</execute_ipython>',
|
||||
'</execute_bash>',
|
||||
'</execute_browse>',
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
'temperature': 0.0,
|
||||
}
|
||||
|
||||
if self.llm.supports_prompt_caching:
|
||||
params['extra_headers'] = {
|
||||
'anthropic-beta': 'prompt-caching-2024-07-31',
|
||||
}
|
||||
|
||||
response = self.llm.completion(**params)
|
||||
|
||||
return self.action_parser.parse(response)
|
||||
|
||||
def _get_messages(self, state: State) -> list[Message]:
|
||||
messages: list[Message] = [
|
||||
Message(
|
||||
role='system',
|
||||
content=[TextContent(text=self.prompt_manager.system_message)],
|
||||
content=[
|
||||
TextContent(
|
||||
text=self.prompt_manager.system_message,
|
||||
cache_prompt=self.llm.supports_prompt_caching, # Cache system prompt
|
||||
)
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=[TextContent(text=self.prompt_manager.initial_user_message)],
|
||||
content=[
|
||||
TextContent(
|
||||
text=self.prompt_manager.initial_user_message,
|
||||
cache_prompt=self.llm.supports_prompt_caching, # if the user asks the same query,
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
@ -214,6 +232,16 @@ class CodeActAgent(Agent):
|
||||
else:
|
||||
messages.append(message)
|
||||
|
||||
# Add caching to the last 2 user messages
|
||||
if self.llm.supports_prompt_caching:
|
||||
user_turns_processed = 0
|
||||
for message in reversed(messages):
|
||||
if message.role == 'user' and user_turns_processed < 2:
|
||||
message.content[
|
||||
-1
|
||||
].cache_prompt = True # Last item inside the message content
|
||||
user_turns_processed += 1
|
||||
|
||||
# the latest user message is important:
|
||||
# we want to remind the agent of the environment constraints
|
||||
latest_user_message = next(
|
||||
@ -225,25 +253,8 @@ class CodeActAgent(Agent):
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Get the last user text inside content
|
||||
if latest_user_message:
|
||||
latest_user_message_text = next(
|
||||
(
|
||||
t
|
||||
for t in reversed(latest_user_message.content)
|
||||
if isinstance(t, TextContent)
|
||||
)
|
||||
)
|
||||
# add a reminder to the prompt
|
||||
reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with <finish></finish>.'
|
||||
|
||||
if latest_user_message_text:
|
||||
latest_user_message_text.text = (
|
||||
latest_user_message_text.text + reminder_text
|
||||
)
|
||||
else:
|
||||
latest_user_message_text = TextContent(text=reminder_text)
|
||||
latest_user_message.content.append(latest_user_message_text)
|
||||
latest_user_message.content.append(TextContent(text=reminder_text))
|
||||
|
||||
return messages
|
||||
|
||||
@ -11,6 +11,7 @@ class ContentType(Enum):
|
||||
|
||||
class Content(BaseModel):
|
||||
type: ContentType
|
||||
cache_prompt: bool = False
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
@ -23,7 +24,13 @@ class TextContent(Content):
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
return {'type': self.type.value, 'text': self.text}
|
||||
data: dict[str, str | dict[str, str]] = {
|
||||
'type': self.type.value,
|
||||
'text': self.text,
|
||||
}
|
||||
if self.cache_prompt:
|
||||
data['cache_control'] = {'type': 'ephemeral'}
|
||||
return data
|
||||
|
||||
|
||||
class ImageContent(Content):
|
||||
@ -35,6 +42,8 @@ class ImageContent(Content):
|
||||
images: list[dict[str, str | dict[str, str]]] = []
|
||||
for url in self.image_urls:
|
||||
images.append({'type': self.type.value, 'image_url': {'url': url}})
|
||||
if self.cache_prompt and images:
|
||||
images[-1]['cache_control'] = {'type': 'ephemeral'}
|
||||
return images
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,11 @@ __all__ = ['LLM']
|
||||
|
||||
message_separator = '\n\n----------\n\n'
|
||||
|
||||
cache_prompting_supported_models = [
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-haiku-20240307',
|
||||
]
|
||||
|
||||
|
||||
class LLM:
|
||||
"""The LLM class represents a Language Model instance.
|
||||
@ -58,6 +63,9 @@ class LLM:
|
||||
self.config = copy.deepcopy(config)
|
||||
self.metrics = metrics if metrics is not None else Metrics()
|
||||
self.cost_metric_supported = True
|
||||
self.supports_prompt_caching = (
|
||||
self.config.model in cache_prompting_supported_models
|
||||
)
|
||||
|
||||
# Set up config attributes with default values to prevent AttributeError
|
||||
LLMConfig.set_missing_attributes(self.config)
|
||||
@ -184,6 +192,7 @@ class LLM:
|
||||
|
||||
# log the response
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
|
||||
llm_response_logger.debug(message_back)
|
||||
|
||||
# post-process to log costs
|
||||
@ -421,19 +430,51 @@ class LLM:
|
||||
def supports_vision(self):
|
||||
return litellm.supports_vision(self.config.model)
|
||||
|
||||
def _post_completion(self, response: str) -> None:
|
||||
def _post_completion(self, response) -> None:
|
||||
"""Post-process the completion response."""
|
||||
try:
|
||||
cur_cost = self.completion_cost(response)
|
||||
except Exception:
|
||||
cur_cost = 0
|
||||
|
||||
stats = ''
|
||||
if self.cost_metric_supported:
|
||||
logger.info(
|
||||
'Cost: %.2f USD | Accumulated Cost: %.2f USD',
|
||||
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
|
||||
cur_cost,
|
||||
self.metrics.accumulated_cost,
|
||||
)
|
||||
|
||||
usage = response.get('usage')
|
||||
|
||||
if usage:
|
||||
input_tokens = usage.get('prompt_tokens')
|
||||
output_tokens = usage.get('completion_tokens')
|
||||
|
||||
if input_tokens:
|
||||
stats += 'Input tokens: ' + str(input_tokens) + '\n'
|
||||
|
||||
if output_tokens:
|
||||
stats += 'Output tokens: ' + str(output_tokens) + '\n'
|
||||
|
||||
model_extra = usage.get('model_extra', {})
|
||||
|
||||
cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
|
||||
if cache_creation_input_tokens:
|
||||
stats += (
|
||||
'Input tokens (cache write): '
|
||||
+ str(cache_creation_input_tokens)
|
||||
+ '\n'
|
||||
)
|
||||
|
||||
cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
|
||||
if cache_read_input_tokens:
|
||||
stats += (
|
||||
'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
|
||||
)
|
||||
|
||||
if stats:
|
||||
logger.info(stats)
|
||||
|
||||
def get_token_count(self, messages):
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
|
||||
210
tests/unit/test_prompt_caching.py
Normal file
210
tests/unit/test_prompt_caching.py
Normal file
@ -0,0 +1,210 @@
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = Mock(spec=LLM)
|
||||
llm.config = LLMConfig(model='claude-3-5-sonnet-20240620')
|
||||
llm.supports_prompt_caching = True
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream(tmp_path):
|
||||
file_store = get_file_store('local', str(tmp_path))
|
||||
return EventStream('test_session', file_store)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def codeact_agent(mock_llm):
|
||||
config = AgentConfig()
|
||||
return CodeActAgent(mock_llm, config)
|
||||
|
||||
|
||||
def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
|
||||
# Add some events to the stream
|
||||
mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
|
||||
mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
|
||||
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
|
||||
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
|
||||
mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(
|
||||
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
|
||||
)
|
||||
|
||||
assert (
|
||||
len(messages) == 6
|
||||
) # System, initial user + user message, agent message, last user message
|
||||
assert messages[0].content[0].cache_prompt
|
||||
assert messages[1].role == 'user'
|
||||
assert messages[1].content[0].text.endswith("LET'S START!")
|
||||
assert messages[1].content[1].text.endswith('Initial user message')
|
||||
assert messages[1].content[0].cache_prompt
|
||||
|
||||
assert messages[3].role == 'user'
|
||||
assert messages[3].content[0].text == ('Hello, agent!')
|
||||
assert messages[4].role == 'assistant'
|
||||
assert messages[4].content[0].text == 'Hello, user!'
|
||||
assert messages[5].role == 'user'
|
||||
assert messages[5].content[0].text.startswith('Laaaaaaaast!')
|
||||
assert messages[5].content[0].cache_prompt
|
||||
assert (
|
||||
messages[5]
|
||||
.content[1]
|
||||
.text.endswith(
|
||||
'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with <finish></finish>.'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
|
||||
# Add multiple user and agent messages
|
||||
for i in range(15):
|
||||
mock_event_stream.add_event(
|
||||
MessageAction(f'User message {i}'), EventSource.USER
|
||||
)
|
||||
mock_event_stream.add_event(
|
||||
MessageAction(f'Agent message {i}'), EventSource.AGENT
|
||||
)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(
|
||||
Mock(history=mock_event_stream, max_iterations=10, iteration=5)
|
||||
)
|
||||
|
||||
# Check that only the last two user messages have cache_prompt=True
|
||||
cached_user_messages = [
|
||||
msg for msg in messages if msg.role == 'user' and msg.content[0].cache_prompt
|
||||
]
|
||||
assert len(cached_user_messages) == 3 # Including the initial system message
|
||||
|
||||
# Verify that these are indeed the last two user messages
|
||||
assert cached_user_messages[0].content[0].text.startswith('Here is an example')
|
||||
assert cached_user_messages[1].content[0].text == 'User message 13'
|
||||
assert cached_user_messages[2].content[0].text.startswith('User message 14')
|
||||
|
||||
|
||||
def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
|
||||
# Add a mix of actions and observations
|
||||
message_action_1 = MessageAction(
|
||||
"Let's list the contents of the current directory."
|
||||
)
|
||||
mock_event_stream.add_event(message_action_1, EventSource.USER)
|
||||
|
||||
cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
|
||||
mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
|
||||
|
||||
cmd_observation_1 = CmdOutputObservation(
|
||||
content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
|
||||
command_id=cmd_action_1._id,
|
||||
command='ls -l',
|
||||
exit_code=0,
|
||||
)
|
||||
mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
|
||||
|
||||
message_action_2 = MessageAction("Now, let's create a new directory.")
|
||||
mock_event_stream.add_event(message_action_2, EventSource.AGENT)
|
||||
|
||||
cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
|
||||
mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
|
||||
|
||||
cmd_observation_2 = CmdOutputObservation(
|
||||
content='',
|
||||
command_id=cmd_action_2._id,
|
||||
command='mkdir new_directory',
|
||||
exit_code=0,
|
||||
)
|
||||
mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(
|
||||
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
|
||||
)
|
||||
|
||||
# Assert the presence of key elements in the messages
|
||||
assert (
|
||||
messages[1]
|
||||
.content[1]
|
||||
.text.startswith("Let's list the contents of the current directory.")
|
||||
) # user, included in the initial message
|
||||
assert any(
|
||||
'List files in current directory\n<execute_bash>\nls -l\n</execute_bash>'
|
||||
in msg.content[0].text
|
||||
for msg in messages
|
||||
) # agent
|
||||
assert any(
|
||||
'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
|
||||
in msg.content[0].text
|
||||
for msg in messages
|
||||
) # user, observation
|
||||
assert any(
|
||||
"Now, let's create a new directory." in msg.content[0].text for msg in messages
|
||||
) # agent
|
||||
assert messages[4].content[1].text.startswith('Create a new directory') # agent
|
||||
assert any(
|
||||
'finished with exit code 0' in msg.content[0].text for msg in messages
|
||||
) # user, observation
|
||||
assert (
|
||||
messages[5].content[0].text.startswith('OBSERVATION:\n\n')
|
||||
) # user, observation
|
||||
|
||||
# prompt cache is added to the system message
|
||||
assert messages[0].content[0].cache_prompt
|
||||
# and the first initial user message
|
||||
assert messages[1].content[0].cache_prompt
|
||||
# and to the last two user messages
|
||||
assert messages[3].content[0].cache_prompt
|
||||
assert messages[5].content[0].cache_prompt
|
||||
|
||||
# reminder is added to the last user message
|
||||
assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
|
||||
|
||||
|
||||
def test_prompt_caching_headers(codeact_agent, mock_event_stream):
|
||||
# Setup
|
||||
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
|
||||
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
|
||||
|
||||
mock_short_term_history = MagicMock()
|
||||
mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
|
||||
|
||||
mock_state = Mock()
|
||||
mock_state.history = mock_short_term_history
|
||||
mock_state.max_iterations = 5
|
||||
mock_state.iteration = 0
|
||||
|
||||
codeact_agent.reset()
|
||||
|
||||
# Replace mock LLM completion with a function that checks headers and returns a structured response
|
||||
def check_headers(**kwargs):
|
||||
assert 'extra_headers' in kwargs
|
||||
assert 'anthropic-beta' in kwargs['extra_headers']
|
||||
assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
|
||||
|
||||
# Create a mock response with the expected structure
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock()]
|
||||
mock_response.choices[0].message = Mock()
|
||||
mock_response.choices[0].message.content = 'Hello! How can I assist you today?'
|
||||
return mock_response
|
||||
|
||||
codeact_agent.llm.completion = check_headers
|
||||
|
||||
# Act
|
||||
result = codeact_agent.step(mock_state)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageAction)
|
||||
assert 'Hello! How can I assist you today?' in result.content
|
||||
Loading…
x
Reference in New Issue
Block a user