From 8d097efb4fdabeb9abd8be388315d49da5478b81 Mon Sep 17 00:00:00 2001 From: Calvin Smith Date: Tue, 18 Feb 2025 11:23:06 -0700 Subject: [PATCH] enh: Refactor `Event` -> `Message` pipeline outside of `CodeActAgent` (#6715) Co-authored-by: Calvin Smith Co-authored-by: Engel Nyst --- .../agenthub/codeact_agent/codeact_agent.py | 392 +++--------------- openhands/core/message_utils.py | 367 ++++++++++++++++ openhands/events/serialization/event.py | 4 +- tests/unit/test_codeact_agent.py | 259 +----------- tests/unit/test_message_utils.py | 271 ++++++++++++ 5 files changed, 693 insertions(+), 600 deletions(-) create mode 100644 openhands/core/message_utils.py create mode 100644 tests/unit/test_message_utils.py diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 5a1f6d54a8..b636e40cb9 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -2,41 +2,21 @@ import json import os from collections import deque -from litellm import ModelResponse - import openhands import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling from openhands.controller.agent import Agent from openhands.controller.state.state import State from openhands.core.config import AgentConfig from openhands.core.logger import openhands_logger as logger -from openhands.core.message import ImageContent, Message, TextContent -from openhands.core.schema import ActionType +from openhands.core.message import Message, TextContent +from openhands.core.message_utils import ( + apply_prompt_caching, + events_to_messages, +) from openhands.events.action import ( Action, - AgentDelegateAction, AgentFinishAction, - BrowseInteractiveAction, - BrowseURLAction, - CmdRunAction, - FileEditAction, - FileReadAction, - IPythonRunCellAction, - MessageAction, ) -from openhands.events.observation import ( - AgentCondensationObservation, - AgentDelegateObservation, - BrowserOutputObservation, - CmdOutputObservation, - FileEditObservation, - FileReadObservation, - IPythonRunCellObservation, - UserRejectObservation, -) -from openhands.events.observation.error import ErrorObservation -from openhands.events.observation.observation import Observation -from openhands.events.serialization.event import truncate_content from openhands.llm.llm import LLM from openhands.memory.condenser import Condenser from openhands.runtime.plugins import ( @@ -113,247 +93,6 @@ class CodeActAgent(Agent): self.condenser = Condenser.from_config(self.config.condenser) logger.debug(f'Using condenser: {self.condenser}') - def get_action_message( - self, - action: Action, - pending_tool_call_action_messages: dict[str, Message], - ) -> list[Message]: - """Converts an action into a message format that can be sent to the LLM. - - This method handles different types of actions and formats them appropriately: - 1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish: - - In function calling mode: Stores the LLM's response in pending_tool_call_action_messages - - In non-function calling mode: Creates a message with the action string - 2. For MessageActions: Creates a message with the text content and optional image content - - Args: - action (Action): The action to convert. Can be one of: - - CmdRunAction: For executing bash commands - - IPythonRunCellAction: For running IPython code - - FileEditAction: For editing files - - FileReadAction: For reading files using openhands-aci commands - - BrowseInteractiveAction: For browsing the web - - AgentFinishAction: For ending the interaction - - MessageAction: For sending messages - pending_tool_call_action_messages (dict[str, Message]): Dictionary mapping response IDs - to their corresponding messages. Used in function calling mode to track tool calls - that are waiting for their results. - - Returns: - list[Message]: A list containing the formatted message(s) for the action. - May be empty if the action is handled as a tool call in function calling mode. - - Note: - In function calling mode, tool-based actions are stored in pending_tool_call_action_messages - rather than being returned immediately. They will be processed later when all corresponding - tool call results are available. - """ - # create a regular message from an event - if isinstance( - action, - ( - AgentDelegateAction, - IPythonRunCellAction, - FileEditAction, - FileReadAction, - BrowseInteractiveAction, - BrowseURLAction, - ), - ) or (isinstance(action, CmdRunAction) and action.source == 'agent'): - tool_metadata = action.tool_call_metadata - assert tool_metadata is not None, ( - 'Tool call metadata should NOT be None when function calling is enabled. Action: ' - + str(action) - ) - - llm_response: ModelResponse = tool_metadata.model_response - assistant_msg = llm_response.choices[0].message - - # Add the LLM message (assistant) that initiated the tool calls - # (overwrites any previous message with the same response_id) - logger.debug( - f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}' - ) - pending_tool_call_action_messages[llm_response.id] = Message( - role=assistant_msg.role, - # tool call content SHOULD BE a string - content=[TextContent(text=assistant_msg.content or '')] - if assistant_msg.content is not None - else [], - tool_calls=assistant_msg.tool_calls, - ) - return [] - elif isinstance(action, AgentFinishAction): - role = 'user' if action.source == 'user' else 'assistant' - - # when agent finishes, it has tool_metadata - # which has already been executed, and it doesn't have a response - # when the user finishes (/exit), we don't have tool_metadata - tool_metadata = action.tool_call_metadata - if tool_metadata is not None: - # take the response message from the tool call - assistant_msg = tool_metadata.model_response.choices[0].message - content = assistant_msg.content or '' - - # save content if any, to thought - if action.thought: - if action.thought != content: - action.thought += '\n' + content - else: - action.thought = content - - # remove the tool call metadata - action.tool_call_metadata = None - return [ - Message( - role=role, - content=[TextContent(text=action.thought)], - ) - ] - elif isinstance(action, MessageAction): - role = 'user' if action.source == 'user' else 'assistant' - content = [TextContent(text=action.content or '')] - if self.llm.vision_is_active() and action.image_urls: - content.append(ImageContent(image_urls=action.image_urls)) - return [ - Message( - role=role, - content=content, - ) - ] - elif isinstance(action, CmdRunAction) and action.source == 'user': - content = [ - TextContent(text=f'User executed the command:\n{action.command}') - ] - return [ - Message( - role='user', - content=content, - ) - ] - return [] - - def get_observation_message( - self, - obs: Observation, - tool_call_id_to_message: dict[str, Message], - ) -> list[Message]: - """Converts an observation into a message format that can be sent to the LLM. - - This method handles different types of observations and formats them appropriately: - - CmdOutputObservation: Formats command execution results with exit codes - - IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images - - FileEditObservation: Formats file editing results - - FileReadObservation: Formats file reading results from openhands-aci - - AgentDelegateObservation: Formats results from delegated agent tasks - - ErrorObservation: Formats error messages from failed actions - - UserRejectObservation: Formats user rejection messages - - In function calling mode, observations with tool_call_metadata are stored in - tool_call_id_to_message for later processing instead of being returned immediately. - - Args: - obs (Observation): The observation to convert - tool_call_id_to_message (dict[str, Message]): Dictionary mapping tool call IDs - to their corresponding messages (used in function calling mode) - - Returns: - list[Message]: A list containing the formatted message(s) for the observation. - May be empty if the observation is handled as a tool response in function calling mode. - - Raises: - ValueError: If the observation type is unknown - """ - message: Message - max_message_chars = self.llm.config.max_message_chars - if isinstance(obs, CmdOutputObservation): - # if it doesn't have tool call metadata, it was triggered by a user action - if obs.tool_call_metadata is None: - text = truncate_content( - f'\nObserved result of command executed by user:\n{obs.to_agent_observation()}', - max_message_chars, - ) - else: - text = truncate_content(obs.to_agent_observation(), max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, IPythonRunCellObservation): - text = obs.content - # replace base64 images with a placeholder - splitted = text.split('\n') - for i, line in enumerate(splitted): - if '![image](data:image/png;base64,' in line: - splitted[i] = ( - '![image](data:image/png;base64, ...) already displayed to user' - ) - text = '\n'.join(splitted) - text = truncate_content(text, max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, FileEditObservation): - text = truncate_content(str(obs), max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, FileReadObservation): - message = Message( - role='user', content=[TextContent(text=obs.content)] - ) # Content is already truncated by openhands-aci - elif isinstance(obs, BrowserOutputObservation): - text = obs.get_agent_obs_text() - if ( - obs.trigger_by_action == ActionType.BROWSE_INTERACTIVE - and obs.set_of_marks is not None - and len(obs.set_of_marks) > 0 - and self.config.enable_som_visual_browsing - and self.llm.vision_is_active() - ): - text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. You may need to scroll to view the remaining portion of the web-page.)\n' - message = Message( - role='user', - content=[ - TextContent(text=text), - ImageContent(image_urls=[obs.set_of_marks]), - ], - ) - else: - message = Message( - role='user', - content=[TextContent(text=text)], - ) - elif isinstance(obs, AgentDelegateObservation): - text = truncate_content( - obs.outputs['content'] if 'content' in obs.outputs else '', - max_message_chars, - ) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, ErrorObservation): - text = truncate_content(obs.content, max_message_chars) - text += '\n[Error occurred in processing last action]' - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, UserRejectObservation): - text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) - text += '\n[Last action has been rejected by the user]' - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, AgentCondensationObservation): - text = truncate_content(obs.content, max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - else: - # If an observation message is not returned, it will cause an error - # when the LLM tries to return the next message - raise ValueError(f'Unknown observation type: {type(obs)}') - - # Update the message as tool response properly - if (tool_call_metadata := obs.tool_call_metadata) is not None: - tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message( - role='tool', - content=message.content, - tool_call_id=tool_call_metadata.tool_call_id, - name=tool_call_metadata.function_name, - ) - # No need to return the observation message - # because it will be added by get_action_message when all the corresponding - # tool calls in the SAME request are processed - return [] - - return [message] - def reset(self) -> None: """Resets the CodeAct Agent.""" super().reset() @@ -429,7 +168,30 @@ class CodeActAgent(Agent): if not self.prompt_manager: raise Exception('Prompt Manager not instantiated.') - messages: list[Message] = [ + messages: list[Message] = self._initial_messages() + + # Condense the events from the state. + events = self.condenser.condensed_history(state) + + messages += events_to_messages( + events, + max_message_chars=self.llm.config.max_message_chars, + vision_is_active=self.llm.vision_is_active(), + enable_som_visual_browsing=self.config.enable_som_visual_browsing, + ) + + messages = self._enhance_messages(messages) + + if self.llm.is_caching_prompt_active(): + apply_prompt_caching(messages) + + return messages + + def _initial_messages(self) -> list[Message]: + """Creates the initial messages (including the system prompt) for the LLM conversation.""" + assert self.prompt_manager, 'Prompt Manager not instantiated.' + + return [ Message( role='system', content=[ @@ -441,84 +203,34 @@ class CodeActAgent(Agent): ) ] - pending_tool_call_action_messages: dict[str, Message] = {} - tool_call_id_to_message: dict[str, Message] = {} + def _enhance_messages(self, messages: list[Message]) -> list[Message]: + """Enhances the user message with additional context based on keywords matched. - # Condense the events from the state. - events = self.condenser.condensed_history(state) + Args: + messages (list[Message]): The list of messages to enhance + Returns: + list[Message]: The enhanced list of messages + """ + assert self.prompt_manager, 'Prompt Manager not instantiated.' + + results: list[Message] = [] is_first_message_handled = False - for event in events: - # create a regular message from an event - if isinstance(event, Action): - messages_to_add = self.get_action_message( - action=event, - pending_tool_call_action_messages=pending_tool_call_action_messages, - ) - elif isinstance(event, Observation): - messages_to_add = self.get_observation_message( - obs=event, - tool_call_id_to_message=tool_call_id_to_message, - ) - else: - raise ValueError(f'Unknown event type: {type(event)}') - # Check pending tool call action messages and see if they are complete - _response_ids_to_remove = [] - for ( - response_id, - pending_message, - ) in pending_tool_call_action_messages.items(): - assert pending_message.tool_calls is not None, ( - 'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. ' - f'Pending message: {pending_message}' - ) - if all( - tool_call.id in tool_call_id_to_message - for tool_call in pending_message.tool_calls - ): - # If complete: - # -- 1. Add the message that **initiated** the tool calls - messages_to_add.append(pending_message) - # -- 2. Add the tool calls **results*** - for tool_call in pending_message.tool_calls: - messages_to_add.append(tool_call_id_to_message[tool_call.id]) - tool_call_id_to_message.pop(tool_call.id) - _response_ids_to_remove.append(response_id) - # Cleanup the processed pending tool messages - for response_id in _response_ids_to_remove: - pending_tool_call_action_messages.pop(response_id) + for msg in messages: + if msg.role == 'user' and not is_first_message_handled: + is_first_message_handled = True + # compose the first user message with examples + self.prompt_manager.add_examples_to_initial_message(msg) - for msg in messages_to_add: - if msg: - if msg.role == 'user' and not is_first_message_handled: - is_first_message_handled = True - # compose the first user message with examples - self.prompt_manager.add_examples_to_initial_message(msg) + # and/or repo/runtime info + if self.config.enable_prompt_extensions: + self.prompt_manager.add_info_to_initial_message(msg) - # and/or repo/runtime info - if self.config.enable_prompt_extensions: - self.prompt_manager.add_info_to_initial_message(msg) + # enhance the user message with additional context based on keywords matched + if msg.role == 'user': + self.prompt_manager.enhance_message(msg) - # enhance the user message with additional context based on keywords matched - if msg.role == 'user': - self.prompt_manager.enhance_message(msg) + results.append(msg) - messages.append(msg) - - if self.llm.is_caching_prompt_active(): - # NOTE: this is only needed for anthropic - # following logic here: - # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 - breakpoints_remaining = 3 # remaining 1 for system/tool - for message in reversed(messages): - if message.role in ('user', 'tool'): - if breakpoints_remaining > 0: - message.content[ - -1 - ].cache_prompt = True # Last item inside the message content - breakpoints_remaining -= 1 - else: - break - - return messages + return results diff --git a/openhands/core/message_utils.py b/openhands/core/message_utils.py new file mode 100644 index 0000000000..25be128731 --- /dev/null +++ b/openhands/core/message_utils.py @@ -0,0 +1,367 @@ +from litellm import ModelResponse + +from openhands.core.logger import openhands_logger as logger +from openhands.core.message import ImageContent, Message, TextContent +from openhands.core.schema import ActionType +from openhands.events.action import ( + Action, + AgentDelegateAction, + AgentFinishAction, + BrowseInteractiveAction, + BrowseURLAction, + CmdRunAction, + FileEditAction, + FileReadAction, + IPythonRunCellAction, + MessageAction, +) +from openhands.events.event import Event +from openhands.events.observation import ( + AgentCondensationObservation, + AgentDelegateObservation, + BrowserOutputObservation, + CmdOutputObservation, + FileEditObservation, + FileReadObservation, + IPythonRunCellObservation, + UserRejectObservation, +) +from openhands.events.observation.error import ErrorObservation +from openhands.events.observation.observation import Observation +from openhands.events.serialization.event import truncate_content + + +def events_to_messages( + events: list[Event], + max_message_chars: int | None = None, + vision_is_active: bool = False, + enable_som_visual_browsing: bool = False, +) -> list[Message]: + """Converts a list of events into a list of messages that can be sent to the LLM. + + Ensures that tool call actions are processed correctly in function calling mode. + + Args: + events: A list of events to convert. Each event can be an Action or Observation. + max_message_chars: The maximum number of characters in the content of an event included in the prompt to the LLM. + Larger observations are truncated. + vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included. + enable_som_visual_browsing: Whether to enable visual browsing for the SOM model. + """ + messages = [] + + pending_tool_call_action_messages: dict[str, Message] = {} + tool_call_id_to_message: dict[str, Message] = {} + + for event in events: + # create a regular message from an event + if isinstance(event, Action): + messages_to_add = get_action_message( + action=event, + pending_tool_call_action_messages=pending_tool_call_action_messages, + vision_is_active=vision_is_active, + ) + elif isinstance(event, Observation): + messages_to_add = get_observation_message( + obs=event, + tool_call_id_to_message=tool_call_id_to_message, + max_message_chars=max_message_chars, + vision_is_active=vision_is_active, + enable_som_visual_browsing=enable_som_visual_browsing, + ) + else: + raise ValueError(f'Unknown event type: {type(event)}') + + # Check pending tool call action messages and see if they are complete + _response_ids_to_remove = [] + for ( + response_id, + pending_message, + ) in pending_tool_call_action_messages.items(): + assert pending_message.tool_calls is not None, ( + 'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. ' + f'Pending message: {pending_message}' + ) + if all( + tool_call.id in tool_call_id_to_message + for tool_call in pending_message.tool_calls + ): + # If complete: + # -- 1. Add the message that **initiated** the tool calls + messages_to_add.append(pending_message) + # -- 2. Add the tool calls **results*** + for tool_call in pending_message.tool_calls: + messages_to_add.append(tool_call_id_to_message[tool_call.id]) + tool_call_id_to_message.pop(tool_call.id) + _response_ids_to_remove.append(response_id) + # Cleanup the processed pending tool messages + for response_id in _response_ids_to_remove: + pending_tool_call_action_messages.pop(response_id) + + messages += messages_to_add + + return messages + + +def get_action_message( + action: Action, + pending_tool_call_action_messages: dict[str, Message], + vision_is_active: bool = False, +) -> list[Message]: + """Converts an action into a message format that can be sent to the LLM. + + This method handles different types of actions and formats them appropriately: + 1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish: + - In function calling mode: Stores the LLM's response in pending_tool_call_action_messages + - In non-function calling mode: Creates a message with the action string + 2. For MessageActions: Creates a message with the text content and optional image content + + Args: + action: The action to convert. Can be one of: + - CmdRunAction: For executing bash commands + - IPythonRunCellAction: For running IPython code + - FileEditAction: For editing files + - FileReadAction: For reading files using openhands-aci commands + - BrowseInteractiveAction: For browsing the web + - AgentFinishAction: For ending the interaction + - MessageAction: For sending messages + + pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages. + Used in function calling mode to track tool calls that are waiting for their results. + + vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included + + Returns: + list[Message]: A list containing the formatted message(s) for the action. + May be empty if the action is handled as a tool call in function calling mode. + + Note: + In function calling mode, tool-based actions are stored in pending_tool_call_action_messages + rather than being returned immediately. They will be processed later when all corresponding + tool call results are available. + """ + # create a regular message from an event + if isinstance( + action, + ( + AgentDelegateAction, + IPythonRunCellAction, + FileEditAction, + FileReadAction, + BrowseInteractiveAction, + BrowseURLAction, + ), + ) or (isinstance(action, CmdRunAction) and action.source == 'agent'): + tool_metadata = action.tool_call_metadata + assert tool_metadata is not None, ( + 'Tool call metadata should NOT be None when function calling is enabled. Action: ' + + str(action) + ) + + llm_response: ModelResponse = tool_metadata.model_response + assistant_msg = llm_response.choices[0].message + + # Add the LLM message (assistant) that initiated the tool calls + # (overwrites any previous message with the same response_id) + logger.debug( + f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}' + ) + pending_tool_call_action_messages[llm_response.id] = Message( + role=assistant_msg.role, + # tool call content SHOULD BE a string + content=[TextContent(text=assistant_msg.content or '')] + if assistant_msg.content is not None + else [], + tool_calls=assistant_msg.tool_calls, + ) + return [] + elif isinstance(action, AgentFinishAction): + role = 'user' if action.source == 'user' else 'assistant' + + # when agent finishes, it has tool_metadata + # which has already been executed, and it doesn't have a response + # when the user finishes (/exit), we don't have tool_metadata + tool_metadata = action.tool_call_metadata + if tool_metadata is not None: + # take the response message from the tool call + assistant_msg = tool_metadata.model_response.choices[0].message + content = assistant_msg.content or '' + + # save content if any, to thought + if action.thought: + if action.thought != content: + action.thought += '\n' + content + else: + action.thought = content + + # remove the tool call metadata + action.tool_call_metadata = None + return [ + Message( + role=role, + content=[TextContent(text=action.thought)], + ) + ] + elif isinstance(action, MessageAction): + role = 'user' if action.source == 'user' else 'assistant' + content = [TextContent(text=action.content or '')] + if vision_is_active and action.image_urls: + content.append(ImageContent(image_urls=action.image_urls)) + return [ + Message( + role=role, + content=content, + ) + ] + elif isinstance(action, CmdRunAction) and action.source == 'user': + content = [TextContent(text=f'User executed the command:\n{action.command}')] + return [ + Message( + role='user', + content=content, + ) + ] + return [] + + +def get_observation_message( + obs: Observation, + tool_call_id_to_message: dict[str, Message], + max_message_chars: int | None = None, + vision_is_active: bool = False, + enable_som_visual_browsing: bool = False, +) -> list[Message]: + """Converts an observation into a message format that can be sent to the LLM. + + This method handles different types of observations and formats them appropriately: + - CmdOutputObservation: Formats command execution results with exit codes + - IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images + - FileEditObservation: Formats file editing results + - FileReadObservation: Formats file reading results from openhands-aci + - AgentDelegateObservation: Formats results from delegated agent tasks + - ErrorObservation: Formats error messages from failed actions + - UserRejectObservation: Formats user rejection messages + + In function calling mode, observations with tool_call_metadata are stored in + tool_call_id_to_message for later processing instead of being returned immediately. + + Args: + obs: The observation to convert + tool_call_id_to_message: Dictionary mapping tool call IDs to their corresponding messages (used in function calling mode) + max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM + vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included + enable_som_visual_browsing: Whether to enable visual browsing for the SOM model + + Returns: + list[Message]: A list containing the formatted message(s) for the observation. + May be empty if the observation is handled as a tool response in function calling mode. + + Raises: + ValueError: If the observation type is unknown + """ + message: Message + + if isinstance(obs, CmdOutputObservation): + # if it doesn't have tool call metadata, it was triggered by a user action + if obs.tool_call_metadata is None: + text = truncate_content( + f'\nObserved result of command executed by user:\n{obs.to_agent_observation()}', + max_message_chars, + ) + else: + text = truncate_content(obs.to_agent_observation(), max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, IPythonRunCellObservation): + text = obs.content + # replace base64 images with a placeholder + splitted = text.split('\n') + for i, line in enumerate(splitted): + if '![image](data:image/png;base64,' in line: + splitted[i] = ( + '![image](data:image/png;base64, ...) already displayed to user' + ) + text = '\n'.join(splitted) + text = truncate_content(text, max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, FileEditObservation): + text = truncate_content(str(obs), max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, FileReadObservation): + message = Message( + role='user', content=[TextContent(text=obs.content)] + ) # Content is already truncated by openhands-aci + elif isinstance(obs, BrowserOutputObservation): + text = obs.get_agent_obs_text() + if ( + obs.trigger_by_action == ActionType.BROWSE_INTERACTIVE + and obs.set_of_marks is not None + and len(obs.set_of_marks) > 0 + and enable_som_visual_browsing + and vision_is_active + ): + text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. You may need to scroll to view the remaining portion of the web-page.)\n' + message = Message( + role='user', + content=[ + TextContent(text=text), + ImageContent(image_urls=[obs.set_of_marks]), + ], + ) + else: + message = Message( + role='user', + content=[TextContent(text=text)], + ) + elif isinstance(obs, AgentDelegateObservation): + text = truncate_content( + obs.outputs['content'] if 'content' in obs.outputs else '', + max_message_chars, + ) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, ErrorObservation): + text = truncate_content(obs.content, max_message_chars) + text += '\n[Error occurred in processing last action]' + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, UserRejectObservation): + text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) + text += '\n[Last action has been rejected by the user]' + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, AgentCondensationObservation): + text = truncate_content(obs.content, max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + else: + # If an observation message is not returned, it will cause an error + # when the LLM tries to return the next message + raise ValueError(f'Unknown observation type: {type(obs)}') + + # Update the message as tool response properly + if (tool_call_metadata := obs.tool_call_metadata) is not None: + tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message( + role='tool', + content=message.content, + tool_call_id=tool_call_metadata.tool_call_id, + name=tool_call_metadata.function_name, + ) + # No need to return the observation message + # because it will be added by get_action_message when all the corresponding + # tool calls in the SAME request are processed + return [] + + return [message] + + +def apply_prompt_caching(messages: list[Message]) -> None: + """Applies caching breakpoints to the messages.""" + # NOTE: this is only needed for anthropic + # following logic here: + # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 + breakpoints_remaining = 3 # remaining 1 for system/tool + for message in reversed(messages): + if message.role in ('user', 'tool'): + if breakpoints_remaining > 0: + message.content[ + -1 + ].cache_prompt = True # Last item inside the message content + breakpoints_remaining -= 1 + else: + break diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index cf2637ab66..71f591fd7d 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -130,9 +130,9 @@ def event_to_memory(event: 'Event', max_message_chars: int) -> dict: return d -def truncate_content(content: str, max_chars: int) -> str: +def truncate_content(content: str, max_chars: int | None = None) -> str: """Truncate the middle of the observation content if it is too long.""" - if len(content) <= max_chars or max_chars == -1: + if max_chars is None or len(content) <= max_chars or max_chars < 0: return content # truncate the middle and include a message to the LLM about it diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 58ce8d8329..675ac17c31 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -19,23 +19,14 @@ from openhands.agenthub.codeact_agent.function_calling import ( from openhands.controller.state.state import State from openhands.core.config import AgentConfig, LLMConfig from openhands.core.exceptions import FunctionCallNotExistsError -from openhands.core.message import ImageContent, TextContent from openhands.events.action import ( - AgentFinishAction, CmdRunAction, MessageAction, ) -from openhands.events.event import EventSource, FileEditSource, FileReadSource -from openhands.events.observation.browse import BrowserOutputObservation +from openhands.events.event import EventSource from openhands.events.observation.commands import ( - CmdOutputMetadata, CmdOutputObservation, - IPythonRunCellObservation, ) -from openhands.events.observation.delegate import AgentDelegateObservation -from openhands.events.observation.error import ErrorObservation -from openhands.events.observation.files import FileEditObservation, FileReadObservation -from openhands.events.observation.reject import UserRejectObservation from openhands.events.tool import ToolCallMetadata from openhands.llm.llm import LLM @@ -59,254 +50,6 @@ def mock_state() -> State: return state -def test_cmd_output_observation_message(agent: CodeActAgent): - obs = CmdOutputObservation( - command='echo hello', - content='Command output', - metadata=CmdOutputMetadata( - exit_code=0, - prefix='[THIS IS PREFIX]', - suffix='[THIS IS SUFFIX]', - ), - ) - - tool_call_id_to_message = {} - results = agent.get_observation_message( - obs, tool_call_id_to_message=tool_call_id_to_message - ) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Observed result of command executed by user:' in result.content[0].text - assert '[Command finished with exit code 0]' in result.content[0].text - assert '[THIS IS PREFIX]' in result.content[0].text - assert '[THIS IS SUFFIX]' in result.content[0].text - - -def test_ipython_run_cell_observation_message(agent: CodeActAgent): - obs = IPythonRunCellObservation( - code='plt.plot()', - content='IPython output\n![image](data:image/png;base64,ABC123)', - ) - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'IPython output' in result.content[0].text - assert ( - '![image](data:image/png;base64, ...) already displayed to user' - in result.content[0].text - ) - assert 'ABC123' not in result.content[0].text - - -def test_agent_delegate_observation_message(agent: CodeActAgent): - obs = AgentDelegateObservation( - content='Content', outputs={'content': 'Delegated agent output'} - ) - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Delegated agent output' in result.content[0].text - - -def test_error_observation_message(agent: CodeActAgent): - obs = ErrorObservation('Error message') - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Error message' in result.content[0].text - assert 'Error occurred in processing last action' in result.content[0].text - - -def test_unknown_observation_message(agent: CodeActAgent): - obs = Mock() - - with pytest.raises(ValueError, match='Unknown observation type'): - agent.get_observation_message(obs, tool_call_id_to_message={}) - - -def test_file_edit_observation_message(agent: CodeActAgent): - obs = FileEditObservation( - path='/test/file.txt', - prev_exist=True, - old_content='old content', - new_content='new content', - content='diff content', - impl_source=FileEditSource.LLM_BASED_EDIT, - ) - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert '[Existing file /test/file.txt is edited with' in result.content[0].text - - -def test_file_read_observation_message(agent: CodeActAgent): - obs = FileReadObservation( - path='/test/file.txt', - content='File content', - impl_source=FileReadSource.DEFAULT, - ) - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'File content' - - -def test_browser_output_observation_message(agent: CodeActAgent): - obs = BrowserOutputObservation( - url='http://example.com', - trigger_by_action='browse', - screenshot='', - content='Page loaded', - error=False, - ) - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert '[Current URL: http://example.com]' in result.content[0].text - - -def test_user_reject_observation_message(agent: CodeActAgent): - obs = UserRejectObservation('Action rejected') - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Action rejected' in result.content[0].text - assert '[Last action has been rejected by the user]' in result.content[0].text - - -def test_function_calling_observation_message(agent: CodeActAgent): - mock_response = { - 'id': 'mock_id', - 'total_calls_in_response': 1, - 'choices': [{'message': {'content': 'Task completed'}}], - } - obs = CmdOutputObservation( - command='echo hello', - content='Command output', - command_id=1, - exit_code=0, - ) - obs.tool_call_metadata = ToolCallMetadata( - tool_call_id='123', - function_name='execute_bash', - model_response=mock_response, - total_calls_in_response=1, - ) - - results = agent.get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 0 # No direct message when using function calling - - -def test_message_action_with_image(agent: CodeActAgent): - action = MessageAction( - content='Message with image', - image_urls=['http://example.com/image.jpg'], - ) - action._source = EventSource.AGENT - - results = agent.get_action_message(action, {}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'assistant' - assert len(result.content) == 2 - assert isinstance(result.content[0], TextContent) - assert isinstance(result.content[1], ImageContent) - assert result.content[0].text == 'Message with image' - assert result.content[1].image_urls == ['http://example.com/image.jpg'] - - -def test_user_cmd_action_message(agent: CodeActAgent): - action = CmdRunAction(command='ls -l') - action._source = EventSource.USER - - results = agent.get_action_message(action, {}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'User executed the command' in result.content[0].text - assert 'ls -l' in result.content[0].text - - -def test_agent_finish_action_with_tool_metadata(agent: CodeActAgent): - mock_response = { - 'id': 'mock_id', - 'total_calls_in_response': 1, - 'choices': [{'message': {'content': 'Task completed'}}], - } - - action = AgentFinishAction(thought='Initial thought') - action._source = EventSource.AGENT - action.tool_call_metadata = ToolCallMetadata( - tool_call_id='123', - function_name='finish', - model_response=mock_response, - total_calls_in_response=1, - ) - - results = agent.get_action_message(action, {}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'assistant' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Initial thought\nTask completed' in result.content[0].text - - def test_reset(agent: CodeActAgent): # Add some state action = MessageAction(content='test') diff --git a/tests/unit/test_message_utils.py b/tests/unit/test_message_utils.py new file mode 100644 index 0000000000..d3114519c8 --- /dev/null +++ b/tests/unit/test_message_utils.py @@ -0,0 +1,271 @@ +from unittest.mock import Mock + +import pytest + +from openhands.core.message import ImageContent, TextContent +from openhands.core.message_utils import get_action_message, get_observation_message +from openhands.events.action import ( + AgentFinishAction, + CmdRunAction, + MessageAction, +) +from openhands.events.event import EventSource, FileEditSource, FileReadSource +from openhands.events.observation.browse import BrowserOutputObservation +from openhands.events.observation.commands import ( + CmdOutputMetadata, + CmdOutputObservation, + IPythonRunCellObservation, +) +from openhands.events.observation.delegate import AgentDelegateObservation +from openhands.events.observation.error import ErrorObservation +from openhands.events.observation.files import FileEditObservation, FileReadObservation +from openhands.events.observation.reject import UserRejectObservation +from openhands.events.tool import ToolCallMetadata + + +def test_cmd_output_observation_message(): + obs = CmdOutputObservation( + command='echo hello', + content='Command output', + metadata=CmdOutputMetadata( + exit_code=0, + prefix='[THIS IS PREFIX]', + suffix='[THIS IS SUFFIX]', + ), + ) + + tool_call_id_to_message = {} + results = get_observation_message( + obs, tool_call_id_to_message=tool_call_id_to_message + ) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Observed result of command executed by user:' in result.content[0].text + assert '[Command finished with exit code 0]' in result.content[0].text + assert '[THIS IS PREFIX]' in result.content[0].text + assert '[THIS IS SUFFIX]' in result.content[0].text + + +def test_ipython_run_cell_observation_message(): + obs = IPythonRunCellObservation( + code='plt.plot()', + content='IPython output\n![image](data:image/png;base64,ABC123)', + ) + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'IPython output' in result.content[0].text + assert ( + '![image](data:image/png;base64, ...) already displayed to user' + in result.content[0].text + ) + assert 'ABC123' not in result.content[0].text + + +def test_agent_delegate_observation_message(): + obs = AgentDelegateObservation( + content='Content', outputs={'content': 'Delegated agent output'} + ) + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Delegated agent output' in result.content[0].text + + +def test_error_observation_message(): + obs = ErrorObservation('Error message') + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Error message' in result.content[0].text + assert 'Error occurred in processing last action' in result.content[0].text + + +def test_unknown_observation_message(): + obs = Mock() + + with pytest.raises(ValueError, match='Unknown observation type'): + get_observation_message(obs, tool_call_id_to_message={}) + + +def test_file_edit_observation_message(): + obs = FileEditObservation( + path='/test/file.txt', + prev_exist=True, + old_content='old content', + new_content='new content', + content='diff content', + impl_source=FileEditSource.LLM_BASED_EDIT, + ) + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert '[Existing file /test/file.txt is edited with' in result.content[0].text + + +def test_file_read_observation_message(): + obs = FileReadObservation( + path='/test/file.txt', + content='File content', + impl_source=FileReadSource.DEFAULT, + ) + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == 'File content' + + +def test_browser_output_observation_message(): + obs = BrowserOutputObservation( + url='http://example.com', + trigger_by_action='browse', + screenshot='', + content='Page loaded', + error=False, + ) + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert '[Current URL: http://example.com]' in result.content[0].text + + +def test_user_reject_observation_message(): + obs = UserRejectObservation('Action rejected') + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Action rejected' in result.content[0].text + assert '[Last action has been rejected by the user]' in result.content[0].text + + +def test_function_calling_observation_message(): + mock_response = { + 'id': 'mock_id', + 'total_calls_in_response': 1, + 'choices': [{'message': {'content': 'Task completed'}}], + } + obs = CmdOutputObservation( + command='echo hello', + content='Command output', + command_id=1, + exit_code=0, + ) + obs.tool_call_metadata = ToolCallMetadata( + tool_call_id='123', + function_name='execute_bash', + model_response=mock_response, + total_calls_in_response=1, + ) + + results = get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 0 # No direct message when using function calling + + +def test_message_action_with_image(): + action = MessageAction( + content='Message with image', + image_urls=['http://example.com/image.jpg'], + ) + action._source = EventSource.AGENT + + results = get_action_message(action, {}, vision_is_active=True) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'assistant' + assert len(result.content) == 2 + assert isinstance(result.content[0], TextContent) + assert isinstance(result.content[1], ImageContent) + assert result.content[0].text == 'Message with image' + assert result.content[1].image_urls == ['http://example.com/image.jpg'] + + +def test_user_cmd_action_message(): + action = CmdRunAction(command='ls -l') + action._source = EventSource.USER + + results = get_action_message(action, {}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'User executed the command' in result.content[0].text + assert 'ls -l' in result.content[0].text + + +def test_agent_finish_action_with_tool_metadata(): + mock_response = { + 'id': 'mock_id', + 'total_calls_in_response': 1, + 'choices': [{'message': {'content': 'Task completed'}}], + } + + action = AgentFinishAction(thought='Initial thought') + action._source = EventSource.AGENT + action.tool_call_metadata = ToolCallMetadata( + tool_call_id='123', + function_name='finish', + model_response=mock_response, + total_calls_in_response=1, + ) + + results = get_action_message(action, {}) + assert len(results) == 1 + + result = results[0] + assert result is not None + assert result.role == 'assistant' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Initial thought\nTask completed' in result.content[0].text