Bug fix: Metrics not accumulated across agent delegation (#3012)

* Add test to reproduce cost miscalculation bug

* Fix metrics bug

* Copy metrics upon AgentRejectAction
This commit is contained in:
Boxuan Li
2024-07-19 21:05:05 -07:00
committed by GitHub
parent 6b16a5da0b
commit be6e6e3add
5 changed files with 49 additions and 17 deletions

View File

@@ -123,7 +123,7 @@ class AgentController:
async def update_state_after_step(self):
# update metrics especially for cost
self.state.metrics = self.agent.llm.metrics
self.state.local_metrics = self.agent.llm.metrics
async def report_error(self, message: str, exception: Exception | None = None):
"""This error will be reported to the user and sent to the LLM next step, in the hope it can self-correct.
@@ -174,9 +174,11 @@ class AgentController:
self.state.root_task.set_subtask_state(event.task_id, event.state)
elif isinstance(event, AgentFinishAction):
self.state.outputs = event.outputs # type: ignore[attr-defined]
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.FINISHED)
elif isinstance(event, AgentRejectAction):
self.state.outputs = event.outputs # type: ignore[attr-defined]
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.REJECTED)
elif isinstance(event, Observation):
if (
@@ -260,7 +262,7 @@ class AgentController:
iteration=self.state.iteration,
max_iterations=self.state.max_iterations,
delegate_level=self.state.delegate_level + 1,
# metrics should be shared between parent and child
# global metrics should be shared between parent and child
metrics=self.state.metrics,
)
logger.info(

View File

@@ -98,6 +98,8 @@ class State:
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
# global metrics for the current task
metrics: Metrics = Metrics()
# local metrics for the current subtask
local_metrics: Metrics = Metrics()
# root agent has level 0, and every delegate increases the level by one
delegate_level: int = 0
# start_id and end_id track the range of events in history

View File

@@ -28,6 +28,10 @@ class Metrics:
self._accumulated_cost += value
self._costs.append(value)
def merge(self, other: 'Metrics') -> None:
self._accumulated_cost += other.accumulated_cost
self._costs += other._costs
def get(self):
"""Return the metrics in a dictionary."""
return {'accumulated_cost': self._accumulated_cost, 'costs': self._costs}

View File

@@ -173,6 +173,11 @@ def mock_completion(*args, test_name, **kwargs):
return response
@pytest.fixture
def current_test_name(request):
return request.node.name
@pytest.fixture(autouse=True)
def patch_completion(monkeypatch, request):
test_name = request.node.name
@@ -182,6 +187,12 @@ def patch_completion(monkeypatch, request):
partial(mock_completion, test_name=test_name),
)
# Mock LLM completion cost (1 USD per conversation)
monkeypatch.setattr(
'opendevin.llm.llm.litellm_completion_cost',
lambda completion_response, **extra_kwargs: 1,
)
# Mock user input (only for tests that have user_responses.log)
user_responses_str = mock_user_response(test_name=test_name)
if user_responses_str:

View File

@@ -28,10 +28,25 @@ print(f'workspace_mount_path: {workspace_mount_path}')
print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
def validate_final_state(final_state: State | None):
def get_number_of_prompts(test_name: str):
mock_dir = os.path.join(
os.environ.get('SCRIPT_DIR'), 'mock', os.environ.get('DEFAULT_AGENT'), test_name
)
prompt_files = [file for file in os.listdir(mock_dir) if file.startswith('prompt_')]
return len(prompt_files)
def validate_final_state(final_state: State | None, test_name: str):
assert final_state is not None
assert final_state.agent_state == AgentState.STOPPED
assert final_state.last_error is None
# number of LLM conversations should be the same as number of prompt/response
# log files under mock/[agent]/[test_name] folder. If not, it means there are
# redundant prompt/response log files checked into the repository.
num_of_conversations = get_number_of_prompts(test_name)
assert num_of_conversations > 0
# we mock the cost of every conversation to be 1 USD
assert final_state.metrics.accumulated_cost == num_of_conversations
if final_state.history.has_delegation():
assert final_state.iteration > final_state.local_iteration
else:
@@ -55,7 +70,7 @@ def validate_final_state(final_state: State | None):
os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
reason='Manager agent is not capable of finishing this in reasonable steps yet',
)
def test_write_simple_script() -> None:
def test_write_simple_script(current_test_name) -> None:
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
args = parse_arguments()
@@ -65,9 +80,7 @@ def test_write_simple_script() -> None:
final_state: State | None = asyncio.run(
run_agent_controller(agent, task, exit_on_message=True)
)
assert final_state is not None
assert final_state.agent_state == AgentState.STOPPED
assert final_state.last_error is None
validate_final_state(final_state, current_test_name)
# Verify the script file exists
assert workspace_base is not None
@@ -103,7 +116,7 @@ def test_write_simple_script() -> None:
os.getenv('SANDBOX_BOX_TYPE') == 'local',
reason='local sandbox shows environment-dependent absolute path for pwd command',
)
def test_edits():
def test_edits(current_test_name):
args = parse_arguments()
# Copy workspace artifacts to workspace_base location
source_dir = os.path.join(os.path.dirname(__file__), 'workspace/test_edits/')
@@ -122,7 +135,7 @@ def test_edits():
final_state: State | None = asyncio.run(
run_agent_controller(agent, task, exit_on_message=True)
)
validate_final_state(final_state)
validate_final_state(final_state, current_test_name)
# Verify bad.txt has been fixed
text = """This is a stupid typo.
@@ -144,7 +157,7 @@ Enjoy!
os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
reason='Currently, only ssh sandbox supports stateful tasks',
)
def test_ipython():
def test_ipython(current_test_name):
args = parse_arguments()
# Create the agent
@@ -155,7 +168,7 @@ def test_ipython():
final_state: State | None = asyncio.run(
run_agent_controller(agent, task, exit_on_message=True)
)
validate_final_state(final_state)
validate_final_state(final_state, current_test_name)
# Verify the file exists
file_path = os.path.join(workspace_base, 'test.txt')
@@ -177,7 +190,7 @@ def test_ipython():
os.getenv('SANDBOX_BOX_TYPE') == 'local',
reason='FIXME: local sandbox does not capture stderr',
)
def test_simple_task_rejection():
def test_simple_task_rejection(current_test_name):
args = parse_arguments()
# Create the agent
@@ -187,7 +200,7 @@ def test_simple_task_rejection():
# the workspace is not a git repo
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
final_state: State | None = asyncio.run(run_agent_controller(agent, task))
validate_final_state(final_state)
validate_final_state(final_state, current_test_name)
assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
@@ -200,7 +213,7 @@ def test_simple_task_rejection():
os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
reason='Currently, only ssh sandbox supports stateful tasks',
)
def test_ipython_module():
def test_ipython_module(current_test_name):
args = parse_arguments()
# Create the agent
@@ -211,7 +224,7 @@ def test_ipython_module():
final_state: State | None = asyncio.run(
run_agent_controller(agent, task, exit_on_message=True)
)
validate_final_state(final_state)
validate_final_state(final_state, current_test_name)
# Verify the file exists
file_path = os.path.join(workspace_base, 'test.txt')
@@ -239,7 +252,7 @@ def test_ipython_module():
and os.getenv('SANDBOX_BOX_TYPE', '').lower() != 'ssh',
reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
)
def test_browse_internet(http_server):
def test_browse_internet(http_server, current_test_name):
args = parse_arguments()
# Create the agent
@@ -250,7 +263,7 @@ def test_browse_internet(http_server):
final_state: State | None = asyncio.run(
run_agent_controller(agent, task, exit_on_message=True)
)
validate_final_state(final_state)
validate_final_state(final_state, current_test_name)
# last action
last_action = final_state.history.get_last_action()