diff --git a/enterprise/server/middleware.py b/enterprise/server/middleware.py index 3124e1ce67..a312d501e6 100644 --- a/enterprise/server/middleware.py +++ b/enterprise/server/middleware.py @@ -103,11 +103,13 @@ class SetAuthCookieMiddleware: keycloak_auth_cookie = request.cookies.get('keycloak_auth') auth_header = request.headers.get('Authorization') mcp_auth_header = request.headers.get('X-Session-API-Key') + api_auth_header = request.headers.get('X-Access-Token') accepted_tos: bool | None = False if ( keycloak_auth_cookie is None and (auth_header is None or not auth_header.startswith('Bearer ')) and mcp_auth_header is None + and api_auth_header is None ): raise NoCredentialsError diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index fccbdd3f1b..4b015be5ac 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -23,7 +23,7 @@ from openhands.app_server.config import get_global_config from openhands.server.user_auth import get_user_id stripe.api_key = STRIPE_API_KEY -billing_router = APIRouter(prefix='/api/billing') +billing_router = APIRouter(prefix='/api/billing', tags=['Billing']) async def validate_billing_enabled() -> None: diff --git a/enterprise/server/routes/feedback.py b/enterprise/server/routes/feedback.py index 4a3ddc2eb8..7cfbf05fac 100644 --- a/enterprise/server/routes/feedback.py +++ b/enterprise/server/routes/feedback.py @@ -8,11 +8,18 @@ from storage.feedback import ConversationFeedback from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas from openhands.events.event_store import EventStore +from openhands.server.dependencies import get_dependencies from openhands.server.shared import file_store from openhands.server.user_auth import get_user_id from openhands.utils.async_utils import call_sync_from_async -router = APIRouter(prefix='/feedback', tags=['feedback']) +# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint +# is protected. The actual protection is provided by SetAuthCookieMiddleware +# TODO: It may be an error by you can actually post feedback to a conversation you don't +# own right now - maybe this is useful in the context of public shared conversations? +router = APIRouter( + prefix='/feedback', tags=['feedback'], dependencies=get_dependencies() +) async def get_event_ids(conversation_id: str, user_id: str) -> List[int]: diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py index 9a1f348189..9e1931715f 100644 --- a/enterprise/server/routes/orgs.py +++ b/enterprise/server/routes/orgs.py @@ -38,7 +38,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.server.user_auth import get_user_id # Initialize API router -org_router = APIRouter(prefix='/api/organizations') +org_router = APIRouter(prefix='/api/organizations', tags=['Orgs']) @org_router.get('', response_model=OrgPage) diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index 8eae724a0f..5a08d1d193 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -18,6 +18,7 @@ from openhands.app_server.services.httpx_client_injector import ( from openhands.app_server.services.injector import InjectorState from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR from openhands.app_server.user.user_context import UserContext +from openhands.server.dependencies import get_dependencies # Handle anext compatibility for Python < 3.10 if sys.version_info >= (3, 10): @@ -74,7 +75,11 @@ from openhands.app_server.utils.docker_utils import ( from openhands.sdk.context.skills import KeywordTrigger, TaskTrigger from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace -router = APIRouter(prefix='/app-conversations', tags=['Conversations']) +# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint +# is protected. The actual protection is provided by SetAuthCookieMiddleware +router = APIRouter( + prefix='/app-conversations', tags=['Conversations'], dependencies=get_dependencies() +) logger = logging.getLogger(__name__) app_conversation_service_dependency = depends_app_conversation_service() app_conversation_start_task_service_dependency = ( diff --git a/openhands/app_server/event/event_router.py b/openhands/app_server/event/event_router.py index 980b3ab47a..522a53c273 100644 --- a/openhands/app_server/event/event_router.py +++ b/openhands/app_server/event/event_router.py @@ -11,8 +11,15 @@ from openhands.app_server.config import depends_event_service from openhands.app_server.event.event_service import EventService from openhands.app_server.event_callback.event_callback_models import EventKind from openhands.sdk import Event +from openhands.server.dependencies import get_dependencies -router = APIRouter(prefix='/conversation/{conversation_id}/events', tags=['Events']) +# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint +# is protected. The actual protection is provided by SetAuthCookieMiddleware +router = APIRouter( + prefix='/conversation/{conversation_id}/events', + tags=['Events'], + dependencies=get_dependencies(), +) event_service_dependency = depends_event_service() diff --git a/openhands/app_server/sandbox/sandbox_router.py b/openhands/app_server/sandbox/sandbox_router.py index 4acb2b943c..79a3ef6b82 100644 --- a/openhands/app_server/sandbox/sandbox_router.py +++ b/openhands/app_server/sandbox/sandbox_router.py @@ -10,8 +10,13 @@ from openhands.app_server.sandbox.sandbox_models import SandboxInfo, SandboxPage from openhands.app_server.sandbox.sandbox_service import ( SandboxService, ) +from openhands.server.dependencies import get_dependencies -router = APIRouter(prefix='/sandboxes', tags=['Sandbox']) +# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint +# is protected. The actual protection is provided by SetAuthCookieMiddleware +router = APIRouter( + prefix='/sandboxes', tags=['Sandbox'], dependencies=get_dependencies() +) sandbox_service_dependency = depends_sandbox_service() # Read methods diff --git a/openhands/app_server/sandbox/sandbox_spec_router.py b/openhands/app_server/sandbox/sandbox_spec_router.py index f7f15e9dc7..6da3353f39 100644 --- a/openhands/app_server/sandbox/sandbox_spec_router.py +++ b/openhands/app_server/sandbox/sandbox_spec_router.py @@ -12,8 +12,15 @@ from openhands.app_server.sandbox.sandbox_spec_models import ( from openhands.app_server.sandbox.sandbox_spec_service import ( SandboxSpecService, ) +from openhands.server.dependencies import get_dependencies -router = APIRouter(prefix='/sandbox-specs', tags=['Sandbox']) +# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint +# is protected. The actual protection is provided by SetAuthCookieMiddleware +# Sandboxes specs share a single immutable list for the server right now, but that is likely to +# change in the future +router = APIRouter( + prefix='/sandbox-specs', tags=['Sandbox'], dependencies=get_dependencies() +) sandbox_spec_service_dependency = depends_sandbox_spec_service() diff --git a/openhands/app_server/user/user_router.py b/openhands/app_server/user/user_router.py index 0d2ff1ab97..2926c8495d 100644 --- a/openhands/app_server/user/user_router.py +++ b/openhands/app_server/user/user_router.py @@ -5,8 +5,11 @@ from fastapi import APIRouter, HTTPException, status from openhands.app_server.config import depends_user_context from openhands.app_server.user.user_context import UserContext from openhands.app_server.user.user_models import UserInfo +from openhands.server.dependencies import get_dependencies -router = APIRouter(prefix='/users', tags=['User']) +# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint +# is protected. The actual protection is provided by SetAuthCookieMiddleware +router = APIRouter(prefix='/users', tags=['User'], dependencies=get_dependencies()) user_dependency = depends_user_context() # Read methods diff --git a/openhands/server/dependencies.py b/openhands/server/dependencies.py index c8eef9c847..c4141c7c0b 100644 --- a/openhands/server/dependencies.py +++ b/openhands/server/dependencies.py @@ -11,6 +11,9 @@ import os from fastapi import Depends, HTTPException, status from fastapi.security import APIKeyHeader +from openhands.app_server.config import get_global_config +from openhands.server.types import AppMode + _SESSION_API_KEY = os.getenv('SESSION_API_KEY') _SESSION_API_KEY_HEADER = APIKeyHeader(name='X-Session-API-Key', auto_error=False) @@ -29,4 +32,9 @@ def get_dependencies() -> list[Depends]: result = [] if _SESSION_API_KEY: result.append(Depends(check_session_api_key)) + elif get_global_config().app_mode == AppMode.SAAS: + # This merely lets the OpenAPI Docs know that an X-Session-API-Key can be + # used for security - it does not fail if the header is not provided + # (Allowing cookies to also be used) + result.append(Depends(APIKeyHeader(name='X-Access-Token', auto_error=False))) return result diff --git a/openhands/server/user_auth/__init__.py b/openhands/server/user_auth/__init__.py index acd4ca0b49..d6a4c58e3f 100644 --- a/openhands/server/user_auth/__init__.py +++ b/openhands/server/user_auth/__init__.py @@ -1,4 +1,5 @@ -from fastapi import Request +from fastapi import Depends, Request +from fastapi.security import APIKeyHeader from pydantic import SecretStr from openhands.integrations.provider import PROVIDER_TOKEN_TYPE @@ -21,7 +22,15 @@ async def get_access_token(request: Request) -> SecretStr | None: return access_token -async def get_user_id(request: Request) -> str | None: +async def get_user_id( + request: Request, + api_key_header: str | None = Depends( + APIKeyHeader(name='X-Access-Token', auto_error=False) + ), +) -> str | None: + """Get the current user_id. Used for dependency injection - the + api key header is used here to signal the requirement in OpenAPI + docs""" user_auth = await get_user_auth(request) user_id = await user_auth.get_user_id() return user_id