diff --git a/evaluation/integration_tests/run_infer.py b/evaluation/integration_tests/run_infer.py
index a530041f92..621a0fa91c 100644
--- a/evaluation/integration_tests/run_infer.py
+++ b/evaluation/integration_tests/run_infer.py
@@ -16,6 +16,7 @@ from evaluation.utils.shared import (
)
from openhands.controller.state.state import State
from openhands.core.config import (
+ AgentConfig,
AppConfig,
SandboxConfig,
get_llm_config_arg,
@@ -24,6 +25,7 @@ from openhands.core.config import (
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction
+from openhands.events.serialization.event import event_to_dict
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
@@ -60,6 +62,12 @@ def get_config(
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
+ agent_config = AgentConfig(
+ codeact_enable_jupyter=True,
+ codeact_enable_browsing_delegate=True,
+ codeact_enable_llm_editor=False,
+ )
+ config.set_agent_config(agent_config)
return config
@@ -122,7 +130,7 @@ def process_instance(
# # result evaluation
# # =============================================
- histories = state.history.get_events()
+ histories = [event_to_dict(event) for event in state.history.get_events()]
test_result: TestResult = test_class.verify_result(runtime, histories)
metrics = state.metrics.get() if state.metrics else None
diff --git a/evaluation/swe_bench/run_infer.py b/evaluation/swe_bench/run_infer.py
index 9ac1e0cf66..8b8b45a463 100644
--- a/evaluation/swe_bench/run_infer.py
+++ b/evaluation/swe_bench/run_infer.py
@@ -23,6 +23,7 @@ from evaluation.utils.shared import (
)
from openhands.controller.state.state import State
from openhands.core.config import (
+ AgentConfig,
AppConfig,
SandboxConfig,
get_llm_config_arg,
@@ -45,11 +46,6 @@ AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActSWEAgent': codeact_user_response,
}
-AGENT_CLS_TO_INST_SUFFIX = {
- 'CodeActAgent': 'When you think you have fixed the issue through code changes, please run the following command: exit .\n',
- 'CodeActSWEAgent': 'When you think you have fixed the issue through code changes, please run the following command: exit .\n',
-}
-
def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
return f'{instance.repo}__{instance.version}'.replace('/', '__')
@@ -71,25 +67,27 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
)
instruction += CODEACT_SWE_PROMPT.format(workspace_dir_name=workspace_dir_name)
else:
- # Testing general agents
+ # Instruction based on Anthropic's official trajectory
+ # https://github.com/eschluntz/swe-bench-experiments/tree/main/evaluation/verified/20241022_tools_claude-3-5-sonnet-updated/trajs
instruction = (
- f'Please fix the following issue for the repository in /workspace/{workspace_dir_name}.\n'
- 'Environment has been set up for you to start working. You may assume all necessary tools are installed.\n\n'
- '# Problem Statement\n'
- f'{instance.problem_statement}\n\n'
+ '\n'
+ f'/workspace/{workspace_dir_name}\n'
+ '\n'
+ f"I've uploaded a python code repository in the directory {workspace_dir_name}. Consider the following PR description:\n\n"
+ f'\n'
+ f'{instance.problem_statement}\n'
+ '\n\n'
+ 'Can you help me implement the necessary changes to the repository so that the requirements specified in the are met?\n'
+ "I've already taken care of all changes to any of the test files described in the . This means you DON'T have to modify the testing logic or any of the tests in any way!\n"
+ 'Your task is to make the minimal changes to non-tests files in the /repo directory to ensure the is satisfied.\n'
+ 'Follow these steps to resolve the issue:\n'
+ '1. As a first step, it might be a good idea to explore the repo to familiarize yourself with its structure.\n'
+ '2. Create a script to reproduce the error and execute it with `python ` using the BashTool, to confirm the error\n'
+ '3. Edit the sourcecode of the repo to resolve the issue\n'
+ '4. Rerun your reproduce script and confirm that the error is fixed!\n'
+ '5. Think about edgecases and make sure your fix handles them as well\n'
+ "Your thinking should be thorough and so it's fine if it's very long.\n"
)
- if USE_HINT_TEXT and instance.hints_text:
- instruction += f'# Hints\n{instance.hints_text}\n\n'
- instruction += (
- 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
- 'You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.\n'
- 'You should verify that the issue is resolved and any new tests you create pass successfully.\n'
- 'You should NEVER use web browsing or any other web-based tools.\n'
- 'You should ALWAYS use the default Python interpreter available in the environment to run code related to the provided issue and/or repository.\n'
- )
-
- # NOTE: You can actually set slightly different instruction for different agents
- instruction += AGENT_CLS_TO_INST_SUFFIX[metadata.agent_class]
return instruction
@@ -153,6 +151,12 @@ def get_config(
f'{metadata.llm_config.log_completions_folder}'
)
config.set_llm_config(metadata.llm_config)
+ agent_config = AgentConfig(
+ codeact_enable_jupyter=False,
+ codeact_enable_browsing_delegate=False,
+ codeact_enable_llm_editor=False,
+ )
+ config.set_agent_config(agent_config)
return config
@@ -312,7 +316,7 @@ def complete_runtime(
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
- obs.exit_code == 0,
+ isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
)
@@ -322,7 +326,7 @@ def complete_runtime(
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
- obs.exit_code == 0,
+ isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
f'Failed to git config --global core.pager "": {str(obs)}',
)
@@ -331,7 +335,10 @@ def complete_runtime(
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
- assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {str(obs)}')
+ assert_and_raise(
+ isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
+ f'Failed to git add -A: {str(obs)}',
+ )
n_retries = 0
git_patch = None
@@ -404,6 +411,7 @@ def process_instance(
if (
state.last_error
and 'fatal error during agent execution' in state.last_error
+ and 'stuck in a loop' not in state.last_error
):
raise EvalException('Fatal error detected: ' + state.last_error)
@@ -488,6 +496,7 @@ if __name__ == '__main__':
llm_config = None
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
+ llm_config.log_completions = True
if llm_config is None:
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
diff --git a/evaluation/swe_bench/scripts/eval/convert_oh_folder_to_swebench_submission.sh b/evaluation/swe_bench/scripts/eval/convert_oh_folder_to_swebench_submission.sh
new file mode 100755
index 0000000000..8bbaa6ddce
--- /dev/null
+++ b/evaluation/swe_bench/scripts/eval/convert_oh_folder_to_swebench_submission.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+FOLDER_PATH=$1
+NEW_FOLDER_PATH=${FOLDER_PATH}.swebench_submission
+mkdir -p $NEW_FOLDER_PATH
+
+# Build all_preds.jsonl
+poetry run python evaluation/swe_bench/scripts/eval/convert_oh_output_to_swe_json.py $FOLDER_PATH/output.jsonl
+mv $FOLDER_PATH/output.swebench.jsonl $NEW_FOLDER_PATH/all_preds.jsonl
+
+# Build trajs/
+mkdir -p $NEW_FOLDER_PATH/trajs
+for instance_dir in $FOLDER_PATH/llm_completions/*/; do
+ instance_id=$(basename "$instance_dir")
+ latest_json=$(ls -t "$instance_dir"/*.json | head -n1)
+ if [ -n "$latest_json" ]; then
+ cat "$latest_json" | jq -r '.messages' > "$NEW_FOLDER_PATH/trajs/$instance_id.json"
+ fi
+done
+
+# Build logs/
+# check if $FOLDER_PATH/eval_outputs exists, if so copy over - else raise error
+if [ -d "$FOLDER_PATH/eval_outputs" ]; then
+ cp -r $FOLDER_PATH/eval_outputs $NEW_FOLDER_PATH/logs
+else
+ echo "Error: $FOLDER_PATH/eval_outputs does not exist. You should run the local docker eval_infer.sh first."
+ exit 1
+fi
diff --git a/evaluation/utils/shared.py b/evaluation/utils/shared.py
index d33658f339..b8d2ad281a 100644
--- a/evaluation/utils/shared.py
+++ b/evaluation/utils/shared.py
@@ -104,7 +104,7 @@ def codeact_user_response(
)
msg = (
'Please continue working on the task on whatever approach you think is suitable.\n'
- 'If you think you have solved the task, please first send your answer to user through message and then exit .\n'
+ 'If you think you have solved the task, please first send your answer to user through message and then finish the interaction.\n'
f'{encaps_str}'
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
)
diff --git a/openhands/agenthub/codeact_agent/action_parser.py b/openhands/agenthub/codeact_agent/action_parser.py
index bf926a6739..893effd825 100644
--- a/openhands/agenthub/codeact_agent/action_parser.py
+++ b/openhands/agenthub/codeact_agent/action_parser.py
@@ -64,6 +64,23 @@ class CodeActResponseParser(ResponseParser):
return action_parser.parse(action_str)
return self.default_parser.parse(action_str)
+ def action_to_str(self, action: Action) -> str:
+ if isinstance(action, CmdRunAction):
+ return (
+ f'{action.thought}\n\n{action.command}\n'
+ )
+ elif isinstance(action, IPythonRunCellAction):
+ return f'{action.thought}\n\n{action.code}\n'
+ elif isinstance(action, AgentDelegateAction):
+ return f'{action.thought}\n\n{action.inputs["task"]}\n'
+ elif isinstance(action, FileEditAction):
+ return f'{action.thought}\n\n{action.content}\n'
+ elif isinstance(action, MessageAction):
+ return action.content
+ elif isinstance(action, AgentFinishAction) and action.source == 'agent':
+ return action.thought
+ return ''
+
class CodeActActionParserFinish(ActionParser):
"""Parser action:
diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py
index c8342ca11f..997d424c51 100644
--- a/openhands/agenthub/codeact_agent/codeact_agent.py
+++ b/openhands/agenthub/codeact_agent/codeact_agent.py
@@ -1,10 +1,16 @@
+import json
import os
+from collections import deque
from itertools import islice
+from litellm import ModelResponse
+
+import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
from openhands.agenthub.codeact_agent.action_parser import CodeActResponseParser
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
+from openhands.core.logger import openhands_logger as logger
from openhands.core.message import ImageContent, Message, TextContent
from openhands.events.action import (
Action,
@@ -36,7 +42,7 @@ from openhands.utils.prompt import PromptManager
class CodeActAgent(Agent):
- VERSION = '2.0'
+ VERSION = '2.1'
"""
The Code Act Agent is a minimalist agent.
The agent works by passing the model a list of action-observation pairs and prompting the model to take the next step.
@@ -63,8 +69,7 @@ class CodeActAgent(Agent):
AgentSkillsRequirement(),
JupyterRequirement(),
]
-
- action_parser = CodeActResponseParser()
+ obs_prefix = 'OBSERVATION:\n'
def __init__(
self,
@@ -88,66 +93,165 @@ class CodeActAgent(Agent):
if config.micro_agent_name
else None
)
-
- self.prompt_manager = PromptManager(
- prompt_dir=os.path.join(os.path.dirname(__file__)),
- agent_skills_docs=AgentSkillsRequirement.documentation,
- micro_agent=self.micro_agent,
- )
-
- def action_to_str(self, action: Action) -> str:
- if isinstance(action, CmdRunAction):
- return (
- f'{action.thought}\n\n{action.command}\n'
+ if (
+ self.config.function_calling
+ and not self.llm.config.supports_function_calling
+ ):
+ logger.warning(
+ f'Function calling not supported for model {self.llm.config.model}. '
+ 'Disabling function calling.'
)
- elif isinstance(action, IPythonRunCellAction):
- return f'{action.thought}\n\n{action.code}\n'
- elif isinstance(action, AgentDelegateAction):
- return f'{action.thought}\n\n{action.inputs["task"]}\n'
- elif isinstance(action, FileEditAction):
- return f'{action.thought}\n\n{action.content}\n'
- elif isinstance(action, MessageAction):
- return action.content
- elif isinstance(action, AgentFinishAction) and action.source == 'agent':
- return action.thought
- return ''
+ self.config.function_calling = False
- def get_action_message(self, action: Action) -> Message | None:
+ if self.config.function_calling:
+ # Function calling mode
+ self.tools = codeact_function_calling.get_tools(
+ codeact_enable_browsing_delegate=self.config.codeact_enable_browsing_delegate,
+ codeact_enable_jupyter=self.config.codeact_enable_jupyter,
+ codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
+ )
+ logger.info(
+ f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2)}'
+ )
+ self.system_prompt = codeact_function_calling.SYSTEM_PROMPT
+ self.initial_user_message = None
+ else:
+ # Non-function-calling mode
+ self.action_parser = CodeActResponseParser()
+ self.prompt_manager = PromptManager(
+ prompt_dir=os.path.join(os.path.dirname(__file__)),
+ agent_skills_docs=AgentSkillsRequirement.documentation,
+ micro_agent=self.micro_agent,
+ )
+ self.system_prompt = self.prompt_manager.system_message
+ self.initial_user_message = self.prompt_manager.initial_user_message
+
+ self.pending_actions: deque[Action] = deque()
+
+ def get_action_message(
+ self,
+ action: Action,
+ pending_tool_call_action_messages: dict[str, Message],
+ ) -> list[Message]:
+ """Converts an action into a message format that can be sent to the LLM.
+
+ This method handles different types of actions and formats them appropriately:
+ 1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish:
+ - In function calling mode: Stores the LLM's response in pending_tool_call_action_messages
+ - In non-function calling mode: Creates a message with the action string
+ 2. For MessageActions: Creates a message with the text content and optional image content
+
+ Args:
+ action (Action): The action to convert. Can be one of:
+ - AgentDelegateAction: For delegating tasks to other agents
+ - CmdRunAction: For executing bash commands
+ - IPythonRunCellAction: For running IPython code
+ - FileEditAction: For editing files
+ - AgentFinishAction: For ending the interaction
+ - MessageAction: For sending messages
+ pending_tool_call_action_messages (dict[str, Message]): Dictionary mapping response IDs
+ to their corresponding messages. Used in function calling mode to track tool calls
+ that are waiting for their results.
+
+ Returns:
+ list[Message]: A list containing the formatted message(s) for the action.
+ May be empty if the action is handled as a tool call in function calling mode.
+
+ Note:
+ In function calling mode, tool-based actions are stored in pending_tool_call_action_messages
+ rather than being returned immediately. They will be processed later when all corresponding
+ tool call results are available.
+ """
+ # create a regular message from an event
if isinstance(
action,
(
AgentDelegateAction,
CmdRunAction,
IPythonRunCellAction,
- MessageAction,
FileEditAction,
),
) or (isinstance(action, AgentFinishAction) and action.source == 'agent'):
- content = [TextContent(text=self.action_to_str(action))]
+ if self.config.function_calling:
+ tool_metadata = action.tool_call_metadata
+ assert tool_metadata is not None, (
+ 'Tool call metadata should NOT be None when function calling is enabled. Action: '
+ + str(action)
+ )
- if (
- self.llm.vision_is_active()
- and isinstance(action, MessageAction)
- and action.images_urls
- ):
+ llm_response: ModelResponse = tool_metadata.model_response
+ assistant_msg = llm_response.choices[0].message
+ # Add the LLM message (assistant) that initiated the tool calls
+ # (overwrites any previous message with the same response_id)
+ pending_tool_call_action_messages[llm_response.id] = Message(
+ role=assistant_msg.role,
+ # tool call content SHOULD BE a string
+ content=[TextContent(text=assistant_msg.content)]
+ if assistant_msg.content is not None
+ else [],
+ tool_calls=assistant_msg.tool_calls,
+ )
+ return []
+ else:
+ content = [TextContent(text=self.action_parser.action_to_str(action))]
+ return [
+ Message(
+ role='user' if action.source == 'user' else 'assistant',
+ content=content,
+ )
+ ]
+ elif isinstance(action, MessageAction):
+ role = 'user' if action.source == 'user' else 'assistant'
+ content = [TextContent(text=action.content)]
+ if self.llm.vision_is_active() and action.images_urls:
content.append(ImageContent(image_urls=action.images_urls))
+ return [
+ Message(
+ role=role,
+ content=content,
+ )
+ ]
+ return []
- return Message(
- role='user' if action.source == 'user' else 'assistant', content=content
- )
- return None
+ def get_observation_message(
+ self,
+ obs: Observation,
+ tool_call_id_to_message: dict[str, Message],
+ ) -> list[Message]:
+ """Converts an observation into a message format that can be sent to the LLM.
- def get_observation_message(self, obs: Observation) -> Message | None:
+ This method handles different types of observations and formats them appropriately:
+ - CmdOutputObservation: Formats command execution results with exit codes
+ - IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images
+ - FileEditObservation: Formats file editing results
+ - AgentDelegateObservation: Formats results from delegated agent tasks
+ - ErrorObservation: Formats error messages from failed actions
+ - UserRejectObservation: Formats user rejection messages
+
+ In function calling mode, observations with tool_call_metadata are stored in
+ tool_call_id_to_message for later processing instead of being returned immediately.
+
+ Args:
+ obs (Observation): The observation to convert
+ tool_call_id_to_message (dict[str, Message]): Dictionary mapping tool call IDs
+ to their corresponding messages (used in function calling mode)
+
+ Returns:
+ list[Message]: A list containing the formatted message(s) for the observation.
+ May be empty if the observation is handled as a tool response in function calling mode.
+
+ Raises:
+ ValueError: If the observation type is unknown
+ """
+ message: Message
max_message_chars = self.llm.config.max_message_chars
obs_prefix = 'OBSERVATION:\n'
if isinstance(obs, CmdOutputObservation):
text = obs_prefix + truncate_content(
obs.content + obs.interpreter_details, max_message_chars
)
- text += (
- f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
- )
- return Message(role='user', content=[TextContent(text=text)])
+ text += f'\n[Command finished with exit code {obs.exit_code}]'
+ message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, IPythonRunCellObservation):
text = obs_prefix + obs.content
# replace base64 images with a placeholder
@@ -159,29 +263,45 @@ class CodeActAgent(Agent):
)
text = '\n'.join(splitted)
text = truncate_content(text, max_message_chars)
- return Message(role='user', content=[TextContent(text=text)])
+ message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, FileEditObservation):
text = obs_prefix + truncate_content(str(obs), max_message_chars)
- return Message(role='user', content=[TextContent(text=text)])
+ message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, AgentDelegateObservation):
text = obs_prefix + truncate_content(
obs.outputs['content'] if 'content' in obs.outputs else '',
max_message_chars,
)
- return Message(role='user', content=[TextContent(text=text)])
+ message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, ErrorObservation):
text = obs_prefix + truncate_content(obs.content, max_message_chars)
text += '\n[Error occurred in processing last action]'
- return Message(role='user', content=[TextContent(text=text)])
+ message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, UserRejectObservation):
text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
text += '\n[Last action has been rejected by the user]'
- return Message(role='user', content=[TextContent(text=text)])
+ message = Message(role='user', content=[TextContent(text=text)])
else:
# If an observation message is not returned, it will cause an error
# when the LLM tries to return the next message
raise ValueError(f'Unknown observation type: {type(obs)}')
+ if self.config.function_calling:
+ # Update the message as tool response properly
+ if (tool_call_metadata := obs.tool_call_metadata) is not None:
+ tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message(
+ role='tool',
+ content=message.content,
+ tool_call_id=tool_call_metadata.tool_call_id,
+ name=tool_call_metadata.function_name,
+ )
+ # No need to return the observation message
+ # because it will be added by get_action_message when all the corresponding
+ # tool calls in the SAME request are processed
+ return []
+
+ return [message]
+
def reset(self) -> None:
"""Resets the CodeAct Agent."""
super().reset()
@@ -200,6 +320,10 @@ class CodeActAgent(Agent):
- MessageAction(content) - Message action to run (e.g. ask for clarification)
- AgentFinishAction() - end the interaction
"""
+ # Continue with pending actions if any
+ if self.pending_actions:
+ return self.pending_actions.popleft()
+
# if we're done, go back
latest_user_message = state.history.get_last_user_message()
if latest_user_message and latest_user_message.strip() == '/exit':
@@ -207,87 +331,172 @@ class CodeActAgent(Agent):
# prepare what we want to send to the LLM
messages = self._get_messages(state)
- params = {
+ params: dict = {
'messages': self.llm.format_messages_for_llm(messages),
- 'stop': [
+ }
+ if self.config.function_calling:
+ params['tools'] = self.tools
+ else:
+ params['stop'] = [
'',
'',
'',
'',
- ],
- }
-
+ ]
response = self.llm.completion(**params)
- return self.action_parser.parse(response)
+ if self.config.function_calling:
+ actions = codeact_function_calling.response_to_actions(response)
+ for action in actions:
+ self.pending_actions.append(action)
+ return self.pending_actions.popleft()
+ else:
+ return self.action_parser.parse(response)
def _get_messages(self, state: State) -> list[Message]:
+ """Constructs the message history for the LLM conversation.
+
+ This method builds a structured conversation history by processing events from the state
+ and formatting them into messages that the LLM can understand. It handles both regular
+ message flow and function-calling scenarios.
+
+ The method performs the following steps:
+ 1. Initializes with system prompt and optional initial user message
+ 2. Processes events (Actions and Observations) into messages
+ 3. Handles tool calls and their responses in function-calling mode
+ 4. Manages message role alternation (user/assistant/tool)
+ 5. Applies caching for specific LLM providers (e.g., Anthropic)
+ 6. Adds environment reminders for non-function-calling mode
+
+ Args:
+ state (State): The current state object containing conversation history and other metadata
+
+ Returns:
+ list[Message]: A list of formatted messages ready for LLM consumption, including:
+ - System message with prompt
+ - Initial user message (if configured)
+ - Action messages (from both user and assistant)
+ - Observation messages (including tool responses)
+ - Environment reminders (in non-function-calling mode)
+
+ Note:
+ - In function-calling mode, tool calls and their responses are carefully tracked
+ to maintain proper conversation flow
+ - Messages from the same role are combined to prevent consecutive same-role messages
+ - For Anthropic models, specific messages are cached according to their documentation
+ """
messages: list[Message] = [
Message(
role='system',
content=[
TextContent(
- text=self.prompt_manager.system_message,
+ text=self.system_prompt,
cache_prompt=self.llm.is_caching_prompt_active(), # Cache system prompt
)
],
- ),
- Message(
- role='user',
- content=[
- TextContent(
- text=self.prompt_manager.initial_user_message,
- cache_prompt=self.llm.is_caching_prompt_active(), # if the user asks the same query,
- )
- ],
- ),
+ )
]
+ if self.initial_user_message:
+ messages.append(
+ Message(
+ role='user',
+ content=[TextContent(text=self.initial_user_message)],
+ )
+ )
- for event in state.history.get_events():
+ pending_tool_call_action_messages: dict[str, Message] = {}
+ tool_call_id_to_message: dict[str, Message] = {}
+ events = list(state.history.get_events())
+ for event in events:
# create a regular message from an event
if isinstance(event, Action):
- message = self.get_action_message(event)
+ messages_to_add = self.get_action_message(
+ action=event,
+ pending_tool_call_action_messages=pending_tool_call_action_messages,
+ )
elif isinstance(event, Observation):
- message = self.get_observation_message(event)
+ messages_to_add = self.get_observation_message(
+ obs=event,
+ tool_call_id_to_message=tool_call_id_to_message,
+ )
else:
raise ValueError(f'Unknown event type: {type(event)}')
- # add regular message
- if message:
- # handle error if the message is the SAME role as the previous message
- # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'}
- # there shouldn't be two consecutive messages from the same role
- if messages and messages[-1].role == message.role:
- messages[-1].content.extend(message.content)
- else:
- messages.append(message)
+ # Check pending tool call action messages and see if they are complete
+ _response_ids_to_remove = []
+ for (
+ response_id,
+ pending_message,
+ ) in pending_tool_call_action_messages.items():
+ assert pending_message.tool_calls is not None, (
+ 'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. '
+ f'Pending message: {pending_message}'
+ )
+ if all(
+ tool_call.id in tool_call_id_to_message
+ for tool_call in pending_message.tool_calls
+ ):
+ # If complete:
+ # -- 1. Add the message that **initiated** the tool calls
+ messages_to_add.append(pending_message)
+ # -- 2. Add the tool calls **results***
+ for tool_call in pending_message.tool_calls:
+ messages_to_add.append(tool_call_id_to_message[tool_call.id])
+ tool_call_id_to_message.pop(tool_call.id)
+ _response_ids_to_remove.append(response_id)
+ # Cleanup the processed pending tool messages
+ for response_id in _response_ids_to_remove:
+ pending_tool_call_action_messages.pop(response_id)
+
+ for message in messages_to_add:
+ # add regular message
+ if message:
+ # handle error if the message is the SAME role as the previous message
+ # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'}
+ # there shouldn't be two consecutive messages from the same role
+ # NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id
+ if (
+ messages
+ and messages[-1].role == message.role
+ and message.role != 'tool'
+ ):
+ messages[-1].content.extend(message.content)
+ else:
+ messages.append(message)
- # Add caching to the last 2 user messages
if self.llm.is_caching_prompt_active():
- user_turns_processed = 0
+ # NOTE: this is only needed for anthropic
+ # following logic here:
+ # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262
+ breakpoints_remaining = 3 # remaining 1 for system/tool
for message in reversed(messages):
- if message.role == 'user' and user_turns_processed < 2:
- message.content[
- -1
- ].cache_prompt = True # Last item inside the message content
- user_turns_processed += 1
+ if message.role == 'user' or message.role == 'tool':
+ if breakpoints_remaining > 0:
+ message.content[
+ -1
+ ].cache_prompt = True # Last item inside the message content
+ breakpoints_remaining -= 1
+ else:
+ break
- # The latest user message is important:
- # we want to remind the agent of the environment constraints
- latest_user_message = next(
- islice(
- (
- m
- for m in reversed(messages)
- if m.role == 'user'
- and any(isinstance(c, TextContent) for c in m.content)
+ if not self.config.function_calling:
+ # The latest user message is important:
+ # we want to remind the agent of the environment constraints
+ latest_user_message = next(
+ islice(
+ (
+ m
+ for m in reversed(messages)
+ if m.role == 'user'
+ and any(isinstance(c, TextContent) for c in m.content)
+ ),
+ 1,
),
- 1,
- ),
- None,
- )
- if latest_user_message:
- reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with .'
- latest_user_message.content.append(TextContent(text=reminder_text))
+ None,
+ )
+ # do not add this for function calling
+ if latest_user_message:
+ reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with .'
+ latest_user_message.content.append(TextContent(text=reminder_text))
return messages
diff --git a/openhands/agenthub/codeact_agent/function_calling.py b/openhands/agenthub/codeact_agent/function_calling.py
new file mode 100644
index 0000000000..f5519124ac
--- /dev/null
+++ b/openhands/agenthub/codeact_agent/function_calling.py
@@ -0,0 +1,397 @@
+"""This file contains the function calling implementation for different actions.
+
+This is similar to the functionality of `CodeActResponseParser`.
+"""
+
+import json
+
+from litellm import (
+ ChatCompletionToolParam,
+ ChatCompletionToolParamFunctionChunk,
+ ModelResponse,
+)
+
+from openhands.core.logger import openhands_logger as logger
+from openhands.events.action import (
+ Action,
+ AgentDelegateAction,
+ AgentFinishAction,
+ CmdRunAction,
+ FileEditAction,
+ IPythonRunCellAction,
+ MessageAction,
+)
+from openhands.events.tool import ToolCallMetadata
+
+SYSTEM_PROMPT = """You are a helpful assistant that can interact with a computer to solve tasks.
+
+* If user provides a path, you should NOT assume it's relative to the current working directory. Instead, you should explore the file system to find the file before working on it.
+
+"""
+
+_BASH_DESCRIPTION = """Execute a bash command in the terminal.
+* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
+* Interactive: If a bash command returns exit code `-1`, this means the process is not yet finished. The assistant must then send a second call to terminal with an empty `command` (which will retrieve any additional logs), or it can send additional text (set `command` to the text) to STDIN of the running process, or it can send command=`ctrl+c` to interrupt the process.
+* Timeout: If a command execution result says "Command timed out. Sending SIGINT to the process", the assistant should retry running the command in the background.
+"""
+
+CmdRunTool = ChatCompletionToolParam(
+ type='function',
+ function=ChatCompletionToolParamFunctionChunk(
+ name='execute_bash',
+ description=_BASH_DESCRIPTION,
+ parameters={
+ 'type': 'object',
+ 'properties': {
+ 'command': {
+ 'type': 'string',
+ 'description': 'The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process.',
+ },
+ },
+ 'required': ['command'],
+ },
+ ),
+)
+
+_IPYTHON_DESCRIPTION = """Run a cell of Python code in an IPython environment.
+* The assistant should define variables and import packages before using them.
+* The variable defined in the IPython environment will not be available outside the IPython environment (e.g., in terminal).
+"""
+# We are not using agentskills's file_ops for viewing files now because StrReplaceEditorTool already supports viewing files
+# """* Apart from the standard Python library, the assistant can also use the following functions (already imported):
+# {AgentSkillsRequirement.documentation}"""
+
+IPythonTool = ChatCompletionToolParam(
+ type='function',
+ function=ChatCompletionToolParamFunctionChunk(
+ name='execute_ipython_cell',
+ description=_IPYTHON_DESCRIPTION,
+ parameters={
+ 'type': 'object',
+ 'properties': {
+ 'code': {
+ 'type': 'string',
+ 'description': 'The Python code to execute. Supports magic commands like %pip.',
+ },
+ },
+ 'required': ['code'],
+ },
+ ),
+)
+
+_FILE_EDIT_DESCRIPTION = """Edit a file.
+* The assistant can edit files by specifying the file path and providing a draft of the new file content.
+* The draft content doesn't need to be exactly the same as the existing file; the assistant may skip unchanged lines using comments like `# unchanged` to indicate unchanged sections.
+* IMPORTANT: For large files (e.g., > 300 lines), specify the range of lines to edit using `start` and `end` (1-indexed, inclusive). The range should be smaller than 300 lines.
+* To append to a file, set both `start` and `end` to `-1`.
+* If the file doesn't exist, a new file will be created with the provided content.
+
+**Example 1: general edit for short files**
+For example, given an existing file `/path/to/file.py` that looks like this:
+(this is the end of the file)
+1|class MyClass:
+2| def __init__(self):
+3| self.x = 1
+4| self.y = 2
+5| self.z = 3
+6|
+7|print(MyClass().z)
+8|print(MyClass().x)
+(this is the end of the file)
+
+The assistant wants to edit the file to look like this:
+(this is the end of the file)
+1|class MyClass:
+2| def __init__(self):
+3| self.x = 1
+4| self.y = 2
+5|
+6|print(MyClass().y)
+(this is the end of the file)
+
+The assistant may produce an edit action like this:
+path="/path/to/file.txt" start=1 end=-1
+content=```
+class MyClass:
+ def __init__(self):
+ # no changes before
+ self.y = 2
+ # self.z is removed
+
+# MyClass().z is removed
+print(MyClass().y)
+```
+
+**Example 2: append to file for short files**
+For example, given an existing file `/path/to/file.py` that looks like this:
+(this is the end of the file)
+1|class MyClass:
+2| def __init__(self):
+3| self.x = 1
+4| self.y = 2
+5| self.z = 3
+6|
+7|print(MyClass().z)
+8|print(MyClass().x)
+(this is the end of the file)
+
+To append the following lines to the file:
+```python
+print(MyClass().y)
+```
+
+The assistant may produce an edit action like this:
+path="/path/to/file.txt" start=-1 end=-1
+content=```
+print(MyClass().y)
+```
+
+**Example 3: edit for long files**
+
+Given an existing file `/path/to/file.py` that looks like this:
+(1000 more lines above)
+1001|class MyClass:
+1002| def __init__(self):
+1003| self.x = 1
+1004| self.y = 2
+1005| self.z = 3
+1006|
+1007|print(MyClass().z)
+1008|print(MyClass().x)
+(2000 more lines below)
+
+The assistant wants to edit the file to look like this:
+
+(1000 more lines above)
+1001|class MyClass:
+1002| def __init__(self):
+1003| self.x = 1
+1004| self.y = 2
+1005|
+1006|print(MyClass().y)
+(2000 more lines below)
+
+The assistant may produce an edit action like this:
+path="/path/to/file.txt" start=1001 end=1008
+content=```
+class MyClass:
+ def __init__(self):
+ # no changes before
+ self.y = 2
+ # self.z is removed
+
+# MyClass().z is removed
+print(MyClass().y)
+```
+"""
+
+LLMBasedFileEditTool = ChatCompletionToolParam(
+ type='function',
+ function=ChatCompletionToolParamFunctionChunk(
+ name='edit_file',
+ description=_FILE_EDIT_DESCRIPTION,
+ parameters={
+ 'type': 'object',
+ 'properties': {
+ 'path': {
+ 'type': 'string',
+ 'description': 'The absolute path to the file to be edited.',
+ },
+ 'new_content_draft': {
+ 'type': 'string',
+ 'description': 'A draft of the new content for the file being edited. Note that the assistant may skip unchanged lines.',
+ },
+ 'start': {
+ 'type': 'integer',
+ 'description': 'The starting line number for the edit (1-indexed, inclusive). Default is 1.',
+ },
+ 'end': {
+ 'type': 'integer',
+ 'description': 'The ending line number for the edit (1-indexed, inclusive). Default is -1 (end of file).',
+ },
+ },
+ 'required': ['path', 'content'],
+ },
+ ),
+)
+
+_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files
+* State is persistent across command calls and discussions with the user
+* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
+* The `create` command cannot be used if the specified `path` already exists as a file
+* If a `command` generates a long output, it will be truncated and marked with ``
+* The `undo_edit` command will revert the last edit made to the file at `path`
+
+Notes for using the `str_replace` command:
+* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!
+* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique
+* The `new_str` parameter should contain the edited lines that should replace the `old_str`
+"""
+
+StrReplaceEditorTool = ChatCompletionToolParam(
+ type='function',
+ function=ChatCompletionToolParamFunctionChunk(
+ name='str_replace_editor',
+ description=_STR_REPLACE_EDITOR_DESCRIPTION,
+ parameters={
+ 'type': 'object',
+ 'properties': {
+ 'command': {
+ 'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
+ 'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'],
+ 'type': 'string',
+ },
+ 'path': {
+ 'description': 'Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.',
+ 'type': 'string',
+ },
+ 'file_text': {
+ 'description': 'Required parameter of `create` command, with the content of the file to be created.',
+ 'type': 'string',
+ },
+ 'old_str': {
+ 'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
+ 'type': 'string',
+ },
+ 'new_str': {
+ 'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
+ 'type': 'string',
+ },
+ 'insert_line': {
+ 'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
+ 'type': 'integer',
+ },
+ 'view_range': {
+ 'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
+ 'items': {'type': 'integer'},
+ 'type': 'array',
+ },
+ },
+ 'required': ['command', 'path'],
+ },
+ ),
+)
+
+_BROWSER_DELEGATION = """Delegate the task to another browsing agent.
+The assistant should delegate the task if it needs to browse the Internet.
+"""
+
+BrowserDelegationTool = ChatCompletionToolParam(
+ type='function',
+ function=ChatCompletionToolParamFunctionChunk(
+ name='delegate_to_browsing_agent',
+ description=_BROWSER_DELEGATION,
+ parameters={
+ 'type': 'object',
+ 'properties': {
+ 'task': {
+ 'type': 'string',
+ 'description': 'The task for the browsing agent to execute. It should include all the necessary context and specify what information the browsing agent should return.',
+ },
+ },
+ 'required': ['task'],
+ },
+ ),
+)
+
+_FINISH_DESCRIPTION = """Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task."""
+
+FinishTool = ChatCompletionToolParam(
+ type='function',
+ function=ChatCompletionToolParamFunctionChunk(
+ name='finish',
+ description=_FINISH_DESCRIPTION,
+ ),
+)
+
+
+def combine_thought(action: Action, thought: str) -> Action:
+ if not hasattr(action, 'thought'):
+ return action
+ if thought:
+ action.thought = thought
+ return action
+
+
+def response_to_actions(response: ModelResponse) -> list[Action]:
+ actions: list[Action] = []
+ assert len(response.choices) == 1, 'Only one choice is supported for now'
+ assistant_msg = response.choices[0].message
+ if assistant_msg.tool_calls:
+ # Check if there's assistant_msg.content. If so, add it to the thought
+ thought = ''
+ if isinstance(assistant_msg.content, str):
+ thought = assistant_msg.content
+ elif isinstance(assistant_msg.content, list):
+ for msg in assistant_msg.content:
+ if msg['type'] == 'text':
+ thought += msg['text']
+
+ # Process each tool call to OpenHands action
+ for i, tool_call in enumerate(assistant_msg.tool_calls):
+ action: Action
+ try:
+ arguments = json.loads(tool_call.function.arguments)
+ except json.decoder.JSONDecodeError as e:
+ raise RuntimeError(
+ f'Failed to parse tool call arguments: {tool_call.function.arguments}'
+ ) from e
+ if tool_call.function.name == 'execute_bash':
+ action = CmdRunAction(**arguments)
+ elif tool_call.function.name == 'execute_ipython_cell':
+ action = IPythonRunCellAction(**arguments)
+ elif tool_call.function.name == 'delegate_to_browsing_agent':
+ action = AgentDelegateAction(
+ agent='BrowsingAgent',
+ inputs=arguments,
+ )
+ elif tool_call.function.name == 'finish':
+ action = AgentFinishAction()
+ elif tool_call.function.name == 'edit_file':
+ action = FileEditAction(**arguments)
+ elif tool_call.function.name == 'str_replace_editor':
+ # We implement this in agent_skills, which can be used via Jupyter
+ # convert tool_call.function.arguments to kwargs that can be passed to file_editor
+ code = f'print(file_editor(**{arguments}))'
+ logger.debug(
+ f'TOOL CALL: str_replace_editor -> file_editor with code: {code}'
+ )
+ action = IPythonRunCellAction(code=code, include_extra=False)
+ else:
+ raise RuntimeError(f'Unknown tool call: {tool_call.function.name}')
+
+ # We only add thought to the first action
+ if i == 0:
+ action = combine_thought(action, thought)
+ # Add metadata for tool calling
+ action.tool_call_metadata = ToolCallMetadata(
+ tool_call_id=tool_call.id,
+ function_name=tool_call.function.name,
+ model_response=response,
+ total_calls_in_response=len(assistant_msg.tool_calls),
+ )
+ actions.append(action)
+ else:
+ actions.append(
+ MessageAction(content=assistant_msg.content, wait_for_response=True)
+ )
+
+ assert len(actions) >= 1
+ return actions
+
+
+def get_tools(
+ codeact_enable_browsing_delegate: bool = False,
+ codeact_enable_llm_editor: bool = False,
+ codeact_enable_jupyter: bool = False,
+) -> list[ChatCompletionToolParam]:
+ tools = [CmdRunTool, FinishTool]
+ if codeact_enable_browsing_delegate:
+ tools.append(BrowserDelegationTool)
+ if codeact_enable_jupyter:
+ tools.append(IPythonTool)
+ if codeact_enable_llm_editor:
+ tools.append(LLMBasedFileEditTool)
+ else:
+ tools.append(StrReplaceEditorTool)
+ return tools
diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py
index 0660cd0309..7b7522e33a 100644
--- a/openhands/controller/agent_controller.py
+++ b/openhands/controller/agent_controller.py
@@ -453,6 +453,12 @@ class AgentController:
# and send the underlying exception to the LLM for self-correction
await self.report_error(str(e))
return
+ # FIXME: more graceful handling of litellm.exceptions.ContextWindowExceededError
+ # e.g. try to condense the memory and try again
+ except litellm.exceptions.ContextWindowExceededError as e:
+ self.state.last_error = str(e)
+ await self.set_agent_state_to(AgentState.ERROR)
+ return
if action.runnable:
if self.state.confirmation_mode and (
diff --git a/openhands/core/config/agent_config.py b/openhands/core/config/agent_config.py
index 839d09277e..db6225be8c 100644
--- a/openhands/core/config/agent_config.py
+++ b/openhands/core/config/agent_config.py
@@ -8,12 +8,20 @@ class AgentConfig:
"""Configuration for the agent.
Attributes:
+ function_calling: Whether function calling is enabled. Default is True.
+ codeact_enable_browsing_delegate: Whether browsing delegate is enabled in the action space. Default is False. Only works with function calling.
+ codeact_enable_llm_editor: Whether LLM editor is enabled in the action space. Default is False. Only works with function calling.
+ codeact_enable_jupyter: Whether Jupyter is enabled in the action space. Default is False.
micro_agent_name: The name of the micro agent to use for this agent.
memory_enabled: Whether long-term memory (embeddings) is enabled.
memory_max_threads: The maximum number of threads indexing at the same time for embeddings.
llm_config: The name of the llm config to use. If specified, this will override global llm config.
"""
+ function_calling: bool = True
+ codeact_enable_browsing_delegate: bool = False
+ codeact_enable_llm_editor: bool = False
+ codeact_enable_jupyter: bool = False
micro_agent_name: str | None = None
memory_enabled: bool = False
memory_max_threads: int = 3
diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py
index ac07b70e0b..7ad6476d7c 100644
--- a/openhands/core/config/llm_config.py
+++ b/openhands/core/config/llm_config.py
@@ -42,6 +42,7 @@ class LLMConfig:
log_completions: Whether to log LLM completions to the state.
log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
+ supports_function_calling: Whether the model supports function calling.
"""
model: str = 'gpt-4o'
@@ -61,7 +62,7 @@ class LLMConfig:
retry_min_wait: int = 15
retry_max_wait: int = 120
timeout: int | None = None
- max_message_chars: int = 10_000 # maximum number of characters in an observation's content when sent to the llm
+ max_message_chars: int = 30_000 # maximum number of characters in an observation's content when sent to the llm
temperature: float = 0.0
top_p: float = 1.0
custom_llm_provider: str | None = None
@@ -76,6 +77,7 @@ class LLMConfig:
log_completions: bool = False
log_completions_folder: str | None = None
draft_editor: Optional['LLMConfig'] = None
+ supports_function_calling: bool = False
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
diff --git a/openhands/core/logger.py b/openhands/core/logger.py
index 13b91e451e..af3ba1e35f 100644
--- a/openhands/core/logger.py
+++ b/openhands/core/logger.py
@@ -12,6 +12,9 @@ LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
DEBUG = os.getenv('DEBUG', 'False').lower() in ['true', '1', 'yes']
if DEBUG:
LOG_LEVEL = 'DEBUG'
+ import litellm
+
+ litellm.set_verbose = True
LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'False').lower() in ['true', '1', 'yes']
DISABLE_COLOR_PRINTING = False
@@ -70,11 +73,6 @@ class ColoredFormatter(logging.Formatter):
return super().format(record)
-console_formatter = ColoredFormatter(
- '\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s',
- datefmt='%H:%M:%S',
-)
-
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s:%(levelname)s: %(filename)s:%(lineno)s - %(message)s',
datefmt='%H:%M:%S',
@@ -123,10 +121,10 @@ def get_console_handler(log_level=logging.INFO, extra_info: str | None = None):
"""Returns a console handler for logging."""
console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)
- formatter_str = '%(asctime)s - %(levelname)s - %(message)s'
+ formatter_str = '\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s'
if extra_info:
formatter_str = f'{extra_info} - ' + formatter_str
- console_handler.setFormatter(logging.Formatter(formatter_str))
+ console_handler.setFormatter(ColoredFormatter(formatter_str, datefmt='%H:%M:%S'))
return console_handler
diff --git a/openhands/core/message.py b/openhands/core/message.py
index 57fadabde7..042028498b 100644
--- a/openhands/core/message.py
+++ b/openhands/core/message.py
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Literal
+from litellm import ChatCompletionMessageToolCall
from pydantic import BaseModel, Field, model_serializer
@@ -48,10 +49,16 @@ class ImageContent(Content):
class Message(BaseModel):
- role: Literal['user', 'system', 'assistant']
- content: list[TextContent | ImageContent] = Field(default=list)
+ role: Literal['user', 'system', 'assistant', 'tool']
+ content: list[TextContent | ImageContent] = Field(default_factory=list)
cache_enabled: bool = False
vision_enabled: bool = False
+ # function calling
+ # - tool calls (from LLM)
+ tool_calls: list[ChatCompletionMessageToolCall] | None = None
+ # - tool execution result (to LLM)
+ tool_call_id: str | None = None
+ name: str | None = None # name of the tool
@property
def contains_image(self) -> bool:
@@ -59,23 +66,31 @@ class Message(BaseModel):
@model_serializer
def serialize_model(self) -> dict:
- content: list[dict] | str
- # two kinds of serializer:
- # 1. vision serializer: when prompt caching or vision is enabled
- # 2. single text serializer: for other cases
- # remove this when liteLLM or providers support this format translation
- if self.cache_enabled or self.vision_enabled:
- # when prompt caching or vision is enabled, use vision serializer
- content = []
- for item in self.content:
- if isinstance(item, TextContent):
- content.append(item.model_dump())
- elif isinstance(item, ImageContent):
- content.extend(item.model_dump())
- else:
- # for other cases, concatenate all text content
- # into a single string per message
- content = '\n'.join(
- item.text for item in self.content if isinstance(item, TextContent)
- )
- return {'content': content, 'role': self.role}
+ content: list[dict] = []
+ role_tool_with_prompt_caching = False
+ for item in self.content:
+ d = item.model_dump()
+ # We have to remove cache_prompt for tool content and move it up to the message level
+ # See discussion here for details: https://github.com/BerriAI/litellm/issues/6422#issuecomment-2438765472
+ if self.role == 'tool' and item.cache_prompt:
+ role_tool_with_prompt_caching = True
+ d.pop('cache_control')
+ if isinstance(item, TextContent):
+ content.append(d)
+ elif isinstance(item, ImageContent) and self.vision_enabled:
+ content.extend(d)
+
+ ret: dict = {'content': content, 'role': self.role}
+
+ if role_tool_with_prompt_caching:
+ ret['cache_control'] = {'type': 'ephemeral'}
+
+ if self.tool_call_id is not None:
+ assert (
+ self.name is not None
+ ), 'name is required when tool_call_id is not None'
+ ret['tool_call_id'] = self.tool_call_id
+ ret['name'] = self.name
+ if self.tool_calls:
+ ret['tool_calls'] = self.tool_calls
+ return ret
diff --git a/openhands/core/utils/json.py b/openhands/core/utils/json.py
index 859cab1450..c0b22740be 100644
--- a/openhands/core/utils/json.py
+++ b/openhands/core/utils/json.py
@@ -2,6 +2,7 @@ import json
from datetime import datetime
from json_repair import repair_json
+from litellm.types.utils import ModelResponse
from openhands.core.exceptions import LLMResponseError
from openhands.events.event import Event
@@ -17,6 +18,8 @@ def my_default_encoder(obj):
return event_to_dict(obj)
if isinstance(obj, Metrics):
return obj.get()
+ if isinstance(obj, ModelResponse):
+ return obj.model_dump()
return json.JSONEncoder().default(obj)
diff --git a/openhands/events/action/commands.py b/openhands/events/action/commands.py
index 5a7e1fef52..83dd19f9d1 100644
--- a/openhands/events/action/commands.py
+++ b/openhands/events/action/commands.py
@@ -47,6 +47,9 @@ class CmdRunAction(Action):
class IPythonRunCellAction(Action):
code: str
thought: str = ''
+ include_extra: bool = (
+ True # whether to include CWD & Python interpreter in the output
+ )
action: str = ActionType.RUN_IPYTHON
runnable: ClassVar[bool] = True
confirmation_state: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED
diff --git a/openhands/events/event.py b/openhands/events/event.py
index 43a8840671..6ec68acc55 100644
--- a/openhands/events/event.py
+++ b/openhands/events/event.py
@@ -2,6 +2,7 @@ from dataclasses import dataclass
from datetime import datetime
from enum import Enum
+from openhands.events.tool import ToolCallMetadata
from openhands.llm.metrics import Metrics
@@ -71,3 +72,14 @@ class Event:
@llm_metrics.setter
def llm_metrics(self, value: Metrics) -> None:
self._llm_metrics = value
+
+ # optional field
+ @property
+ def tool_call_metadata(self) -> ToolCallMetadata | None:
+ if hasattr(self, '_tool_call_metadata'):
+ return self._tool_call_metadata # type: ignore[attr-defined]
+ return None
+
+ @tool_call_metadata.setter
+ def tool_call_metadata(self, value: ToolCallMetadata) -> None:
+ self._tool_call_metadata = value
diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py
index ee15ab6955..f0430b4cdd 100644
--- a/openhands/events/serialization/event.py
+++ b/openhands/events/serialization/event.py
@@ -6,10 +6,20 @@ from openhands.events.observation.observation import Observation
from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.observation import observation_from_dict
from openhands.events.serialization.utils import remove_fields
+from openhands.events.tool import ToolCallMetadata
# TODO: move `content` into `extras`
-TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation']
-UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause']
+TOP_KEYS = [
+ 'id',
+ 'timestamp',
+ 'source',
+ 'message',
+ 'cause',
+ 'action',
+ 'observation',
+ 'tool_call_metadata',
+]
+UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause', 'tool_call_metadata']
DELETE_FROM_TRAJECTORY_EXTRAS = {
'screenshot',
@@ -40,6 +50,8 @@ def event_from_dict(data) -> 'Event':
value = value.isoformat()
if key == 'source':
value = EventSource(value)
+ if key == 'tool_call_metadata':
+ value = ToolCallMetadata(**value)
setattr(evt, '_' + key, value)
return evt
@@ -59,6 +71,8 @@ def event_to_dict(event: 'Event') -> dict:
d['timestamp'] = d['timestamp'].isoformat()
if key == 'source' and 'source' in d:
d['source'] = d['source'].value
+ if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
+ d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
props.pop(key, None)
if 'security_risk' in props and props['security_risk'] is None:
props.pop('security_risk')
diff --git a/openhands/events/tool.py b/openhands/events/tool.py
new file mode 100644
index 0000000000..30e288dc2f
--- /dev/null
+++ b/openhands/events/tool.py
@@ -0,0 +1,11 @@
+from litellm import ModelResponse
+from pydantic import BaseModel
+
+
+class ToolCallMetadata(BaseModel):
+ # See https://docs.litellm.ai/docs/completion/function_call#step-3---second-litellmcompletion-call
+ function_name: str # Name of the function that was called
+ tool_call_id: str # ID of the tool call
+
+ model_response: ModelResponse
+ total_calls_in_response: int
diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py
index 9eb3a08aa9..1696fd13b9 100644
--- a/openhands/llm/llm.py
+++ b/openhands/llm/llm.py
@@ -1,11 +1,12 @@
import copy
-import json
import os
import time
import warnings
from functools import partial
from typing import Any
+import requests
+
from openhands.core.config import LLMConfig
with warnings.catch_warnings():
@@ -81,16 +82,47 @@ class LLM(RetryMixin, DebugMixin):
# litellm actually uses base Exception here for unknown model
self.model_info: ModelInfo | None = None
- try:
- if self.config.model.startswith('openrouter'):
- self.model_info = litellm.get_model_info(self.config.model)
- else:
+
+ if self.config.model.startswith('openrouter'):
+ self.model_info = litellm.get_model_info(self.config.model)
+ elif self.config.model.startswith('litellm_proxy/'):
+ # IF we are using LiteLLM proxy, get model info from LiteLLM proxy
+ # GET {base_url}/v1/model/info with litellm_model_id as path param
+ response = requests.get(
+ f'{self.config.base_url}/v1/model/info',
+ headers={'Authorization': f'Bearer {self.config.api_key}'},
+ )
+ all_model_info = response.json()['data']
+ current_model_info = next(
+ (
+ info
+ for info in all_model_info
+ if info['model_name']
+ == self.config.model.removeprefix('litellm_proxy/')
+ ),
+ None,
+ )
+ if current_model_info:
+ self.model_info = current_model_info['model_info']
+
+ # Last two attempts to get model info from NAME
+ if not self.model_info:
+ try:
self.model_info = litellm.get_model_info(
self.config.model.split(':')[0]
)
- # noinspection PyBroadException
- except Exception as e:
- logger.warning(f'Could not get model info for {config.model}:\n{e}')
+ # noinspection PyBroadException
+ except Exception:
+ pass
+ if not self.model_info:
+ try:
+ self.model_info = litellm.get_model_info(
+ self.config.model.split('/')[-1]
+ )
+ # noinspection PyBroadException
+ except Exception:
+ pass
+ logger.info(f'Model info: {self.model_info}')
if self.config.log_completions:
if self.config.log_completions_folder is None:
@@ -126,6 +158,11 @@ class LLM(RetryMixin, DebugMixin):
):
self.config.max_output_tokens = self.model_info['max_tokens']
+ self.config.supports_function_calling = (
+ self.model_info is not None
+ and self.model_info.get('supports_function_calling', False)
+ )
+
self._completion = partial(
litellm_completion,
model=self.config.model,
@@ -144,6 +181,8 @@ class LLM(RetryMixin, DebugMixin):
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
+ if self.config.supports_function_calling:
+ logger.debug('LLM: model supports function calling')
completion_unwrapped = self._completion
@@ -195,26 +234,32 @@ class LLM(RetryMixin, DebugMixin):
try:
# we don't support streaming here, thus we get a ModelResponse
resp: ModelResponse = completion_unwrapped(*args, **kwargs)
-
# log for evals or other scripts that need the raw completion
if self.config.log_completions:
assert self.config.log_completions_folder is not None
log_file = os.path.join(
self.config.log_completions_folder,
# use the metric model name (for draft editor)
- f'{self.metrics.model_name}-{time.time()}.json',
+ f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
)
+ from openhands.core.utils import json
+
with open(log_file, 'w') as f:
- json.dump(
- {
- 'messages': messages,
- 'response': resp,
- 'args': args,
- 'kwargs': kwargs,
- 'timestamp': time.time(),
- 'cost': self._completion_cost(resp),
- },
- f,
+ f.write(
+ json.dumps(
+ {
+ 'messages': messages,
+ 'response': resp,
+ 'args': args,
+ 'kwargs': {
+ k: v
+ for k, v in kwargs.items()
+ if k != 'messages'
+ },
+ 'timestamp': time.time(),
+ 'cost': self._completion_cost(resp),
+ },
+ )
)
message_back: str = resp['choices'][0]['message']['content']
@@ -390,16 +435,18 @@ class LLM(RetryMixin, DebugMixin):
logger.info(f'Using custom cost per token: {cost_per_token}')
extra_kwargs['custom_cost_per_token'] = cost_per_token
- if not self._is_local():
- try:
+ try:
+ # try directly get response_cost from response
+ cost = getattr(response, '_hidden_params', {}).get('response_cost', None)
+ if cost is None:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
)
- self.metrics.add_cost(cost)
- return cost
- except Exception:
- self.cost_metric_supported = False
- logger.warning('Cost calculation not supported for this model.')
+ self.metrics.add_cost(cost)
+ return cost
+ except Exception:
+ self.cost_metric_supported = False
+ logger.warning('Cost calculation not supported for this model.')
return 0.0
def __str__(self):
diff --git a/openhands/runtime/action_execution_server.py b/openhands/runtime/action_execution_server.py
index 2da1372532..1cec1b7083 100644
--- a/openhands/runtime/action_execution_server.py
+++ b/openhands/runtime/action_execution_server.py
@@ -194,10 +194,11 @@ class ActionExecutor:
obs: IPythonRunCellObservation = await _jupyter_plugin.run(action)
obs.content = obs.content.rstrip()
- obs.content += (
- f'\n[Jupyter current working directory: {self.bash_session.pwd}]'
- )
- obs.content += f'\n[Jupyter Python interpreter: {_jupyter_plugin.python_interpreter_path}]'
+ if action.include_extra:
+ obs.content += (
+ f'\n[Jupyter current working directory: {self.bash_session.pwd}]'
+ )
+ obs.content += f'\n[Jupyter Python interpreter: {_jupyter_plugin.python_interpreter_path}]'
return obs
else:
raise RuntimeError(
diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py
index 0d752d71be..81e0366665 100644
--- a/openhands/runtime/base.py
+++ b/openhands/runtime/base.py
@@ -127,8 +127,11 @@ class Runtime(FileEditRuntimeMixin):
if event.timeout is None:
event.timeout = self.config.sandbox.timeout
assert event.timeout is not None
- observation = await call_sync_from_async(self.run_action, event)
+ observation: Observation = await call_sync_from_async(
+ self.run_action, event
+ )
observation._cause = event.id # type: ignore[attr-defined]
+ observation.tool_call_metadata = event.tool_call_metadata
source = event.source if event.source else EventSource.AGENT
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
diff --git a/openhands/runtime/impl/eventstream/eventstream_runtime.py b/openhands/runtime/impl/eventstream/eventstream_runtime.py
index 1cf4096922..5121e9ede9 100644
--- a/openhands/runtime/impl/eventstream/eventstream_runtime.py
+++ b/openhands/runtime/impl/eventstream/eventstream_runtime.py
@@ -1,6 +1,7 @@
import os
import tempfile
import threading
+import traceback
from functools import lru_cache
from typing import Callable
from zipfile import ZipFile
@@ -496,7 +497,9 @@ class EventStreamRuntime(Runtime):
)
except Exception as e:
logger.error(f'Error during action execution: {e}')
- obs = FatalErrorObservation(f'Action execution failed: {str(e)}')
+ obs = FatalErrorObservation(
+ f'Action execution failed: {str(e)}.\n{traceback.format_exc()}'
+ )
self._refresh_logs()
return obs
diff --git a/openhands/runtime/plugins/agent_skills/agentskills.py b/openhands/runtime/plugins/agent_skills/agentskills.py
index dd34e3878d..046f8af20c 100644
--- a/openhands/runtime/plugins/agent_skills/agentskills.py
+++ b/openhands/runtime/plugins/agent_skills/agentskills.py
@@ -23,3 +23,9 @@ for func_name in __all__:
fn_signature = f'{func.__name__}' + str(signature(func))
DOCUMENTATION += f'{fn_signature}:\n{cur_doc}\n\n'
+
+
+# Add file_editor (a function)
+from openhands.runtime.plugins.agent_skills.file_editor import file_editor # noqa: E402
+
+__all__ += ['file_editor']
diff --git a/openhands/runtime/plugins/agent_skills/file_editor/README.md b/openhands/runtime/plugins/agent_skills/file_editor/README.md
new file mode 100644
index 0000000000..37c6c2818a
--- /dev/null
+++ b/openhands/runtime/plugins/agent_skills/file_editor/README.md
@@ -0,0 +1,3 @@
+# File Editor
+
+This file editor is largely based on Anthorpic released [`str_replace_editor`](https://github.com/anthropics/anthropic-quickstarts/tree/main/computer-use-demo/computer_use_demo/tools/edit.py). The original code was released under [MIT license](https://github.com/anthropics/anthropic-quickstarts/blob/e373524f07594d48c3f9563248ea282a4c306c0c/LICENSE).
diff --git a/openhands/runtime/plugins/agent_skills/file_editor/__init__.py b/openhands/runtime/plugins/agent_skills/file_editor/__init__.py
new file mode 100644
index 0000000000..f6d3eb39a0
--- /dev/null
+++ b/openhands/runtime/plugins/agent_skills/file_editor/__init__.py
@@ -0,0 +1,60 @@
+"""This file contains a global singleton of the `EditTool` class as well as raw functions that expose its __call__."""
+
+from .base import CLIResult, ToolError, ToolResult
+from .impl import Command, EditTool
+
+_GLOBAL_EDITOR = EditTool()
+
+
+def _make_api_tool_result(
+ result: ToolResult,
+) -> str:
+ """Convert an agent ToolResult to an API ToolResultBlockParam."""
+ tool_result_content: str = ''
+ is_error = False
+ if result.error:
+ is_error = True
+ tool_result_content = _maybe_prepend_system_tool_result(result, result.error)
+ else:
+ assert result.output, 'Expecting output in file_editor'
+ tool_result_content = _maybe_prepend_system_tool_result(result, result.output)
+ assert (
+ not result.base64_image
+ ), 'Not expecting base64_image as output in file_editor'
+ if is_error:
+ return f'ERROR:\n{tool_result_content}'
+ else:
+ return tool_result_content
+
+
+def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str:
+ if result.system:
+ result_text = f'{result.system}\n{result_text}'
+ return result_text
+
+
+def file_editor(
+ command: Command,
+ path: str,
+ file_text: str | None = None,
+ view_range: list[int] | None = None,
+ old_str: str | None = None,
+ new_str: str | None = None,
+ insert_line: int | None = None,
+) -> str:
+ try:
+ result: CLIResult = _GLOBAL_EDITOR(
+ command=command,
+ path=path,
+ file_text=file_text,
+ view_range=view_range,
+ old_str=old_str,
+ new_str=new_str,
+ insert_line=insert_line,
+ )
+ except ToolError as e:
+ return _make_api_tool_result(ToolResult(error=e.message))
+ return _make_api_tool_result(result)
+
+
+__all__ = ['file_editor']
diff --git a/openhands/runtime/plugins/agent_skills/file_editor/base.py b/openhands/runtime/plugins/agent_skills/file_editor/base.py
new file mode 100644
index 0000000000..6ad2a4b5b6
--- /dev/null
+++ b/openhands/runtime/plugins/agent_skills/file_editor/base.py
@@ -0,0 +1,50 @@
+from dataclasses import dataclass, fields, replace
+
+
+@dataclass(kw_only=True, frozen=True)
+class ToolResult:
+ """Represents the result of a tool execution."""
+
+ output: str | None = None
+ error: str | None = None
+ base64_image: str | None = None
+ system: str | None = None
+
+ def __bool__(self):
+ return any(getattr(self, field.name) for field in fields(self))
+
+ def __add__(self, other: 'ToolResult'):
+ def combine_fields(
+ field: str | None, other_field: str | None, concatenate: bool = True
+ ):
+ if field and other_field:
+ if concatenate:
+ return field + other_field
+ raise ValueError('Cannot combine tool results')
+ return field or other_field
+
+ return ToolResult(
+ output=combine_fields(self.output, other.output),
+ error=combine_fields(self.error, other.error),
+ base64_image=combine_fields(self.base64_image, other.base64_image, False),
+ system=combine_fields(self.system, other.system),
+ )
+
+ def replace(self, **kwargs):
+ """Returns a new ToolResult with the given fields replaced."""
+ return replace(self, **kwargs)
+
+
+class CLIResult(ToolResult):
+ """A ToolResult that can be rendered as a CLI output."""
+
+
+class ToolFailure(ToolResult):
+ """A ToolResult that represents a failure."""
+
+
+class ToolError(Exception):
+ """Raised when a tool encounters an error."""
+
+ def __init__(self, message):
+ self.message = message
diff --git a/openhands/runtime/plugins/agent_skills/file_editor/impl.py b/openhands/runtime/plugins/agent_skills/file_editor/impl.py
new file mode 100644
index 0000000000..e0944ab659
--- /dev/null
+++ b/openhands/runtime/plugins/agent_skills/file_editor/impl.py
@@ -0,0 +1,279 @@
+from collections import defaultdict
+from pathlib import Path
+from typing import Literal, get_args
+
+from .base import CLIResult, ToolError, ToolResult
+from .run import maybe_truncate, run
+
+Command = Literal[
+ 'view',
+ 'create',
+ 'str_replace',
+ 'insert',
+ 'undo_edit',
+]
+SNIPPET_LINES: int = 4
+
+
+class EditTool:
+ """
+ An filesystem editor tool that allows the agent to view, create, and edit files.
+ The tool parameters are defined by Anthropic and are not editable.
+
+ Original implementation: https://github.com/anthropics/anthropic-quickstarts/blob/main/computer-use-demo/computer_use_demo/tools/edit.py
+ """
+
+ _file_history: dict[Path, list[str]]
+
+ def __init__(self):
+ self._file_history = defaultdict(list)
+ super().__init__()
+
+ def __call__(
+ self,
+ *,
+ command: Command,
+ path: str,
+ file_text: str | None = None,
+ view_range: list[int] | None = None,
+ old_str: str | None = None,
+ new_str: str | None = None,
+ insert_line: int | None = None,
+ **kwargs,
+ ):
+ _path = Path(path)
+ self.validate_path(command, _path)
+ if command == 'view':
+ return self.view(_path, view_range)
+ elif command == 'create':
+ if not file_text:
+ raise ToolError('Parameter `file_text` is required for command: create')
+ self.write_file(_path, file_text)
+ self._file_history[_path].append(file_text)
+ return ToolResult(output=f'File created successfully at: {_path}')
+ elif command == 'str_replace':
+ if not old_str:
+ raise ToolError(
+ 'Parameter `old_str` is required for command: str_replace'
+ )
+ return self.str_replace(_path, old_str, new_str)
+ elif command == 'insert':
+ if insert_line is None:
+ raise ToolError(
+ 'Parameter `insert_line` is required for command: insert'
+ )
+ if not new_str:
+ raise ToolError('Parameter `new_str` is required for command: insert')
+ return self.insert(_path, insert_line, new_str)
+ elif command == 'undo_edit':
+ return self.undo_edit(_path)
+ raise ToolError(
+ f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
+ )
+
+ def validate_path(self, command: str, path: Path):
+ """
+ Check that the path/command combination is valid.
+ """
+ # Check if its an absolute path
+ if not path.is_absolute():
+ suggested_path = Path('') / path
+ raise ToolError(
+ f'The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?'
+ )
+ # Check if path exists
+ if not path.exists() and command != 'create':
+ raise ToolError(
+ f'The path {path} does not exist. Please provide a valid path.'
+ )
+ if path.exists() and command == 'create':
+ raise ToolError(
+ f'File already exists at: {path}. Cannot overwrite files using command `create`.'
+ )
+ # Check if the path points to a directory
+ if path.is_dir():
+ if command != 'view':
+ raise ToolError(
+ f'The path {path} is a directory and only the `view` command can be used on directories'
+ )
+
+ def view(self, path: Path, view_range: list[int] | None = None):
+ """Implement the view command"""
+ if path.is_dir():
+ if view_range:
+ raise ToolError(
+ 'The `view_range` parameter is not allowed when `path` points to a directory.'
+ )
+
+ _, stdout, stderr = run(rf"find {path} -maxdepth 2 -not -path '*/\.*'")
+ if not stderr:
+ stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
+ return CLIResult(output=stdout, error=stderr)
+
+ file_content = self.read_file(path)
+ init_line = 1
+ if view_range:
+ if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
+ raise ToolError(
+ 'Invalid `view_range`. It should be a list of two integers.'
+ )
+ file_lines = file_content.split('\n')
+ n_lines_file = len(file_lines)
+ init_line, final_line = view_range
+ if init_line < 1 or init_line > n_lines_file:
+ raise ToolError(
+ f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
+ )
+ if final_line > n_lines_file:
+ raise ToolError(
+ f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
+ )
+ if final_line != -1 and final_line < init_line:
+ raise ToolError(
+ f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`"
+ )
+
+ if final_line == -1:
+ file_content = '\n'.join(file_lines[init_line - 1 :])
+ else:
+ file_content = '\n'.join(file_lines[init_line - 1 : final_line])
+
+ return CLIResult(
+ output=self._make_output(file_content, str(path), init_line=init_line)
+ )
+
+ def str_replace(self, path: Path, old_str: str, new_str: str | None):
+ """Implement the str_replace command, which replaces old_str with new_str in the file content"""
+ # Read the file content
+ file_content = self.read_file(path).expandtabs()
+ old_str = old_str.expandtabs()
+ new_str = new_str.expandtabs() if new_str is not None else ''
+
+ # Check if old_str is unique in the file
+ occurrences = file_content.count(old_str)
+ if occurrences == 0:
+ raise ToolError(
+ f'No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}.'
+ )
+ elif occurrences > 1:
+ file_content_lines = file_content.split('\n')
+ lines = [
+ idx + 1
+ for idx, line in enumerate(file_content_lines)
+ if old_str in line
+ ]
+ raise ToolError(
+ f'No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique'
+ )
+
+ # Replace old_str with new_str
+ new_file_content = file_content.replace(old_str, new_str)
+
+ # Write the new content to the file
+ self.write_file(path, new_file_content)
+
+ # Save the content to history
+ self._file_history[path].append(file_content)
+
+ # Create a snippet of the edited section
+ replacement_line = file_content.split(old_str)[0].count('\n')
+ start_line = max(0, replacement_line - SNIPPET_LINES)
+ end_line = replacement_line + SNIPPET_LINES + new_str.count('\n')
+ snippet = '\n'.join(new_file_content.split('\n')[start_line : end_line + 1])
+
+ # Prepare the success message
+ success_msg = f'The file {path} has been edited. '
+ success_msg += self._make_output(
+ snippet, f'a snippet of {path}', start_line + 1
+ )
+ success_msg += 'Review the changes and make sure they are as expected. Edit the file again if necessary.'
+
+ return CLIResult(output=success_msg)
+
+ def insert(self, path: Path, insert_line: int, new_str: str):
+ """Implement the insert command, which inserts new_str at the specified line in the file content."""
+ file_text = self.read_file(path).expandtabs()
+ new_str = new_str.expandtabs()
+ file_text_lines = file_text.split('\n')
+ n_lines_file = len(file_text_lines)
+
+ if insert_line < 0 or insert_line > n_lines_file:
+ raise ToolError(
+ f'Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}'
+ )
+
+ new_str_lines = new_str.split('\n')
+ new_file_text_lines = (
+ file_text_lines[:insert_line]
+ + new_str_lines
+ + file_text_lines[insert_line:]
+ )
+ snippet_lines = (
+ file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ + new_str_lines
+ + file_text_lines[insert_line : insert_line + SNIPPET_LINES]
+ )
+
+ new_file_text = '\n'.join(new_file_text_lines)
+ snippet = '\n'.join(snippet_lines)
+
+ self.write_file(path, new_file_text)
+ self._file_history[path].append(file_text)
+
+ success_msg = f'The file {path} has been edited. '
+ success_msg += self._make_output(
+ snippet,
+ 'a snippet of the edited file',
+ max(1, insert_line - SNIPPET_LINES + 1),
+ )
+ success_msg += 'Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary.'
+ return CLIResult(output=success_msg)
+
+ def undo_edit(self, path: Path):
+ """Implement the undo_edit command."""
+ if not self._file_history[path]:
+ raise ToolError(f'No edit history found for {path}.')
+
+ old_text = self._file_history[path].pop()
+ self.write_file(path, old_text)
+
+ return CLIResult(
+ output=f'Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}'
+ )
+
+ def read_file(self, path: Path):
+ """Read the content of a file from a given path; raise a ToolError if an error occurs."""
+ try:
+ return path.read_text()
+ except Exception as e:
+ raise ToolError(f'Ran into {e} while trying to read {path}') from None
+
+ def write_file(self, path: Path, file: str):
+ """Write the content of a file to a given path; raise a ToolError if an error occurs."""
+ try:
+ path.write_text(file)
+ except Exception as e:
+ raise ToolError(f'Ran into {e} while trying to write to {path}') from None
+
+ def _make_output(
+ self,
+ file_content: str,
+ file_descriptor: str,
+ init_line: int = 1,
+ expand_tabs: bool = True,
+ ):
+ """Generate output for the CLI based on the content of a file."""
+ file_content = maybe_truncate(file_content)
+ if expand_tabs:
+ file_content = file_content.expandtabs()
+ file_content = '\n'.join(
+ [
+ f'{i + init_line:6}\t{line}'
+ for i, line in enumerate(file_content.split('\n'))
+ ]
+ )
+ return (
+ f"Here's the result of running `cat -n` on {file_descriptor}:\n"
+ + file_content
+ + '\n'
+ )
diff --git a/openhands/runtime/plugins/agent_skills/file_editor/run.py b/openhands/runtime/plugins/agent_skills/file_editor/run.py
new file mode 100644
index 0000000000..29c604256f
--- /dev/null
+++ b/openhands/runtime/plugins/agent_skills/file_editor/run.py
@@ -0,0 +1,44 @@
+"""Utility to run shell commands asynchronously with a timeout."""
+
+import subprocess
+import time
+
+TRUNCATED_MESSAGE: str = 'To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.'
+MAX_RESPONSE_LEN: int = 16000
+
+
+def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
+ """Truncate content and append a notice if content exceeds the specified length."""
+ return (
+ content
+ if not truncate_after or len(content) <= truncate_after
+ else content[:truncate_after] + TRUNCATED_MESSAGE
+ )
+
+
+def run(
+ cmd: str,
+ timeout: float | None = 120.0, # seconds
+ truncate_after: int | None = MAX_RESPONSE_LEN,
+):
+ """Run a shell command synchronously with a timeout."""
+ start_time = time.time()
+
+ try:
+ process = subprocess.Popen(
+ cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
+ )
+
+ stdout, stderr = process.communicate(timeout=timeout)
+
+ return (
+ process.returncode or 0,
+ maybe_truncate(stdout, truncate_after=truncate_after),
+ maybe_truncate(stderr, truncate_after=truncate_after),
+ )
+ except subprocess.TimeoutExpired:
+ process.kill()
+ elapsed_time = time.time() - start_time
+ raise TimeoutError(
+ f"Command '{cmd}' timed out after {elapsed_time:.2f} seconds"
+ )
diff --git a/poetry.lock b/poetry.lock
index 20316ac24f..6ef4a43322 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "aenum"
@@ -3901,22 +3901,20 @@ name = "litellm"
version = "1.50.4"
description = "Library to easily interface with LLM API providers"
optional = false
-python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
-files = [
- {file = "litellm-1.50.4-py3-none-any.whl", hash = "sha256:cc6992275e24a0bbb4a3b377e6842d45a8510fc85d7f255930a64bb872980a36"},
- {file = "litellm-1.50.4.tar.gz", hash = "sha256:a7e68ef614f631b58969c2c7a5154a565ba5974558d437c8cd6c8623654880ea"},
-]
+python-versions = ">=3.8.1,<4.0, !=3.9.7"
+files = []
+develop = false
[package.dependencies]
aiohttp = "*"
click = "*"
importlib-metadata = ">=6.8.0"
-jinja2 = ">=3.1.2,<4.0.0"
-jsonschema = ">=4.22.0,<5.0.0"
+jinja2 = "^3.1.2"
+jsonschema = "^4.22.0"
openai = ">=1.52.0"
-pydantic = ">=2.0.0,<3.0.0"
+pydantic = "^2.0.0"
python-dotenv = ">=0.2.0"
-requests = ">=2.31.0,<3.0.0"
+requests = "^2.31.0"
tiktoken = ">=0.7.0"
tokenizers = "*"
@@ -3924,6 +3922,12 @@ tokenizers = "*"
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "resend (>=0.8.0,<0.9.0)"]
proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"]
+[package.source]
+type = "git"
+url = "https://github.com/BerriAI/litellm.git"
+reference = "58fe6610601297c5a90367fd4583469e2df3fcf9"
+resolved_reference = "58fe6610601297c5a90367fd4583469e2df3fcf9"
+
[[package]]
name = "llama-index"
version = "0.10.45.post1"
@@ -10095,4 +10099,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
-content-hash = "aeb09e429a789c3f8ced605e7e1a5932fd6cce7f7f4ce30a960da77fba18b9a3"
+content-hash = "880b0251ec1ac83a7a8f6b1637b0860f75553d1f9c7c67313af9f3bb686c166a"
diff --git a/pyproject.toml b/pyproject.toml
index 393bab9594..a7daebd37d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,7 +14,7 @@ packages = [
python = "^3.12"
datasets = "*"
pandas = "*"
-litellm = "*"
+litellm = { git = "https://github.com/BerriAI/litellm.git", rev = "58fe6610601297c5a90367fd4583469e2df3fcf9" }
google-generativeai = "*" # To use litellm with Gemini Pro API
termcolor = "*"
seaborn = "*"
@@ -89,6 +89,7 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
+
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -119,6 +120,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
+
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"
diff --git a/tests/unit/test_agent_skill.py b/tests/unit/test_agent_skill.py
index 63745f4dd2..f619bc00bf 100644
--- a/tests/unit/test_agent_skill.py
+++ b/tests/unit/test_agent_skill.py
@@ -5,6 +5,7 @@ import sys
import docx
import pytest
+from openhands.runtime.plugins.agent_skills.agentskills import file_editor
from openhands.runtime.plugins.agent_skills.file_ops.file_ops import (
WINDOW,
_print_window,
@@ -715,3 +716,302 @@ def test_parse_pptx(tmp_path):
'Hello, this is the second test PPTX slide.\n\n'
)
assert output == expected_output, f'Expected output does not match. Got: {output}'
+
+
+# =============================================================================
+
+
+def test_file_editor_view(tmp_path):
+ # generate a random directory
+ random_dir = tmp_path / 'dir_1'
+ random_dir.mkdir()
+ # create a file in the directory
+ random_file = random_dir / 'a.txt'
+ random_file.write_text('Line 1\nLine 2\nLine 3\nLine 4\nLine 5')
+ random_dir_2 = tmp_path / 'dir_2'
+ random_dir_2.mkdir()
+ random_file_2 = random_dir_2 / 'b.txt'
+ random_file_2.write_text('Line 1\nLine 2\nLine 3\nLine 4\nLine 5')
+
+ from openhands.runtime.plugins.agent_skills.agentskills import file_editor
+
+ # view the file
+ result = file_editor(command='view', path=str(random_file))
+ print('\n', result)
+ assert result is not None
+ assert (
+ result.split('\n')
+ == f"""Here's the result of running `cat -n` on {random_file}:
+ 1\tLine 1
+ 2\tLine 2
+ 3\tLine 3
+ 4\tLine 4
+ 5\tLine 5
+""".split('\n')
+ )
+
+ # view the directory
+ result = file_editor(command='view', path=str(tmp_path))
+ print('\n', result)
+ assert result is not None
+ assert (
+ result.strip().split('\n')
+ == f"""Here's the files and directories up to 2 levels deep in {tmp_path}, excluding hidden items:
+{tmp_path}
+{tmp_path}/dir_2
+{tmp_path}/dir_2/b.txt
+{tmp_path}/dir_1
+{tmp_path}/dir_1/a.txt
+""".strip().split('\n')
+ )
+
+
+def test_file_editor_create(tmp_path):
+ # generate a random directory
+ random_dir = tmp_path / 'dir_1'
+ random_dir.mkdir()
+ # create a file in the directory
+ random_file = random_dir / 'a.txt'
+
+ from openhands.runtime.plugins.agent_skills.agentskills import file_editor
+
+ # view an unexist file
+ result = file_editor(command='view', path=str(random_file))
+ print(result)
+ assert result is not None
+ assert (
+ result
+ == f'ERROR:\nThe path {random_file} does not exist. Please provide a valid path.'
+ )
+
+ # create a file
+ result = file_editor(command='create', path=str(random_file), file_text='Line 6')
+ print(result)
+ assert result is not None
+ assert result == f'File created successfully at: {random_file}'
+
+ # view again
+ result = file_editor(command='view', path=str(random_file))
+ print(result)
+ assert result is not None
+ assert (
+ result.strip().split('\n')
+ == f"""Here's the result of running `cat -n` on {random_file}:
+ 1\tLine 6
+""".strip().split('\n')
+ )
+
+
+@pytest.fixture
+def setup_file(tmp_path):
+ random_dir = tmp_path / 'dir_1'
+ random_dir.mkdir()
+ random_file = random_dir / 'a.txt'
+ return random_file
+
+
+def test_file_editor_create_and_view(setup_file):
+ random_file = setup_file
+
+ # Test create command
+ result = file_editor(
+ command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3'
+ )
+ print(result)
+ assert result == f'File created successfully at: {random_file}'
+
+ # Test view command for file
+ result = file_editor(command='view', path=str(random_file))
+ print(result)
+ assert (
+ result.strip().split('\n')
+ == f"""Here's the result of running `cat -n` on {random_file}:
+ 1\tLine 1
+ 2\tLine 2
+ 3\tLine 3
+""".strip().split('\n')
+ )
+
+ # Test view command for directory
+ result = file_editor(command='view', path=str(random_file.parent))
+ assert f'{random_file.parent}' in result
+ assert f'{random_file.name}' in result
+
+
+def test_file_editor_view_nonexistent(setup_file):
+ random_file = setup_file
+
+ # Test view command for non-existent file
+ result = file_editor(command='view', path=str(random_file))
+ assert (
+ result
+ == f'ERROR:\nThe path {random_file} does not exist. Please provide a valid path.'
+ )
+
+
+def test_file_editor_str_replace(setup_file):
+ random_file = setup_file
+ file_editor(
+ command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3'
+ )
+
+ # Test str_replace command
+ result = file_editor(
+ command='str_replace',
+ path=str(random_file),
+ old_str='Line 2',
+ new_str='New Line 2',
+ )
+ print(result)
+ assert (
+ result
+ == f"""The file {random_file} has been edited. Here's the result of running `cat -n` on a snippet of {random_file}:
+ 1\tLine 1
+ 2\tNew Line 2
+ 3\tLine 3
+Review the changes and make sure they are as expected. Edit the file again if necessary."""
+ )
+
+ # View the file after str_replace
+ result = file_editor(command='view', path=str(random_file))
+ print(result)
+ assert (
+ result.strip().split('\n')
+ == f"""Here's the result of running `cat -n` on {random_file}:
+ 1\tLine 1
+ 2\tNew Line 2
+ 3\tLine 3
+""".strip().split('\n')
+ )
+
+
+def test_file_editor_str_replace_non_existent(setup_file):
+ random_file = setup_file
+ file_editor(
+ command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3'
+ )
+
+ # Test str_replace with non-existent string
+ result = file_editor(
+ command='str_replace',
+ path=str(random_file),
+ old_str='Non-existent Line',
+ new_str='New Line',
+ )
+ print(result)
+ assert (
+ result
+ == f'ERROR:\nNo replacement was performed, old_str `Non-existent Line` did not appear verbatim in {random_file}.'
+ )
+
+
+def test_file_editor_insert(setup_file):
+ random_file = setup_file
+ file_editor(
+ command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3'
+ )
+
+ # Test insert command
+ result = file_editor(
+ command='insert', path=str(random_file), insert_line=2, new_str='Inserted Line'
+ )
+ print(result)
+ assert (
+ result
+ == f"""The file {random_file} has been edited. Here's the result of running `cat -n` on a snippet of the edited file:
+ 1\tLine 1
+ 2\tLine 2
+ 3\tInserted Line
+ 4\tLine 3
+Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."""
+ )
+
+ # View the file after insert
+ result = file_editor(command='view', path=str(random_file))
+ assert (
+ result.strip().split('\n')
+ == f"""Here's the result of running `cat -n` on {random_file}:
+ 1\tLine 1
+ 2\tLine 2
+ 3\tInserted Line
+ 4\tLine 3
+""".strip().split('\n')
+ )
+
+
+def test_file_editor_insert_invalid_line(setup_file):
+ random_file = setup_file
+ file_editor(
+ command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3'
+ )
+
+ # Test insert with invalid line number
+ result = file_editor(
+ command='insert',
+ path=str(random_file),
+ insert_line=10,
+ new_str='Invalid Insert',
+ )
+ assert (
+ result
+ == 'ERROR:\nInvalid `insert_line` parameter: 10. It should be within the range of lines of the file: [0, 3]'
+ )
+
+
+def test_file_editor_undo_edit(setup_file):
+ random_file = setup_file
+ result = file_editor(
+ command='create', path=str(random_file), file_text='Line 1\nLine 2\nLine 3'
+ )
+ print(result)
+ assert result == f"""File created successfully at: {random_file}"""
+
+ # Make an edit
+ result = file_editor(
+ command='str_replace',
+ path=str(random_file),
+ old_str='Line 2',
+ new_str='New Line 2',
+ )
+ print(result)
+ assert (
+ result
+ == f"""The file {random_file} has been edited. Here's the result of running `cat -n` on a snippet of {random_file}:
+ 1\tLine 1
+ 2\tNew Line 2
+ 3\tLine 3
+Review the changes and make sure they are as expected. Edit the file again if necessary."""
+ )
+
+ # Test undo_edit command
+ result = file_editor(command='undo_edit', path=str(random_file))
+ print(result)
+ assert (
+ result
+ == f"""Last edit to {random_file} undone successfully. Here's the result of running `cat -n` on {random_file}:
+ 1\tLine 1
+ 2\tLine 2
+ 3\tLine 3
+"""
+ )
+
+ # View the file after undo_edit
+ result = file_editor(command='view', path=str(random_file))
+ assert (
+ result.strip().split('\n')
+ == f"""Here's the result of running `cat -n` on {random_file}:
+ 1\tLine 1
+ 2\tLine 2
+ 3\tLine 3
+""".strip().split('\n')
+ )
+
+
+def test_file_editor_undo_edit_no_edits(tmp_path):
+ random_file = tmp_path / 'a.txt'
+ random_file.touch()
+
+ # Test undo_edit when no edits have been made
+ result = file_editor(command='undo_edit', path=str(random_file))
+ print(result)
+ assert result == f'ERROR:\nNo edit history found for {random_file}.'
diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py
index 55dfa3feb7..9e3dda6c2c 100644
--- a/tests/unit/test_codeact_agent.py
+++ b/tests/unit/test_codeact_agent.py
@@ -24,29 +24,35 @@ def agent() -> CodeActAgent:
def test_cmd_output_observation_message(agent: CodeActAgent):
+ agent.config.function_calling = False
obs = CmdOutputObservation(
command='echo hello', content='Command output', command_id=1, exit_code=0
)
- result = agent.get_observation_message(obs)
+ results = agent.get_observation_message(obs, tool_call_id_to_message={})
+ assert len(results) == 1
+ result = results[0]
assert result is not None
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert 'OBSERVATION:' in result.content[0].text
assert 'Command output' in result.content[0].text
- assert 'Command 1 finished with exit code 0' in result.content[0].text
+ assert 'Command finished with exit code 0' in result.content[0].text
def test_ipython_run_cell_observation_message(agent: CodeActAgent):
+ agent.config.function_calling = False
obs = IPythonRunCellObservation(
code='plt.plot()',
content='IPython output\n',
)
- result = agent.get_observation_message(obs)
+ results = agent.get_observation_message(obs, tool_call_id_to_message={})
+ assert len(results) == 1
+ result = results[0]
assert result is not None
assert result.role == 'user'
assert len(result.content) == 1
@@ -61,12 +67,15 @@ def test_ipython_run_cell_observation_message(agent: CodeActAgent):
def test_agent_delegate_observation_message(agent: CodeActAgent):
+ agent.config.function_calling = False
obs = AgentDelegateObservation(
content='Content', outputs={'content': 'Delegated agent output'}
)
- result = agent.get_observation_message(obs)
+ results = agent.get_observation_message(obs, tool_call_id_to_message={})
+ assert len(results) == 1
+ result = results[0]
assert result is not None
assert result.role == 'user'
assert len(result.content) == 1
@@ -76,10 +85,13 @@ def test_agent_delegate_observation_message(agent: CodeActAgent):
def test_error_observation_message(agent: CodeActAgent):
+ agent.config.function_calling = False
obs = ErrorObservation('Error message')
- result = agent.get_observation_message(obs)
+ results = agent.get_observation_message(obs, tool_call_id_to_message={})
+ assert len(results) == 1
+ result = results[0]
assert result is not None
assert result.role == 'user'
assert len(result.content) == 1
@@ -93,4 +105,4 @@ def test_unknown_observation_message(agent: CodeActAgent):
obs = Mock()
with pytest.raises(ValueError, match='Unknown observation type:'):
- agent.get_observation_message(obs)
+ agent.get_observation_message(obs, tool_call_id_to_message={})
diff --git a/tests/unit/test_message_serialization.py b/tests/unit/test_message_serialization.py
index 16035f0aed..aac1f8d014 100644
--- a/tests/unit/test_message_serialization.py
+++ b/tests/unit/test_message_serialization.py
@@ -78,7 +78,10 @@ def test_message_with_only_text_content_and_vision_disabled():
expected_serialized_message = {
'role': 'user',
- 'content': 'This is a text message\nThis is another text message',
+ 'content': [
+ {'type': 'text', 'text': 'This is a text message'},
+ {'type': 'text', 'text': 'This is another text message'},
+ ],
}
assert serialized_message == expected_serialized_message
@@ -107,7 +110,10 @@ def test_message_with_mixed_content_and_vision_disabled():
# Expected serialization ignores images and concatenates text
expected_serialized_message = {
'role': 'user',
- 'content': 'This is a text message\nThis is another text message',
+ 'content': [
+ {'type': 'text', 'text': 'This is a text message'},
+ {'type': 'text', 'text': 'This is another text message'},
+ ],
}
# Assert serialized message matches expectation
diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py
index 41dc756187..d038f18d9c 100644
--- a/tests/unit/test_prompt_caching.py
+++ b/tests/unit/test_prompt_caching.py
@@ -25,12 +25,38 @@ def mock_event_stream(tmp_path):
return EventStream('test_session', file_store)
-@pytest.fixture
-def codeact_agent(mock_llm):
+@pytest.fixture(params=[False, True])
+def codeact_agent(mock_llm, request):
config = AgentConfig()
+ config.function_calling = request.param
return CodeActAgent(mock_llm, config)
+def response_mock(content: str):
+ class MockModelResponse:
+ def __init__(self, content):
+ self.choices = [
+ {
+ 'message': {
+ 'content': content,
+ 'tool_calls': [
+ {
+ 'function': {
+ 'name': 'execute_bash',
+ 'arguments': '{}',
+ }
+ }
+ ],
+ }
+ }
+ ]
+
+ def model_dump(self):
+ return {'choices': self.choices}
+
+ return MockModelResponse(content)
+
+
def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
# Add some events to the stream
mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER)
@@ -49,9 +75,13 @@ def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
) # System, initial user + user message, agent message, last user message
assert messages[0].content[0].cache_prompt # system message
assert messages[1].role == 'user'
- assert messages[1].content[0].text.endswith("LET'S START!")
- assert messages[1].content[1].text.endswith('Initial user message')
- assert messages[1].content[0].cache_prompt
+ if not codeact_agent.config.function_calling:
+ assert messages[1].content[0].text.endswith("LET'S START!")
+ assert messages[1].content[1].text.endswith('Initial user message')
+ else:
+ assert messages[1].content[0].text.endswith('Initial user message')
+ # we add cache breakpoint to the last 3 user messages
+ assert messages[1].content[-1].cache_prompt
assert messages[3].role == 'user'
assert messages[3].content[0].text == ('Hello, agent!')
@@ -62,13 +92,14 @@ def test_get_messages_with_reminder(codeact_agent, mock_event_stream):
assert messages[5].role == 'user'
assert messages[5].content[0].text.startswith('Laaaaaaaast!')
assert messages[5].content[0].cache_prompt
- assert (
- messages[5]
- .content[1]
- .text.endswith(
- 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with .'
+ if not codeact_agent.config.function_calling:
+ assert (
+ messages[5]
+ .content[1]
+ .text.endswith(
+ 'ENVIRONMENT REMINDER: You have 5 turns left to complete the task. When finished reply with .'
+ )
)
- )
def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
@@ -97,9 +128,10 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream):
) # Including the initial system+user + 2 last user message
# Verify that these are indeed the last two user messages (from start)
- assert (
- cached_user_messages[0].content[0].text.startswith('A chat between')
- ) # system message
+ if not codeact_agent.config.function_calling:
+ assert (
+ cached_user_messages[0].content[0].text.startswith('A chat between')
+ ) # system message
assert cached_user_messages[2].content[0].text.startswith('User message 1')
assert cached_user_messages[3].content[0].text.startswith('User message 1')
@@ -144,14 +176,15 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
# Assert the presence of key elements in the messages
assert (
messages[1]
- .content[1]
+ .content[-1]
.text.startswith("Let's list the contents of the current directory.")
) # user, included in the initial message
- assert any(
- 'List files in current directory\n\nls -l\n'
- in msg.content[0].text
- for msg in messages
- ) # agent
+ if not codeact_agent.config.function_calling:
+ assert any(
+ 'List files in current directory\n\nls -l\n'
+ in msg.content[0].text
+ for msg in messages
+ ) # agent
assert any(
'total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt'
in msg.content[0].text
@@ -160,7 +193,8 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
assert any(
"Now, let's create a new directory." in msg.content[0].text for msg in messages
) # agent
- assert messages[4].content[1].text.startswith('Create a new directory') # agent
+ if not codeact_agent.config.function_calling:
+ assert messages[4].content[1].text.startswith('Create a new directory') # agent
assert any(
'finished with exit code 0' in msg.content[0].text for msg in messages
) # user, observation
@@ -171,16 +205,20 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
# prompt cache is added to the system message
assert messages[0].content[0].cache_prompt
# and the first initial user message
- assert messages[1].content[0].cache_prompt
+ assert messages[1].content[-1].cache_prompt
# and to the last two user messages
assert messages[3].content[0].cache_prompt
assert messages[5].content[0].cache_prompt
# reminder is added to the last user message
- assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
+ if not codeact_agent.config.function_calling:
+ assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text
def test_prompt_caching_headers(codeact_agent, mock_event_stream):
+ if codeact_agent.config.function_calling:
+ pytest.skip('Skipping this test for function calling')
+
# Setup
mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER)
mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT)
diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py
index 4f0aa4dc39..771ccc206d 100644
--- a/tests/unit/test_security.py
+++ b/tests/unit/test_security.py
@@ -221,8 +221,9 @@ def test_unsafe_bash_command(temp_dir: str):
name=ActionType.RUN_IPYTHON,
arguments={
'code': "print('hello')",
- 'kernel_init_code': '',
+ 'include_extra': True,
'confirmation_state': ActionConfirmationStatus.CONFIRMED,
+ 'kernel_init_code': '',
},
),
),