From dfc36eb861ea1da09bfeee6cf0a2164290df8f4f Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Mon, 12 Aug 2024 21:26:36 -0400 Subject: [PATCH] Add tests for agent controller --- tests/unit/test_agent_controller.py | 216 ++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 tests/unit/test_agent_controller.py diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py new file mode 100644 index 0000000000..ee760b3717 --- /dev/null +++ b/tests/unit/test_agent_controller.py @@ -0,0 +1,216 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from opendevin.controller.agent import Agent +from opendevin.controller.agent_controller import AgentController +from opendevin.controller.state.state import TrafficControlState +from opendevin.core.exceptions import LLMMalformedActionError +from opendevin.core.schema import AgentState +from opendevin.events import EventStream +from opendevin.events.action import ChangeAgentStateAction, MessageAction + + +@pytest.fixture +def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str: + return str(tmp_path_factory.mktemp('test_event_stream')) + + +@pytest.fixture(scope='function') +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def mock_agent(): + return MagicMock(spec=Agent) + + +# @pytest.fixture +# def event_stream(temp_dir): +# file_store = get_file_store('local', temp_dir) +# return EventStream(sid='test', file_store=file_store) + + +@pytest.fixture +def mock_event_stream(): + return MagicMock(spec=EventStream) + + +@pytest.fixture +async def agent_controller(mock_agent, mock_event_stream): + return AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + # try: + # yield controller + # finally: + # await controller.close() + + +@pytest.mark.asyncio +async def test_set_agent_state(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + await controller.set_agent_state_to(AgentState.RUNNING) + assert controller.get_agent_state() == AgentState.RUNNING + + await controller.set_agent_state_to(AgentState.PAUSED) + assert controller.get_agent_state() == AgentState.PAUSED + + +@pytest.mark.asyncio +async def test_on_event_message_action(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + controller.state.agent_state = AgentState.RUNNING + message_action = MessageAction(content='Test message') + await controller.on_event(message_action) + assert controller.get_agent_state() == AgentState.RUNNING + + +@pytest.mark.asyncio +async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + controller.state.agent_state = AgentState.RUNNING + change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED) + await controller.on_event(change_state_action) + assert controller.get_agent_state() == AgentState.PAUSED + + +@pytest.mark.asyncio +async def test_report_error(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + error_message = 'Test error' + await controller.report_error(error_message) + assert controller.state.last_error == error_message + controller.event_stream.add_event.assert_called_once() + + +@pytest.mark.asyncio +async def test_step_with_exception(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + controller.state.agent_state = AgentState.RUNNING + controller.report_error = AsyncMock() + controller.agent.step.side_effect = LLMMalformedActionError('Malformed action') + await controller._step() + + # Verify that report_error was called with the correct error message + controller.report_error.assert_called_once_with('Malformed action') + + +@pytest.mark.asyncio +async def test_step_max_iterations(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=False, + ) + controller.state.agent_state = AgentState.RUNNING + controller.state.iteration = 10 + assert controller.state.traffic_control_state == TrafficControlState.NORMAL + await controller._step() + assert controller.state.traffic_control_state == TrafficControlState.THROTTLING + assert controller.state.agent_state == AgentState.PAUSED + + +@pytest.mark.asyncio +async def test_step_max_iterations_headless(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + controller.state.agent_state = AgentState.RUNNING + controller.state.iteration = 10 + assert controller.state.traffic_control_state == TrafficControlState.NORMAL + await controller._step() + assert controller.state.traffic_control_state == TrafficControlState.THROTTLING + # In headless mode, throttling results in an error + assert controller.state.agent_state == AgentState.ERROR + + +@pytest.mark.asyncio +async def test_step_max_budget(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + max_budget_per_task=10, + sid='test', + confirmation_mode=False, + headless_mode=False, + ) + controller.state.agent_state = AgentState.RUNNING + controller.state.metrics.accumulated_cost = 10.1 + assert controller.state.traffic_control_state == TrafficControlState.NORMAL + await controller._step() + assert controller.state.traffic_control_state == TrafficControlState.THROTTLING + assert controller.state.agent_state == AgentState.PAUSED + + +@pytest.mark.asyncio +async def test_step_max_budget_headless(mock_agent, mock_event_stream): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + max_budget_per_task=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + controller.state.agent_state = AgentState.RUNNING + controller.state.metrics.accumulated_cost = 10.1 + assert controller.state.traffic_control_state == TrafficControlState.NORMAL + await controller._step() + assert controller.state.traffic_control_state == TrafficControlState.THROTTLING + # In headless mode, throttling results in an error + assert controller.state.agent_state == AgentState.ERROR