mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Expose accumulate and llm_metric from eventstream (backend) (#7082)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
2cb5b91300
commit
5e521a4a6e
@ -53,6 +53,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_trajectory, truncate_content
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
|
||||
# note: RESUME is only available on web GUI
|
||||
TRAFFIC_CONTROL_REMINDER = (
|
||||
@ -733,6 +734,10 @@ class AgentController:
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION)
|
||||
|
||||
# Create and log metrics for frontend display
|
||||
self._prepare_metrics_for_frontend(action)
|
||||
|
||||
self.event_stream.add_event(action, action._source) # type: ignore [attr-defined]
|
||||
|
||||
await self.update_state_after_step()
|
||||
@ -1086,6 +1091,44 @@ class AgentController:
|
||||
|
||||
return self._stuck_detector.is_stuck(self.headless_mode)
|
||||
|
||||
def _prepare_metrics_for_frontend(self, action: Action) -> None:
|
||||
"""Create a minimal metrics object for frontend display and log it.
|
||||
|
||||
To avoid performance issues with long conversations, we only keep:
|
||||
- accumulated_cost: The current total cost
|
||||
- latest token_usage: Token statistics from the most recent API call
|
||||
|
||||
Args:
|
||||
action: The action to attach metrics to
|
||||
"""
|
||||
metrics = Metrics(model_name=self.agent.llm.metrics.model_name)
|
||||
metrics.accumulated_cost = self.agent.llm.metrics.accumulated_cost
|
||||
if self.agent.llm.metrics.token_usages:
|
||||
latest_usage = self.agent.llm.metrics.token_usages[-1]
|
||||
metrics.add_token_usage(
|
||||
prompt_tokens=latest_usage.prompt_tokens,
|
||||
completion_tokens=latest_usage.completion_tokens,
|
||||
cache_read_tokens=latest_usage.cache_read_tokens,
|
||||
cache_write_tokens=latest_usage.cache_write_tokens,
|
||||
response_id=latest_usage.response_id,
|
||||
)
|
||||
action.llm_metrics = metrics
|
||||
|
||||
# Log the metrics information for frontend display
|
||||
log_usage: TokenUsage | None = (
|
||||
metrics.token_usages[-1] if metrics.token_usages else None
|
||||
)
|
||||
self.log(
|
||||
'debug',
|
||||
f'Action metrics - accumulated_cost: {metrics.accumulated_cost}, '
|
||||
f'tokens (prompt/completion/cache_read/cache_write): '
|
||||
f'{log_usage.prompt_tokens if log_usage else 0}/'
|
||||
f'{log_usage.completion_tokens if log_usage else 0}/'
|
||||
f'{log_usage.cache_read_tokens if log_usage else 0}/'
|
||||
f'{log_usage.cache_write_tokens if log_usage else 0}',
|
||||
extra={'msg_type': 'METRICS'},
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'AgentController(id={getattr(self, "id", "<uninitialized>")}, '
|
||||
|
||||
@ -9,6 +9,7 @@ from openhands.events.serialization.action import action_from_dict
|
||||
from openhands.events.serialization.observation import observation_from_dict
|
||||
from openhands.events.serialization.utils import remove_fields
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm.metrics import Cost, Metrics, ResponseLatency, TokenUsage
|
||||
|
||||
# TODO: move `content` into `extras`
|
||||
TOP_KEYS = [
|
||||
@ -20,8 +21,9 @@ TOP_KEYS = [
|
||||
'action',
|
||||
'observation',
|
||||
'tool_call_metadata',
|
||||
'llm_metrics',
|
||||
]
|
||||
UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause', 'tool_call_metadata']
|
||||
UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause', 'tool_call_metadata', 'llm_metrics']
|
||||
|
||||
DELETE_FROM_TRAJECTORY_EXTRAS = {
|
||||
'screenshot',
|
||||
@ -54,6 +56,15 @@ def event_from_dict(data) -> 'Event':
|
||||
value = EventSource(value)
|
||||
if key == 'tool_call_metadata':
|
||||
value = ToolCallMetadata(**value)
|
||||
if key == 'llm_metrics':
|
||||
metrics = Metrics()
|
||||
if isinstance(value, dict):
|
||||
metrics.accumulated_cost = value.get('accumulated_cost', 0.0)
|
||||
for cost in value.get('costs', []):
|
||||
metrics._costs.append(Cost(**cost))
|
||||
metrics.response_latencies = [ResponseLatency(**latency) for latency in value.get('response_latencies', [])]
|
||||
metrics.token_usages = [TokenUsage(**usage) for usage in value.get('token_usages', [])]
|
||||
value = metrics
|
||||
setattr(evt, '_' + key, value)
|
||||
return evt
|
||||
|
||||
@ -81,6 +92,8 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
d['source'] = d['source'].value
|
||||
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
|
||||
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
|
||||
if key == 'llm_metrics' and 'llm_metrics' in d:
|
||||
d['llm_metrics'] = d['llm_metrics'].get()
|
||||
props.pop(key, None)
|
||||
if 'security_risk' in props and props['security_risk'] is None:
|
||||
props.pop('security_risk')
|
||||
|
||||
@ -19,7 +19,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@ -749,3 +749,105 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
|
||||
# Check that the context window exceeded error was raised during the run
|
||||
assert step_state.has_errored
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_metrics_copy():
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
# Create agent with metrics
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
metrics = Metrics(model_name='test-model')
|
||||
metrics.accumulated_cost = 0.05
|
||||
|
||||
# Add multiple token usages - we should get the last one in the action
|
||||
usage1 = TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=5,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=2,
|
||||
cache_write_tokens=2,
|
||||
response_id='test-id-1',
|
||||
)
|
||||
|
||||
usage2 = TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
cache_read_tokens=5,
|
||||
cache_write_tokens=5,
|
||||
response_id='test-id-2',
|
||||
)
|
||||
|
||||
metrics.token_usages = [usage1, usage2]
|
||||
|
||||
# Add a cost instance - should not be included in action metrics
|
||||
# This will increase accumulated_cost by 0.02
|
||||
metrics.add_cost(0.02)
|
||||
|
||||
# Add a response latency - should not be included in action metrics
|
||||
metrics.add_response_latency(0.5, 'test-id-2')
|
||||
|
||||
agent.llm.metrics = metrics
|
||||
|
||||
# Mock agent step to return an action
|
||||
action = MessageAction(content='Test message')
|
||||
|
||||
def agent_step_fn(state):
|
||||
return action
|
||||
|
||||
agent.step = agent_step_fn
|
||||
|
||||
# Create controller with correct parameters
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
event_stream=event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Execute one step
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
await controller._step()
|
||||
|
||||
# Get the last event from event stream
|
||||
events = list(event_stream.get_events())
|
||||
assert len(events) > 0
|
||||
last_action = events[-1]
|
||||
|
||||
# Verify metrics were copied correctly
|
||||
assert last_action.llm_metrics is not None
|
||||
assert (
|
||||
last_action.llm_metrics.accumulated_cost == 0.07
|
||||
) # 0.05 initial + 0.02 from add_cost
|
||||
|
||||
# Should include the last token usage
|
||||
assert len(last_action.llm_metrics.token_usages) == 1
|
||||
assert last_action.llm_metrics.token_usages[0].prompt_tokens == 10
|
||||
assert last_action.llm_metrics.token_usages[0].completion_tokens == 20
|
||||
assert last_action.llm_metrics.token_usages[0].cache_read_tokens == 5
|
||||
assert last_action.llm_metrics.token_usages[0].cache_write_tokens == 5
|
||||
assert last_action.llm_metrics.token_usages[0].response_id == 'test-id-2'
|
||||
|
||||
# Should not include the cost history
|
||||
assert len(last_action.llm_metrics.costs) == 0
|
||||
|
||||
# Should not include the response latency history
|
||||
assert len(last_action.llm_metrics.response_latencies) == 0
|
||||
|
||||
# Verify that there's no latency information in the action's metrics
|
||||
# Either directly or as a calculated property
|
||||
assert not hasattr(last_action.llm_metrics, 'latency')
|
||||
assert not hasattr(last_action.llm_metrics, 'total_latency')
|
||||
assert not hasattr(last_action.llm_metrics, 'average_latency')
|
||||
|
||||
# Verify it's a deep copy by modifying the original
|
||||
agent.llm.metrics.accumulated_cost = 0.1
|
||||
assert last_action.llm_metrics.accumulated_cost == 0.07
|
||||
|
||||
await controller.close()
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.observation import CmdOutputMetadata, CmdOutputObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.llm.metrics import Cost, Metrics, ResponseLatency, TokenUsage
|
||||
|
||||
|
||||
def test_command_output_success_serialization():
|
||||
@ -20,3 +22,102 @@ def test_command_output_success_serialization():
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is False
|
||||
|
||||
|
||||
def test_metrics_basic_serialization():
|
||||
# Create a basic action with only accumulated_cost
|
||||
action = MessageAction(content='Hello, world!')
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 0.03
|
||||
action._llm_metrics = metrics
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(action)
|
||||
assert 'llm_metrics' in serialized
|
||||
assert serialized['llm_metrics']['accumulated_cost'] == 0.03
|
||||
assert serialized['llm_metrics']['costs'] == []
|
||||
assert serialized['llm_metrics']['response_latencies'] == []
|
||||
assert serialized['llm_metrics']['token_usages'] == []
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.llm_metrics is not None
|
||||
assert deserialized.llm_metrics.accumulated_cost == 0.03
|
||||
assert len(deserialized.llm_metrics.costs) == 0
|
||||
assert len(deserialized.llm_metrics.response_latencies) == 0
|
||||
assert len(deserialized.llm_metrics.token_usages) == 0
|
||||
|
||||
|
||||
def test_metrics_full_serialization():
|
||||
# Create an observation with all metrics fields
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='test.txt',
|
||||
metadata=CmdOutputMetadata(exit_code=0),
|
||||
)
|
||||
metrics = Metrics(model_name='test-model')
|
||||
metrics.accumulated_cost = 0.03
|
||||
|
||||
# Add a cost
|
||||
cost = Cost(model='test-model', cost=0.02)
|
||||
metrics._costs.append(cost)
|
||||
|
||||
# Add a response latency
|
||||
latency = ResponseLatency(model='test-model', latency=0.5, response_id='test-id')
|
||||
metrics.response_latencies = [latency]
|
||||
|
||||
# Add token usage
|
||||
usage = TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='test-id',
|
||||
)
|
||||
metrics.token_usages = [usage]
|
||||
|
||||
obs._llm_metrics = metrics
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(obs)
|
||||
assert 'llm_metrics' in serialized
|
||||
metrics_dict = serialized['llm_metrics']
|
||||
assert metrics_dict['accumulated_cost'] == 0.03
|
||||
assert len(metrics_dict['costs']) == 1
|
||||
assert metrics_dict['costs'][0]['cost'] == 0.02
|
||||
assert len(metrics_dict['response_latencies']) == 1
|
||||
assert metrics_dict['response_latencies'][0]['latency'] == 0.5
|
||||
assert len(metrics_dict['token_usages']) == 1
|
||||
assert metrics_dict['token_usages'][0]['prompt_tokens'] == 10
|
||||
assert metrics_dict['token_usages'][0]['completion_tokens'] == 20
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.llm_metrics is not None
|
||||
assert deserialized.llm_metrics.accumulated_cost == 0.03
|
||||
assert len(deserialized.llm_metrics.costs) == 1
|
||||
assert deserialized.llm_metrics.costs[0].cost == 0.02
|
||||
assert len(deserialized.llm_metrics.response_latencies) == 1
|
||||
assert deserialized.llm_metrics.response_latencies[0].latency == 0.5
|
||||
assert len(deserialized.llm_metrics.token_usages) == 1
|
||||
assert deserialized.llm_metrics.token_usages[0].prompt_tokens == 10
|
||||
assert deserialized.llm_metrics.token_usages[0].completion_tokens == 20
|
||||
|
||||
|
||||
def test_metrics_none_serialization():
|
||||
# Test when metrics is None
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='test.txt',
|
||||
metadata=CmdOutputMetadata(exit_code=0),
|
||||
)
|
||||
obs._llm_metrics = None
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(obs)
|
||||
assert 'llm_metrics' not in serialized
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.llm_metrics is None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user