mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
fix(agent controller): state.metrics is missing on exception (#6036)
This commit is contained in:
parent
0c58f469b4
commit
f5f988e552
@ -191,6 +191,7 @@ class AgentController:
|
||||
self,
|
||||
e: Exception,
|
||||
):
|
||||
"""React to an exception by setting the agent state to error and sending a status message."""
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
if self.status_callback is not None:
|
||||
err_id = ''
|
||||
@ -348,7 +349,6 @@ class AgentController:
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Resets the agent controller"""
|
||||
|
||||
# make sure there is an Observation with the tool call metadata to be recognized by the agent
|
||||
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
|
||||
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
|
||||
@ -389,6 +389,9 @@ class AgentController:
|
||||
return
|
||||
|
||||
if new_state in (AgentState.STOPPED, AgentState.ERROR):
|
||||
# sync existing metrics BEFORE resetting the agent
|
||||
self.update_state_after_step()
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
self._reset()
|
||||
elif (
|
||||
new_state == AgentState.RUNNING
|
||||
|
||||
@ -39,7 +39,11 @@ def event_loop():
|
||||
def mock_agent():
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = MagicMock(spec=Metrics)
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
metrics.costs = []
|
||||
metrics.accumulated_cost = 0.0
|
||||
metrics.response_latencies = []
|
||||
agent.llm.metrics = metrics
|
||||
return agent
|
||||
|
||||
|
||||
@ -292,6 +296,14 @@ async def test_delegate_step_different_states(
|
||||
async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
# Test with headless_mode=False - should extend max_iterations
|
||||
initial_state = State(max_iterations=10)
|
||||
|
||||
# Set up proper metrics mock with required attributes
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
metrics._costs = []
|
||||
metrics._response_latencies = []
|
||||
metrics.accumulated_cost = 0.0
|
||||
mock_agent.llm.metrics = metrics
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
@ -544,3 +556,59 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
# Verify that agent.reset() was called
|
||||
mock_agent.reset.assert_called_once()
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_max_iterations_has_metrics():
|
||||
config = AppConfig(
|
||||
max_iterations=3,
|
||||
)
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
# Mock the cost of the LLM
|
||||
agent.llm.metrics.add_cost(10.0)
|
||||
print(
|
||||
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
|
||||
)
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, CmdRunAction):
|
||||
non_fatal_error_obs = ErrorObservation(
|
||||
'Non fatal error. event id: ' + str(event.id)
|
||||
)
|
||||
non_fatal_error_obs._cause = event.id
|
||||
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content='Test message'),
|
||||
runtime=runtime,
|
||||
sid='test',
|
||||
agent=agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
)
|
||||
assert state.iteration == 3
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
state.last_error
|
||||
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
|
||||
)
|
||||
assert (
|
||||
state.metrics.accumulated_cost == 10.0 * 3
|
||||
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'
|
||||
|
||||
@ -141,9 +141,9 @@ def test_llm_reset():
|
||||
initial_metrics.add_cost(1.0)
|
||||
initial_metrics.add_response_latency(0.5, 'test-id')
|
||||
llm.reset()
|
||||
assert llm.metrics._accumulated_cost != initial_metrics._accumulated_cost
|
||||
assert llm.metrics._costs != initial_metrics._costs
|
||||
assert llm.metrics._response_latencies != initial_metrics._response_latencies
|
||||
assert llm.metrics.accumulated_cost != initial_metrics.accumulated_cost
|
||||
assert llm.metrics.costs != initial_metrics.costs
|
||||
assert llm.metrics.response_latencies != initial_metrics.response_latencies
|
||||
assert isinstance(llm.metrics, Metrics)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user