diff --git a/docs/static/openapi.json b/docs/static/openapi.json index a3e6c0ccda..954c289578 100644 --- a/docs/static/openapi.json +++ b/docs/static/openapi.json @@ -876,6 +876,11 @@ "type": "string", "nullable": true }, + "conversation_instructions": { + "type": "string", + "nullable": true, + "description": "Optional instructions the agent must follow throughout the conversation while addressing the user's initial task" + }, "image_urls": { "type": "array", "items": { diff --git a/openhands/agenthub/codeact_agent/prompts/additional_info.j2 b/openhands/agenthub/codeact_agent/prompts/additional_info.j2 index 46b0290888..3d08e0e657 100644 --- a/openhands/agenthub/codeact_agent/prompts/additional_info.j2 +++ b/openhands/agenthub/codeact_agent/prompts/additional_info.j2 @@ -36,4 +36,9 @@ You are have access to the following environment variables Today's date is {{ runtime_info.date }} (UTC). {% endif %} +{% if conversation_instructions and conversation_instructions.content -%} + +{{ conversation_instructions.content }} + +{% endif %} {% endif %} diff --git a/openhands/cli/main.py b/openhands/cli/main.py index 8ba7ef532f..7d160336da 100644 --- a/openhands/cli/main.py +++ b/openhands/cli/main.py @@ -105,6 +105,7 @@ async def run_session( settings_store: FileSettingsStore, current_dir: str, task_content: str | None = None, + conversation_instructions: str | None = None, session_name: str | None = None, ) -> bool: reload_microagents = False @@ -248,6 +249,7 @@ async def run_session( sid=sid, selected_repository=config.sandbox.selected_repo, repo_directory=repo_directory, + conversation_instructions=conversation_instructions, ) # Add MCP tools to the agent diff --git a/openhands/core/main.py b/openhands/core/main.py index e9cbc45b26..0484ce14fc 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -55,6 +55,7 @@ async def run_controller( fake_user_response_fn: FakeUserResponseFunc | None = None, headless_mode: bool = True, memory: Memory | None = None, + conversation_instructions: str | None = None, ) -> State | None: """Main coroutine to run the agent controller with task input flexibility. @@ -126,6 +127,7 @@ async def run_controller( sid=sid, selected_repository=config.sandbox.selected_repo, repo_directory=repo_directory, + conversation_instructions=conversation_instructions, ) # Add MCP tools to the agent diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 08fb45bc83..be54a9b093 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -135,6 +135,7 @@ def create_memory( selected_repository: str | None = None, repo_directory: str | None = None, status_callback: Callable | None = None, + conversation_instructions: str | None = None, ) -> Memory: """Create a memory for the agent to use. @@ -145,6 +146,7 @@ def create_memory( selected_repository: The repository to clone and start with, if any. repo_directory: The repository directory, if any. status_callback: Optional callback function to handle status updates. + conversation_instructions: Optional instructions that are passed to the agent """ memory = Memory( event_stream=event_stream, @@ -152,6 +154,8 @@ def create_memory( status_callback=status_callback, ) + memory.set_conversation_instructions(conversation_instructions) + if runtime: # sets available hosts memory.set_runtime_info(runtime, {}) diff --git a/openhands/events/observation/agent.py b/openhands/events/observation/agent.py index c057efa8df..f97ea2bf43 100644 --- a/openhands/events/observation/agent.py +++ b/openhands/events/observation/agent.py @@ -75,6 +75,7 @@ class RecallObservation(Observation): additional_agent_instructions: str = '' date: str = '' custom_secrets_descriptions: dict[str, str] = field(default_factory=dict) + conversation_instructions: str = '' # knowledge microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list) @@ -117,6 +118,7 @@ class RecallObservation(Observation): f'additional_agent_instructions={self.additional_agent_instructions[:20]}...', f'date={self.date}' f'custom_secrets_descriptions={self.custom_secrets_descriptions}', + f'conversation_instructions={self.conversation_instructions[0:20]}...' ] ) else: diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index 17e39fb759..c0de1877b1 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -41,7 +41,12 @@ from openhands.events.observation.error import ErrorObservation from openhands.events.observation.mcp import MCPObservation from openhands.events.observation.observation import Observation from openhands.events.serialization.event import truncate_content -from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo +from openhands.utils.prompt import ( + ConversationInstructions, + PromptManager, + RepositoryInfo, + RuntimeInfo, +) class ConversationMemory: @@ -467,6 +472,13 @@ class ConversationMemory: custom_secrets_descriptions=obs.custom_secrets_descriptions, ) + conversation_instructions = None + + if obs.conversation_instructions: + conversation_instructions = ConversationInstructions( + content=obs.conversation_instructions + ) + repo_instructions = ( obs.repo_instructions if obs.repo_instructions else '' ) @@ -476,10 +488,10 @@ class ConversationMemory: repo_info.repo_name or repo_info.repo_directory ) has_runtime_info = runtime_info is not None and ( - runtime_info.available_hosts - or runtime_info.additional_agent_instructions + runtime_info.date or runtime_info.custom_secrets_descriptions ) has_repo_instructions = bool(repo_instructions.strip()) + has_conversation_instructions = conversation_instructions is not None # Filter and process microagent knowledge filtered_agents = [] @@ -497,11 +509,17 @@ class ConversationMemory: message_content = [] # Build the workspace context information - if has_repo_info or has_runtime_info or has_repo_instructions: + if ( + has_repo_info + or has_runtime_info + or has_repo_instructions + or has_conversation_instructions + ): formatted_workspace_text = ( self.prompt_manager.build_workspace_context( repository_info=repo_info, runtime_info=runtime_info, + conversation_instructions=conversation_instructions, repo_instructions=repo_instructions, ) ) diff --git a/openhands/memory/memory.py b/openhands/memory/memory.py index b34e263196..488dc795d6 100644 --- a/openhands/memory/memory.py +++ b/openhands/memory/memory.py @@ -22,7 +22,11 @@ from openhands.microagent import ( load_microagents_from_dir, ) from openhands.runtime.base import Runtime -from openhands.utils.prompt import RepositoryInfo, RuntimeInfo +from openhands.utils.prompt import ( + ConversationInstructions, + RepositoryInfo, + RuntimeInfo, +) GLOBAL_MICROAGENTS_DIR = os.path.join( os.path.dirname(os.path.dirname(openhands.__file__)), @@ -65,6 +69,7 @@ class Memory: # Store repository / runtime info to send them to the templating later self.repository_info: RepositoryInfo | None = None self.runtime_info: RuntimeInfo | None = None + self.conversation_instructions: ConversationInstructions | None = None # Load global microagents (Knowledge + Repo) # from typically OpenHands/microagents (i.e., the PUBLIC microagents) @@ -156,6 +161,7 @@ class Memory: or self.runtime_info or repo_instructions or microagent_knowledge + or self.conversation_instructions ): obs = RecallObservation( recall_type=RecallType.WORKSPACE_CONTEXT, @@ -180,6 +186,9 @@ class Memory: custom_secrets_descriptions=self.runtime_info.custom_secrets_descriptions if self.runtime_info is not None else {}, + conversation_instructions=self.conversation_instructions.content + if self.conversation_instructions is not None + else '', ) return obs return None @@ -290,7 +299,9 @@ class Memory: self.repository_info = None def set_runtime_info( - self, runtime: Runtime, custom_secrets_descriptions: dict[str, str] + self, + runtime: Runtime, + custom_secrets_descriptions: dict[str, str], ) -> None: """Store runtime info (web hosts, ports, etc.).""" # e.g. { '127.0.0.1': 8080 } @@ -306,9 +317,21 @@ class Memory: ) else: self.runtime_info = RuntimeInfo( - date=date, custom_secrets_descriptions=custom_secrets_descriptions + date=date, + custom_secrets_descriptions=custom_secrets_descriptions, ) + def set_conversation_instructions( + self, conversation_instructions: str | None + ) -> None: + """ + Set contextual information for conversation + This is information the agent may require + """ + self.conversation_instructions = ConversationInstructions( + content=conversation_instructions or '' + ) + def send_error_message(self, message_id: str, message: str): """Sends an error message if the callback function was provided.""" if self.status_callback: diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 73376f21b6..6a088b141c 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -61,6 +61,7 @@ class InitSessionRequest(BaseModel): image_urls: list[str] | None = None replay_json: str | None = None suggested_task: SuggestedTask | None = None + conversation_instructions: str | None = None model_config = {'extra': 'forbid'} @@ -82,6 +83,7 @@ async def _create_new_conversation( initial_user_msg: str | None, image_urls: list[str] | None, replay_json: str | None, + conversation_instructions: str | None = None, conversation_trigger: ConversationTrigger = ConversationTrigger.GUI, attach_convo_id: bool = False, ) -> AgentLoopInfo: @@ -120,6 +122,7 @@ async def _create_new_conversation( session_init_args['selected_repository'] = selected_repository session_init_args['custom_secrets'] = custom_secrets session_init_args['selected_branch'] = selected_branch + session_init_args['conversation_instructions'] = conversation_instructions conversation_init_data = ConversationInitData(**session_init_args) logger.info('Loading conversation store') conversation_store = await ConversationStoreImpl.get_instance(config, user_id) @@ -195,6 +198,7 @@ async def new_conversation( replay_json = data.replay_json suggested_task = data.suggested_task git_provider = data.git_provider + conversation_instructions = data.conversation_instructions conversation_trigger = ConversationTrigger.GUI @@ -222,6 +226,7 @@ async def new_conversation( image_urls=image_urls, replay_json=replay_json, conversation_trigger=conversation_trigger, + conversation_instructions=conversation_instructions ) return InitSessionResponse( diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index e388b9c37a..5d71977cce 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -90,6 +90,7 @@ class AgentSession: selected_repository: str | None = None, selected_branch: str | None = None, initial_message: MessageAction | None = None, + conversation_instructions: str | None = None, replay_json: str | None = None, ) -> None: """Starts the Agent session @@ -144,6 +145,7 @@ class AgentSession: self.memory = await self._create_memory( selected_repository=selected_repository, repo_directory=repo_directory, + conversation_instructions=conversation_instructions, custom_secrets_descriptions=custom_secrets_handler.get_custom_secrets_descriptions() ) @@ -415,7 +417,11 @@ class AgentSession: return controller async def _create_memory( - self, selected_repository: str | None, repo_directory: str | None, custom_secrets_descriptions: dict[str, str] + self, + selected_repository: str | None, + repo_directory: str | None, + conversation_instructions: str | None, + custom_secrets_descriptions: dict[str, str] ) -> Memory: memory = Memory( event_stream=self.event_stream, @@ -426,6 +432,7 @@ class AgentSession: if self.runtime: # sets available hosts and other runtime info memory.set_runtime_info(self.runtime, custom_secrets_descriptions) + memory.set_conversation_instructions(conversation_instructions) # loads microagents from repo/.openhands/microagents microagents: list[BaseMicroagent] = await call_sync_from_async( @@ -435,7 +442,10 @@ class AgentSession: memory.load_user_workspace_microagents(microagents) if selected_repository and repo_directory: - memory.set_repository_info(selected_repository, repo_directory) + memory.set_repository_info( + selected_repository, + repo_directory + ) return memory def _maybe_restore_state(self) -> State | None: diff --git a/openhands/server/session/conversation_init_data.py b/openhands/server/session/conversation_init_data.py index 859c9179f3..18acbb3806 100644 --- a/openhands/server/session/conversation_init_data.py +++ b/openhands/server/session/conversation_init_data.py @@ -14,6 +14,7 @@ class ConversationInitData(Settings): selected_repository: str | None = Field(default=None) replay_json: str | None = Field(default=None) selected_branch: str | None = Field(default=None) + conversation_instructions: str | None = Field(default=None) model_config = { 'arbitrary_types_allowed': True, diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index c5239eafe6..9a690f4e87 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -155,11 +155,13 @@ class Session: selected_repository = None selected_branch = None custom_secrets = None + conversation_instructions = None if isinstance(settings, ConversationInitData): git_provider_tokens = settings.git_provider_tokens selected_repository = settings.selected_repository selected_branch = settings.selected_branch custom_secrets = settings.custom_secrets + conversation_instructions = settings.conversation_instructions try: await self.agent_session.start( @@ -175,6 +177,7 @@ class Session: selected_repository=selected_repository, selected_branch=selected_branch, initial_message=initial_message, + conversation_instructions=conversation_instructions, replay_json=replay_json, ) except MicroagentValidationError as e: diff --git a/openhands/utils/prompt.py b/openhands/utils/prompt.py index d760419ce1..ab01db86de 100644 --- a/openhands/utils/prompt.py +++ b/openhands/utils/prompt.py @@ -25,6 +25,20 @@ class RepositoryInfo: repo_directory: str | None = None +@dataclass +class ConversationInstructions: + """ + Optional instructions the agent must follow throughout the conversation while addressing the user's initial task + + Examples include + + 1. Resolver instructions: you're responding to GitHub issue #1234, make sure to open a PR when you are done + 2. Slack instructions: make sure to check whether any of the context attached is relevant to the task + """ + + content: str = '' + + class PromptManager: """ Manages prompt templates and includes information from the user's workspace micro-agents and global micro-agents. @@ -74,6 +88,7 @@ class PromptManager: self, repository_info: RepositoryInfo | None, runtime_info: RuntimeInfo | None, + conversation_instructions: ConversationInstructions | None, repo_instructions: str = '', ) -> str: """Renders the additional info template with the stored repository/runtime info.""" @@ -81,6 +96,7 @@ class PromptManager: repository_info=repository_info, repository_instructions=repo_instructions, runtime_info=runtime_info, + conversation_instructions=conversation_instructions, ).strip() def build_microagent_info( diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py index 3589e4c142..c4897a4475 100644 --- a/tests/unit/test_memory.py +++ b/tests/unit/test_memory.py @@ -2,11 +2,12 @@ import asyncio import os import shutil import time -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from openhands.controller.agent import Agent +from openhands.controller.agent_controller import AgentController from openhands.core.config import AppConfig from openhands.core.main import run_controller from openhands.core.schema.agent import AgentState @@ -25,8 +26,14 @@ from openhands.memory.memory import Memory from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, ) +from openhands.server.session.agent_session import AgentSession from openhands.storage.memory import InMemoryFileStore -from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo +from openhands.utils.prompt import ( + ConversationInstructions, + PromptManager, + RepositoryInfo, + RuntimeInfo, +) @pytest.fixture @@ -76,6 +83,11 @@ def mock_agent(): system_message._id = -1 # Set invalid ID to avoid the ID check agent.get_system_message.return_value = system_message + agent.config = MagicMock() + agent.config.enable_mcp = False + + return agent + @pytest.mark.asyncio async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent): @@ -443,11 +455,16 @@ def test_custom_secrets_descriptions_serialization(prompt_dir): repo_name='test-owner/test-repo', repo_directory='/workspace/test-repo' ) + conversation_instructions = ConversationInstructions( + content='additional agent context for the task' + ) + # Build the workspace context message workspace_context = prompt_manager.build_workspace_context( repository_info=repository_info, runtime_info=runtime_info, repo_instructions='Test repository instructions', + conversation_instructions=conversation_instructions, ) # Verify that the workspace context includes the custom_secrets_descriptions @@ -456,6 +473,9 @@ def test_custom_secrets_descriptions_serialization(prompt_dir): assert f'$**{secret_name}**' in workspace_context assert secret_description in workspace_context + assert '' in workspace_context + assert 'additional agent context for the task' in workspace_context + def test_serialization_deserialization_with_custom_secrets(): """Test that RecallObservation can be serialized and deserialized with custom_secrets_descriptions.""" @@ -566,3 +586,54 @@ REPOSITORY INSTRUCTIONS: This is the second test repository. # Clean up os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent1_name}.md')) os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent2_name}.md')) + + +@pytest.mark.asyncio +async def test_conversation_instructions_plumbed_to_memory( + mock_agent, event_stream, file_store +): + # Setup + session = AgentSession( + sid='test-session', + file_store=file_store, + ) + + # Create a mock runtime and set it up + mock_runtime = MagicMock(spec=ActionExecutionClient) + + # Mock the runtime creation to set up the runtime attribute + async def mock_create_runtime(*args, **kwargs): + session.runtime = mock_runtime + return True + + session._create_runtime = AsyncMock(side_effect=mock_create_runtime) + + # Create a spy on set_initial_state + class SpyAgentController(AgentController): + set_initial_state_call_count = 0 + test_initial_state = None + + def set_initial_state(self, *args, state=None, **kwargs): + self.set_initial_state_call_count += 1 + self.test_initial_state = state + super().set_initial_state(*args, state=state, **kwargs) + + # Patch AgentController + with ( + patch( + 'openhands.server.session.agent_session.AgentController', SpyAgentController + ), + ): + await session.start( + runtime_name='test-runtime', + config=AppConfig(), + agent=mock_agent, + max_iterations=10, + conversation_instructions='instructions for conversation', + ) + + # Use the memory instance from the session, not the fixture + assert ( + session.memory.conversation_instructions.content + == 'instructions for conversation' + ) diff --git a/tests/unit/test_observation_serialization.py b/tests/unit/test_observation_serialization.py index b729a60c80..74c91e5cc9 100644 --- a/tests/unit/test_observation_serialization.py +++ b/tests/unit/test_observation_serialization.py @@ -248,6 +248,7 @@ def test_microagent_observation_serialization(): 'custom_secrets_descriptions': {'SECRET': 'CUSTOM'}, 'date': '04/12/1023', 'microagent_knowledge': [], + 'conversation_instructions': 'additional_context', }, } serialization_deserialization(original_observation_dict, RecallObservation) @@ -266,6 +267,7 @@ def test_microagent_observation_microagent_knowledge_serialization(): 'runtime_hosts': {}, 'additional_agent_instructions': '', 'custom_secrets_descriptions': {}, + 'conversation_instructions': 'additional_context', 'date': '', 'microagent_knowledge': [ { diff --git a/tests/unit/test_prompt_manager.py b/tests/unit/test_prompt_manager.py index 7a2751c02f..17b15e1129 100644 --- a/tests/unit/test_prompt_manager.py +++ b/tests/unit/test_prompt_manager.py @@ -7,7 +7,12 @@ from openhands.controller.state.state import State from openhands.core.message import Message, TextContent from openhands.events.observation.agent import MicroagentKnowledge from openhands.microagent import BaseMicroagent -from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo +from openhands.utils.prompt import ( + ConversationInstructions, + PromptManager, + RepositoryInfo, + RuntimeInfo, +) @pytest.fixture @@ -52,7 +57,10 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone # Test building additional info additional_info = manager.build_workspace_context( - repository_info=repo_info, runtime_info=None, repo_instructions='' + repository_info=repo_info, + runtime_info=None, + repo_instructions='', + conversation_instructions=None, ) assert '' in additional_info assert ( @@ -205,7 +213,14 @@ each of which has a corresponding port: {% if runtime_info.additional_agent_instructions %} {{ runtime_info.additional_agent_instructions }} {% endif %} + +Today's date is {{ runtime_info.date }} +{% if conversation_instructions.content %} + +{{ conversation_instructions.content }} + +{% endif %} {% endif %} """) @@ -221,11 +236,14 @@ each of which has a corresponding port: ) repo_instructions = 'This repository contains important code.' + conversation_instructions = ConversationInstructions(content='additional context') + # Build additional info result = manager.build_workspace_context( repository_info=repo_info, runtime_info=runtime_info, repo_instructions=repo_instructions, + conversation_instructions=conversation_instructions, ) # Check that all information is included @@ -237,7 +255,8 @@ each of which has a corresponding port: assert '' in result assert 'example.com (port 8080)' in result assert 'You know everything about this runtime.' in result - assert "Today's date is 02/12/1232 (UTC)." + assert "Today's date is 02/12/1232" in result + assert 'additional context' in result # Clean up os.remove(os.path.join(prompt_dir, 'additional_info.j2'))