OpenHands/tests/unit/test_agent_controller.py
AutoLTX 5e521a4a6e
Expose accumulate and llm_metric from eventstream (backend) (#7082)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
2025-03-11 12:18:08 +01:00

854 lines
28 KiB
Python

import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock
from uuid import uuid4
import pytest
from litellm import ContextWindowExceededError
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State, TrafficControlState
from openhands.core.config import AppConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
from openhands.events.observation import (
ErrorObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.runtime.base import Runtime
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
return str(tmp_path_factory.mktemp('test_event_stream'))
@pytest.fixture(scope='function')
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
def mock_agent():
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
return agent
@pytest.fixture
def mock_event_stream():
mock = MagicMock(spec=EventStream)
mock.get_latest_event_id.return_value = 0
return mock
@pytest.fixture
def mock_runtime() -> Runtime:
return MagicMock(
spec=Runtime,
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
)
@pytest.fixture
def mock_status_callback():
return AsyncMock()
async def send_event_to_controller(controller, event):
await controller._on_event(event)
await asyncio.sleep(0.1)
@pytest.mark.asyncio
async def test_set_agent_state(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
await controller.set_agent_state_to(AgentState.RUNNING)
assert controller.get_agent_state() == AgentState.RUNNING
await controller.set_agent_state_to(AgentState.PAUSED)
assert controller.get_agent_state() == AgentState.PAUSED
await controller.close()
@pytest.mark.asyncio
async def test_on_event_message_action(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
message_action = MessageAction(content='Test message')
await send_event_to_controller(controller, message_action)
assert controller.get_agent_state() == AgentState.RUNNING
await controller.close()
@pytest.mark.asyncio
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
await send_event_to_controller(controller, change_state_action)
assert controller.get_agent_state() == AgentState.PAUSED
await controller.close()
@pytest.mark.asyncio
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
status_callback=mock_status_callback,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
error_message = 'Test error'
await controller._react_to_exception(RuntimeError(error_message))
controller.status_callback.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_run_controller_with_fatal_error():
config = AppConfig()
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
agent = MagicMock(spec=Agent)
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls')
agent.step = agent_step_fn
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
def on_event(event: Event):
if isinstance(event, CmdRunAction):
error_obs = ErrorObservation('You messed around with Jim')
error_obs._cause = event.id
event_stream.add_event(error_obs, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
)
print(f'state: {state}')
events = list(event_stream.get_events())
print(f'event_stream: {events}')
assert state.iteration == 4
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
assert len(events) == 11
@pytest.mark.asyncio
async def test_run_controller_stop_with_stuck():
config = AppConfig()
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls')
agent.step = agent_step_fn
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=Runtime)
def on_event(event: Event):
if isinstance(event, CmdRunAction):
non_fatal_error_obs = ErrorObservation(
'Non fatal error here to trigger loop'
)
non_fatal_error_obs._cause = event.id
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
)
events = list(event_stream.get_events())
print(f'state: {state}')
for i, event in enumerate(events):
print(f'event {i}: {event_to_dict(event)}')
assert state.iteration == 4
assert len(events) == 11
# check the eventstream have 4 pairs of repeated actions and observations
repeating_actions_and_observations = events[2:10]
for action, observation in zip(
repeating_actions_and_observations[0::2],
repeating_actions_and_observations[1::2],
):
action_dict = event_to_dict(action)
observation_dict = event_to_dict(observation)
assert action_dict['action'] == 'run' and action_dict['args']['command'] == 'ls'
assert (
observation_dict['observation'] == 'error'
and observation_dict['content'] == 'Non fatal error here to trigger loop'
)
last_event = event_to_dict(events[-1])
assert last_event['extras']['agent_state'] == 'error'
assert last_event['observation'] == 'agent_state_changed'
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
@pytest.mark.asyncio
async def test_max_iterations_extension(mock_agent, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations
initial_state = State(max_iterations=10)
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=False,
initial_state=initial_state,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
# Trigger throttling by calling _step() when we hit max_iterations
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.ERROR
# Simulate a new user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await send_event_to_controller(controller, message_action)
# Max iterations should be extended to current iteration + initial max_iterations
assert (
controller.state.max_iterations == 20
) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
assert controller.state.agent_state == AgentState.RUNNING
# Close the controller to clean up
await controller.close()
# Test with headless_mode=True - should NOT extend max_iterations
initial_state = State(max_iterations=10)
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
initial_state=initial_state,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
# Simulate a new user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await send_event_to_controller(controller, message_action)
# Max iterations should NOT be extended in headless mode
assert controller.state.max_iterations == 10 # Original value unchanged
# Trigger throttling by calling _step() when we hit max_iterations
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.ERROR
await controller.close()
@pytest.mark.asyncio
async def test_step_max_budget(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
max_budget_per_task=10,
sid='test',
confirmation_mode=False,
headless_mode=False,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.metrics.accumulated_cost = 10.1
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.ERROR
await controller.close()
@pytest.mark.asyncio
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
max_budget_per_task=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.metrics.accumulated_cost = 10.1
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
# In headless mode, throttling results in an error
assert controller.state.agent_state == AgentState.ERROR
await controller.close()
@pytest.mark.asyncio
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
"""Test reset() when there's a pending action with tool call metadata but no observation."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
'function': 'test_function',
'args': {'arg1': 'value1'},
}
controller._pending_action = pending_action
# Call reset
controller._reset()
# Verify that an ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_called_once()
args, kwargs = mock_event_stream.add_event.call_args
error_obs, source = args
assert isinstance(error_obs, ErrorObservation)
assert error_obs.content == 'The action has not been executed.'
assert error_obs.tool_call_metadata == pending_action.tool_call_metadata
assert error_obs._cause == pending_action.id
assert source == EventSource.AGENT
# Verify that pending action was reset
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_reset_with_pending_action_existing_observation(
mock_agent, mock_event_stream
):
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
'function': 'test_function',
'args': {'arg1': 'value1'},
}
controller._pending_action = pending_action
# Add an existing observation to the history
existing_obs = ErrorObservation(content='Previous error')
existing_obs.tool_call_metadata = pending_action.tool_call_metadata
controller.state.history.append(existing_obs)
# Call reset
controller._reset()
# Verify that no new ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()
# Verify that pending action was reset
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
"""Test reset() when there's no pending action."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Call reset
controller._reset()
# Verify that no ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()
# Verify that pending action is None
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_reset_with_pending_action_no_metadata(
mock_agent, mock_event_stream, monkeypatch
):
"""Test reset() when there's a pending action without tool call metadata."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Create a pending action without tool call metadata
pending_action = CmdRunAction(command='test')
# Mock hasattr to return False for tool_call_metadata
original_hasattr = hasattr
def mock_hasattr(obj, name):
if obj == pending_action and name == 'tool_call_metadata':
return False
return original_hasattr(obj, name)
monkeypatch.setattr('builtins.hasattr', mock_hasattr)
controller._pending_action = pending_action
# Call reset
controller._reset()
# Verify that no ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()
# Verify that pending action was reset
assert controller._pending_action is None
# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_run_controller_max_iterations_has_metrics():
config = AppConfig(
max_iterations=3,
)
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
# Mock the cost of the LLM
agent.llm.metrics.add_cost(10.0)
print(
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
)
return CmdRunAction(command='ls')
agent.step = agent_step_fn
runtime = MagicMock(spec=Runtime)
def on_event(event: Event):
if isinstance(event, CmdRunAction):
non_fatal_error_obs = ErrorObservation(
'Non fatal error. event id: ' + str(event.id)
)
non_fatal_error_obs._cause = event.id
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
runtime.event_stream = event_stream
state = await run_controller(
config=config,
initial_user_action=MessageAction(content='Test message'),
runtime=runtime,
sid='test',
agent=agent,
fake_user_response_fn=lambda _: 'repeat',
)
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert (
state.last_error
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
assert (
state.metrics.accumulated_cost == 10.0 * 3
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'
@pytest.mark.asyncio
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
status_callback=mock_status_callback,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller._notify_on_llm_retry(1, 2)
controller.status_callback.assert_called_once_with('info', 'STATUS$LLM_RETRY', ANY)
await controller.close()
@pytest.mark.asyncio
async def test_context_window_exceeded_error_handling(mock_agent, mock_event_stream):
"""Test that context window exceeded errors are handled correctly by truncating history."""
class StepState:
def __init__(self):
self.has_errored = False
def step(self, state: State):
# Append a few messages to the history -- these will be truncated when we throw the error
state.history = [
MessageAction(content='Test message 0'),
MessageAction(content='Test message 1'),
]
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
llm_provider='',
)
self.has_errored = True
raise error
state = StepState()
mock_agent.step = state.step
mock_agent.config = AgentConfig()
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Set the agent running and take a step in the controller -- this is similar
# to taking a single step using `run_controller`, but much easier to control
# termination for testing purposes
controller.state.agent_state = AgentState.RUNNING
await controller._step()
# Check that the error was thrown and the history has been truncated
assert state.has_errored
assert controller.state.history == [MessageAction(content='Test message 1')]
@pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_with_truncation(
mock_agent, mock_runtime
):
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON"""
class StepState:
def __init__(self):
self.has_errored = False
def step(self, state: State):
# If the state has more than one message and we haven't errored yet,
# throw the context window exceeded error
if len(state.history) > 1 and not self.has_errored:
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
llm_provider='',
)
self.has_errored = True
raise error
return MessageAction(content=f'STEP {len(state.history)}')
step_state = StepState()
mock_agent.step = step_state.step
mock_agent.config = AgentConfig()
try:
state = await asyncio.wait_for(
run_controller(
config=AppConfig(max_iterations=3),
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
),
timeout=10,
)
# A timeout error indicates the run_controller entrypoint is not making
# progress
except asyncio.TimeoutError as e:
raise AssertionError(
'The run_controller function did not complete in time.'
) from e
# Hitting the iteration limit indicates the controller is failing for the
# expected reason
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert (
state.last_error
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
# Check that the context window exceeded error was raised during the run
assert step_state.has_errored
@pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_without_truncation(
mock_agent, mock_runtime
):
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
class StepState:
def __init__(self):
self.has_errored = False
def step(self, state: State):
# If the state has more than one message and we haven't errored yet,
# throw the context window exceeded error
if len(state.history) > 1 and not self.has_errored:
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
llm_provider='',
)
self.has_errored = True
raise error
return MessageAction(content=f'STEP {len(state.history)}')
step_state = StepState()
mock_agent.step = step_state.step
mock_agent.config = AgentConfig()
mock_agent.config.enable_history_truncation = False
try:
state = await asyncio.wait_for(
run_controller(
config=AppConfig(max_iterations=3),
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
),
timeout=10,
)
# A timeout error indicates the run_controller entrypoint is not making
# progress
except asyncio.TimeoutError as e:
raise AssertionError(
'The run_controller function did not complete in time.'
) from e
# Hitting the iteration limit indicates the controller is failing for the
# expected reason
assert state.iteration == 2
assert state.agent_state == AgentState.ERROR
assert (
state.last_error
== 'LLMContextWindowExceedError: Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error'
)
# Check that the context window exceeded error was raised during the run
assert step_state.has_errored
@pytest.mark.asyncio
async def test_action_metrics_copy():
# Setup
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
# Create agent with metrics
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
metrics = Metrics(model_name='test-model')
metrics.accumulated_cost = 0.05
# Add multiple token usages - we should get the last one in the action
usage1 = TokenUsage(
model='test-model',
prompt_tokens=5,
completion_tokens=10,
cache_read_tokens=2,
cache_write_tokens=2,
response_id='test-id-1',
)
usage2 = TokenUsage(
model='test-model',
prompt_tokens=10,
completion_tokens=20,
cache_read_tokens=5,
cache_write_tokens=5,
response_id='test-id-2',
)
metrics.token_usages = [usage1, usage2]
# Add a cost instance - should not be included in action metrics
# This will increase accumulated_cost by 0.02
metrics.add_cost(0.02)
# Add a response latency - should not be included in action metrics
metrics.add_response_latency(0.5, 'test-id-2')
agent.llm.metrics = metrics
# Mock agent step to return an action
action = MessageAction(content='Test message')
def agent_step_fn(state):
return action
agent.step = agent_step_fn
# Create controller with correct parameters
controller = AgentController(
agent=agent,
event_stream=event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Execute one step
controller.state.agent_state = AgentState.RUNNING
await controller._step()
# Get the last event from event stream
events = list(event_stream.get_events())
assert len(events) > 0
last_action = events[-1]
# Verify metrics were copied correctly
assert last_action.llm_metrics is not None
assert (
last_action.llm_metrics.accumulated_cost == 0.07
) # 0.05 initial + 0.02 from add_cost
# Should include the last token usage
assert len(last_action.llm_metrics.token_usages) == 1
assert last_action.llm_metrics.token_usages[0].prompt_tokens == 10
assert last_action.llm_metrics.token_usages[0].completion_tokens == 20
assert last_action.llm_metrics.token_usages[0].cache_read_tokens == 5
assert last_action.llm_metrics.token_usages[0].cache_write_tokens == 5
assert last_action.llm_metrics.token_usages[0].response_id == 'test-id-2'
# Should not include the cost history
assert len(last_action.llm_metrics.costs) == 0
# Should not include the response latency history
assert len(last_action.llm_metrics.response_latencies) == 0
# Verify that there's no latency information in the action's metrics
# Either directly or as a calculated property
assert not hasattr(last_action.llm_metrics, 'latency')
assert not hasattr(last_action.llm_metrics, 'total_latency')
assert not hasattr(last_action.llm_metrics, 'average_latency')
# Verify it's a deep copy by modifying the original
agent.llm.metrics.accumulated_cost = 0.1
assert last_action.llm_metrics.accumulated_cost == 0.07
await controller.close()