fix(backend): validate API key org_id during authorization to prevent cross-org access (org project) (#13468)

This commit is contained in:
Hiep Le
2026-03-19 16:09:37 +07:00
committed by GitHub
parent 8039807c3f
commit e02dbb8974
7 changed files with 354 additions and 32 deletions

View File

@@ -35,7 +35,7 @@ Usage:
from enum import Enum from enum import Enum
from uuid import UUID from uuid import UUID
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, Request, status
from storage.org_member_store import OrgMemberStore from storage.org_member_store import OrgMemberStore
from storage.role import Role from storage.role import Role
from storage.role_store import RoleStore from storage.role_store import RoleStore
@@ -214,6 +214,19 @@ def has_permission(user_role: Role, permission: Permission) -> bool:
return permission in permissions return permission in permissions
async def get_api_key_org_id_from_request(request: Request) -> UUID | None:
"""Get the org_id bound to the API key used for authentication.
Returns None if:
- Not authenticated via API key (cookie auth)
- API key is a legacy key without org binding
"""
user_auth = getattr(request.state, 'user_auth', None)
if user_auth and hasattr(user_auth, 'get_api_key_org_id'):
return user_auth.get_api_key_org_id()
return None
def require_permission(permission: Permission): def require_permission(permission: Permission):
""" """
Factory function that creates a dependency to require a specific permission. Factory function that creates a dependency to require a specific permission.
@@ -221,8 +234,9 @@ def require_permission(permission: Permission):
This creates a FastAPI dependency that: This creates a FastAPI dependency that:
1. Extracts org_id from the path parameter 1. Extracts org_id from the path parameter
2. Gets the authenticated user_id 2. Gets the authenticated user_id
3. Checks if the user has the required permission in the organization 3. Validates API key org binding (if using API key auth)
4. Returns the user_id if authorized, raises HTTPException otherwise 4. Checks if the user has the required permission in the organization
5. Returns the user_id if authorized, raises HTTPException otherwise
Usage: Usage:
@router.get('/{org_id}/settings') @router.get('/{org_id}/settings')
@@ -240,6 +254,7 @@ def require_permission(permission: Permission):
""" """
async def permission_checker( async def permission_checker(
request: Request,
org_id: UUID | None = None, org_id: UUID | None = None,
user_id: str | None = Depends(get_user_id), user_id: str | None = Depends(get_user_id),
) -> str: ) -> str:
@@ -249,6 +264,23 @@ def require_permission(permission: Permission):
detail='User not authenticated', detail='User not authenticated',
) )
# Validate API key organization binding
api_key_org_id = await get_api_key_org_id_from_request(request)
if api_key_org_id is not None and org_id is not None:
if api_key_org_id != org_id:
logger.warning(
'API key organization mismatch',
extra={
'user_id': user_id,
'api_key_org_id': str(api_key_org_id),
'target_org_id': str(org_id),
},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='API key is not authorized for this organization',
)
user_role = await get_user_org_role(user_id, org_id) user_role = await get_user_org_role(user_id, org_id)
if not user_role: if not user_role:

View File

@@ -61,10 +61,19 @@ class SaasUserAuth(UserAuth):
accepted_tos: bool | None = None accepted_tos: bool | None = None
auth_type: AuthType = AuthType.COOKIE auth_type: AuthType = AuthType.COOKIE
# API key context fields - populated when authenticated via API key # API key context fields - populated when authenticated via API key
api_key_org_id: UUID | None = None api_key_org_id: UUID | None = None # Org bound to the API key used for auth
api_key_id: int | None = None api_key_id: int | None = None
api_key_name: str | None = None api_key_name: str | None = None
def get_api_key_org_id(self) -> UUID | None:
"""Get the organization ID bound to the API key used for authentication.
Returns:
The org_id if authenticated via API key with org binding, None otherwise
(cookie auth or legacy API keys without org binding).
"""
return self.api_key_org_id
async def get_user_id(self) -> str | None: async def get_user_id(self) -> str | None:
return self.user_id return self.user_id

View File

