mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
refactor the logic in agent_controller to imporve readability (#3873)
Signed-off-by: Yi Lin <teroincn@gmail.com>
This commit is contained in:
parent
41a54378dc
commit
804674bb9f
@ -172,56 +172,83 @@ class AgentController:
|
||||
Args:
|
||||
event (Event): The incoming event to process.
|
||||
"""
|
||||
if isinstance(event, ChangeAgentStateAction):
|
||||
await self.set_agent_state_to(event.agent_state) # type: ignore
|
||||
elif isinstance(event, MessageAction):
|
||||
if event.source == EventSource.USER:
|
||||
logger.info(
|
||||
event,
|
||||
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
|
||||
)
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif event.source == EventSource.AGENT and event.wait_for_response:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
elif isinstance(event, AgentDelegateAction):
|
||||
await self.start_delegate(event)
|
||||
elif isinstance(event, AddTaskAction):
|
||||
self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
|
||||
elif isinstance(event, ModifyTaskAction):
|
||||
self.state.root_task.set_subtask_state(event.task_id, event.state)
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
self.state.outputs = event.outputs
|
||||
if isinstance(event, Action):
|
||||
await self._handle_action(event)
|
||||
elif isinstance(event, Observation):
|
||||
await self._handle_observation(event)
|
||||
|
||||
async def _handle_action(self, action: Action):
|
||||
"""Handles actions from the event stream.
|
||||
|
||||
Args:
|
||||
action (Action): The action to handle.
|
||||
"""
|
||||
if isinstance(action, ChangeAgentStateAction):
|
||||
await self.set_agent_state_to(action.agent_state) # type: ignore
|
||||
elif isinstance(action, MessageAction):
|
||||
await self._handle_message_action(action)
|
||||
elif isinstance(action, AgentDelegateAction):
|
||||
await self.start_delegate(action)
|
||||
elif isinstance(action, AddTaskAction):
|
||||
self.state.root_task.add_subtask(
|
||||
action.parent, action.goal, action.subtasks
|
||||
)
|
||||
elif isinstance(action, ModifyTaskAction):
|
||||
self.state.root_task.set_subtask_state(action.task_id, action.state)
|
||||
elif isinstance(action, AgentFinishAction):
|
||||
self.state.outputs = action.outputs
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.FINISHED)
|
||||
elif isinstance(event, AgentRejectAction):
|
||||
self.state.outputs = event.outputs
|
||||
elif isinstance(action, AgentRejectAction):
|
||||
self.state.outputs = action.outputs
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.REJECTED)
|
||||
elif isinstance(event, Observation):
|
||||
if (
|
||||
self._pending_action
|
||||
and hasattr(self._pending_action, 'is_confirmed')
|
||||
and self._pending_action.is_confirmed
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
return
|
||||
if self._pending_action and self._pending_action.id == event.cause:
|
||||
self._pending_action = None
|
||||
if self.state.agent_state == AgentState.USER_CONFIRMED:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
if self.state.agent_state == AgentState.USER_REJECTED:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
elif isinstance(event, CmdOutputObservation):
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
elif isinstance(event, AgentDelegateObservation):
|
||||
self.state.history.on_event(event)
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
elif isinstance(event, ErrorObservation):
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
if self.state.agent_state == AgentState.ERROR:
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
|
||||
async def _handle_observation(self, observation: Observation):
|
||||
"""Handles observation from the event stream.
|
||||
|
||||
Args:
|
||||
observation (observation): The observation to handle.
|
||||
"""
|
||||
if (
|
||||
self._pending_action
|
||||
and hasattr(self._pending_action, 'is_confirmed')
|
||||
and self._pending_action.is_confirmed
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
return
|
||||
|
||||
logger.info(observation, extra={'msg_type': 'OBSERVATION'})
|
||||
if self._pending_action and self._pending_action.id == observation.cause:
|
||||
self._pending_action = None
|
||||
if self.state.agent_state == AgentState.USER_CONFIRMED:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
if self.state.agent_state == AgentState.USER_REJECTED:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
return
|
||||
|
||||
if isinstance(observation, CmdOutputObservation):
|
||||
return
|
||||
elif isinstance(observation, AgentDelegateObservation):
|
||||
self.state.history.on_event(observation)
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
if self.state.agent_state == AgentState.ERROR:
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
|
||||
async def _handle_message_action(self, action: MessageAction):
|
||||
"""Handles message actions from the event stream.
|
||||
|
||||
Args:
|
||||
action (MessageAction): The message action to handle.
|
||||
"""
|
||||
if action.source == EventSource.USER:
|
||||
logger.info(
|
||||
action, extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}
|
||||
)
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif action.source == EventSource.AGENT and action.wait_for_response:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
|
||||
def reset_task(self):
|
||||
"""Resets the agent's task."""
|
||||
@ -242,9 +269,11 @@ class AgentController:
|
||||
if new_state == self.state.agent_state:
|
||||
return
|
||||
|
||||
if (
|
||||
self.state.agent_state == AgentState.PAUSED
|
||||
and new_state == AgentState.RUNNING
|
||||
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
|
||||
self.reset_task()
|
||||
elif (
|
||||
new_state == AgentState.RUNNING
|
||||
and self.state.agent_state == AgentState.PAUSED
|
||||
and self.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
):
|
||||
# user intends to interrupt traffic control and let the task resume temporarily
|
||||
@ -257,6 +286,7 @@ class AgentController:
|
||||
):
|
||||
if self.state.iteration >= self.state.max_iterations:
|
||||
self.state.max_iterations += self._initial_max_iterations
|
||||
|
||||
if (
|
||||
self.state.metrics.accumulated_cost is not None
|
||||
and self.max_budget_per_task is not None
|
||||
@ -264,12 +294,7 @@ class AgentController:
|
||||
):
|
||||
if self.state.metrics.accumulated_cost >= self.max_budget_per_task:
|
||||
self.max_budget_per_task += self._initial_max_budget_per_task
|
||||
|
||||
self.state.agent_state = new_state
|
||||
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
|
||||
self.reset_task()
|
||||
|
||||
if self._pending_action is not None and (
|
||||
elif self._pending_action is not None and (
|
||||
new_state == AgentState.USER_CONFIRMED
|
||||
or new_state == AgentState.USER_REJECTED
|
||||
):
|
||||
@ -281,6 +306,7 @@ class AgentController:
|
||||
self._pending_action.is_confirmed = ActionConfirmationStatus.REJECTED # type: ignore[attr-defined]
|
||||
self.event_stream.add_event(self._pending_action, EventSource.AGENT)
|
||||
|
||||
self.state.agent_state = new_state
|
||||
self.event_stream.add_event(
|
||||
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
|
||||
)
|
||||
@ -355,56 +381,8 @@ class AgentController:
|
||||
return
|
||||
|
||||
if self.delegate is not None:
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
|
||||
assert self.delegate != self
|
||||
await self.delegate._step()
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
|
||||
assert self.delegate is not None
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
logger.debug(
|
||||
f'[Agent Controller {self.id}] Delegate state: {delegate_state}'
|
||||
)
|
||||
if delegate_state == AgentState.ERROR:
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close the delegate upon error
|
||||
await self.delegate.close()
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
|
||||
await self.report_error('Delegator agent encounters an error')
|
||||
return
|
||||
delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED)
|
||||
if delegate_done:
|
||||
logger.info(
|
||||
f'[Agent Controller {self.id}] Delegate agent has finished execution'
|
||||
)
|
||||
# retrieve delegate result
|
||||
outputs = self.delegate.state.outputs if self.delegate.state else {}
|
||||
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close delegate controller: we must close the delegate controller before adding new events
|
||||
await self.delegate.close()
|
||||
|
||||
# update delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in outputs.items()
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
obs: Observation = AgentDelegateObservation(
|
||||
outputs=outputs, content=content
|
||||
)
|
||||
|
||||
# clean up delegate status
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
await self._delegate_step()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
@ -412,50 +390,20 @@ class AgentController:
|
||||
extra={'msg_type': 'STEP'},
|
||||
)
|
||||
|
||||
# check if agent hit the resources limit
|
||||
stop_step = False
|
||||
if self.state.iteration >= self.state.max_iterations:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
if self.headless_mode:
|
||||
# set to ERROR state if running in headless mode
|
||||
# since user cannot resume on the web interface
|
||||
await self.report_error(
|
||||
'Agent reached maximum number of iterations in headless mode, task stopped.'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
else:
|
||||
await self.report_error(
|
||||
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
elif self.max_budget_per_task is not None:
|
||||
stop_step = await self._handle_traffic_control(
|
||||
'iteration', self.state.iteration, self.state.max_iterations
|
||||
)
|
||||
if self.max_budget_per_task is not None:
|
||||
current_cost = self.state.metrics.accumulated_cost
|
||||
if current_cost > self.max_budget_per_task:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
if self.headless_mode:
|
||||
# set to ERROR state if running in headless mode
|
||||
# there is no way to resume
|
||||
await self.report_error(
|
||||
f'Task budget exceeded. Current cost: {current_cost:.2f}, max budget: {self.max_budget_per_task:.2f}, task stopped.'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
else:
|
||||
await self.report_error(
|
||||
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
stop_step = await self._handle_traffic_control(
|
||||
'budget', current_cost, self.max_budget_per_task
|
||||
)
|
||||
if stop_step:
|
||||
return
|
||||
|
||||
self.update_state_before_step()
|
||||
action: Action = NullAction()
|
||||
@ -492,6 +440,89 @@ class AgentController:
|
||||
await self.report_error('Agent got stuck in a loop')
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
|
||||
async def _delegate_step(self):
|
||||
"""Executes a single step of the delegate agent."""
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
|
||||
await self.delegate._step() # type: ignore[union-attr]
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
|
||||
assert self.delegate is not None
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate state: {delegate_state}')
|
||||
if delegate_state == AgentState.ERROR:
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close the delegate upon error
|
||||
await self.delegate.close()
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
|
||||
await self.report_error('Delegator agent encounters an error')
|
||||
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
logger.info(
|
||||
f'[Agent Controller {self.id}] Delegate agent has finished execution'
|
||||
)
|
||||
# retrieve delegate result
|
||||
outputs = self.delegate.state.outputs if self.delegate.state else {}
|
||||
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close delegate controller: we must close the delegate controller before adding new events
|
||||
await self.delegate.close()
|
||||
|
||||
# update delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in outputs.items()
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
obs: Observation = AgentDelegateObservation(
|
||||
outputs=outputs, content=content
|
||||
)
|
||||
|
||||
# clean up delegate status
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
return
|
||||
|
||||
async def _handle_traffic_control(
|
||||
self, limit_type: str, current_value: float, max_value: float
|
||||
):
|
||||
"""Handles agent state after hitting the traffic control limit.
|
||||
|
||||
Args:
|
||||
limit_type (str): The type of limit that was hit.
|
||||
current_value (float): The current value of the limit.
|
||||
max_value (float): The maximum value of the limit.
|
||||
"""
|
||||
stop_step = False
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info('Hitting traffic control, temporarily resume upon user request')
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
if self.headless_mode:
|
||||
# set to ERROR state if running in headless mode
|
||||
# since user cannot resume on the web interface
|
||||
await self.report_error(
|
||||
f'Agent reached maximum {limit_type} in headless mode, task stopped. '
|
||||
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
else:
|
||||
await self.report_error(
|
||||
f'Agent reached maximum {limit_type}, task paused. '
|
||||
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
|
||||
f'{TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
stop_step = True
|
||||
return stop_step
|
||||
|
||||
def get_state(self):
|
||||
"""Returns the current running state object.
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -123,6 +123,55 @@ async def test_step_with_exception(mock_agent, mock_event_stream):
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'delegate_state',
|
||||
[
|
||||
AgentState.RUNNING,
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
AgentState.REJECTED,
|
||||
],
|
||||
)
|
||||
async def test_delegate_step_different_states(
|
||||
mock_agent, mock_event_stream, delegate_state
|
||||
):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
mock_delegate = AsyncMock()
|
||||
controller.delegate = mock_delegate
|
||||
|
||||
mock_delegate.state.iteration = 5
|
||||
mock_delegate.state.outputs = {'result': 'test'}
|
||||
mock_delegate.agent.name = 'TestDelegate'
|
||||
|
||||
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
|
||||
mock_delegate._step = AsyncMock()
|
||||
mock_delegate.close = AsyncMock()
|
||||
|
||||
await controller._delegate_step()
|
||||
|
||||
mock_delegate._step.assert_called_once()
|
||||
|
||||
if delegate_state == AgentState.RUNNING:
|
||||
assert controller.delegate is not None
|
||||
assert controller.state.iteration == 0
|
||||
mock_delegate.close.assert_not_called()
|
||||
else:
|
||||
assert controller.delegate is None
|
||||
assert controller.state.iteration == 5
|
||||
mock_delegate.close.assert_called_once()
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_iterations(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user