diff --git a/enterprise/server/auth/authorization.py b/enterprise/server/auth/authorization.py index c8d72021c6..203f74f112 100644 --- a/enterprise/server/auth/authorization.py +++ b/enterprise/server/auth/authorization.py @@ -35,7 +35,7 @@ Usage: from enum import Enum from uuid import UUID -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, Request, status from storage.org_member_store import OrgMemberStore from storage.role import Role from storage.role_store import RoleStore @@ -214,6 +214,19 @@ def has_permission(user_role: Role, permission: Permission) -> bool: return permission in permissions +async def get_api_key_org_id_from_request(request: Request) -> UUID | None: + """Get the org_id bound to the API key used for authentication. + + Returns None if: + - Not authenticated via API key (cookie auth) + - API key is a legacy key without org binding + """ + user_auth = getattr(request.state, 'user_auth', None) + if user_auth and hasattr(user_auth, 'get_api_key_org_id'): + return user_auth.get_api_key_org_id() + return None + + def require_permission(permission: Permission): """ Factory function that creates a dependency to require a specific permission. @@ -221,8 +234,9 @@ def require_permission(permission: Permission): This creates a FastAPI dependency that: 1. Extracts org_id from the path parameter 2. Gets the authenticated user_id - 3. Checks if the user has the required permission in the organization - 4. Returns the user_id if authorized, raises HTTPException otherwise + 3. Validates API key org binding (if using API key auth) + 4. Checks if the user has the required permission in the organization + 5. Returns the user_id if authorized, raises HTTPException otherwise Usage: @router.get('/{org_id}/settings') @@ -240,6 +254,7 @@ def require_permission(permission: Permission): """ async def permission_checker( + request: Request, org_id: UUID | None = None, user_id: str | None = Depends(get_user_id), ) -> str: @@ -249,6 +264,23 @@ def require_permission(permission: Permission): detail='User not authenticated', ) + # Validate API key organization binding + api_key_org_id = await get_api_key_org_id_from_request(request) + if api_key_org_id is not None and org_id is not None: + if api_key_org_id != org_id: + logger.warning( + 'API key organization mismatch', + extra={ + 'user_id': user_id, + 'api_key_org_id': str(api_key_org_id), + 'target_org_id': str(org_id), + }, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail='API key is not authorized for this organization', + ) + user_role = await get_user_org_role(user_id, org_id) if not user_role: diff --git a/enterprise/server/auth/saas_user_auth.py b/enterprise/server/auth/saas_user_auth.py index 6c8aefea7a..15f50cc40b 100644 --- a/enterprise/server/auth/saas_user_auth.py +++ b/enterprise/server/auth/saas_user_auth.py @@ -61,10 +61,19 @@ class SaasUserAuth(UserAuth): accepted_tos: bool | None = None auth_type: AuthType = AuthType.COOKIE # API key context fields - populated when authenticated via API key - api_key_org_id: UUID | None = None + api_key_org_id: UUID | None = None # Org bound to the API key used for auth api_key_id: int | None = None api_key_name: str | None = None + def get_api_key_org_id(self) -> UUID | None: + """Get the organization ID bound to the API key used for authentication. + + Returns: + The org_id if authenticated via API key with org binding, None otherwise + (cookie auth or legacy API keys without org binding). + """ + return self.api_key_org_id + async def get_user_id(self) -> str | None: return self.user_id diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index ecbb375592..30e5b242e8 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -16,10 +16,10 @@ from openhands.core.logger import openhands_logger as logger @dataclass class ApiKeyValidationResult: - """Result of API key validation containing user and org context.""" + """Result of API key validation containing user and organization info.""" user_id: str - org_id: UUID | None + org_id: UUID | None # None for legacy API keys without org binding key_id: int key_name: str | None @@ -195,7 +195,12 @@ class ApiKeyStore: return api_key async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None: - """Validate an API key and return the associated user_id and org_id if valid.""" + """Validate an API key and return the associated user_id and org_id if valid. + + Returns: + ApiKeyValidationResult if the key is valid, None otherwise. + The org_id may be None for legacy API keys that weren't bound to an organization. + """ now = datetime.now(UTC) async with a_session_maker() as session: diff --git a/enterprise/storage/saas_conversation_validator.py b/enterprise/storage/saas_conversation_validator.py index 51a5302dfc..6493fdd602 100644 --- a/enterprise/storage/saas_conversation_validator.py +++ b/enterprise/storage/saas_conversation_validator.py @@ -15,13 +15,13 @@ class SaasConversationValidator(ConversationValidator): async def _validate_api_key(self, api_key: str) -> str | None: """ - Validate an API key and return the user_id and github_user_id if valid. + Validate an API key and return the user_id if valid. Args: api_key: The API key to validate Returns: - A tuple of (user_id, github_user_id) if the API key is valid, None otherwise + The user_id if the API key is valid, None otherwise """ try: token_manager = TokenManager() diff --git a/enterprise/tests/unit/test_api_key_store.py b/enterprise/tests/unit/test_api_key_store.py index baffe5893c..c57f63f2ae 100644 --- a/enterprise/tests/unit/test_api_key_store.py +++ b/enterprise/tests/unit/test_api_key_store.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy import select from storage.api_key import ApiKey -from storage.api_key_store import ApiKeyStore +from storage.api_key_store import ApiKeyStore, ApiKeyValidationResult @pytest.fixture @@ -110,8 +110,8 @@ async def test_create_api_key( @pytest.mark.asyncio async def test_validate_api_key_valid(api_key_store, async_session_maker): - """Test validating a valid API key.""" - # Setup - create an API key in the database + """Test validating a valid API key returns user_id and org_id.""" + # Arrange user_id = str(uuid.uuid4()) org_id = uuid.uuid4() api_key_value = 'test-api-key' @@ -128,11 +128,12 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker): await session.commit() key_id = key_record.id - # Execute - patch a_session_maker to use test's async session maker + # Act with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - result is now ApiKeyValidationResult + # Assert + assert isinstance(result, ApiKeyValidationResult) assert result is not None assert result.user_id == user_id assert result.org_id == org_id @@ -202,7 +203,7 @@ async def test_validate_api_key_valid_timezone_naive( api_key_store, async_session_maker ): """Test validating a valid API key with timezone-naive datetime from database.""" - # Setup - create a valid API key with timezone-naive datetime (future date) + # Arrange user_id = str(uuid.uuid4()) org_id = uuid.uuid4() api_key_value = 'test-valid-naive-key' @@ -219,13 +220,44 @@ async def test_validate_api_key_valid_timezone_naive( session.add(key_record) await session.commit() - # Execute - patch a_session_maker to use test's async session maker + # Act with patch('storage.api_key_store.a_session_maker', async_session_maker): result = await api_key_store.validate_api_key(api_key_value) - # Verify - result is now ApiKeyValidationResult + # Assert + assert isinstance(result, ApiKeyValidationResult) + assert result.user_id == user_id + assert result.org_id == org_id + + +@pytest.mark.asyncio +async def test_validate_api_key_legacy_without_org_id( + api_key_store, async_session_maker +): + """Test validating a legacy API key without org_id returns None for org_id.""" + # Arrange + user_id = str(uuid.uuid4()) + api_key_value = 'test-legacy-key-no-org' + + async with async_session_maker() as session: + key_record = ApiKey( + key=api_key_value, + user_id=user_id, + org_id=None, # Legacy key without org binding + name='Legacy Key', + ) + session.add(key_record) + await session.commit() + + # Act + with patch('storage.api_key_store.a_session_maker', async_session_maker): + result = await api_key_store.validate_api_key(api_key_value) + + # Assert + assert isinstance(result, ApiKeyValidationResult) assert result is not None assert result.user_id == user_id + assert result.org_id is None @pytest.mark.asyncio diff --git a/enterprise/tests/unit/test_authorization.py b/enterprise/tests/unit/test_authorization.py index c751e6454a..a4051b4824 100644 --- a/enterprise/tests/unit/test_authorization.py +++ b/enterprise/tests/unit/test_authorization.py @@ -13,6 +13,7 @@ from server.auth.authorization import ( ROLE_PERMISSIONS, Permission, RoleName, + get_api_key_org_id_from_request, get_role_permissions, get_user_org_role, has_permission, @@ -444,6 +445,15 @@ class TestGetUserOrgRole: # ============================================================================= +def _create_mock_request(api_key_org_id=None): + """Helper to create a mock request with optional api_key_org_id.""" + mock_request = MagicMock() + mock_user_auth = MagicMock() + mock_user_auth.get_api_key_org_id.return_value = api_key_org_id + mock_request.state.user_auth = mock_user_auth + return mock_request + + class TestRequirePermission: """Tests for require_permission dependency factory.""" @@ -456,6 +466,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -465,7 +476,9 @@ class TestRequirePermission: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -476,10 +489,11 @@ class TestRequirePermission: THEN: 401 Unauthorized is raised """ org_id = uuid4() + mock_request = _create_mock_request() permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=None) + await permission_checker(request=mock_request, org_id=org_id, user_id=None) assert exc_info.value.status_code == 401 assert 'not authenticated' in exc_info.value.detail.lower() @@ -493,6 +507,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() with patch( 'server.auth.authorization.get_user_org_role', @@ -500,7 +515,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 assert 'not a member' in exc_info.value.detail.lower() @@ -514,6 +531,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -524,7 +542,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 assert 'delete_organization' in exc_info.value.detail.lower() @@ -538,6 +558,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'owner' @@ -547,7 +568,9 @@ class TestRequirePermission: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -559,6 +582,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -569,7 +593,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 @@ -582,6 +608,7 @@ class TestRequirePermission: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -595,7 +622,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.DELETE_ORGANIZATION) with pytest.raises(HTTPException): - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) mock_logger.warning.assert_called() call_args = mock_logger.warning.call_args @@ -611,6 +640,7 @@ class TestRequirePermission: THEN: User ID is returned """ user_id = str(uuid4()) + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -620,7 +650,9 @@ class TestRequirePermission: AsyncMock(return_value=mock_role), ) as mock_get_role: permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) - result = await permission_checker(org_id=None, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=None, user_id=user_id + ) assert result == user_id mock_get_role.assert_called_once_with(user_id, None) @@ -632,6 +664,7 @@ class TestRequirePermission: THEN: HTTPException with 403 status is raised """ user_id = str(uuid4()) + mock_request = _create_mock_request() with patch( 'server.auth.authorization.get_user_org_role', @@ -639,7 +672,9 @@ class TestRequirePermission: ): permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=None, user_id=user_id) + await permission_checker( + request=mock_request, org_id=None, user_id=user_id + ) assert exc_info.value.status_code == 403 assert 'not a member' in exc_info.value.detail @@ -662,6 +697,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -671,7 +707,9 @@ class TestPermissionScenarios: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.MANAGE_SECRETS) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -683,6 +721,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'member' @@ -695,7 +734,9 @@ class TestPermissionScenarios: Permission.INVITE_USER_TO_ORGANIZATION ) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 @@ -708,6 +749,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -719,7 +761,9 @@ class TestPermissionScenarios: permission_checker = require_permission( Permission.INVITE_USER_TO_ORGANIZATION ) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id @pytest.mark.asyncio @@ -731,6 +775,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'admin' @@ -741,7 +786,9 @@ class TestPermissionScenarios: ): permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) with pytest.raises(HTTPException) as exc_info: - await permission_checker(org_id=org_id, user_id=user_id) + await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert exc_info.value.status_code == 403 @@ -754,6 +801,7 @@ class TestPermissionScenarios: """ user_id = str(uuid4()) org_id = uuid4() + mock_request = _create_mock_request() mock_role = MagicMock() mock_role.name = 'owner' @@ -763,5 +811,200 @@ class TestPermissionScenarios: AsyncMock(return_value=mock_role), ): permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) - result = await permission_checker(org_id=org_id, user_id=user_id) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) assert result == user_id + + +# ============================================================================= +# Tests for API key organization validation +# ============================================================================= + + +class TestApiKeyOrgValidation: + """Tests for API key organization binding validation in require_permission.""" + + @pytest.mark.asyncio + async def test_allows_access_when_api_key_org_matches_target_org(self): + """ + GIVEN: API key with org_id that matches the target org_id in the request + WHEN: Permission checker is called + THEN: User ID is returned (access allowed) + """ + # Arrange + user_id = str(uuid4()) + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=org_id) + + mock_role = MagicMock() + mock_role.name = 'admin' + + # Act & Assert + with patch( + 'server.auth.authorization.get_user_org_role', + AsyncMock(return_value=mock_role), + ): + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) + assert result == user_id + + @pytest.mark.asyncio + async def test_denies_access_when_api_key_org_mismatches_target_org(self): + """ + GIVEN: API key created for Org A, but user tries to access Org B + WHEN: Permission checker is called + THEN: 403 Forbidden is raised with org mismatch message + """ + # Arrange + user_id = str(uuid4()) + api_key_org_id = uuid4() # Org A - where API key was created + target_org_id = uuid4() # Org B - where user is trying to access + mock_request = _create_mock_request(api_key_org_id=api_key_org_id) + + # Act & Assert + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + with pytest.raises(HTTPException) as exc_info: + await permission_checker( + request=mock_request, org_id=target_org_id, user_id=user_id + ) + + assert exc_info.value.status_code == 403 + assert ( + 'API key is not authorized for this organization' in exc_info.value.detail + ) + + @pytest.mark.asyncio + async def test_allows_access_for_legacy_api_key_without_org_binding(self): + """ + GIVEN: Legacy API key without org_id binding (org_id is None) + WHEN: Permission checker is called + THEN: Falls through to normal permission check (backward compatible) + """ + # Arrange + user_id = str(uuid4()) + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=None) + + mock_role = MagicMock() + mock_role.name = 'admin' + + # Act & Assert + with patch( + 'server.auth.authorization.get_user_org_role', + AsyncMock(return_value=mock_role), + ): + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) + assert result == user_id + + @pytest.mark.asyncio + async def test_allows_access_for_cookie_auth_without_api_key_org_id(self): + """ + GIVEN: Cookie-based authentication (no api_key_org_id in user_auth) + WHEN: Permission checker is called + THEN: Falls through to normal permission check + """ + # Arrange + user_id = str(uuid4()) + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=None) + + mock_role = MagicMock() + mock_role.name = 'admin' + + # Act & Assert + with patch( + 'server.auth.authorization.get_user_org_role', + AsyncMock(return_value=mock_role), + ): + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + result = await permission_checker( + request=mock_request, org_id=org_id, user_id=user_id + ) + assert result == user_id + + @pytest.mark.asyncio + async def test_logs_warning_on_api_key_org_mismatch(self): + """ + GIVEN: API key org_id doesn't match target org_id + WHEN: Permission checker is called + THEN: Warning is logged with org mismatch details + """ + # Arrange + user_id = str(uuid4()) + api_key_org_id = uuid4() + target_org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=api_key_org_id) + + # Act & Assert + with patch('server.auth.authorization.logger') as mock_logger: + permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) + with pytest.raises(HTTPException): + await permission_checker( + request=mock_request, org_id=target_org_id, user_id=user_id + ) + + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args + assert call_args[1]['extra']['user_id'] == user_id + assert call_args[1]['extra']['api_key_org_id'] == str(api_key_org_id) + assert call_args[1]['extra']['target_org_id'] == str(target_org_id) + + +class TestGetApiKeyOrgIdFromRequest: + """Tests for get_api_key_org_id_from_request helper function.""" + + @pytest.mark.asyncio + async def test_returns_org_id_when_user_auth_has_api_key_org_id(self): + """ + GIVEN: Request with user_auth that has api_key_org_id + WHEN: get_api_key_org_id_from_request is called + THEN: Returns the api_key_org_id + """ + # Arrange + org_id = uuid4() + mock_request = _create_mock_request(api_key_org_id=org_id) + + # Act + result = await get_api_key_org_id_from_request(mock_request) + + # Assert + assert result == org_id + + @pytest.mark.asyncio + async def test_returns_none_when_user_auth_has_no_api_key_org_id(self): + """ + GIVEN: Request with user_auth that has no api_key_org_id (cookie auth) + WHEN: get_api_key_org_id_from_request is called + THEN: Returns None + """ + # Arrange + mock_request = _create_mock_request(api_key_org_id=None) + + # Act + result = await get_api_key_org_id_from_request(mock_request) + + # Assert + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_no_user_auth_in_request(self): + """ + GIVEN: Request without user_auth in state + WHEN: get_api_key_org_id_from_request is called + THEN: Returns None + """ + # Arrange + mock_request = MagicMock() + mock_request.state.user_auth = None + + # Act + result = await get_api_key_org_id_from_request(mock_request) + + # Assert + assert result is None diff --git a/enterprise/tests/unit/test_saas_user_auth.py b/enterprise/tests/unit/test_saas_user_auth.py index 2fb1b68445..726702f310 100644 --- a/enterprise/tests/unit/test_saas_user_auth.py +++ b/enterprise/tests/unit/test_saas_user_auth.py @@ -459,7 +459,8 @@ async def test_get_instance_no_auth(mock_request): @pytest.mark.asyncio async def test_saas_user_auth_from_bearer_success(): - """Test successful authentication from bearer token.""" + """Test successful authentication from bearer token sets user_id and api_key_org_id.""" + # Arrange mock_request = MagicMock() mock_request.headers = {'Authorization': 'Bearer test_api_key'}