mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
Add optional sandbox_id parameter to start_sandbox method (#12382)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -478,7 +478,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
"""Wait for sandbox to start and return info."""
|
||||
# Get or create the sandbox
|
||||
if not task.request.sandbox_id:
|
||||
sandbox = await self.sandbox_service.start_sandbox()
|
||||
# Convert conversation_id to hex string if present
|
||||
sandbox_id_str = (
|
||||
task.request.conversation_id.hex
|
||||
if task.request.conversation_id is not None
|
||||
else None
|
||||
)
|
||||
sandbox = await self.sandbox_service.start_sandbox(
|
||||
sandbox_id=sandbox_id_str
|
||||
)
|
||||
task.sandbox_id = sandbox.id
|
||||
else:
|
||||
sandbox_info = await self.sandbox_service.get_sandbox(
|
||||
|
||||
@@ -294,7 +294,9 @@ class DockerSandboxService(SandboxService):
|
||||
except (NotFound, APIError):
|
||||
return None
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
async def start_sandbox(
|
||||
self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None
|
||||
) -> SandboxInfo:
|
||||
"""Start a new sandbox."""
|
||||
# Enforce sandbox limits by cleaning up old sandboxes
|
||||
await self.pause_old_sandboxes(self.max_num_sandboxes - 1)
|
||||
@@ -309,10 +311,12 @@ class DockerSandboxService(SandboxService):
|
||||
raise ValueError('Sandbox Spec not found')
|
||||
sandbox_spec = sandbox_spec_maybe
|
||||
|
||||
# Generate container ID and session api key
|
||||
container_name = (
|
||||
f'{self.container_name_prefix}{base62.encodebytes(os.urandom(16))}'
|
||||
)
|
||||
# Generate a sandbox id if none was provided
|
||||
if sandbox_id is None:
|
||||
sandbox_id = base62.encodebytes(os.urandom(16))
|
||||
|
||||
# Generate container name and session api key
|
||||
container_name = f'{self.container_name_prefix}{sandbox_id}'
|
||||
session_api_key = base62.encodebytes(os.urandom(32))
|
||||
|
||||
# Prepare environment variables
|
||||
|
||||
@@ -286,7 +286,9 @@ class ProcessSandboxService(SandboxService):
|
||||
|
||||
return None
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
async def start_sandbox(
|
||||
self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None
|
||||
) -> SandboxInfo:
|
||||
"""Start a new sandbox."""
|
||||
# Get sandbox spec
|
||||
if sandbox_spec_id is None:
|
||||
@@ -300,7 +302,9 @@ class ProcessSandboxService(SandboxService):
|
||||
sandbox_spec = sandbox_spec_maybe
|
||||
|
||||
# Generate unique sandbox ID and session API key
|
||||
sandbox_id = base62.encodebytes(os.urandom(16))
|
||||
# Use provided sandbox_id if available, otherwise generate a random one
|
||||
if sandbox_id is None:
|
||||
sandbox_id = base62.encodebytes(os.urandom(16))
|
||||
session_api_key = base62.encodebytes(os.urandom(32))
|
||||
|
||||
# Find available port
|
||||
|
||||
@@ -383,7 +383,9 @@ class RemoteSandboxService(SandboxService):
|
||||
|
||||
return None
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
async def start_sandbox(
|
||||
self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None
|
||||
) -> SandboxInfo:
|
||||
"""Start a new sandbox by creating a remote runtime."""
|
||||
try:
|
||||
# Enforce sandbox limits by cleaning up old sandboxes
|
||||
@@ -402,8 +404,9 @@ class RemoteSandboxService(SandboxService):
|
||||
raise ValueError('Sandbox Spec not found')
|
||||
sandbox_spec = sandbox_spec_maybe
|
||||
|
||||
# Create a unique id
|
||||
sandbox_id = base62.encodebytes(os.urandom(16))
|
||||
# Create a unique id, use provided sandbox_id if available
|
||||
if sandbox_id is None:
|
||||
sandbox_id = base62.encodebytes(os.urandom(16))
|
||||
|
||||
# get user id
|
||||
user_id = await self.user_context.get_user_id()
|
||||
|
||||
@@ -50,10 +50,14 @@ class SandboxService(ABC):
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
async def start_sandbox(
|
||||
self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None
|
||||
) -> SandboxInfo:
|
||||
"""Begin the process of starting a sandbox.
|
||||
|
||||
Return the info on the new sandbox. If no spec is selected, use the default.
|
||||
If sandbox_id is provided, it will be used as the sandbox identifier instead
|
||||
of generating a random one.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -430,6 +430,45 @@ class TestDockerSandboxService:
|
||||
):
|
||||
await service.start_sandbox(sandbox_spec_id='nonexistent')
|
||||
|
||||
@patch('openhands.app_server.sandbox.docker_sandbox_service.base62.encodebytes')
|
||||
@patch('os.urandom')
|
||||
async def test_start_sandbox_with_sandbox_id(
|
||||
self, mock_urandom, mock_encodebytes, service
|
||||
):
|
||||
"""Test starting sandbox with a specified sandbox_id."""
|
||||
# Setup - only need urandom for session key
|
||||
mock_urandom.return_value = b'session_key'
|
||||
mock_encodebytes.return_value = 'test_session_key'
|
||||
|
||||
mock_container = MagicMock()
|
||||
mock_container.name = 'oh-test-custom_sandbox_id'
|
||||
mock_container.status = 'running'
|
||||
mock_container.image.tags = ['test-image:latest']
|
||||
mock_container.attrs = {
|
||||
'Created': '2024-01-15T10:30:00.000000000Z',
|
||||
'Config': {
|
||||
'Env': ['OH_SESSION_API_KEYS_0=test_session_key', 'TEST_VAR=test_value']
|
||||
},
|
||||
'NetworkSettings': {'Ports': {}},
|
||||
}
|
||||
|
||||
service.docker_client.containers.run.return_value = mock_container
|
||||
|
||||
with (
|
||||
patch.object(service, '_find_unused_port', side_effect=[12345, 12346]),
|
||||
patch.object(service, 'pause_old_sandboxes', return_value=[]),
|
||||
):
|
||||
# Execute with custom sandbox_id
|
||||
result = await service.start_sandbox(sandbox_id='custom_sandbox_id')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.id == 'oh-test-custom_sandbox_id'
|
||||
|
||||
# Verify container was created with the custom sandbox ID in the name
|
||||
call_args = service.docker_client.containers.run.call_args
|
||||
assert call_args[1]['name'] == 'oh-test-custom_sandbox_id'
|
||||
|
||||
async def test_start_sandbox_docker_error(self, service):
|
||||
"""Test handling of Docker errors during sandbox startup."""
|
||||
# Setup
|
||||
|
||||
@@ -185,6 +185,41 @@ class TestProcessSandboxService:
|
||||
result = await process_sandbox_service.delete_sandbox('nonexistent')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_with_sandbox_id(self, process_sandbox_service):
|
||||
"""Test starting a sandbox with a specified sandbox_id."""
|
||||
# Mock subprocess and waiting for server
|
||||
with (
|
||||
patch.object(
|
||||
process_sandbox_service, '_start_agent_process'
|
||||
) as mock_start_process,
|
||||
patch.object(
|
||||
process_sandbox_service, '_wait_for_server_ready', return_value=True
|
||||
),
|
||||
patch.object(
|
||||
process_sandbox_service,
|
||||
'_get_process_status',
|
||||
return_value=SandboxStatus.RUNNING,
|
||||
),
|
||||
):
|
||||
mock_process = MagicMock()
|
||||
mock_process.pid = 1234
|
||||
mock_start_process.return_value = mock_process
|
||||
|
||||
# Mock successful health check response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
process_sandbox_service.httpx_client.get.return_value = mock_response
|
||||
|
||||
# Execute with custom sandbox_id
|
||||
result = await process_sandbox_service.start_sandbox(
|
||||
sandbox_id='custom_sandbox_id'
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.id == 'custom_sandbox_id'
|
||||
|
||||
@patch('psutil.Process')
|
||||
def test_get_process_status_paused(
|
||||
self, mock_process_class, process_sandbox_service
|
||||
|
||||
@@ -450,6 +450,34 @@ class TestSandboxLifecycle:
|
||||
with pytest.raises(ValueError, match='Sandbox Spec not found'):
|
||||
await remote_sandbox_service.start_sandbox('non-existent-spec')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_with_sandbox_id(
|
||||
self, remote_sandbox_service, mock_sandbox_spec_service
|
||||
):
|
||||
"""Test starting sandbox with a specified sandbox_id."""
|
||||
# Setup
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = create_runtime_data(
|
||||
session_id='custom_sandbox_id'
|
||||
)
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
|
||||
# Mock database operations
|
||||
remote_sandbox_service.db_session.add = MagicMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
|
||||
# Execute with custom sandbox_id - should not need base62 encoding
|
||||
sandbox_info = await remote_sandbox_service.start_sandbox(
|
||||
sandbox_id='custom_sandbox_id'
|
||||
)
|
||||
|
||||
# Verify the custom sandbox_id is used
|
||||
assert sandbox_info.id == 'custom_sandbox_id'
|
||||
# Verify the stored sandbox used the custom ID
|
||||
add_call_args = remote_sandbox_service.db_session.add.call_args[0][0]
|
||||
assert add_call_args.id == 'custom_sandbox_id'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_http_error(self, remote_sandbox_service):
|
||||
"""Test sandbox start with HTTP error."""
|
||||
|
||||
@@ -46,8 +46,10 @@ class MockSandboxService(SandboxService):
|
||||
) -> SandboxInfo | None:
|
||||
return await self.get_sandbox_by_session_api_key_mock(session_api_key)
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
return await self.start_sandbox_mock(sandbox_spec_id)
|
||||
async def start_sandbox(
|
||||
self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None
|
||||
) -> SandboxInfo:
|
||||
return await self.start_sandbox_mock(sandbox_spec_id, sandbox_id)
|
||||
|
||||
async def resume_sandbox(self, sandbox_id: str) -> bool:
|
||||
return await self.resume_sandbox_mock(sandbox_id)
|
||||
|
||||
Reference in New Issue
Block a user