diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index a8d490489c..db30710f76 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -477,7 +477,11 @@ class LiveStatusAppConversationService(AppConversationServiceBase): if sandbox.status in (None, SandboxStatus.ERROR): raise SandboxError(f'Sandbox status: {sandbox.status}') if sandbox.status == SandboxStatus.RUNNING: - return + # There are still bugs in the remote runtime - they report running while still just + # starting resulting in a race condition. Manually check that it is actually + # running. + if await self._check_agent_server_alive(sandbox): + return if sandbox.status != SandboxStatus.STARTING: raise SandboxError(f'Sandbox not startable: {sandbox.id}') @@ -490,9 +494,19 @@ class LiveStatusAppConversationService(AppConversationServiceBase): if sandbox.status not in (SandboxStatus.STARTING, SandboxStatus.RUNNING): raise SandboxError(f'Sandbox not startable: {sandbox.id}') if sandbox_info.status == SandboxStatus.RUNNING: - return + # There are still bugs in the remote runtime - they report running while still just + # starting resulting in a race condition. Manually check that it is actually + # running. + if await self._check_agent_server_alive(sandbox_info): + return raise SandboxError(f'Sandbox failed to start: {sandbox.id}') + async def _check_agent_server_alive(self, sandbox_info: SandboxInfo) -> bool: + agent_server_url = self._get_agent_server_url(sandbox_info) + url = f'{agent_server_url.rstrip("/")}/alive' + response = await self.httpx_client.get(url) + return response.is_success + def _get_agent_server_url(self, sandbox: SandboxInfo) -> str: """Get agent server url for running sandbox.""" exposed_urls = sandbox.exposed_urls diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index dc1993440c..d1e7f86d75 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -122,18 +122,9 @@ class RemoteSandboxService(SandboxService): _logger.error(f'HTTP error for URL {url}: {e}') raise - async def _to_sandbox_info( + def _to_sandbox_info( self, stored: StoredRemoteSandbox, runtime: dict[str, Any] | None = None - ) -> SandboxInfo: - # If we did not get passsed runtime data, load some - if runtime is None: - try: - runtime = await self._get_runtime(stored.id) - except Exception: - _logger.exception( - f'Error getting runtime: {stored.id}', stack_info=True - ) - + ): status = self._get_sandbox_status_from_runtime(runtime) # Get session_api_key and exposed urls @@ -233,6 +224,41 @@ class RemoteSandboxService(SandboxService): runtime_data = response.json() return runtime_data + async def _get_runtimes_batch( + self, sandbox_ids: list[str] + ) -> dict[str, dict[str, Any]]: + """Get multiple runtimes in a single batch request. + + Args: + sandbox_ids: List of sandbox IDs to fetch + + Returns: + Dictionary mapping sandbox_id to runtime data + """ + if not sandbox_ids: + return {} + + # Build query parameters for the batch endpoint + params = [('ids', sandbox_id) for sandbox_id in sandbox_ids] + + response = await self._send_runtime_api_request( + 'GET', + '/sessions/batch', + params=params, + ) + response.raise_for_status() + batch_data = response.json() + + # The batch endpoint should return a list of runtimes + # Convert to a dictionary keyed by session_id for easy lookup + runtimes_by_id = {} + if batch_data and 'runtimes' in batch_data: + for runtime in batch_data['runtimes']: + if 'session_id' in runtime: + runtimes_by_id[runtime['session_id']] = runtime + + return runtimes_by_id + async def _init_environment( self, sandbox_spec: SandboxSpecInfo, sandbox_id: str ) -> dict[str, str]: @@ -283,13 +309,15 @@ class RemoteSandboxService(SandboxService): if has_more: next_page_id = str(offset + limit) - # Convert stored callbacks to domain models - items = await asyncio.gather( - *[ - self._to_sandbox_info(stored_sandbox) - for stored_sandbox in stored_sandboxes - ] - ) + # Batch fetch runtime data for all sandboxes + sandbox_ids = [stored_sandbox.id for stored_sandbox in stored_sandboxes] + runtimes_by_id = await self._get_runtimes_batch(sandbox_ids) + + # Convert stored sandboxes to domain models with runtime data + items = [ + self._to_sandbox_info(stored_sandbox, runtimes_by_id.get(stored_sandbox.id)) + for stored_sandbox in stored_sandboxes + ] return SandboxPage(items=items, next_page_id=next_page_id) @@ -298,7 +326,16 @@ class RemoteSandboxService(SandboxService): stored_sandbox = await self._get_stored_sandbox(sandbox_id) if stored_sandbox is None: return None - return await self._to_sandbox_info(stored_sandbox) + + runtime = None + try: + runtime = await self._get_runtime(stored_sandbox.id) + except Exception: + _logger.exception( + f'Error getting runtime: {stored_sandbox.id}', stack_info=True + ) + + return self._to_sandbox_info(stored_sandbox, runtime) async def get_sandbox_by_session_api_key( self, session_api_key: str @@ -323,7 +360,7 @@ class RemoteSandboxService(SandboxService): sandbox = result.first() if sandbox is None: raise ValueError('sandbox_not_found') - return await self._to_sandbox_info(sandbox, runtime) + return self._to_sandbox_info(sandbox, runtime) except Exception: _logger.exception( 'Error getting sandbox from session_api_key', stack_info=True @@ -339,7 +376,7 @@ class RemoteSandboxService(SandboxService): try: runtime = await self._get_runtime(stored_sandbox.id) if runtime and runtime.get('session_api_key') == session_api_key: - return await self._to_sandbox_info(stored_sandbox, runtime) + return self._to_sandbox_info(stored_sandbox, runtime) except Exception: # Continue checking other sandboxes if one fails continue @@ -412,7 +449,7 @@ class RemoteSandboxService(SandboxService): # Hack - result doesn't contain this runtime_data['pod_status'] = 'pending' - return await self._to_sandbox_info(stored_sandbox, runtime_data) + return self._to_sandbox_info(stored_sandbox, runtime_data) except httpx.HTTPError as e: _logger.error(f'Failed to start sandbox: {e}') @@ -480,6 +517,55 @@ class RemoteSandboxService(SandboxService): _logger.error(f'Error deleting sandbox {sandbox_id}: {e}') return False + async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]: + """Pause the oldest sandboxes if there are more than max_num_sandboxes running. + In a multi user environment, this will pause sandboxes only for the current user. + + Args: + max_num_sandboxes: Maximum number of sandboxes to keep running + + Returns: + List of sandbox IDs that were paused + """ + if max_num_sandboxes <= 0: + raise ValueError('max_num_sandboxes must be greater than 0') + + response = await self._send_runtime_api_request( + 'GET', + '/list', + ) + content = response.json() + running_session_ids = [ + runtime.get('session_id') for runtime in content['runtimes'] + ] + + query = await self._secure_select() + query = query.filter(StoredRemoteSandbox.id.in_(running_session_ids)).order_by( + StoredRemoteSandbox.created_at.desc() + ) + running_sandboxes = list(await self.db_session.execute(query)) + + # If we're within the limit, no cleanup needed + if len(running_sandboxes) <= max_num_sandboxes: + return [] + + # Determine how many to pause + num_to_pause = len(running_sandboxes) - max_num_sandboxes + sandboxes_to_pause = running_sandboxes[:num_to_pause] + + # Stop the oldest sandboxes + paused_sandbox_ids = [] + for sandbox in sandboxes_to_pause: + try: + success = await self.pause_sandbox(sandbox.id) + if success: + paused_sandbox_ids.append(sandbox.id) + except Exception: + # Continue trying to pause other sandboxes even if one fails + pass + + return paused_sandbox_ids + def _build_service_url(url: str, service_name: str): scheme, host_and_path = url.split('://') diff --git a/openhands/app_server/sandbox/sandbox_service.py b/openhands/app_server/sandbox/sandbox_service.py index e7f8afe9c2..45274975d7 100644 --- a/openhands/app_server/sandbox/sandbox_service.py +++ b/openhands/app_server/sandbox/sandbox_service.py @@ -72,7 +72,7 @@ class SandboxService(ABC): """ async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]: - """Stop the oldest sandboxes if there are more than max_num_sandboxes running. + """Pause the oldest sandboxes if there are more than max_num_sandboxes running. In a multi user environment, this will pause sandboxes only for the current user. Args: diff --git a/tests/unit/app_server/test_remote_sandbox_service.py b/tests/unit/app_server/test_remote_sandbox_service.py index 5802e46ecb..bb950c732c 100644 --- a/tests/unit/app_server/test_remote_sandbox_service.py +++ b/tests/unit/app_server/test_remote_sandbox_service.py @@ -331,7 +331,7 @@ class TestSandboxInfoConversion: runtime_data = create_runtime_data(status='running', pod_status='ready') # Execute - sandbox_info = await remote_sandbox_service._to_sandbox_info( + sandbox_info = remote_sandbox_service._to_sandbox_info( stored_sandbox, runtime_data ) @@ -358,7 +358,7 @@ class TestSandboxInfoConversion: runtime_data = create_runtime_data(status='running', pod_status='pending') # Execute - sandbox_info = await remote_sandbox_service._to_sandbox_info( + sandbox_info = remote_sandbox_service._to_sandbox_info( stored_sandbox, runtime_data ) @@ -367,23 +367,6 @@ class TestSandboxInfoConversion: assert sandbox_info.session_api_key == 'test-session-key' assert sandbox_info.exposed_urls is None - @pytest.mark.asyncio - async def test_to_sandbox_info_without_runtime(self, remote_sandbox_service): - """Test conversion to SandboxInfo without runtime data.""" - # Setup - stored_sandbox = create_stored_sandbox() - remote_sandbox_service._get_runtime = AsyncMock( - side_effect=Exception('Runtime not found') - ) - - # Execute - sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox) - - # Verify - assert sandbox_info.status == SandboxStatus.MISSING - assert sandbox_info.session_api_key is None - assert sandbox_info.exposed_urls is None - @pytest.mark.asyncio async def test_to_sandbox_info_loads_runtime_when_none_provided( self, remote_sandbox_service @@ -391,15 +374,12 @@ class TestSandboxInfoConversion: """Test that runtime data is loaded when not provided.""" # Setup stored_sandbox = create_stored_sandbox() - runtime_data = create_runtime_data() - remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data) # Execute - sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox) + sandbox_info = remote_sandbox_service._to_sandbox_info(stored_sandbox, None) # Verify - remote_sandbox_service._get_runtime.assert_called_once_with('test-sandbox-123') - assert sandbox_info.status == SandboxStatus.RUNNING + assert sandbox_info.status == SandboxStatus.MISSING class TestSandboxLifecycle: @@ -677,15 +657,18 @@ class TestSandboxSearch: mock_result = MagicMock() mock_result.scalars.return_value = mock_scalars remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result) - remote_sandbox_service._to_sandbox_info = AsyncMock( - side_effect=lambda stored: SandboxInfo( - id=stored.id, - created_by_user_id=stored.created_by_user_id, - sandbox_spec_id=stored.sandbox_spec_id, - status=SandboxStatus.RUNNING, - session_api_key='test-key', - created_at=stored.created_at, - ) + + # Mock the batch endpoint response + mock_batch_response = MagicMock() + mock_batch_response.raise_for_status.return_value = None + mock_batch_response.json.return_value = { + 'runtimes': [ + create_runtime_data('sb1'), + create_runtime_data('sb2'), + ] + } + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_batch_response ) # Execute @@ -697,6 +680,14 @@ class TestSandboxSearch: assert result.items[0].id == 'sb1' assert result.items[1].id == 'sb2' + # Verify that the batch endpoint was called + remote_sandbox_service.httpx_client.request.assert_called_once_with( + 'GET', + 'https://api.example.com/sessions/batch', + headers={'X-API-Key': 'test-api-key'}, + params=[('ids', 'sb1'), ('ids', 'sb2')], + ) + @pytest.mark.asyncio async def test_search_sandboxes_with_pagination(self, remote_sandbox_service): """Test sandbox search with pagination.""" @@ -710,15 +701,15 @@ class TestSandboxSearch: mock_result = MagicMock() mock_result.scalars.return_value = mock_scalars remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result) - remote_sandbox_service._to_sandbox_info = AsyncMock( - side_effect=lambda stored: SandboxInfo( - id=stored.id, - created_by_user_id=stored.created_by_user_id, - sandbox_spec_id=stored.sandbox_spec_id, - status=SandboxStatus.RUNNING, - session_api_key='test-key', - created_at=stored.created_at, - ) + + # Mock the batch endpoint response + mock_batch_response = MagicMock() + mock_batch_response.raise_for_status.return_value = None + mock_batch_response.json.return_value = { + 'runtimes': [create_runtime_data(f'sb{i}') for i in range(6)] + } + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_batch_response ) # Execute @@ -739,15 +730,15 @@ class TestSandboxSearch: mock_result = MagicMock() mock_result.scalars.return_value = mock_scalars remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result) - remote_sandbox_service._to_sandbox_info = AsyncMock( - side_effect=lambda stored: SandboxInfo( - id=stored.id, - created_by_user_id=stored.created_by_user_id, - sandbox_spec_id=stored.sandbox_spec_id, - status=SandboxStatus.RUNNING, - session_api_key='test-key', - created_at=stored.created_at, - ) + + # Mock the batch endpoint response + mock_batch_response = MagicMock() + mock_batch_response.raise_for_status.return_value = None + mock_batch_response.json.return_value = { + 'runtimes': [create_runtime_data('sb1')] + } + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_batch_response ) # Execute @@ -757,6 +748,80 @@ class TestSandboxSearch: # Note: We can't easily verify the exact SQL query, but we can verify the method was called remote_sandbox_service.db_session.execute.assert_called_once() + @pytest.mark.asyncio + async def test_get_runtimes_batch_success(self, remote_sandbox_service): + """Test successful batch runtime retrieval.""" + # Setup + sandbox_ids = ['sb1', 'sb2', 'sb3'] + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + 'runtimes': [ + create_runtime_data('sb1'), + create_runtime_data('sb2'), + create_runtime_data('sb3'), + ] + } + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_response + ) + + # Execute + result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids) + + # Verify + assert len(result) == 3 + assert 'sb1' in result + assert 'sb2' in result + assert 'sb3' in result + assert result['sb1']['session_id'] == 'sb1' + + # Verify the correct API call was made + remote_sandbox_service.httpx_client.request.assert_called_once_with( + 'GET', + 'https://api.example.com/sessions/batch', + headers={'X-API-Key': 'test-api-key'}, + params=[('ids', 'sb1'), ('ids', 'sb2'), ('ids', 'sb3')], + ) + + @pytest.mark.asyncio + async def test_get_runtimes_batch_empty_list(self, remote_sandbox_service): + """Test batch runtime retrieval with empty sandbox list.""" + # Execute + result = await remote_sandbox_service._get_runtimes_batch([]) + + # Verify + assert result == {} + # Verify no API call was made + remote_sandbox_service.httpx_client.request.assert_not_called() + + @pytest.mark.asyncio + async def test_get_runtimes_batch_partial_results(self, remote_sandbox_service): + """Test batch runtime retrieval with partial results (some sandboxes not found).""" + # Setup + sandbox_ids = ['sb1', 'sb2', 'sb3'] + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + 'runtimes': [ + create_runtime_data('sb1'), + create_runtime_data('sb3'), + # sb2 is missing from the response + ] + } + remote_sandbox_service.httpx_client.request = AsyncMock( + return_value=mock_response + ) + + # Execute + result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids) + + # Verify + assert len(result) == 2 + assert 'sb1' in result + assert 'sb2' not in result # Missing from response + assert 'sb3' in result + @pytest.mark.asyncio async def test_get_sandbox_exists(self, remote_sandbox_service): """Test getting an existing sandbox.""" @@ -765,7 +830,7 @@ class TestSandboxSearch: remote_sandbox_service._get_stored_sandbox = AsyncMock( return_value=stored_sandbox ) - remote_sandbox_service._to_sandbox_info = AsyncMock( + remote_sandbox_service._to_sandbox_info = MagicMock( return_value=SandboxInfo( id='test-sandbox-123', created_by_user_id='test-user-123',