mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
fix: eliminate N+1 performance bug in RemoteSandboxService with batch endpoint (#12105)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
a873af307a
commit
adff39507a
@ -477,7 +477,11 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
if sandbox.status in (None, SandboxStatus.ERROR):
|
||||
raise SandboxError(f'Sandbox status: {sandbox.status}')
|
||||
if sandbox.status == SandboxStatus.RUNNING:
|
||||
return
|
||||
# There are still bugs in the remote runtime - they report running while still just
|
||||
# starting resulting in a race condition. Manually check that it is actually
|
||||
# running.
|
||||
if await self._check_agent_server_alive(sandbox):
|
||||
return
|
||||
if sandbox.status != SandboxStatus.STARTING:
|
||||
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
|
||||
|
||||
@ -490,9 +494,19 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
if sandbox.status not in (SandboxStatus.STARTING, SandboxStatus.RUNNING):
|
||||
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
|
||||
if sandbox_info.status == SandboxStatus.RUNNING:
|
||||
return
|
||||
# There are still bugs in the remote runtime - they report running while still just
|
||||
# starting resulting in a race condition. Manually check that it is actually
|
||||
# running.
|
||||
if await self._check_agent_server_alive(sandbox_info):
|
||||
return
|
||||
raise SandboxError(f'Sandbox failed to start: {sandbox.id}')
|
||||
|
||||
async def _check_agent_server_alive(self, sandbox_info: SandboxInfo) -> bool:
|
||||
agent_server_url = self._get_agent_server_url(sandbox_info)
|
||||
url = f'{agent_server_url.rstrip("/")}/alive'
|
||||
response = await self.httpx_client.get(url)
|
||||
return response.is_success
|
||||
|
||||
def _get_agent_server_url(self, sandbox: SandboxInfo) -> str:
|
||||
"""Get agent server url for running sandbox."""
|
||||
exposed_urls = sandbox.exposed_urls
|
||||
|
||||
@ -122,18 +122,9 @@ class RemoteSandboxService(SandboxService):
|
||||
_logger.error(f'HTTP error for URL {url}: {e}')
|
||||
raise
|
||||
|
||||
async def _to_sandbox_info(
|
||||
def _to_sandbox_info(
|
||||
self, stored: StoredRemoteSandbox, runtime: dict[str, Any] | None = None
|
||||
) -> SandboxInfo:
|
||||
# If we did not get passsed runtime data, load some
|
||||
if runtime is None:
|
||||
try:
|
||||
runtime = await self._get_runtime(stored.id)
|
||||
except Exception:
|
||||
_logger.exception(
|
||||
f'Error getting runtime: {stored.id}', stack_info=True
|
||||
)
|
||||
|
||||
):
|
||||
status = self._get_sandbox_status_from_runtime(runtime)
|
||||
|
||||
# Get session_api_key and exposed urls
|
||||
@ -233,6 +224,41 @@ class RemoteSandboxService(SandboxService):
|
||||
runtime_data = response.json()
|
||||
return runtime_data
|
||||
|
||||
async def _get_runtimes_batch(
|
||||
self, sandbox_ids: list[str]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get multiple runtimes in a single batch request.
|
||||
|
||||
Args:
|
||||
sandbox_ids: List of sandbox IDs to fetch
|
||||
|
||||
Returns:
|
||||
Dictionary mapping sandbox_id to runtime data
|
||||
"""
|
||||
if not sandbox_ids:
|
||||
return {}
|
||||
|
||||
# Build query parameters for the batch endpoint
|
||||
params = [('ids', sandbox_id) for sandbox_id in sandbox_ids]
|
||||
|
||||
response = await self._send_runtime_api_request(
|
||||
'GET',
|
||||
'/sessions/batch',
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
batch_data = response.json()
|
||||
|
||||
# The batch endpoint should return a list of runtimes
|
||||
# Convert to a dictionary keyed by session_id for easy lookup
|
||||
runtimes_by_id = {}
|
||||
if batch_data and 'runtimes' in batch_data:
|
||||
for runtime in batch_data['runtimes']:
|
||||
if 'session_id' in runtime:
|
||||
runtimes_by_id[runtime['session_id']] = runtime
|
||||
|
||||
return runtimes_by_id
|
||||
|
||||
async def _init_environment(
|
||||
self, sandbox_spec: SandboxSpecInfo, sandbox_id: str
|
||||
) -> dict[str, str]:
|
||||
@ -283,13 +309,15 @@ class RemoteSandboxService(SandboxService):
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
# Convert stored callbacks to domain models
|
||||
items = await asyncio.gather(
|
||||
*[
|
||||
self._to_sandbox_info(stored_sandbox)
|
||||
for stored_sandbox in stored_sandboxes
|
||||
]
|
||||
)
|
||||
# Batch fetch runtime data for all sandboxes
|
||||
sandbox_ids = [stored_sandbox.id for stored_sandbox in stored_sandboxes]
|
||||
runtimes_by_id = await self._get_runtimes_batch(sandbox_ids)
|
||||
|
||||
# Convert stored sandboxes to domain models with runtime data
|
||||
items = [
|
||||
self._to_sandbox_info(stored_sandbox, runtimes_by_id.get(stored_sandbox.id))
|
||||
for stored_sandbox in stored_sandboxes
|
||||
]
|
||||
|
||||
return SandboxPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
@ -298,7 +326,16 @@ class RemoteSandboxService(SandboxService):
|
||||
stored_sandbox = await self._get_stored_sandbox(sandbox_id)
|
||||
if stored_sandbox is None:
|
||||
return None
|
||||
return await self._to_sandbox_info(stored_sandbox)
|
||||
|
||||
runtime = None
|
||||
try:
|
||||
runtime = await self._get_runtime(stored_sandbox.id)
|
||||
except Exception:
|
||||
_logger.exception(
|
||||
f'Error getting runtime: {stored_sandbox.id}', stack_info=True
|
||||
)
|
||||
|
||||
return self._to_sandbox_info(stored_sandbox, runtime)
|
||||
|
||||
async def get_sandbox_by_session_api_key(
|
||||
self, session_api_key: str
|
||||
@ -323,7 +360,7 @@ class RemoteSandboxService(SandboxService):
|
||||
sandbox = result.first()
|
||||
if sandbox is None:
|
||||
raise ValueError('sandbox_not_found')
|
||||
return await self._to_sandbox_info(sandbox, runtime)
|
||||
return self._to_sandbox_info(sandbox, runtime)
|
||||
except Exception:
|
||||
_logger.exception(
|
||||
'Error getting sandbox from session_api_key', stack_info=True
|
||||
@ -339,7 +376,7 @@ class RemoteSandboxService(SandboxService):
|
||||
try:
|
||||
runtime = await self._get_runtime(stored_sandbox.id)
|
||||
if runtime and runtime.get('session_api_key') == session_api_key:
|
||||
return await self._to_sandbox_info(stored_sandbox, runtime)
|
||||
return self._to_sandbox_info(stored_sandbox, runtime)
|
||||
except Exception:
|
||||
# Continue checking other sandboxes if one fails
|
||||
continue
|
||||
@ -412,7 +449,7 @@ class RemoteSandboxService(SandboxService):
|
||||
# Hack - result doesn't contain this
|
||||
runtime_data['pod_status'] = 'pending'
|
||||
|
||||
return await self._to_sandbox_info(stored_sandbox, runtime_data)
|
||||
return self._to_sandbox_info(stored_sandbox, runtime_data)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
_logger.error(f'Failed to start sandbox: {e}')
|
||||
@ -480,6 +517,55 @@ class RemoteSandboxService(SandboxService):
|
||||
_logger.error(f'Error deleting sandbox {sandbox_id}: {e}')
|
||||
return False
|
||||
|
||||
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
|
||||
"""Pause the oldest sandboxes if there are more than max_num_sandboxes running.
|
||||
In a multi user environment, this will pause sandboxes only for the current user.
|
||||
|
||||
Args:
|
||||
max_num_sandboxes: Maximum number of sandboxes to keep running
|
||||
|
||||
Returns:
|
||||
List of sandbox IDs that were paused
|
||||
"""
|
||||
if max_num_sandboxes <= 0:
|
||||
raise ValueError('max_num_sandboxes must be greater than 0')
|
||||
|
||||
response = await self._send_runtime_api_request(
|
||||
'GET',
|
||||
'/list',
|
||||
)
|
||||
content = response.json()
|
||||
running_session_ids = [
|
||||
runtime.get('session_id') for runtime in content['runtimes']
|
||||
]
|
||||
|
||||
query = await self._secure_select()
|
||||
query = query.filter(StoredRemoteSandbox.id.in_(running_session_ids)).order_by(
|
||||
StoredRemoteSandbox.created_at.desc()
|
||||
)
|
||||
running_sandboxes = list(await self.db_session.execute(query))
|
||||
|
||||
# If we're within the limit, no cleanup needed
|
||||
if len(running_sandboxes) <= max_num_sandboxes:
|
||||
return []
|
||||
|
||||
# Determine how many to pause
|
||||
num_to_pause = len(running_sandboxes) - max_num_sandboxes
|
||||
sandboxes_to_pause = running_sandboxes[:num_to_pause]
|
||||
|
||||
# Stop the oldest sandboxes
|
||||
paused_sandbox_ids = []
|
||||
for sandbox in sandboxes_to_pause:
|
||||
try:
|
||||
success = await self.pause_sandbox(sandbox.id)
|
||||
if success:
|
||||
paused_sandbox_ids.append(sandbox.id)
|
||||
except Exception:
|
||||
# Continue trying to pause other sandboxes even if one fails
|
||||
pass
|
||||
|
||||
return paused_sandbox_ids
|
||||
|
||||
|
||||
def _build_service_url(url: str, service_name: str):
|
||||
scheme, host_and_path = url.split('://')
|
||||
|
||||
@ -72,7 +72,7 @@ class SandboxService(ABC):
|
||||
"""
|
||||
|
||||
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
|
||||
"""Stop the oldest sandboxes if there are more than max_num_sandboxes running.
|
||||
"""Pause the oldest sandboxes if there are more than max_num_sandboxes running.
|
||||
In a multi user environment, this will pause sandboxes only for the current user.
|
||||
|
||||
Args:
|
||||
|
||||
@ -331,7 +331,7 @@ class TestSandboxInfoConversion:
|
||||
runtime_data = create_runtime_data(status='running', pod_status='ready')
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(
|
||||
sandbox_info = remote_sandbox_service._to_sandbox_info(
|
||||
stored_sandbox, runtime_data
|
||||
)
|
||||
|
||||
@ -358,7 +358,7 @@ class TestSandboxInfoConversion:
|
||||
runtime_data = create_runtime_data(status='running', pod_status='pending')
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(
|
||||
sandbox_info = remote_sandbox_service._to_sandbox_info(
|
||||
stored_sandbox, runtime_data
|
||||
)
|
||||
|
||||
@ -367,23 +367,6 @@ class TestSandboxInfoConversion:
|
||||
assert sandbox_info.session_api_key == 'test-session-key'
|
||||
assert sandbox_info.exposed_urls is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_sandbox_info_without_runtime(self, remote_sandbox_service):
|
||||
"""Test conversion to SandboxInfo without runtime data."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
remote_sandbox_service._get_runtime = AsyncMock(
|
||||
side_effect=Exception('Runtime not found')
|
||||
)
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
|
||||
|
||||
# Verify
|
||||
assert sandbox_info.status == SandboxStatus.MISSING
|
||||
assert sandbox_info.session_api_key is None
|
||||
assert sandbox_info.exposed_urls is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_sandbox_info_loads_runtime_when_none_provided(
|
||||
self, remote_sandbox_service
|
||||
@ -391,15 +374,12 @@ class TestSandboxInfoConversion:
|
||||
"""Test that runtime data is loaded when not provided."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
|
||||
sandbox_info = remote_sandbox_service._to_sandbox_info(stored_sandbox, None)
|
||||
|
||||
# Verify
|
||||
remote_sandbox_service._get_runtime.assert_called_once_with('test-sandbox-123')
|
||||
assert sandbox_info.status == SandboxStatus.RUNNING
|
||||
assert sandbox_info.status == SandboxStatus.MISSING
|
||||
|
||||
|
||||
class TestSandboxLifecycle:
|
||||
@ -677,15 +657,18 @@ class TestSandboxSearch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
side_effect=lambda stored: SandboxInfo(
|
||||
id=stored.id,
|
||||
created_by_user_id=stored.created_by_user_id,
|
||||
sandbox_spec_id=stored.sandbox_spec_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test-key',
|
||||
created_at=stored.created_at,
|
||||
)
|
||||
|
||||
# Mock the batch endpoint response
|
||||
mock_batch_response = MagicMock()
|
||||
mock_batch_response.raise_for_status.return_value = None
|
||||
mock_batch_response.json.return_value = {
|
||||
'runtimes': [
|
||||
create_runtime_data('sb1'),
|
||||
create_runtime_data('sb2'),
|
||||
]
|
||||
}
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_batch_response
|
||||
)
|
||||
|
||||
# Execute
|
||||
@ -697,6 +680,14 @@ class TestSandboxSearch:
|
||||
assert result.items[0].id == 'sb1'
|
||||
assert result.items[1].id == 'sb2'
|
||||
|
||||
# Verify that the batch endpoint was called
|
||||
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||
'GET',
|
||||
'https://api.example.com/sessions/batch',
|
||||
headers={'X-API-Key': 'test-api-key'},
|
||||
params=[('ids', 'sb1'), ('ids', 'sb2')],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_sandboxes_with_pagination(self, remote_sandbox_service):
|
||||
"""Test sandbox search with pagination."""
|
||||
@ -710,15 +701,15 @@ class TestSandboxSearch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
side_effect=lambda stored: SandboxInfo(
|
||||
id=stored.id,
|
||||
created_by_user_id=stored.created_by_user_id,
|
||||
sandbox_spec_id=stored.sandbox_spec_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test-key',
|
||||
created_at=stored.created_at,
|
||||
)
|
||||
|
||||
# Mock the batch endpoint response
|
||||
mock_batch_response = MagicMock()
|
||||
mock_batch_response.raise_for_status.return_value = None
|
||||
mock_batch_response.json.return_value = {
|
||||
'runtimes': [create_runtime_data(f'sb{i}') for i in range(6)]
|
||||
}
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_batch_response
|
||||
)
|
||||
|
||||
# Execute
|
||||
@ -739,15 +730,15 @@ class TestSandboxSearch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
side_effect=lambda stored: SandboxInfo(
|
||||
id=stored.id,
|
||||
created_by_user_id=stored.created_by_user_id,
|
||||
sandbox_spec_id=stored.sandbox_spec_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test-key',
|
||||
created_at=stored.created_at,
|
||||
)
|
||||
|
||||
# Mock the batch endpoint response
|
||||
mock_batch_response = MagicMock()
|
||||
mock_batch_response.raise_for_status.return_value = None
|
||||
mock_batch_response.json.return_value = {
|
||||
'runtimes': [create_runtime_data('sb1')]
|
||||
}
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_batch_response
|
||||
)
|
||||
|
||||
# Execute
|
||||
@ -757,6 +748,80 @@ class TestSandboxSearch:
|
||||
# Note: We can't easily verify the exact SQL query, but we can verify the method was called
|
||||
remote_sandbox_service.db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runtimes_batch_success(self, remote_sandbox_service):
|
||||
"""Test successful batch runtime retrieval."""
|
||||
# Setup
|
||||
sandbox_ids = ['sb1', 'sb2', 'sb3']
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'runtimes': [
|
||||
create_runtime_data('sb1'),
|
||||
create_runtime_data('sb2'),
|
||||
create_runtime_data('sb3'),
|
||||
]
|
||||
}
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 3
|
||||
assert 'sb1' in result
|
||||
assert 'sb2' in result
|
||||
assert 'sb3' in result
|
||||
assert result['sb1']['session_id'] == 'sb1'
|
||||
|
||||
# Verify the correct API call was made
|
||||
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||
'GET',
|
||||
'https://api.example.com/sessions/batch',
|
||||
headers={'X-API-Key': 'test-api-key'},
|
||||
params=[('ids', 'sb1'), ('ids', 'sb2'), ('ids', 'sb3')],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runtimes_batch_empty_list(self, remote_sandbox_service):
|
||||
"""Test batch runtime retrieval with empty sandbox list."""
|
||||
# Execute
|
||||
result = await remote_sandbox_service._get_runtimes_batch([])
|
||||
|
||||
# Verify
|
||||
assert result == {}
|
||||
# Verify no API call was made
|
||||
remote_sandbox_service.httpx_client.request.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runtimes_batch_partial_results(self, remote_sandbox_service):
|
||||
"""Test batch runtime retrieval with partial results (some sandboxes not found)."""
|
||||
# Setup
|
||||
sandbox_ids = ['sb1', 'sb2', 'sb3']
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'runtimes': [
|
||||
create_runtime_data('sb1'),
|
||||
create_runtime_data('sb3'),
|
||||
# sb2 is missing from the response
|
||||
]
|
||||
}
|
||||
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert 'sb1' in result
|
||||
assert 'sb2' not in result # Missing from response
|
||||
assert 'sb3' in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_exists(self, remote_sandbox_service):
|
||||
"""Test getting an existing sandbox."""
|
||||
@ -765,7 +830,7 @@ class TestSandboxSearch:
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
remote_sandbox_service._to_sandbox_info = MagicMock(
|
||||
return_value=SandboxInfo(
|
||||
id='test-sandbox-123',
|
||||
created_by_user_id='test-user-123',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user