mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Refactor webhook endpoints to use session API key authentication (#11926)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
ed7adb335c
commit
db64abc580
@ -60,16 +60,22 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def valid_sandbox(
|
||||
sandbox_id: str,
|
||||
user_context: UserContext = Depends(as_admin),
|
||||
session_api_key: str = Depends(
|
||||
APIKeyHeader(name='X-Session-API-Key', auto_error=False)
|
||||
),
|
||||
sandbox_service: SandboxService = sandbox_service_dependency,
|
||||
) -> SandboxInfo:
|
||||
sandbox_info = await sandbox_service.get_sandbox(sandbox_id)
|
||||
if sandbox_info is None or sandbox_info.session_api_key != session_api_key:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
if session_api_key is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, detail='X-Session-API-Key header is required'
|
||||
)
|
||||
|
||||
sandbox_info = await sandbox_service.get_sandbox_by_session_api_key(session_api_key)
|
||||
if sandbox_info is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, detail='Invalid session API key'
|
||||
)
|
||||
return sandbox_info
|
||||
|
||||
|
||||
@ -94,7 +100,7 @@ async def valid_conversation(
|
||||
return app_conversation_info
|
||||
|
||||
|
||||
@router.post('/{sandbox_id}/conversations')
|
||||
@router.post('/conversations')
|
||||
async def on_conversation_update(
|
||||
conversation_info: ConversationInfo,
|
||||
sandbox_info: SandboxInfo = Depends(valid_sandbox),
|
||||
@ -125,7 +131,7 @@ async def on_conversation_update(
|
||||
return Success()
|
||||
|
||||
|
||||
@router.post('/{sandbox_id}/events/{conversation_id}')
|
||||
@router.post('/events/{conversation_id}')
|
||||
async def on_event(
|
||||
events: list[Event],
|
||||
conversation_id: UUID,
|
||||
|
||||
@ -260,6 +260,29 @@ class DockerSandboxService(SandboxService):
|
||||
except (NotFound, APIError):
|
||||
return None
|
||||
|
||||
async def get_sandbox_by_session_api_key(
|
||||
self, session_api_key: str
|
||||
) -> SandboxInfo | None:
|
||||
"""Get a single sandbox by session API key."""
|
||||
try:
|
||||
# Get all containers with our prefix
|
||||
all_containers = self.docker_client.containers.list(all=True)
|
||||
|
||||
for container in all_containers:
|
||||
if container.name and container.name.startswith(
|
||||
self.container_name_prefix
|
||||
):
|
||||
# Check if this container has the matching session API key
|
||||
env_vars = self._get_container_env_vars(container)
|
||||
container_session_key = env_vars.get(SESSION_API_KEY_VARIABLE)
|
||||
|
||||
if container_session_key == session_api_key:
|
||||
return await self._container_to_checked_sandbox_info(container)
|
||||
|
||||
return None
|
||||
except (NotFound, APIError):
|
||||
return None
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
"""Start a new sandbox."""
|
||||
# Enforce sandbox limits by cleaning up old sandboxes
|
||||
@ -285,8 +308,7 @@ class DockerSandboxService(SandboxService):
|
||||
env_vars = sandbox_spec.initial_env.copy()
|
||||
env_vars[SESSION_API_KEY_VARIABLE] = session_api_key
|
||||
env_vars[WEBHOOK_CALLBACK_VARIABLE] = (
|
||||
f'http://host.docker.internal:{self.host_port}'
|
||||
f'/api/v1/webhooks/{container_name}'
|
||||
f'http://host.docker.internal:{self.host_port}/api/v1/webhooks'
|
||||
)
|
||||
|
||||
# Prepare port mappings and add port environment variables
|
||||
|
||||
@ -275,6 +275,17 @@ class ProcessSandboxService(SandboxService):
|
||||
|
||||
return await self._process_to_sandbox_info(sandbox_id, process_info)
|
||||
|
||||
async def get_sandbox_by_session_api_key(
|
||||
self, session_api_key: str
|
||||
) -> SandboxInfo | None:
|
||||
"""Get a single sandbox by session API key."""
|
||||
# Search through all processes to find one with matching session_api_key
|
||||
for sandbox_id, process_info in _processes.items():
|
||||
if process_info.session_api_key == session_api_key:
|
||||
return await self._process_to_sandbox_info(sandbox_id, process_info)
|
||||
|
||||
return None
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
"""Start a new sandbox."""
|
||||
# Get sandbox spec
|
||||
|
||||
@ -240,9 +240,7 @@ class RemoteSandboxService(SandboxService):
|
||||
|
||||
# If a public facing url is defined, add a callback to the agent server environment.
|
||||
if self.web_url:
|
||||
environment[WEBHOOK_CALLBACK_VARIABLE] = (
|
||||
f'{self.web_url}/api/v1/webhooks/{sandbox_id}'
|
||||
)
|
||||
environment[WEBHOOK_CALLBACK_VARIABLE] = f'{self.web_url}/api/v1/webhooks'
|
||||
# We specify CORS settings only if there is a public facing url - otherwise
|
||||
# we are probably in local development and the only url in use is localhost
|
||||
environment[ALLOW_CORS_ORIGINS_VARIABLE] = self.web_url
|
||||
@ -301,6 +299,27 @@ class RemoteSandboxService(SandboxService):
|
||||
return None
|
||||
return await self._to_sandbox_info(stored_sandbox)
|
||||
|
||||
async def get_sandbox_by_session_api_key(
|
||||
self, session_api_key: str
|
||||
) -> Union[SandboxInfo, None]:
|
||||
"""Get a single sandbox by session API key."""
|
||||
# Get all stored sandboxes for the current user
|
||||
stmt = await self._secure_select()
|
||||
result = await self.db_session.execute(stmt)
|
||||
stored_sandboxes = result.scalars().all()
|
||||
|
||||
# Check each sandbox's runtime data for matching session_api_key
|
||||
for stored_sandbox in stored_sandboxes:
|
||||
try:
|
||||
runtime = await self._get_runtime(stored_sandbox.id)
|
||||
if runtime and runtime.get('session_api_key') == session_api_key:
|
||||
return await self._to_sandbox_info(stored_sandbox, runtime)
|
||||
except Exception:
|
||||
# Continue checking other sandboxes if one fails
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
"""Start a new sandbox by creating a remote runtime."""
|
||||
try:
|
||||
|
||||
@ -25,6 +25,12 @@ class SandboxService(ABC):
|
||||
async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""Get a single sandbox. Return None if the sandbox was not found."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_sandbox_by_session_api_key(
|
||||
self, session_api_key: str
|
||||
) -> SandboxInfo | None:
|
||||
"""Get a single sandbox by session API key. Return None if the sandbox was not found."""
|
||||
|
||||
async def batch_get_sandboxes(
|
||||
self, sandbox_ids: list[str]
|
||||
) -> list[SandboxInfo | None]:
|
||||
|
||||
@ -291,9 +291,7 @@ class TestEnvironmentInitialization:
|
||||
)
|
||||
|
||||
# Verify
|
||||
expected_webhook_url = (
|
||||
'https://web.example.com/api/v1/webhooks/test-sandbox-123'
|
||||
)
|
||||
expected_webhook_url = 'https://web.example.com/api/v1/webhooks'
|
||||
assert environment['EXISTING_VAR'] == 'existing_value'
|
||||
assert environment[WEBHOOK_CALLBACK_VARIABLE] == expected_webhook_url
|
||||
assert environment[ALLOW_CORS_ORIGINS_VARIABLE] == 'https://web.example.com'
|
||||
|
||||
@ -27,6 +27,7 @@ class MockSandboxService(SandboxService):
|
||||
def __init__(self):
|
||||
self.search_sandboxes_mock = AsyncMock()
|
||||
self.get_sandbox_mock = AsyncMock()
|
||||
self.get_sandbox_by_session_api_key_mock = AsyncMock()
|
||||
self.start_sandbox_mock = AsyncMock()
|
||||
self.resume_sandbox_mock = AsyncMock()
|
||||
self.pause_sandbox_mock = AsyncMock()
|
||||
@ -40,6 +41,11 @@ class MockSandboxService(SandboxService):
|
||||
async def get_sandbox(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
return await self.get_sandbox_mock(sandbox_id)
|
||||
|
||||
async def get_sandbox_by_session_api_key(
|
||||
self, session_api_key: str
|
||||
) -> SandboxInfo | None:
|
||||
return await self.get_sandbox_by_session_api_key_mock(session_api_key)
|
||||
|
||||
async def start_sandbox(self, sandbox_spec_id: str | None = None) -> SandboxInfo:
|
||||
return await self.start_sandbox_mock(sandbox_spec_id)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user