From ae13171194565f21ae10afe1e9d126173d8fdc1d Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Mon, 28 Oct 2024 22:06:33 -0500 Subject: [PATCH] feat(agent): CodeAct with function calling (#4537) Signed-off-by: dependabot[bot] Co-authored-by: tobitege <10787084+tobitege@users.noreply.github.com> Co-authored-by: Engel Nyst Co-authored-by: tofarr Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- evaluation/integration_tests/run_infer.py | 10 +- evaluation/swe_bench/run_infer.py | 59 +-- ...onvert_oh_folder_to_swebench_submission.sh | 28 ++ evaluation/utils/shared.py | 2 +- .../agenthub/codeact_agent/action_parser.py | 17 + .../agenthub/codeact_agent/codeact_agent.py | 407 +++++++++++++----- .../codeact_agent/function_calling.py | 397 +++++++++++++++++ openhands/controller/agent_controller.py | 6 + openhands/core/config/agent_config.py | 8 + openhands/core/config/llm_config.py | 4 +- openhands/core/logger.py | 12 +- openhands/core/message.py | 59 ++- openhands/core/utils/json.py | 3 + openhands/events/action/commands.py | 3 + openhands/events/event.py | 12 + openhands/events/serialization/event.py | 18 +- openhands/events/tool.py | 11 + openhands/llm/llm.py | 101 +++-- openhands/runtime/action_execution_server.py | 9 +- openhands/runtime/base.py | 5 +- .../impl/eventstream/eventstream_runtime.py | 5 +- .../plugins/agent_skills/agentskills.py | 6 + .../agent_skills/file_editor/README.md | 3 + .../agent_skills/file_editor/__init__.py | 60 +++ .../plugins/agent_skills/file_editor/base.py | 50 +++ .../plugins/agent_skills/file_editor/impl.py | 279 ++++++++++++ .../plugins/agent_skills/file_editor/run.py | 44 ++ poetry.lock | 26 +- pyproject.toml | 4 +- tests/unit/test_agent_skill.py | 300 +++++++++++++ tests/unit/test_codeact_agent.py | 24 +- tests/unit/test_message_serialization.py | 10 +- tests/unit/test_prompt_caching.py | 84 +++- tests/unit/test_security.py | 3 +- 34 files changed, 1834 insertions(+), 235 deletions(-) create mode 100755 evaluation/swe_bench/scripts/eval/convert_oh_folder_to_swebench_submission.sh create mode 100644 openhands/agenthub/codeact_agent/function_calling.py create mode 100644 openhands/events/tool.py create mode 100644 openhands/runtime/plugins/agent_skills/file_editor/README.md create mode 100644 openhands/runtime/plugins/agent_skills/file_editor/__init__.py create mode 100644 openhands/runtime/plugins/agent_skills/file_editor/base.py create mode 100644 openhands/runtime/plugins/agent_skills/file_editor/impl.py create mode 100644 openhands/runtime/plugins/agent_skills/file_editor/run.py diff --git a/evaluation/integration_tests/run_infer.py b/evaluation/integration_tests/run_infer.py index a530041f92..621a0fa91c 100644 --- a/evaluation/integration_tests/run_infer.py +++ b/evaluation/integration_tests/run_infer.py @@ -16,6 +16,7 @@ from evaluation.utils.shared import ( ) from openhands.controller.state.state import State from openhands.core.config import ( + AgentConfig, AppConfig, SandboxConfig, get_llm_config_arg, @@ -24,6 +25,7 @@ from openhands.core.config import ( from openhands.core.logger import openhands_logger as logger from openhands.core.main import create_runtime, run_controller from openhands.events.action import MessageAction +from openhands.events.serialization.event import event_to_dict from openhands.runtime.base import Runtime from openhands.utils.async_utils import call_async_from_sync @@ -60,6 +62,12 @@ def get_config( f'{metadata.llm_config.log_completions_folder}' ) config.set_llm_config(metadata.llm_config) + agent_config = AgentConfig( + codeact_enable_jupyter=True, + codeact_enable_browsing_delegate=True, + codeact_enable_llm_editor=False, + ) + config.set_agent_config(agent_config) return config @@ -122,7 +130,7 @@ def process_instance( # # result evaluation # # ============================================= - histories = state.history.get_events() + histories = [event_to_dict(event) for event in state.history.get_events()] test_result: TestResult = test_class.verify_result(runtime, histories) metrics = state.metrics.get() if state.metrics else None diff --git a/evaluation/swe_bench/run_infer.py b/evaluation/swe_bench/run_infer.py index 9ac1e0cf66..8b8b45a463 100644 --- a/evaluation/swe_bench/run_infer.py +++ b/evaluation/swe_bench/run_infer.py @@ -23,6 +23,7 @@ from evaluation.utils.shared import ( ) from openhands.controller.state.state import State from openhands.core.config import ( + AgentConfig, AppConfig, SandboxConfig, get_llm_config_arg, @@ -45,11 +46,6 @@ AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = { 'CodeActSWEAgent': codeact_user_response, } -AGENT_CLS_TO_INST_SUFFIX = { - 'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: exit .\n', - 'CodeActSWEAgent': 'When you think you have fixed the issue through code changes, please run the following command: exit .\n', -} - def _get_swebench_workspace_dir_name(instance: pd.Series) -> str: return f'{instance.repo}__{instance.version}'.replace('/', '__') @@ -71,25 +67,27 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata): ) instruction += CODEACT_SWE_PROMPT.format(workspace_dir_name=workspace_dir_name) else: - # Testing general agents + # Instruction based on Anthropic's official trajectory + # https://github.com/eschluntz/swe-bench-experiments/tree/main/evaluation/verified/20241022_tools_claude-3-5-sonnet-updated/trajs instruction = ( - f'Please fix the following issue for the repository in /workspace/{workspace_dir_name}.\n' - 'Environment has been set up for you to start working. You may assume all necessary tools are installed.\n\n' - '# Problem Statement\n' - f'{instance.problem_statement}\n\n' + '\n' + f'/workspace/{workspace_dir_name}\n' + '\n' + f"I've uploaded a python code repository in the directory {workspace_dir_name}. Consider the following PR description:\n\n" + f'\n' + f'{instance.problem_statement}\n' + '\n\n' + 'Can you help me implement the necessary changes to the repository so that the requirements specified in the are met?\n' + "I've already taken care of all changes to any of the test files described in the . This means you DON'T have to modify the testing logic or any of the tests in any way!\n" + 'Your task is to make the minimal changes to non-tests files in the /repo directory to ensure the is satisfied.\n' + 'Follow these steps to resolve the issue:\n' + '1. As a first step, it might be a good idea to explore the repo to familiarize yourself with its structure.\n' + '2. Create a script to reproduce the error and execute it with `python ` using the BashTool, to confirm the error\n' + '3. Edit the sourcecode of the repo to resolve the issue\n' + '4. Rerun your reproduce script and confirm that the error is fixed!\n' + '5. Think about edgecases and make sure your fix handles them as well\n' + "Your thinking should be thorough and so it's fine if it's very long.\n" ) - if USE_HINT_TEXT and instance.hints_text: - instruction += f'# Hints\n{instance.hints_text}\n\n' - instruction += ( - 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n' - 'You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.\n' - 'You should verify that the issue is resolved and any new tests you create pass successfully.\n' - 'You should NEVER use web browsing or any other web-based tools.\n' - 'You should ALWAYS use the default Python interpreter available in the environment to run code related to the provided issue and/or repository.\n' - ) - - # NOTE: You can actually set slightly different instruction for different agents - instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class] return instruction @@ -153,6 +151,12 @@ def get_config( f'{metadata.llm_config.log_completions_folder}' ) config.set_llm_config(metadata.llm_config) + agent_config = AgentConfig( + codeact_enable_jupyter=False, + codeact_enable_browsing_delegate=False, + codeact_enable_llm_editor=False, + ) + config.set_agent_config(agent_config) return config @@ -312,7 +316,7 @@ def complete_runtime( obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert_and_raise( - obs.exit_code == 0, + isinstance(obs, CmdOutputObservation) and obs.exit_code == 0, f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}', ) @@ -322,7 +326,7 @@ def complete_runtime( obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) assert_and_raise( - obs.exit_code == 0, + isinstance(obs, CmdOutputObservation) and obs.exit_code == 0, f'Failed to git config --global core.pager "": {str(obs)}', ) @@ -331,7 +335,10 @@ def complete_runtime( logger.info(action, extra={'msg_type': 'ACTION'}) obs = runtime.run_action(action) logger.info(obs, extra={'msg_type': 'OBSERVATION'}) - assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {str(obs)}') + assert_and_raise( + isinstance(obs, CmdOutputObservation) and obs.exit_code == 0, + f'Failed to git add -A: {str(obs)}', + ) n_retries = 0 git_patch = None @@ -404,6 +411,7 @@ def process_instance( if ( state.last_error and 'fatal error during agent execution' in state.last_error + and 'stuck in a loop' not in state.last_error ): raise EvalException('Fatal error detected: ' + state.last_error) @@ -488,6 +496,7 @@ if __name__ == '__main__': llm_config = None if args.llm_config: llm_config = get_llm_config_arg(args.llm_config) + llm_config.log_completions = True if llm_config is None: raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}') diff --git a/evaluation/swe_bench/scripts/eval/convert_oh_folder_to_swebench_submission.sh b/evaluation/swe_bench/scripts/eval/convert_oh_folder_to_swebench_submission.sh new file mode 100755 index 0000000000..8bbaa6ddce --- /dev/null +++ b/evaluation/swe_bench/scripts/eval/convert_oh_folder_to_swebench_submission.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +FOLDER_PATH=$1 +NEW_FOLDER_PATH=${FOLDER_PATH}.swebench_submission +mkdir -p $NEW_FOLDER_PATH + +# Build all_preds.jsonl +poetry run python evaluation/swe_bench/scripts/eval/convert_oh_output_to_swe_json.py $FOLDER_PATH/output.jsonl +mv $FOLDER_PATH/output.swebench.jsonl $NEW_FOLDER_PATH/all_preds.jsonl + +# Build trajs/ +mkdir -p $NEW_FOLDER_PATH/trajs +for instance_dir in $FOLDER_PATH/llm_completions/*/; do + instance_id=$(basename "$instance_dir") + latest_json=$(ls -t "$instance_dir"/*.json | head -n1) + if [ -n "$latest_json" ]; then + cat "$latest_json" | jq -r '.messages' > "$NEW_FOLDER_PATH/trajs/$instance_id.json" + fi +done + +# Build logs/ +# check if $FOLDER_PATH/eval_outputs exists, if so copy over - else raise error +if [ -d "$FOLDER_PATH/eval_outputs" ]; then + cp -r $FOLDER_PATH/eval_outputs $NEW_FOLDER_PATH/logs +else + echo "Error: $FOLDER_PATH/eval_outputs does not exist. You should run the local docker eval_infer.sh first." + exit 1 +fi diff --git a/evaluation/utils/shared.py b/evaluation/utils/shared.py index d33658f339..b8d2ad281a 100644 --- a/evaluation/utils/shared.py +++ b/evaluation/utils/shared.py @@ -104,7 +104,7 @@ def codeact_user_response( ) msg = ( 'Please continue working on the task on whatever approach you think is suitable.\n' - 'If you think you have solved the task, please first send your answer to user through message and then exit .\n' + 'If you think you have solved the task, please first send your answer to user through message and then finish the interaction.\n' f'{encaps_str}' 'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n' ) diff --git a/openhands/agenthub/codeact_agent/action_parser.py b/openhands/agenthub/codeact_agent/action_parser.py index bf926a6739..893effd825 100644 --- a/openhands/agenthub/codeact_agent/action_parser.py +++ b/openhands/agenthub/codeact_agent/action_parser.py @@ -64,6 +64,23 @@ class CodeActResponseParser(ResponseParser): return action_parser.parse(action_str) return self.default_parser.parse(action_str) + def action_to_str(self, action: Action) -> str: + if isinstance(action, CmdRunAction): + return ( + f'{action.thought}\n\n{action.command}\n' + ) + elif isinstance(action, IPythonRunCellAction): + return f'{action.thought}\n\n{action.code}\n' + elif isinstance(action, AgentDelegateAction): + return f'{action.thought}\n\n{action.inputs["task"]}\n' + elif isinstance(action, FileEditAction): + return f'{action.thought}\n\n{action.content}\n' + elif isinstance(action, MessageAction): + return action.content + elif isinstance(action, AgentFinishAction) and action.source == 'agent': + return action.thought + return '' + class CodeActActionParserFinish(ActionParser): """Parser action: diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index c8342ca11f..997d424c51 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -1,10 +1,16 @@ +import json import os +from collections import deque from itertools import islice +from litellm import ModelResponse + +import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling from openhands.agenthub.codeact_agent.action_parser import CodeActResponseParser from openhands.controller.agent import Agent from openhands.controller.state.state import State from openhands.core.config import AgentConfig +from openhands.core.logger import openhands_logger as logger from openhands.core.message import ImageContent, Message, TextContent from openhands.events.action import ( Action, @@ -36,7 +42,7 @@ from openhands.utils.prompt import PromptManager class CodeActAgent(Agent): - VERSION = '2.0' + VERSION = '2.1' """ The Code Act Agent is a minimalist agent. The agent works by passing the model a list of action-observation pairs and prompting the model to take the next step. @@ -63,8 +69,7 @@ class CodeActAgent(Agent): AgentSkillsRequirement(), JupyterRequirement(), ] - - action_parser = CodeActResponseParser() + obs_prefix = 'OBSERVATION:\n' def __init__( self, @@ -88,66 +93,165 @@ class CodeActAgent(Agent): if config.micro_agent_name else None ) - - self.prompt_manager = PromptManager( - prompt_dir=os.path.join(os.path.dirname(__file__)), - agent_skills_docs=AgentSkillsRequirement.documentation, - micro_agent=self.micro_agent, - ) - - def action_to_str(self, action: Action) -> str: - if isinstance(action, CmdRunAction): - return ( - f'{action.thought}\n\n{action.command}\n' + if ( + self.config.function_calling + and not self.llm.config.supports_function_calling + ): + logger.warning( + f'Function calling not supported for model {self.llm.config.model}. ' + 'Disabling function calling.' ) - elif isinstance(action, IPythonRunCellAction): - return f'{action.thought}\n\n{action.code}\n' - elif isinstance(action, AgentDelegateAction): - return f'{action.thought}\n\n{action.inputs["task"]}\n' - elif isinstance(action, FileEditAction): - return f'{action.thought}\n\n{action.content}\n' - elif isinstance(action, MessageAction): - return action.content - elif isinstance(action, AgentFinishAction) and action.source == 'agent': - return action.thought - return '' + self.config.function_calling = False - def get_action_message(self, action: Action) -> Message | None: + if self.config.function_calling: + # Function calling mode + self.tools = codeact_function_calling.get_tools( + codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate, + codeact_enable_jupyter=self.config.codeact_enable_jupyter, + codeact_enable_llm_editor=self.config.codeact_enable_llm_editor, + ) + logger.info( + f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2)}' + ) + self.system_prompt = codeact_function_calling.SYSTEM_PROMPT + self.initial_user_message = None + else: + # Non-function-calling mode + self.action_parser = CodeActResponseParser() + self.prompt_manager = PromptManager( + prompt_dir=os.path.join(os.path.dirname(__file__)), + agent_skills_docs=AgentSkillsRequirement.documentation, + micro_agent=self.micro_agent, + ) + self.system_prompt = self.prompt_manager.system_message + self.initial_user_message = self.prompt_manager.initial_user_message + + self.pending_actions: deque[Action] = deque() + + def get_action_message( + self, + action: Action, + pending_tool_call_action_messages: dict[str, Message], + ) -> list[Message]: + """Converts an action into a message format that can be sent to the LLM. + + This method handles different types of actions and formats them appropriately: + 1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish: + - In function calling mode: Stores the LLM's response in pending_tool_call_action_messages + - In non-function calling mode: Creates a message with the action string + 2. For MessageActions: Creates a message with the text content and optional image content + + Args: + action (Action): The action to convert. Can be one of: + - AgentDelegateAction: For delegating tasks to other agents + - CmdRunAction: For executing bash commands + - IPythonRunCellAction: For running IPython code + - FileEditAction: For editing files + - AgentFinishAction: For ending the interaction + - MessageAction: For sending messages + pending_tool_call_action_messages (dict[str, Message]): Dictionary mapping response IDs + to their corresponding messages. Used in function calling mode to track tool calls + that are waiting for their results. + + Returns: + list[Message]: A list containing the formatted message(s) for the action. + May be empty if the action is handled as a tool call in function calling mode. + + Note: + In function calling mode, tool-based actions are stored in pending_tool_call_action_messages + rather than being returned immediately. They will be processed later when all corresponding + tool call results are available. + """ + # create a regular message from an event if isinstance( action, ( AgentDelegateAction, CmdRunAction, IPythonRunCellAction, - MessageAction, FileEditAction, ), ) or (isinstance(action, AgentFinishAction) and action.source == 'agent'): - content = [TextContent(text=self.action_to_str(action))] + if self.config.function_calling: + tool_metadata = action.tool_call_metadata + assert tool_metadata is not None, ( + 'Tool call metadata should NOT be None when function calling is enabled. Action: ' + + str(action) + ) - if ( - self.llm.vision_is_active() - and isinstance(action, MessageAction) - and action.images_urls - ): + llm_response: ModelResponse = tool_metadata.model_response + assistant_msg = llm_response.choices[0].message + # Add the LLM message (assistant) that initiated the tool calls + # (overwrites any previous message with the same response_id) + pending_tool_call_action_messages[llm_response.id] = Message( + role=assistant_msg.role, + # tool call content SHOULD BE a string + content=[TextContent(text=assistant_msg.content)] + if assistant_msg.content is not None + else [], + tool_calls=assistant_msg.tool_calls, + ) + return [] + else: + content = [TextContent(text=self.action_parser.action_to_str(action))] + return [ + Message( + role='user' if action.source == 'user' else 'assistant', + content=content, + ) + ] + elif isinstance(action, MessageAction): + role = 'user' if action.source == 'user' else 'assistant' + content = [TextContent(text=action.content)] + if self.llm.vision_is_active() and action.images_urls: content.append(ImageContent(image_urls=action.images_urls)) + return [ + Message( + role=role, + content=content, + ) + ] + return [] - return Message( - role='user' if action.source == 'user' else 'assistant', content=content - ) - return None + def get_observation_message( + self, + obs: Observation, + tool_call_id_to_message: dict[str, Message], + ) -> list[Message]: + """Converts an observation into a message format that can be sent to the LLM. - def get_observation_message(self, obs: Observation) -> Message | None: + This method handles different types of observations and formats them appropriately: + - CmdOutputObservation: Formats command execution results with exit codes + - IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images + - FileEditObservation: Formats file editing results + - AgentDelegateObservation: Formats results from delegated agent tasks + - ErrorObservation: Formats error messages from failed actions + - UserRejectObservation: Formats user rejection messages + + In function calling mode, observations with tool_call_metadata are stored in + tool_call_id_to_message for later processing instead of being returned immediately. + + Args: + obs (Observation): The observation to convert + tool_call_id_to_message (dict[str, Message]): Dictionary mapping tool call IDs + to their corresponding messages (used in function calling mode) + + Returns: + list[Message]: A list containing the formatted message(s) for the observation. + May be empty if the observation is handled as a tool response in function calling mode. + + Raises: + ValueError: If the observation type is unknown + """ + message: Message max_message_chars = self.llm.config.max_message_chars obs_prefix = 'OBSERVATION:\n' if isinstance(obs, CmdOutputObservation): text = obs_prefix + truncate_content( obs.content + obs.interpreter_details, max_message_chars ) - text += ( - f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]' - ) - return Message(role='user', content=[TextContent(text=text)]) + text += f'\n[Command finished with exit code {obs.exit_code}]' + message = Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, IPythonRunCellObservation): text = obs_prefix + obs.content # replace base64 images with a placeholder @@ -159,29 +263,45 @@ class CodeActAgent(Agent): ) text = '\n'.join(splitted) text = truncate_content(text, max_message_chars) - return Message(role='user', content=[TextContent(text=text)]) + message = Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, FileEditObservation): text = obs_prefix + truncate_content(str(obs), max_message_chars) - return Message(role='user', content=[TextContent(text=text)]) + message = Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, AgentDelegateObservation): text = obs_prefix + truncate_content( obs.outputs['content'] if 'content' in obs.outputs else '', max_message_chars, ) - return Message(role='user', content=[TextContent(text=text)]) + message = Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, ErrorObservation): text = obs_prefix + truncate_content(obs.content, max_message_chars) text += '\n[Error occurred in processing last action]' - return Message(role='user', content=[TextContent(text=text)]) + message = Message(role='user', content=[TextContent(text=text)]) elif isinstance(obs, UserRejectObservation): text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) text += '\n[Last action has been rejected by the user]' - return Message(role='user', content=[TextContent(text=text)]) + message = Message(role='user', content=[TextContent(text=text)]) else: # If an observation message is not returned, it will cause an error # when the LLM tries to return the next message raise ValueError(f'Unknown observation type: {type(obs)}') + if self.config.function_calling: + # Update the message as tool response properly + if (tool_call_metadata := obs.tool_call_metadata) is not None: + tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message( + role='tool', + content=message.content, + tool_call_id=tool_call_metadata.tool_call_id, + name=tool_call_metadata.function_name, + ) + # No need to return the observation message + # because it will be added by get_action_message when all the corresponding + # tool calls in the SAME request are processed + return [] + + return [message] + def reset(self) -> None: """Resets the CodeAct Agent.""" super().reset() @@ -200,6 +320,10 @@ class CodeActAgent(Agent): - MessageAction(content) - Message action to run (e.g. ask for clarification) - AgentFinishAction() - end the interaction """ + # Continue with pending actions if any + if self.pending_actions: + return self.pending_actions.popleft() + # if we're done, go back latest_user_message = state.history.get_last_user_message() if latest_user_message and latest_user_message.strip() == '/exit': @@ -207,87 +331,172 @@ class CodeActAgent(Agent): # prepare what we want to send to the LLM messages = self._get_messages(state) - params = { + params: dict = { 'messages': self.llm.format_messages_for_llm(messages), - 'stop': [ + } + if self.config.function_calling: + params['tools'] = self.tools + else: + params['stop'] = [ '', '', '', '', - ], - } - + ] response = self.llm.completion(**params) - return self.action_parser.parse(response) + if self.config.function_calling: + actions = codeact_function_calling.response_to_actions(response) + for action in actions: + self.pending_actions.append(action) + return self.pending_actions.popleft() + else: + return self.action_parser.parse(response) def _get_messages(self, state: State) -> list[Message]: + """Constructs the message history for the LLM conversation. + + This method builds a structured conversation history by processing events from the state + and formatting them into messages that the LLM can understand. It handles both regular + message flow and function-calling scenarios. + + The method performs the following steps: + 1. Initializes with system prompt and optional initial user message + 2. Processes events (Actions and Observations) into messages + 3. Handles tool calls and their responses in function-calling mode + 4. Manages message role alternation (user/assistant/tool) + 5. Applies caching for specific LLM providers (e.g., Anthropic) + 6. Adds environment reminders for non-function-calling mode + + Args: + state (State): The current state object containing conversation history and other metadata + + Returns: + list[Message]: A list of formatted messages ready for LLM consumption, including: + - System message with prompt + - Initial user message (if configured) + - Action messages (from both user and assistant) + - Observation messages (including tool responses) + - Environment reminders (in non-function-calling mode) + + Note: + - In function-calling mode, tool calls and their responses are carefully tracked + to maintain proper conversation flow + - Messages from the same role are combined to prevent consecutive same-role messages + - For Anthropic models, specific messages are cached according to their documentation + """ messages: list[Message] = [ Message( role='system', content=[ TextContent( - text=self.prompt_manager.system_message, + text=self.system_prompt, cache_prompt=self.llm.is_caching_prompt_active(), # Cache system prompt ) ], - ), - Message( - role='user', - content=[ - TextContent( - text=self.prompt_manager.initial_user_message, - cache_prompt=self.llm.is_caching_prompt_active(), # if the user asks the same query, - ) - ], - ), + ) ] + if self.initial_user_message: + messages.append( + Message( + role='user', + content=[TextContent(text=self.initial_user_message)], + ) + ) - for event in state.history.get_events(): + pending_tool_call_action_messages: dict[str, Message] = {} + tool_call_id_to_message: dict[str, Message] = {} + events = list(state.history.get_events()) + for event in events: # create a regular message from an event if isinstance(event, Action): - message = self.get_action_message(event) + messages_to_add = self.get_action_message( + action=event, + pending_tool_call_action_messages=pending_tool_call_action_messages, + ) elif isinstance(event, Observation): - message = self.get_observation_message(event) + messages_to_add = self.get_observation_message( + obs=event, + tool_call_id_to_message=tool_call_id_to_message, + ) else: raise ValueError(f'Unknown event type: {type(event)}') - # add regular message - if message: - # handle error if the message is the SAME role as the previous message - # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'} - # there shouldn't be two consecutive messages from the same role - if messages and messages[-1].role == message.role: - messages[-1].content.extend(message.content) - else: - messages.append(message) + # Check pending tool call action messages and see if they are complete + _response_ids_to_remove = [] + for ( + response_id, + pending_message, + ) in pending_tool_call_action_messages.items(): + assert pending_message.tool_calls is not None, ( + 'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. ' + f'Pending message: {pending_message}' + ) + if all( + tool_call.id in tool_call_id_to_message + for tool_call in pending_message.tool_calls + ): + # If complete: + # -- 1. Add the message that **initiated** the tool calls + messages_to_add.append(pending_message) + # -- 2. Add the tool calls **results*** + for tool_call in pending_message.tool_calls: + messages_to_add.append(tool_call_id_to_message[tool_call.id]) + tool_call_id_to_message.pop(tool_call.id) + _response_ids_to_remove.append(response_id) + # Cleanup the processed pending tool messages + for response_id in _response_ids_to_remove: + pending_tool_call_action_messages.pop(response_id) + + for message in messages_to_add: + # add regular message + if message: + # handle error if the message is the SAME role as the previous message + # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'} + # there shouldn't be two consecutive messages from the same role + # NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id + if ( + messages + and messages[-1].role == message.role + and message.role != 'tool' + ): + messages[-1].content.extend(message.content) + else: + messages.append(message) - # Add caching to the last 2 user messages if self.llm.is_caching_prompt_active(): - user_turns_processed = 0 + # NOTE: this is only needed for anthropic + # following logic here: + # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 + breakpoints_remaining = 3 # remaining 1 for system/tool for message in reversed(messages): - if message.role == 'user' and user_turns_processed < 2: - message.content[ - -1 - ].cache_prompt = True # Last item inside the message content - user_turns_processed += 1 + if message.role == 'user' or message.role == 'tool': + if breakpoints_remaining > 0: + message.content[ + -1 + ].cache_prompt = True # Last item inside the message content + breakpoints_remaining -= 1 + else: + break - # The latest user message is important: - # we want to remind the agent of the environment constraints - latest_user_message = next( - islice( - ( - m - for m in reversed(messages) - if m.role == 'user' - and any(isinstance(c, TextContent) for c in m.content) + if not self.config.function_calling: + # The latest user message is important: + # we want to remind the agent of the environment constraints + latest_user_message = next( + islice( + ( + m + for m in reversed(messages) + if m.role == 'user' + and any(isinstance(c, TextContent) for c in m.content) + ), + 1, ), - 1, - ), - None, - ) - if latest_user_message: - reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with .' - latest_user_message.content.append(TextContent(text=reminder_text)) + None, + ) + # do not add this for function calling + if latest_user_message: + reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with .' + latest_user_message.content.append(TextContent(text=reminder_text)) return messages diff --git a/openhands/agenthub/codeact_agent/function_calling.py b/openhands/agenthub/codeact_agent/function_calling.py new file mode 100644 index 0000000000..f5519124ac --- /dev/null +++ b/openhands/agenthub/codeact_agent/function_calling.py @@ -0,0 +1,397 @@ +"""This file contains the function calling implementation for different actions. + +This is similar to the functionality of `CodeActResponseParser`. +""" + +import json + +from litellm import ( + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, + ModelResponse, +) + +from openhands.core.logger import openhands_logger as logger +from openhands.events.action import ( + Action, + AgentDelegateAction, + AgentFinishAction, + CmdRunAction, + FileEditAction, + IPythonRunCellAction, + MessageAction, +) +from openhands.events.tool import ToolCallMetadata + +SYSTEM_PROMPT = """You are a helpful assistant that can interact with a computer to solve tasks. + +* If user provides a path, you should NOT assume it's relative to the current working directory. Instead, you should explore the file system to find the file before working on it. + +""" + +_BASH_DESCRIPTION = """Execute a bash command in the terminal. +* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`. +* Interactive: If a bash command returns exit code `-1`, this means the process is not yet finished. The assistant must then send a second call to terminal with an empty `command` (which will retrieve any additional logs), or it can send additional text (set `command` to the text) to STDIN of the running process, or it can send command=`ctrl+c` to interrupt the process. +* Timeout: If a command execution result says "Command timed out. Sending SIGINT to the process", the assistant should retry running the command in the background. +""" + +CmdRunTool = ChatCompletionToolParam( + type='function', + function=ChatCompletionToolParamFunctionChunk( + name='execute_bash', + description=_BASH_DESCRIPTION, + parameters={ + 'type': 'object', + 'properties': { + 'command': { + 'type': 'string', + 'description': 'The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process.', + }, + }, + 'required': ['command'], + }, + ), +) + +_IPYTHON_DESCRIPTION = """Run a cell of Python code in an IPython environment. +* The assistant should define variables and import packages before using them. +* The variable defined in the IPython environment will not be available outside the IPython environment (e.g., in terminal). +""" +# We are not using agentskills's file_ops for viewing files now because StrReplaceEditorTool already supports viewing files +# """* Apart from the standard Python library, the assistant can also use the following functions (already imported): +# {AgentSkillsRequirement.documentation}""" + +IPythonTool = ChatCompletionToolParam( + type='function', + function=ChatCompletionToolParamFunctionChunk( + name='execute_ipython_cell', + description=_IPYTHON_DESCRIPTION, + parameters={ + 'type': 'object', + 'properties': { + 'code': { + 'type': 'string', + 'description': 'The Python code to execute. Supports magic commands like %pip.', + }, + }, + 'required': ['code'], + }, + ), +) + +_FILE_EDIT_DESCRIPTION = """Edit a file. +* The assistant can edit files by specifying the file path and providing a draft of the new file content. +* The draft content doesn't need to be exactly the same as the existing file; the assistant may skip unchanged lines using comments like `# unchanged` to indicate unchanged sections. +* IMPORTANT: For large files (e.g., > 300 lines), specify the range of lines to edit using `start` and `end` (1-indexed, inclusive). The range should be smaller than 300 lines. +* To append to a file, set both `start` and `end` to `-1`. +* If the file doesn't exist, a new file will be created with the provided content. + +**Example 1: general edit for short files** +For example, given an existing file `/path/to/file.py` that looks like this: +(this is the end of the file) +1|class MyClass: +2| def __init__(self): +3| self.x = 1 +4| self.y = 2 +5| self.z = 3 +6| +7|print(MyClass().z) +8|print(MyClass().x) +(this is the end of the file) + +The assistant wants to edit the file to look like this: +(this is the end of the file) +1|class MyClass: +2| def __init__(self): +3| self.x = 1 +4| self.y = 2 +5| +6|print(MyClass().y) +(this is the end of the file) + +The assistant may produce an edit action like this: +path="/path/to/file.txt" start=1 end=-1 +content=``` +class MyClass: + def __init__(self): + # no changes before + self.y = 2 + # self.z is removed + +# MyClass().z is removed +print(MyClass().y) +``` + +**Example 2: append to file for short files** +For example, given an existing file `/path/to/file.py` that looks like this: +(this is the end of the file) +1|class MyClass: +2| def __init__(self): +3| self.x = 1 +4| self.y = 2 +5| self.z = 3 +6| +7|print(MyClass().z) +8|print(MyClass().x) +(this is the end of the file) + +To append the following lines to the file: +```python +print(MyClass().y) +``` + +The assistant may produce an edit action like this: +path="/path/to/file.txt" start=-1 end=-1 +content=``` +print(MyClass().y) +``` + +**Example 3: edit for long files** + +Given an existing file `/path/to/file.py` that looks like this: +(1000 more lines above) +1001|class MyClass: +1002| def __init__(self): +1003| self.x = 1 +1004| self.y = 2 +1005| self.z = 3 +1006| +1007|print(MyClass().z) +1008|print(MyClass().x) +(2000 more lines below) + +The assistant wants to edit the file to look like this: + +(1000 more lines above) +1001|class MyClass: +1002| def __init__(self): +1003| self.x = 1 +1004| self.y = 2 +1005| +1006|print(MyClass().y) +(2000 more lines below) + +The assistant may produce an edit action like this: +path="/path/to/file.txt" start=1001 end=1008 +content=``` +class MyClass: + def __init__(self): + # no changes before + self.y = 2 + # self.z is removed + +# MyClass().z is removed +print(MyClass().y) +``` +""" + +LLMBasedFileEditTool = ChatCompletionToolParam( + type='function', + function=ChatCompletionToolParamFunctionChunk( + name='edit_file', + description=_FILE_EDIT_DESCRIPTION, + parameters={ + 'type': 'object', + 'properties': { + 'path': { + 'type': 'string', + 'description': 'The absolute path to the file to be edited.', + }, + 'new_content_draft': { + 'type': 'string', + 'description': 'A draft of the new content for the file being edited. Note that the assistant may skip unchanged lines.', + }, + 'start': { + 'type': 'integer', + 'description': 'The starting line number for the edit (1-indexed, inclusive). Default is 1.', + }, + 'end': { + 'type': 'integer', + 'description': 'The ending line number for the edit (1-indexed, inclusive). Default is -1 (end of file).', + }, + }, + 'required': ['path', 'content'], + }, + ), +) + +_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files +* State is persistent across command calls and discussions with the user +* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep +* The `create` command cannot be used if the specified `path` already exists as a file +* If a `command` generates a long output, it will be truncated and marked with `` +* The `undo_edit` command will revert the last edit made to the file at `path` + +Notes for using the `str_replace` command: +* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces! +* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique +* The `new_str` parameter should contain the edited lines that should replace the `old_str` +""" + +StrReplaceEditorTool = ChatCompletionToolParam( + type='function', + function=ChatCompletionToolParamFunctionChunk( + name='str_replace_editor', + description=_STR_REPLACE_EDITOR_DESCRIPTION, + parameters={ + 'type': 'object', + 'properties': { + 'command': { + 'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.', + 'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'], + 'type': 'string', + }, + 'path': { + 'description': 'Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.', + 'type': 'string', + }, + 'file_text': { + 'description': 'Required parameter of `create` command, with the content of the file to be created.', + 'type': 'string', + }, + 'old_str': { + 'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.', + 'type': 'string', + }, + 'new_str': { + 'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.', + 'type': 'string', + }, + 'insert_line': { + 'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.', + 'type': 'integer', + }, + 'view_range': { + 'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.', + 'items': {'type': 'integer'}, + 'type': 'array', + }, + }, + 'required': ['command', 'path'], + }, + ), +) + +_BROWSER_DELEGATION = """Delegate the task to another browsing agent. +The assistant should delegate the task if it needs to browse the Internet. +""" + +BrowserDelegationTool = ChatCompletionToolParam( + type='function', + function=ChatCompletionToolParamFunctionChunk( + name='delegate_to_browsing_agent', + description=_BROWSER_DELEGATION, + parameters={ + 'type': 'object', + 'properties': { + 'task': { + 'type': 'string', + 'description': 'The task for the browsing agent to execute. It should include all the necessary context and specify what information the browsing agent should return.', + }, + }, + 'required': ['task'], + }, + ), +) + +_FINISH_DESCRIPTION = """Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task.""" + +FinishTool = ChatCompletionToolParam( + type='function', + function=ChatCompletionToolParamFunctionChunk( + name='finish', + description=_FINISH_DESCRIPTION, + ), +) + + +def combine_thought(action: Action, thought: str) -> Action: + if not hasattr(action, 'thought'): + return action + if thought: + action.thought = thought + return action + + +def response_to_actions(response: ModelResponse) -> list[Action]: + actions: list[Action] = [] + assert len(response.choices) == 1, 'Only one choice is supported for now' + assistant_msg = response.choices[0].message + if assistant_msg.tool_calls: + # Check if there's assistant_msg.content. If so, add it to the thought + thought = '' + if isinstance(assistant_msg.content, str): + thought = assistant_msg.content + elif isinstance(assistant_msg.content, list): + for msg in assistant_msg.content: + if msg['type'] == 'text': + thought += msg['text'] + + # Process each tool call to OpenHands action + for i, tool_call in enumerate(assistant_msg.tool_calls): + action: Action + try: + arguments = json.loads(tool_call.function.arguments) + except json.decoder.JSONDecodeError as e: + raise RuntimeError( + f'Failed to parse tool call arguments: {tool_call.function.arguments}' + ) from e + if tool_call.function.name == 'execute_bash': + action = CmdRunAction(**arguments) + elif tool_call.function.name == 'execute_ipython_cell': + action = IPythonRunCellAction(**arguments) + elif tool_call.function.name == 'delegate_to_browsing_agent': + action = AgentDelegateAction( + agent='BrowsingAgent', + inputs=arguments, + ) + elif tool_call.function.name == 'finish': + action = AgentFinishAction() + elif tool_call.function.name == 'edit_file': + action = FileEditAction(**arguments) + elif tool_call.function.name == 'str_replace_editor': + # We implement this in agent_skills, which can be used via Jupyter + # convert tool_call.function.arguments to kwargs that can be passed to file_editor + code = f'print(file_editor(**{arguments}))' + logger.debug( + f'TOOL CALL: str_replace_editor -> file_editor with code: {code}' + ) + action = IPythonRunCellAction(code=code, include_extra=False) + else: + raise RuntimeError(f'Unknown tool call: {tool_call.function.name}') + + # We only add thought to the first action + if i == 0: + action = combine_thought(action, thought) + # Add metadata for tool calling + action.tool_call_metadata = ToolCallMetadata( + tool_call_id=tool_call.id, + function_name=tool_call.function.name, + model_response=response, + total_calls_in_response=len(assistant_msg.tool_calls), + ) + actions.append(action) + else: + actions.append( + MessageAction(content=assistant_msg.content, wait_for_response=True) + ) + + assert len(actions) >= 1 + return actions + + +def get_tools( + codeact_enable_browsing_delegate: bool = False, + codeact_enable_llm_editor: bool = False, + codeact_enable_jupyter: bool = False, +) -> list[ChatCompletionToolParam]: + tools = [CmdRunTool, FinishTool] + if codeact_enable_browsing_delegate: + tools.append(BrowserDelegationTool) + if codeact_enable_jupyter: + tools.append(IPythonTool) + if codeact_enable_llm_editor: + tools.append(LLMBasedFileEditTool) + else: + tools.append(StrReplaceEditorTool) + return tools diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 0660cd0309..7b7522e33a 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -453,6 +453,12 @@ class AgentController: # and send the underlying exception to the LLM for self-correction await self.report_error(str(e)) return + # FIXME: more graceful handling of litellm.exceptions.ContextWindowExceededError + # e.g. try to condense the memory and try again + except litellm.exceptions.ContextWindowExceededError as e: + self.state.last_error = str(e) + await self.set_agent_state_to(AgentState.ERROR) + return if action.runnable: if self.state.confirmation_mode and ( diff --git a/openhands/core/config/agent_config.py b/openhands/core/config/agent_config.py index 839d09277e..db6225be8c 100644 --- a/openhands/core/config/agent_config.py +++ b/openhands/core/config/agent_config.py @@ -8,12 +8,20 @@ class AgentConfig: """Configuration for the agent. Attributes: + function_calling: Whether function calling is enabled. Default is True. + codeact_enable_browsing_delegate: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling. + codeact_enable_llm_editor: Whether LLM editor is enabled in the action space. Default is False. Only works with function calling. + codeact_enable_jupyter: Whether Jupyter is enabled in the action space. Default is False. micro_agent_name: The name of the micro agent to use for this agent. memory_enabled: Whether long-term memory (embeddings) is enabled. memory_max_threads: The maximum number of threads indexing at the same time for embeddings. llm_config: The name of the llm config to use. If specified, this will override global llm config. """ + function_calling: bool = True + codeact_enable_browsing_delegate: bool = False + codeact_enable_llm_editor: bool = False + codeact_enable_jupyter: bool = False micro_agent_name: str | None = None memory_enabled: bool = False memory_max_threads: int = 3 diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index ac07b70e0b..7ad6476d7c 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -42,6 +42,7 @@ class LLMConfig: log_completions: Whether to log LLM completions to the state. log_completions_folder: The folder to log LLM completions to. Required if log_completions is True. draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985). + supports_function_calling: Whether the model supports function calling. """ model: str = 'gpt-4o' @@ -61,7 +62,7 @@ class LLMConfig: retry_min_wait: int = 15 retry_max_wait: int = 120 timeout: int | None = None - max_message_chars: int = 10_000 # maximum number of characters in an observation's content when sent to the llm + max_message_chars: int = 30_000 # maximum number of characters in an observation's content when sent to the llm temperature: float = 0.0 top_p: float = 1.0 custom_llm_provider: str | None = None @@ -76,6 +77,7 @@ class LLMConfig: log_completions: bool = False log_completions_folder: str | None = None draft_editor: Optional['LLMConfig'] = None + supports_function_calling: bool = False def defaults_to_dict(self) -> dict: """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" diff --git a/openhands/core/logger.py b/openhands/core/logger.py index 13b91e451e..af3ba1e35f 100644 --- a/openhands/core/logger.py +++ b/openhands/core/logger.py @@ -12,6 +12,9 @@ LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper() DEBUG = os.getenv('DEBUG', 'False').lower() in ['true', '1', 'yes'] if DEBUG: LOG_LEVEL = 'DEBUG' + import litellm + + litellm.set_verbose = True LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'False').lower() in ['true', '1', 'yes'] DISABLE_COLOR_PRINTING = False @@ -70,11 +73,6 @@ class ColoredFormatter(logging.Formatter): return super().format(record) -console_formatter = ColoredFormatter( - '\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s', - datefmt='%H:%M:%S', -) - file_formatter = logging.Formatter( '%(asctime)s - %(name)s:%(levelname)s: %(filename)s:%(lineno)s - %(message)s', datefmt='%H:%M:%S', @@ -123,10 +121,10 @@ def get_console_handler(log_level=logging.INFO, extra_info: str | None = None): """Returns a console handler for logging.""" console_handler = logging.StreamHandler() console_handler.setLevel(log_level) - formatter_str = '%(asctime)s - %(levelname)s - %(message)s' + formatter_str = '\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s' if extra_info: formatter_str = f'{extra_info} - ' + formatter_str - console_handler.setFormatter(logging.Formatter(formatter_str)) + console_handler.setFormatter(ColoredFormatter(formatter_str, datefmt='%H:%M:%S')) return console_handler diff --git a/openhands/core/message.py b/openhands/core/message.py index 57fadabde7..042028498b 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Literal +from litellm import ChatCompletionMessageToolCall from pydantic import BaseModel, Field, model_serializer @@ -48,10 +49,16 @@ class ImageContent(Content): class Message(BaseModel): - role: Literal['user', 'system', 'assistant'] - content: list[TextContent | ImageContent] = Field(default=list) + role: Literal['user', 'system', 'assistant', 'tool'] + content: list[TextContent | ImageContent] = Field(default_factory=list) cache_enabled: bool = False vision_enabled: bool = False + # function calling + # - tool calls (from LLM) + tool_calls: list[ChatCompletionMessageToolCall] | None = None + # - tool execution result (to LLM) + tool_call_id: str | None = None + name: str | None = None # name of the tool @property def contains_image(self) -> bool: @@ -59,23 +66,31 @@ class Message(BaseModel): @model_serializer def serialize_model(self) -> dict: - content: list[dict] | str - # two kinds of serializer: - # 1. vision serializer: when prompt caching or vision is enabled - # 2. single text serializer: for other cases - # remove this when liteLLM or providers support this format translation - if self.cache_enabled or self.vision_enabled: - # when prompt caching or vision is enabled, use vision serializer - content = [] - for item in self.content: - if isinstance(item, TextContent): - content.append(item.model_dump()) - elif isinstance(item, ImageContent): - content.extend(item.model_dump()) - else: - # for other cases, concatenate all text content - # into a single string per message - content = '\n'.join( - item.text for item in self.content if isinstance(item, TextContent) - ) - return {'content': content, 'role': self.role} + content: list[dict] = [] + role_tool_with_prompt_caching = False + for item in self.content: + d = item.model_dump() + # We have to remove cache_prompt for tool content and move it up to the message level + # See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472 + if self.role == 'tool' and item.cache_prompt: + role_tool_with_prompt_caching = True + d.pop('cache_control') + if isinstance(item, TextContent): + content.append(d) + elif isinstance(item, ImageContent) and self.vision_enabled: + content.extend(d) + + ret: dict = {'content': content, 'role': self.role} + + if role_tool_with_prompt_caching: + ret['cache_control'] = {'type': 'ephemeral'} + + if self.tool_call_id is not None: + assert ( + self.name is not None + ), 'name is required when tool_call_id is not None' + ret['tool_call_id'] = self.tool_call_id + ret['name'] = self.name + if self.tool_calls: + ret['tool_calls'] = self.tool_calls + return ret diff --git a/openhands/core/utils/json.py b/openhands/core/utils/json.py index 859cab1450..c0b22740be 100644 --- a/openhands/core/utils/json.py +++ b/openhands/core/utils/json.py @@ -2,6 +2,7 @@ import json from datetime import datetime from json_repair import repair_json +from litellm.types.utils import ModelResponse from openhands.core.exceptions import LLMResponseError from openhands.events.event import Event @@ -17,6 +18,8 @@ def my_default_encoder(obj): return event_to_dict(obj) if isinstance(obj, Metrics): return obj.get() + if isinstance(obj, ModelResponse): + return obj.model_dump() return json.JSONEncoder().default(obj) diff --git a/openhands/events/action/commands.py b/openhands/events/action/commands.py index 5a7e1fef52..83dd19f9d1 100644 --- a/openhands/events/action/commands.py +++ b/openhands/events/action/commands.py @@ -47,6 +47,9 @@ class CmdRunAction(Action): class IPythonRunCellAction(Action): code: str thought: str = '' + include_extra: bool = ( + True # whether to include CWD & Python interpreter in the output + ) action: str = ActionType.RUN_IPYTHON runnable: ClassVar[bool] = True confirmation_state: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED diff --git a/openhands/events/event.py b/openhands/events/event.py index 43a8840671..6ec68acc55 100644 --- a/openhands/events/event.py +++ b/openhands/events/event.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum +from openhands.events.tool import ToolCallMetadata from openhands.llm.metrics import Metrics @@ -71,3 +72,14 @@ class Event: @llm_metrics.setter def llm_metrics(self, value: Metrics) -> None: self._llm_metrics = value + + # optional field + @property + def tool_call_metadata(self) -> ToolCallMetadata | None: + if hasattr(self, '_tool_call_metadata'): + return self._tool_call_metadata # type: ignore[attr-defined] + return None + + @tool_call_metadata.setter + def tool_call_metadata(self, value: ToolCallMetadata) -> None: + self._tool_call_metadata = value diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index ee15ab6955..f0430b4cdd 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -6,10 +6,20 @@ from openhands.events.observation.observation import Observation from openhands.events.serialization.action import action_from_dict from openhands.events.serialization.observation import observation_from_dict from openhands.events.serialization.utils import remove_fields +from openhands.events.tool import ToolCallMetadata # TODO: move `content` into `extras` -TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation'] -UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause'] +TOP_KEYS = [ + 'id', + 'timestamp', + 'source', + 'message', + 'cause', + 'action', + 'observation', + 'tool_call_metadata', +] +UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause', 'tool_call_metadata'] DELETE_FROM_TRAJECTORY_EXTRAS = { 'screenshot', @@ -40,6 +50,8 @@ def event_from_dict(data) -> 'Event': value = value.isoformat() if key == 'source': value = EventSource(value) + if key == 'tool_call_metadata': + value = ToolCallMetadata(**value) setattr(evt, '_' + key, value) return evt @@ -59,6 +71,8 @@ def event_to_dict(event: 'Event') -> dict: d['timestamp'] = d['timestamp'].isoformat() if key == 'source' and 'source' in d: d['source'] = d['source'].value + if key == 'tool_call_metadata' and 'tool_call_metadata' in d: + d['tool_call_metadata'] = d['tool_call_metadata'].model_dump() props.pop(key, None) if 'security_risk' in props and props['security_risk'] is None: props.pop('security_risk') diff --git a/openhands/events/tool.py b/openhands/events/tool.py new file mode 100644 index 0000000000..30e288dc2f --- /dev/null +++ b/openhands/events/tool.py @@ -0,0 +1,11 @@ +from litellm import ModelResponse +from pydantic import BaseModel + + +class ToolCallMetadata(BaseModel): + # See https://docs.litellm.ai/docs/completion/function_call#step-3---second-litellmcompletion-call + function_name: str # Name of the function that was called + tool_call_id: str # ID of the tool call + + model_response: ModelResponse + total_calls_in_response: int diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 9eb3a08aa9..1696fd13b9 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -1,11 +1,12 @@ import copy -import json import os import time import warnings from functools import partial from typing import Any +import requests + from openhands.core.config import LLMConfig with warnings.catch_warnings(): @@ -81,16 +82,47 @@ class LLM(RetryMixin, DebugMixin): # litellm actually uses base Exception here for unknown model self.model_info: ModelInfo | None = None - try: - if self.config.model.startswith('openrouter'): - self.model_info = litellm.get_model_info(self.config.model) - else: + + if self.config.model.startswith('openrouter'): + self.model_info = litellm.get_model_info(self.config.model) + elif self.config.model.startswith('litellm_proxy/'): + # IF we are using LiteLLM proxy, get model info from LiteLLM proxy + # GET {base_url}/v1/model/info with litellm_model_id as path param + response = requests.get( + f'{self.config.base_url}/v1/model/info', + headers={'Authorization': f'Bearer {self.config.api_key}'}, + ) + all_model_info = response.json()['data'] + current_model_info = next( + ( + info + for info in all_model_info + if info['model_name'] + == self.config.model.removeprefix('litellm_proxy/') + ), + None, + ) + if current_model_info: + self.model_info = current_model_info['model_info'] + + # Last two attempts to get model info from NAME + if not self.model_info: + try: self.model_info = litellm.get_model_info( self.config.model.split(':')[0] ) - # noinspection PyBroadException - except Exception as e: - logger.warning(f'Could not get model info for {config.model}:\n{e}') + # noinspection PyBroadException + except Exception: + pass + if not self.model_info: + try: + self.model_info = litellm.get_model_info( + self.config.model.split('/')[-1] + ) + # noinspection PyBroadException + except Exception: + pass + logger.info(f'Model info: {self.model_info}') if self.config.log_completions: if self.config.log_completions_folder is None: @@ -126,6 +158,11 @@ class LLM(RetryMixin, DebugMixin): ): self.config.max_output_tokens = self.model_info['max_tokens'] + self.config.supports_function_calling = ( + self.model_info is not None + and self.model_info.get('supports_function_calling', False) + ) + self._completion = partial( litellm_completion, model=self.config.model, @@ -144,6 +181,8 @@ class LLM(RetryMixin, DebugMixin): logger.debug('LLM: model has vision enabled') if self.is_caching_prompt_active(): logger.debug('LLM: caching prompt enabled') + if self.config.supports_function_calling: + logger.debug('LLM: model supports function calling') completion_unwrapped = self._completion @@ -195,26 +234,32 @@ class LLM(RetryMixin, DebugMixin): try: # we don't support streaming here, thus we get a ModelResponse resp: ModelResponse = completion_unwrapped(*args, **kwargs) - # log for evals or other scripts that need the raw completion if self.config.log_completions: assert self.config.log_completions_folder is not None log_file = os.path.join( self.config.log_completions_folder, # use the metric model name (for draft editor) - f'{self.metrics.model_name}-{time.time()}.json', + f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json', ) + from openhands.core.utils import json + with open(log_file, 'w') as f: - json.dump( - { - 'messages': messages, - 'response': resp, - 'args': args, - 'kwargs': kwargs, - 'timestamp': time.time(), - 'cost': self._completion_cost(resp), - }, - f, + f.write( + json.dumps( + { + 'messages': messages, + 'response': resp, + 'args': args, + 'kwargs': { + k: v + for k, v in kwargs.items() + if k != 'messages' + }, + 'timestamp': time.time(), + 'cost': self._completion_cost(resp), + }, + ) ) message_back: str = resp['choices'][0]['message']['content'] @@ -390,16 +435,18 @@ class LLM(RetryMixin, DebugMixin): logger.info(f'Using custom cost per token: {cost_per_token}') extra_kwargs['custom_cost_per_token'] = cost_per_token - if not self._is_local(): - try: + try: + # try directly get response_cost from response + cost = getattr(response, '_hidden_params', {}).get('response_cost', None) + if cost is None: cost = litellm_completion_cost( completion_response=response, **extra_kwargs ) - self.metrics.add_cost(cost) - return cost - except Exception: - self.cost_metric_supported = False - logger.warning('Cost calculation not supported for this model.') + self.metrics.add_cost(cost) + return cost + except Exception: + self.cost_metric_supported = False + logger.warning('Cost calculation not supported for this model.') return 0.0 def __str__(self): diff --git a/openhands/runtime/action_execution_server.py b/openhands/runtime/action_execution_server.py index 2da1372532..1cec1b7083 100644 --- a/openhands/runtime/action_execution_server.py +++ b/openhands/runtime/action_execution_server.py @@ -194,10 +194,11 @@ class ActionExecutor: obs: IPythonRunCellObservation = await _jupyter_plugin.run(action) obs.content = obs.content.rstrip() - obs.content += ( - f'\n[Jupyter current working directory: {self.bash_session.pwd}]' - ) - obs.content += f'\n[Jupyter Python interpreter: {_jupyter_plugin.python_interpreter_path}]' + if action.include_extra: + obs.content += ( + f'\n[Jupyter current working directory: {self.bash_session.pwd}]' + ) + obs.content += f'\n[Jupyter Python interpreter: {_jupyter_plugin.python_interpreter_path}]' return obs else: raise RuntimeError( diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 0d752d71be..81e0366665 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -127,8 +127,11 @@ class Runtime(FileEditRuntimeMixin): if event.timeout is None: event.timeout = self.config.sandbox.timeout assert event.timeout is not None - observation = await call_sync_from_async(self.run_action, event) + observation: Observation = await call_sync_from_async( + self.run_action, event + ) observation._cause = event.id # type: ignore[attr-defined] + observation.tool_call_metadata = event.tool_call_metadata source = event.source if event.source else EventSource.AGENT await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type] diff --git a/openhands/runtime/impl/eventstream/eventstream_runtime.py b/openhands/runtime/impl/eventstream/eventstream_runtime.py index 1cf4096922..5121e9ede9 100644 --- a/openhands/runtime/impl/eventstream/eventstream_runtime.py +++ b/openhands/runtime/impl/eventstream/eventstream_runtime.py @@ -1,6 +1,7 @@ import os import tempfile import threading +import traceback from functools import lru_cache from typing import Callable from zipfile import ZipFile @@ -496,7 +497,9 @@ class EventStreamRuntime(Runtime): ) except Exception as e: logger.error(f'Error during action execution: {e}') - obs = FatalErrorObservation(f'Action execution failed: {str(e)}') + obs = FatalErrorObservation( + f'Action execution failed: {str(e)}.\n{traceback.format_exc()}' + ) self._refresh_logs() return obs diff --git a/openhands/runtime/plugins/agent_skills/agentskills.py b/openhands/runtime/plugins/agent_skills/agentskills.py index dd34e3878d..046f8af20c 100644 --- a/openhands/runtime/plugins/agent_skills/agentskills.py +++ b/openhands/runtime/plugins/agent_skills/agentskills.py @@ -23,3 +23,9 @@ for func_name in __all__: fn_signature = f'{func.__name__}' + str(signature(func)) DOCUMENTATION += f'{fn_signature}:\n{cur_doc}\n\n' + + +# Add file_editor (a function) +from openhands.runtime.plugins.agent_skills.file_editor import file_editor # noqa: E402 + +__all__ += ['file_editor'] diff --git a/openhands/runtime/plugins/agent_skills/file_editor/README.md b/openhands/runtime/plugins/agent_skills/file_editor/README.md new file mode 100644 index 0000000000..37c6c2818a --- /dev/null +++ b/openhands/runtime/plugins/agent_skills/file_editor/README.md @@ -0,0 +1,3 @@ +# File Editor + +This file editor is largely based on Anthorpic released [`str_replace_editor`](https://github.com/anthropics/anthropic-quickstarts/tree/main/computer-use-demo/computer_use_demo/tools/edit.py). The original code was released under [MIT license](https://github.com/anthropics/anthropic-quickstarts/blob/e373524f07594d48c3f9563248ea282a4c306c0c/LICENSE). diff --git a/openhands/runtime/plugins/agent_skills/file_editor/__init__.py b/openhands/runtime/plugins/agent_skills/file_editor/__init__.py new file mode 100644 index 0000000000..f6d3eb39a0 --- /dev/null +++ b/openhands/runtime/plugins/agent_skills/file_editor/__init__.py @@ -0,0 +1,60 @@ +"""This file contains a global singleton of the `EditTool` class as well as raw functions that expose its __call__.""" + +from .base import CLIResult, ToolError, ToolResult +from .impl import Command, EditTool + +_GLOBAL_EDITOR = EditTool() + + +def _make_api_tool_result( + result: ToolResult, +) -> str: + """Convert an agent ToolResult to an API ToolResultBlockParam.""" + tool_result_content: str = '' + is_error = False + if result.error: + is_error = True + tool_result_content = _maybe_prepend_system_tool_result(result, result.error) + else: + assert result.output, 'Expecting output in file_editor' + tool_result_content = _maybe_prepend_system_tool_result(result, result.output) + assert ( + not result.base64_image + ), 'Not expecting base64_image as output in file_editor' + if is_error: + return f'ERROR:\n{tool_result_content}' + else: + return tool_result_content + + +def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: + if result.system: + result_text = f'{result.system}\n{result_text}' + return result_text + + +def file_editor( + command: Command, + path: str, + file_text: str | None = None, + view_range: list[int] | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, +) -> str: + try: + result: CLIResult = _GLOBAL_EDITOR( + command=command, + path=path, + file_text=file_text, + view_range=view_range, + old_str=old_str, + new_str=new_str, + insert_line=insert_line, + ) + except ToolError as e: + return _make_api_tool_result(ToolResult(error=e.message)) + return _make_api_tool_result(result) + + +__all__ = ['file_editor'] diff --git a/openhands/runtime/plugins/agent_skills/file_editor/base.py b/openhands/runtime/plugins/agent_skills/file_editor/base.py new file mode 100644 index 0000000000..6ad2a4b5b6 --- /dev/null +++ b/openhands/runtime/plugins/agent_skills/file_editor/base.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass, fields, replace + + +@dataclass(kw_only=True, frozen=True) +class ToolResult: + """Represents the result of a tool execution.""" + + output: str | None = None + error: str | None = None + base64_image: str | None = None + system: str | None = None + + def __bool__(self): + return any(getattr(self, field.name) for field in fields(self)) + + def __add__(self, other: 'ToolResult'): + def combine_fields( + field: str | None, other_field: str | None, concatenate: bool = True + ): + if field and other_field: + if concatenate: + return field + other_field + raise ValueError('Cannot combine tool results') + return field or other_field + + return ToolResult( + output=combine_fields(self.output, other.output), + error=combine_fields(self.error, other.error), + base64_image=combine_fields(self.base64_image, other.base64_image, False), + system=combine_fields(self.system, other.system), + ) + + def replace(self, **kwargs): + """Returns a new ToolResult with the given fields replaced.""" + return replace(self, **kwargs) + + +class CLIResult(ToolResult): + """A ToolResult that can be rendered as a CLI output.""" + + +class ToolFailure(ToolResult): + """A ToolResult that represents a failure.""" + + +class ToolError(Exception): + """Raised when a tool encounters an error.""" + + def __init__(self, message): + self.message = message diff --git a/openhands/runtime/plugins/agent_skills/file_editor/impl.py b/openhands/runtime/plugins/agent_skills/file_editor/impl.py new file mode 100644 index 0000000000..e0944ab659 --- /dev/null +++ b/openhands/runtime/plugins/agent_skills/file_editor/impl.py @@ -0,0 +1,279 @@ +from collections import defaultdict +from pathlib import Path +from typing import Literal, get_args + +from .base import CLIResult, ToolError, ToolResult +from .run import maybe_truncate, run + +Command = Literal[ + 'view', + 'create', + 'str_replace', + 'insert', + 'undo_edit', +] +SNIPPET_LINES: int = 4 + + +class EditTool: + """ + An filesystem editor tool that allows the agent to view, create, and edit files. + The tool parameters are defined by Anthropic and are not editable. + + Original implementation: https://github.com/anthropics/anthropic-quickstarts/blob/main/computer-use-demo/computer_use_demo/tools/edit.py + """ + + _file_history: dict[Path, list[str]] + + def __init__(self): + self._file_history = defaultdict(list) + super().__init__() + + def __call__( + self, + *, + command: Command, + path: str, + file_text: str | None = None, + view_range: list[int] | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, + **kwargs, + ): + _path = Path(path) + self.validate_path(command, _path) + if command == 'view': + return self.view(_path, view_range) + elif command == 'create': + if not file_text: + raise ToolError('Parameter `file_text` is required for command: create') + self.write_file(_path, file_text) + self._file_history[_path].append(file_text) + return ToolResult(output=f'File created successfully at: {_path}') + elif command == 'str_replace': + if not old_str: + raise ToolError( + 'Parameter `old_str` is required for command: str_replace' + ) + return self.str_replace(_path, old_str, new_str) + elif command == 'insert': + if insert_line is None: + raise ToolError( + 'Parameter `insert_line` is required for command: insert' + ) + if not new_str: + raise ToolError('Parameter `new_str` is required for command: insert') + return self.insert(_path, insert_line, new_str) + elif command == 'undo_edit': + return self.undo_edit(_path) + raise ToolError( + f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}' + ) + + def validate_path(self, command: str, path: Path): + """ + Check that the path/command combination is valid. + """ + # Check if its an absolute path + if not path.is_absolute(): + suggested_path = Path('') / path + raise ToolError( + f'The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?' + ) + # Check if path exists + if not path.exists() and command != 'create': + raise ToolError( + f'The path {path} does not exist. Please provide a valid path.' + ) + if path.exists() and command == 'create': + raise ToolError( + f'File already exists at: {path}. Cannot overwrite files using command `create`.' + ) + # Check if the path points to a directory + if path.is_dir(): + if command != 'view': + raise ToolError( + f'The path {path} is a directory and only the `view` command can be used on directories' + ) + + def view(self, path: Path, view_range: list[int] | None = None): + """Implement the view command""" + if path.is_dir(): + if view_range: + raise ToolError( + 'The `view_range` parameter is not allowed when `path` points to a directory.' + ) + + _, stdout, stderr = run(rf"find {path} -maxdepth 2 -not -path '*/\.*'") + if not stderr: + stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n" + return CLIResult(output=stdout, error=stderr) + + file_content = self.read_file(path) + init_line = 1 + if view_range: + if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range): + raise ToolError( + 'Invalid `view_range`. It should be a list of two integers.' + ) + file_lines = file_content.split('\n') + n_lines_file = len(file_lines) + init_line, final_line = view_range + if init_line < 1 or init_line > n_lines_file: + raise ToolError( + f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}" + ) + if final_line > n_lines_file: + raise ToolError( + f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`" + ) + if final_line != -1 and final_line < init_line: + raise ToolError( + f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`" + ) + + if final_line == -1: + file_content = '\n'.join(file_lines[init_line - 1 :]) + else: + file_content = '\n'.join(file_lines[init_line - 1 : final_line]) + + return CLIResult( + output=self._make_output(file_content, str(path), init_line=init_line) + ) + + def str_replace(self, path: Path, old_str: str, new_str: str | None): + """Implement the str_replace command, which replaces old_str with new_str in the file content""" + # Read the file content + file_content = self.read_file(path).expandtabs() + old_str = old_str.expandtabs() + new_str = new_str.expandtabs() if new_str is not None else '' + + # Check if old_str is unique in the file + occurrences = file_content.count(old_str) + if occurrences == 0: + raise ToolError( + f'No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}.' + ) + elif occurrences > 1: + file_content_lines = file_content.split('\n') + lines = [ + idx + 1 + for idx, line in enumerate(file_content_lines) + if old_str in line + ] + raise ToolError( + f'No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique' + ) + + # Replace old_str with new_str + new_file_content = file_content.replace(old_str, new_str) + + # Write the new content to the file + self.write_file(path, new_file_content) + + # Save the content to history + self._file_history[path].append(file_content) + + # Create a snippet of the edited section + replacement_line = file_content.split(old_str)[0].count('\n') + start_line = max(0, replacement_line - SNIPPET_LINES) + end_line = replacement_line + SNIPPET_LINES + new_str.count('\n') + snippet = '\n'.join(new_file_content.split('\n')[start_line : end_line + 1]) + + # Prepare the success message + success_msg = f'The file {path} has been edited. ' + success_msg += self._make_output( + snippet, f'a snippet of {path}', start_line + 1 + ) + success_msg += 'Review the changes and make sure they are as expected. Edit the file again if necessary.' + + return CLIResult(output=success_msg) + + def insert(self, path: Path, insert_line: int, new_str: str): + """Implement the insert command, which inserts new_str at the specified line in the file content.""" + file_text = self.read_file(path).expandtabs() + new_str = new_str.expandtabs() + file_text_lines = file_text.split('\n') + n_lines_file = len(file_text_lines) + + if insert_line < 0 or insert_line > n_lines_file: + raise ToolError( + f'Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}' + ) + + new_str_lines = new_str.split('\n') + new_file_text_lines = ( + file_text_lines[:insert_line] + + new_str_lines + + file_text_lines[insert_line:] + ) + snippet_lines = ( + file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line] + + new_str_lines + + file_text_lines[insert_line : insert_line + SNIPPET_LINES] + ) + + new_file_text = '\n'.join(new_file_text_lines) + snippet = '\n'.join(snippet_lines) + + self.write_file(path, new_file_text) + self._file_history[path].append(file_text) + + success_msg = f'The file {path} has been edited. ' + success_msg += self._make_output( + snippet, + 'a snippet of the edited file', + max(1, insert_line - SNIPPET_LINES + 1), + ) + success_msg += 'Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary.' + return CLIResult(output=success_msg) + + def undo_edit(self, path: Path): + """Implement the undo_edit command.""" + if not self._file_history[path]: + raise ToolError(f'No edit history found for {path}.') + + old_text = self._file_history[path].pop() + self.write_file(path, old_text) + + return CLIResult( + output=f'Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}' + ) + + def read_file(self, path: Path): + """Read the content of a file from a given path; raise a ToolError if an error occurs.""" + try: + return path.read_text() + except Exception as e: + raise ToolError(f'Ran into {e} while trying to read {path}') from None + + def write_file(self, path: Path, file: str): + """Write the content of a file to a given path; raise a ToolError if an error occurs.""" + try: + path.write_text(file) + except Exception as e: + raise ToolError(f'Ran into {e} while trying to write to {path}') from None + + def _make_output( + self, + file_content: str, + file_descriptor: str, + init_line: int = 1, + expand_tabs: bool = True, + ): + """Generate output for the CLI based on the content of a file.""" + file_content = maybe_truncate(file_content) + if expand_tabs: + file_content = file_content.expandtabs() + file_content = '\n'.join( + [ + f'{i + init_line:6}\t{line}' + for i, line in enumerate(file_content.split('\n')) + ] + ) + return ( + f"Here's the result of running `cat -n` on {file_descriptor}:\n" + + file_content + + '\n' + ) diff --git a/openhands/runtime/plugins/agent_skills/file_editor/run.py b/openhands/runtime/plugins/agent_skills/file_editor/run.py new file mode 100644 index 0000000000..29c604256f --- /dev/null +++ b/openhands/runtime/plugins/agent_skills/file_editor/run.py @@ -0,0 +1,44 @@ +"""Utility to run shell commands asynchronously with a timeout.""" + +import subprocess +import time + +TRUNCATED_MESSAGE: str = 'To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.' +MAX_RESPONSE_LEN: int = 16000 + + +def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN): + """Truncate content and append a notice if content exceeds the specified length.""" + return ( + content + if not truncate_after or len(content) <= truncate_after + else content[:truncate_after] + TRUNCATED_MESSAGE + ) + + +def run( + cmd: str, + timeout: float | None = 120.0, # seconds + truncate_after: int | None = MAX_RESPONSE_LEN, +): + """Run a shell command synchronously with a timeout.""" + start_time = time.time() + + try: + process = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + + stdout, stderr = process.communicate(timeout=timeout) + + return ( + process.returncode or 0, + maybe_truncate(stdout, truncate_after=truncate_after), + maybe_truncate(stderr, truncate_after=truncate_after), + ) + except subprocess.TimeoutExpired: + process.kill() + elapsed_time = time.time() - start_time + raise TimeoutError( + f"Command '{cmd}' timed out after {elapsed_time:.2f} seconds" + ) diff --git a/poetry.lock b/poetry.lock index 20316ac24f..6ef4a43322 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aenum" @@ -3901,22 +3901,20 @@ name = "litellm" version = "1.50.4" description = "Library to easily interface with LLM API providers" optional = false -python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" -files = [ - {file = "litellm-1.50.4-py3-none-any.whl", hash = "sha256:cc6992275e24a0bbb4a3b377e6842d45a8510fc85d7f255930a64bb872980a36"}, - {file = "litellm-1.50.4.tar.gz", hash = "sha256:a7e68ef614f631b58969c2c7a5154a565ba5974558d437c8cd6c8623654880ea"}, -] +python-versions = ">=3.8.1,<4.0, !=3.9.7" +files = [] +develop = false [package.dependencies] aiohttp = "*" click = "*" importlib-metadata = ">=6.8.0" -jinja2 = ">=3.1.2,<4.0.0" -jsonschema = ">=4.22.0,<5.0.0" +jinja2 = "^3.1.2" +jsonschema = "^4.22.0" openai = ">=1.52.0" -pydantic = ">=2.0.0,<3.0.0" +pydantic = "^2.0.0" python-dotenv = ">=0.2.0" -requests = ">=2.31.0,<3.0.0" +requests = "^2.31.0" tiktoken = ">=0.7.0" tokenizers = "*" @@ -3924,6 +3922,12 @@ tokenizers = "*" extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "resend (>=0.8.0,<0.9.0)"] proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"] +[package.source] +type = "git" +url = "https://github.com/BerriAI/litellm.git" +reference = "58fe6610601297c5a90367fd4583469e2df3fcf9" +resolved_reference = "58fe6610601297c5a90367fd4583469e2df3fcf9" + [[package]] name = "llama-index" version = "0.10.45.post1" @@ -10095,4 +10099,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "aeb09e429a789c3f8ced605e7e1a5932fd6cce7f7f4ce30a960da77fba18b9a3" +content-hash = "880b0251ec1ac83a7a8f6b1637b0860f75553d1f9c7c67313af9f3bb686c166a" diff --git a/pyproject.toml b/pyproject.toml index 393bab9594..a7daebd37d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ packages = [ python = "^3.12" datasets = "*" pandas = "*" -litellm = "*" +litellm = { git = "https://github.com/BerriAI/litellm.git", rev = "58fe6610601297c5a90367fd4583469e2df3fcf9" } google-generativeai = "*" # To use litellm with Gemini Pro API termcolor = "*" seaborn = "*" @@ -89,6 +89,7 @@ reportlab = "*" [tool.coverage.run] concurrency = ["gevent"] + [tool.poetry.group.runtime.dependencies] jupyterlab = "*" notebook = "*" @@ -119,6 +120,7 @@ ignore = ["D1"] [tool.ruff.lint.pydocstyle] convention = "google" + [tool.poetry.group.evaluation.dependencies] streamlit = "*" whatthepatch = "*" diff --git a/tests/unit/test_agent_skill.py b/tests/unit/test_agent_skill.py index 63745f4dd2..f619bc00bf 100644 --- a/tests/unit/test_agent_skill.py +++ b/tests/unit/test_agent_skill.py @@ -5,6 +5,7 @@ import sys import docx import pytest +from openhands.runtime.plugins.agent_skills.agentskills import file_editor from openhands.runtime.plugins.agent_skills.file_ops.file_ops import ( WINDOW, _print_window, @@ -715,3 +716,302 @@ def test_parse_pptx(tmp_path): 'Hello, this is the second test PPTX slide.\n\n' ) assert output == expected_output, f'Expected output does not match. Got: {output}' + + +# ============================================================================= + + +def test_file_editor_view(tmp_path): + # generate a random directory + random_dir = tmp_path / 'dir_1' + random_dir.mkdir() + # create a file in the directory + random_file = random_dir / 'a.txt' + random_file.write_text('Line 1\nLine 2\nLine 3\nLine 4\nLine 5') + random_dir_2 = tmp_path / 'dir_2' + random_dir_2.mkdir() + random_file_2 = random_dir_2 / 'b.txt' + random_file_2.write_text('Line 1\nLine 2\nLine 3\nLine 4\nLine 5') + + from openhands.runtime.plugins.agent_skills.agentskills import file_editor + + # view the file + result = file_editor(command='view', path=str(random_file)) + print('\n', result) + assert result is not None + assert ( + result.split('\n') + == f"""Here's the result of running `cat -n` on {random_file}: + 1\tLine 1 + 2\tLine 2 + 3\tLine 3 + 4\tLine 4 + 5\tLine 5 +""".split('\n') + ) + + # view the directory + result = file_editor(command='view', path=str(tmp_path)) + print('\n', result) + assert result is not None + assert ( + result.strip().split('\n') + == f"""Here's the files and directories up to 2 levels deep in {tmp_path}, excluding hidden items: +{tmp_path} +{tmp_path}/dir_2 +{tmp_path}/dir_2/b.txt +{tmp_path}/dir_1 +{tmp_path}/dir_1/a.txt +""".strip().split('\n') + ) + + +def test_file_editor_create(tmp_path): + # generate a random directory + random_dir = tmp_path / 'dir_1' + random_dir.mkdir() + # create a file in the directory + random_file = random_dir / 'a.txt' + + from openhands.runtime.plugins.agent_skills.agentskills import file_editor + + # view an unexist file + result = file_editor(command='view', path=str(random_file)) + print(result) + assert result is not None + assert ( + result + == f'ERROR:\nThe path {random_file} does not exist. Please provide a valid path.' + ) + + # create a file + result = file_editor(command='create', path=str(random_file), file_text='Line 6') + print(result) + assert result is not None + assert result == f'File created successfully at: {random_file}' + + # view again + result = file_editor(command='view', path=str(random_file)) + print(result) + assert result is not None + assert ( + result.strip().split('\n') + == f"""Here's the result of running `cat -n` on {random_file}: + 1\tLine 6 +""".strip().split('\n') + ) + + +@pytest.fixture +def setup_file(tmp_path): + random_dir = tmp_path / 'dir_1' + random_dir.mkdir() + random_file = random_dir / 'a.txt' + return random_file + + +def test_file_editor_create_and_view(setup_file): + random_file = setup_file + + # Test create command + result = file_editor( + command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3' + ) + print(result) + assert result == f'File created successfully at: {random_file}' + + # Test view command for file + result = file_editor(command='view', path=str(random_file)) + print(result) + assert ( + result.strip().split('\n') + == f"""Here's the result of running `cat -n` on {random_file}: + 1\tLine 1 + 2\tLine 2 + 3\tLine 3 +""".strip().split('\n') + ) + + # Test view command for directory + result = file_editor(command='view', path=str(random_file.parent)) + assert f'{random_file.parent}' in result + assert f'{random_file.name}' in result + + +def test_file_editor_view_nonexistent(setup_file): + random_file = setup_file + + # Test view command for non-existent file + result = file_editor(command='view', path=str(random_file)) + assert ( + result + == f'ERROR:\nThe path {random_file} does not exist. Please provide a valid path.' + ) + + +def test_file_editor_str_replace(setup_file): + random_file = setup_file + file_editor( + command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3' + ) + + # Test str_replace command + result = file_editor( + command='str_replace', + path=str(random_file), + old_str='Line 2', + new_str='New Line 2', + ) + print(result) + assert ( + result + == f"""The file {random_file} has been edited. Here's the result of running `cat -n` on a snippet of {random_file}: + 1\tLine 1 + 2\tNew Line 2 + 3\tLine 3 +Review the changes and make sure they are as expected. Edit the file again if necessary.""" + ) + + # View the file after str_replace + result = file_editor(command='view', path=str(random_file)) + print(result) + assert ( + result.strip().split('\n') + == f"""Here's the result of running `cat -n` on {random_file}: + 1\tLine 1 + 2\tNew Line 2 + 3\tLine 3 +""".strip().split('\n') + ) + + +def test_file_editor_str_replace_non_existent(setup_file): + random_file = setup_file + file_editor( + command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3' + ) + + # Test str_replace with non-existent string + result = file_editor( + command='str_replace', + path=str(random_file), + old_str='Non-existent Line', + new_str='New Line', + ) + print(result) + assert ( + result + == f'ERROR:\nNo replacement was performed, old_str `Non-existent Line` did not appear verbatim in {random_file}.' + ) + + +def test_file_editor_insert(setup_file): + random_file = setup_file + file_editor( + command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3' + ) + + # Test insert command + result = file_editor( + command='insert', path=str(random_file), insert_line=2, new_str='Inserted Line' + ) + print(result) + assert ( + result + == f"""The file {random_file} has been edited. Here's the result of running `cat -n` on a snippet of the edited file: + 1\tLine 1 + 2\tLine 2 + 3\tInserted Line + 4\tLine 3 +Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary.""" + ) + + # View the file after insert + result = file_editor(command='view', path=str(random_file)) + assert ( + result.strip().split('\n') + == f"""Here's the result of running `cat -n` on {random_file}: + 1\tLine 1 + 2\tLine 2 + 3\tInserted Line + 4\tLine 3 +""".strip().split('\n') + ) + + +def test_file_editor_insert_invalid_line(setup_file): + random_file = setup_file + file_editor( + command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3' + ) + + # Test insert with invalid line number + result = file_editor( + command='insert', + path=str(random_file), + insert_line=10, + new_str='Invalid Insert', + ) + assert ( + result + == 'ERROR:\nInvalid `insert_line` parameter: 10. It should be within the range of lines of the file: [0, 3]' + ) + + +def test_file_editor_undo_edit(setup_file): + random_file = setup_file + result = file_editor( + command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3' + ) + print(result) + assert result == f"""File created successfully at: {random_file}""" + + # Make an edit + result = file_editor( + command='str_replace', + path=str(random_file), + old_str='Line 2', + new_str='New Line 2', + ) + print(result) + assert ( + result + == f"""The file {random_file} has been edited. Here's the result of running `cat -n` on a snippet of {random_file}: + 1\tLine 1 + 2\tNew Line 2 + 3\tLine 3 +Review the changes and make sure they are as expected. Edit the file again if necessary.""" + ) + + # Test undo_edit command + result = file_editor(command='undo_edit', path=str(random_file)) + print(result) + assert ( + result + == f"""Last edit to {random_file} undone successfully. Here's the result of running `cat -n` on {random_file}: + 1\tLine 1 + 2\tLine 2 + 3\tLine 3 +""" + ) + + # View the file after undo_edit + result = file_editor(command='view', path=str(random_file)) + assert ( + result.strip().split('\n') + == f"""Here's the result of running `cat -n` on {random_file}: + 1\tLine 1 + 2\tLine 2 + 3\tLine 3 +""".strip().split('\n') + ) + + +def test_file_editor_undo_edit_no_edits(tmp_path): + random_file = tmp_path / 'a.txt' + random_file.touch() + + # Test undo_edit when no edits have been made + result = file_editor(command='undo_edit', path=str(random_file)) + print(result) + assert result == f'ERROR:\nNo edit history found for {random_file}.' diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 55dfa3feb7..9e3dda6c2c 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -24,29 +24,35 @@ def agent() -> CodeActAgent: def test_cmd_output_observation_message(agent: CodeActAgent): + agent.config.function_calling = False obs = CmdOutputObservation( command='echo hello', content='Command output', command_id=1, exit_code=0 ) - result = agent.get_observation_message(obs) + results = agent.get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + result = results[0] assert result is not None assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert 'OBSERVATION:' in result.content[0].text assert 'Command output' in result.content[0].text - assert 'Command 1 finished with exit code 0' in result.content[0].text + assert 'Command finished with exit code 0' in result.content[0].text def test_ipython_run_cell_observation_message(agent: CodeActAgent): + agent.config.function_calling = False obs = IPythonRunCellObservation( code='plt.plot()', content='IPython output\n![image](data:image/png;base64,ABC123)', ) - result = agent.get_observation_message(obs) + results = agent.get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + result = results[0] assert result is not None assert result.role == 'user' assert len(result.content) == 1 @@ -61,12 +67,15 @@ def test_ipython_run_cell_observation_message(agent: CodeActAgent): def test_agent_delegate_observation_message(agent: CodeActAgent): + agent.config.function_calling = False obs = AgentDelegateObservation( content='Content', outputs={'content': 'Delegated agent output'} ) - result = agent.get_observation_message(obs) + results = agent.get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + result = results[0] assert result is not None assert result.role == 'user' assert len(result.content) == 1 @@ -76,10 +85,13 @@ def test_agent_delegate_observation_message(agent: CodeActAgent): def test_error_observation_message(agent: CodeActAgent): + agent.config.function_calling = False obs = ErrorObservation('Error message') - result = agent.get_observation_message(obs) + results = agent.get_observation_message(obs, tool_call_id_to_message={}) + assert len(results) == 1 + result = results[0] assert result is not None assert result.role == 'user' assert len(result.content) == 1 @@ -93,4 +105,4 @@ def test_unknown_observation_message(agent: CodeActAgent): obs = Mock() with pytest.raises(ValueError, match='Unknown observation type:'): - agent.get_observation_message(obs) + agent.get_observation_message(obs, tool_call_id_to_message={}) diff --git a/tests/unit/test_message_serialization.py b/tests/unit/test_message_serialization.py index 16035f0aed..aac1f8d014 100644 --- a/tests/unit/test_message_serialization.py +++ b/tests/unit/test_message_serialization.py @@ -78,7 +78,10 @@ def test_message_with_only_text_content_and_vision_disabled(): expected_serialized_message = { 'role': 'user', - 'content': 'This is a text message\nThis is another text message', + 'content': [ + {'type': 'text', 'text': 'This is a text message'}, + {'type': 'text', 'text': 'This is another text message'}, + ], } assert serialized_message == expected_serialized_message @@ -107,7 +110,10 @@ def test_message_with_mixed_content_and_vision_disabled(): # Expected serialization ignores images and concatenates text expected_serialized_message = { 'role': 'user', - 'content': 'This is a text message\nThis is another text message', + 'content': [ + {'type': 'text', 'text': 'This is a text message'}, + {'type': 'text', 'text': 'This is another text message'}, + ], } # Assert serialized message matches expectation diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 41dc756187..d038f18d9c 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -25,12 +25,38 @@ def mock_event_stream(tmp_path): return EventStream('test_session', file_store) -@pytest.fixture -def codeact_agent(mock_llm): +@pytest.fixture(params=[False, True]) +def codeact_agent(mock_llm, request): config = AgentConfig() + config.function_calling = request.param return CodeActAgent(mock_llm, config) +def response_mock(content: str): + class MockModelResponse: + def __init__(self, content): + self.choices = [ + { + 'message': { + 'content': content, + 'tool_calls': [ + { + 'function': { + 'name': 'execute_bash', + 'arguments': '{}', + } + } + ], + } + } + ] + + def model_dump(self): + return {'choices': self.choices} + + return MockModelResponse(content) + + def test_get_messages_with_reminder(codeact_agent, mock_event_stream): # Add some events to the stream mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER) @@ -49,9 +75,13 @@ def test_get_messages_with_reminder(codeact_agent, mock_event_stream): ) # System, initial user + user message, agent message, last user message assert messages[0].content[0].cache_prompt # system message assert messages[1].role == 'user' - assert messages[1].content[0].text.endswith("LET'S START!") - assert messages[1].content[1].text.endswith('Initial user message') - assert messages[1].content[0].cache_prompt + if not codeact_agent.config.function_calling: + assert messages[1].content[0].text.endswith("LET'S START!") + assert messages[1].content[1].text.endswith('Initial user message') + else: + assert messages[1].content[0].text.endswith('Initial user message') + # we add cache breakpoint to the last 3 user messages + assert messages[1].content[-1].cache_prompt assert messages[3].role == 'user' assert messages[3].content[0].text == ('Hello, agent!') @@ -62,13 +92,14 @@ def test_get_messages_with_reminder(codeact_agent, mock_event_stream): assert messages[5].role == 'user' assert messages[5].content[0].text.startswith('Laaaaaaaast!') assert messages[5].content[0].cache_prompt - assert ( - messages[5] - .content[1] - .text.endswith( - 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with .' + if not codeact_agent.config.function_calling: + assert ( + messages[5] + .content[1] + .text.endswith( + 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with .' + ) ) - ) def test_get_messages_prompt_caching(codeact_agent, mock_event_stream): @@ -97,9 +128,10 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream): ) # Including the initial system+user + 2 last user message # Verify that these are indeed the last two user messages (from start) - assert ( - cached_user_messages[0].content[0].text.startswith('A chat between') - ) # system message + if not codeact_agent.config.function_calling: + assert ( + cached_user_messages[0].content[0].text.startswith('A chat between') + ) # system message assert cached_user_messages[2].content[0].text.startswith('User message 1') assert cached_user_messages[3].content[0].text.startswith('User message 1') @@ -144,14 +176,15 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): # Assert the presence of key elements in the messages assert ( messages[1] - .content[1] + .content[-1] .text.startswith("Let's list the contents of the current directory.") ) # user, included in the initial message - assert any( - 'List files in current directory\n\nls -l\n' - in msg.content[0].text - for msg in messages - ) # agent + if not codeact_agent.config.function_calling: + assert any( + 'List files in current directory\n\nls -l\n' + in msg.content[0].text + for msg in messages + ) # agent assert any( 'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt' in msg.content[0].text @@ -160,7 +193,8 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): assert any( "Now, let's create a new directory." in msg.content[0].text for msg in messages ) # agent - assert messages[4].content[1].text.startswith('Create a new directory') # agent + if not codeact_agent.config.function_calling: + assert messages[4].content[1].text.startswith('Create a new directory') # agent assert any( 'finished with exit code 0' in msg.content[0].text for msg in messages ) # user, observation @@ -171,16 +205,20 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): # prompt cache is added to the system message assert messages[0].content[0].cache_prompt # and the first initial user message - assert messages[1].content[0].cache_prompt + assert messages[1].content[-1].cache_prompt # and to the last two user messages assert messages[3].content[0].cache_prompt assert messages[5].content[0].cache_prompt # reminder is added to the last user message - assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text + if not codeact_agent.config.function_calling: + assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text def test_prompt_caching_headers(codeact_agent, mock_event_stream): + if codeact_agent.config.function_calling: + pytest.skip('Skipping this test for function calling') + # Setup mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER) mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT) diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py index 4f0aa4dc39..771ccc206d 100644 --- a/tests/unit/test_security.py +++ b/tests/unit/test_security.py @@ -221,8 +221,9 @@ def test_unsafe_bash_command(temp_dir: str): name=ActionType.RUN_IPYTHON, arguments={ 'code': "print('hello')", - 'kernel_init_code': '', + 'include_extra': True, 'confirmation_state': ActionConfirmationStatus.CONFIRMED, + 'kernel_init_code': '', }, ), ),