diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index f970a60e7e..1cfd122c07 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -478,7 +478,15 @@ class LiveStatusAppConversationService(AppConversationServiceBase): """Wait for sandbox to start and return info.""" # Get or create the sandbox if not task.request.sandbox_id: - sandbox = await self.sandbox_service.start_sandbox() + # Convert conversation_id to hex string if present + sandbox_id_str = ( + task.request.conversation_id.hex + if task.request.conversation_id is not None + else None + ) + sandbox = await self.sandbox_service.start_sandbox( + sandbox_id=sandbox_id_str + ) task.sandbox_id = sandbox.id else: sandbox_info = await self.sandbox_service.get_sandbox( diff --git a/openhands/app_server/sandbox/docker_sandbox_service.py b/openhands/app_server/sandbox/docker_sandbox_service.py index 82857ab098..a85d3c843d 100644 --- a/openhands/app_server/sandbox/docker_sandbox_service.py +++ b/openhands/app_server/sandbox/docker_sandbox_service.py @@ -294,7 +294,9 @@ class DockerSandboxService(SandboxService): except (NotFound, APIError): return None - async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: + async def start_sandbox( + self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None + ) -> SandboxInfo: """Start a new sandbox.""" # Enforce sandbox limits by cleaning up old sandboxes await self.pause_old_sandboxes(self.max_num_sandboxes - 1) @@ -309,10 +311,12 @@ class DockerSandboxService(SandboxService): raise ValueError('Sandbox Spec not found') sandbox_spec = sandbox_spec_maybe - # Generate container ID and session api key - container_name = ( - f'{self.container_name_prefix}{base62.encodebytes(os.urandom(16))}' - ) + # Generate a sandbox id if none was provided + if sandbox_id is None: + sandbox_id = base62.encodebytes(os.urandom(16)) + + # Generate container name and session api key + container_name = f'{self.container_name_prefix}{sandbox_id}' session_api_key = base62.encodebytes(os.urandom(32)) # Prepare environment variables diff --git a/openhands/app_server/sandbox/process_sandbox_service.py b/openhands/app_server/sandbox/process_sandbox_service.py index 200bf62c44..328c400d10 100644 --- a/openhands/app_server/sandbox/process_sandbox_service.py +++ b/openhands/app_server/sandbox/process_sandbox_service.py @@ -286,7 +286,9 @@ class ProcessSandboxService(SandboxService): return None - async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: + async def start_sandbox( + self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None + ) -> SandboxInfo: """Start a new sandbox.""" # Get sandbox spec if sandbox_spec_id is None: @@ -300,7 +302,9 @@ class ProcessSandboxService(SandboxService): sandbox_spec = sandbox_spec_maybe # Generate unique sandbox ID and session API key - sandbox_id = base62.encodebytes(os.urandom(16)) + # Use provided sandbox_id if available, otherwise generate a random one + if sandbox_id is None: + sandbox_id = base62.encodebytes(os.urandom(16)) session_api_key = base62.encodebytes(os.urandom(32)) # Find available port diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index f8850c3a56..9e5b7740fb 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -383,7 +383,9 @@ class RemoteSandboxService(SandboxService): return None - async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: + async def start_sandbox( + self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None + ) -> SandboxInfo: """Start a new sandbox by creating a remote runtime.""" try: # Enforce sandbox limits by cleaning up old sandboxes @@ -402,8 +404,9 @@ class RemoteSandboxService(SandboxService): raise ValueError('Sandbox Spec not found') sandbox_spec = sandbox_spec_maybe - # Create a unique id - sandbox_id = base62.encodebytes(os.urandom(16)) + # Create a unique id, use provided sandbox_id if available + if sandbox_id is None: + sandbox_id = base62.encodebytes(os.urandom(16)) # get user id user_id = await self.user_context.get_user_id() diff --git a/openhands/app_server/sandbox/sandbox_service.py b/openhands/app_server/sandbox/sandbox_service.py index 7319e4f5c7..efe8a9120c 100644 --- a/openhands/app_server/sandbox/sandbox_service.py +++ b/openhands/app_server/sandbox/sandbox_service.py @@ -50,10 +50,14 @@ class SandboxService(ABC): return results @abstractmethod - async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: + async def start_sandbox( + self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None + ) -> SandboxInfo: """Begin the process of starting a sandbox. Return the info on the new sandbox. If no spec is selected, use the default. + If sandbox_id is provided, it will be used as the sandbox identifier instead + of generating a random one. """ @abstractmethod diff --git a/tests/unit/app_server/test_docker_sandbox_service.py b/tests/unit/app_server/test_docker_sandbox_service.py index 714d030685..5a1e582a19 100644 --- a/tests/unit/app_server/test_docker_sandbox_service.py +++ b/tests/unit/app_server/test_docker_sandbox_service.py @@ -430,6 +430,45 @@ class TestDockerSandboxService: ): await service.start_sandbox(sandbox_spec_id='nonexistent') + @patch('openhands.app_server.sandbox.docker_sandbox_service.base62.encodebytes') + @patch('os.urandom') + async def test_start_sandbox_with_sandbox_id( + self, mock_urandom, mock_encodebytes, service + ): + """Test starting sandbox with a specified sandbox_id.""" + # Setup - only need urandom for session key + mock_urandom.return_value = b'session_key' + mock_encodebytes.return_value = 'test_session_key' + + mock_container = MagicMock() + mock_container.name = 'oh-test-custom_sandbox_id' + mock_container.status = 'running' + mock_container.image.tags = ['test-image:latest'] + mock_container.attrs = { + 'Created': '2024-01-15T10:30:00.000000000Z', + 'Config': { + 'Env': ['OH_SESSION_API_KEYS_0=test_session_key', 'TEST_VAR=test_value'] + }, + 'NetworkSettings': {'Ports': {}}, + } + + service.docker_client.containers.run.return_value = mock_container + + with ( + patch.object(service, '_find_unused_port', side_effect=[12345, 12346]), + patch.object(service, 'pause_old_sandboxes', return_value=[]), + ): + # Execute with custom sandbox_id + result = await service.start_sandbox(sandbox_id='custom_sandbox_id') + + # Verify + assert result is not None + assert result.id == 'oh-test-custom_sandbox_id' + + # Verify container was created with the custom sandbox ID in the name + call_args = service.docker_client.containers.run.call_args + assert call_args[1]['name'] == 'oh-test-custom_sandbox_id' + async def test_start_sandbox_docker_error(self, service): """Test handling of Docker errors during sandbox startup.""" # Setup diff --git a/tests/unit/app_server/test_process_sandbox_service.py b/tests/unit/app_server/test_process_sandbox_service.py index f39384241d..8cb66225ba 100644 --- a/tests/unit/app_server/test_process_sandbox_service.py +++ b/tests/unit/app_server/test_process_sandbox_service.py @@ -185,6 +185,41 @@ class TestProcessSandboxService: result = await process_sandbox_service.delete_sandbox('nonexistent') assert result is False + @pytest.mark.asyncio + async def test_start_sandbox_with_sandbox_id(self, process_sandbox_service): + """Test starting a sandbox with a specified sandbox_id.""" + # Mock subprocess and waiting for server + with ( + patch.object( + process_sandbox_service, '_start_agent_process' + ) as mock_start_process, + patch.object( + process_sandbox_service, '_wait_for_server_ready', return_value=True + ), + patch.object( + process_sandbox_service, + '_get_process_status', + return_value=SandboxStatus.RUNNING, + ), + ): + mock_process = MagicMock() + mock_process.pid = 1234 + mock_start_process.return_value = mock_process + + # Mock successful health check response + mock_response = MagicMock() + mock_response.status_code = 200 + process_sandbox_service.httpx_client.get.return_value = mock_response + + # Execute with custom sandbox_id + result = await process_sandbox_service.start_sandbox( + sandbox_id='custom_sandbox_id' + ) + + # Verify + assert result is not None + assert result.id == 'custom_sandbox_id' + @patch('psutil.Process') def test_get_process_status_paused( self, mock_process_class, process_sandbox_service diff --git a/tests/unit/app_server/test_remote_sandbox_service.py b/tests/unit/app_server/test_remote_sandbox_service.py index c70ad7d324..1bdcf87d59 100644 --- a/tests/unit/app_server/test_remote_sandbox_service.py +++ b/tests/unit/app_server/test_remote_sandbox_service.py @@ -450,6 +450,34 @@ class TestSandboxLifecycle: with pytest.raises(ValueError, match='Sandbox Spec not found'): await remote_sandbox_service.start_sandbox('non-existent-spec') + @pytest.mark.asyncio + async def test_start_sandbox_with_sandbox_id( + self, remote_sandbox_service, mock_sandbox_spec_service + ): + """Test starting sandbox with a specified sandbox_id.""" + # Setup + mock_response = MagicMock() + mock_response.json.return_value = create_runtime_data( + session_id='custom_sandbox_id' + ) + remote_sandbox_service.httpx_client.request.return_value = mock_response + remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[]) + + # Mock database operations + remote_sandbox_service.db_session.add = MagicMock() + remote_sandbox_service.db_session.commit = AsyncMock() + + # Execute with custom sandbox_id - should not need base62 encoding + sandbox_info = await remote_sandbox_service.start_sandbox( + sandbox_id='custom_sandbox_id' + ) + + # Verify the custom sandbox_id is used + assert sandbox_info.id == 'custom_sandbox_id' + # Verify the stored sandbox used the custom ID + add_call_args = remote_sandbox_service.db_session.add.call_args[0][0] + assert add_call_args.id == 'custom_sandbox_id' + @pytest.mark.asyncio async def test_start_sandbox_http_error(self, remote_sandbox_service): """Test sandbox start with HTTP error.""" diff --git a/tests/unit/app_server/test_sandbox_service.py b/tests/unit/app_server/test_sandbox_service.py index f3eea1d2ea..c07367c91c 100644 --- a/tests/unit/app_server/test_sandbox_service.py +++ b/tests/unit/app_server/test_sandbox_service.py @@ -46,8 +46,10 @@ class MockSandboxService(SandboxService): ) -> SandboxInfo | None: return await self.get_sandbox_by_session_api_key_mock(session_api_key) - async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: - return await self.start_sandbox_mock(sandbox_spec_id) + async def start_sandbox( + self, sandbox_spec_id: str | None = None, sandbox_id: str | None = None + ) -> SandboxInfo: + return await self.start_sandbox_mock(sandbox_spec_id, sandbox_id) async def resume_sandbox(self, sandbox_id: str) -> bool: return await self.resume_sandbox_mock(sandbox_id)