From db64abc580c583f5381889311bb84f74dc307faa Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Mon, 8 Dec 2025 07:40:01 -0700 Subject: [PATCH] Refactor webhook endpoints to use session API key authentication (#11926) Co-authored-by: openhands --- .../event_callback/webhook_router.py | 18 ++++++++----- .../sandbox/docker_sandbox_service.py | 26 +++++++++++++++++-- .../sandbox/process_sandbox_service.py | 11 ++++++++ .../sandbox/remote_sandbox_service.py | 25 +++++++++++++++--- .../app_server/sandbox/sandbox_service.py | 6 +++++ .../app_server/test_remote_sandbox_service.py | 4 +-- tests/unit/app_server/test_sandbox_service.py | 6 +++++ 7 files changed, 82 insertions(+), 14 deletions(-) diff --git a/openhands/app_server/event_callback/webhook_router.py b/openhands/app_server/event_callback/webhook_router.py index ac9812764d..28236b7325 100644 --- a/openhands/app_server/event_callback/webhook_router.py +++ b/openhands/app_server/event_callback/webhook_router.py @@ -60,16 +60,22 @@ _logger = logging.getLogger(__name__) async def valid_sandbox( - sandbox_id: str, user_context: UserContext = Depends(as_admin), session_api_key: str = Depends( APIKeyHeader(name='X-Session-API-Key', auto_error=False) ), sandbox_service: SandboxService = sandbox_service_dependency, ) -> SandboxInfo: - sandbox_info = await sandbox_service.get_sandbox(sandbox_id) - if sandbox_info is None or sandbox_info.session_api_key != session_api_key: - raise HTTPException(status.HTTP_401_UNAUTHORIZED) + if session_api_key is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, detail='X-Session-API-Key header is required' + ) + + sandbox_info = await sandbox_service.get_sandbox_by_session_api_key(session_api_key) + if sandbox_info is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, detail='Invalid session API key' + ) return sandbox_info @@ -94,7 +100,7 @@ async def valid_conversation( return app_conversation_info -@router.post('/{sandbox_id}/conversations') +@router.post('/conversations') async def on_conversation_update( conversation_info: ConversationInfo, sandbox_info: SandboxInfo = Depends(valid_sandbox), @@ -125,7 +131,7 @@ async def on_conversation_update( return Success() -@router.post('/{sandbox_id}/events/{conversation_id}') +@router.post('/events/{conversation_id}') async def on_event( events: list[Event], conversation_id: UUID, diff --git a/openhands/app_server/sandbox/docker_sandbox_service.py b/openhands/app_server/sandbox/docker_sandbox_service.py index d7fe0b726d..a0aeddc0e6 100644 --- a/openhands/app_server/sandbox/docker_sandbox_service.py +++ b/openhands/app_server/sandbox/docker_sandbox_service.py @@ -260,6 +260,29 @@ class DockerSandboxService(SandboxService): except (NotFound, APIError): return None + async def get_sandbox_by_session_api_key( + self, session_api_key: str + ) -> SandboxInfo | None: + """Get a single sandbox by session API key.""" + try: + # Get all containers with our prefix + all_containers = self.docker_client.containers.list(all=True) + + for container in all_containers: + if container.name and container.name.startswith( + self.container_name_prefix + ): + # Check if this container has the matching session API key + env_vars = self._get_container_env_vars(container) + container_session_key = env_vars.get(SESSION_API_KEY_VARIABLE) + + if container_session_key == session_api_key: + return await self._container_to_checked_sandbox_info(container) + + return None + except (NotFound, APIError): + return None + async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: """Start a new sandbox.""" # Enforce sandbox limits by cleaning up old sandboxes @@ -285,8 +308,7 @@ class DockerSandboxService(SandboxService): env_vars = sandbox_spec.initial_env.copy() env_vars[SESSION_API_KEY_VARIABLE] = session_api_key env_vars[WEBHOOK_CALLBACK_VARIABLE] = ( - f'http://host.docker.internal:{self.host_port}' - f'/api/v1/webhooks/{container_name}' + f'http://host.docker.internal:{self.host_port}/api/v1/webhooks' ) # Prepare port mappings and add port environment variables diff --git a/openhands/app_server/sandbox/process_sandbox_service.py b/openhands/app_server/sandbox/process_sandbox_service.py index 716c2e1b19..200bf62c44 100644 --- a/openhands/app_server/sandbox/process_sandbox_service.py +++ b/openhands/app_server/sandbox/process_sandbox_service.py @@ -275,6 +275,17 @@ class ProcessSandboxService(SandboxService): return await self._process_to_sandbox_info(sandbox_id, process_info) + async def get_sandbox_by_session_api_key( + self, session_api_key: str + ) -> SandboxInfo | None: + """Get a single sandbox by session API key.""" + # Search through all processes to find one with matching session_api_key + for sandbox_id, process_info in _processes.items(): + if process_info.session_api_key == session_api_key: + return await self._process_to_sandbox_info(sandbox_id, process_info) + + return None + async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: """Start a new sandbox.""" # Get sandbox spec diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index dfa029462e..5ee42218dc 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -240,9 +240,7 @@ class RemoteSandboxService(SandboxService): # If a public facing url is defined, add a callback to the agent server environment. if self.web_url: - environment[WEBHOOK_CALLBACK_VARIABLE] = ( - f'{self.web_url}/api/v1/webhooks/{sandbox_id}' - ) + environment[WEBHOOK_CALLBACK_VARIABLE] = f'{self.web_url}/api/v1/webhooks' # We specify CORS settings only if there is a public facing url - otherwise # we are probably in local development and the only url in use is localhost environment[ALLOW_CORS_ORIGINS_VARIABLE] = self.web_url @@ -301,6 +299,27 @@ class RemoteSandboxService(SandboxService): return None return await self._to_sandbox_info(stored_sandbox) + async def get_sandbox_by_session_api_key( + self, session_api_key: str + ) -> Union[SandboxInfo, None]: + """Get a single sandbox by session API key.""" + # Get all stored sandboxes for the current user + stmt = await self._secure_select() + result = await self.db_session.execute(stmt) + stored_sandboxes = result.scalars().all() + + # Check each sandbox's runtime data for matching session_api_key + for stored_sandbox in stored_sandboxes: + try: + runtime = await self._get_runtime(stored_sandbox.id) + if runtime and runtime.get('session_api_key') == session_api_key: + return await self._to_sandbox_info(stored_sandbox, runtime) + except Exception: + # Continue checking other sandboxes if one fails + continue + + return None + async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: """Start a new sandbox by creating a remote runtime.""" try: diff --git a/openhands/app_server/sandbox/sandbox_service.py b/openhands/app_server/sandbox/sandbox_service.py index 43393dfcf7..b1144a47cc 100644 --- a/openhands/app_server/sandbox/sandbox_service.py +++ b/openhands/app_server/sandbox/sandbox_service.py @@ -25,6 +25,12 @@ class SandboxService(ABC): async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None: """Get a single sandbox. Return None if the sandbox was not found.""" + @abstractmethod + async def get_sandbox_by_session_api_key( + self, session_api_key: str + ) -> SandboxInfo | None: + """Get a single sandbox by session API key. Return None if the sandbox was not found.""" + async def batch_get_sandboxes( self, sandbox_ids: list[str] ) -> list[SandboxInfo | None]: diff --git a/tests/unit/app_server/test_remote_sandbox_service.py b/tests/unit/app_server/test_remote_sandbox_service.py index 567ecad2e3..5802e46ecb 100644 --- a/tests/unit/app_server/test_remote_sandbox_service.py +++ b/tests/unit/app_server/test_remote_sandbox_service.py @@ -291,9 +291,7 @@ class TestEnvironmentInitialization: ) # Verify - expected_webhook_url = ( - 'https://web.example.com/api/v1/webhooks/test-sandbox-123' - ) + expected_webhook_url = 'https://web.example.com/api/v1/webhooks' assert environment['EXISTING_VAR'] == 'existing_value' assert environment[WEBHOOK_CALLBACK_VARIABLE] == expected_webhook_url assert environment[ALLOW_CORS_ORIGINS_VARIABLE] == 'https://web.example.com' diff --git a/tests/unit/app_server/test_sandbox_service.py b/tests/unit/app_server/test_sandbox_service.py index 9a65131821..f3eea1d2ea 100644 --- a/tests/unit/app_server/test_sandbox_service.py +++ b/tests/unit/app_server/test_sandbox_service.py @@ -27,6 +27,7 @@ class MockSandboxService(SandboxService): def __init__(self): self.search_sandboxes_mock = AsyncMock() self.get_sandbox_mock = AsyncMock() + self.get_sandbox_by_session_api_key_mock = AsyncMock() self.start_sandbox_mock = AsyncMock() self.resume_sandbox_mock = AsyncMock() self.pause_sandbox_mock = AsyncMock() @@ -40,6 +41,11 @@ class MockSandboxService(SandboxService): async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None: return await self.get_sandbox_mock(sandbox_id) + async def get_sandbox_by_session_api_key( + self, session_api_key: str + ) -> SandboxInfo | None: + return await self.get_sandbox_by_session_api_key_mock(session_api_key) + async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo: return await self.start_sandbox_mock(sandbox_spec_id)