From db41148396188a7566f412f1053211a655aaea7d Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Thu, 19 Mar 2026 01:46:23 +0700 Subject: [PATCH] feat(backend): expose API key org_id via new GET /api/keys/current endpoint (org project) (#13469) --- enterprise/server/auth/saas_user_auth.py | 18 +++- enterprise/server/routes/api_keys.py | 57 ++++++++++++- enterprise/storage/api_key_store.py | 22 ++++- .../storage/saas_conversation_validator.py | 6 +- .../tests/unit/server/routes/test_api_keys.py | 85 +++++++++++++++++++ enterprise/tests/unit/test_api_key_store.py | 14 ++- enterprise/tests/unit/test_saas_user_auth.py | 17 +++- 7 files changed, 203 insertions(+), 16 deletions(-) diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index c2b3e1fbe9..6c8aefea7a 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass from types import MappingProxyType +from uuid import UUID import jwt from fastapi import Request @@ -59,6 +60,10 @@ class SaasUserAuth(UserAuth): _secrets: Secrets | None = None accepted_tos: bool | None = None auth_type: AuthType = AuthType.COOKIE + # API key context fields - populated when authenticated via API key + api_key_org_id: UUID | None = None + api_key_id: int | None = None + api_key_name: str | None = None async def get_user_id(self) -> str | None: return self.user_id @@ -283,14 +288,19 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None: return None api_key_store = ApiKeyStore.get_instance() - user_id = await api_key_store.validate_api_key(api_key) - if not user_id: + validation_result = await api_key_store.validate_api_key(api_key) + if not validation_result: return None - offline_token = await token_manager.load_offline_token(user_id) + offline_token = await token_manager.load_offline_token( + validation_result.user_id + ) saas_user_auth = SaasUserAuth( - user_id=user_id, + user_id=validation_result.user_id, refresh_token=SecretStr(offline_token), auth_type=AuthType.BEARER, + api_key_org_id=validation_result.org_id, + api_key_id=validation_result.key_id, + api_key_name=validation_result.key_name, ) await saas_user_auth.refresh() return saas_user_auth diff --git a/enterprise/server/routes/api_keys.py b/enterprise/server/routes/api_keys.py index d5f30f87cf..31320966da 100644 --- a/enterprise/server/routes/api_keys.py +++ b/enterprise/server/routes/api_keys.py @@ -1,7 +1,9 @@ from datetime import UTC, datetime +from typing import cast -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel, field_validator +from server.auth.saas_user_auth import SaasUserAuth from storage.api_key import ApiKey from storage.api_key_store import ApiKeyStore from storage.lite_llm_manager import LiteLlmManager @@ -11,7 +13,8 @@ from storage.org_service import OrgService from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger -from openhands.server.user_auth import get_user_id +from openhands.server.user_auth import get_user_auth, get_user_id +from openhands.server.user_auth.user_auth import AuthType # Helper functions for BYOR API key management @@ -150,6 +153,16 @@ class MessageResponse(BaseModel): message: str +class CurrentApiKeyResponse(BaseModel): + """Response model for the current API key endpoint.""" + + id: int + name: str | None + org_id: str + user_id: str + auth_type: str + + def api_key_to_response(key: ApiKey) -> ApiKeyResponse: """Convert an ApiKey model to an ApiKeyResponse.""" return ApiKeyResponse( @@ -262,6 +275,46 @@ async def delete_api_key( ) +@api_router.get('/current', tags=['Keys']) +async def get_current_api_key( + request: Request, + user_id: str = Depends(get_user_id), +) -> CurrentApiKeyResponse: + """Get information about the currently authenticated API key. + + This endpoint returns metadata about the API key used for the current request, + including the org_id associated with the key. This is useful for API key + callers who need to know which organization context their key operates in. + + Returns 400 if not authenticated via API key (e.g., using cookie auth). + """ + user_auth = await get_user_auth(request) + + # Check if authenticated via API key + if user_auth.get_auth_type() != AuthType.BEARER: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='This endpoint requires API key authentication. Not available for cookie-based auth.', + ) + + # In SaaS context, bearer auth always produces SaasUserAuth + saas_user_auth = cast(SaasUserAuth, user_auth) + + if saas_user_auth.api_key_org_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='This API key was created before organization support. Please regenerate your API key to use this endpoint.', + ) + + return CurrentApiKeyResponse( + id=saas_user_auth.api_key_id, + name=saas_user_auth.api_key_name, + org_id=str(saas_user_auth.api_key_org_id), + user_id=user_id, + auth_type=saas_user_auth.auth_type.value, + ) + + @api_router.get('/llm/byor', tags=['Keys']) async def get_llm_api_key_for_byor( user_id: str = Depends(get_user_id), diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index 74a2d3d73e..3090b8da07 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -4,6 +4,7 @@ import secrets import string from dataclasses import dataclass from datetime import UTC, datetime +from uuid import UUID from sqlalchemy import select, update from storage.api_key import ApiKey @@ -13,6 +14,16 @@ from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger +@dataclass +class ApiKeyValidationResult: + """Result of API key validation containing user and org context.""" + + user_id: str + org_id: UUID | None + key_id: int + key_name: str | None + + @dataclass class ApiKeyStore: API_KEY_PREFIX = 'sk-oh-' @@ -60,8 +71,8 @@ class ApiKeyStore: return api_key - async def validate_api_key(self, api_key: str) -> str | None: - """Validate an API key and return the associated user_id if valid.""" + async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None: + """Validate an API key and return the associated user_id and org_id if valid.""" now = datetime.now(UTC) async with a_session_maker() as session: @@ -89,7 +100,12 @@ class ApiKeyStore: ) await session.commit() - return key_record.user_id + return ApiKeyValidationResult( + user_id=key_record.user_id, + org_id=key_record.org_id, + key_id=key_record.id, + key_name=key_record.name, + ) async def delete_api_key(self, api_key: str) -> bool: """Delete an API key by the key value.""" diff --git a/enterprise/storage/saas_conversation_validator.py b/enterprise/storage/saas_conversation_validator.py index bff4468011..51a5302dfc 100644 --- a/enterprise/storage/saas_conversation_validator.py +++ b/enterprise/storage/saas_conversation_validator.py @@ -28,12 +28,14 @@ class SaasConversationValidator(ConversationValidator): # Validate the API key and get the user_id api_key_store = ApiKeyStore.get_instance() - user_id = await api_key_store.validate_api_key(api_key) + validation_result = await api_key_store.validate_api_key(api_key) - if not user_id: + if not validation_result: logger.warning('Invalid API key') return None + user_id = validation_result.user_id + # Get the offline token for the user offline_token = await token_manager.load_offline_token(user_id) if not offline_token: diff --git a/enterprise/tests/unit/server/routes/test_api_keys.py b/enterprise/tests/unit/server/routes/test_api_keys.py index 57a9cb465d..4c35e9d5be 100644 --- a/enterprise/tests/unit/server/routes/test_api_keys.py +++ b/enterprise/tests/unit/server/routes/test_api_keys.py @@ -1,19 +1,26 @@ """Unit tests for API keys routes, focusing on BYOR key validation and retrieval.""" +import uuid from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from fastapi import HTTPException +from pydantic import SecretStr +from server.auth.saas_user_auth import SaasUserAuth from server.routes.api_keys import ( ByorPermittedResponse, + CurrentApiKeyResponse, LlmApiKeyResponse, check_byor_permitted, delete_byor_key_from_litellm, + get_current_api_key, get_llm_api_key_for_byor, ) from storage.lite_llm_manager import LiteLlmManager +from openhands.server.user_auth.user_auth import AuthType + class TestVerifyByorKeyInLitellm: """Test the verify_byor_key_in_litellm function.""" @@ -512,3 +519,81 @@ class TestCheckByorPermitted: assert exc_info.value.status_code == 500 assert 'Failed to check BYOR export permission' in exc_info.value.detail + + +class TestGetCurrentApiKey: + """Test the get_current_api_key endpoint.""" + + @pytest.mark.asyncio + @patch('server.routes.api_keys.get_user_auth') + async def test_returns_api_key_info_for_bearer_auth(self, mock_get_user_auth): + """Test that API key metadata including org_id is returned for bearer token auth.""" + # Arrange + user_id = 'user-123' + org_id = uuid.uuid4() + mock_request = MagicMock() + + user_auth = SaasUserAuth( + refresh_token=SecretStr('mock-token'), + user_id=user_id, + auth_type=AuthType.BEARER, + api_key_org_id=org_id, + api_key_id=42, + api_key_name='My Production Key', + ) + mock_get_user_auth.return_value = user_auth + + # Act + result = await get_current_api_key(request=mock_request, user_id=user_id) + + # Assert + assert isinstance(result, CurrentApiKeyResponse) + assert result.org_id == str(org_id) + assert result.id == 42 + assert result.name == 'My Production Key' + assert result.user_id == user_id + assert result.auth_type == 'bearer' + + @pytest.mark.asyncio + @patch('server.routes.api_keys.get_user_auth') + async def test_returns_400_for_cookie_auth(self, mock_get_user_auth): + """Test that 400 Bad Request is returned when using cookie authentication.""" + # Arrange + user_id = 'user-123' + mock_request = MagicMock() + + mock_user_auth = MagicMock() + mock_user_auth.get_auth_type.return_value = AuthType.COOKIE + mock_get_user_auth.return_value = mock_user_auth + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_current_api_key(request=mock_request, user_id=user_id) + + assert exc_info.value.status_code == 400 + assert 'API key authentication' in exc_info.value.detail + + @pytest.mark.asyncio + @patch('server.routes.api_keys.get_user_auth') + async def test_returns_400_when_api_key_org_id_is_none(self, mock_get_user_auth): + """Test that 400 is returned when API key has no org_id (legacy key).""" + # Arrange + user_id = 'user-123' + mock_request = MagicMock() + + user_auth = SaasUserAuth( + refresh_token=SecretStr('mock-token'), + user_id=user_id, + auth_type=AuthType.BEARER, + api_key_org_id=None, # No org_id - legacy key + api_key_id=42, + api_key_name='Legacy Key', + ) + mock_get_user_auth.return_value = user_auth + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await get_current_api_key(request=mock_request, user_id=user_id) + + assert exc_info.value.status_code == 400 + assert 'created before organization support' in exc_info.value.detail diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index d3a2d13d1e..baffe5893c 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -126,13 +126,18 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker): ) session.add(key_record) await session.commit() + key_id = key_record.id # Execute - patch a_session_maker to use test's async session maker with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - assert result == user_id + # Verify - result is now ApiKeyValidationResult + assert result is not None + assert result.user_id == user_id + assert result.org_id == org_id + assert result.key_id == key_id + assert result.key_name == 'Test Key' @pytest.mark.asyncio @@ -218,8 +223,9 @@ async def test_validate_api_key_valid_timezone_naive( with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - assert result == user_id + # Verify - result is now ApiKeyValidationResult + assert result is not None + assert result.user_id == user_id @pytest.mark.asyncio diff --git a/enterprise/tests/unit/test_saas_user_auth.py b/enterprise/tests/unit/test_saas_user_auth.py index 92552de3ad..2fb1b68445 100644 --- a/enterprise/tests/unit/test_saas_user_auth.py +++ b/enterprise/tests/unit/test_saas_user_auth.py @@ -1,4 +1,5 @@ import time +import uuid from unittest.mock import AsyncMock, MagicMock, patch import jwt @@ -18,6 +19,7 @@ from server.auth.saas_user_auth import ( saas_user_auth_from_cookie, saas_user_auth_from_signed_token, ) +from storage.api_key_store import ApiKeyValidationResult from storage.user_authorization import UserAuthorizationType from openhands.integrations.provider import ProviderToken, ProviderType @@ -468,12 +470,22 @@ async def test_saas_user_auth_from_bearer_success(): algorithm='HS256', ) + mock_org_id = uuid.uuid4() + mock_validation_result = ApiKeyValidationResult( + user_id='test_user_id', + org_id=mock_org_id, + key_id=42, + key_name='Test Key', + ) + with ( patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls, patch('server.auth.saas_user_auth.token_manager') as mock_token_manager, ): mock_api_key_store = MagicMock() - mock_api_key_store.validate_api_key = AsyncMock(return_value='test_user_id') + mock_api_key_store.validate_api_key = AsyncMock( + return_value=mock_validation_result + ) mock_api_key_store_cls.get_instance.return_value = mock_api_key_store mock_token_manager.load_offline_token = AsyncMock(return_value=offline_token) @@ -485,6 +497,9 @@ async def test_saas_user_auth_from_bearer_success(): assert isinstance(result, SaasUserAuth) assert result.user_id == 'test_user_id' + assert result.api_key_org_id == mock_org_id + assert result.api_key_id == 42 + assert result.api_key_name == 'Test Key' mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key') mock_token_manager.load_offline_token.assert_called_once_with('test_user_id') mock_token_manager.refresh.assert_called_once_with(offline_token)