mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
430 lines
15 KiB
Python
430 lines
15 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import socket
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import AsyncGenerator
|
|
|
|
import base62
|
|
import docker
|
|
import httpx
|
|
from docker.errors import APIError, NotFound
|
|
from fastapi import Request
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
|
|
from openhands.agent_server.utils import utc_now
|
|
from openhands.app_server.errors import SandboxError
|
|
from openhands.app_server.sandbox.docker_sandbox_spec_service import get_docker_client
|
|
from openhands.app_server.sandbox.sandbox_models import (
|
|
AGENT_SERVER,
|
|
VSCODE,
|
|
ExposedUrl,
|
|
SandboxInfo,
|
|
SandboxPage,
|
|
SandboxStatus,
|
|
)
|
|
from openhands.app_server.sandbox.sandbox_service import (
|
|
SandboxService,
|
|
SandboxServiceInjector,
|
|
)
|
|
from openhands.app_server.sandbox.sandbox_spec_service import SandboxSpecService
|
|
from openhands.app_server.services.injector import InjectorState
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
SESSION_API_KEY_VARIABLE = 'OH_SESSION_API_KEYS_0'
|
|
WEBHOOK_CALLBACK_VARIABLE = 'OH_WEBHOOKS_0_BASE_URL'
|
|
|
|
|
|
class VolumeMount(BaseModel):
|
|
"""Mounted volume within the container."""
|
|
|
|
host_path: str
|
|
container_path: str
|
|
mode: str = 'rw'
|
|
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
|
|
class ExposedPort(BaseModel):
|
|
"""Exposed port within container to be matched to a free port on the host."""
|
|
|
|
name: str
|
|
description: str
|
|
container_port: int = 8000
|
|
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
|
|
@dataclass
|
|
class DockerSandboxService(SandboxService):
|
|
"""Sandbox service built on docker.
|
|
|
|
The Docker API does not currently support async operations, so some of these operations will block.
|
|
Given that the docker API is intended for local use on a single machine, this is probably acceptable.
|
|
"""
|
|
|
|
sandbox_spec_service: SandboxSpecService
|
|
container_name_prefix: str
|
|
host_port: int
|
|
container_url_pattern: str
|
|
mounts: list[VolumeMount]
|
|
exposed_ports: list[ExposedPort]
|
|
health_check_path: str | None
|
|
httpx_client: httpx.AsyncClient
|
|
docker_client: docker.DockerClient = field(default_factory=get_docker_client)
|
|
|
|
def _find_unused_port(self) -> int:
|
|
"""Find an unused port on the host machine."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(('', 0))
|
|
s.listen(1)
|
|
port = s.getsockname()[1]
|
|
return port
|
|
|
|
def _docker_status_to_sandbox_status(self, docker_status: str) -> SandboxStatus:
|
|
"""Convert Docker container status to SandboxStatus."""
|
|
status_mapping = {
|
|
'running': SandboxStatus.RUNNING,
|
|
'paused': SandboxStatus.PAUSED,
|
|
'exited': SandboxStatus.MISSING,
|
|
'created': SandboxStatus.STARTING,
|
|
'restarting': SandboxStatus.STARTING,
|
|
'removing': SandboxStatus.MISSING,
|
|
'dead': SandboxStatus.ERROR,
|
|
}
|
|
return status_mapping.get(docker_status.lower(), SandboxStatus.ERROR)
|
|
|
|
def _get_container_env_vars(self, container) -> dict[str, str | None]:
|
|
env_vars_list = container.attrs['Config']['Env']
|
|
result = {}
|
|
for env_var in env_vars_list:
|
|
if '=' in env_var:
|
|
key, value = env_var.split('=', 1)
|
|
result[key] = value
|
|
else:
|
|
# Handle cases where an environment variable might not have a value
|
|
result[env_var] = None
|
|
return result
|
|
|
|
async def _container_to_sandbox_info(self, container) -> SandboxInfo | None:
|
|
"""Convert Docker container to SandboxInfo."""
|
|
# Convert Docker status to runtime status
|
|
status = self._docker_status_to_sandbox_status(container.status)
|
|
|
|
# Parse creation time
|
|
created_str = container.attrs.get('Created', '')
|
|
try:
|
|
created_at = datetime.fromisoformat(created_str.replace('Z', '+00:00'))
|
|
except (ValueError, AttributeError):
|
|
created_at = utc_now()
|
|
|
|
# Get URL and session key for running containers
|
|
exposed_urls = None
|
|
session_api_key = None
|
|
|
|
if status == SandboxStatus.RUNNING:
|
|
# Get the first exposed port mapping
|
|
exposed_urls = []
|
|
port_bindings = container.attrs.get('NetworkSettings', {}).get('Ports', {})
|
|
if port_bindings:
|
|
for container_port, host_bindings in port_bindings.items():
|
|
if host_bindings:
|
|
host_port = host_bindings[0]['HostPort']
|
|
exposed_port = next(
|
|
(
|
|
exposed_port
|
|
for exposed_port in self.exposed_ports
|
|
if container_port
|
|
== f'{exposed_port.container_port}/tcp'
|
|
),
|
|
None,
|
|
)
|
|
if exposed_port:
|
|
exposed_urls.append(
|
|
ExposedUrl(
|
|
name=exposed_port.name,
|
|
url=self.container_url_pattern.format(
|
|
port=host_port
|
|
),
|
|
)
|
|
)
|
|
|
|
# Get session API key
|
|
env = self._get_container_env_vars(container)
|
|
session_api_key = env[SESSION_API_KEY_VARIABLE]
|
|
|
|
return SandboxInfo(
|
|
id=container.name,
|
|
created_by_user_id=None,
|
|
sandbox_spec_id=container.image.tags[0],
|
|
status=status,
|
|
session_api_key=session_api_key,
|
|
exposed_urls=exposed_urls,
|
|
created_at=created_at,
|
|
)
|
|
|
|
async def _container_to_checked_sandbox_info(self, container) -> SandboxInfo | None:
|
|
sandbox_info = await self._container_to_sandbox_info(container)
|
|
if (
|
|
sandbox_info
|
|
and self.health_check_path is not None
|
|
and sandbox_info.exposed_urls
|
|
):
|
|
app_server_url = next(
|
|
exposed_url.url
|
|
for exposed_url in sandbox_info.exposed_urls
|
|
if exposed_url.name == AGENT_SERVER
|
|
)
|
|
try:
|
|
response = await self.httpx_client.get(
|
|
f'{app_server_url}{self.health_check_path}'
|
|
)
|
|
response.raise_for_status()
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
_logger.info(f'Sandbox server not running: {exc}')
|
|
sandbox_info.status = SandboxStatus.ERROR
|
|
sandbox_info.exposed_urls = None
|
|
sandbox_info.session_api_key = None
|
|
return sandbox_info
|
|
|
|
async def search_sandboxes(
|
|
self,
|
|
page_id: str | None = None,
|
|
limit: int = 100,
|
|
) -> SandboxPage:
|
|
"""Search for sandboxes."""
|
|
try:
|
|
# Get all containers with our prefix
|
|
all_containers = self.docker_client.containers.list(all=True)
|
|
sandboxes = []
|
|
|
|
for container in all_containers:
|
|
if container.name.startswith(self.container_name_prefix):
|
|
sandbox_info = await self._container_to_checked_sandbox_info(
|
|
container
|
|
)
|
|
if sandbox_info:
|
|
sandboxes.append(sandbox_info)
|
|
|
|
# Sort by creation time (newest first)
|
|
sandboxes.sort(key=lambda x: x.created_at, reverse=True)
|
|
|
|
# Apply pagination
|
|
start_idx = 0
|
|
if page_id:
|
|
try:
|
|
start_idx = int(page_id)
|
|
except ValueError:
|
|
start_idx = 0
|
|
|
|
end_idx = start_idx + limit
|
|
paginated_containers = sandboxes[start_idx:end_idx]
|
|
|
|
# Determine next page ID
|
|
next_page_id = None
|
|
if end_idx < len(sandboxes):
|
|
next_page_id = str(end_idx)
|
|
|
|
return SandboxPage(items=paginated_containers, next_page_id=next_page_id)
|
|
|
|
except APIError:
|
|
return SandboxPage(items=[], next_page_id=None)
|
|
|
|
async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None:
|
|
"""Get a single sandbox info."""
|
|
try:
|
|
if not sandbox_id.startswith(self.container_name_prefix):
|
|
return None
|
|
container = self.docker_client.containers.get(sandbox_id)
|
|
return await self._container_to_checked_sandbox_info(container)
|
|
except (NotFound, APIError):
|
|
return None
|
|
|
|
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
|
"""Start a new sandbox."""
|
|
if sandbox_spec_id is None:
|
|
sandbox_spec = await self.sandbox_spec_service.get_default_sandbox_spec()
|
|
else:
|
|
sandbox_spec_maybe = await self.sandbox_spec_service.get_sandbox_spec(
|
|
sandbox_spec_id
|
|
)
|
|
if sandbox_spec_maybe is None:
|
|
raise ValueError('Sandbox Spec not found')
|
|
sandbox_spec = sandbox_spec_maybe
|
|
|
|
# Generate container ID and session api key
|
|
container_name = (
|
|
f'{self.container_name_prefix}{base62.encodebytes(os.urandom(16))}'
|
|
)
|
|
session_api_key = base62.encodebytes(os.urandom(32))
|
|
|
|
# Prepare environment variables
|
|
env_vars = sandbox_spec.initial_env.copy()
|
|
env_vars[SESSION_API_KEY_VARIABLE] = session_api_key
|
|
env_vars[WEBHOOK_CALLBACK_VARIABLE] = (
|
|
f'http://host.docker.internal:{self.host_port}'
|
|
f'/api/v1/webhooks/{container_name}'
|
|
)
|
|
|
|
# Prepare port mappings and add port environment variables
|
|
port_mappings = {}
|
|
for exposed_port in self.exposed_ports:
|
|
host_port = self._find_unused_port()
|
|
port_mappings[exposed_port.container_port] = host_port
|
|
# Add port as environment variable
|
|
env_vars[exposed_port.name] = str(host_port)
|
|
|
|
# Prepare labels
|
|
labels = {
|
|
'sandbox_spec_id': sandbox_spec.id,
|
|
}
|
|
|
|
# Prepare volumes
|
|
volumes = {
|
|
mount.host_path: {
|
|
'bind': mount.container_path,
|
|
'mode': mount.mode,
|
|
}
|
|
for mount in self.mounts
|
|
}
|
|
|
|
try:
|
|
# Create and start the container
|
|
container = self.docker_client.containers.run( # type: ignore[call-overload]
|
|
image=sandbox_spec.id,
|
|
command=sandbox_spec.command, # Use default command from image
|
|
remove=False,
|
|
name=container_name,
|
|
environment=env_vars,
|
|
ports=port_mappings,
|
|
volumes=volumes,
|
|
working_dir=sandbox_spec.working_dir,
|
|
labels=labels,
|
|
detach=True,
|
|
)
|
|
|
|
sandbox_info = await self._container_to_sandbox_info(container)
|
|
assert sandbox_info is not None
|
|
return sandbox_info
|
|
|
|
except APIError as e:
|
|
raise SandboxError(f'Failed to start container: {e}')
|
|
|
|
async def resume_sandbox(self, sandbox_id: str) -> bool:
|
|
"""Resume a paused sandbox."""
|
|
try:
|
|
if not sandbox_id.startswith(self.container_name_prefix):
|
|
return False
|
|
container = self.docker_client.containers.get(sandbox_id)
|
|
|
|
if container.status == 'paused':
|
|
container.unpause()
|
|
elif container.status == 'exited':
|
|
container.start()
|
|
|
|
return True
|
|
except (NotFound, APIError):
|
|
return False
|
|
|
|
async def pause_sandbox(self, sandbox_id: str) -> bool:
|
|
"""Pause a running sandbox."""
|
|
try:
|
|
if not sandbox_id.startswith(self.container_name_prefix):
|
|
return False
|
|
container = self.docker_client.containers.get(sandbox_id)
|
|
|
|
if container.status == 'running':
|
|
container.pause()
|
|
|
|
return True
|
|
except (NotFound, APIError):
|
|
return False
|
|
|
|
async def delete_sandbox(self, sandbox_id: str) -> bool:
|
|
"""Delete a sandbox."""
|
|
try:
|
|
if not sandbox_id.startswith(self.container_name_prefix):
|
|
return False
|
|
container = self.docker_client.containers.get(sandbox_id)
|
|
|
|
# Stop the container if it's running
|
|
if container.status in ['running', 'paused']:
|
|
container.stop(timeout=10)
|
|
|
|
# Remove the container
|
|
container.remove()
|
|
|
|
# Remove associated volume
|
|
try:
|
|
volume_name = f'openhands-workspace-{sandbox_id}'
|
|
volume = self.docker_client.volumes.get(volume_name)
|
|
volume.remove()
|
|
except (NotFound, APIError):
|
|
# Volume might not exist or already removed
|
|
pass
|
|
|
|
return True
|
|
except (NotFound, APIError):
|
|
return False
|
|
|
|
|
|
class DockerSandboxServiceInjector(SandboxServiceInjector):
|
|
"""Dependency injector for docker sandbox services."""
|
|
|
|
container_url_pattern: str = 'http://localhost:{port}'
|
|
host_port: int = 3000
|
|
container_name_prefix: str = 'oh-agent-server-'
|
|
mounts: list[VolumeMount] = Field(default_factory=list)
|
|
exposed_ports: list[ExposedPort] = Field(
|
|
default_factory=lambda: [
|
|
ExposedPort(
|
|
name=AGENT_SERVER,
|
|
description=(
|
|
'The port on which the agent server runs within the container'
|
|
),
|
|
container_port=8000,
|
|
),
|
|
ExposedPort(
|
|
name=VSCODE,
|
|
description=(
|
|
'The port on which the VSCode server runs within the container'
|
|
),
|
|
container_port=8001,
|
|
),
|
|
]
|
|
)
|
|
health_check_path: str | None = Field(
|
|
default='/health',
|
|
description=(
|
|
'The url path in the sandbox agent server to check to '
|
|
'determine whether the server is running'
|
|
),
|
|
)
|
|
|
|
async def inject(
|
|
self, state: InjectorState, request: Request | None = None
|
|
) -> AsyncGenerator[SandboxService, None]:
|
|
# Define inline to prevent circular lookup
|
|
from openhands.app_server.config import (
|
|
get_httpx_client,
|
|
get_sandbox_spec_service,
|
|
)
|
|
|
|
async with (
|
|
get_httpx_client(state) as httpx_client,
|
|
get_sandbox_spec_service(state) as sandbox_spec_service,
|
|
):
|
|
yield DockerSandboxService(
|
|
sandbox_spec_service=sandbox_spec_service,
|
|
container_name_prefix=self.container_name_prefix,
|
|
host_port=self.host_port,
|
|
container_url_pattern=self.container_url_pattern,
|
|
mounts=self.mounts,
|
|
exposed_ports=self.exposed_ports,
|
|
health_check_path=self.health_check_path,
|
|
httpx_client=httpx_client,
|
|
)
|