tobitege bc31fb15fe
(fix) CodeActAgent: fix issues with vision support in prompts (#3665)
* CodeActAgent: fix message prep if prompt caching is not supported

* fix python version in regen tests workflow

* fix in conftest "mock_completion" method

* add disable_vision to LLMConfig; revert change in message parsing in llm.py

* format messages in several files for completion

* refactored message(s) formatting (llm.py); added vision_is_active()

* fix a unit test

* regenerate: added LOG_TO_FILE and FORCE_REGENERATE env flags

* try to fix path to logs folder in workflow

* llm: prevent index error

* try FORCE_USE_LLM in regenerate

* tweaks everywhere...

* fix 2 random unit test errors :(

* added FORCE_REGENERATE_TESTS=true to regenerate CLI

* fix test_lint_file_fail_typescript again

* double-quotes for env vars in workflow; llm logger set to debug

* fix typo in regenerate

* regenerate iterations now 20; applied iteration counter fix by Li

* regenerate: pass FORCE_REGENERATE flag into env

* fixes for int tests. several mock files updated.

* browsing_agent: fix response_parser.py adding ) to empty response

* test_browse_internet: fix skipif and revert obsolete mock files

* regenerate: fi bracketing for http server start/kill conditions

* disable test_browse_internet for CodeAct*Agents; mock files updated after merge

* missed to include more mock files earlier

* reverts after review feedback from Li

* forgot one

* browsing agent test, partial fixes and updated mock files

* test_browse_internet works in my WSL now!

* adapt unit test test_prompt_caching.py

* add DEBUG to regenerate workflow command

* convert regenerate workflow params to inputs

* more integration test mock files updated

* more files

* test_prompt_caching: restored test_prompt_caching_headers purpose

* file_ops: fix potential exception, like "cross device copy"; fixed mock files accordingly

* reverts/changes wrt feedback from xingyao

* updated docs and config template

* code cleanup wrt review feedback
2024-09-04 17:58:30 +02:00

203 lines
7.3 KiB
Python

import atexit
import copy
import json
import os
from abc import abstractmethod
from openhands.core.config import AppConfig, SandboxConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
Action,
ActionConfirmationStatus,
BrowseInteractiveAction,
BrowseURLAction,
CmdRunAction,
FileReadAction,
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.event import Event
from openhands.events.observation import (
CmdOutputObservation,
ErrorObservation,
NullObservation,
Observation,
UserRejectObservation,
)
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
ret = {}
for key in os.environ:
if key.startswith('SANDBOX_ENV_'):
sandbox_key = key.removeprefix('SANDBOX_ENV_')
ret[sandbox_key] = os.environ[key]
if sandbox_config.enable_auto_lint:
ret['ENABLE_AUTO_LINT'] = 'true'
return ret
class Runtime:
"""The runtime is how the agent interacts with the external environment.
This includes a bash sandbox, a browser, and filesystem interactions.
sid is the session id, which is used to identify the current user session.
"""
sid: str
config: AppConfig
DEFAULT_ENV_VARS: dict[str, str]
def __init__(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
):
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
self.config = copy.deepcopy(config)
self.DEFAULT_ENV_VARS = _default_env_vars(config.sandbox)
atexit.register(self.close)
logger.debug(f'Runtime `{sid}`')
if self.DEFAULT_ENV_VARS:
logger.debug(f'Adding default env vars: {self.DEFAULT_ENV_VARS}')
self.add_env_vars(self.DEFAULT_ENV_VARS)
if env_vars is not None:
logger.debug(f'Adding provided env vars: {env_vars}')
self.add_env_vars(env_vars)
def close(self) -> None:
pass
# ====================================================================
def add_env_vars(self, env_vars: dict[str, str]) -> None:
# Add env vars to the IPython shell (if Jupyter is used)
if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
code = 'import os\n'
for key, value in env_vars.items():
# Note: json.dumps gives us nice escaping for free
code += f'os.environ["{key}"] = {json.dumps(value)}\n'
code += '\n'
obs = self.run_ipython(IPythonRunCellAction(code))
logger.info(f'Added env vars to IPython: code={code}, obs={obs}')
# Add env vars to the Bash shell
cmd = ''
for key, value in env_vars.items():
# Note: json.dumps gives us nice escaping for free
cmd += f'export {key}={json.dumps(value)}; '
if not cmd:
return
cmd = cmd.strip()
logger.debug(f'Adding env var: {cmd}')
obs = self.run(CmdRunAction(cmd))
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
raise RuntimeError(
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
)
async def on_event(self, event: Event) -> None:
if isinstance(event, Action):
# set timeout to default if not set
if event.timeout is None:
event.timeout = self.config.sandbox.timeout
assert event.timeout is not None
observation = self.run_action(event)
observation._cause = event.id # type: ignore[attr-defined]
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
def run_action(self, action: Action) -> Observation:
"""Run an action and return the resulting observation.
If the action is not runnable in any runtime, a NullObservation is returned.
If the action is not supported by the current runtime, an ErrorObservation is returned.
"""
if not action.runnable:
return NullObservation('')
if (
hasattr(action, 'is_confirmed')
and action.is_confirmed == ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
return ErrorObservation(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.'
)
if (
hasattr(action, 'is_confirmed')
and action.is_confirmed == ActionConfirmationStatus.REJECTED
):
return UserRejectObservation(
'Action has been rejected by the user! Waiting for further user input.'
)
observation = getattr(self, action_type)(action)
return observation
# ====================================================================
# Context manager
# ====================================================================
def __enter__(self) -> 'Runtime':
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()
# ====================================================================
# Action execution
# ====================================================================
@abstractmethod
def run(self, action: CmdRunAction) -> Observation:
pass
@abstractmethod
def run_ipython(self, action: IPythonRunCellAction) -> Observation:
pass
@abstractmethod
def read(self, action: FileReadAction) -> Observation:
pass
@abstractmethod
def write(self, action: FileWriteAction) -> Observation:
pass
@abstractmethod
def browse(self, action: BrowseURLAction) -> Observation:
pass
@abstractmethod
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
pass
# ====================================================================
# File operations
# ====================================================================
@abstractmethod
def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
raise NotImplementedError('This method is not implemented in the base class.')
@abstractmethod
def list_files(self, path: str | None = None) -> list[str]:
"""List files in the sandbox.
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
"""
raise NotImplementedError('This method is not implemented in the base class.')