diff --git a/openhands/core/message.py b/openhands/core/message.py index 325b388474..6e35991c5f 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -72,12 +72,14 @@ class Message(BaseModel): def format_messages( - messages: Union[Message, list[Message]], with_images: bool + messages: Union[Message, list[Message]], + with_images: bool, + with_prompt_caching: bool, ) -> list[dict]: if not isinstance(messages, list): messages = [messages] - if with_images: + if with_images or with_prompt_caching: return [message.model_dump() for message in messages] converted_messages = [] @@ -113,4 +115,5 @@ def format_messages( 'content': content_str, } ) + return converted_messages diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 771522802a..553382e0e5 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -597,4 +597,6 @@ class LLM: def format_messages_for_llm( self, messages: Union[Message, list[Message]] ) -> list[dict]: - return format_messages(messages, self.vision_is_active()) + return format_messages( + messages, self.vision_is_active(), self.is_caching_prompt_active() + ) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index e4c18e9150..9b8387c9db 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -185,7 +185,9 @@ def mock_user_response(*args, test_name, **kwargs): def mock_completion(*args, test_name, **kwargs): global cur_id messages = kwargs['messages'] - plain_messages = format_messages(messages, with_images=False) + plain_messages = format_messages( + messages, with_images=False, with_prompt_caching=False + ) message_str = message_separator.join(msg['content'] for msg in plain_messages) # this assumes all response_(*).log filenames are in numerical order, starting from one