From 97a03faf33c4e48c79a4f84afdec20aab98d9066 Mon Sep 17 00:00:00 2001 From: Cole Murray Date: Tue, 10 Sep 2024 09:34:41 -0700 Subject: [PATCH] Add Handling of Cache Prompt When Formatting Messages (#3773) * Add Handling of Cache Prompt When Formatting Messages * Fix Value for Cache Control * Fix Value for Cache Control * Update openhands/core/message.py Co-authored-by: Engel Nyst * Fix lint error * Serialize Messages if Propt Caching Is Enabled * Remove formatting message change --------- Co-authored-by: Engel Nyst Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com> --- openhands/core/message.py | 7 +++++-- openhands/llm/llm.py | 4 +++- tests/integration/conftest.py | 4 +++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/openhands/core/message.py b/openhands/core/message.py index 325b388474..6e35991c5f 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -72,12 +72,14 @@ class Message(BaseModel): def format_messages( - messages: Union[Message, list[Message]], with_images: bool + messages: Union[Message, list[Message]], + with_images: bool, + with_prompt_caching: bool, ) -> list[dict]: if not isinstance(messages, list): messages = [messages] - if with_images: + if with_images or with_prompt_caching: return [message.model_dump() for message in messages] converted_messages = [] @@ -113,4 +115,5 @@ def format_messages( 'content': content_str, } ) + return converted_messages diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 771522802a..553382e0e5 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -597,4 +597,6 @@ class LLM: def format_messages_for_llm( self, messages: Union[Message, list[Message]] ) -> list[dict]: - return format_messages(messages, self.vision_is_active()) + return format_messages( + messages, self.vision_is_active(), self.is_caching_prompt_active() + ) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index e4c18e9150..9b8387c9db 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -185,7 +185,9 @@ def mock_user_response(*args, test_name, **kwargs): def mock_completion(*args, test_name, **kwargs): global cur_id messages = kwargs['messages'] - plain_messages = format_messages(messages, with_images=False) + plain_messages = format_messages( + messages, with_images=False, with_prompt_caching=False + ) message_str = message_separator.join(msg['content'] for msg in plain_messages) # this assumes all response_(*).log filenames are in numerical order, starting from one