fix(agent controller): state.metrics is missing on exception (#6036)

This commit is contained in:
Xingyao Wang 2025-01-04 20:08:47 -05:00 committed by GitHub
parent 0c58f469b4
commit f5f988e552
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 76 additions and 5 deletions

View File

@ -191,6 +191,7 @@ class AgentController:
self,
e: Exception,
):
"""React to an exception by setting the agent state to error and sending a status message."""
await self.set_agent_state_to(AgentState.ERROR)
if self.status_callback is not None:
err_id = ''
@ -348,7 +349,6 @@ class AgentController:
def _reset(self) -> None:
"""Resets the agent controller"""
# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
@ -389,6 +389,9 @@ class AgentController:
return
if new_state in (AgentState.STOPPED, AgentState.ERROR):
# sync existing metrics BEFORE resetting the agent
self.update_state_after_step()
self.state.metrics.merge(self.state.local_metrics)
self._reset()
elif (
new_state == AgentState.RUNNING

View File

@ -39,7 +39,11 @@ def event_loop():
def mock_agent():
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = MagicMock(spec=Metrics)
metrics = MagicMock(spec=Metrics)
metrics.costs = []
metrics.accumulated_cost = 0.0
metrics.response_latencies = []
agent.llm.metrics = metrics
return agent
@ -292,6 +296,14 @@ async def test_delegate_step_different_states(
async def test_max_iterations_extension(mock_agent, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations
initial_state = State(max_iterations=10)
# Set up proper metrics mock with required attributes
metrics = MagicMock(spec=Metrics)
metrics._costs = []
metrics._response_latencies = []
metrics.accumulated_cost = 0.0
mock_agent.llm.metrics = metrics
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
@ -544,3 +556,59 @@ async def test_reset_with_pending_action_no_metadata(
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_run_controller_max_iterations_has_metrics():
config = AppConfig(
max_iterations=3,
)
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
# Mock the cost of the LLM
agent.llm.metrics.add_cost(10.0)
print(
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
)
return CmdRunAction(command='ls')
agent.step = agent_step_fn
runtime = MagicMock(spec=Runtime)
def on_event(event: Event):
if isinstance(event, CmdRunAction):
non_fatal_error_obs = ErrorObservation(
'Non fatal error. event id: ' + str(event.id)
)
non_fatal_error_obs._cause = event.id
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
)
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert (
state.last_error
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
assert (
state.metrics.accumulated_cost == 10.0 * 3
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'

View File

@ -141,9 +141,9 @@ def test_llm_reset():
initial_metrics.add_cost(1.0)
initial_metrics.add_response_latency(0.5, 'test-id')
llm.reset()
assert llm.metrics._accumulated_cost != initial_metrics._accumulated_cost
assert llm.metrics._costs != initial_metrics._costs
assert llm.metrics._response_latencies != initial_metrics._response_latencies
assert llm.metrics.accumulated_cost != initial_metrics.accumulated_cost
assert llm.metrics.costs != initial_metrics.costs
assert llm.metrics.response_latencies != initial_metrics.response_latencies
assert isinstance(llm.metrics, Metrics)