fix: eliminate N+1 performance bug in RemoteSandboxService with batch endpoint (#12105)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell 2025-12-19 16:24:40 -07:00 committed by GitHub
parent a873af307a
commit adff39507a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 242 additions and 77 deletions

View File

@ -477,7 +477,11 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
if sandbox.status in (None, SandboxStatus.ERROR):
raise SandboxError(f'Sandbox status: {sandbox.status}')
if sandbox.status == SandboxStatus.RUNNING:
return
# There are still bugs in the remote runtime - they report running while still just
# starting resulting in a race condition. Manually check that it is actually
# running.
if await self._check_agent_server_alive(sandbox):
return
if sandbox.status != SandboxStatus.STARTING:
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
@ -490,9 +494,19 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
if sandbox.status not in (SandboxStatus.STARTING, SandboxStatus.RUNNING):
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
if sandbox_info.status == SandboxStatus.RUNNING:
return
# There are still bugs in the remote runtime - they report running while still just
# starting resulting in a race condition. Manually check that it is actually
# running.
if await self._check_agent_server_alive(sandbox_info):
return
raise SandboxError(f'Sandbox failed to start: {sandbox.id}')
async def _check_agent_server_alive(self, sandbox_info: SandboxInfo) -> bool:
agent_server_url = self._get_agent_server_url(sandbox_info)
url = f'{agent_server_url.rstrip("/")}/alive'
response = await self.httpx_client.get(url)
return response.is_success
def _get_agent_server_url(self, sandbox: SandboxInfo) -> str:
"""Get agent server url for running sandbox."""
exposed_urls = sandbox.exposed_urls

View File

@ -122,18 +122,9 @@ class RemoteSandboxService(SandboxService):
_logger.error(f'HTTP error for URL {url}: {e}')
raise
async def _to_sandbox_info(
def _to_sandbox_info(
self, stored: StoredRemoteSandbox, runtime: dict[str, Any] | None = None
) -> SandboxInfo:
# If we did not get passsed runtime data, load some
if runtime is None:
try:
runtime = await self._get_runtime(stored.id)
except Exception:
_logger.exception(
f'Error getting runtime: {stored.id}', stack_info=True
)
):
status = self._get_sandbox_status_from_runtime(runtime)
# Get session_api_key and exposed urls
@ -233,6 +224,41 @@ class RemoteSandboxService(SandboxService):
runtime_data = response.json()
return runtime_data
async def _get_runtimes_batch(
self, sandbox_ids: list[str]
) -> dict[str, dict[str, Any]]:
"""Get multiple runtimes in a single batch request.
Args:
sandbox_ids: List of sandbox IDs to fetch
Returns:
Dictionary mapping sandbox_id to runtime data
"""
if not sandbox_ids:
return {}
# Build query parameters for the batch endpoint
params = [('ids', sandbox_id) for sandbox_id in sandbox_ids]
response = await self._send_runtime_api_request(
'GET',
'/sessions/batch',
params=params,
)
response.raise_for_status()
batch_data = response.json()
# The batch endpoint should return a list of runtimes
# Convert to a dictionary keyed by session_id for easy lookup
runtimes_by_id = {}
if batch_data and 'runtimes' in batch_data:
for runtime in batch_data['runtimes']:
if 'session_id' in runtime:
runtimes_by_id[runtime['session_id']] = runtime
return runtimes_by_id
async def _init_environment(
self, sandbox_spec: SandboxSpecInfo, sandbox_id: str
) -> dict[str, str]:
@ -283,13 +309,15 @@ class RemoteSandboxService(SandboxService):
if has_more:
next_page_id = str(offset + limit)
# Convert stored callbacks to domain models
items = await asyncio.gather(
*[
self._to_sandbox_info(stored_sandbox)
for stored_sandbox in stored_sandboxes
]
)
# Batch fetch runtime data for all sandboxes
sandbox_ids = [stored_sandbox.id for stored_sandbox in stored_sandboxes]
runtimes_by_id = await self._get_runtimes_batch(sandbox_ids)
# Convert stored sandboxes to domain models with runtime data
items = [
self._to_sandbox_info(stored_sandbox, runtimes_by_id.get(stored_sandbox.id))
for stored_sandbox in stored_sandboxes
]
return SandboxPage(items=items, next_page_id=next_page_id)
@ -298,7 +326,16 @@ class RemoteSandboxService(SandboxService):
stored_sandbox = await self._get_stored_sandbox(sandbox_id)
if stored_sandbox is None:
return None
return await self._to_sandbox_info(stored_sandbox)
runtime = None
try:
runtime = await self._get_runtime(stored_sandbox.id)
except Exception:
_logger.exception(
f'Error getting runtime: {stored_sandbox.id}', stack_info=True
)
return self._to_sandbox_info(stored_sandbox, runtime)
async def get_sandbox_by_session_api_key(
self, session_api_key: str
@ -323,7 +360,7 @@ class RemoteSandboxService(SandboxService):
sandbox = result.first()
if sandbox is None:
raise ValueError('sandbox_not_found')
return await self._to_sandbox_info(sandbox, runtime)
return self._to_sandbox_info(sandbox, runtime)
except Exception:
_logger.exception(
'Error getting sandbox from session_api_key', stack_info=True
@ -339,7 +376,7 @@ class RemoteSandboxService(SandboxService):
try:
runtime = await self._get_runtime(stored_sandbox.id)
if runtime and runtime.get('session_api_key') == session_api_key:
return await self._to_sandbox_info(stored_sandbox, runtime)
return self._to_sandbox_info(stored_sandbox, runtime)
except Exception:
# Continue checking other sandboxes if one fails
continue
@ -412,7 +449,7 @@ class RemoteSandboxService(SandboxService):
# Hack - result doesn't contain this
runtime_data['pod_status'] = 'pending'
return await self._to_sandbox_info(stored_sandbox, runtime_data)
return self._to_sandbox_info(stored_sandbox, runtime_data)
except httpx.HTTPError as e:
_logger.error(f'Failed to start sandbox: {e}')
@ -480,6 +517,55 @@ class RemoteSandboxService(SandboxService):
_logger.error(f'Error deleting sandbox {sandbox_id}: {e}')
return False
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
"""Pause the oldest sandboxes if there are more than max_num_sandboxes running.
In a multi user environment, this will pause sandboxes only for the current user.
Args:
max_num_sandboxes: Maximum number of sandboxes to keep running
Returns:
List of sandbox IDs that were paused
"""
if max_num_sandboxes <= 0:
raise ValueError('max_num_sandboxes must be greater than 0')
response = await self._send_runtime_api_request(
'GET',
'/list',
)
content = response.json()
running_session_ids = [
runtime.get('session_id') for runtime in content['runtimes']
]
query = await self._secure_select()
query = query.filter(StoredRemoteSandbox.id.in_(running_session_ids)).order_by(
StoredRemoteSandbox.created_at.desc()
)
running_sandboxes = list(await self.db_session.execute(query))
# If we're within the limit, no cleanup needed
if len(running_sandboxes) <= max_num_sandboxes:
return []
# Determine how many to pause
num_to_pause = len(running_sandboxes) - max_num_sandboxes
sandboxes_to_pause = running_sandboxes[:num_to_pause]
# Stop the oldest sandboxes
paused_sandbox_ids = []
for sandbox in sandboxes_to_pause:
try:
success = await self.pause_sandbox(sandbox.id)
if success:
paused_sandbox_ids.append(sandbox.id)
except Exception:
# Continue trying to pause other sandboxes even if one fails
pass
return paused_sandbox_ids
def _build_service_url(url: str, service_name: str):
scheme, host_and_path = url.split('://')

View File

@ -72,7 +72,7 @@ class SandboxService(ABC):
"""
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
"""Stop the oldest sandboxes if there are more than max_num_sandboxes running.
"""Pause the oldest sandboxes if there are more than max_num_sandboxes running.
In a multi user environment, this will pause sandboxes only for the current user.
Args:

View File

@ -331,7 +331,7 @@ class TestSandboxInfoConversion:
runtime_data = create_runtime_data(status='running', pod_status='ready')
# Execute
sandbox_info = await remote_sandbox_service._to_sandbox_info(
sandbox_info = remote_sandbox_service._to_sandbox_info(
stored_sandbox, runtime_data
)
@ -358,7 +358,7 @@ class TestSandboxInfoConversion:
runtime_data = create_runtime_data(status='running', pod_status='pending')
# Execute
sandbox_info = await remote_sandbox_service._to_sandbox_info(
sandbox_info = remote_sandbox_service._to_sandbox_info(
stored_sandbox, runtime_data
)
@ -367,23 +367,6 @@ class TestSandboxInfoConversion:
assert sandbox_info.session_api_key == 'test-session-key'
assert sandbox_info.exposed_urls is None
@pytest.mark.asyncio
async def test_to_sandbox_info_without_runtime(self, remote_sandbox_service):
"""Test conversion to SandboxInfo without runtime data."""
# Setup
stored_sandbox = create_stored_sandbox()
remote_sandbox_service._get_runtime = AsyncMock(
side_effect=Exception('Runtime not found')
)
# Execute
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
# Verify
assert sandbox_info.status == SandboxStatus.MISSING
assert sandbox_info.session_api_key is None
assert sandbox_info.exposed_urls is None
@pytest.mark.asyncio
async def test_to_sandbox_info_loads_runtime_when_none_provided(
self, remote_sandbox_service
@ -391,15 +374,12 @@ class TestSandboxInfoConversion:
"""Test that runtime data is loaded when not provided."""
# Setup
stored_sandbox = create_stored_sandbox()
runtime_data = create_runtime_data()
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
# Execute
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
sandbox_info = remote_sandbox_service._to_sandbox_info(stored_sandbox, None)
# Verify
remote_sandbox_service._get_runtime.assert_called_once_with('test-sandbox-123')
assert sandbox_info.status == SandboxStatus.RUNNING
assert sandbox_info.status == SandboxStatus.MISSING
class TestSandboxLifecycle:
@ -677,15 +657,18 @@ class TestSandboxSearch:
mock_result = MagicMock()
mock_result.scalars.return_value = mock_scalars
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
remote_sandbox_service._to_sandbox_info = AsyncMock(
side_effect=lambda stored: SandboxInfo(
id=stored.id,
created_by_user_id=stored.created_by_user_id,
sandbox_spec_id=stored.sandbox_spec_id,
status=SandboxStatus.RUNNING,
session_api_key='test-key',
created_at=stored.created_at,
)
# Mock the batch endpoint response
mock_batch_response = MagicMock()
mock_batch_response.raise_for_status.return_value = None
mock_batch_response.json.return_value = {
'runtimes': [
create_runtime_data('sb1'),
create_runtime_data('sb2'),
]
}
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_batch_response
)
# Execute
@ -697,6 +680,14 @@ class TestSandboxSearch:
assert result.items[0].id == 'sb1'
assert result.items[1].id == 'sb2'
# Verify that the batch endpoint was called
remote_sandbox_service.httpx_client.request.assert_called_once_with(
'GET',
'https://api.example.com/sessions/batch',
headers={'X-API-Key': 'test-api-key'},
params=[('ids', 'sb1'), ('ids', 'sb2')],
)
@pytest.mark.asyncio
async def test_search_sandboxes_with_pagination(self, remote_sandbox_service):
"""Test sandbox search with pagination."""
@ -710,15 +701,15 @@ class TestSandboxSearch:
mock_result = MagicMock()
mock_result.scalars.return_value = mock_scalars
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
remote_sandbox_service._to_sandbox_info = AsyncMock(
side_effect=lambda stored: SandboxInfo(
id=stored.id,
created_by_user_id=stored.created_by_user_id,
sandbox_spec_id=stored.sandbox_spec_id,
status=SandboxStatus.RUNNING,
session_api_key='test-key',
created_at=stored.created_at,
)
# Mock the batch endpoint response
mock_batch_response = MagicMock()
mock_batch_response.raise_for_status.return_value = None
mock_batch_response.json.return_value = {
'runtimes': [create_runtime_data(f'sb{i}') for i in range(6)]
}
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_batch_response
)
# Execute
@ -739,15 +730,15 @@ class TestSandboxSearch:
mock_result = MagicMock()
mock_result.scalars.return_value = mock_scalars
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
remote_sandbox_service._to_sandbox_info = AsyncMock(
side_effect=lambda stored: SandboxInfo(
id=stored.id,
created_by_user_id=stored.created_by_user_id,
sandbox_spec_id=stored.sandbox_spec_id,
status=SandboxStatus.RUNNING,
session_api_key='test-key',
created_at=stored.created_at,
)
# Mock the batch endpoint response
mock_batch_response = MagicMock()
mock_batch_response.raise_for_status.return_value = None
mock_batch_response.json.return_value = {
'runtimes': [create_runtime_data('sb1')]
}
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_batch_response
)
# Execute
@ -757,6 +748,80 @@ class TestSandboxSearch:
# Note: We can't easily verify the exact SQL query, but we can verify the method was called
remote_sandbox_service.db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_runtimes_batch_success(self, remote_sandbox_service):
"""Test successful batch runtime retrieval."""
# Setup
sandbox_ids = ['sb1', 'sb2', 'sb3']
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
'runtimes': [
create_runtime_data('sb1'),
create_runtime_data('sb2'),
create_runtime_data('sb3'),
]
}
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_response
)
# Execute
result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids)
# Verify
assert len(result) == 3
assert 'sb1' in result
assert 'sb2' in result
assert 'sb3' in result
assert result['sb1']['session_id'] == 'sb1'
# Verify the correct API call was made
remote_sandbox_service.httpx_client.request.assert_called_once_with(
'GET',
'https://api.example.com/sessions/batch',
headers={'X-API-Key': 'test-api-key'},
params=[('ids', 'sb1'), ('ids', 'sb2'), ('ids', 'sb3')],
)
@pytest.mark.asyncio
async def test_get_runtimes_batch_empty_list(self, remote_sandbox_service):
"""Test batch runtime retrieval with empty sandbox list."""
# Execute
result = await remote_sandbox_service._get_runtimes_batch([])
# Verify
assert result == {}
# Verify no API call was made
remote_sandbox_service.httpx_client.request.assert_not_called()
@pytest.mark.asyncio
async def test_get_runtimes_batch_partial_results(self, remote_sandbox_service):
"""Test batch runtime retrieval with partial results (some sandboxes not found)."""
# Setup
sandbox_ids = ['sb1', 'sb2', 'sb3']
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
'runtimes': [
create_runtime_data('sb1'),
create_runtime_data('sb3'),
# sb2 is missing from the response
]
}
remote_sandbox_service.httpx_client.request = AsyncMock(
return_value=mock_response
)
# Execute
result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids)
# Verify
assert len(result) == 2
assert 'sb1' in result
assert 'sb2' not in result # Missing from response
assert 'sb3' in result
@pytest.mark.asyncio
async def test_get_sandbox_exists(self, remote_sandbox_service):
"""Test getting an existing sandbox."""
@ -765,7 +830,7 @@ class TestSandboxSearch:
remote_sandbox_service._get_stored_sandbox = AsyncMock(
return_value=stored_sandbox
)
remote_sandbox_service._to_sandbox_info = AsyncMock(
remote_sandbox_service._to_sandbox_info = MagicMock(
return_value=SandboxInfo(
id='test-sandbox-123',
created_by_user_id='test-user-123',