OpenHands/opendevin/controller/agent_controller.py
Leo 1356da8795
feat: support controlling agent task state. (#1094)
* feat: support controlling agent task state.

* feat: add agent task state to agent status bar.

* feat: add agent task control bar to FE.

* Remove stop agent task action.

* Merge pause and resume buttons into one button; Add loading and disabled status for action buttons.

* Apply suggestions from code review

---------

Co-authored-by: Robert Brennan <accounts@rbren.io>
2024-04-18 11:09:00 +00:00

219 lines
7.4 KiB
Python

import asyncio
import time
from typing import List, Callable
from opendevin.plan import Plan
from opendevin.state import State
from opendevin.agent import Agent
from opendevin.observation import Observation, AgentErrorObservation, NullObservation
from litellm.exceptions import APIConnectionError
from openai import AuthenticationError
from opendevin import config
from opendevin.logger import opendevin_logger as logger
from opendevin.exceptions import MaxCharsExceedError
from .action_manager import ActionManager
from opendevin.action import (
Action,
NullAction,
AgentFinishAction,
)
from opendevin.exceptions import AgentNoActionError
from ..action.tasks import TaskStateChangedAction
from ..schema import TaskState
MAX_ITERATIONS = config.get('MAX_ITERATIONS')
MAX_CHARS = config.get('MAX_CHARS')
class AgentController:
id: str
agent: Agent
max_iterations: int
action_manager: ActionManager
callbacks: List[Callable]
state: State | None = None
_task_state: TaskState = TaskState.INIT
_finished: bool = False
_cur_step: int = 0
def __init__(
self,
agent: Agent,
sid: str = '',
max_iterations: int = MAX_ITERATIONS,
max_chars: int = MAX_CHARS,
container_image: str | None = None,
callbacks: List[Callable] = [],
):
self.id = sid
self.agent = agent
self.max_iterations = max_iterations
self.action_manager = ActionManager(self.id, container_image)
self.max_chars = max_chars
self.callbacks = callbacks
def update_state_for_step(self, i):
if self.state is None:
return
self.state.iteration = i
self.state.background_commands_obs = self.action_manager.get_background_obs()
def update_state_after_step(self):
if self.state is None:
return
self.state.updated_info = []
def add_history(self, action: Action, observation: Observation):
if self.state is None:
return
if not isinstance(action, Action):
raise TypeError(
f'action must be an instance of Action, got {type(action).__name__} instead'
)
if not isinstance(observation, Observation):
raise TypeError(
f'observation must be an instance of Observation, got {type(observation).__name__} instead'
)
self.state.history.append((action, observation))
self.state.updated_info.append((action, observation))
async def _run(self):
if self.state is None:
return
if self._task_state != TaskState.RUNNING:
raise ValueError('Task is not in running state')
for i in range(self._cur_step, self.max_iterations):
try:
self._finished = await self.step(i)
except Exception as e:
logger.error('Error in loop', exc_info=True)
raise e
match self._task_state:
case TaskState.FINISHED, TaskState.STOPPED:
await self.reset_task() # type: ignore[unreachable]
break
case TaskState.PAUSED:
# save current state for resuming
self._cur_step = i + 1 # type: ignore[unreachable]
await self.notify_task_state_changed()
break
if not self._finished:
logger.info('Exited before finishing the task.')
self.agent.reset()
async def start(self, task: str):
"""Starts the agent controller with a task.
If task already run before, it will continue from the last step.
"""
self._task_state = TaskState.RUNNING
await self.notify_task_state_changed()
self.state = State(Plan(task))
await self._run()
async def resume(self):
if self.state is None:
raise ValueError('No task to resume')
self._task_state = TaskState.RUNNING
await self.notify_task_state_changed()
await self._run()
async def reset_task(self):
self.state = None
self._cur_step = 0
self._task_state = TaskState.INIT
await self.notify_task_state_changed()
async def set_task_state_to(self, state: TaskState):
self._task_state = state
if state == TaskState.STOPPED:
await self.reset_task()
logger.info(f'Task state set to {state}')
def get_task_state(self):
"""Returns the current state of the agent task."""
return self._task_state
async def notify_task_state_changed(self):
await self._run_callbacks(TaskStateChangedAction(self._task_state))
async def step(self, i: int):
if self.state is None:
return
logger.info(f'STEP {i}', extra={'msg_type': 'STEP'})
logger.info(self.state.plan.main_goal, extra={'msg_type': 'PLAN'})
if self.state.num_of_chars > self.max_chars:
raise MaxCharsExceedError(self.state.num_of_chars, self.max_chars)
log_obs = self.action_manager.get_background_obs()
for obs in log_obs:
self.add_history(NullAction(), obs)
await self._run_callbacks(obs)
logger.info(obs, extra={'msg_type': 'BACKGROUND LOG'})
self.update_state_for_step(i)
action: Action = NullAction()
observation: Observation = NullObservation('')
try:
action = self.agent.step(self.state)
if action is None:
raise AgentNoActionError()
logger.info(action, extra={'msg_type': 'ACTION'})
except Exception as e:
observation = AgentErrorObservation(str(e))
logger.error(e)
if isinstance(e, APIConnectionError):
time.sleep(3)
# raise specific exceptions that need to be handled outside
# note: we are using AuthenticationError class from openai rather than
# litellm because:
# 1) litellm.exceptions.AuthenticationError is a subclass of openai.AuthenticationError
# 2) embeddings call, initiated by llama-index, has no wrapper for authentication
# errors. This means we have to catch individual authentication errors
# from different providers, and OpenAI is one of these.
if isinstance(e, (AuthenticationError, AgentNoActionError)):
raise
self.update_state_after_step()
await self._run_callbacks(action)
finished = isinstance(action, AgentFinishAction)
if finished:
logger.info(action, extra={'msg_type': 'INFO'})
return True
if isinstance(observation, NullObservation):
observation = await self.action_manager.run_action(action, self)
if not isinstance(observation, NullObservation):
logger.info(observation, extra={'msg_type': 'OBSERVATION'})
self.add_history(action, observation)
await self._run_callbacks(observation)
async def _run_callbacks(self, event):
if event is None:
return
for callback in self.callbacks:
idx = self.callbacks.index(callback)
try:
await callback(event)
except Exception as e:
logger.exception(f'Callback error: {e}, idx: {idx}')
await asyncio.sleep(
0.001
) # Give back control for a tick, so we can await in callbacks