Add prompt caching (Sonnet, Haiku only) (#3411)

* Add prompt caching

* remove anthropic-version from extra_headers

* change supports_prompt_caching method to attribute

* change caching strat and log cache statistics

* add reminder as a new message to fix caching

* fix unit test

* append reminder to the end of the last message content

* move token logs to post completion function

* fix unit test failure

* fix reminder and prompt caching

* unit tests for prompt caching

* add test

* clean up tests

* separate reminder, use latest two messages

* fix tests

---------

Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com>
Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Kaushik Deka 2024-08-26 17:46:44 -07:00 committed by GitHub
parent e72dc96d13
commit 5bb931e4d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 300 additions and 29 deletions

View File

@ -172,26 +172,44 @@ class CodeActAgent(Agent):
# prepare what we want to send to the LLM
messages = self._get_messages(state)
response = self.llm.completion(
messages=[message.model_dump() for message in messages],
stop=[
params = {
'messages': [message.model_dump() for message in messages],
'stop': [
'</execute_ipython>',
'</execute_bash>',
'</execute_browse>',
],
temperature=0.0,
)
'temperature': 0.0,
}
if self.llm.supports_prompt_caching:
params['extra_headers'] = {
'anthropic-beta': 'prompt-caching-2024-07-31',
}
response = self.llm.completion(**params)
return self.action_parser.parse(response)
def _get_messages(self, state: State) -> list[Message]:
messages: list[Message] = [
Message(
role='system',
content=[TextContent(text=self.prompt_manager.system_message)],
content=[
TextContent(
text=self.prompt_manager.system_message,
cache_prompt=self.llm.supports_prompt_caching, # Cache system prompt
)
],
),
Message(
role='user',
content=[TextContent(text=self.prompt_manager.initial_user_message)],
content=[
TextContent(
text=self.prompt_manager.initial_user_message,
cache_prompt=self.llm.supports_prompt_caching, # if the user asks the same query,
)
],
),
]
@ -214,6 +232,16 @@ class CodeActAgent(Agent):
else:
messages.append(message)
# Add caching to the last 2 user messages
if self.llm.supports_prompt_caching:
user_turns_processed = 0
for message in reversed(messages):
if message.role == 'user' and user_turns_processed < 2:
message.content[
-1
].cache_prompt = True # Last item inside the message content
user_turns_processed += 1
# the latest user message is important:
# we want to remind the agent of the environment constraints
latest_user_message = next(
@ -225,25 +253,8 @@ class CodeActAgent(Agent):
),
None,
)
# Get the last user text inside content
if latest_user_message:
latest_user_message_text = next(
(
t
for t in reversed(latest_user_message.content)
if isinstance(t, TextContent)
)
)
# add a reminder to the prompt
reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with <finish></finish>.'
if latest_user_message_text:
latest_user_message_text.text = (
latest_user_message_text.text + reminder_text
)
else:
latest_user_message_text = TextContent(text=reminder_text)
latest_user_message.content.append(latest_user_message_text)
latest_user_message.content.append(TextContent(text=reminder_text))
return messages

View File

@ -11,6 +11,7 @@ class ContentType(Enum):
class Content(BaseModel):
type: ContentType
cache_prompt: bool = False
@model_serializer
def serialize_model(self):
@ -23,7 +24,13 @@ class TextContent(Content):
@model_serializer
def serialize_model(self):
return {'type': self.type.value, 'text': self.text}
data: dict[str, str | dict[str, str]] = {
'type': self.type.value,
'text': self.text,
}
if self.cache_prompt:
data['cache_control'] = {'type': 'ephemeral'}
return data
class ImageContent(Content):
@ -35,6 +42,8 @@ class ImageContent(Content):
images: list[dict[str, str | dict[str, str]]] = []
for url in self.image_urls:
images.append({'type': self.type.value, 'image_url': {'url': url}})
if self.cache_prompt and images:
images[-1]['cache_control'] = {'type': 'ephemeral'}
return images

View File

@ -35,6 +35,11 @@ __all__ = ['LLM']
message_separator = '\n\n----------\n\n'
cache_prompting_supported_models = [
'claude-3-5-sonnet-20240620',
'claude-3-haiku-20240307',
]
class LLM:
"""The LLM class represents a Language Model instance.
@ -58,6 +63,9 @@ class LLM:
self.config = copy.deepcopy(config)
self.metrics = metrics if metrics is not None else Metrics()
self.cost_metric_supported = True
self.supports_prompt_caching = (
self.config.model in cache_prompting_supported_models
)
# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)
@ -184,6 +192,7 @@ class LLM:
# log the response
message_back = resp['choices'][0]['message']['content']
llm_response_logger.debug(message_back)
# post-process to log costs
@ -421,19 +430,51 @@ class LLM:
def supports_vision(self):
return litellm.supports_vision(self.config.model)
def _post_completion(self, response: str) -> None:
def _post_completion(self, response) -> None:
"""Post-process the completion response."""
try:
cur_cost = self.completion_cost(response)
except Exception:
cur_cost = 0
stats = ''
if self.cost_metric_supported:
logger.info(
'Cost: %.2f USD | Accumulated Cost: %.2f USD',
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
cur_cost,
self.metrics.accumulated_cost,
)
usage = response.get('usage')
if usage:
input_tokens = usage.get('prompt_tokens')
output_tokens = usage.get('completion_tokens')
if input_tokens:
stats += 'Input tokens: ' + str(input_tokens) + '\n'
if output_tokens:
stats += 'Output tokens: ' + str(output_tokens) + '\n'
model_extra = usage.get('model_extra', {})
cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
if cache_creation_input_tokens:
stats += (
'Input tokens (cache write): '
+ str(cache_creation_input_tokens)
+ '\n'
)
cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
if cache_read_input_tokens:
stats += (
'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
)
if stats:
logger.info(stats)
def get_token_count(self, messages):
"""Get the number of tokens in a list of messages.

View File

@ -0,0 +1,210 @@
from unittest.mock import MagicMock, Mock
import pytest
from agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig, LLMConfig
from openhands.events import EventSource, EventStream
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.llm.llm import LLM
from openhands.storage import get_file_store
@pytest.fixture
def mock_llm():
llm = Mock(spec=LLM)
llm.config = LLMConfig(model='claude-3-5-sonnet-20240620')
llm.supports_prompt_caching = True
return llm
@pytest.fixture
def mock_event_stream(tmp_path):
file_store = get_file_store('local', str(tmp_path))
return EventStream('test_session', file_store)
@pytest.fixture
def codeact_agent(mock_llm):
config = AgentConfig()
return CodeActAgent(mock_llm, config)
def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
# Add some events to the stream
mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT)
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
)
assert (
len(messages) == 6
) # System, initial user + user message, agent message, last user message
assert messages[0].content[0].cache_prompt
assert messages[1].role == 'user'
assert messages[1].content[0].text.endswith("LET'S START!")
assert messages[1].content[1].text.endswith('Initial user message')
assert messages[1].content[0].cache_prompt
assert messages[3].role == 'user'
assert messages[3].content[0].text == ('Hello, agent!')
assert messages[4].role == 'assistant'
assert messages[4].content[0].text == 'Hello, user!'
assert messages[5].role == 'user'
assert messages[5].content[0].text.startswith('Laaaaaaaast!')
assert messages[5].content[0].cache_prompt
assert (
messages[5]
.content[1]
.text.endswith(
'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with <finish></finish>.'
)
)
def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
# Add multiple user and agent messages
for i in range(15):
mock_event_stream.add_event(
MessageAction(f'User message {i}'), EventSource.USER
)
mock_event_stream.add_event(
MessageAction(f'Agent message {i}'), EventSource.AGENT
)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=mock_event_stream, max_iterations=10, iteration=5)
)
# Check that only the last two user messages have cache_prompt=True
cached_user_messages = [
msg for msg in messages if msg.role == 'user' and msg.content[0].cache_prompt
]
assert len(cached_user_messages) == 3 # Including the initial system message
# Verify that these are indeed the last two user messages
assert cached_user_messages[0].content[0].text.startswith('Here is an example')
assert cached_user_messages[1].content[0].text == 'User message 13'
assert cached_user_messages[2].content[0].text.startswith('User message 14')
def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
# Add a mix of actions and observations
message_action_1 = MessageAction(
"Let's list the contents of the current directory."
)
mock_event_stream.add_event(message_action_1, EventSource.USER)
cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory')
mock_event_stream.add_event(cmd_action_1, EventSource.AGENT)
cmd_observation_1 = CmdOutputObservation(
content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt',
command_id=cmd_action_1._id,
command='ls -l',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
message_action_2 = MessageAction("Now, let's create a new directory.")
mock_event_stream.add_event(message_action_2, EventSource.AGENT)
cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory')
mock_event_stream.add_event(cmd_action_2, EventSource.AGENT)
cmd_observation_2 = CmdOutputObservation(
content='',
command_id=cmd_action_2._id,
command='mkdir new_directory',
exit_code=0,
)
mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=mock_event_stream, max_iterations=5, iteration=0)
)
# Assert the presence of key elements in the messages
assert (
messages[1]
.content[1]
.text.startswith("Let's list the contents of the current directory.")
) # user, included in the initial message
assert any(
'List files in current directory\n<execute_bash>\nls -l\n</execute_bash>'
in msg.content[0].text
for msg in messages
) # agent
assert any(
'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
in msg.content[0].text
for msg in messages
) # user, observation
assert any(
"Now, let's create a new directory." in msg.content[0].text for msg in messages
) # agent
assert messages[4].content[1].text.startswith('Create a new directory') # agent
assert any(
'finished with exit code 0' in msg.content[0].text for msg in messages
) # user, observation
assert (
messages[5].content[0].text.startswith('OBSERVATION:\n\n')
) # user, observation
# prompt cache is added to the system message
assert messages[0].content[0].cache_prompt
# and the first initial user message
assert messages[1].content[0].cache_prompt
# and to the last two user messages
assert messages[3].content[0].cache_prompt
assert messages[5].content[0].cache_prompt
# reminder is added to the last user message
assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
def test_prompt_caching_headers(codeact_agent, mock_event_stream):
# Setup
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
mock_short_term_history = MagicMock()
mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!'
mock_state = Mock()
mock_state.history = mock_short_term_history
mock_state.max_iterations = 5
mock_state.iteration = 0
codeact_agent.reset()
# Replace mock LLM completion with a function that checks headers and returns a structured response
def check_headers(**kwargs):
assert 'extra_headers' in kwargs
assert 'anthropic-beta' in kwargs['extra_headers']
assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
# Create a mock response with the expected structure
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message = Mock()
mock_response.choices[0].message.content = 'Hello! How can I assist you today?'
return mock_response
codeact_agent.llm.completion = check_headers
# Act
result = codeact_agent.step(mock_state)
# Assert
assert isinstance(result, MessageAction)
assert 'Hello! How can I assist you today?' in result.content