mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix: expose aggregated LLM metrics in State for evaluation scripts (#10537)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
e9e2c98946
commit
91d3d1d20a
@ -10,6 +10,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -146,7 +147,7 @@ def process_instance(
|
||||
|
||||
logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
|
||||
test_result = game.reward()
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@ -18,6 +18,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -273,7 +274,7 @@ def process_instance(
|
||||
# remove when it becomes unnecessary
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -17,6 +17,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -246,7 +247,7 @@ def process_instance(
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
histories = compatibility_for_eval_history_pairs(state.history)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -15,6 +15,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -294,7 +295,7 @@ def process_instance(
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
test_result = complete_runtime(runtime, instance)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
|
||||
@ -18,6 +18,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -422,7 +423,7 @@ def process_instance(
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@ -11,6 +11,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -88,7 +89,7 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
|
||||
@ -16,6 +16,7 @@ from evaluation.utils.shared import (
|
||||
assert_and_raise,
|
||||
codeact_user_response,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -480,7 +481,7 @@ def process_instance(
|
||||
|
||||
# NOTE: this is NO LONGER the event stream, but an agent history that includes delegate agent's events
|
||||
histories = [event_to_dict(event) for event in state.history]
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -17,6 +17,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -294,7 +295,7 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
test_result = complete_runtime(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
|
||||
@ -22,6 +22,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -269,7 +270,7 @@ Here is the task:
|
||||
'model_answer': model_answer,
|
||||
'ground_truth': instance['Final answer'],
|
||||
}
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@ -12,6 +12,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -108,7 +109,7 @@ def process_instance(
|
||||
# attempt to parse model_answer
|
||||
ast_eval_fn = instance['ast_eval']
|
||||
correct, hallucination = ast_eval_fn(instance_id, model_answer_raw)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
logger.info(
|
||||
f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}'
|
||||
)
|
||||
|
||||
@ -30,6 +30,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -292,7 +293,7 @@ Ok now its time to start solving the question. Good luck!
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
|
||||
@ -23,6 +23,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -248,7 +249,7 @@ def process_instance(
|
||||
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
test_result = complete_runtime(runtime, instance)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
|
||||
@ -22,6 +22,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -335,7 +336,7 @@ Be thorough in your exploration, testing, and reasoning. It's fine if your think
|
||||
)
|
||||
)
|
||||
assert state is not None
|
||||
metrics = state.metrics.get() if state.metrics else {}
|
||||
metrics = get_metrics(state)
|
||||
|
||||
test_result = complete_runtime(runtime, instance)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -247,7 +248,7 @@ def process_instance(
|
||||
)
|
||||
test_result['final_message'] = final_message
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
# remove when it becomes unnecessary
|
||||
|
||||
@ -13,6 +13,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -174,7 +175,7 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Instruction is the first message from the USER
|
||||
instruction = ''
|
||||
|
||||
@ -15,6 +15,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -205,7 +206,7 @@ def process_instance(
|
||||
task_state = state.extra_data['task_state']
|
||||
logger.info('Task state: ' + str(task_state.to_dict()))
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@ -26,6 +26,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -250,7 +251,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
|
||||
)
|
||||
)
|
||||
assert state is not None
|
||||
metrics = state.metrics.get() if state.metrics else {}
|
||||
metrics = get_metrics(state)
|
||||
|
||||
test_result = complete_runtime(runtime)
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -218,7 +219,7 @@ If the program uses some packages that are incompatible, please figure out alter
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@ -21,6 +21,7 @@ from evaluation.utils.shared import (
|
||||
EvalException,
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -179,7 +180,7 @@ def process_instance(
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
histories = [event_to_dict(event) for event in state.history]
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Save the output
|
||||
instruction = message_action.content
|
||||
|
||||
@ -11,6 +11,7 @@ from evaluation.utils.shared import (
|
||||
codeact_user_response,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -134,7 +135,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
|
||||
correct = eval_answer(str(model_answer_raw), str(answer))
|
||||
logger.info(f'Final message: {model_answer_raw} | Correctness: {correct}')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
|
||||
# for compatibility with the existing output format, we can remake the pairs here
|
||||
|
||||
@ -12,6 +12,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -179,7 +180,7 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Instruction obtained from the first message from the USER
|
||||
instruction = ''
|
||||
|
||||
@ -12,6 +12,7 @@ from evaluation.utils.shared import (
|
||||
EvalOutput,
|
||||
compatibility_for_eval_history_pairs,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -163,7 +164,7 @@ def process_instance(
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Instruction is the first message from the USER
|
||||
instruction = ''
|
||||
|
||||
@ -9,6 +9,7 @@ from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
@ -135,7 +136,7 @@ def process_instance(
|
||||
assert len(histories) > 0, 'History should not be empty'
|
||||
|
||||
test_result: TestResult = test_class.verify_result(runtime, histories)
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
metrics = get_metrics(state)
|
||||
finally:
|
||||
runtime.close()
|
||||
|
||||
|
||||
@ -668,8 +668,23 @@ def is_fatal_runtime_error(error: str | None) -> bool:
|
||||
|
||||
|
||||
def get_metrics(state: State) -> dict[str, Any]:
|
||||
"""Extract metrics from the state."""
|
||||
metrics = state.metrics.get() if state.metrics else {}
|
||||
"""Extract metrics for evaluations.
|
||||
|
||||
Prefer ConversationStats (source of truth) and fall back to state.metrics for
|
||||
backward compatibility.
|
||||
"""
|
||||
metrics: dict[str, Any]
|
||||
try:
|
||||
if getattr(state, 'conversation_stats', None):
|
||||
combined = state.conversation_stats.get_combined_metrics()
|
||||
metrics = combined.get()
|
||||
elif getattr(state, 'metrics', None):
|
||||
metrics = state.metrics.get()
|
||||
else:
|
||||
metrics = {}
|
||||
except Exception:
|
||||
metrics = state.metrics.get() if getattr(state, 'metrics', None) else {}
|
||||
|
||||
metrics['condenser'] = get_condensation_metadata(state)
|
||||
return metrics
|
||||
|
||||
|
||||
205
tests/unit/test_state_metrics_exposure.py
Normal file
205
tests/unit/test_state_metrics_exposure.py
Normal file
@ -0,0 +1,205 @@
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
|
||||
class FakeEventStream:
|
||||
def __init__(self):
|
||||
self.sid = 'test-sid'
|
||||
self.file_store = None
|
||||
self.user_id = None
|
||||
|
||||
def add_event(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def subscribe(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeRuntime:
|
||||
def __init__(self):
|
||||
self.event_stream = FakeEventStream()
|
||||
|
||||
async def connect(self):
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class DummyState:
|
||||
def __init__(self, conversation_stats):
|
||||
self.conversation_stats = conversation_stats
|
||||
self.metrics = Metrics()
|
||||
self.history = []
|
||||
self.last_error = ''
|
||||
self.extra_data = {}
|
||||
|
||||
|
||||
class FakeController:
|
||||
def __init__(self, state):
|
||||
self._state = state
|
||||
|
||||
def get_state(self):
|
||||
return self._state
|
||||
|
||||
async def close(self, set_stop_state: bool = False):
|
||||
return None
|
||||
|
||||
def get_trajectory(self, include_screenshots: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
class FakeConversationStats:
|
||||
def __init__(self, cost: float = 1.23):
|
||||
self._m = Metrics()
|
||||
self._m.add_cost(cost)
|
||||
|
||||
def get_combined_metrics(self) -> Metrics:
|
||||
return self._m
|
||||
|
||||
|
||||
def test_state_tracker_save_state_consolidates_metrics(tmp_path):
|
||||
"""Ensure StateTracker.save_state persists ConversationStats and does not touch State.metrics.
|
||||
|
||||
Eval scripts should read from state.conversation_stats via evaluation.utils.shared.get_metrics.
|
||||
"""
|
||||
from openhands.controller.state.state_tracker import StateTracker
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
# Prepare conversation stats with one service metrics
|
||||
store = InMemoryFileStore({})
|
||||
conv_stats = ConversationStats(
|
||||
file_store=store, conversation_id='cid', user_id=None
|
||||
)
|
||||
m = Metrics()
|
||||
m.add_cost(0.5)
|
||||
conv_stats.service_to_metrics['svc'] = m
|
||||
|
||||
# Create a new tracker and initialize state
|
||||
tracker = StateTracker(sid='sid', file_store=store, user_id=None)
|
||||
tracker.set_initial_state(
|
||||
id='sid',
|
||||
state=None,
|
||||
conversation_stats=conv_stats,
|
||||
max_iterations=1,
|
||||
max_budget_per_task=None,
|
||||
confirmation_mode=False,
|
||||
)
|
||||
|
||||
# Preconditions
|
||||
assert tracker.state.metrics.accumulated_cost == 0.0
|
||||
|
||||
# Act
|
||||
tracker.save_state()
|
||||
|
||||
# Assert state.metrics unaffected (source of truth remains ConversationStats)
|
||||
assert tracker.state.metrics.accumulated_cost == 0.0
|
||||
# Persistence still called on ConversationStats (no exception)
|
||||
|
||||
|
||||
def test_run_controller_exposes_aggregated_metrics_in_state():
|
||||
"""Ensure get_metrics(state) reads from ConversationStats when available."""
|
||||
from evaluation.utils.shared import get_metrics
|
||||
from openhands.core.main import run_controller
|
||||
|
||||
cfg = OpenHandsConfig()
|
||||
# Prevent run_controller from trying to persist state via DummyState
|
||||
cfg.file_store = 'memory'
|
||||
|
||||
fake_conv_stats = FakeConversationStats(cost=2.5)
|
||||
|
||||
def fake_create_registry_and_conversation_stats(config, sid, _):
|
||||
# return (llm_registry, conversation_stats, config)
|
||||
return (None, fake_conv_stats, config)
|
||||
|
||||
def fake_create_agent(config, llm_registry):
|
||||
class _AgentCfg:
|
||||
enable_mcp = False
|
||||
|
||||
class _LLMCfg:
|
||||
model = 'test-model'
|
||||
|
||||
class _LLM:
|
||||
config = _LLMCfg()
|
||||
|
||||
class _Agent:
|
||||
name = 'FakeAgent'
|
||||
config = _AgentCfg()
|
||||
llm = _LLM()
|
||||
|
||||
return _Agent()
|
||||
|
||||
def fake_create_runtime(
|
||||
config,
|
||||
llm_registry,
|
||||
sid=None,
|
||||
headless_mode=True,
|
||||
agent=None,
|
||||
git_provider_tokens=None,
|
||||
):
|
||||
return FakeRuntime()
|
||||
|
||||
def fake_create_memory(
|
||||
runtime,
|
||||
event_stream,
|
||||
sid,
|
||||
selected_repository=None,
|
||||
repo_directory=None,
|
||||
status_callback=None,
|
||||
conversation_instructions=None,
|
||||
working_dir=None,
|
||||
):
|
||||
return object()
|
||||
|
||||
def fake_create_controller(
|
||||
agent,
|
||||
runtime,
|
||||
config,
|
||||
conversation_stats,
|
||||
headless_mode=True,
|
||||
replay_events=None,
|
||||
):
|
||||
# Return a controller that yields a DummyState with provided conversation_stats
|
||||
state = DummyState(conversation_stats)
|
||||
return (FakeController(state), None)
|
||||
|
||||
# Invoke run_controller under patch context
|
||||
with (
|
||||
patch(
|
||||
'openhands.core.main.create_registry_and_conversation_stats',
|
||||
side_effect=fake_create_registry_and_conversation_stats,
|
||||
),
|
||||
patch('openhands.core.main.create_agent', side_effect=fake_create_agent),
|
||||
patch('openhands.core.main.create_runtime', side_effect=fake_create_runtime),
|
||||
patch('openhands.core.main.create_memory', side_effect=fake_create_memory),
|
||||
patch(
|
||||
'openhands.core.main.create_controller', side_effect=fake_create_controller
|
||||
),
|
||||
patch(
|
||||
'openhands.core.main.run_agent_until_done',
|
||||
side_effect=lambda *args, **kwargs: None,
|
||||
),
|
||||
):
|
||||
state = asyncio.run(
|
||||
run_controller(
|
||||
config=cfg,
|
||||
initial_user_action=MessageAction(content='hi'),
|
||||
sid='sid',
|
||||
fake_user_response_fn=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert state is not None
|
||||
# get_metrics must prefer conversation_stats and reflect its values
|
||||
m = get_metrics(state)
|
||||
assert pytest.approx(m.get('accumulated_cost', 0.0), rel=1e-6) == 2.5
|
||||
Loading…
x
Reference in New Issue
Block a user