Add Concurrency Limits to SandboxService (#11399)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2025-10-20 14:22:12 -06:00
committed by GitHub
parent 9efe6eb776
commit 44578664ed
5 changed files with 489 additions and 16 deletions

View File

@@ -74,6 +74,7 @@ class DockerSandboxService(SandboxService):
exposed_ports: list[ExposedPort]
health_check_path: str | None
httpx_client: httpx.AsyncClient
max_num_sandboxes: int
docker_client: docker.DockerClient = field(default_factory=get_docker_client)
def _find_unused_port(self) -> int:
@@ -251,6 +252,9 @@ class DockerSandboxService(SandboxService):
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
"""Start a new sandbox."""
# Enforce sandbox limits by cleaning up old sandboxes
await self.pause_old_sandboxes(self.max_num_sandboxes - 1)
if sandbox_spec_id is None:
sandbox_spec = await self.sandbox_spec_service.get_default_sandbox_spec()
else:
@@ -321,6 +325,9 @@ class DockerSandboxService(SandboxService):
async def resume_sandbox(self, sandbox_id: str) -> bool:
"""Resume a paused sandbox."""
# Enforce sandbox limits by cleaning up old sandboxes
await self.pause_old_sandboxes(self.max_num_sandboxes - 1)
try:
if not sandbox_id.startswith(self.container_name_prefix):
return False
@@ -383,6 +390,10 @@ class DockerSandboxServiceInjector(SandboxServiceInjector):
container_url_pattern: str = 'http://localhost:{port}'
host_port: int = 3000
container_name_prefix: str = 'oh-agent-server-'
max_num_sandboxes: int = Field(
default=5,
description='Maximum number of sandboxes allowed to run simultaneously',
)
mounts: list[VolumeMount] = Field(default_factory=list)
exposed_ports: list[ExposedPort] = Field(
default_factory=lambda: [
@@ -446,4 +457,5 @@ class DockerSandboxServiceInjector(SandboxServiceInjector):
exposed_ports=self.exposed_ports,
health_check_path=self.health_check_path,
httpx_client=httpx_client,
max_num_sandboxes=self.max_num_sandboxes,
)

View File

@@ -95,6 +95,7 @@ class RemoteSandboxService(SandboxService):
resource_factor: int
runtime_class: str | None
start_sandbox_timeout: int
max_num_sandboxes: int
user_context: UserContext
httpx_client: httpx.AsyncClient
db_session: AsyncSession
@@ -268,6 +269,9 @@ class RemoteSandboxService(SandboxService):
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
"""Start a new sandbox by creating a remote runtime."""
try:
# Enforce sandbox limits by cleaning up old sandboxes
await self.pause_old_sandboxes(self.max_num_sandboxes - 1)
# Get sandbox spec
if sandbox_spec_id is None:
sandbox_spec = (
@@ -338,6 +342,9 @@ class RemoteSandboxService(SandboxService):
async def resume_sandbox(self, sandbox_id: str) -> bool:
"""Resume a paused sandbox."""
# Enforce sandbox limits by cleaning up old sandboxes
await self.pause_old_sandboxes(self.max_num_sandboxes - 1)
try:
if not await self._get_stored_sandbox(sandbox_id):
return False
@@ -588,6 +595,10 @@ class RemoteSandboxServiceInjector(SandboxServiceInjector):
'be in an error state.'
),
)
max_num_sandboxes: int = Field(
default=10,
description='Maximum number of sandboxes allowed to run simultaneously',
)
async def inject(
self, state: InjectorState, request: Request | None = None
@@ -628,6 +639,7 @@ class RemoteSandboxServiceInjector(SandboxServiceInjector):
resource_factor=self.resource_factor,
runtime_class=self.runtime_class,
start_sandbox_timeout=self.start_sandbox_timeout,
max_num_sandboxes=self.max_num_sandboxes,
user_context=user_context,
httpx_client=httpx_client,
db_session=db_session,

View File

@@ -1,7 +1,11 @@
import asyncio
from abc import ABC, abstractmethod
from openhands.app_server.sandbox.sandbox_models import SandboxInfo, SandboxPage
from openhands.app_server.sandbox.sandbox_models import (
SandboxInfo,
SandboxPage,
SandboxStatus,
)
from openhands.app_server.services.injector import Injector
from openhands.sdk.utils.models import DiscriminatedUnionMixin
@@ -60,6 +64,61 @@ class SandboxService(ABC):
Return False if the sandbox did not exist.
"""
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
"""Stop the oldest sandboxes if there are more than max_num_sandboxes running.
Args:
max_num_sandboxes: Maximum number of sandboxes to keep running
Returns:
List of sandbox IDs that were paused
"""
if max_num_sandboxes <= 0:
raise ValueError('max_num_sandboxes must be greater than 0')
# Get all sandboxes (we'll search through all pages)
all_sandboxes = []
page_id = None
while True:
page = await self.search_sandboxes(page_id=page_id, limit=100)
all_sandboxes.extend(page.items)
if page.next_page_id is None:
break
page_id = page.next_page_id
# Filter to only running sandboxes
running_sandboxes = [
sandbox
for sandbox in all_sandboxes
if sandbox.status == SandboxStatus.RUNNING
]
# If we're within the limit, no cleanup needed
if len(running_sandboxes) <= max_num_sandboxes:
return []
# Sort by creation time (oldest first)
running_sandboxes.sort(key=lambda x: x.created_at)
# Determine how many to pause
num_to_pause = len(running_sandboxes) - max_num_sandboxes
sandboxes_to_pause = running_sandboxes[:num_to_pause]
# Stop the oldest sandboxes
paused_sandbox_ids = []
for sandbox in sandboxes_to_pause:
try:
success = await self.pause_sandbox(sandbox.id)
if success:
paused_sandbox_ids.append(sandbox.id)
except Exception:
# Continue trying to pause other sandboxes even if one fails
pass
return paused_sandbox_ids
class SandboxServiceInjector(DiscriminatedUnionMixin, Injector[SandboxService], ABC):
pass

View File

@@ -80,6 +80,7 @@ def service(mock_sandbox_spec_service, mock_httpx_client, mock_docker_client):
],
health_check_path='/health',
httpx_client=mock_httpx_client,
max_num_sandboxes=3,
docker_client=mock_docker_client,
)
@@ -354,7 +355,12 @@ class TestDockerSandboxService:
service.docker_client.containers.run.return_value = mock_container
with patch.object(service, '_find_unused_port', side_effect=[12345, 12346]):
with (
patch.object(service, '_find_unused_port', side_effect=[12345, 12346]),
patch.object(
service, 'pause_old_sandboxes', return_value=[]
) as mock_cleanup,
):
# Execute
result = await service.start_sandbox()
@@ -362,6 +368,9 @@ class TestDockerSandboxService:
assert result is not None
assert result.id == 'oh-test-test_container_id'
# Verify cleanup was called with the correct limit
mock_cleanup.assert_called_once_with(2)
# Verify container was created with correct parameters
service.docker_client.containers.run.assert_called_once()
call_args = service.docker_client.containers.run.call_args
@@ -395,7 +404,10 @@ class TestDockerSandboxService:
}
service.docker_client.containers.run.return_value = mock_container
with patch.object(service, '_find_unused_port', return_value=12345):
with (
patch.object(service, '_find_unused_port', return_value=12345),
patch.object(service, 'pause_old_sandboxes', return_value=[]),
):
# Execute
await service.start_sandbox(sandbox_spec_id='custom-spec')
@@ -412,7 +424,10 @@ class TestDockerSandboxService:
mock_sandbox_spec_service.get_sandbox_spec.return_value = None
# Execute & Verify
with pytest.raises(ValueError, match='Sandbox Spec not found'):
with (
patch.object(service, 'pause_old_sandboxes', return_value=[]),
pytest.raises(ValueError, match='Sandbox Spec not found'),
):
await service.start_sandbox(sandbox_spec_id='nonexistent')
async def test_start_sandbox_docker_error(self, service):
@@ -422,10 +437,12 @@ class TestDockerSandboxService:
'Failed to create container'
)
with patch.object(service, '_find_unused_port', return_value=12345):
# Execute & Verify
with pytest.raises(SandboxError, match='Failed to start container'):
await service.start_sandbox()
with (
patch.object(service, '_find_unused_port', return_value=12345),
patch.object(service, 'pause_old_sandboxes', return_value=[]),
pytest.raises(SandboxError, match='Failed to start container'),
):
await service.start_sandbox()
async def test_resume_sandbox_from_paused(self, service):
"""Test resuming a paused sandbox."""
@@ -434,13 +451,18 @@ class TestDockerSandboxService:
mock_container.status = 'paused'
service.docker_client.containers.get.return_value = mock_container
# Execute
result = await service.resume_sandbox('oh-test-abc123')
with patch.object(
service, 'pause_old_sandboxes', return_value=[]
) as mock_cleanup:
# Execute
result = await service.resume_sandbox('oh-test-abc123')
# Verify
assert result is True
mock_container.unpause.assert_called_once()
mock_container.start.assert_not_called()
# Verify cleanup was called with the correct limit
mock_cleanup.assert_called_once_with(2)
async def test_resume_sandbox_from_exited(self, service):
"""Test resuming an exited sandbox."""
@@ -449,22 +471,32 @@ class TestDockerSandboxService:
mock_container.status = 'exited'
service.docker_client.containers.get.return_value = mock_container
# Execute
result = await service.resume_sandbox('oh-test-abc123')
with patch.object(
service, 'pause_old_sandboxes', return_value=[]
) as mock_cleanup:
# Execute
result = await service.resume_sandbox('oh-test-abc123')
# Verify
assert result is True
mock_container.start.assert_called_once()
mock_container.unpause.assert_not_called()
# Verify cleanup was called with the correct limit
mock_cleanup.assert_called_once_with(2)
async def test_resume_sandbox_wrong_prefix(self, service):
"""Test resuming sandbox with wrong prefix."""
# Execute
result = await service.resume_sandbox('wrong-prefix-abc123')
with patch.object(
service, 'pause_old_sandboxes', return_value=[]
) as mock_cleanup:
# Execute
result = await service.resume_sandbox('wrong-prefix-abc123')
# Verify
assert result is False
service.docker_client.containers.get.assert_not_called()
# Verify cleanup was still called
mock_cleanup.assert_called_once_with(2)
async def test_resume_sandbox_not_found(self, service):
"""Test resuming non-existent sandbox."""
@@ -473,11 +505,16 @@ class TestDockerSandboxService:
'Container not found'
)
# Execute
result = await service.resume_sandbox('oh-test-abc123')
with patch.object(
service, 'pause_old_sandboxes', return_value=[]
) as mock_cleanup:
# Execute
result = await service.resume_sandbox('oh-test-abc123')
# Verify
assert result is False
# Verify cleanup was still called
mock_cleanup.assert_called_once_with(2)
async def test_pause_sandbox_success(self, service):
"""Test pausing a running sandbox."""

View File

@@ -0,0 +1,353 @@
"""Tests for SandboxService base class.
This module tests the SandboxService base class implementation, focusing on:
- pause_old_sandboxes method functionality
- Proper handling of pagination when searching sandboxes
- Correct filtering of running vs non-running sandboxes
- Proper sorting by creation time (oldest first)
- Error handling and edge cases
"""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock
import pytest
from openhands.app_server.sandbox.sandbox_models import (
SandboxInfo,
SandboxPage,
SandboxStatus,
)
from openhands.app_server.sandbox.sandbox_service import SandboxService
class MockSandboxService(SandboxService):
"""Mock implementation of SandboxService for testing."""
def __init__(self):
self.search_sandboxes_mock = AsyncMock()
self.get_sandbox_mock = AsyncMock()
self.start_sandbox_mock = AsyncMock()
self.resume_sandbox_mock = AsyncMock()
self.pause_sandbox_mock = AsyncMock()
self.delete_sandbox_mock = AsyncMock()
async def search_sandboxes(
self, page_id: str | None = None, limit: int = 100
) -> SandboxPage:
return await self.search_sandboxes_mock(page_id=page_id, limit=limit)
async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None:
return await self.get_sandbox_mock(sandbox_id)
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
return await self.start_sandbox_mock(sandbox_spec_id)
async def resume_sandbox(self, sandbox_id: str) -> bool:
return await self.resume_sandbox_mock(sandbox_id)
async def pause_sandbox(self, sandbox_id: str) -> bool:
return await self.pause_sandbox_mock(sandbox_id)
async def delete_sandbox(self, sandbox_id: str) -> bool:
return await self.delete_sandbox_mock(sandbox_id)
def create_sandbox_info(
sandbox_id: str,
status: SandboxStatus,
created_at: datetime,
created_by_user_id: str | None = None,
sandbox_spec_id: str = 'test-spec',
) -> SandboxInfo:
"""Helper function to create SandboxInfo objects for testing."""
return SandboxInfo(
id=sandbox_id,
created_by_user_id=created_by_user_id,
sandbox_spec_id=sandbox_spec_id,
status=status,
session_api_key='test-api-key' if status == SandboxStatus.RUNNING else None,
created_at=created_at,
)
@pytest.fixture
def mock_sandbox_service():
"""Fixture providing a mock sandbox service."""
return MockSandboxService()
class TestCleanupOldSandboxes:
"""Test cases for the pause_old_sandboxes method."""
@pytest.mark.asyncio
async def test_cleanup_with_no_sandboxes(self, mock_sandbox_service):
"""Test cleanup when there are no sandboxes."""
# Setup: No sandboxes
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=[], next_page_id=None
)
# Execute
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=5)
# Verify
assert result == []
mock_sandbox_service.search_sandboxes_mock.assert_called_once_with(
page_id=None, limit=100
)
mock_sandbox_service.pause_sandbox_mock.assert_not_called()
@pytest.mark.asyncio
async def test_cleanup_within_limit(self, mock_sandbox_service):
"""Test cleanup when sandbox count is within the limit."""
# Setup: 3 running sandboxes, limit is 5
now = datetime.now(timezone.utc)
sandboxes = [
create_sandbox_info('sb1', SandboxStatus.RUNNING, now - timedelta(hours=3)),
create_sandbox_info('sb2', SandboxStatus.RUNNING, now - timedelta(hours=2)),
create_sandbox_info('sb3', SandboxStatus.RUNNING, now - timedelta(hours=1)),
]
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=sandboxes, next_page_id=None
)
# Execute
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=5)
# Verify
assert result == []
mock_sandbox_service.pause_sandbox_mock.assert_not_called()
@pytest.mark.asyncio
async def test_cleanup_exceeds_limit(self, mock_sandbox_service):
"""Test cleanup when sandbox count exceeds the limit."""
# Setup: 5 running sandboxes, limit is 3
now = datetime.now(timezone.utc)
sandboxes = [
create_sandbox_info(
'sb1', SandboxStatus.RUNNING, now - timedelta(hours=5)
), # oldest
create_sandbox_info(
'sb2', SandboxStatus.RUNNING, now - timedelta(hours=4)
), # second oldest
create_sandbox_info(
'sb3', SandboxStatus.RUNNING, now - timedelta(hours=3)
), # should be stopped
create_sandbox_info(
'sb4', SandboxStatus.RUNNING, now - timedelta(hours=2)
), # should remain
create_sandbox_info(
'sb5', SandboxStatus.RUNNING, now - timedelta(hours=1)
), # newest
]
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=sandboxes, next_page_id=None
)
mock_sandbox_service.pause_sandbox_mock.return_value = True
# Execute
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=2)
# Verify: Should pause the 2 oldest sandboxes
assert result == ['sb1', 'sb2', 'sb3']
assert mock_sandbox_service.pause_sandbox_mock.call_count == 3
@pytest.mark.asyncio
async def test_cleanup_filters_non_running_sandboxes(self, mock_sandbox_service):
"""Test that cleanup only considers running sandboxes."""
# Setup: Mix of running and non-running sandboxes
now = datetime.now(timezone.utc)
sandboxes = [
create_sandbox_info('sb1', SandboxStatus.RUNNING, now - timedelta(hours=5)),
create_sandbox_info(
'sb2', SandboxStatus.PAUSED, now - timedelta(hours=4)
), # should be ignored
create_sandbox_info('sb3', SandboxStatus.RUNNING, now - timedelta(hours=3)),
create_sandbox_info(
'sb4', SandboxStatus.ERROR, now - timedelta(hours=2)
), # should be ignored
create_sandbox_info('sb5', SandboxStatus.RUNNING, now - timedelta(hours=1)),
]
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=sandboxes, next_page_id=None
)
mock_sandbox_service.pause_sandbox_mock.return_value = True
# Execute: Limit is 2, but only 3 are running
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=2)
# Verify: Should stop only 1 sandbox (the oldest running one)
assert len(result) == 1
assert 'sb1' in result
mock_sandbox_service.pause_sandbox_mock.assert_called_once_with('sb1')
@pytest.mark.asyncio
async def test_cleanup_with_pagination(self, mock_sandbox_service):
"""Test cleanup handles pagination correctly."""
# Setup: Multiple pages of sandboxes
now = datetime.now(timezone.utc)
# First page
page1_sandboxes = [
create_sandbox_info('sb1', SandboxStatus.RUNNING, now - timedelta(hours=3)),
create_sandbox_info('sb2', SandboxStatus.RUNNING, now - timedelta(hours=2)),
]
# Second page
page2_sandboxes = [
create_sandbox_info('sb3', SandboxStatus.RUNNING, now - timedelta(hours=1)),
]
def search_side_effect(page_id=None, limit=100):
if page_id is None:
return SandboxPage(items=page1_sandboxes, next_page_id='page2')
elif page_id == 'page2':
return SandboxPage(items=page2_sandboxes, next_page_id=None)
mock_sandbox_service.search_sandboxes_mock.side_effect = search_side_effect
mock_sandbox_service.pause_sandbox_mock.return_value = True
# Execute: Limit is 2, total is 3
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=2)
# Verify: Should stop the oldest sandbox
assert len(result) == 1
assert 'sb1' in result
assert mock_sandbox_service.search_sandboxes_mock.call_count == 2
@pytest.mark.asyncio
async def test_cleanup_handles_pause_failures(self, mock_sandbox_service):
"""Test cleanup continues when some pause operations fail."""
# Setup: 4 running sandboxes, limit is 2
now = datetime.now(timezone.utc)
sandboxes = [
create_sandbox_info('sb1', SandboxStatus.RUNNING, now - timedelta(hours=4)),
create_sandbox_info('sb2', SandboxStatus.RUNNING, now - timedelta(hours=3)),
create_sandbox_info('sb3', SandboxStatus.RUNNING, now - timedelta(hours=2)),
create_sandbox_info('sb4', SandboxStatus.RUNNING, now - timedelta(hours=1)),
]
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=sandboxes, next_page_id=None
)
# Setup: First pause fails, second succeeds
def pause_side_effect(sandbox_id):
if sandbox_id == 'sb1':
return False # Simulate failure
return True
mock_sandbox_service.pause_sandbox_mock.side_effect = pause_side_effect
# Execute
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=2)
# Verify: Should only include successfully paused sandbox
assert len(result) == 1
assert 'sb2' in result
assert mock_sandbox_service.pause_sandbox_mock.call_count == 2
@pytest.mark.asyncio
async def test_cleanup_handles_pause_exceptions(self, mock_sandbox_service):
"""Test cleanup continues when pause operations raise exceptions."""
# Setup: 3 running sandboxes, limit is 1
now = datetime.now(timezone.utc)
sandboxes = [
create_sandbox_info('sb1', SandboxStatus.RUNNING, now - timedelta(hours=3)),
create_sandbox_info('sb2', SandboxStatus.RUNNING, now - timedelta(hours=2)),
create_sandbox_info('sb3', SandboxStatus.RUNNING, now - timedelta(hours=1)),
]
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=sandboxes, next_page_id=None
)
# Setup: First pause raises exception, second succeeds
def pause_side_effect(sandbox_id):
if sandbox_id == 'sb1':
raise Exception('Delete failed')
return True
mock_sandbox_service.pause_sandbox_mock.side_effect = pause_side_effect
# Execute
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=1)
# Verify: Should only include successfully paused sandbox
assert len(result) == 1
assert 'sb2' in result
assert mock_sandbox_service.pause_sandbox_mock.call_count == 2
@pytest.mark.asyncio
async def test_cleanup_invalid_max_num_sandboxes(self, mock_sandbox_service):
"""Test cleanup raises ValueError for invalid max_num_sandboxes."""
# Test zero
with pytest.raises(
ValueError, match='max_num_sandboxes must be greater than 0'
):
await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=0)
# Test negative
with pytest.raises(
ValueError, match='max_num_sandboxes must be greater than 0'
):
await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=-1)
@pytest.mark.asyncio
async def test_cleanup_sorts_by_creation_time(self, mock_sandbox_service):
"""Test that cleanup properly sorts sandboxes by creation time."""
# Setup: Sandboxes in random order by creation time
now = datetime.now(timezone.utc)
sandboxes = [
create_sandbox_info('sb_newest', SandboxStatus.RUNNING, now), # newest
create_sandbox_info(
'sb_oldest', SandboxStatus.RUNNING, now - timedelta(hours=5)
), # oldest
create_sandbox_info(
'sb_middle', SandboxStatus.RUNNING, now - timedelta(hours=2)
), # middle
]
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=sandboxes, next_page_id=None
)
mock_sandbox_service.pause_sandbox_mock.return_value = True
# Execute: Keep only 1 sandbox
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=1)
# Verify: Should stop the 2 oldest sandboxes in order
assert len(result) == 2
assert 'sb_oldest' in result
assert 'sb_middle' in result
# Verify pause was called in the correct order (oldest first)
calls = mock_sandbox_service.pause_sandbox_mock.call_args_list
assert calls[0][0][0] == 'sb_oldest'
assert calls[1][0][0] == 'sb_middle'
@pytest.mark.asyncio
async def test_cleanup_exact_limit(self, mock_sandbox_service):
"""Test cleanup when sandbox count exactly equals the limit."""
# Setup: Exactly 3 running sandboxes, limit is 3
now = datetime.now(timezone.utc)
sandboxes = [
create_sandbox_info('sb1', SandboxStatus.RUNNING, now - timedelta(hours=3)),
create_sandbox_info('sb2', SandboxStatus.RUNNING, now - timedelta(hours=2)),
create_sandbox_info('sb3', SandboxStatus.RUNNING, now - timedelta(hours=1)),
]
mock_sandbox_service.search_sandboxes_mock.return_value = SandboxPage(
items=sandboxes, next_page_id=None
)
# Execute
result = await mock_sandbox_service.pause_old_sandboxes(max_num_sandboxes=3)
# Verify: No sandboxes should be stopped
assert result == []
mock_sandbox_service.pause_sandbox_mock.assert_not_called()