From 5e521a4a6ecf793ee5e7fe45b5653598d201c830 Mon Sep 17 00:00:00 2001 From: AutoLTX Date: Tue, 11 Mar 2025 19:18:08 +0800 Subject: [PATCH] Expose accumulate and llm_metric from eventstream (backend) (#7082) Co-authored-by: Engel Nyst --- openhands/controller/agent_controller.py | 43 ++++++++++ openhands/events/serialization/event.py | 15 +++- tests/unit/test_agent_controller.py | 104 ++++++++++++++++++++++- tests/unit/test_event_serialization.py | 103 +++++++++++++++++++++- 4 files changed, 262 insertions(+), 3 deletions(-) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 59f17804de..61bebacfc9 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -53,6 +53,7 @@ from openhands.events.observation import ( ) from openhands.events.serialization.event import event_to_trajectory, truncate_content from openhands.llm.llm import LLM +from openhands.llm.metrics import Metrics, TokenUsage # note: RESUME is only available on web GUI TRAFFIC_CONTROL_REMINDER = ( @@ -733,6 +734,10 @@ class AgentController: == ActionConfirmationStatus.AWAITING_CONFIRMATION ): await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION) + + # Create and log metrics for frontend display + self._prepare_metrics_for_frontend(action) + self.event_stream.add_event(action, action._source) # type: ignore [attr-defined] await self.update_state_after_step() @@ -1086,6 +1091,44 @@ class AgentController: return self._stuck_detector.is_stuck(self.headless_mode) + def _prepare_metrics_for_frontend(self, action: Action) -> None: + """Create a minimal metrics object for frontend display and log it. + + To avoid performance issues with long conversations, we only keep: + - accumulated_cost: The current total cost + - latest token_usage: Token statistics from the most recent API call + + Args: + action: The action to attach metrics to + """ + metrics = Metrics(model_name=self.agent.llm.metrics.model_name) + metrics.accumulated_cost = self.agent.llm.metrics.accumulated_cost + if self.agent.llm.metrics.token_usages: + latest_usage = self.agent.llm.metrics.token_usages[-1] + metrics.add_token_usage( + prompt_tokens=latest_usage.prompt_tokens, + completion_tokens=latest_usage.completion_tokens, + cache_read_tokens=latest_usage.cache_read_tokens, + cache_write_tokens=latest_usage.cache_write_tokens, + response_id=latest_usage.response_id, + ) + action.llm_metrics = metrics + + # Log the metrics information for frontend display + log_usage: TokenUsage | None = ( + metrics.token_usages[-1] if metrics.token_usages else None + ) + self.log( + 'debug', + f'Action metrics - accumulated_cost: {metrics.accumulated_cost}, ' + f'tokens (prompt/completion/cache_read/cache_write): ' + f'{log_usage.prompt_tokens if log_usage else 0}/' + f'{log_usage.completion_tokens if log_usage else 0}/' + f'{log_usage.cache_read_tokens if log_usage else 0}/' + f'{log_usage.cache_write_tokens if log_usage else 0}', + extra={'msg_type': 'METRICS'}, + ) + def __repr__(self): return ( f'AgentController(id={getattr(self, "id", "")}, ' diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index 71f591fd7d..420c355943 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -9,6 +9,7 @@ from openhands.events.serialization.action import action_from_dict from openhands.events.serialization.observation import observation_from_dict from openhands.events.serialization.utils import remove_fields from openhands.events.tool import ToolCallMetadata +from openhands.llm.metrics import Cost, Metrics, ResponseLatency, TokenUsage # TODO: move `content` into `extras` TOP_KEYS = [ @@ -20,8 +21,9 @@ TOP_KEYS = [ 'action', 'observation', 'tool_call_metadata', + 'llm_metrics', ] -UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause', 'tool_call_metadata'] +UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause', 'tool_call_metadata', 'llm_metrics'] DELETE_FROM_TRAJECTORY_EXTRAS = { 'screenshot', @@ -54,6 +56,15 @@ def event_from_dict(data) -> 'Event': value = EventSource(value) if key == 'tool_call_metadata': value = ToolCallMetadata(**value) + if key == 'llm_metrics': + metrics = Metrics() + if isinstance(value, dict): + metrics.accumulated_cost = value.get('accumulated_cost', 0.0) + for cost in value.get('costs', []): + metrics._costs.append(Cost(**cost)) + metrics.response_latencies = [ResponseLatency(**latency) for latency in value.get('response_latencies', [])] + metrics.token_usages = [TokenUsage(**usage) for usage in value.get('token_usages', [])] + value = metrics setattr(evt, '_' + key, value) return evt @@ -81,6 +92,8 @@ def event_to_dict(event: 'Event') -> dict: d['source'] = d['source'].value if key == 'tool_call_metadata' and 'tool_call_metadata' in d: d['tool_call_metadata'] = d['tool_call_metadata'].model_dump() + if key == 'llm_metrics' and 'llm_metrics' in d: + d['llm_metrics'] = d['llm_metrics'].get() props.pop(key, None) if 'security_risk' in props and props['security_risk'] is None: props.pop('security_risk') diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index a0a350d4d8..1d8c7bb8d5 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -19,7 +19,7 @@ from openhands.events.observation import ( ) from openhands.events.serialization import event_to_dict from openhands.llm import LLM -from openhands.llm.metrics import Metrics +from openhands.llm.metrics import Metrics, TokenUsage from openhands.runtime.base import Runtime from openhands.storage.memory import InMemoryFileStore @@ -749,3 +749,105 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( # Check that the context window exceeded error was raised during the run assert step_state.has_errored + + +@pytest.mark.asyncio +async def test_action_metrics_copy(): + # Setup + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + # Create agent with metrics + agent = MagicMock(spec=Agent) + agent.llm = MagicMock(spec=LLM) + metrics = Metrics(model_name='test-model') + metrics.accumulated_cost = 0.05 + + # Add multiple token usages - we should get the last one in the action + usage1 = TokenUsage( + model='test-model', + prompt_tokens=5, + completion_tokens=10, + cache_read_tokens=2, + cache_write_tokens=2, + response_id='test-id-1', + ) + + usage2 = TokenUsage( + model='test-model', + prompt_tokens=10, + completion_tokens=20, + cache_read_tokens=5, + cache_write_tokens=5, + response_id='test-id-2', + ) + + metrics.token_usages = [usage1, usage2] + + # Add a cost instance - should not be included in action metrics + # This will increase accumulated_cost by 0.02 + metrics.add_cost(0.02) + + # Add a response latency - should not be included in action metrics + metrics.add_response_latency(0.5, 'test-id-2') + + agent.llm.metrics = metrics + + # Mock agent step to return an action + action = MessageAction(content='Test message') + + def agent_step_fn(state): + return action + + agent.step = agent_step_fn + + # Create controller with correct parameters + controller = AgentController( + agent=agent, + event_stream=event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Execute one step + controller.state.agent_state = AgentState.RUNNING + await controller._step() + + # Get the last event from event stream + events = list(event_stream.get_events()) + assert len(events) > 0 + last_action = events[-1] + + # Verify metrics were copied correctly + assert last_action.llm_metrics is not None + assert ( + last_action.llm_metrics.accumulated_cost == 0.07 + ) # 0.05 initial + 0.02 from add_cost + + # Should include the last token usage + assert len(last_action.llm_metrics.token_usages) == 1 + assert last_action.llm_metrics.token_usages[0].prompt_tokens == 10 + assert last_action.llm_metrics.token_usages[0].completion_tokens == 20 + assert last_action.llm_metrics.token_usages[0].cache_read_tokens == 5 + assert last_action.llm_metrics.token_usages[0].cache_write_tokens == 5 + assert last_action.llm_metrics.token_usages[0].response_id == 'test-id-2' + + # Should not include the cost history + assert len(last_action.llm_metrics.costs) == 0 + + # Should not include the response latency history + assert len(last_action.llm_metrics.response_latencies) == 0 + + # Verify that there's no latency information in the action's metrics + # Either directly or as a calculated property + assert not hasattr(last_action.llm_metrics, 'latency') + assert not hasattr(last_action.llm_metrics, 'total_latency') + assert not hasattr(last_action.llm_metrics, 'average_latency') + + # Verify it's a deep copy by modifying the original + agent.llm.metrics.accumulated_cost = 0.1 + assert last_action.llm_metrics.accumulated_cost == 0.07 + + await controller.close() diff --git a/tests/unit/test_event_serialization.py b/tests/unit/test_event_serialization.py index 52df3ef1d0..2d6b837660 100644 --- a/tests/unit/test_event_serialization.py +++ b/tests/unit/test_event_serialization.py @@ -1,5 +1,7 @@ +from openhands.events.action import MessageAction from openhands.events.observation import CmdOutputMetadata, CmdOutputObservation -from openhands.events.serialization import event_to_dict +from openhands.events.serialization import event_from_dict, event_to_dict +from openhands.llm.metrics import Cost, Metrics, ResponseLatency, TokenUsage def test_command_output_success_serialization(): @@ -20,3 +22,102 @@ def test_command_output_success_serialization(): ) serialized = event_to_dict(obs) assert serialized['success'] is False + + +def test_metrics_basic_serialization(): + # Create a basic action with only accumulated_cost + action = MessageAction(content='Hello, world!') + metrics = Metrics() + metrics.accumulated_cost = 0.03 + action._llm_metrics = metrics + + # Test serialization + serialized = event_to_dict(action) + assert 'llm_metrics' in serialized + assert serialized['llm_metrics']['accumulated_cost'] == 0.03 + assert serialized['llm_metrics']['costs'] == [] + assert serialized['llm_metrics']['response_latencies'] == [] + assert serialized['llm_metrics']['token_usages'] == [] + + # Test deserialization + deserialized = event_from_dict(serialized) + assert deserialized.llm_metrics is not None + assert deserialized.llm_metrics.accumulated_cost == 0.03 + assert len(deserialized.llm_metrics.costs) == 0 + assert len(deserialized.llm_metrics.response_latencies) == 0 + assert len(deserialized.llm_metrics.token_usages) == 0 + + +def test_metrics_full_serialization(): + # Create an observation with all metrics fields + obs = CmdOutputObservation( + command='ls', + content='test.txt', + metadata=CmdOutputMetadata(exit_code=0), + ) + metrics = Metrics(model_name='test-model') + metrics.accumulated_cost = 0.03 + + # Add a cost + cost = Cost(model='test-model', cost=0.02) + metrics._costs.append(cost) + + # Add a response latency + latency = ResponseLatency(model='test-model', latency=0.5, response_id='test-id') + metrics.response_latencies = [latency] + + # Add token usage + usage = TokenUsage( + model='test-model', + prompt_tokens=10, + completion_tokens=20, + cache_read_tokens=0, + cache_write_tokens=0, + response_id='test-id', + ) + metrics.token_usages = [usage] + + obs._llm_metrics = metrics + + # Test serialization + serialized = event_to_dict(obs) + assert 'llm_metrics' in serialized + metrics_dict = serialized['llm_metrics'] + assert metrics_dict['accumulated_cost'] == 0.03 + assert len(metrics_dict['costs']) == 1 + assert metrics_dict['costs'][0]['cost'] == 0.02 + assert len(metrics_dict['response_latencies']) == 1 + assert metrics_dict['response_latencies'][0]['latency'] == 0.5 + assert len(metrics_dict['token_usages']) == 1 + assert metrics_dict['token_usages'][0]['prompt_tokens'] == 10 + assert metrics_dict['token_usages'][0]['completion_tokens'] == 20 + + # Test deserialization + deserialized = event_from_dict(serialized) + assert deserialized.llm_metrics is not None + assert deserialized.llm_metrics.accumulated_cost == 0.03 + assert len(deserialized.llm_metrics.costs) == 1 + assert deserialized.llm_metrics.costs[0].cost == 0.02 + assert len(deserialized.llm_metrics.response_latencies) == 1 + assert deserialized.llm_metrics.response_latencies[0].latency == 0.5 + assert len(deserialized.llm_metrics.token_usages) == 1 + assert deserialized.llm_metrics.token_usages[0].prompt_tokens == 10 + assert deserialized.llm_metrics.token_usages[0].completion_tokens == 20 + + +def test_metrics_none_serialization(): + # Test when metrics is None + obs = CmdOutputObservation( + command='ls', + content='test.txt', + metadata=CmdOutputMetadata(exit_code=0), + ) + obs._llm_metrics = None + + # Test serialization + serialized = event_to_dict(obs) + assert 'llm_metrics' not in serialized + + # Test deserialization + deserialized = event_from_dict(serialized) + assert deserialized.llm_metrics is None