@@ -16,10 +16,10 @@ from openhands.core.logger import openhands_logger as logger
@dataclass @dataclass
class ApiKeyValidationResult: class ApiKeyValidationResult:
"""Result of API key validation containing user and org context.""" """Result of API key validation containing user and organization info."""
user_id: str user_id: str
org_id: UUID | None org_id: UUID | None # None for legacy API keys without org binding
key_id: int key_id: int
key_name: str | None key_name: str | None
@@ -195,7 +195,12 @@ class ApiKeyStore:
return api_key return api_key
async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None: async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None:
"""Validate an API key and return the associated user_id and org_id if valid.""" """Validate an API key and return the associated user_id and org_id if valid.
Returns:
ApiKeyValidationResult if the key is valid, None otherwise.
The org_id may be None for legacy API keys that weren't bound to an organization.
"""
now = datetime.now(UTC) now = datetime.now(UTC)
async with a_session_maker() as session: async with a_session_maker() as session:

View File

@@ -15,13 +15,13 @@ class SaasConversationValidator(ConversationValidator):
async def _validate_api_key(self, api_key: str) -> str | None: async def _validate_api_key(self, api_key: str) -> str | None:
""" """
Validate an API key and return the user_id and github_user_id if valid. Validate an API key and return the user_id if valid.
Args: Args:
api_key: The API key to validate api_key: The API key to validate
Returns: Returns:
A tuple of (user_id, github_user_id) if the API key is valid, None otherwise The user_id if the API key is valid, None otherwise
""" """
try: try:
token_manager = TokenManager() token_manager = TokenManager()

View File

