[Feat]: add context msg to new conversation endpoint (#8586)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Rohit Malhotra 2025-05-20 16:47:15 -04:00 committed by GitHub
parent 6f5bb4341f
commit 0deabd5935
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 202 additions and 14 deletions

View File

@ -876,6 +876,11 @@
"type": "string",
"nullable": true
},
"conversation_instructions": {
"type": "string",
"nullable": true,
"description": "Optional instructions the agent must follow throughout the conversation while addressing the user's initial task"
},
"image_urls": {
"type": "array",
"items": {

View File

@ -36,4 +36,9 @@ You are have access to the following environment variables
Today's date is {{ runtime_info.date }} (UTC).
{% endif %}
</RUNTIME_INFORMATION>
{% if conversation_instructions and conversation_instructions.content -%}
<CONVERSATION_INSTRUCTIONS>
{{ conversation_instructions.content }}
</CONVERSATION_INSTRUCTIONS>
{% endif %}
{% endif %}

View File

@ -105,6 +105,7 @@ async def run_session(
settings_store: FileSettingsStore,
current_dir: str,
task_content: str | None = None,
conversation_instructions: str | None = None,
session_name: str | None = None,
) -> bool:
reload_microagents = False
@ -248,6 +249,7 @@ async def run_session(
sid=sid,
selected_repository=config.sandbox.selected_repo,
repo_directory=repo_directory,
conversation_instructions=conversation_instructions,
)
# Add MCP tools to the agent

View File

@ -55,6 +55,7 @@ async def run_controller(
fake_user_response_fn: FakeUserResponseFunc | None = None,
headless_mode: bool = True,
memory: Memory | None = None,
conversation_instructions: str | None = None,
) -> State | None:
"""Main coroutine to run the agent controller with task input flexibility.
@ -126,6 +127,7 @@ async def run_controller(
sid=sid,
selected_repository=config.sandbox.selected_repo,
repo_directory=repo_directory,
conversation_instructions=conversation_instructions,
)
# Add MCP tools to the agent

View File

@ -135,6 +135,7 @@ def create_memory(
selected_repository: str | None = None,
repo_directory: str | None = None,
status_callback: Callable | None = None,
conversation_instructions: str | None = None,
) -> Memory:
"""Create a memory for the agent to use.
@ -145,6 +146,7 @@ def create_memory(
selected_repository: The repository to clone and start with, if any.
repo_directory: The repository directory, if any.
status_callback: Optional callback function to handle status updates.
conversation_instructions: Optional instructions that are passed to the agent
"""
memory = Memory(
event_stream=event_stream,
@ -152,6 +154,8 @@ def create_memory(
status_callback=status_callback,
)
memory.set_conversation_instructions(conversation_instructions)
if runtime:
# sets available hosts
memory.set_runtime_info(runtime, {})

View File

@ -75,6 +75,7 @@ class RecallObservation(Observation):
additional_agent_instructions: str = ''
date: str = ''
custom_secrets_descriptions: dict[str, str] = field(default_factory=dict)
conversation_instructions: str = ''
# knowledge
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
@ -117,6 +118,7 @@ class RecallObservation(Observation):
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
f'date={self.date}'
f'custom_secrets_descriptions={self.custom_secrets_descriptions}',
f'conversation_instructions={self.conversation_instructions[0:20]}...'
]
)
else:

View File

@ -41,7 +41,12 @@ from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import truncate_content
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
from openhands.utils.prompt import (
ConversationInstructions,
PromptManager,
RepositoryInfo,
RuntimeInfo,
)
class ConversationMemory:
@ -467,6 +472,13 @@ class ConversationMemory:
custom_secrets_descriptions=obs.custom_secrets_descriptions,
)
conversation_instructions = None
if obs.conversation_instructions:
conversation_instructions = ConversationInstructions(
content=obs.conversation_instructions
)
repo_instructions = (
obs.repo_instructions if obs.repo_instructions else ''
)
@ -476,10 +488,10 @@ class ConversationMemory:
repo_info.repo_name or repo_info.repo_directory
)
has_runtime_info = runtime_info is not None and (
runtime_info.available_hosts
or runtime_info.additional_agent_instructions
runtime_info.date or runtime_info.custom_secrets_descriptions
)
has_repo_instructions = bool(repo_instructions.strip())
has_conversation_instructions = conversation_instructions is not None
# Filter and process microagent knowledge
filtered_agents = []
@ -497,11 +509,17 @@ class ConversationMemory:
message_content = []
# Build the workspace context information
if has_repo_info or has_runtime_info or has_repo_instructions:
if (
has_repo_info
or has_runtime_info
or has_repo_instructions
or has_conversation_instructions
):
formatted_workspace_text = (
self.prompt_manager.build_workspace_context(
repository_info=repo_info,
runtime_info=runtime_info,
conversation_instructions=conversation_instructions,
repo_instructions=repo_instructions,
)
)

View File

@ -22,7 +22,11 @@ from openhands.microagent import (
load_microagents_from_dir,
)
from openhands.runtime.base import Runtime
from openhands.utils.prompt import RepositoryInfo, RuntimeInfo
from openhands.utils.prompt import (
ConversationInstructions,
RepositoryInfo,
RuntimeInfo,
)
GLOBAL_MICROAGENTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
@ -65,6 +69,7 @@ class Memory:
# Store repository / runtime info to send them to the templating later
self.repository_info: RepositoryInfo | None = None
self.runtime_info: RuntimeInfo | None = None
self.conversation_instructions: ConversationInstructions | None = None
# Load global microagents (Knowledge + Repo)
# from typically OpenHands/microagents (i.e., the PUBLIC microagents)
@ -156,6 +161,7 @@ class Memory:
or self.runtime_info
or repo_instructions
or microagent_knowledge
or self.conversation_instructions
):
obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
@ -180,6 +186,9 @@ class Memory:
custom_secrets_descriptions=self.runtime_info.custom_secrets_descriptions
if self.runtime_info is not None
else {},
conversation_instructions=self.conversation_instructions.content
if self.conversation_instructions is not None
else '',
)
return obs
return None
@ -290,7 +299,9 @@ class Memory:
self.repository_info = None
def set_runtime_info(
self, runtime: Runtime, custom_secrets_descriptions: dict[str, str]
self,
runtime: Runtime,
custom_secrets_descriptions: dict[str, str],
) -> None:
"""Store runtime info (web hosts, ports, etc.)."""
# e.g. { '127.0.0.1': 8080 }
@ -306,9 +317,21 @@ class Memory:
)
else:
self.runtime_info = RuntimeInfo(
date=date, custom_secrets_descriptions=custom_secrets_descriptions
date=date,
custom_secrets_descriptions=custom_secrets_descriptions,
)
def set_conversation_instructions(
self, conversation_instructions: str | None
) -> None:
"""
Set contextual information for conversation
This is information the agent may require
"""
self.conversation_instructions = ConversationInstructions(
content=conversation_instructions or ''
)
def send_error_message(self, message_id: str, message: str):
"""Sends an error message if the callback function was provided."""
if self.status_callback:

View File

@ -61,6 +61,7 @@ class InitSessionRequest(BaseModel):
image_urls: list[str] | None = None
replay_json: str | None = None
suggested_task: SuggestedTask | None = None
conversation_instructions: str | None = None
model_config = {'extra': 'forbid'}
@ -82,6 +83,7 @@ async def _create_new_conversation(
initial_user_msg: str | None,
image_urls: list[str] | None,
replay_json: str | None,
conversation_instructions: str | None = None,
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
attach_convo_id: bool = False,
) -> AgentLoopInfo:
@ -120,6 +122,7 @@ async def _create_new_conversation(
session_init_args['selected_repository'] = selected_repository
session_init_args['custom_secrets'] = custom_secrets
session_init_args['selected_branch'] = selected_branch
session_init_args['conversation_instructions'] = conversation_instructions
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
@ -195,6 +198,7 @@ async def new_conversation(
replay_json = data.replay_json
suggested_task = data.suggested_task
git_provider = data.git_provider
conversation_instructions = data.conversation_instructions
conversation_trigger = ConversationTrigger.GUI
@ -222,6 +226,7 @@ async def new_conversation(
image_urls=image_urls,
replay_json=replay_json,
conversation_trigger=conversation_trigger,
conversation_instructions=conversation_instructions
)
return InitSessionResponse(

View File

@ -90,6 +90,7 @@ class AgentSession:
selected_repository: str | None = None,
selected_branch: str | None = None,
initial_message: MessageAction | None = None,
conversation_instructions: str | None = None,
replay_json: str | None = None,
) -> None:
"""Starts the Agent session
@ -144,6 +145,7 @@ class AgentSession:
self.memory = await self._create_memory(
selected_repository=selected_repository,
repo_directory=repo_directory,
conversation_instructions=conversation_instructions,
custom_secrets_descriptions=custom_secrets_handler.get_custom_secrets_descriptions()
)
@ -415,7 +417,11 @@ class AgentSession:
return controller
async def _create_memory(
self, selected_repository: str | None, repo_directory: str | None, custom_secrets_descriptions: dict[str, str]
self,
selected_repository: str | None,
repo_directory: str | None,
conversation_instructions: str | None,
custom_secrets_descriptions: dict[str, str]
) -> Memory:
memory = Memory(
event_stream=self.event_stream,
@ -426,6 +432,7 @@ class AgentSession:
if self.runtime:
# sets available hosts and other runtime info
memory.set_runtime_info(self.runtime, custom_secrets_descriptions)
memory.set_conversation_instructions(conversation_instructions)
# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroagent] = await call_sync_from_async(
@ -435,7 +442,10 @@ class AgentSession:
memory.load_user_workspace_microagents(microagents)
if selected_repository and repo_directory:
memory.set_repository_info(selected_repository, repo_directory)
memory.set_repository_info(
selected_repository,
repo_directory
)
return memory
def _maybe_restore_state(self) -> State | None:

View File

@ -14,6 +14,7 @@ class ConversationInitData(Settings):
selected_repository: str | None = Field(default=None)
replay_json: str | None = Field(default=None)
selected_branch: str | None = Field(default=None)
conversation_instructions: str | None = Field(default=None)
model_config = {
'arbitrary_types_allowed': True,

View File

@ -155,11 +155,13 @@ class Session:
selected_repository = None
selected_branch = None
custom_secrets = None
conversation_instructions = None
if isinstance(settings, ConversationInitData):
git_provider_tokens = settings.git_provider_tokens
selected_repository = settings.selected_repository
selected_branch = settings.selected_branch
custom_secrets = settings.custom_secrets
conversation_instructions = settings.conversation_instructions
try:
await self.agent_session.start(
@ -175,6 +177,7 @@ class Session:
selected_repository=selected_repository,
selected_branch=selected_branch,
initial_message=initial_message,
conversation_instructions=conversation_instructions,
replay_json=replay_json,
)
except MicroagentValidationError as e:

View File

@ -25,6 +25,20 @@ class RepositoryInfo:
repo_directory: str | None = None
@dataclass
class ConversationInstructions:
"""
Optional instructions the agent must follow throughout the conversation while addressing the user's initial task
Examples include
1. Resolver instructions: you're responding to GitHub issue #1234, make sure to open a PR when you are done
2. Slack instructions: make sure to check whether any of the context attached is relevant to the task <context_messages>
"""
content: str = ''
class PromptManager:
"""
Manages prompt templates and includes information from the user's workspace micro-agents and global micro-agents.
@ -74,6 +88,7 @@ class PromptManager:
self,
repository_info: RepositoryInfo | None,
runtime_info: RuntimeInfo | None,
conversation_instructions: ConversationInstructions | None,
repo_instructions: str = '',
) -> str:
"""Renders the additional info template with the stored repository/runtime info."""
@ -81,6 +96,7 @@ class PromptManager:
repository_info=repository_info,
repository_instructions=repo_instructions,
runtime_info=runtime_info,
conversation_instructions=conversation_instructions,
).strip()
def build_microagent_info(

View File

@ -2,11 +2,12 @@ import asyncio
import os
import shutil
import time
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.core.config import AppConfig
from openhands.core.main import run_controller
from openhands.core.schema.agent import AgentState
@ -25,8 +26,14 @@ from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
from openhands.utils.prompt import (
ConversationInstructions,
PromptManager,
RepositoryInfo,
RuntimeInfo,
)
@pytest.fixture
@ -76,6 +83,11 @@ def mock_agent():
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
agent.config = MagicMock()
agent.config.enable_mcp = False
return agent
@pytest.mark.asyncio
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
@ -443,11 +455,16 @@ def test_custom_secrets_descriptions_serialization(prompt_dir):
repo_name='test-owner/test-repo', repo_directory='/workspace/test-repo'
)
conversation_instructions = ConversationInstructions(
content='additional agent context for the task'
)
# Build the workspace context message
workspace_context = prompt_manager.build_workspace_context(
repository_info=repository_info,
runtime_info=runtime_info,
repo_instructions='Test repository instructions',
conversation_instructions=conversation_instructions,
)
# Verify that the workspace context includes the custom_secrets_descriptions
@ -456,6 +473,9 @@ def test_custom_secrets_descriptions_serialization(prompt_dir):
assert f'$**{secret_name}**' in workspace_context
assert secret_description in workspace_context
assert '<CONVERSATION_INSTRUCTIONS>' in workspace_context
assert 'additional agent context for the task' in workspace_context
def test_serialization_deserialization_with_custom_secrets():
"""Test that RecallObservation can be serialized and deserialized with custom_secrets_descriptions."""
@ -566,3 +586,54 @@ REPOSITORY INSTRUCTIONS: This is the second test repository.
# Clean up
os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent1_name}.md'))
os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent2_name}.md'))
@pytest.mark.asyncio
async def test_conversation_instructions_plumbed_to_memory(
mock_agent, event_stream, file_store
):
# Setup
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a spy on set_initial_state
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Patch AgentController
with (
patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
),
):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
conversation_instructions='instructions for conversation',
)
# Use the memory instance from the session, not the fixture
assert (
session.memory.conversation_instructions.content
== 'instructions for conversation'
)

View File

@ -248,6 +248,7 @@ def test_microagent_observation_serialization():
'custom_secrets_descriptions': {'SECRET': 'CUSTOM'},
'date': '04/12/1023',
'microagent_knowledge': [],
'conversation_instructions': 'additional_context',
},
}
serialization_deserialization(original_observation_dict, RecallObservation)
@ -266,6 +267,7 @@ def test_microagent_observation_microagent_knowledge_serialization():
'runtime_hosts': {},
'additional_agent_instructions': '',
'custom_secrets_descriptions': {},
'conversation_instructions': 'additional_context',
'date': '',
'microagent_knowledge': [
{

View File

@ -7,7 +7,12 @@ from openhands.controller.state.state import State
from openhands.core.message import Message, TextContent
from openhands.events.observation.agent import MicroagentKnowledge
from openhands.microagent import BaseMicroagent
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
from openhands.utils.prompt import (
ConversationInstructions,
PromptManager,
RepositoryInfo,
RuntimeInfo,
)
@pytest.fixture
@ -52,7 +57,10 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone
# Test building additional info
additional_info = manager.build_workspace_context(
repository_info=repo_info, runtime_info=None, repo_instructions=''
repository_info=repo_info,
runtime_info=None,
repo_instructions='',
conversation_instructions=None,
)
assert '<REPOSITORY_INFO>' in additional_info
assert (
@ -205,7 +213,14 @@ each of which has a corresponding port:
{% if runtime_info.additional_agent_instructions %}
{{ runtime_info.additional_agent_instructions }}
{% endif %}
Today's date is {{ runtime_info.date }}
</RUNTIME_INFORMATION>
{% if conversation_instructions.content %}
<CONVERSATION_INSTRUCTIONS>
{{ conversation_instructions.content }}
</CONVERSATION_INSTRUCTIONS>
{% endif %}
{% endif %}
""")
@ -221,11 +236,14 @@ each of which has a corresponding port:
)
repo_instructions = 'This repository contains important code.'
conversation_instructions = ConversationInstructions(content='additional context')
# Build additional info
result = manager.build_workspace_context(
repository_info=repo_info,
runtime_info=runtime_info,
repo_instructions=repo_instructions,
conversation_instructions=conversation_instructions,
)
# Check that all information is included
@ -237,7 +255,8 @@ each of which has a corresponding port:
assert '<RUNTIME_INFORMATION>' in result
assert 'example.com (port 8080)' in result
assert 'You know everything about this runtime.' in result
assert "Today's date is 02/12/1232 (UTC)."
assert "Today's date is 02/12/1232" in result
assert 'additional context' in result
# Clean up
os.remove(os.path.join(prompt_dir, 'additional_info.j2'))