Delegation fixes (#6165)

This commit is contained in:
Engel Nyst 2025-01-15 04:24:39 +01:00 committed by GitHub
parent 082d0b25c5
commit b9a70c8d5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 411 additions and 150 deletions

View File

@ -56,6 +56,7 @@ jobs:
LLM_MODEL: "litellm_proxy/claude-3-5-haiku-20241022"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 10
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
@ -70,7 +71,7 @@ jobs:
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' $N_PROCESSES '' 'haiku_run'
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' 10 $N_PROCESSES '' 'haiku_run'
# get integration tests report
REPORT_FILE_HAIKU=$(find evaluation/evaluation_outputs/outputs/integration_tests/CodeActAgent/*haiku*_maxiter_10_N* -name "report.md" -type f | head -n 1)
@ -88,6 +89,7 @@ jobs:
LLM_MODEL: "litellm_proxy/deepseek-chat"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 10
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
@ -99,7 +101,7 @@ jobs:
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' $N_PROCESSES '' 'deepseek_run'
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD CodeActAgent '' 10 $N_PROCESSES '' 'deepseek_run'
# get integration tests report
REPORT_FILE_DEEPSEEK=$(find evaluation/evaluation_outputs/outputs/integration_tests/CodeActAgent/deepseek*_maxiter_10_N* -name "report.md" -type f | head -n 1)
@ -109,11 +111,75 @@ jobs:
echo >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
# -------------------------------------------------------------
# Run DelegatorAgent tests for Haiku, limited to t01 and t02
- name: Wait a little bit (again)
run: sleep 5
- name: Configure config.toml for testing DelegatorAgent (Haiku)
env:
LLM_MODEL: "litellm_proxy/claude-3-5-haiku-20241022"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 30
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
echo "api_key = \"$LLM_API_KEY\"" >> config.toml
echo "base_url = \"$LLM_BASE_URL\"" >> config.toml
echo "temperature = 0.0" >> config.toml
- name: Run integration test evaluation for DelegatorAgent (Haiku)
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD DelegatorAgent '' 30 $N_PROCESSES "t01_fix_simple_typo,t02_add_bash_hello" 'delegator_haiku_run'
# Find and export the delegator test results
REPORT_FILE_DELEGATOR_HAIKU=$(find evaluation/evaluation_outputs/outputs/integration_tests/DelegatorAgent/*haiku*_maxiter_30_N* -name "report.md" -type f | head -n 1)
echo "REPORT_FILE_DELEGATOR_HAIKU: $REPORT_FILE_DELEGATOR_HAIKU"
echo "INTEGRATION_TEST_REPORT_DELEGATOR_HAIKU<<EOF" >> $GITHUB_ENV
cat $REPORT_FILE_DELEGATOR_HAIKU >> $GITHUB_ENV
echo >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
# -------------------------------------------------------------
# Run DelegatorAgent tests for DeepSeek, limited to t01 and t02
- name: Wait a little bit (again)
run: sleep 5
- name: Configure config.toml for testing DelegatorAgent (DeepSeek)
env:
LLM_MODEL: "litellm_proxy/deepseek-chat"
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_BASE_URL: ${{ secrets.LLM_BASE_URL }}
MAX_ITERATIONS: 30
run: |
echo "[llm.eval]" > config.toml
echo "model = \"$LLM_MODEL\"" >> config.toml
echo "api_key = \"$LLM_API_KEY\"" >> config.toml
echo "base_url = \"$LLM_BASE_URL\"" >> config.toml
echo "temperature = 0.0" >> config.toml
- name: Run integration test evaluation for DelegatorAgent (DeepSeek)
env:
SANDBOX_FORCE_REBUILD_RUNTIME: True
run: |
poetry run ./evaluation/integration_tests/scripts/run_infer.sh llm.eval HEAD DelegatorAgent '' 30 $N_PROCESSES "t01_fix_simple_typo,t02_add_bash_hello" 'delegator_deepseek_run'
# Find and export the delegator test results
REPORT_FILE_DELEGATOR_DEEPSEEK=$(find evaluation/evaluation_outputs/outputs/integration_tests/DelegatorAgent/deepseek*_maxiter_30_N* -name "report.md" -type f | head -n 1)
echo "REPORT_FILE_DELEGATOR_DEEPSEEK: $REPORT_FILE_DELEGATOR_DEEPSEEK"
echo "INTEGRATION_TEST_REPORT_DELEGATOR_DEEPSEEK<<EOF" >> $GITHUB_ENV
cat $REPORT_FILE_DELEGATOR_DEEPSEEK >> $GITHUB_ENV
echo >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
- name: Create archive of evaluation outputs
run: |
TIMESTAMP=$(date +'%y-%m-%d-%H-%M')
cd evaluation/evaluation_outputs/outputs # Change to the outputs directory
tar -czvf ../../../integration_tests_${TIMESTAMP}.tar.gz integration_tests/CodeActAgent/* # Only include the actual result directories
tar -czvf ../../../integration_tests_${TIMESTAMP}.tar.gz integration_tests/CodeActAgent/* integration_tests/DelegatorAgent/* # Only include the actual result directories
- name: Upload evaluation results as artifact
uses: actions/upload-artifact@v4
@ -154,5 +220,11 @@ jobs:
**Integration Tests Report (DeepSeek)**
DeepSeek LLM Test Results:
${{ env.INTEGRATION_TEST_REPORT_DEEPSEEK }}
---
**Integration Tests Report Delegator (Haiku)**
${{ env.INTEGRATION_TEST_REPORT_DELEGATOR_HAIKU }}
---
**Integration Tests Report Delegator (DeepSeek)**
${{ env.INTEGRATION_TEST_REPORT_DELEGATOR_DEEPSEEK }}
---
Download testing outputs (includes both Haiku and DeepSeek results): [Download](${{ steps.upload_results_artifact.outputs.artifact-url }})

View File

@ -8,13 +8,15 @@ from evaluation.integration_tests.tests.base import BaseIntegrationTest, TestRes
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from evaluation.utils.shared import (
codeact_user_response as fake_user_response,
)
from openhands.controller.state.state import State
from openhands.core.config import (
AgentConfig,
@ -31,7 +33,8 @@ from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
FAKE_RESPONSES = {
'CodeActAgent': codeact_user_response,
'CodeActAgent': fake_user_response,
'DelegatorAgent': fake_user_response,
}
@ -219,7 +222,7 @@ if __name__ == '__main__':
df = pd.read_json(output_file, lines=True, orient='records')
# record success and reason for failure for the final report
# record success and reason
df['success'] = df['test_result'].apply(lambda x: x['success'])
df['reason'] = df['test_result'].apply(lambda x: x['reason'])
logger.info('-' * 100)
@ -234,15 +237,27 @@ if __name__ == '__main__':
logger.info('-' * 100)
# record cost for each instance, with 3 decimal places
df['cost'] = df['metrics'].apply(lambda x: round(x['accumulated_cost'], 3))
# we sum up all the "costs" from the metrics array
df['cost'] = df['metrics'].apply(
lambda m: round(sum(c['cost'] for c in m['costs']), 3)
if m and 'costs' in m
else 0.0
)
# capture the top-level error if present, per instance
df['error_message'] = df.get('error', None)
logger.info(f'Total cost: USD {df["cost"].sum():.2f}')
report_file = os.path.join(metadata.eval_output_dir, 'report.md')
with open(report_file, 'w') as f:
f.write(
f'Success rate: {df["success"].mean():.2%} ({df["success"].sum()}/{len(df)})\n'
f'Success rate: {df["success"].mean():.2%}'
f' ({df["success"].sum()}/{len(df)})\n'
)
f.write(f'\nTotal cost: USD {df["cost"].sum():.2f}\n')
f.write(
df[['instance_id', 'success', 'reason', 'cost']].to_markdown(index=False)
df[
['instance_id', 'success', 'reason', 'cost', 'error_message']
].to_markdown(index=False)
)

View File

@ -7,8 +7,9 @@ MODEL_CONFIG=$1
COMMIT_HASH=$2
AGENT=$3
EVAL_LIMIT=$4
NUM_WORKERS=$5
EVAL_IDS=$6
MAX_ITERATIONS=$5
NUM_WORKERS=$6
EVAL_IDS=$7
if [ -z "$NUM_WORKERS" ]; then
NUM_WORKERS=1
@ -43,7 +44,7 @@ fi
COMMAND="poetry run python evaluation/integration_tests/run_infer.py \
--agent-cls $AGENT \
--llm-config $MODEL_CONFIG \
--max-iterations 10 \
--max-iterations ${MAX_ITERATIONS:-10} \
--eval-num-workers $NUM_WORKERS \
--eval-note $EVAL_NOTE"

View File

@ -50,6 +50,10 @@ class MicroAgent(Agent):
# history is in reverse order, let's fix it
processed_history.reverse()
# everything starts with a message
# the first message is already in the prompt as the task
# TODO: so we don't need to include it in the history
return json.dumps(processed_history, **kwargs)
def __init__(self, llm: LLM, config: AgentConfig):

View File

@ -112,12 +112,16 @@ class AgentController:
self.id = sid
self.agent = agent
self.headless_mode = headless_mode
self.is_delegate = is_delegate
# subscribe to the event stream
# the event stream must be set before maybe subscribing to it
self.event_stream = event_stream
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# subscribe to the event stream if this is not a delegate
if not self.is_delegate:
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# state from the previous session, state from a parent agent, or a fresh state
self.set_initial_state(
@ -165,7 +169,11 @@ class AgentController:
)
# unsubscribe from the event stream
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
# only the root parent controller subscribes to the event stream
if not self.is_delegate:
self.event_stream.unsubscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.id
)
self._closed = True
def log(self, level: str, message: str, extra: dict | None = None) -> None:
@ -226,9 +234,21 @@ class AgentController:
await self._react_to_exception(reported)
def should_step(self, event: Event) -> bool:
# it might be the delegate's day in the sun
if self.delegate is not None:
return False
if isinstance(event, Action):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
return True
if (
isinstance(event, MessageAction)
and self.get_agent_state() != AgentState.AWAITING_USER_INPUT
):
# TODO: this is fragile, but how else to check if eligible?
return True
if isinstance(event, AgentDelegateAction):
return True
return False
if isinstance(event, Observation):
if isinstance(event, NullObservation) or isinstance(
@ -244,12 +264,35 @@ class AgentController:
Args:
event (Event): The incoming event to process.
"""
# If we have a delegate that is not finished or errored, forward events to it
if self.delegate is not None:
delegate_state = self.delegate.get_agent_state()
if delegate_state not in (
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
):
# Forward the event to delegate and skip parent processing
asyncio.get_event_loop().run_until_complete(
self.delegate._on_event(event)
)
return
else:
# delegate is done or errored, so end it
self.end_delegate()
return
# continue parent processing only if there's no active delegate
asyncio.get_event_loop().run_until_complete(self._on_event(event))
async def _on_event(self, event: Event) -> None:
if hasattr(event, 'hidden') and event.hidden:
return
# Give others a little chance
await asyncio.sleep(0.01)
# if the event is not filtered out, add it to the history
if not any(isinstance(event, filter_type) for filter_type in self.filter_out):
self.state.history.append(event)
@ -263,17 +306,22 @@ class AgentController:
self.step()
async def _handle_action(self, action: Action) -> None:
"""Handles actions from the event stream.
Args:
action (Action): The action to handle.
"""
"""Handles an Action from the agent or delegate."""
if isinstance(action, ChangeAgentStateAction):
await self.set_agent_state_to(action.agent_state) # type: ignore
elif isinstance(action, MessageAction):
await self._handle_message_action(action)
elif isinstance(action, AgentDelegateAction):
await self.start_delegate(action)
assert self.delegate is not None
# Post a MessageAction with the task for the delegate
if 'task' in action.inputs:
self.event_stream.add_event(
MessageAction(content='TASK: ' + action.inputs['task']),
EventSource.USER,
)
await self.delegate.set_agent_state_to(AgentState.RUNNING)
return
elif isinstance(action, AgentFinishAction):
self.state.outputs = action.outputs
@ -491,7 +539,7 @@ class AgentController:
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
)
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, self.id)
# Create the delegate with is_delegate=True so it does NOT subscribe directly
self.delegate = AgentController(
sid=self.id + '-delegate',
agent=delegate_agent,
@ -504,7 +552,57 @@ class AgentController:
is_delegate=True,
headless_mode=self.headless_mode,
)
await self.delegate.set_agent_state_to(AgentState.RUNNING)
def end_delegate(self) -> None:
"""Ends the currently active delegate (e.g., if it is finished or errored)
so that this controller can resume normal operation.
"""
if self.delegate is None:
return
delegate_state = self.delegate.get_agent_state()
# update iteration that is shared across agents
self.state.iteration = self.delegate.state.iteration
# close the delegate controller before adding new events
asyncio.get_event_loop().run_until_complete(self.delegate.close())
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
# retrieve delegate result
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
# prepare delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in delegate_outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
# emit the delegate result observation
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
self.event_stream.add_event(obs, EventSource.AGENT)
else:
# delegate state is ERROR
# emit AgentDelegateObservation with error content
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
content = (
f'{self.delegate.agent.name} encountered an error during execution.'
)
# emit the delegate result observation
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
self.event_stream.add_event(obs, EventSource.AGENT)
# unset delegate so parent can resume normal handling
self.delegate = None
self.delegateAction = None
async def _step(self) -> None:
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
@ -514,14 +612,6 @@ class AgentController:
if self._pending_action:
return
if self.delegate is not None:
assert self.delegate != self
# TODO this conditional will always be false, because the parent controllers are unsubscribed
# remove if it's still useless when delegation is reworked
if self.delegate.get_agent_state() != AgentState.PAUSED:
await self._delegate_step()
return
self.log(
'info',
f'LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration}',
@ -611,68 +701,6 @@ class AgentController:
log_level = 'info' if LOG_ALL_EVENTS else 'debug'
self.log(log_level, str(action), extra={'msg_type': 'ACTION'})
async def _delegate_step(self) -> None:
"""Executes a single step of the delegate agent."""
await self.delegate._step() # type: ignore[union-attr]
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
self.log('debug', f'Delegate state: {delegate_state}')
if delegate_state == AgentState.ERROR:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# emit AgentDelegateObservation to mark delegate termination due to error
delegate_outputs = (
self.delegate.state.outputs if self.delegate.state else {}
)
content = (
f'{self.delegate.agent.name} encountered an error during execution.'
)
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
self.event_stream.add_event(obs, EventSource.AGENT)
# close the delegate upon error
await self.delegate.close()
# resubscribe parent when delegate is finished
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
self.delegate = None
self.delegateAction = None
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
self.log('debug', 'Delegate agent has finished execution')
# retrieve delegate result
outputs = self.delegate.state.outputs if self.delegate.state else {}
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()
# resubscribe parent when delegate is finished
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
)
# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs = AgentDelegateObservation(outputs=outputs, content=content)
# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
return
async def _handle_traffic_control(
self, limit_type: str, current_value: float, max_value: float
) -> bool:

View File

@ -65,6 +65,7 @@ class EventStream:
_queue: queue.Queue[Event]
_queue_thread: threading.Thread
_queue_loop: asyncio.AbstractEventLoop | None
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]]
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
def __init__(self, sid: str, file_store: FileStore):
@ -72,8 +73,8 @@ class EventStream:
self.file_store = file_store
self._stop_flag = threading.Event()
self._queue: queue.Queue[Event] = queue.Queue()
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
self._thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] = {}
self._thread_pools = {}
self._thread_loops = {}
self._queue_loop = None
self._queue_thread = threading.Thread(target=self._run_queue_loop)
self._queue_thread.daemon = True
@ -257,7 +258,7 @@ class EventStream:
def add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
)
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
@ -285,6 +286,8 @@ class EventStream:
event = self._queue.get(timeout=0.1)
except queue.Empty:
continue
# pass each event to each callback in order
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:

View File

@ -125,7 +125,7 @@ class Runtime(FileEditRuntimeMixin):
def setup_initial_env(self) -> None:
if self.attach_to_existing:
return
logger.debug(f'Adding env vars: {self.initial_env_vars}')
logger.debug(f'Adding env vars: {self.initial_env_vars.keys()}')
self.add_env_vars(self.initial_env_vars)
if self.config.sandbox.runtime_startup_env_vars:
self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
@ -172,7 +172,7 @@ class Runtime(FileEditRuntimeMixin):
obs = self.run(CmdRunAction(cmd))
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
raise RuntimeError(
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
f'Failed to add env vars [{env_vars.keys()}] to environment: {obs.content}'
)
def on_event(self, event: Event) -> None:

View File

@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
@ -130,7 +130,7 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
@pytest.mark.asyncio
async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
async def test_run_controller_with_fatal_error():
config = AppConfig()
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(sid='test', file_store=file_store)
@ -239,55 +239,6 @@ async def test_run_controller_stop_with_stuck():
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
@pytest.mark.asyncio
@pytest.mark.parametrize(
'delegate_state',
[
AgentState.RUNNING,
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
],
)
async def test_delegate_step_different_states(
mock_agent, mock_event_stream, delegate_state
):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
mock_delegate = AsyncMock()
controller.delegate = mock_delegate
mock_delegate.state.iteration = 5
mock_delegate.state.outputs = {'result': 'test'}
mock_delegate.agent.name = 'TestDelegate'
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
mock_delegate._step = AsyncMock()
mock_delegate.close = AsyncMock()
await controller._delegate_step()
mock_delegate._step.assert_called_once()
if delegate_state == AgentState.RUNNING:
assert controller.delegate is not None
assert controller.state.iteration == 0
mock_delegate.close.assert_not_called()
else:
assert controller.delegate is None
assert controller.state.iteration == 5
mock_delegate.close.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_max_iterations_extension(mock_agent, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations

View File

@ -0,0 +1,187 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, Mock
from uuid import uuid4
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream
from openhands.events.action import (
AgentDelegateAction,
AgentFinishAction,
MessageAction,
)
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_event_stream():
"""Creates an event stream in memory."""
sid = f'test-{uuid4()}'
file_store = InMemoryFileStore({})
return EventStream(sid=sid, file_store=file_store)
@pytest.fixture
def mock_parent_agent():
"""Creates a mock parent agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ParentAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
return agent
@pytest.fixture
def mock_child_agent():
"""Creates a mock child agent for testing delegation."""
agent = MagicMock(spec=Agent)
agent.name = 'ChildAgent'
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig()
agent.config = AgentConfig()
return agent
@pytest.mark.asyncio
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
"""
Test that when the parent agent delegates to a child, the parent's delegate
is set, and once the child finishes, the parent is cleaned up properly.
"""
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
# Create parent controller
parent_state = State(max_iterations=10)
parent_controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='parent',
confirmation_mode=False,
headless_mode=True,
initial_state=parent_state,
)
# Setup a delegate action from the parent
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
mock_parent_agent.step.return_value = delegate_action
# Simulate a user message event to cause parent.step() to run
message_action = MessageAction(content='please delegate now')
message_action._source = EventSource.USER
await parent_controller._on_event(message_action)
# Give time for the async step() to execute
await asyncio.sleep(1)
# The parent should receive step() from that event
# Verify that a delegate agent controller is created
assert (
parent_controller.delegate is not None
), "Parent's delegate controller was not set."
# The parent's iteration should have incremented
assert (
parent_controller.state.iteration == 1
), 'Parent iteration should be incremented after step.'
# Now simulate that the child increments local iteration and finishes its subtask
delegate_controller = parent_controller.delegate
delegate_controller.state.iteration = 5 # child had some steps
delegate_controller.state.outputs = {'delegate_result': 'done'}
# The child is done, so we simulate it finishing:
child_finish_action = AgentFinishAction()
await delegate_controller._on_event(child_finish_action)
await asyncio.sleep(0.5)
# Now the parent's delegate is None
assert (
parent_controller.delegate is None
), 'Parent delegate should be None after child finishes.'
# Parent's global iteration is updated from the child
assert (
parent_controller.state.iteration == 6
), "Parent iteration should be the child's iteration + 1 after child is done."
# Cleanup
await parent_controller.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'delegate_state',
[
AgentState.RUNNING,
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
],
)
async def test_delegate_step_different_states(
mock_parent_agent, mock_event_stream, delegate_state
):
"""Ensure that delegate is closed or remains open based on the delegate's state."""
controller = AgentController(
agent=mock_parent_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
mock_delegate = AsyncMock()
controller.delegate = mock_delegate
mock_delegate.state.iteration = 5
mock_delegate.state.outputs = {'result': 'test'}
mock_delegate.agent.name = 'TestDelegate'
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
mock_delegate._step = AsyncMock()
mock_delegate.close = AsyncMock()
def call_on_event_with_new_loop():
"""
In this thread, create and set a fresh event loop, so that the run_until_complete()
calls inside controller.on_event(...) find a valid loop.
"""
loop_in_thread = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop_in_thread)
msg_action = MessageAction(content='Test message')
msg_action._source = EventSource.USER
controller.on_event(msg_action)
finally:
loop_in_thread.close()
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
future = loop.run_in_executor(executor, call_on_event_with_new_loop)
await future
if delegate_state == AgentState.RUNNING:
assert controller.delegate is not None
assert controller.state.iteration == 0
mock_delegate.close.assert_not_called()
else:
assert controller.delegate is None
assert controller.state.iteration == 5
mock_delegate.close.assert_called_once()
await controller.close()