From 0687608feb591ca1356351dfb7365af12599778b Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 31 Oct 2024 19:33:13 +0100 Subject: [PATCH] [Arch proposal] ENVIRONMENT event source (#4584) Co-authored-by: Xingyao Wang --- openhands/controller/agent_controller.py | 8 ++- openhands/core/cli.py | 4 +- openhands/core/message.py | 2 + openhands/events/event.py | 1 + openhands/runtime/base.py | 2 + openhands/security/invariant/analyzer.py | 1 + openhands/server/session/agent_session.py | 2 +- openhands/server/session/session.py | 18 ++++-- tests/unit/test_agent_controller.py | 4 +- tests/unit/test_is_stuck.py | 74 +++++++++++------------ tests/unit/test_memory.py | 2 +- tests/unit/test_prompt_caching.py | 4 +- 12 files changed, 70 insertions(+), 52 deletions(-) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 58d3513fd6..cde15d78d0 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -156,7 +156,7 @@ class AgentController: if exception is not None and isinstance(exception, litellm.AuthenticationError): detail = 'Please check your credentials. Is your API key correct?' self.event_stream.add_event( - ErrorObservation(f'{message}:{detail}'), EventSource.USER + ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT ) async def start_step_loop(self): @@ -346,7 +346,8 @@ class AgentController: self.state.agent_state = new_state self.event_stream.add_event( - AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT + AgentStateChangedObservation('', self.state.agent_state), + EventSource.ENVIRONMENT, ) if new_state == AgentState.INIT and self.state.resume_state: @@ -423,7 +424,8 @@ class AgentController: if self._is_stuck(): # This need to go BEFORE report_error to sync metrics self.event_stream.add_event( - FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER + FatalErrorObservation('Agent got stuck in a loop'), + EventSource.ENVIRONMENT, ) return diff --git a/openhands/core/cli.py b/openhands/core/cli.py index 6a2620790f..acf39a71f5 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -61,7 +61,7 @@ def display_event(event: Event): if hasattr(event, 'thought'): display_message(event.thought) if isinstance(event, MessageAction): - if event.source != EventSource.USER: + if event.source == EventSource.AGENT: display_message(event.content) if isinstance(event, CmdRunAction): display_command(event.command) @@ -131,7 +131,7 @@ async def main(): next_message = input('How can I help? >> ') if next_message == 'exit': event_stream.add_event( - ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER + ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT ) return action = MessageAction(content=next_message) diff --git a/openhands/core/message.py b/openhands/core/message.py index 38568b504b..e538bec44b 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -49,6 +49,8 @@ class ImageContent(Content): class Message(BaseModel): + # NOTE: this is not the same as EventSource + # These are the roles in the LLM's APIs role: Literal['user', 'system', 'assistant', 'tool'] content: list[TextContent | ImageContent] = Field(default_factory=list) cache_enabled: bool = False diff --git a/openhands/events/event.py b/openhands/events/event.py index 6ec68acc55..126172bac7 100644 --- a/openhands/events/event.py +++ b/openhands/events/event.py @@ -9,6 +9,7 @@ from openhands.llm.metrics import Metrics class EventSource(str, Enum): AGENT = 'agent' USER = 'user' + ENVIRONMENT = 'environment' @dataclass diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 474ba741a4..1b1b01d32f 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -136,6 +136,8 @@ class Runtime(FileEditRuntimeMixin): ) observation._cause = event.id # type: ignore[attr-defined] observation.tool_call_metadata = event.tool_call_metadata + + # this might be unnecessary, since source should be set by the event stream when we're here source = event.source if event.source else EventSource.AGENT await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type] diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index ba7fb890eb..0ba13b4ecd 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -147,6 +147,7 @@ class InvariantAnalyzer(SecurityAnalyzer): new_event = action_from_dict( {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}} ) + # we should confirm only on agent actions event_source = event.source if event.source else EventSource.AGENT await call_sync_from_async(self.event_stream.add_event, new_event, event_source) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 41ee968891..194427b498 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -118,7 +118,7 @@ class AgentSession: agent_configs=agent_configs, ) self.event_stream.add_event( - ChangeAgentStateAction(AgentState.INIT), EventSource.USER + ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT ) if self.controller: self.controller.agent_task = self.controller.start_step_loop() diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 4e6119a185..c9e2a48388 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -73,10 +73,11 @@ class Session: async def _initialize_agent(self, data: dict): self.agent_session.event_stream.add_event( - ChangeAgentStateAction(AgentState.LOADING), EventSource.USER + ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT ) self.agent_session.event_stream.add_event( - AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT + AgentStateChangedObservation('', AgentState.LOADING), + EventSource.ENVIRONMENT, ) # Extract the agent-relevant arguments from the request args = {key: value for key, value in data.get('args', {}).items()} @@ -138,12 +139,19 @@ class Session: return if event.source == EventSource.AGENT: await self.send(event_to_dict(event)) - elif event.source == EventSource.USER and isinstance( + # NOTE: ipython observations are not sent here currently + elif event.source == EventSource.ENVIRONMENT and isinstance( event, CmdOutputObservation ): - await self.send(event_to_dict(event)) + # feedback from the environment to agent actions is understood as agent events by the UI + event_dict = event_to_dict(event) + event_dict['source'] = EventSource.AGENT + await self.send(event_dict) elif isinstance(event, ErrorObservation): - await self.send(event_to_dict(event)) + # send error events as agent events to the UI + event_dict = event_to_dict(event) + event_dict['source'] = EventSource.AGENT + await self.send(event_dict) async def dispatch(self, data: dict): action = data.get('action', '') diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 9b0522302f..4d63c27405 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -207,7 +207,9 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream): 'Non fatal error here to trigger loop' ) non_fatal_error_obs._cause = event.id - await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER) + await event_stream.async_add_event( + non_fatal_error_obs, EventSource.ENVIRONMENT + ) event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event) runtime.event_stream = event_stream diff --git a/tests/unit/test_is_stuck.py b/tests/unit/test_is_stuck.py index 4a13307521..2fe8a683c8 100644 --- a/tests/unit/test_is_stuck.py +++ b/tests/unit/test_is_stuck.py @@ -80,7 +80,7 @@ class TestStuckDetector: code=code_snippet, ) ipython_observation._cause = ipython_action._id - event_stream.add_event(ipython_observation, EventSource.USER) + event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT) def _impl_unterminated_string_error_events( self, event_stream: EventStream, random_line: bool, incidents: int = 4 @@ -96,7 +96,7 @@ class TestStuckDetector: code=code_snippet, ) ipython_observation._cause = ipython_action._id - event_stream.add_event(ipython_observation, EventSource.USER) + event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT) def test_history_too_short( self, stuck_detector: StuckDetector, event_stream: EventStream @@ -106,7 +106,7 @@ class TestStuckDetector: observation = NullObservation(content='') observation._cause = message_action.id event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(observation, EventSource.USER) + event_stream.add_event(observation, EventSource.ENVIRONMENT) cmd_action = CmdRunAction(command='ls') event_stream.add_event(cmd_action, EventSource.AGENT) @@ -114,7 +114,7 @@ class TestStuckDetector: command_id=1, command='ls', content='file1.txt\nfile2.txt' ) cmd_observation._cause = cmd_action._id - event_stream.add_event(cmd_observation, EventSource.USER) + event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT) # stuck_detector.state.history.set_event_stream(event_stream) @@ -131,7 +131,7 @@ class TestStuckDetector: # 2 events event_stream.add_event(hello_action, EventSource.USER) - event_stream.add_event(hello_observation, EventSource.USER) + event_stream.add_event(hello_observation, EventSource.ENVIRONMENT) cmd_action_1 = CmdRunAction(command='ls') event_stream.add_event(cmd_action_1, EventSource.AGENT) @@ -139,7 +139,7 @@ class TestStuckDetector: content='', command='ls', command_id=cmd_action_1._id ) cmd_observation_1._cause = cmd_action_1._id - event_stream.add_event(cmd_observation_1, EventSource.USER) + event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) # 4 events cmd_action_2 = CmdRunAction(command='ls') @@ -148,13 +148,13 @@ class TestStuckDetector: content='', command='ls', command_id=cmd_action_2._id ) cmd_observation_2._cause = cmd_action_2._id - event_stream.add_event(cmd_observation_2, EventSource.USER) + event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) # 6 events # random user message just because we can message_null_observation = NullObservation(content='') event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(message_null_observation, EventSource.USER) + event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) # 8 events assert stuck_detector.is_stuck() is False @@ -166,7 +166,7 @@ class TestStuckDetector: content='', command='ls', command_id=cmd_action_3._id ) cmd_observation_3._cause = cmd_action_3._id - event_stream.add_event(cmd_observation_3, EventSource.USER) + event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT) # 10 events assert len(collect_events(event_stream)) == 10 @@ -191,7 +191,7 @@ class TestStuckDetector: content='', command='ls', command_id=cmd_action_4._id ) cmd_observation_4._cause = cmd_action_4._id - event_stream.add_event(cmd_observation_4, EventSource.USER) + event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT) # 12 events assert len(collect_events(event_stream)) == 12 @@ -223,14 +223,14 @@ class TestStuckDetector: hello_observation = NullObservation(content='') event_stream.add_event(hello_action, EventSource.USER) hello_observation._cause = hello_action._id - event_stream.add_event(hello_observation, EventSource.USER) + event_stream.add_event(hello_observation, EventSource.ENVIRONMENT) # 2 events cmd_action_1 = CmdRunAction(command='invalid_command') event_stream.add_event(cmd_action_1, EventSource.AGENT) error_observation_1 = ErrorObservation(content='Command not found') error_observation_1._cause = cmd_action_1._id - event_stream.add_event(error_observation_1, EventSource.USER) + event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT) # 4 events cmd_action_2 = CmdRunAction(command='invalid_command') @@ -239,26 +239,26 @@ class TestStuckDetector: content='Command still not found or another error' ) error_observation_2._cause = cmd_action_2._id - event_stream.add_event(error_observation_2, EventSource.USER) + event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT) # 6 events message_null_observation = NullObservation(content='') event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(message_null_observation, EventSource.USER) + event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) # 8 events cmd_action_3 = CmdRunAction(command='invalid_command') event_stream.add_event(cmd_action_3, EventSource.AGENT) error_observation_3 = ErrorObservation(content='Different error') error_observation_3._cause = cmd_action_3._id - event_stream.add_event(error_observation_3, EventSource.USER) + event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT) # 10 events cmd_action_4 = CmdRunAction(command='invalid_command') event_stream.add_event(cmd_action_4, EventSource.AGENT) error_observation_4 = ErrorObservation(content='Command not found') error_observation_4._cause = cmd_action_4._id - event_stream.add_event(error_observation_4, EventSource.USER) + event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT) # 12 events with patch('logging.Logger.warning') as mock_warning: @@ -366,7 +366,7 @@ class TestStuckDetector: code='print("hello', ) ipython_observation_1._cause = ipython_action_1._id - event_stream.add_event(ipython_observation_1, EventSource.USER) + event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT) ipython_action_2 = IPythonRunCellAction(code='print("hello') event_stream.add_event(ipython_action_2, EventSource.AGENT) @@ -375,7 +375,7 @@ class TestStuckDetector: code='print("hello', ) ipython_observation_2._cause = ipython_action_2._id - event_stream.add_event(ipython_observation_2, EventSource.USER) + event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT) ipython_action_3 = IPythonRunCellAction(code='print("hello') event_stream.add_event(ipython_action_3, EventSource.AGENT) @@ -384,7 +384,7 @@ class TestStuckDetector: code='print("hello', ) ipython_observation_3._cause = ipython_action_3._id - event_stream.add_event(ipython_observation_3, EventSource.USER) + event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT) ipython_action_4 = IPythonRunCellAction(code='print("hello') event_stream.add_event(ipython_action_4, EventSource.AGENT) @@ -393,7 +393,7 @@ class TestStuckDetector: code='print("hello', ) ipython_observation_4._cause = ipython_action_4._id - event_stream.add_event(ipython_observation_4, EventSource.USER) + event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT) with patch('logging.Logger.warning') as mock_warning: assert stuck_detector.is_stuck() is False @@ -406,7 +406,7 @@ class TestStuckDetector: message_action._source = EventSource.USER event_stream.add_event(message_action, EventSource.USER) message_observation = NullObservation(content='') - event_stream.add_event(message_observation, EventSource.USER) + event_stream.add_event(message_observation, EventSource.ENVIRONMENT) cmd_action_1 = CmdRunAction(command='ls') event_stream.add_event(cmd_action_1, EventSource.AGENT) @@ -414,7 +414,7 @@ class TestStuckDetector: command_id=1, command='ls', content='file1.txt\nfile2.txt' ) cmd_observation_1._cause = cmd_action_1._id - event_stream.add_event(cmd_observation_1, EventSource.USER) + event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) read_action_1 = FileReadAction(path='file1.txt') event_stream.add_event(read_action_1, EventSource.AGENT) @@ -422,7 +422,7 @@ class TestStuckDetector: content='File content', path='file1.txt' ) read_observation_1._cause = read_action_1._id - event_stream.add_event(read_observation_1, EventSource.USER) + event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT) cmd_action_2 = CmdRunAction(command='ls') event_stream.add_event(cmd_action_2, EventSource.AGENT) @@ -430,7 +430,7 @@ class TestStuckDetector: command_id=2, command='ls', content='file1.txt\nfile2.txt' ) cmd_observation_2._cause = cmd_action_2._id - event_stream.add_event(cmd_observation_2, EventSource.USER) + event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) read_action_2 = FileReadAction(path='file1.txt') event_stream.add_event(read_action_2, EventSource.AGENT) @@ -438,12 +438,12 @@ class TestStuckDetector: content='File content', path='file1.txt' ) read_observation_2._cause = read_action_2._id - event_stream.add_event(read_observation_2, EventSource.USER) + event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT) # one more message to break the pattern message_null_observation = NullObservation(content='') event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(message_null_observation, EventSource.USER) + event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) cmd_action_3 = CmdRunAction(command='ls') event_stream.add_event(cmd_action_3, EventSource.AGENT) @@ -451,7 +451,7 @@ class TestStuckDetector: command_id=3, command='ls', content='file1.txt\nfile2.txt' ) cmd_observation_3._cause = cmd_action_3._id - event_stream.add_event(cmd_observation_3, EventSource.USER) + event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT) read_action_3 = FileReadAction(path='file1.txt') event_stream.add_event(read_action_3, EventSource.AGENT) @@ -459,7 +459,7 @@ class TestStuckDetector: content='File content', path='file1.txt' ) read_observation_3._cause = read_action_3._id - event_stream.add_event(read_observation_3, EventSource.USER) + event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT) with patch('logging.Logger.warning') as mock_warning: assert stuck_detector.is_stuck() is True @@ -475,7 +475,7 @@ class TestStuckDetector: event_stream.add_event(hello_action, EventSource.USER) hello_observation = NullObservation(content='') hello_observation._cause = hello_action._id - event_stream.add_event(hello_observation, EventSource.USER) + event_stream.add_event(hello_observation, EventSource.ENVIRONMENT) cmd_action_1 = CmdRunAction(command='ls') event_stream.add_event(cmd_action_1, EventSource.AGENT) @@ -483,7 +483,7 @@ class TestStuckDetector: command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt' ) cmd_observation_1._cause = cmd_action_1._id - event_stream.add_event(cmd_observation_1, EventSource.USER) + event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) read_action_1 = FileReadAction(path='file1.txt') event_stream.add_event(read_action_1, EventSource.AGENT) @@ -491,7 +491,7 @@ class TestStuckDetector: content='File content', path='file1.txt' ) read_observation_1._cause = read_action_1._id - event_stream.add_event(read_observation_1, EventSource.USER) + event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT) cmd_action_2 = CmdRunAction(command='pwd') event_stream.add_event(cmd_action_2, EventSource.AGENT) @@ -499,7 +499,7 @@ class TestStuckDetector: command_id=2, command='pwd', content='/home/user' ) cmd_observation_2._cause = cmd_action_2._id - event_stream.add_event(cmd_observation_2, EventSource.USER) + event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) read_action_2 = FileReadAction(path='file2.txt') event_stream.add_event(read_action_2, EventSource.AGENT) @@ -507,11 +507,11 @@ class TestStuckDetector: content='Another file content', path='file2.txt' ) read_observation_2._cause = read_action_2._id - event_stream.add_event(read_observation_2, EventSource.USER) + event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT) message_null_observation = NullObservation(content='') event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(message_null_observation, EventSource.USER) + event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) cmd_action_3 = CmdRunAction(command='pwd') event_stream.add_event(cmd_action_3, EventSource.AGENT) @@ -519,7 +519,7 @@ class TestStuckDetector: command_id=cmd_action_3.id, command='pwd', content='/home/user' ) cmd_observation_3._cause = cmd_action_3._id - event_stream.add_event(cmd_observation_3, EventSource.USER) + event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT) read_action_3 = FileReadAction(path='file2.txt') event_stream.add_event(read_action_3, EventSource.AGENT) @@ -527,7 +527,7 @@ class TestStuckDetector: content='Another file content', path='file2.txt' ) read_observation_3._cause = read_action_3._id - event_stream.add_event(read_observation_3, EventSource.USER) + event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT) assert stuck_detector.is_stuck() is False @@ -572,7 +572,7 @@ class TestStuckDetector: exit_code=0, ) cmd_output_observation._cause = cmd_kill_action._id - event_stream.add_event(cmd_output_observation, EventSource.USER) + event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT) message_action_7 = MessageAction(content="I'm doing well, thanks for asking.") event_stream.add_event(message_action_7, EventSource.AGENT) diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py index 10991ca27d..96c06e0fd4 100644 --- a/tests/unit/test_memory.py +++ b/tests/unit/test_memory.py @@ -88,7 +88,7 @@ def _create_observation_event(observation: str) -> Event: event = Event() event._id = -1 event._timestamp = datetime.now(timezone.utc).isoformat() - event._source = EventSource.USER + event._source = EventSource.ENVIRONMENT event.observation = observation return event diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 50c42bf662..1bd1828573 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -155,7 +155,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): command='ls -l', exit_code=0, ) - mock_event_stream.add_event(cmd_observation_1, EventSource.USER) + mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) message_action_2 = MessageAction("Now, let's create a new directory.") mock_event_stream.add_event(message_action_2, EventSource.AGENT) @@ -169,7 +169,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): command='mkdir new_directory', exit_code=0, ) - mock_event_stream.add_event(cmd_observation_2, EventSource.USER) + mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) codeact_agent.reset() messages = codeact_agent._get_messages(