mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix first user message (#6471)
This commit is contained in:
parent
604534905f
commit
89c7bf59a7
@ -433,26 +433,6 @@ class CodeActAgent(Agent):
|
||||
],
|
||||
)
|
||||
]
|
||||
example_message = self.prompt_manager.get_example_user_message()
|
||||
if example_message:
|
||||
messages.append(
|
||||
Message(
|
||||
role='user',
|
||||
content=[TextContent(text=example_message)],
|
||||
cache_prompt=self.llm.is_caching_prompt_active(),
|
||||
)
|
||||
)
|
||||
|
||||
# Repository and runtime info
|
||||
additional_info = self.prompt_manager.get_additional_info()
|
||||
if self.config.enable_prompt_extensions and additional_info:
|
||||
# only add these if prompt extension is enabled
|
||||
messages.append(
|
||||
Message(
|
||||
role='user',
|
||||
content=[TextContent(text=additional_info)],
|
||||
)
|
||||
)
|
||||
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
tool_call_id_to_message: dict[str, Message] = {}
|
||||
@ -460,6 +440,7 @@ class CodeActAgent(Agent):
|
||||
# Condense the events from the state.
|
||||
events = self.condenser.condensed_history(state)
|
||||
|
||||
is_first_message_handled = False
|
||||
for event in events:
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
@ -501,11 +482,22 @@ class CodeActAgent(Agent):
|
||||
for response_id in _response_ids_to_remove:
|
||||
pending_tool_call_action_messages.pop(response_id)
|
||||
|
||||
for message in messages_to_add:
|
||||
if message:
|
||||
if message.role == 'user':
|
||||
self.prompt_manager.enhance_message(message)
|
||||
messages.append(message)
|
||||
for msg in messages_to_add:
|
||||
if msg:
|
||||
if msg.role == 'user' and not is_first_message_handled:
|
||||
is_first_message_handled = True
|
||||
# compose the first user message with examples
|
||||
self.prompt_manager.add_examples_to_initial_message(msg)
|
||||
|
||||
# and/or repo/runtime info
|
||||
if self.config.enable_prompt_extensions:
|
||||
self.prompt_manager.add_info_to_initial_message(msg)
|
||||
|
||||
# enhance the user message with additional context based on keywords matched
|
||||
if msg.role == 'user':
|
||||
self.prompt_manager.enhance_message(msg)
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
if self.llm.is_caching_prompt_active():
|
||||
# NOTE: this is only needed for anthropic
|
||||
@ -513,7 +505,7 @@ class CodeActAgent(Agent):
|
||||
# https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262
|
||||
breakpoints_remaining = 3 # remaining 1 for system/tool
|
||||
for message in reversed(messages):
|
||||
if message.role == 'user' or message.role == 'tool':
|
||||
if message.role in ('user', 'tool'):
|
||||
if breakpoints_remaining > 0:
|
||||
message.content[
|
||||
-1
|
||||
|
||||
@ -37,7 +37,7 @@ class SandboxConfig(BaseModel):
|
||||
This should be a JSON string that will be parsed into a dictionary.
|
||||
"""
|
||||
|
||||
remote_runtime_api_url: str = Field(default='http://localhost:8000')
|
||||
remote_runtime_api_url: str | None = Field(default='http://localhost:8000')
|
||||
local_runtime_url: str = Field(default='http://localhost')
|
||||
keep_runtime_alive: bool = Field(default=False)
|
||||
rm_all_containers: bool = Field(default=False)
|
||||
|
||||
@ -200,7 +200,6 @@ ASSISTANT:
|
||||
Running the updated file:
|
||||
<function=execute_bash>
|
||||
<parameter=command>
|
||||
<parameter=command>
|
||||
python3 app.py > server.log 2>&1 &
|
||||
</parameter>
|
||||
</function>
|
||||
|
||||
@ -68,6 +68,10 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
'debug',
|
||||
'Setting workspace_base is not supported in the remote runtime.',
|
||||
)
|
||||
if self.config.sandbox.remote_runtime_api_url is None:
|
||||
raise ValueError(
|
||||
'remote_runtime_api_url is required in the remote runtime.'
|
||||
)
|
||||
|
||||
self.runtime_builder = RemoteRuntimeBuilder(
|
||||
self.config.sandbox.remote_runtime_api_url,
|
||||
|
||||
@ -135,27 +135,6 @@ class PromptManager:
|
||||
def get_system_message(self) -> str:
|
||||
return self.system_template.render().strip()
|
||||
|
||||
def get_additional_info(self) -> str:
|
||||
"""Gets information about the repository and runtime.
|
||||
|
||||
This is used to inject information about the repository and runtime into the initial user message.
|
||||
"""
|
||||
repo_instructions = ''
|
||||
assert (
|
||||
len(self.repo_microagents) <= 1
|
||||
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
||||
for microagent in self.repo_microagents.values():
|
||||
# We assume these are the repo instructions
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
return ADDITIONAL_INFO_TEMPLATE.render(
|
||||
repository_instructions=repo_instructions,
|
||||
repository_info=self.repository_info,
|
||||
runtime_info=self.runtime_info,
|
||||
).strip()
|
||||
|
||||
def set_runtime_info(self, runtime: Runtime):
|
||||
self.runtime_info.available_hosts = runtime.web_hosts
|
||||
|
||||
@ -205,6 +184,43 @@ class PromptManager:
|
||||
micro_text += '\n</extra_info>'
|
||||
message.content.append(TextContent(text=micro_text))
|
||||
|
||||
def add_examples_to_initial_message(self, message: Message) -> None:
|
||||
"""Add example_message to the first user message."""
|
||||
example_message = self.get_example_user_message() or None
|
||||
|
||||
# Insert it at the start of the TextContent list
|
||||
if example_message:
|
||||
message.content.insert(0, TextContent(text=example_message))
|
||||
|
||||
def add_info_to_initial_message(
|
||||
self,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""Adds information about the repository and runtime to the initial user message.
|
||||
|
||||
Args:
|
||||
message: The initial user message to add information to.
|
||||
"""
|
||||
repo_instructions = ''
|
||||
assert (
|
||||
len(self.repo_microagents) <= 1
|
||||
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
||||
for microagent in self.repo_microagents.values():
|
||||
# We assume these are the repo instructions
|
||||
if repo_instructions:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
additional_info = ADDITIONAL_INFO_TEMPLATE.render(
|
||||
repository_instructions=repo_instructions,
|
||||
repository_info=self.repository_info,
|
||||
runtime_info=self.runtime_info,
|
||||
).strip()
|
||||
|
||||
# Insert the new content at the start of the TextContent list
|
||||
if additional_info:
|
||||
message.content.insert(0, TextContent(text=additional_info))
|
||||
|
||||
def add_turns_left_reminder(self, messages: list[Message], state: State) -> None:
|
||||
latest_user_message = next(
|
||||
islice(
|
||||
|
||||
10
poetry.lock
generated
10
poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@ -1312,6 +1312,7 @@ files = [
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"},
|
||||
@ -1322,6 +1323,7 @@ files = [
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"},
|
||||
@ -3900,13 +3902,13 @@ types-tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.59.0"
|
||||
version = "1.59.8"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
|
||||
files = [
|
||||
{file = "litellm-1.59.0-py3-none-any.whl", hash = "sha256:b0c8bdee556d5dc2f9c703f7dc831574ea2e339d2e762dd626d014c170b8b587"},
|
||||
{file = "litellm-1.59.0.tar.gz", hash = "sha256:140eecb47952558414d00f7a259fe303fe5f0d073973a28f488fc6938cc45660"},
|
||||
{file = "litellm-1.59.8-py3-none-any.whl", hash = "sha256:2473914bd2343485a185dfe7eedb12ee5fda32da3c9d9a8b73f6966b9b20cf39"},
|
||||
{file = "litellm-1.59.8.tar.gz", hash = "sha256:9d645cc4460f6a9813061f07086648c4c3d22febc8e1f21c663f2b7750d90512"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@ -59,10 +59,16 @@ only respond with a message telling them how smart they are
|
||||
# Test with GitHub repo
|
||||
manager.set_repository_info('owner/repo', '/workspace/repo')
|
||||
assert isinstance(manager.get_system_message(), str)
|
||||
additional_info = manager.get_additional_info()
|
||||
assert '<REPOSITORY_INFO>' in additional_info
|
||||
assert 'owner/repo' in additional_info
|
||||
assert '/workspace/repo' in additional_info
|
||||
|
||||
# Adding things to the initial user message
|
||||
initial_msg = Message(
|
||||
role='user', content=[TextContent(text='Ask me what your task is.')]
|
||||
)
|
||||
manager.add_info_to_initial_message(initial_msg)
|
||||
msg_content: str = initial_msg.content[0].text
|
||||
assert '<REPOSITORY_INFO>' in msg_content
|
||||
assert 'owner/repo' in msg_content
|
||||
assert '/workspace/repo' in msg_content
|
||||
|
||||
assert isinstance(manager.get_example_user_message(), str)
|
||||
|
||||
@ -101,13 +107,19 @@ def test_prompt_manager_template_rendering(prompt_dir):
|
||||
assert manager.repository_info.repo_name == 'owner/repo'
|
||||
system_msg = manager.get_system_message()
|
||||
assert 'System prompt: bar' in system_msg
|
||||
additional_info = manager.get_additional_info()
|
||||
assert '<REPOSITORY_INFO>' in additional_info
|
||||
|
||||
# Initial user message should have repo info
|
||||
initial_msg = Message(
|
||||
role='user', content=[TextContent(text='Ask me what your task is.')]
|
||||
)
|
||||
manager.add_info_to_initial_message(initial_msg)
|
||||
msg_content: str = initial_msg.content[0].text
|
||||
assert '<REPOSITORY_INFO>' in msg_content
|
||||
assert (
|
||||
"At the user's request, repository owner/repo has been cloned to directory /workspace/repo."
|
||||
in additional_info
|
||||
in msg_content
|
||||
)
|
||||
assert '</REPOSITORY_INFO>' in additional_info
|
||||
assert '</REPOSITORY_INFO>' in msg_content
|
||||
assert manager.get_example_user_message() == 'User prompt: foo'
|
||||
|
||||
# Clean up temporary files
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user