mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
* Replace OpenDevin with OpenHands * Update CONTRIBUTING.md * Update README.md * Update README.md * update poetry lock; move opendevin folder to openhands * fix env var * revert image references in docs * revert permissions * revert permissions --------- Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
195 lines
6.4 KiB
Python
195 lines
6.4 KiB
Python
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from openhands.controller.agent import Agent
|
|
from openhands.controller.agent_controller import AgentController
|
|
from openhands.controller.state.state import TrafficControlState
|
|
from openhands.core.exceptions import LLMMalformedActionError
|
|
from openhands.core.schema import AgentState
|
|
from openhands.events import EventStream
|
|
from openhands.events.action import ChangeAgentStateAction, MessageAction
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
|
|
return str(tmp_path_factory.mktemp('test_event_stream'))
|
|
|
|
|
|
@pytest.fixture(scope='function')
|
|
def event_loop():
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
return MagicMock(spec=Agent)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_event_stream():
|
|
return MagicMock(spec=EventStream)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_agent_state(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
await controller.set_agent_state_to(AgentState.RUNNING)
|
|
assert controller.get_agent_state() == AgentState.RUNNING
|
|
|
|
await controller.set_agent_state_to(AgentState.PAUSED)
|
|
assert controller.get_agent_state() == AgentState.PAUSED
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_event_message_action(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
message_action = MessageAction(content='Test message')
|
|
await controller.on_event(message_action)
|
|
assert controller.get_agent_state() == AgentState.RUNNING
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
|
|
await controller.on_event(change_state_action)
|
|
assert controller.get_agent_state() == AgentState.PAUSED
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_report_error(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
error_message = 'Test error'
|
|
await controller.report_error(error_message)
|
|
assert controller.state.last_error == error_message
|
|
controller.event_stream.add_event.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_with_exception(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.report_error = AsyncMock()
|
|
controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
|
|
await controller._step()
|
|
|
|
# Verify that report_error was called with the correct error message
|
|
controller.report_error.assert_called_once_with('Malformed action')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_max_iterations(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=False,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.iteration = 10
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
await controller._step()
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
assert controller.state.agent_state == AgentState.PAUSED
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.iteration = 10
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
await controller._step()
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
# In headless mode, throttling results in an error
|
|
assert controller.state.agent_state == AgentState.ERROR
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_max_budget(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
max_budget_per_task=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=False,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.metrics.accumulated_cost = 10.1
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
await controller._step()
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
assert controller.state.agent_state == AgentState.PAUSED
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
max_budget_per_task=10,
|
|
sid='test',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
controller.state.agent_state = AgentState.RUNNING
|
|
controller.state.metrics.accumulated_cost = 10.1
|
|
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
|
await controller._step()
|
|
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
|
# In headless mode, throttling results in an error
|
|
assert controller.state.agent_state == AgentState.ERROR
|