From f5f988e552856276956d5cad911df3f55eb07a1b Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Sat, 4 Jan 2025 20:08:47 -0500 Subject: [PATCH] fix(agent controller): state.metrics is missing on exception (#6036) --- openhands/controller/agent_controller.py | 5 +- tests/unit/test_agent_controller.py | 70 +++++++++++++++++++++++- tests/unit/test_llm.py | 6 +- 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 743decc4df..625d64280c 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -191,6 +191,7 @@ class AgentController: self, e: Exception, ): + """React to an exception by setting the agent state to error and sending a status message.""" await self.set_agent_state_to(AgentState.ERROR) if self.status_callback is not None: err_id = '' @@ -348,7 +349,6 @@ class AgentController: def _reset(self) -> None: """Resets the agent controller""" - # make sure there is an Observation with the tool call metadata to be recognized by the agent # otherwise the pending action is found in history, but it's incomplete without an obs with tool result if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'): @@ -389,6 +389,9 @@ class AgentController: return if new_state in (AgentState.STOPPED, AgentState.ERROR): + # sync existing metrics BEFORE resetting the agent + self.update_state_after_step() + self.state.metrics.merge(self.state.local_metrics) self._reset() elif ( new_state == AgentState.RUNNING diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index a2136c2393..0eba74edd1 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -39,7 +39,11 @@ def event_loop(): def mock_agent(): agent = MagicMock(spec=Agent) agent.llm = MagicMock(spec=LLM) - agent.llm.metrics = MagicMock(spec=Metrics) + metrics = MagicMock(spec=Metrics) + metrics.costs = [] + metrics.accumulated_cost = 0.0 + metrics.response_latencies = [] + agent.llm.metrics = metrics return agent @@ -292,6 +296,14 @@ async def test_delegate_step_different_states( async def test_max_iterations_extension(mock_agent, mock_event_stream): # Test with headless_mode=False - should extend max_iterations initial_state = State(max_iterations=10) + + # Set up proper metrics mock with required attributes + metrics = MagicMock(spec=Metrics) + metrics._costs = [] + metrics._response_latencies = [] + metrics.accumulated_cost = 0.0 + mock_agent.llm.metrics = metrics + controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, @@ -544,3 +556,59 @@ async def test_reset_with_pending_action_no_metadata( # Verify that agent.reset() was called mock_agent.reset.assert_called_once() await controller.close() + + +@pytest.mark.asyncio +async def test_run_controller_max_iterations_has_metrics(): + config = AppConfig( + max_iterations=3, + ) + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + agent = MagicMock(spec=Agent) + agent.llm = MagicMock(spec=LLM) + agent.llm.metrics = Metrics() + agent.llm.config = config.get_llm_config() + + def agent_step_fn(state): + print(f'agent_step_fn received state: {state}') + # Mock the cost of the LLM + agent.llm.metrics.add_cost(10.0) + print( + f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}' + ) + return CmdRunAction(command='ls') + + agent.step = agent_step_fn + + runtime = MagicMock(spec=Runtime) + + def on_event(event: Event): + if isinstance(event, CmdRunAction): + non_fatal_error_obs = ErrorObservation( + 'Non fatal error. event id: ' + str(event.id) + ) + non_fatal_error_obs._cause = event.id + event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT) + + event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) + runtime.event_stream = event_stream + + state = await run_controller( + config=config, + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + agent=agent, + fake_user_response_fn=lambda _: 'repeat', + ) + assert state.iteration == 3 + assert state.agent_state == AgentState.ERROR + assert ( + state.last_error + == 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3' + ) + assert ( + state.metrics.accumulated_cost == 10.0 * 3 + ), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}' diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 46923aa217..edf82d8aa4 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -141,9 +141,9 @@ def test_llm_reset(): initial_metrics.add_cost(1.0) initial_metrics.add_response_latency(0.5, 'test-id') llm.reset() - assert llm.metrics._accumulated_cost != initial_metrics._accumulated_cost - assert llm.metrics._costs != initial_metrics._costs - assert llm.metrics._response_latencies != initial_metrics._response_latencies + assert llm.metrics.accumulated_cost != initial_metrics.accumulated_cost + assert llm.metrics.costs != initial_metrics.costs + assert llm.metrics.response_latencies != initial_metrics.response_latencies assert isinstance(llm.metrics, Metrics)