@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from sqlalchemy import select from sqlalchemy import select
from storage.api_key import ApiKey from storage.api_key import ApiKey
from storage.api_key_store import ApiKeyStore from storage.api_key_store import ApiKeyStore, ApiKeyValidationResult
@pytest.fixture @pytest.fixture
@@ -110,8 +110,8 @@ async def test_create_api_key(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_api_key_valid(api_key_store, async_session_maker): async def test_validate_api_key_valid(api_key_store, async_session_maker):
"""Test validating a valid API key.""" """Test validating a valid API key returns user_id and org_id."""
# Setup - create an API key in the database # Arrange
user_id = str(uuid.uuid4()) user_id = str(uuid.uuid4())
org_id = uuid.uuid4() org_id = uuid.uuid4()
api_key_value = 'test-api-key' api_key_value = 'test-api-key'
@@ -128,11 +128,12 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker):
await session.commit() await session.commit()
key_id = key_record.id key_id = key_record.id
# Execute - patch a_session_maker to use test's async session maker # Act
with patch('storage.api_key_store.a_session_maker', async_session_maker): with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key(api_key_value) result = await api_key_store.validate_api_key(api_key_value)
# Verify - result is now ApiKeyValidationResult # Assert
assert isinstance(result, ApiKeyValidationResult)
assert result is not None assert result is not None
assert result.user_id == user_id assert result.user_id == user_id
assert result.org_id == org_id assert result.org_id == org_id
@@ -202,7 +203,7 @@ async def test_validate_api_key_valid_timezone_naive(
api_key_store, async_session_maker api_key_store, async_session_maker
): ):
"""Test validating a valid API key with timezone-naive datetime from database.""" """Test validating a valid API key with timezone-naive datetime from database."""
# Setup - create a valid API key with timezone-naive datetime (future date) # Arrange
user_id = str(uuid.uuid4()) user_id = str(uuid.uuid4())
org_id = uuid.uuid4() org_id = uuid.uuid4()
api_key_value = 'test-valid-naive-key' api_key_value = 'test-valid-naive-key'
@@ -219,13 +220,44 @@ async def test_validate_api_key_valid_timezone_naive(
session.add(key_record) session.add(key_record)
await session.commit() await session.commit()
# Execute - patch a_session_maker to use test's async session maker # Act
with patch('storage.api_key_store.a_session_maker', async_session_maker): with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key(api_key_value) result = await api_key_store.validate_api_key(api_key_value)
# Verify - result is now ApiKeyValidationResult # Assert
assert isinstance(result, ApiKeyValidationResult)
assert result.user_id == user_id
assert result.org_id == org_id
@pytest.mark.asyncio
async def test_validate_api_key_legacy_without_org_id(
api_key_store, async_session_maker
):
"""Test validating a legacy API key without org_id returns None for org_id."""
# Arrange
user_id = str(uuid.uuid4())
api_key_value = 'test-legacy-key-no-org'
async with async_session_maker() as session:
key_record = ApiKey(
key=api_key_value,
user_id=user_id,
org_id=None, # Legacy key without org binding
name='Legacy Key',
)
session.add(key_record)
await session.commit()
# Act
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key(api_key_value)
# Assert
assert isinstance(result, ApiKeyValidationResult)
assert result is not None assert result is not None
assert result.user_id == user_id assert result.user_id == user_id
assert result.org_id is None
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -13,6 +13,7 @@ from server.auth.authorization import (
ROLE_PERMISSIONS, ROLE_PERMISSIONS,
Permission, Permission,
RoleName, RoleName,
get_api_key_org_id_from_request,
get_role_permissions, get_role_permissions,
get_user_org_role, get_user_org_role,
has_permission, has_permission,
@@ -444,6 +445,15 @@ class TestGetUserOrgRole:
# ============================================================================= # =============================================================================
def _create_mock_request(api_key_org_id=None):
"""Helper to create a mock request with optional api_key_org_id."""
mock_request = MagicMock()
mock_user_auth = MagicMock()
mock_user_auth.get_api_key_org_id.return_value = api_key_org_id
mock_request.state.user_auth = mock_user_auth
return mock_request
class TestRequirePermission: class TestRequirePermission:
"""Tests for require_permission dependency factory.""" """Tests for require_permission dependency factory."""
@@ -456,6 +466,7 @@ class TestRequirePermission:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'admin' mock_role.name = 'admin'
@@ -465,7 +476,9 @@ class TestRequirePermission:
AsyncMock(return_value=mock_role), AsyncMock(return_value=mock_role),
): ):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=org_id, user_id=user_id) result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id assert result == user_id
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -476,10 +489,11 @@ class TestRequirePermission:
THEN: 401 Unauthorized is raised THEN: 401 Unauthorized is raised
""" """
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=None) await permission_checker(request=mock_request, org_id=org_id, user_id=None)
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
assert 'not authenticated' in exc_info.value.detail.lower() assert 'not authenticated' in exc_info.value.detail.lower()
@@ -493,6 +507,7 @@ class TestRequirePermission:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
with patch( with patch(
'server.auth.authorization.get_user_org_role', 'server.auth.authorization.get_user_org_role',
@@ -500,7 +515,9 @@ class TestRequirePermission:
): ):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id) await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert exc_info.value.status_code == 403 assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail.lower() assert 'not a member' in exc_info.value.detail.lower()
@@ -514,6 +531,7 @@ class TestRequirePermission:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'member' mock_role.name = 'member'
@@ -524,7 +542,9 @@ class TestRequirePermission:
): ):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION) permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id) await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert exc_info.value.status_code == 403 assert exc_info.value.status_code == 403
assert 'delete_organization' in exc_info.value.detail.lower() assert 'delete_organization' in exc_info.value.detail.lower()
@@ -538,6 +558,7 @@ class TestRequirePermission:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'owner' mock_role.name = 'owner'
@@ -547,7 +568,9 @@ class TestRequirePermission:
AsyncMock(return_value=mock_role), AsyncMock(return_value=mock_role),
): ):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION) permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
result = await permission_checker(org_id=org_id, user_id=user_id) result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id assert result == user_id
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -559,6 +582,7 @@ class TestRequirePermission:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'admin' mock_role.name = 'admin'
@@ -569,7 +593,9 @@ class TestRequirePermission:
): ):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION) permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id) await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert exc_info.value.status_code == 403 assert exc_info.value.status_code == 403
@@ -582,6 +608,7 @@ class TestRequirePermission:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'member' mock_role.name = 'member'
@@ -595,7 +622,9 @@ class TestRequirePermission:
): ):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION) permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
await permission_checker(org_id=org_id, user_id=user_id) await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
mock_logger.warning.assert_called() mock_logger.warning.assert_called()
call_args = mock_logger.warning.call_args call_args = mock_logger.warning.call_args
@@ -611,6 +640,7 @@ class TestRequirePermission:
THEN: User ID is returned THEN: User ID is returned
""" """
user_id = str(uuid4()) user_id = str(uuid4())
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'admin' mock_role.name = 'admin'
@@ -620,7 +650,9 @@ class TestRequirePermission:
AsyncMock(return_value=mock_role), AsyncMock(return_value=mock_role),
) as mock_get_role: ) as mock_get_role:
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=None, user_id=user_id) result = await permission_checker(
request=mock_request, org_id=None, user_id=user_id
)
assert result == user_id assert result == user_id
mock_get_role.assert_called_once_with(user_id, None) mock_get_role.assert_called_once_with(user_id, None)
@@ -632,6 +664,7 @@ class TestRequirePermission:
THEN: HTTPException with 403 status is raised THEN: HTTPException with 403 status is raised
""" """
user_id = str(uuid4()) user_id = str(uuid4())
mock_request = _create_mock_request()
with patch( with patch(
'server.auth.authorization.get_user_org_role', 'server.auth.authorization.get_user_org_role',
@@ -639,7 +672,9 @@ class TestRequirePermission:
): ):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS) permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=None, user_id=user_id) await permission_checker(
request=mock_request, org_id=None, user_id=user_id
)
assert exc_info.value.status_code == 403 assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail assert 'not a member' in exc_info.value.detail
@@ -662,6 +697,7 @@ class TestPermissionScenarios:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'member' mock_role.name = 'member'
@@ -671,7 +707,9 @@ class TestPermissionScenarios:
AsyncMock(return_value=mock_role), AsyncMock(return_value=mock_role),
): ):
permission_checker = require_permission(Permission.MANAGE_SECRETS) permission_checker = require_permission(Permission.MANAGE_SECRETS)
result = await permission_checker(org_id=org_id, user_id=user_id) result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id assert result == user_id
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -683,6 +721,7 @@ class TestPermissionScenarios:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'member' mock_role.name = 'member'
@@ -695,7 +734,9 @@ class TestPermissionScenarios:
Permission.INVITE_USER_TO_ORGANIZATION Permission.INVITE_USER_TO_ORGANIZATION
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id) await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert exc_info.value.status_code == 403 assert exc_info.value.status_code == 403
@@ -708,6 +749,7 @@ class TestPermissionScenarios:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'admin' mock_role.name = 'admin'
@@ -719,7 +761,9 @@ class TestPermissionScenarios:
permission_checker = require_permission( permission_checker = require_permission(
Permission.INVITE_USER_TO_ORGANIZATION Permission.INVITE_USER_TO_ORGANIZATION
) )
result = await permission_checker(org_id=org_id, user_id=user_id) result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id assert result == user_id
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -731,6 +775,7 @@ class TestPermissionScenarios:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'admin' mock_role.name = 'admin'
@@ -741,7 +786,9 @@ class TestPermissionScenarios:
): ):
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id) await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert exc_info.value.status_code == 403 assert exc_info.value.status_code == 403
@@ -754,6 +801,7 @@ class TestPermissionScenarios:
""" """
user_id = str(uuid4()) user_id = str(uuid4())
org_id = uuid4() org_id = uuid4()
mock_request = _create_mock_request()
mock_role = MagicMock() mock_role = MagicMock()
mock_role.name = 'owner' mock_role.name = 'owner'
@@ -763,5 +811,200 @@ class TestPermissionScenarios:
AsyncMock(return_value=mock_role), AsyncMock(return_value=mock_role),
): ):
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER) permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
result = await permission_checker(org_id=org_id, user_id=user_id) result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id assert result == user_id
# =============================================================================
# Tests for API key organization validation
# =============================================================================
class TestApiKeyOrgValidation:
"""Tests for API key organization binding validation in require_permission."""
@pytest.mark.asyncio
async def test_allows_access_when_api_key_org_matches_target_org(self):
"""
GIVEN: API key with org_id that matches the target org_id in the request
WHEN: Permission checker is called
THEN: User ID is returned (access allowed)
"""
# Arrange
user_id = str(uuid4())
org_id = uuid4()
mock_request = _create_mock_request(api_key_org_id=org_id)
mock_role = MagicMock()
mock_role.name = 'admin'
# Act & Assert
with patch(
'server.auth.authorization.get_user_org_role',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id
@pytest.mark.asyncio
async def test_denies_access_when_api_key_org_mismatches_target_org(self):
"""
GIVEN: API key created for Org A, but user tries to access Org B
WHEN: Permission checker is called
THEN: 403 Forbidden is raised with org mismatch message
"""
# Arrange
user_id = str(uuid4())
api_key_org_id = uuid4() # Org A - where API key was created
target_org_id = uuid4() # Org B - where user is trying to access
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
# Act & Assert
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(
request=mock_request, org_id=target_org_id, user_id=user_id
)
assert exc_info.value.status_code == 403
assert (
'API key is not authorized for this organization' in exc_info.value.detail
)
@pytest.mark.asyncio
async def test_allows_access_for_legacy_api_key_without_org_binding(self):
"""
GIVEN: Legacy API key without org_id binding (org_id is None)
WHEN: Permission checker is called
THEN: Falls through to normal permission check (backward compatible)
"""
# Arrange
user_id = str(uuid4())
org_id = uuid4()
mock_request = _create_mock_request(api_key_org_id=None)
mock_role = MagicMock()
mock_role.name = 'admin'
# Act & Assert
with patch(
'server.auth.authorization.get_user_org_role',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id
@pytest.mark.asyncio
async def test_allows_access_for_cookie_auth_without_api_key_org_id(self):
"""
GIVEN: Cookie-based authentication (no api_key_org_id in user_auth)
WHEN: Permission checker is called
THEN: Falls through to normal permission check
"""
# Arrange
user_id = str(uuid4())
org_id = uuid4()
mock_request = _create_mock_request(api_key_org_id=None)
mock_role = MagicMock()
mock_role.name = 'admin'
# Act & Assert
with patch(
'server.auth.authorization.get_user_org_role',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(
request=mock_request, org_id=org_id, user_id=user_id
)
assert result == user_id
@pytest.mark.asyncio
async def test_logs_warning_on_api_key_org_mismatch(self):
"""
GIVEN: API key org_id doesn't match target org_id
WHEN: Permission checker is called
THEN: Warning is logged with org mismatch details
"""
# Arrange
user_id = str(uuid4())
api_key_org_id = uuid4()
target_org_id = uuid4()
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
# Act & Assert
with patch('server.auth.authorization.logger') as mock_logger:
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException):
await permission_checker(
request=mock_request, org_id=target_org_id, user_id=user_id
)
mock_logger.warning.assert_called()
call_args = mock_logger.warning.call_args
assert call_args[1]['extra']['user_id'] == user_id
assert call_args[1]['extra']['api_key_org_id'] == str(api_key_org_id)
assert call_args[1]['extra']['target_org_id'] == str(target_org_id)
class TestGetApiKeyOrgIdFromRequest:
"""Tests for get_api_key_org_id_from_request helper function."""
@pytest.mark.asyncio
async def test_returns_org_id_when_user_auth_has_api_key_org_id(self):
"""
GIVEN: Request with user_auth that has api_key_org_id
WHEN: get_api_key_org_id_from_request is called
THEN: Returns the api_key_org_id
"""
# Arrange
org_id = uuid4()
mock_request = _create_mock_request(api_key_org_id=org_id)
# Act
result = await get_api_key_org_id_from_request(mock_request)
# Assert
assert result == org_id
@pytest.mark.asyncio
async def test_returns_none_when_user_auth_has_no_api_key_org_id(self):
"""
GIVEN: Request with user_auth that has no api_key_org_id (cookie auth)
WHEN: get_api_key_org_id_from_request is called
THEN: Returns None
"""
# Arrange
mock_request = _create_mock_request(api_key_org_id=None)
# Act
result = await get_api_key_org_id_from_request(mock_request)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_no_user_auth_in_request(self):
"""
GIVEN: Request without user_auth in state
WHEN: get_api_key_org_id_from_request is called
THEN: Returns None
"""
# Arrange
mock_request = MagicMock()
mock_request.state.user_auth = None
# Act
result = await get_api_key_org_id_from_request(mock_request)
# Assert
assert result is None

View File

@@ -459,7 +459,8 @@ async def test_get_instance_no_auth(mock_request):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_saas_user_auth_from_bearer_success(): async def test_saas_user_auth_from_bearer_success():
"""Test successful authentication from bearer token.""" """Test successful authentication from bearer token sets user_id and api_key_org_id."""
# Arrange
mock_request = MagicMock() mock_request = MagicMock()
mock_request.headers = {'Authorization': 'Bearer test_api_key'} mock_request.headers = {'Authorization': 'Bearer test_api_key'}