refactor the logic in agent_controller to imporve readability (#3873)

Signed-off-by: Yi Lin <teroincn@gmail.com>
This commit is contained in:
niliy01 2024-09-17 02:13:52 +08:00 committed by GitHub
parent 41a54378dc
commit 804674bb9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 226 additions and 146 deletions

View File

@ -172,56 +172,83 @@ class AgentController:
Args:
event (Event): The incoming event to process.
"""
if isinstance(event, ChangeAgentStateAction):
await self.set_agent_state_to(event.agent_state) # type: ignore
elif isinstance(event, MessageAction):
if event.source == EventSource.USER:
logger.info(
event,
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif event.source == EventSource.AGENT and event.wait_for_response:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
elif isinstance(event, AgentDelegateAction):
await self.start_delegate(event)
elif isinstance(event, AddTaskAction):
self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
elif isinstance(event, ModifyTaskAction):
self.state.root_task.set_subtask_state(event.task_id, event.state)
elif isinstance(event, AgentFinishAction):
self.state.outputs = event.outputs
if isinstance(event, Action):
await self._handle_action(event)
elif isinstance(event, Observation):
await self._handle_observation(event)
async def _handle_action(self, action: Action):
"""Handles actions from the event stream.
Args:
action (Action): The action to handle.
"""
if isinstance(action, ChangeAgentStateAction):
await self.set_agent_state_to(action.agent_state) # type: ignore
elif isinstance(action, MessageAction):
await self._handle_message_action(action)
elif isinstance(action, AgentDelegateAction):
await self.start_delegate(action)
elif isinstance(action, AddTaskAction):
self.state.root_task.add_subtask(
action.parent, action.goal, action.subtasks
)
elif isinstance(action, ModifyTaskAction):
self.state.root_task.set_subtask_state(action.task_id, action.state)
elif isinstance(action, AgentFinishAction):
self.state.outputs = action.outputs
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.FINISHED)
elif isinstance(event, AgentRejectAction):
self.state.outputs = event.outputs
elif isinstance(action, AgentRejectAction):
self.state.outputs = action.outputs
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.REJECTED)
elif isinstance(event, Observation):
if (
self._pending_action
and hasattr(self._pending_action, 'is_confirmed')
and self._pending_action.is_confirmed
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return
if self._pending_action and self._pending_action.id == event.cause:
self._pending_action = None
if self.state.agent_state == AgentState.USER_CONFIRMED:
await self.set_agent_state_to(AgentState.RUNNING)
if self.state.agent_state == AgentState.USER_REJECTED:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, CmdOutputObservation):
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, AgentDelegateObservation):
self.state.history.on_event(event)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, ErrorObservation):
logger.info(event, extra={'msg_type': 'OBSERVATION'})
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)
async def _handle_observation(self, observation: Observation):
"""Handles observation from the event stream.
Args:
observation (observation): The observation to handle.
"""
if (
self._pending_action
and hasattr(self._pending_action, 'is_confirmed')
and self._pending_action.is_confirmed
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return
logger.info(observation, extra={'msg_type': 'OBSERVATION'})
if self._pending_action and self._pending_action.id == observation.cause:
self._pending_action = None
if self.state.agent_state == AgentState.USER_CONFIRMED:
await self.set_agent_state_to(AgentState.RUNNING)
if self.state.agent_state == AgentState.USER_REJECTED:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
return
if isinstance(observation, CmdOutputObservation):
return
elif isinstance(observation, AgentDelegateObservation):
self.state.history.on_event(observation)
elif isinstance(observation, ErrorObservation):
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)
async def _handle_message_action(self, action: MessageAction):
"""Handles message actions from the event stream.
Args:
action (MessageAction): The message action to handle.
"""
if action.source == EventSource.USER:
logger.info(
action, extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}
)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
def reset_task(self):
"""Resets the agent's task."""
@ -242,9 +269,11 @@ class AgentController:
if new_state == self.state.agent_state:
return
if (
self.state.agent_state == AgentState.PAUSED
and new_state == AgentState.RUNNING
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
self.reset_task()
elif (
new_state == AgentState.RUNNING
and self.state.agent_state == AgentState.PAUSED
and self.state.traffic_control_state == TrafficControlState.THROTTLING
):
# user intends to interrupt traffic control and let the task resume temporarily
@ -257,6 +286,7 @@ class AgentController:
):
if self.state.iteration >= self.state.max_iterations:
self.state.max_iterations += self._initial_max_iterations
if (
self.state.metrics.accumulated_cost is not None
and self.max_budget_per_task is not None
@ -264,12 +294,7 @@ class AgentController:
):
if self.state.metrics.accumulated_cost >= self.max_budget_per_task:
self.max_budget_per_task += self._initial_max_budget_per_task
self.state.agent_state = new_state
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
self.reset_task()
if self._pending_action is not None and (
elif self._pending_action is not None and (
new_state == AgentState.USER_CONFIRMED
or new_state == AgentState.USER_REJECTED
):
@ -281,6 +306,7 @@ class AgentController:
self._pending_action.is_confirmed = ActionConfirmationStatus.REJECTED # type: ignore[attr-defined]
self.event_stream.add_event(self._pending_action, EventSource.AGENT)
self.state.agent_state = new_state
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
)
@ -355,56 +381,8 @@ class AgentController:
return
if self.delegate is not None:
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
assert self.delegate != self
await self.delegate._step()
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
logger.debug(
f'[Agent Controller {self.id}] Delegate state: {delegate_state}'
)
if delegate_state == AgentState.ERROR:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# close the delegate upon error
await self.delegate.close()
self.delegate = None
self.delegateAction = None
await self.report_error('Delegator agent encounters an error')
return
delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED)
if delegate_done:
logger.info(
f'[Agent Controller {self.id}] Delegate agent has finished execution'
)
# retrieve delegate result
outputs = self.delegate.state.outputs if self.delegate.state else {}
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()
# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs: Observation = AgentDelegateObservation(
outputs=outputs, content=content
)
# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
await self._delegate_step()
return
logger.info(
@ -412,50 +390,20 @@ class AgentController:
extra={'msg_type': 'STEP'},
)
# check if agent hit the resources limit
stop_step = False
if self.state.iteration >= self.state.max_iterations:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
'Agent reached maximum number of iterations in headless mode, task stopped.'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
elif self.max_budget_per_task is not None:
stop_step = await self._handle_traffic_control(
'iteration', self.state.iteration, self.state.max_iterations
)
if self.max_budget_per_task is not None:
current_cost = self.state.metrics.accumulated_cost
if current_cost > self.max_budget_per_task:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# there is no way to resume
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, max budget: {self.max_budget_per_task:.2f}, task stopped.'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
stop_step = await self._handle_traffic_control(
'budget', current_cost, self.max_budget_per_task
)
if stop_step:
return
self.update_state_before_step()
action: Action = NullAction()
@ -492,6 +440,89 @@ class AgentController:
await self.report_error('Agent got stuck in a loop')
await self.set_agent_state_to(AgentState.ERROR)
async def _delegate_step(self):
"""Executes a single step of the delegate agent."""
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
await self.delegate._step() # type: ignore[union-attr]
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
logger.debug(f'[Agent Controller {self.id}] Delegate state: {delegate_state}')
if delegate_state == AgentState.ERROR:
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# close the delegate upon error
await self.delegate.close()
self.delegate = None
self.delegateAction = None
await self.report_error('Delegator agent encounters an error')
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
logger.info(
f'[Agent Controller {self.id}] Delegate agent has finished execution'
)
# retrieve delegate result
outputs = self.delegate.state.outputs if self.delegate.state else {}
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()
# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs: Observation = AgentDelegateObservation(
outputs=outputs, content=content
)
# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
return
async def _handle_traffic_control(
self, limit_type: str, current_value: float, max_value: float
):
"""Handles agent state after hitting the traffic control limit.
Args:
limit_type (str): The type of limit that was hit.
current_value (float): The current value of the limit.
max_value (float): The maximum value of the limit.
"""
stop_step = False
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info('Hitting traffic control, temporarily resume upon user request')
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
f'Agent reached maximum {limit_type} in headless mode, task stopped. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Agent reached maximum {limit_type}, task paused. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
f'{TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
stop_step = True
return stop_step
def get_state(self):
"""Returns the current running state object.

View File

@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, Mock
import pytest
@ -123,6 +123,55 @@ async def test_step_with_exception(mock_agent, mock_event_stream):
await controller.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'delegate_state',
[
AgentState.RUNNING,
AgentState.FINISHED,
AgentState.ERROR,
AgentState.REJECTED,
],
)
async def test_delegate_step_different_states(
mock_agent, mock_event_stream, delegate_state
):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
mock_delegate = AsyncMock()
controller.delegate = mock_delegate
mock_delegate.state.iteration = 5
mock_delegate.state.outputs = {'result': 'test'}
mock_delegate.agent.name = 'TestDelegate'
mock_delegate.get_agent_state = Mock(return_value=delegate_state)
mock_delegate._step = AsyncMock()
mock_delegate.close = AsyncMock()
await controller._delegate_step()
mock_delegate._step.assert_called_once()
if delegate_state == AgentState.RUNNING:
assert controller.delegate is not None
assert controller.state.iteration == 0
mock_delegate.close.assert_not_called()
else:
assert controller.delegate is None
assert controller.state.iteration == 5
mock_delegate.close.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_step_max_iterations(mock_agent, mock_event_stream):
controller = AgentController(