OpenHands/opendevin/controller/agent_controller.py
Boxuan Li dd32fa6f4a
Unify linter behaviour across CI and pre-commit-hook (#1071)
* CI: Add autopep8 linter

Currently, we have autopep8 as part of pre-commit-hook. To ensure
consistent behaviour, we should have it in CI as well.

Moreover, pre-commit-hook contains a double-quote-string-fixer hook
which changes all double quotes to single quotes, but I do observe
some PRs with massive changes that do the opposite way. I suspect
that these authors 1) disable or circumvent the pre-commit-hook,
and 2) have other linters such as black in their IDE, which
automatically change all single quotes to double quotes. This
has caused a lot of unnecessary diff, made review really hard,
and led to a lot of conflicts.

* Use -diff for autopep8

* autopep8: Freeze version in CI

* Ultimate fix

* Remove pep8 long line disable workaround

* Fix lint.yml

* Fix all files under opendevin and agenthub
2024-04-14 00:19:56 -04:00

213 lines
6.8 KiB
Python

import asyncio
import inspect
import traceback
from typing import List, Callable, Literal, Mapping, Awaitable, Any, cast
from termcolor import colored
from opendevin import config
from opendevin.action import (
Action,
NullAction,
AgentFinishAction,
AddTaskAction,
ModifyTaskAction,
)
from opendevin.agent import Agent
from opendevin.logger import opendevin_logger as logger
from opendevin.exceptions import MaxCharsExceedError
from opendevin.observation import Observation, AgentErrorObservation, NullObservation
from opendevin.plan import Plan
from opendevin.state import State
from .command_manager import CommandManager
ColorType = Literal[
'red',
'green',
'yellow',
'blue',
'magenta',
'cyan',
'light_grey',
'dark_grey',
'light_red',
'light_green',
'light_yellow',
'light_blue',
'light_magenta',
'light_cyan',
'white',
]
DISABLE_COLOR_PRINTING = (
config.get('DISABLE_COLOR').lower() == 'true'
)
MAX_ITERATIONS = config.get('MAX_ITERATIONS')
MAX_CHARS = config.get('MAX_CHARS')
def print_with_color(text: Any, print_type: str = 'INFO'):
TYPE_TO_COLOR: Mapping[str, ColorType] = {
'BACKGROUND LOG': 'blue',
'ACTION': 'green',
'OBSERVATION': 'yellow',
'INFO': 'cyan',
'ERROR': 'red',
'PLAN': 'light_magenta',
}
color = TYPE_TO_COLOR.get(print_type.upper(), TYPE_TO_COLOR['INFO'])
if DISABLE_COLOR_PRINTING:
print(f'\n{print_type.upper()}:\n{str(text)}', flush=True)
else:
print(
colored(f'\n{print_type.upper()}:\n', color, attrs=['bold'])
+ colored(str(text), color),
flush=True,
)
class AgentController:
id: str
agent: Agent
max_iterations: int
workdir: str
command_manager: CommandManager
callbacks: List[Callable]
def __init__(
self,
agent: Agent,
workdir: str,
sid: str = '',
max_iterations: int = MAX_ITERATIONS,
max_chars: int = MAX_CHARS,
container_image: str | None = None,
callbacks: List[Callable] = [],
):
self.id = sid
self.agent = agent
self.max_iterations = max_iterations
self.max_chars = max_chars
self.workdir = workdir
self.command_manager = CommandManager(
self.id, workdir, container_image)
self.callbacks = callbacks
def update_state_for_step(self, i):
self.state.iteration = i
self.state.background_commands_obs = self.command_manager.get_background_obs()
def update_state_after_step(self):
self.state.updated_info = []
def add_history(self, action: Action, observation: Observation):
if not isinstance(action, Action):
raise ValueError('action must be an instance of Action')
if not isinstance(observation, Observation):
raise ValueError('observation must be an instance of Observation')
self.state.history.append((action, observation))
self.state.updated_info.append((action, observation))
async def start_loop(self, task: str):
finished = False
plan = Plan(task)
self.state = State(plan)
for i in range(self.max_iterations):
try:
finished = await self.step(i)
except Exception as e:
logger.error('Error in loop', exc_info=True)
raise e
if finished:
break
if not finished:
logger.info('Exited before finishing the task.')
async def step(self, i: int):
print('\n\n==============', flush=True)
print('STEP', i, flush=True)
print_with_color(self.state.plan.main_goal, 'PLAN')
if self.state.num_of_chars > self.max_chars:
raise MaxCharsExceedError(
self.state.num_of_chars, self.max_chars)
log_obs = self.command_manager.get_background_obs()
for obs in log_obs:
self.add_history(NullAction(), obs)
await self._run_callbacks(obs)
print_with_color(obs, 'BACKGROUND LOG')
self.update_state_for_step(i)
action: Action = NullAction()
observation: Observation = NullObservation('')
try:
action = self.agent.step(self.state)
if action is None:
raise ValueError('Agent must return an action')
print_with_color(action, 'ACTION')
except Exception as e:
observation = AgentErrorObservation(str(e))
print_with_color(observation, 'ERROR')
traceback.print_exc()
# TODO Change to more robust error handling
if (
'The api_key client option must be set' in observation.content
or 'Incorrect API key provided:' in observation.content
):
raise
self.update_state_after_step()
await self._run_callbacks(action)
finished = isinstance(action, AgentFinishAction)
if finished:
print_with_color(action, 'INFO')
return True
if isinstance(action, AddTaskAction):
try:
self.state.plan.add_subtask(
action.parent, action.goal, action.subtasks)
except Exception as e:
observation = AgentErrorObservation(str(e))
print_with_color(observation, 'ERROR')
traceback.print_exc()
elif isinstance(action, ModifyTaskAction):
try:
self.state.plan.set_subtask_state(action.id, action.state)
except Exception as e:
observation = AgentErrorObservation(str(e))
print_with_color(observation, 'ERROR')
traceback.print_exc()
if action.executable:
try:
if inspect.isawaitable(action.run(self)):
observation = await cast(Awaitable[Observation], action.run(self))
else:
observation = action.run(self)
except Exception as e:
observation = AgentErrorObservation(str(e))
print_with_color(observation, 'ERROR')
traceback.print_exc()
if not isinstance(observation, NullObservation):
print_with_color(observation, 'OBSERVATION')
self.add_history(action, observation)
await self._run_callbacks(observation)
async def _run_callbacks(self, event):
if event is None:
return
for callback in self.callbacks:
idx = self.callbacks.index(callback)
try:
callback(event)
except Exception as e:
logger.exception(f'Callback error: {e}, idx: {idx}')
await asyncio.sleep(
0.001
) # Give back control for a tick, so we can await in callbacks