Refactor webhook endpoints to use session API key authentication (#11926)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell 2025-12-08 07:40:01 -07:00 committed by GitHub
parent ed7adb335c
commit db64abc580
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 82 additions and 14 deletions

View File

@ -60,16 +60,22 @@ _logger = logging.getLogger(__name__)
async def valid_sandbox(
sandbox_id: str,
user_context: UserContext = Depends(as_admin),
session_api_key: str = Depends(
APIKeyHeader(name='X-Session-API-Key', auto_error=False)
),
sandbox_service: SandboxService = sandbox_service_dependency,
) -> SandboxInfo:
sandbox_info = await sandbox_service.get_sandbox(sandbox_id)
if sandbox_info is None or sandbox_info.session_api_key != session_api_key:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
if session_api_key is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, detail='X-Session-API-Key header is required'
)
sandbox_info = await sandbox_service.get_sandbox_by_session_api_key(session_api_key)
if sandbox_info is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, detail='Invalid session API key'
)
return sandbox_info
@ -94,7 +100,7 @@ async def valid_conversation(
return app_conversation_info
@router.post('/{sandbox_id}/conversations')
@router.post('/conversations')
async def on_conversation_update(
conversation_info: ConversationInfo,
sandbox_info: SandboxInfo = Depends(valid_sandbox),
@ -125,7 +131,7 @@ async def on_conversation_update(
return Success()
@router.post('/{sandbox_id}/events/{conversation_id}')
@router.post('/events/{conversation_id}')
async def on_event(
events: list[Event],
conversation_id: UUID,

View File

@ -260,6 +260,29 @@ class DockerSandboxService(SandboxService):
except (NotFound, APIError):
return None
async def get_sandbox_by_session_api_key(
self, session_api_key: str
) -> SandboxInfo | None:
"""Get a single sandbox by session API key."""
try:
# Get all containers with our prefix
all_containers = self.docker_client.containers.list(all=True)
for container in all_containers:
if container.name and container.name.startswith(
self.container_name_prefix
):
# Check if this container has the matching session API key
env_vars = self._get_container_env_vars(container)
container_session_key = env_vars.get(SESSION_API_KEY_VARIABLE)
if container_session_key == session_api_key:
return await self._container_to_checked_sandbox_info(container)
return None
except (NotFound, APIError):
return None
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
"""Start a new sandbox."""
# Enforce sandbox limits by cleaning up old sandboxes
@ -285,8 +308,7 @@ class DockerSandboxService(SandboxService):
env_vars = sandbox_spec.initial_env.copy()
env_vars[SESSION_API_KEY_VARIABLE] = session_api_key
env_vars[WEBHOOK_CALLBACK_VARIABLE] = (
f'http://host.docker.internal:{self.host_port}'
f'/api/v1/webhooks/{container_name}'
f'http://host.docker.internal:{self.host_port}/api/v1/webhooks'
)
# Prepare port mappings and add port environment variables

View File

@ -275,6 +275,17 @@ class ProcessSandboxService(SandboxService):
return await self._process_to_sandbox_info(sandbox_id, process_info)
async def get_sandbox_by_session_api_key(
self, session_api_key: str
) -> SandboxInfo | None:
"""Get a single sandbox by session API key."""
# Search through all processes to find one with matching session_api_key
for sandbox_id, process_info in _processes.items():
if process_info.session_api_key == session_api_key:
return await self._process_to_sandbox_info(sandbox_id, process_info)
return None
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
"""Start a new sandbox."""
# Get sandbox spec

View File

@ -240,9 +240,7 @@ class RemoteSandboxService(SandboxService):
# If a public facing url is defined, add a callback to the agent server environment.
if self.web_url:
environment[WEBHOOK_CALLBACK_VARIABLE] = (
f'{self.web_url}/api/v1/webhooks/{sandbox_id}'
)
environment[WEBHOOK_CALLBACK_VARIABLE] = f'{self.web_url}/api/v1/webhooks'
# We specify CORS settings only if there is a public facing url - otherwise
# we are probably in local development and the only url in use is localhost
environment[ALLOW_CORS_ORIGINS_VARIABLE] = self.web_url
@ -301,6 +299,27 @@ class RemoteSandboxService(SandboxService):
return None
return await self._to_sandbox_info(stored_sandbox)
async def get_sandbox_by_session_api_key(
self, session_api_key: str
) -> Union[SandboxInfo, None]:
"""Get a single sandbox by session API key."""
# Get all stored sandboxes for the current user
stmt = await self._secure_select()
result = await self.db_session.execute(stmt)
stored_sandboxes = result.scalars().all()
# Check each sandbox's runtime data for matching session_api_key
for stored_sandbox in stored_sandboxes:
try:
runtime = await self._get_runtime(stored_sandbox.id)
if runtime and runtime.get('session_api_key') == session_api_key:
return await self._to_sandbox_info(stored_sandbox, runtime)
except Exception:
# Continue checking other sandboxes if one fails
continue
return None
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
"""Start a new sandbox by creating a remote runtime."""
try:

View File

@ -25,6 +25,12 @@ class SandboxService(ABC):
async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None:
"""Get a single sandbox. Return None if the sandbox was not found."""
@abstractmethod
async def get_sandbox_by_session_api_key(
self, session_api_key: str
) -> SandboxInfo | None:
"""Get a single sandbox by session API key. Return None if the sandbox was not found."""
async def batch_get_sandboxes(
self, sandbox_ids: list[str]
) -> list[SandboxInfo | None]:

View File

@ -291,9 +291,7 @@ class TestEnvironmentInitialization:
)
# Verify
expected_webhook_url = (
'https://web.example.com/api/v1/webhooks/test-sandbox-123'
)
expected_webhook_url = 'https://web.example.com/api/v1/webhooks'
assert environment['EXISTING_VAR'] == 'existing_value'
assert environment[WEBHOOK_CALLBACK_VARIABLE] == expected_webhook_url
assert environment[ALLOW_CORS_ORIGINS_VARIABLE] == 'https://web.example.com'

View File

@ -27,6 +27,7 @@ class MockSandboxService(SandboxService):
def __init__(self):
self.search_sandboxes_mock = AsyncMock()
self.get_sandbox_mock = AsyncMock()
self.get_sandbox_by_session_api_key_mock = AsyncMock()
self.start_sandbox_mock = AsyncMock()
self.resume_sandbox_mock = AsyncMock()
self.pause_sandbox_mock = AsyncMock()
@ -40,6 +41,11 @@ class MockSandboxService(SandboxService):
async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None:
return await self.get_sandbox_mock(sandbox_id)
async def get_sandbox_by_session_api_key(
self, session_api_key: str
) -> SandboxInfo | None:
return await self.get_sandbox_by_session_api_key_mock(session_api_key)
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
return await self.start_sandbox_mock(sandbox_spec_id)