fix: metric logging in agent controller (#4387)

This commit is contained in:
Xingyao Wang 2024-10-15 09:32:39 -05:00 committed by GitHub
parent 50c13aad98
commit 6bbd75c6e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 7 deletions

View File

@ -127,8 +127,8 @@ class AgentController:
self.state.local_iteration += 1
async def update_state_after_step(self):
# update metrics especially for cost
self.state.local_metrics = self.agent.llm.metrics
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
if 'llm_completions' not in self.state.extra_data:
self.state.extra_data['llm_completions'] = []
self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
@ -139,12 +139,12 @@ class AgentController:
This method should be called for a particular type of errors, which have:
- a user-friendly message, which will be shown in the chat box. This should not be a raw exception message.
- an ErrorObservation that can be sent to the LLM by the agent, with the exception message, so it can self-correct next time.
- an ErrorObservation that can be sent to the LLM by the user role, with the exception message, so it can self-correct next time.
"""
self.state.last_error = message
if exception:
self.state.last_error += f': {exception}'
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
self.event_stream.add_event(ErrorObservation(message), EventSource.USER)
async def start_step_loop(self):
"""The main loop for the agent's step-by-step execution."""
@ -229,6 +229,11 @@ class AgentController:
observation_to_print.content, self.agent.llm.config.max_message_chars
)
logger.info(observation_to_print, extra={'msg_type': 'OBSERVATION'})
# Merge with the metrics from the LLM - it will to synced to the controller's local metrics in update_state_after_step()
if observation.llm_metrics is not None:
self.agent.llm.metrics.merge(observation.llm_metrics)
if self._pending_action and self._pending_action.id == observation.cause:
self._pending_action = None
if self.state.agent_state == AgentState.USER_CONFIRMED:
@ -450,8 +455,9 @@ class AgentController:
logger.info(action, extra={'msg_type': 'ACTION'})
if self._is_stuck():
await self.report_error('Agent got stuck in a loop')
# This need to go BEFORE report_error to sync metrics
await self.set_agent_state_to(AgentState.ERROR)
await self.report_error('Agent got stuck in a loop')
async def _delegate_step(self):
"""Executes a single step of the delegate agent."""
@ -519,20 +525,21 @@ class AgentController:
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# This need to go BEFORE report_error to sync metrics
await self.set_agent_state_to(AgentState.ERROR)
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
f'Agent reached maximum {limit_type} in headless mode, task stopped. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.set_agent_state_to(AgentState.PAUSED)
await self.report_error(
f'Agent reached maximum {limit_type}, task paused. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
f'{TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
stop_step = True
return stop_step

View File

@ -2,6 +2,8 @@ from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from openhands.core.metrics import Metrics
class EventSource(str, Enum):
AGENT = 'agent'
@ -58,3 +60,14 @@ class Event:
if hasattr(self, 'blocking'):
# .blocking needs to be set to True if .timeout is set
self.blocking = True
# optional metadata, LLM call cost of the edit
@property
def llm_metrics(self) -> Metrics | None:
if hasattr(self, '_llm_metrics'):
return self._llm_metrics # type: ignore[attr-defined]
return None
@llm_metrics.setter
def llm_metrics(self, value: Metrics) -> None:
self._llm_metrics = value