diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index d1e7f86d75..076c478478 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -252,10 +252,9 @@ class RemoteSandboxService(SandboxService): # The batch endpoint should return a list of runtimes # Convert to a dictionary keyed by session_id for easy lookup runtimes_by_id = {} - if batch_data and 'runtimes' in batch_data: - for runtime in batch_data['runtimes']: - if 'session_id' in runtime: - runtimes_by_id[runtime['session_id']] = runtime + for runtime in batch_data: + if runtime and 'session_id' in runtime: + runtimes_by_id[runtime['session_id']] = runtime return runtimes_by_id @@ -566,6 +565,32 @@ class RemoteSandboxService(SandboxService): return paused_sandbox_ids + async def batch_get_sandboxes( + self, sandbox_ids: list[str] + ) -> list[SandboxInfo | None]: + """Get a batch of sandboxes, returning None for any which were not found.""" + if not sandbox_ids: + return [] + query = await self._secure_select() + query = query.filter(StoredRemoteSandbox.id.in_(sandbox_ids)) + stored_remote_sandboxes = await self.db_session.execute(query) + stored_remote_sandboxes_by_id = { + stored_remote_sandbox[0].id: stored_remote_sandbox[0] + for stored_remote_sandbox in stored_remote_sandboxes + } + runtimes_by_id = await self._get_runtimes_batch( + list(stored_remote_sandboxes_by_id) + ) + results = [] + for sandbox_id in sandbox_ids: + stored_remote_sandbox = stored_remote_sandboxes_by_id.get(sandbox_id) + result = None + if stored_remote_sandbox: + runtime = runtimes_by_id.get(sandbox_id) + result = self._to_sandbox_info(stored_remote_sandbox, runtime) + results.append(result) + return results + def _build_service_url(url: str, service_name: str): scheme, host_and_path = url.split('://') diff --git a/tests/unit/app_server/test_remote_sandbox_service.py b/tests/unit/app_server/test_remote_sandbox_service.py index bb950c732c..c70ad7d324 100644 --- a/tests/unit/app_server/test_remote_sandbox_service.py +++ b/tests/unit/app_server/test_remote_sandbox_service.py @@ -755,13 +755,11 @@ class TestSandboxSearch: sandbox_ids = ['sb1', 'sb2', 'sb3'] mock_response = MagicMock() mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - 'runtimes': [ - create_runtime_data('sb1'), - create_runtime_data('sb2'), - create_runtime_data('sb3'), - ] - } + mock_response.json.return_value = [ + create_runtime_data('sb1'), + create_runtime_data('sb2'), + create_runtime_data('sb3'), + ] remote_sandbox_service.httpx_client.request = AsyncMock( return_value=mock_response ) @@ -802,13 +800,11 @@ class TestSandboxSearch: sandbox_ids = ['sb1', 'sb2', 'sb3'] mock_response = MagicMock() mock_response.raise_for_status.return_value = None - mock_response.json.return_value = { - 'runtimes': [ - create_runtime_data('sb1'), - create_runtime_data('sb3'), - # sb2 is missing from the response - ] - } + mock_response.json.return_value = [ + create_runtime_data('sb1'), + create_runtime_data('sb3'), + # sb2 is missing from the response + ] remote_sandbox_service.httpx_client.request = AsyncMock( return_value=mock_response )