(llm): Track accumulated token usage instead of per-request token usage (#7511)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>
This commit is contained in:
Xingyao Wang 2025-03-26 09:05:36 -07:00 committed by GitHub
parent 1230b229b5
commit c63d52d5e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 298 additions and 38 deletions

View File

@ -240,22 +240,67 @@ export function ConversationCard({
title="Metrics Information"
testID="metrics-modal"
>
<div className="space-y-2">
{metrics?.cost !== null && (
<p>Total Cost: ${metrics.cost.toFixed(4)}</p>
)}
{metrics?.usage !== null && (
<>
<p>Tokens Used:</p>
<ul className="list-inside space-y-1 ml-2">
<li>- Input: {metrics.usage.prompt_tokens}</li>
<li>- Output: {metrics.usage.completion_tokens}</li>
<li>- Total: {metrics.usage.total_tokens}</li>
</ul>
</>
<div className="space-y-4">
{(metrics?.cost !== null || metrics?.usage !== null) && (
<div className="rounded-md p-3">
<div className="grid gap-3">
{metrics?.cost !== null && (
<div className="flex justify-between items-center border-b border-neutral-700 pb-2">
<span className="text-lg font-semibold">
Total Cost (USD):
</span>
<span className="font-semibold">
${metrics.cost.toFixed(4)}
</span>
</div>
)}
{metrics?.usage !== null && (
<>
<div className="flex justify-between items-center pb-2">
<span>Total Input Tokens:</span>
<span className="font-semibold">
{metrics.usage.prompt_tokens.toLocaleString()}
</span>
</div>
<div className="grid grid-cols-2 gap-2 pl-4 text-sm">
<span className="text-neutral-400">Cache Hit:</span>
<span className="text-right">
{metrics.usage.cache_read_tokens.toLocaleString()}
</span>
<span className="text-neutral-400">Cache Write:</span>
<span className="text-right">
{metrics.usage.cache_write_tokens.toLocaleString()}
</span>
</div>
<div className="flex justify-between items-center border-b border-neutral-700 pb-2">
<span>Total Output Tokens:</span>
<span className="font-semibold">
{metrics.usage.completion_tokens.toLocaleString()}
</span>
</div>
<div className="flex justify-between items-center pt-1">
<span className="font-semibold">Total Tokens:</span>
<span className="font-bold">
{(
metrics.usage.prompt_tokens +
metrics.usage.completion_tokens
).toLocaleString()}
</span>
</div>
</>
)}
</div>
</div>
)}
{!metrics?.cost && !metrics?.usage && (
<p className="text-neutral-400">No metrics data available</p>
<div className="rounded-md p-4 text-center">
<p className="text-neutral-400">No metrics data available</p>
</div>
)}
</div>
</BaseModal>

View File

@ -87,13 +87,10 @@ export function handleActionMessage(message: ActionMessage) {
}
// Update metrics if available
if (
message.llm_metrics ||
message.tool_call_metadata?.model_response?.usage
) {
if (message.llm_metrics) {
const metrics = {
cost: message.llm_metrics?.accumulated_cost ?? null,
usage: message.tool_call_metadata?.model_response?.usage ?? null,
usage: message.llm_metrics?.accumulated_token_usage ?? null,
};
store.dispatch(setMetrics(metrics));
}

View File

@ -5,7 +5,8 @@ interface MetricsState {
usage: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cache_read_tokens: number;
cache_write_tokens: number;
} | null;
}

View File

@ -19,6 +19,12 @@ export interface ActionMessage {
// LLM metrics information
llm_metrics?: {
accumulated_cost: number;
accumulated_token_usage: {
prompt_tokens: number;
completion_tokens: number;
cache_read_tokens: number;
cache_write_tokens: number;
};
};
// Tool call metadata

View File

@ -42,7 +42,6 @@ class ActionType(str, Enum):
"""Delegates a task to another agent.
"""
THINK = 'think'
"""Logs a thought.
"""

View File

@ -1,5 +1,5 @@
import re
from typing import Any
from openhands.core.exceptions import LLMMalformedActionError
from openhands.events.action.action import Action
from openhands.events.action.agent import (

View File

@ -79,6 +79,11 @@ def event_from_dict(data: dict[str, Any]) -> 'Event':
metrics.token_usages = [
TokenUsage(**usage) for usage in value.get('token_usages', [])
]
# Set accumulated token usage if available
if 'accumulated_token_usage' in value:
metrics._accumulated_token_usage = TokenUsage(
**value.get('accumulated_token_usage', {})
)
value = metrics
setattr(evt, '_' + key, value)
return evt

View File

@ -20,12 +20,23 @@ class ResponseLatency(BaseModel):
class TokenUsage(BaseModel):
"""Metric tracking detailed token usage per completion call."""
model: str
prompt_tokens: int
completion_tokens: int
cache_read_tokens: int
cache_write_tokens: int
response_id: str
model: str = Field(default='')
prompt_tokens: int = Field(default=0)
completion_tokens: int = Field(default=0)
cache_read_tokens: int = Field(default=0)
cache_write_tokens: int = Field(default=0)
response_id: str = Field(default='')
def __add__(self, other: 'TokenUsage') -> 'TokenUsage':
"""Add two TokenUsage instances together."""
return TokenUsage(
model=self.model,
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
completion_tokens=self.completion_tokens + other.completion_tokens,
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
response_id=self.response_id,
)
class Metrics:
@ -42,6 +53,14 @@ class Metrics:
self._response_latencies: list[ResponseLatency] = []
self.model_name = model_name
self._token_usages: list[TokenUsage] = []
self._accumulated_token_usage: TokenUsage = TokenUsage(
model=model_name,
prompt_tokens=0,
completion_tokens=0,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='',
)
@property
def accumulated_cost(self) -> float:
@ -99,15 +118,24 @@ class Metrics:
response_id: str,
) -> None:
"""Add a single usage record."""
self._token_usages.append(
TokenUsage(
model=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
response_id=response_id,
)
usage = TokenUsage(
model=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
response_id=response_id,
)
self._token_usages.append(usage)
# Update accumulated token usage using the __add__ operator
self._accumulated_token_usage = self._accumulated_token_usage + TokenUsage(
model=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
response_id='',
)
def merge(self, other: 'Metrics') -> None:
@ -118,10 +146,16 @@ class Metrics:
self.token_usages += other.token_usages
self.response_latencies += other.response_latencies
# Merge accumulated token usage using the __add__ operator
self._accumulated_token_usage = (
self._accumulated_token_usage + other._accumulated_token_usage
)
def get(self) -> dict:
"""Return the metrics in a dictionary."""
return {
'accumulated_cost': self._accumulated_cost,
'accumulated_token_usage': self._accumulated_token_usage.model_dump(),
'costs': [cost.model_dump() for cost in self._costs],
'response_latencies': [
latency.model_dump() for latency in self._response_latencies
@ -134,6 +168,15 @@ class Metrics:
self._costs = []
self._response_latencies = []
self._token_usages = []
# Reset accumulated token usage with a new instance
self._accumulated_token_usage = TokenUsage(
model=self.model_name,
prompt_tokens=0,
completion_tokens=0,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='',
)
def log(self):
"""Log the metrics."""

View File

@ -13,7 +13,7 @@ from openhands.core.config import LLMConfig
from openhands.core.exceptions import OperationCancelled
from openhands.core.message import Message, TextContent
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.llm.metrics import Metrics, TokenUsage
@pytest.fixture(autouse=True)
@ -45,6 +45,84 @@ def test_llm_init_with_default_config(default_config):
assert llm.metrics.model_name == 'gpt-4o'
def test_token_usage_add():
"""Test that TokenUsage instances can be added together."""
# Create two TokenUsage instances
usage1 = TokenUsage(
model='model1',
prompt_tokens=10,
completion_tokens=5,
cache_read_tokens=3,
cache_write_tokens=2,
response_id='response-1',
)
usage2 = TokenUsage(
model='model2',
prompt_tokens=8,
completion_tokens=6,
cache_read_tokens=2,
cache_write_tokens=4,
response_id='response-2',
)
# Add them together
combined = usage1 + usage2
# Verify the result
assert combined.model == 'model1' # Should keep the model from the first instance
assert combined.prompt_tokens == 18 # 10 + 8
assert combined.completion_tokens == 11 # 5 + 6
assert combined.cache_read_tokens == 5 # 3 + 2
assert combined.cache_write_tokens == 6 # 2 + 4
assert (
combined.response_id == 'response-1'
) # Should keep the response_id from the first instance
def test_metrics_merge_accumulated_token_usage():
"""Test that accumulated token usage is properly merged between two Metrics instances."""
# Create two Metrics instances
metrics1 = Metrics(model_name='model1')
metrics2 = Metrics(model_name='model2')
# Add token usage to each
metrics1.add_token_usage(10, 5, 3, 2, 'response-1')
metrics2.add_token_usage(8, 6, 2, 4, 'response-2')
# Verify initial accumulated token usage
metrics1_data = metrics1.get()
accumulated1 = metrics1_data['accumulated_token_usage']
assert accumulated1['prompt_tokens'] == 10
assert accumulated1['completion_tokens'] == 5
assert accumulated1['cache_read_tokens'] == 3
assert accumulated1['cache_write_tokens'] == 2
metrics2_data = metrics2.get()
accumulated2 = metrics2_data['accumulated_token_usage']
assert accumulated2['prompt_tokens'] == 8
assert accumulated2['completion_tokens'] == 6
assert accumulated2['cache_read_tokens'] == 2
assert accumulated2['cache_write_tokens'] == 4
# Merge metrics2 into metrics1
metrics1.merge(metrics2)
# Verify merged accumulated token usage
merged_data = metrics1.get()
merged_accumulated = merged_data['accumulated_token_usage']
assert merged_accumulated['prompt_tokens'] == 18 # 10 + 8
assert merged_accumulated['completion_tokens'] == 11 # 5 + 6
assert merged_accumulated['cache_read_tokens'] == 5 # 3 + 2
assert merged_accumulated['cache_write_tokens'] == 6 # 2 + 4
# Verify individual token usage records are maintained
token_usages = merged_data['token_usages']
assert len(token_usages) == 2
assert token_usages[0]['response_id'] == 'response-1'
assert token_usages[1]['response_id'] == 'response-2'
@patch('openhands.llm.llm.litellm.get_model_info')
def test_llm_init_with_model_info(mock_get_model_info, default_config):
mock_get_model_info.return_value = {
@ -140,12 +218,22 @@ def test_llm_reset():
initial_metrics = copy.deepcopy(llm.metrics)
initial_metrics.add_cost(1.0)
initial_metrics.add_response_latency(0.5, 'test-id')
initial_metrics.add_token_usage(10, 5, 3, 2, 'test-id')
llm.reset()
assert llm.metrics.accumulated_cost != initial_metrics.accumulated_cost
assert llm.metrics.costs != initial_metrics.costs
assert llm.metrics.response_latencies != initial_metrics.response_latencies
assert llm.metrics.token_usages != initial_metrics.token_usages
assert isinstance(llm.metrics, Metrics)
# Check that accumulated token usage is reset
metrics_data = llm.metrics.get()
accumulated_usage = metrics_data['accumulated_token_usage']
assert accumulated_usage['prompt_tokens'] == 0
assert accumulated_usage['completion_tokens'] == 0
assert accumulated_usage['cache_read_tokens'] == 0
assert accumulated_usage['cache_write_tokens'] == 0
@patch('openhands.llm.llm.litellm.get_model_info')
def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
@ -493,6 +581,82 @@ def test_llm_token_usage(mock_litellm_completion, default_config):
assert usage_entry_2['response_id'] == 'test-response-usage-2'
@patch('openhands.llm.llm.litellm_completion')
def test_accumulated_token_usage(mock_litellm_completion, default_config):
"""Test that token usage is properly accumulated across multiple LLM calls."""
# Mock responses with token usage information
mock_response_1 = {
'id': 'test-response-1',
'choices': [{'message': {'content': 'First response'}}],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'prompt_tokens_details': PromptTokensDetails(cached_tokens=3),
'model_extra': {'cache_creation_input_tokens': 4},
},
}
mock_response_2 = {
'id': 'test-response-2',
'choices': [{'message': {'content': 'Second response'}}],
'usage': {
'prompt_tokens': 8,
'completion_tokens': 6,
'prompt_tokens_details': PromptTokensDetails(cached_tokens=2),
'model_extra': {'cache_creation_input_tokens': 3},
},
}
# Set up the mock to return these responses in sequence
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
# Create LLM instance
llm = LLM(config=default_config)
# First call
llm.completion(messages=[{'role': 'user', 'content': 'First message'}])
# Check accumulated token usage after first call
metrics_data = llm.metrics.get()
accumulated_usage = metrics_data['accumulated_token_usage']
assert accumulated_usage['prompt_tokens'] == 10
assert accumulated_usage['completion_tokens'] == 5
assert accumulated_usage['cache_read_tokens'] == 3
assert accumulated_usage['cache_write_tokens'] == 4
# Second call
llm.completion(messages=[{'role': 'user', 'content': 'Second message'}])
# Check accumulated token usage after second call
metrics_data = llm.metrics.get()
accumulated_usage = metrics_data['accumulated_token_usage']
# Values should be the sum of both calls
assert accumulated_usage['prompt_tokens'] == 18 # 10 + 8
assert accumulated_usage['completion_tokens'] == 11 # 5 + 6
assert accumulated_usage['cache_read_tokens'] == 5 # 3 + 2
assert accumulated_usage['cache_write_tokens'] == 7 # 4 + 3
# Verify individual token usage records are still maintained
token_usages = metrics_data['token_usages']
assert len(token_usages) == 2
# First record
assert token_usages[0]['prompt_tokens'] == 10
assert token_usages[0]['completion_tokens'] == 5
assert token_usages[0]['cache_read_tokens'] == 3
assert token_usages[0]['cache_write_tokens'] == 4
assert token_usages[0]['response_id'] == 'test-response-1'
# Second record
assert token_usages[1]['prompt_tokens'] == 8
assert token_usages[1]['completion_tokens'] == 6
assert token_usages[1]['cache_read_tokens'] == 2
assert token_usages[1]['cache_write_tokens'] == 3
assert token_usages[1]['response_id'] == 'test-response-2'
@patch('openhands.llm.llm.litellm_completion')
def test_completion_with_log_completions(mock_litellm_completion, default_config):
with tempfile.TemporaryDirectory() as temp_dir: