Implement basic modal sandbox support (#4133)

This commit is contained in:
Peyton Walters 2024-10-15 06:37:02 -04:00 committed by GitHub
parent 0ca66beac9
commit 9566ca4a3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1421 additions and 826 deletions

View File

@ -13,6 +13,10 @@
# API key for E2B # API key for E2B
#e2b_api_key = "" #e2b_api_key = ""
# API key for Modal
#modal_api_token_id = ""
#modal_api_token_secret = ""
# Base path for the workspace # Base path for the workspace
workspace_base = "./workspace" workspace_base = "./workspace"

View File

@ -67,6 +67,8 @@ class AppConfig:
max_iterations: int = OH_MAX_ITERATIONS max_iterations: int = OH_MAX_ITERATIONS
max_budget_per_task: float | None = None max_budget_per_task: float | None = None
e2b_api_key: str = '' e2b_api_key: str = ''
modal_api_token_id: str = ''
modal_api_token_secret: str = ''
disable_color: bool = False disable_color: bool = False
jwt_secret: str = uuid.uuid4().hex jwt_secret: str = uuid.uuid4().hex
debug: bool = False debug: bool = False
@ -142,6 +144,8 @@ class AppConfig:
'e2b_api_key', 'e2b_api_key',
'github_token', 'github_token',
'jwt_secret', 'jwt_secret',
'modal_api_token_id',
'modal_api_token_secret',
]: ]:
attr_value = '******' if attr_value else None attr_value = '******' if attr_value else None

View File

@ -92,6 +92,8 @@ class SensitiveDataFilter(logging.Filter):
'e2b_api_key', 'e2b_api_key',
'github_token', 'github_token',
'jwt_secret', 'jwt_secret',
'modal_api_token_id',
'modal_api_token_secret',
] ]
# add env var names # add env var names

View File

@ -1,3 +1,4 @@
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.e2b.sandbox import E2BBox from openhands.runtime.e2b.sandbox import E2BBox
@ -15,6 +16,11 @@ def get_runtime_cls(name: str):
from openhands.runtime.remote.runtime import RemoteRuntime from openhands.runtime.remote.runtime import RemoteRuntime
return RemoteRuntime return RemoteRuntime
elif name == 'modal':
logger.info('Using ModalRuntime')
from openhands.runtime.modal.runtime import ModalRuntime
return ModalRuntime
else: else:
raise ValueError(f'Runtime {name} not supported') raise ValueError(f'Runtime {name} not supported')

View File

@ -481,7 +481,9 @@ class RuntimeClient:
logger.debug(f'{self.pwd} != {jupyter_pwd} -> reset Jupyter PWD') logger.debug(f'{self.pwd} != {jupyter_pwd} -> reset Jupyter PWD')
reset_jupyter_pwd_code = f'import os; os.chdir("{self.pwd}")' reset_jupyter_pwd_code = f'import os; os.chdir("{self.pwd}")'
_aux_action = IPythonRunCellAction(code=reset_jupyter_pwd_code) _aux_action = IPythonRunCellAction(code=reset_jupyter_pwd_code)
_reset_obs = await _jupyter_plugin.run(_aux_action) _reset_obs: IPythonRunCellObservation = await _jupyter_plugin.run(
_aux_action
)
logger.debug( logger.debug(
f'Changed working directory in IPython to: {self.pwd}. Output: {_reset_obs}' f'Changed working directory in IPython to: {self.pwd}. Output: {_reset_obs}'
) )

View File

@ -70,8 +70,7 @@ class LogBuffer:
return logs return logs
def stream_logs(self): def stream_logs(self):
""" """Stream logs from the Docker container in a separate thread.
Stream logs from the Docker container in a separate thread.
This method runs in its own thread to handle the blocking This method runs in its own thread to handle the blocking
operation of reading log lines from the Docker SDK's synchronous generator. operation of reading log lines from the Docker SDK's synchronous generator.
@ -114,6 +113,22 @@ class EventStreamRuntime(Runtime):
container_name_prefix = 'openhands-runtime-' container_name_prefix = 'openhands-runtime-'
# Need to provide this method to allow inheritors to init the Runtime
# without initting the EventStreamRuntime.
def init_base_runtime(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
attach_to_existing: bool = False,
):
super().__init__(
config, event_stream, sid, plugins, env_vars, status_message_callback, attach_to_existing
)
def __init__( def __init__(
self, self,
config: AppConfig, config: AppConfig,
@ -175,22 +190,15 @@ class EventStreamRuntime(Runtime):
else: else:
self._attach_to_container() self._attach_to_container()
# will initialize both the event stream and the env vars # Will initialize both the event stream and the env vars
super().__init__( self.init_base_runtime(
config, config, event_stream, sid, plugins, env_vars, status_message_callback, attach_to_existing
event_stream,
sid,
plugins,
env_vars,
status_message_callback,
attach_to_existing,
) )
logger.info('Waiting for client to become ready...') logger.info('Waiting for client to become ready...')
self.send_status_message('STATUS$WAITING_FOR_CLIENT') self.send_status_message('STATUS$WAITING_FOR_CLIENT')
self._wait_until_alive() self._wait_until_alive()
self.setup_initial_env() self.setup_initial_env()
logger.info( logger.info(

View File

@ -0,0 +1,273 @@
import os
import tempfile
import threading
import uuid
from typing import Callable, Generator
import modal
import requests
import tenacity
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.runtime.client.runtime import EventStreamRuntime, LogBuffer
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.runtime_build import (
prep_docker_build_folder,
)
# Modal's log generator returns strings, but the upstream LogBuffer expects bytes.
def bytes_shim(string_generator) -> Generator[bytes, None, None]:
for line in string_generator:
yield line.encode('utf-8')
class ModalLogBuffer(LogBuffer):
"""Synchronous buffer for Modal sandbox logs.
This class provides a thread-safe way to collect, store, and retrieve logs
from a Modal sandbox. It uses a list to store log lines and provides methods
for appending, retrieving, and clearing logs.
"""
def __init__(self, sandbox: modal.Sandbox):
self.client_ready = False
self.init_msg = 'Runtime client initialized.'
self.buffer: list[str] = []
self.lock = threading.Lock()
self._stop_event = threading.Event()
self.log_generator = bytes_shim(sandbox.stderr)
self.log_stream_thread = threading.Thread(target=self.stream_logs)
self.log_stream_thread.daemon = True
self.log_stream_thread.start()
class ModalRuntime(EventStreamRuntime):
"""This runtime will subscribe the event stream.
When receive an event, it will send the event to runtime-client which run inside the Modal sandbox environment.
Args:
config (AppConfig): The application configuration.
event_stream (EventStream): The event stream to subscribe to.
sid (str, optional): The session ID. Defaults to 'default'.
plugins (list[PluginRequirement] | None, optional): List of plugin requirements. Defaults to None.
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
"""
container_name_prefix = 'openhands-sandbox-'
def __init__(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
):
assert config.modal_api_token_id, 'Modal API token id is required'
assert config.modal_api_token_secret, 'Modal API token secret is required'
self.config = config
self.modal_client = modal.Client.from_credentials(
config.modal_api_token_id, config.modal_api_token_secret
)
self.app = modal.App.lookup(
'openhands', create_if_missing=True, client=self.modal_client
)
# workspace_base cannot be used because we can't bind mount into a sandbox.
if self.config.workspace_base is not None:
logger.warning(
'Setting workspace_base is not supported in the modal runtime.'
)
# This value is arbitrary as it's private to the container
self.container_port = 3000
self.session = requests.Session()
self.instance_id = (
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
)
self.status_message_callback = status_message_callback
self.send_status_message('STATUS$STARTING_RUNTIME')
self.base_container_image_id = self.config.sandbox.base_container_image
self.runtime_container_image_id = self.config.sandbox.runtime_container_image
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
logger.info(f'ModalRuntime `{self.instance_id}`')
# Buffer for container logs
self.log_buffer: LogBuffer | None = None
if self.config.sandbox.runtime_extra_deps:
logger.debug(
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
)
self.image = self._get_image_definition(
self.base_container_image_id,
self.runtime_container_image_id,
self.config.sandbox.runtime_extra_deps,
)
self.sandbox = self._init_sandbox(
sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox,
plugins=plugins,
)
# Will initialize both the event stream and the env vars
self.init_base_runtime(
config, event_stream, sid, plugins, env_vars, status_message_callback
)
logger.info('Waiting for client to become ready...')
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
self._wait_until_alive()
self.setup_initial_env()
logger.info(
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'
)
self.send_status_message(' ')
def _get_image_definition(
self,
base_container_image_id: str | None,
runtime_container_image_id: str | None,
runtime_extra_deps: str | None,
) -> modal.Image:
if runtime_container_image_id:
base_runtime_image = modal.Image.from_registry(runtime_container_image_id)
elif base_container_image_id:
build_folder = tempfile.mkdtemp()
prep_docker_build_folder(
build_folder,
base_container_image_id,
extra_deps=runtime_extra_deps,
)
base_runtime_image = modal.Image.from_dockerfile(
path=os.path.join(build_folder, 'Dockerfile'),
context_mount=modal.Mount.from_local_dir(
local_path=build_folder,
remote_path='.', # to current WORKDIR
),
)
else:
raise ValueError(
'Neither runtime container image nor base container image is set'
)
return base_runtime_image.run_commands(
"""
# Disable bracketed paste
# https://github.com/pexpect/pexpect/issues/669
echo "set enable-bracketed-paste off" >> /etc/inputrc && \\
echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
""".strip()
)
@tenacity.retry(
stop=tenacity.stop_after_attempt(5),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
)
def _init_sandbox(
self,
sandbox_workspace_dir: str,
plugins: list[PluginRequirement] | None = None,
) -> modal.Sandbox:
try:
logger.info('Preparing to start container...')
self.send_status_message('STATUS$PREPARING_CONTAINER')
plugin_args = []
if plugins is not None and len(plugins) > 0:
plugin_args.append('--plugins')
plugin_args.extend([plugin.name for plugin in plugins])
# Combine environment variables
environment: dict[str, str | None] = {
'port': str(self.container_port),
'PYTHONUNBUFFERED': '1',
}
if self.config.debug:
environment['DEBUG'] = 'true'
browsergym_args = []
if self.config.sandbox.browsergym_eval_env is not None:
browsergym_args = [
'-browsergym-eval-env',
self.config.sandbox.browsergym_eval_env,
]
env_secret = modal.Secret.from_dict(environment)
logger.debug(f'Sandbox workspace: {sandbox_workspace_dir}')
sandbox_start_cmd: list[str] = [
'/openhands/micromamba/bin/micromamba',
'run',
'-n',
'openhands',
'poetry',
'run',
'python',
'-u',
'-m',
'openhands.runtime.client.client',
str(self.container_port),
'--working-dir',
sandbox_workspace_dir,
*plugin_args,
'--username',
'openhands' if self.config.run_as_openhands else 'root',
'--user-id',
str(self.config.sandbox.user_id),
*browsergym_args,
]
sandbox = modal.Sandbox.create(
*sandbox_start_cmd,
secrets=[env_secret],
workdir='/openhands/code',
encrypted_ports=[self.container_port],
image=self.image,
app=self.app,
client=self.modal_client,
timeout=60 * 60,
)
tunnel = sandbox.tunnels()[self.container_port]
self.api_url = tunnel.url
self.log_buffer = ModalLogBuffer(sandbox)
logger.info(f'Container started. Server url: {self.api_url}')
self.send_status_message('STATUS$CONTAINER_STARTED')
return sandbox
except Exception as e:
logger.error(
f'Error: Instance {self.instance_id} FAILED to start container!\n'
)
logger.exception(e)
self.close()
raise e
def close(self):
"""Closes the ModalRuntime and associated objects."""
# if self.temp_dir_handler:
# self.temp_dir_handler.__exit__(None, None, None)
if self.log_buffer:
self.log_buffer.close()
if self.session:
self.session.close()
if self.sandbox:
self.sandbox.terminate()

1907
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -55,6 +55,7 @@ python-dotenv = "*"
protobuf = "^4.21.6,<5.0.0" # chromadb currently fails on 5.0+ protobuf = "^4.21.6,<5.0.0" # chromadb currently fails on 5.0+
opentelemetry-api = "1.25.0" opentelemetry-api = "1.25.0"
opentelemetry-exporter-otlp-proto-grpc = "1.25.0" opentelemetry-exporter-otlp-proto-grpc = "1.25.0"
modal = "^0.64.145"
[tool.poetry.group.llama-index.dependencies] [tool.poetry.group.llama-index.dependencies]
llama-index = "*" llama-index = "*"

View File

@ -15,7 +15,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
TEST_RUNTIME = os.getenv('TEST_RUNTIME') TEST_RUNTIME = os.getenv('TEST_RUNTIME')
assert TEST_RUNTIME in ['eventstream', 'remote'] assert TEST_RUNTIME in ['eventstream', 'remote', 'modal']
_ = get_runtime_cls(TEST_RUNTIME) # make sure it does not raise an error _ = get_runtime_cls(TEST_RUNTIME) # make sure it does not raise an error
CONFIG = load_app_config() CONFIG = load_app_config()

View File

@ -477,15 +477,25 @@ def test_api_keys_repr_str():
agents={'agent': agent_config}, agents={'agent': agent_config},
e2b_api_key='my_e2b_api_key', e2b_api_key='my_e2b_api_key',
jwt_secret='my_jwt_secret', jwt_secret='my_jwt_secret',
modal_api_token_id='my_modal_api_token_id',
modal_api_token_secret='my_modal_api_token_secret',
) )
assert "e2b_api_key='******'" in repr(app_config) assert "e2b_api_key='******'" in repr(app_config)
assert "e2b_api_key='******'" in str(app_config) assert "e2b_api_key='******'" in str(app_config)
assert "jwt_secret='******'" in repr(app_config) assert "jwt_secret='******'" in repr(app_config)
assert "jwt_secret='******'" in str(app_config) assert "jwt_secret='******'" in str(app_config)
assert "modal_api_token_id='******'" in repr(app_config)
assert "modal_api_token_id='******'" in str(app_config)
assert "modal_api_token_secret='******'" in repr(app_config)
assert "modal_api_token_secret='******'" in str(app_config)
# Check that no other attrs in AppConfig have 'key' or 'token' in their name # Check that no other attrs in AppConfig have 'key' or 'token' in their name
# This will fail when new attrs are added, and attract attention # This will fail when new attrs are added, and attract attention
known_key_token_attrs_app = ['e2b_api_key'] known_key_token_attrs_app = [
'e2b_api_key',
'modal_api_token_id',
'modal_api_token_secret',
]
for attr_name in dir(AppConfig): for attr_name in dir(AppConfig):
if ( if (
not attr_name.startswith('__') not attr_name.startswith('__')