mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Bug fix: Metrics not accumulated across agent delegation (#3012)
* Add test to reproduce cost miscalculation bug * Fix metrics bug * Copy metrics upon AgentRejectAction
This commit is contained in:
@@ -123,7 +123,7 @@ class AgentController:
|
||||
|
||||
async def update_state_after_step(self):
|
||||
# update metrics especially for cost
|
||||
self.state.metrics = self.agent.llm.metrics
|
||||
self.state.local_metrics = self.agent.llm.metrics
|
||||
|
||||
async def report_error(self, message: str, exception: Exception | None = None):
|
||||
"""This error will be reported to the user and sent to the LLM next step, in the hope it can self-correct.
|
||||
@@ -174,9 +174,11 @@ class AgentController:
|
||||
self.state.root_task.set_subtask_state(event.task_id, event.state)
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
self.state.outputs = event.outputs # type: ignore[attr-defined]
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.FINISHED)
|
||||
elif isinstance(event, AgentRejectAction):
|
||||
self.state.outputs = event.outputs # type: ignore[attr-defined]
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.REJECTED)
|
||||
elif isinstance(event, Observation):
|
||||
if (
|
||||
@@ -260,7 +262,7 @@ class AgentController:
|
||||
iteration=self.state.iteration,
|
||||
max_iterations=self.state.max_iterations,
|
||||
delegate_level=self.state.delegate_level + 1,
|
||||
# metrics should be shared between parent and child
|
||||
# global metrics should be shared between parent and child
|
||||
metrics=self.state.metrics,
|
||||
)
|
||||
logger.info(
|
||||
|
||||
@@ -98,6 +98,8 @@ class State:
|
||||
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
|
||||
# global metrics for the current task
|
||||
metrics: Metrics = Metrics()
|
||||
# local metrics for the current subtask
|
||||
local_metrics: Metrics = Metrics()
|
||||
# root agent has level 0, and every delegate increases the level by one
|
||||
delegate_level: int = 0
|
||||
# start_id and end_id track the range of events in history
|
||||
|
||||
@@ -28,6 +28,10 @@ class Metrics:
|
||||
self._accumulated_cost += value
|
||||
self._costs.append(value)
|
||||
|
||||
def merge(self, other: 'Metrics') -> None:
|
||||
self._accumulated_cost += other.accumulated_cost
|
||||
self._costs += other._costs
|
||||
|
||||
def get(self):
|
||||
"""Return the metrics in a dictionary."""
|
||||
return {'accumulated_cost': self._accumulated_cost, 'costs': self._costs}
|
||||
|
||||
@@ -173,6 +173,11 @@ def mock_completion(*args, test_name, **kwargs):
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_test_name(request):
|
||||
return request.node.name
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_completion(monkeypatch, request):
|
||||
test_name = request.node.name
|
||||
@@ -182,6 +187,12 @@ def patch_completion(monkeypatch, request):
|
||||
partial(mock_completion, test_name=test_name),
|
||||
)
|
||||
|
||||
# Mock LLM completion cost (1 USD per conversation)
|
||||
monkeypatch.setattr(
|
||||
'opendevin.llm.llm.litellm_completion_cost',
|
||||
lambda completion_response, **extra_kwargs: 1,
|
||||
)
|
||||
|
||||
# Mock user input (only for tests that have user_responses.log)
|
||||
user_responses_str = mock_user_response(test_name=test_name)
|
||||
if user_responses_str:
|
||||
|
||||
@@ -28,10 +28,25 @@ print(f'workspace_mount_path: {workspace_mount_path}')
|
||||
print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
|
||||
|
||||
|
||||
def validate_final_state(final_state: State | None):
|
||||
def get_number_of_prompts(test_name: str):
|
||||
mock_dir = os.path.join(
|
||||
os.environ.get('SCRIPT_DIR'), 'mock', os.environ.get('DEFAULT_AGENT'), test_name
|
||||
)
|
||||
prompt_files = [file for file in os.listdir(mock_dir) if file.startswith('prompt_')]
|
||||
return len(prompt_files)
|
||||
|
||||
|
||||
def validate_final_state(final_state: State | None, test_name: str):
|
||||
assert final_state is not None
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.last_error is None
|
||||
# number of LLM conversations should be the same as number of prompt/response
|
||||
# log files under mock/[agent]/[test_name] folder. If not, it means there are
|
||||
# redundant prompt/response log files checked into the repository.
|
||||
num_of_conversations = get_number_of_prompts(test_name)
|
||||
assert num_of_conversations > 0
|
||||
# we mock the cost of every conversation to be 1 USD
|
||||
assert final_state.metrics.accumulated_cost == num_of_conversations
|
||||
if final_state.history.has_delegation():
|
||||
assert final_state.iteration > final_state.local_iteration
|
||||
else:
|
||||
@@ -55,7 +70,7 @@ def validate_final_state(final_state: State | None):
|
||||
os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
|
||||
reason='Manager agent is not capable of finishing this in reasonable steps yet',
|
||||
)
|
||||
def test_write_simple_script() -> None:
|
||||
def test_write_simple_script(current_test_name) -> None:
|
||||
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
|
||||
args = parse_arguments()
|
||||
|
||||
@@ -65,9 +80,7 @@ def test_write_simple_script() -> None:
|
||||
final_state: State | None = asyncio.run(
|
||||
run_agent_controller(agent, task, exit_on_message=True)
|
||||
)
|
||||
assert final_state is not None
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.last_error is None
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
# Verify the script file exists
|
||||
assert workspace_base is not None
|
||||
@@ -103,7 +116,7 @@ def test_write_simple_script() -> None:
|
||||
os.getenv('SANDBOX_BOX_TYPE') == 'local',
|
||||
reason='local sandbox shows environment-dependent absolute path for pwd command',
|
||||
)
|
||||
def test_edits():
|
||||
def test_edits(current_test_name):
|
||||
args = parse_arguments()
|
||||
# Copy workspace artifacts to workspace_base location
|
||||
source_dir = os.path.join(os.path.dirname(__file__), 'workspace/test_edits/')
|
||||
@@ -122,7 +135,7 @@ def test_edits():
|
||||
final_state: State | None = asyncio.run(
|
||||
run_agent_controller(agent, task, exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
# Verify bad.txt has been fixed
|
||||
text = """This is a stupid typo.
|
||||
@@ -144,7 +157,7 @@ Enjoy!
|
||||
os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
|
||||
reason='Currently, only ssh sandbox supports stateful tasks',
|
||||
)
|
||||
def test_ipython():
|
||||
def test_ipython(current_test_name):
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
@@ -155,7 +168,7 @@ def test_ipython():
|
||||
final_state: State | None = asyncio.run(
|
||||
run_agent_controller(agent, task, exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
# Verify the file exists
|
||||
file_path = os.path.join(workspace_base, 'test.txt')
|
||||
@@ -177,7 +190,7 @@ def test_ipython():
|
||||
os.getenv('SANDBOX_BOX_TYPE') == 'local',
|
||||
reason='FIXME: local sandbox does not capture stderr',
|
||||
)
|
||||
def test_simple_task_rejection():
|
||||
def test_simple_task_rejection(current_test_name):
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
@@ -187,7 +200,7 @@ def test_simple_task_rejection():
|
||||
# the workspace is not a git repo
|
||||
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
|
||||
final_state: State | None = asyncio.run(run_agent_controller(agent, task))
|
||||
validate_final_state(final_state)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
|
||||
|
||||
|
||||
@@ -200,7 +213,7 @@ def test_simple_task_rejection():
|
||||
os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
|
||||
reason='Currently, only ssh sandbox supports stateful tasks',
|
||||
)
|
||||
def test_ipython_module():
|
||||
def test_ipython_module(current_test_name):
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
@@ -211,7 +224,7 @@ def test_ipython_module():
|
||||
final_state: State | None = asyncio.run(
|
||||
run_agent_controller(agent, task, exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
# Verify the file exists
|
||||
file_path = os.path.join(workspace_base, 'test.txt')
|
||||
@@ -239,7 +252,7 @@ def test_ipython_module():
|
||||
and os.getenv('SANDBOX_BOX_TYPE', '').lower() != 'ssh',
|
||||
reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
|
||||
)
|
||||
def test_browse_internet(http_server):
|
||||
def test_browse_internet(http_server, current_test_name):
|
||||
args = parse_arguments()
|
||||
|
||||
# Create the agent
|
||||
@@ -250,7 +263,7 @@ def test_browse_internet(http_server):
|
||||
final_state: State | None = asyncio.run(
|
||||
run_agent_controller(agent, task, exit_on_message=True)
|
||||
)
|
||||
validate_final_state(final_state)
|
||||
validate_final_state(final_state, current_test_name)
|
||||
|
||||
# last action
|
||||
last_action = final_state.history.get_last_action()
|
||||
|
||||
Reference in New Issue
Block a user