mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
202 lines
6.9 KiB
Python
202 lines
6.9 KiB
Python
from openhands.core.message_utils import (
|
|
get_token_usage_for_event,
|
|
get_token_usage_for_event_id,
|
|
)
|
|
from openhands.events.event import Event
|
|
from openhands.events.tool import ToolCallMetadata
|
|
from openhands.llm.metrics import Metrics, TokenUsage
|
|
|
|
|
|
def test_get_token_usage_for_event():
|
|
"""Test that we get the single matching usage record (if any) based on the event's model_response.id."""
|
|
metrics = Metrics(model_name='test-model')
|
|
usage_record = TokenUsage(
|
|
model='test-model',
|
|
prompt_tokens=10,
|
|
completion_tokens=5,
|
|
cache_read_tokens=2,
|
|
cache_write_tokens=1,
|
|
response_id='test-response-id',
|
|
)
|
|
metrics.add_token_usage(
|
|
prompt_tokens=usage_record.prompt_tokens,
|
|
completion_tokens=usage_record.completion_tokens,
|
|
cache_read_tokens=usage_record.cache_read_tokens,
|
|
cache_write_tokens=usage_record.cache_write_tokens,
|
|
context_window=1000,
|
|
response_id=usage_record.response_id,
|
|
)
|
|
|
|
# Create an event referencing that response_id
|
|
event = Event()
|
|
mock_tool_call_metadata = ToolCallMetadata(
|
|
tool_call_id='test-tool-call',
|
|
function_name='fake_function',
|
|
model_response={'id': 'test-response-id'},
|
|
total_calls_in_response=1,
|
|
)
|
|
event._tool_call_metadata = (
|
|
mock_tool_call_metadata # normally you'd do event.tool_call_metadata = ...
|
|
)
|
|
|
|
# We should find that usage record
|
|
found = get_token_usage_for_event(event, metrics)
|
|
assert found is not None
|
|
assert found.prompt_tokens == 10
|
|
assert found.response_id == 'test-response-id'
|
|
|
|
# If we change the event's response ID, we won't find anything
|
|
mock_tool_call_metadata.model_response.id = 'some-other-id'
|
|
found2 = get_token_usage_for_event(event, metrics)
|
|
assert found2 is None
|
|
|
|
# If the event has no tool_call_metadata, also returns None
|
|
event._tool_call_metadata = None
|
|
found3 = get_token_usage_for_event(event, metrics)
|
|
assert found3 is None
|
|
|
|
|
|
def test_get_token_usage_for_event_id():
|
|
"""
|
|
Test that we search backward from the event with the given id,
|
|
finding the first usage record that matches a response_id in that or previous events.
|
|
"""
|
|
metrics = Metrics(model_name='test-model')
|
|
usage_1 = TokenUsage(
|
|
model='test-model',
|
|
prompt_tokens=12,
|
|
completion_tokens=3,
|
|
cache_read_tokens=2,
|
|
cache_write_tokens=5,
|
|
response_id='resp-1',
|
|
)
|
|
usage_2 = TokenUsage(
|
|
model='test-model',
|
|
prompt_tokens=7,
|
|
completion_tokens=2,
|
|
cache_read_tokens=1,
|
|
cache_write_tokens=3,
|
|
response_id='resp-2',
|
|
)
|
|
metrics._token_usages.append(usage_1)
|
|
metrics._token_usages.append(usage_2)
|
|
|
|
# Build a list of events
|
|
events = []
|
|
for i in range(5):
|
|
e = Event()
|
|
e._id = i
|
|
# We'll attach usage_1 to event 1, usage_2 to event 3
|
|
if i == 1:
|
|
e._tool_call_metadata = ToolCallMetadata(
|
|
tool_call_id='tid1',
|
|
function_name='fn1',
|
|
model_response={'id': 'resp-1'},
|
|
total_calls_in_response=1,
|
|
)
|
|
elif i == 3:
|
|
e._tool_call_metadata = ToolCallMetadata(
|
|
tool_call_id='tid2',
|
|
function_name='fn2',
|
|
model_response={'id': 'resp-2'},
|
|
total_calls_in_response=1,
|
|
)
|
|
events.append(e)
|
|
|
|
# If we ask for event_id=3, we find usage_2 immediately
|
|
found_3 = get_token_usage_for_event_id(events, 3, metrics)
|
|
assert found_3 is not None
|
|
assert found_3.response_id == 'resp-2'
|
|
|
|
# If we ask for event_id=2, no usage in event2, so we check event1 -> usage_1 found
|
|
found_2 = get_token_usage_for_event_id(events, 2, metrics)
|
|
assert found_2 is not None
|
|
assert found_2.response_id == 'resp-1'
|
|
|
|
# If we ask for event_id=0, no usage in event0 or earlier, so return None
|
|
found_0 = get_token_usage_for_event_id(events, 0, metrics)
|
|
assert found_0 is None
|
|
|
|
|
|
def test_get_token_usage_for_event_fallback():
|
|
"""
|
|
Verify that if tool_call_metadata.model_response.id is missing or mismatched,
|
|
but event.response_id is set to a valid usage ID, we find the usage record via fallback.
|
|
"""
|
|
metrics = Metrics(model_name='fallback-test')
|
|
usage_record = TokenUsage(
|
|
model='fallback-test',
|
|
prompt_tokens=22,
|
|
completion_tokens=8,
|
|
cache_read_tokens=3,
|
|
cache_write_tokens=2,
|
|
response_id='fallback-response-id',
|
|
)
|
|
metrics.add_token_usage(
|
|
prompt_tokens=usage_record.prompt_tokens,
|
|
completion_tokens=usage_record.completion_tokens,
|
|
cache_read_tokens=usage_record.cache_read_tokens,
|
|
cache_write_tokens=usage_record.cache_write_tokens,
|
|
context_window=1000,
|
|
response_id=usage_record.response_id,
|
|
)
|
|
|
|
event = Event()
|
|
# Provide some mismatched tool_call_metadata:
|
|
event._tool_call_metadata = ToolCallMetadata(
|
|
tool_call_id='irrelevant-tool-call',
|
|
function_name='fake_function',
|
|
model_response={'id': 'not-matching-any-usage'},
|
|
total_calls_in_response=1,
|
|
)
|
|
# But also set event.response_id to the actual usage ID
|
|
event._response_id = 'fallback-response-id'
|
|
|
|
found = get_token_usage_for_event(event, metrics)
|
|
assert found is not None
|
|
assert found.prompt_tokens == 22
|
|
assert found.response_id == 'fallback-response-id'
|
|
|
|
|
|
def test_get_token_usage_for_event_id_fallback():
|
|
"""
|
|
Verify that get_token_usage_for_event_id also falls back to event.response_id
|
|
if tool_call_metadata.model_response.id is missing or mismatched.
|
|
"""
|
|
|
|
# NOTE: this should never happen (tm), but there is a hint in the code that it might:
|
|
# message_utils.py: 166 ("(overwrites any previous message with the same response_id)")
|
|
# so we'll handle it gracefully.
|
|
metrics = Metrics(model_name='fallback-test')
|
|
usage_record = TokenUsage(
|
|
model='fallback-test',
|
|
prompt_tokens=15,
|
|
completion_tokens=4,
|
|
cache_read_tokens=1,
|
|
cache_write_tokens=0,
|
|
response_id='resp-fallback',
|
|
)
|
|
metrics.token_usages.append(usage_record)
|
|
|
|
events = []
|
|
for i in range(3):
|
|
e = Event()
|
|
e._id = i
|
|
if i == 1:
|
|
# Mismatch in tool_call_metadata
|
|
e._tool_call_metadata = ToolCallMetadata(
|
|
tool_call_id='tool-123',
|
|
function_name='whatever',
|
|
model_response={'id': 'no-such-response'},
|
|
total_calls_in_response=1,
|
|
)
|
|
# But the event's top-level response_id is correct
|
|
e._response_id = 'resp-fallback'
|
|
events.append(e)
|
|
|
|
# Searching from event_id=2 goes back to event1, which has fallback response_id
|
|
found_usage = get_token_usage_for_event_id(events, 2, metrics)
|
|
assert found_usage is not None
|
|
assert found_usage.response_id == 'resp-fallback'
|
|
assert found_usage.prompt_tokens == 15
|