mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
fix(backend): validate API key org_id during authorization to prevent cross-org access (org project) (#13468)
This commit is contained in:
@@ -35,7 +35,7 @@ Usage:
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
from storage.org_member_store import OrgMemberStore
|
from storage.org_member_store import OrgMemberStore
|
||||||
from storage.role import Role
|
from storage.role import Role
|
||||||
from storage.role_store import RoleStore
|
from storage.role_store import RoleStore
|
||||||
@@ -214,6 +214,19 @@ def has_permission(user_role: Role, permission: Permission) -> bool:
|
|||||||
return permission in permissions
|
return permission in permissions
|
||||||
|
|
||||||
|
|
||||||
|
async def get_api_key_org_id_from_request(request: Request) -> UUID | None:
|
||||||
|
"""Get the org_id bound to the API key used for authentication.
|
||||||
|
|
||||||
|
Returns None if:
|
||||||
|
- Not authenticated via API key (cookie auth)
|
||||||
|
- API key is a legacy key without org binding
|
||||||
|
"""
|
||||||
|
user_auth = getattr(request.state, 'user_auth', None)
|
||||||
|
if user_auth and hasattr(user_auth, 'get_api_key_org_id'):
|
||||||
|
return user_auth.get_api_key_org_id()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def require_permission(permission: Permission):
|
def require_permission(permission: Permission):
|
||||||
"""
|
"""
|
||||||
Factory function that creates a dependency to require a specific permission.
|
Factory function that creates a dependency to require a specific permission.
|
||||||
@@ -221,8 +234,9 @@ def require_permission(permission: Permission):
|
|||||||
This creates a FastAPI dependency that:
|
This creates a FastAPI dependency that:
|
||||||
1. Extracts org_id from the path parameter
|
1. Extracts org_id from the path parameter
|
||||||
2. Gets the authenticated user_id
|
2. Gets the authenticated user_id
|
||||||
3. Checks if the user has the required permission in the organization
|
3. Validates API key org binding (if using API key auth)
|
||||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
4. Checks if the user has the required permission in the organization
|
||||||
|
5. Returns the user_id if authorized, raises HTTPException otherwise
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@router.get('/{org_id}/settings')
|
@router.get('/{org_id}/settings')
|
||||||
@@ -240,6 +254,7 @@ def require_permission(permission: Permission):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def permission_checker(
|
async def permission_checker(
|
||||||
|
request: Request,
|
||||||
org_id: UUID | None = None,
|
org_id: UUID | None = None,
|
||||||
user_id: str | None = Depends(get_user_id),
|
user_id: str | None = Depends(get_user_id),
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -249,6 +264,23 @@ def require_permission(permission: Permission):
|
|||||||
detail='User not authenticated',
|
detail='User not authenticated',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate API key organization binding
|
||||||
|
api_key_org_id = await get_api_key_org_id_from_request(request)
|
||||||
|
if api_key_org_id is not None and org_id is not None:
|
||||||
|
if api_key_org_id != org_id:
|
||||||
|
logger.warning(
|
||||||
|
'API key organization mismatch',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'api_key_org_id': str(api_key_org_id),
|
||||||
|
'target_org_id': str(org_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail='API key is not authorized for this organization',
|
||||||
|
)
|
||||||
|
|
||||||
user_role = await get_user_org_role(user_id, org_id)
|
user_role = await get_user_org_role(user_id, org_id)
|
||||||
|
|
||||||
if not user_role:
|
if not user_role:
|
||||||
|
|||||||
@@ -61,10 +61,19 @@ class SaasUserAuth(UserAuth):
|
|||||||
accepted_tos: bool | None = None
|
accepted_tos: bool | None = None
|
||||||
auth_type: AuthType = AuthType.COOKIE
|
auth_type: AuthType = AuthType.COOKIE
|
||||||
# API key context fields - populated when authenticated via API key
|
# API key context fields - populated when authenticated via API key
|
||||||
api_key_org_id: UUID | None = None
|
api_key_org_id: UUID | None = None # Org bound to the API key used for auth
|
||||||
api_key_id: int | None = None
|
api_key_id: int | None = None
|
||||||
api_key_name: str | None = None
|
api_key_name: str | None = None
|
||||||
|
|
||||||
|
def get_api_key_org_id(self) -> UUID | None:
|
||||||
|
"""Get the organization ID bound to the API key used for authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The org_id if authenticated via API key with org binding, None otherwise
|
||||||
|
(cookie auth or legacy API keys without org binding).
|
||||||
|
"""
|
||||||
|
return self.api_key_org_id
|
||||||
|
|
||||||
async def get_user_id(self) -> str | None:
|
async def get_user_id(self) -> str | None:
|
||||||
return self.user_id
|
return self.user_id
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ from openhands.core.logger import openhands_logger as logger
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ApiKeyValidationResult:
|
class ApiKeyValidationResult:
|
||||||
"""Result of API key validation containing user and org context."""
|
"""Result of API key validation containing user and organization info."""
|
||||||
|
|
||||||
user_id: str
|
user_id: str
|
||||||
org_id: UUID | None
|
org_id: UUID | None # None for legacy API keys without org binding
|
||||||
key_id: int
|
key_id: int
|
||||||
key_name: str | None
|
key_name: str | None
|
||||||
|
|
||||||
@@ -195,7 +195,12 @@ class ApiKeyStore:
|
|||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None:
|
async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None:
|
||||||
"""Validate an API key and return the associated user_id and org_id if valid."""
|
"""Validate an API key and return the associated user_id and org_id if valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiKeyValidationResult if the key is valid, None otherwise.
|
||||||
|
The org_id may be None for legacy API keys that weren't bound to an organization.
|
||||||
|
"""
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
async with a_session_maker() as session:
|
async with a_session_maker() as session:
|
||||||
|
|||||||
@@ -15,13 +15,13 @@ class SaasConversationValidator(ConversationValidator):
|
|||||||
|
|
||||||
async def _validate_api_key(self, api_key: str) -> str | None:
|
async def _validate_api_key(self, api_key: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Validate an API key and return the user_id and github_user_id if valid.
|
Validate an API key and return the user_id if valid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: The API key to validate
|
api_key: The API key to validate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, github_user_id) if the API key is valid, None otherwise
|
The user_id if the API key is valid, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
token_manager = TokenManager()
|
token_manager = TokenManager()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from storage.api_key import ApiKey
|
from storage.api_key import ApiKey
|
||||||
from storage.api_key_store import ApiKeyStore
|
from storage.api_key_store import ApiKeyStore, ApiKeyValidationResult
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -110,8 +110,8 @@ async def test_create_api_key(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||||
"""Test validating a valid API key."""
|
"""Test validating a valid API key returns user_id and org_id."""
|
||||||
# Setup - create an API key in the database
|
# Arrange
|
||||||
user_id = str(uuid.uuid4())
|
user_id = str(uuid.uuid4())
|
||||||
org_id = uuid.uuid4()
|
org_id = uuid.uuid4()
|
||||||
api_key_value = 'test-api-key'
|
api_key_value = 'test-api-key'
|
||||||
@@ -128,11 +128,12 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
key_id = key_record.id
|
key_id = key_record.id
|
||||||
|
|
||||||
# Execute - patch a_session_maker to use test's async session maker
|
# Act
|
||||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||||
result = await api_key_store.validate_api_key(api_key_value)
|
result = await api_key_store.validate_api_key(api_key_value)
|
||||||
|
|
||||||
# Verify - result is now ApiKeyValidationResult
|
# Assert
|
||||||
|
assert isinstance(result, ApiKeyValidationResult)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.user_id == user_id
|
assert result.user_id == user_id
|
||||||
assert result.org_id == org_id
|
assert result.org_id == org_id
|
||||||
@@ -202,7 +203,7 @@ async def test_validate_api_key_valid_timezone_naive(
|
|||||||
api_key_store, async_session_maker
|
api_key_store, async_session_maker
|
||||||
):
|
):
|
||||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||||
# Setup - create a valid API key with timezone-naive datetime (future date)
|
# Arrange
|
||||||
user_id = str(uuid.uuid4())
|
user_id = str(uuid.uuid4())
|
||||||
org_id = uuid.uuid4()
|
org_id = uuid.uuid4()
|
||||||
api_key_value = 'test-valid-naive-key'
|
api_key_value = 'test-valid-naive-key'
|
||||||
@@ -219,13 +220,44 @@ async def test_validate_api_key_valid_timezone_naive(
|
|||||||
session.add(key_record)
|
session.add(key_record)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Execute - patch a_session_maker to use test's async session maker
|
# Act
|
||||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||||
result = await api_key_store.validate_api_key(api_key_value)
|
result = await api_key_store.validate_api_key(api_key_value)
|
||||||
|
|
||||||
# Verify - result is now ApiKeyValidationResult
|
# Assert
|
||||||
|
assert isinstance(result, ApiKeyValidationResult)
|
||||||
|
assert result.user_id == user_id
|
||||||
|
assert result.org_id == org_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_api_key_legacy_without_org_id(
|
||||||
|
api_key_store, async_session_maker
|
||||||
|
):
|
||||||
|
"""Test validating a legacy API key without org_id returns None for org_id."""
|
||||||
|
# Arrange
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
api_key_value = 'test-legacy-key-no-org'
|
||||||
|
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
key_record = ApiKey(
|
||||||
|
key=api_key_value,
|
||||||
|
user_id=user_id,
|
||||||
|
org_id=None, # Legacy key without org binding
|
||||||
|
name='Legacy Key',
|
||||||
|
)
|
||||||
|
session.add(key_record)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||||
|
result = await api_key_store.validate_api_key(api_key_value)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(result, ApiKeyValidationResult)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.user_id == user_id
|
assert result.user_id == user_id
|
||||||
|
assert result.org_id is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from server.auth.authorization import (
|
|||||||
ROLE_PERMISSIONS,
|
ROLE_PERMISSIONS,
|
||||||
Permission,
|
Permission,
|
||||||
RoleName,
|
RoleName,
|
||||||
|
get_api_key_org_id_from_request,
|
||||||
get_role_permissions,
|
get_role_permissions,
|
||||||
get_user_org_role,
|
get_user_org_role,
|
||||||
has_permission,
|
has_permission,
|
||||||
@@ -444,6 +445,15 @@ class TestGetUserOrgRole:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mock_request(api_key_org_id=None):
|
||||||
|
"""Helper to create a mock request with optional api_key_org_id."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_user_auth = MagicMock()
|
||||||
|
mock_user_auth.get_api_key_org_id.return_value = api_key_org_id
|
||||||
|
mock_request.state.user_auth = mock_user_auth
|
||||||
|
return mock_request
|
||||||
|
|
||||||
|
|
||||||
class TestRequirePermission:
|
class TestRequirePermission:
|
||||||
"""Tests for require_permission dependency factory."""
|
"""Tests for require_permission dependency factory."""
|
||||||
|
|
||||||
@@ -456,6 +466,7 @@ class TestRequirePermission:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'admin'
|
mock_role.name = 'admin'
|
||||||
@@ -465,7 +476,9 @@ class TestRequirePermission:
|
|||||||
AsyncMock(return_value=mock_role),
|
AsyncMock(return_value=mock_role),
|
||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
assert result == user_id
|
assert result == user_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -476,10 +489,11 @@ class TestRequirePermission:
|
|||||||
THEN: 401 Unauthorized is raised
|
THEN: 401 Unauthorized is raised
|
||||||
"""
|
"""
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await permission_checker(org_id=org_id, user_id=None)
|
await permission_checker(request=mock_request, org_id=org_id, user_id=None)
|
||||||
|
|
||||||
assert exc_info.value.status_code == 401
|
assert exc_info.value.status_code == 401
|
||||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||||
@@ -493,6 +507,7 @@ class TestRequirePermission:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
'server.auth.authorization.get_user_org_role',
|
'server.auth.authorization.get_user_org_role',
|
||||||
@@ -500,7 +515,9 @@ class TestRequirePermission:
|
|||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await permission_checker(org_id=org_id, user_id=user_id)
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
assert exc_info.value.status_code == 403
|
assert exc_info.value.status_code == 403
|
||||||
assert 'not a member' in exc_info.value.detail.lower()
|
assert 'not a member' in exc_info.value.detail.lower()
|
||||||
@@ -514,6 +531,7 @@ class TestRequirePermission:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'member'
|
mock_role.name = 'member'
|
||||||
@@ -524,7 +542,9 @@ class TestRequirePermission:
|
|||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await permission_checker(org_id=org_id, user_id=user_id)
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
assert exc_info.value.status_code == 403
|
assert exc_info.value.status_code == 403
|
||||||
assert 'delete_organization' in exc_info.value.detail.lower()
|
assert 'delete_organization' in exc_info.value.detail.lower()
|
||||||
@@ -538,6 +558,7 @@ class TestRequirePermission:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'owner'
|
mock_role.name = 'owner'
|
||||||
@@ -547,7 +568,9 @@ class TestRequirePermission:
|
|||||||
AsyncMock(return_value=mock_role),
|
AsyncMock(return_value=mock_role),
|
||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
assert result == user_id
|
assert result == user_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -559,6 +582,7 @@ class TestRequirePermission:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'admin'
|
mock_role.name = 'admin'
|
||||||
@@ -569,7 +593,9 @@ class TestRequirePermission:
|
|||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await permission_checker(org_id=org_id, user_id=user_id)
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
assert exc_info.value.status_code == 403
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
@@ -582,6 +608,7 @@ class TestRequirePermission:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'member'
|
mock_role.name = 'member'
|
||||||
@@ -595,7 +622,9 @@ class TestRequirePermission:
|
|||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||||
with pytest.raises(HTTPException):
|
with pytest.raises(HTTPException):
|
||||||
await permission_checker(org_id=org_id, user_id=user_id)
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
mock_logger.warning.assert_called()
|
mock_logger.warning.assert_called()
|
||||||
call_args = mock_logger.warning.call_args
|
call_args = mock_logger.warning.call_args
|
||||||
@@ -611,6 +640,7 @@ class TestRequirePermission:
|
|||||||
THEN: User ID is returned
|
THEN: User ID is returned
|
||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'admin'
|
mock_role.name = 'admin'
|
||||||
@@ -620,7 +650,9 @@ class TestRequirePermission:
|
|||||||
AsyncMock(return_value=mock_role),
|
AsyncMock(return_value=mock_role),
|
||||||
) as mock_get_role:
|
) as mock_get_role:
|
||||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
result = await permission_checker(org_id=None, user_id=user_id)
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=None, user_id=user_id
|
||||||
|
)
|
||||||
assert result == user_id
|
assert result == user_id
|
||||||
mock_get_role.assert_called_once_with(user_id, None)
|
mock_get_role.assert_called_once_with(user_id, None)
|
||||||
|
|
||||||
@@ -632,6 +664,7 @@ class TestRequirePermission:
|
|||||||
THEN: HTTPException with 403 status is raised
|
THEN: HTTPException with 403 status is raised
|
||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
'server.auth.authorization.get_user_org_role',
|
'server.auth.authorization.get_user_org_role',
|
||||||
@@ -639,7 +672,9 @@ class TestRequirePermission:
|
|||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await permission_checker(org_id=None, user_id=user_id)
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=None, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
assert exc_info.value.status_code == 403
|
assert exc_info.value.status_code == 403
|
||||||
assert 'not a member' in exc_info.value.detail
|
assert 'not a member' in exc_info.value.detail
|
||||||
@@ -662,6 +697,7 @@ class TestPermissionScenarios:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'member'
|
mock_role.name = 'member'
|
||||||
@@ -671,7 +707,9 @@ class TestPermissionScenarios:
|
|||||||
AsyncMock(return_value=mock_role),
|
AsyncMock(return_value=mock_role),
|
||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.MANAGE_SECRETS)
|
permission_checker = require_permission(Permission.MANAGE_SECRETS)
|
||||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
assert result == user_id
|
assert result == user_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -683,6 +721,7 @@ class TestPermissionScenarios:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'member'
|
mock_role.name = 'member'
|
||||||
@@ -695,7 +734,9 @@ class TestPermissionScenarios:
|
|||||||
Permission.INVITE_USER_TO_ORGANIZATION
|
Permission.INVITE_USER_TO_ORGANIZATION
|
||||||
)
|
)
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await permission_checker(org_id=org_id, user_id=user_id)
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
assert exc_info.value.status_code == 403
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
@@ -708,6 +749,7 @@ class TestPermissionScenarios:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'admin'
|
mock_role.name = 'admin'
|
||||||
@@ -719,7 +761,9 @@ class TestPermissionScenarios:
|
|||||||
permission_checker = require_permission(
|
permission_checker = require_permission(
|
||||||
Permission.INVITE_USER_TO_ORGANIZATION
|
Permission.INVITE_USER_TO_ORGANIZATION
|
||||||
)
|
)
|
||||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
assert result == user_id
|
assert result == user_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -731,6 +775,7 @@ class TestPermissionScenarios:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'admin'
|
mock_role.name = 'admin'
|
||||||
@@ -741,7 +786,9 @@ class TestPermissionScenarios:
|
|||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await permission_checker(org_id=org_id, user_id=user_id)
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
assert exc_info.value.status_code == 403
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
@@ -754,6 +801,7 @@ class TestPermissionScenarios:
|
|||||||
"""
|
"""
|
||||||
user_id = str(uuid4())
|
user_id = str(uuid4())
|
||||||
org_id = uuid4()
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request()
|
||||||
|
|
||||||
mock_role = MagicMock()
|
mock_role = MagicMock()
|
||||||
mock_role.name = 'owner'
|
mock_role.name = 'owner'
|
||||||
@@ -763,5 +811,200 @@ class TestPermissionScenarios:
|
|||||||
AsyncMock(return_value=mock_role),
|
AsyncMock(return_value=mock_role),
|
||||||
):
|
):
|
||||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
assert result == user_id
|
assert result == user_id
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Tests for API key organization validation
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiKeyOrgValidation:
|
||||||
|
"""Tests for API key organization binding validation in require_permission."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_access_when_api_key_org_matches_target_org(self):
|
||||||
|
"""
|
||||||
|
GIVEN: API key with org_id that matches the target org_id in the request
|
||||||
|
WHEN: Permission checker is called
|
||||||
|
THEN: User ID is returned (access allowed)
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
user_id = str(uuid4())
|
||||||
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request(api_key_org_id=org_id)
|
||||||
|
|
||||||
|
mock_role = MagicMock()
|
||||||
|
mock_role.name = 'admin'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with patch(
|
||||||
|
'server.auth.authorization.get_user_org_role',
|
||||||
|
AsyncMock(return_value=mock_role),
|
||||||
|
):
|
||||||
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
assert result == user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_denies_access_when_api_key_org_mismatches_target_org(self):
|
||||||
|
"""
|
||||||
|
GIVEN: API key created for Org A, but user tries to access Org B
|
||||||
|
WHEN: Permission checker is called
|
||||||
|
THEN: 403 Forbidden is raised with org mismatch message
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
user_id = str(uuid4())
|
||||||
|
api_key_org_id = uuid4() # Org A - where API key was created
|
||||||
|
target_org_id = uuid4() # Org B - where user is trying to access
|
||||||
|
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=target_org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
assert (
|
||||||
|
'API key is not authorized for this organization' in exc_info.value.detail
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_access_for_legacy_api_key_without_org_binding(self):
|
||||||
|
"""
|
||||||
|
GIVEN: Legacy API key without org_id binding (org_id is None)
|
||||||
|
WHEN: Permission checker is called
|
||||||
|
THEN: Falls through to normal permission check (backward compatible)
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
user_id = str(uuid4())
|
||||||
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request(api_key_org_id=None)
|
||||||
|
|
||||||
|
mock_role = MagicMock()
|
||||||
|
mock_role.name = 'admin'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with patch(
|
||||||
|
'server.auth.authorization.get_user_org_role',
|
||||||
|
AsyncMock(return_value=mock_role),
|
||||||
|
):
|
||||||
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
assert result == user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_access_for_cookie_auth_without_api_key_org_id(self):
|
||||||
|
"""
|
||||||
|
GIVEN: Cookie-based authentication (no api_key_org_id in user_auth)
|
||||||
|
WHEN: Permission checker is called
|
||||||
|
THEN: Falls through to normal permission check
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
user_id = str(uuid4())
|
||||||
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request(api_key_org_id=None)
|
||||||
|
|
||||||
|
mock_role = MagicMock()
|
||||||
|
mock_role.name = 'admin'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with patch(
|
||||||
|
'server.auth.authorization.get_user_org_role',
|
||||||
|
AsyncMock(return_value=mock_role),
|
||||||
|
):
|
||||||
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
|
result = await permission_checker(
|
||||||
|
request=mock_request, org_id=org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
assert result == user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_logs_warning_on_api_key_org_mismatch(self):
|
||||||
|
"""
|
||||||
|
GIVEN: API key org_id doesn't match target org_id
|
||||||
|
WHEN: Permission checker is called
|
||||||
|
THEN: Warning is logged with org mismatch details
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
user_id = str(uuid4())
|
||||||
|
api_key_org_id = uuid4()
|
||||||
|
target_org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with patch('server.auth.authorization.logger') as mock_logger:
|
||||||
|
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
await permission_checker(
|
||||||
|
request=mock_request, org_id=target_org_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.warning.assert_called()
|
||||||
|
call_args = mock_logger.warning.call_args
|
||||||
|
assert call_args[1]['extra']['user_id'] == user_id
|
||||||
|
assert call_args[1]['extra']['api_key_org_id'] == str(api_key_org_id)
|
||||||
|
assert call_args[1]['extra']['target_org_id'] == str(target_org_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetApiKeyOrgIdFromRequest:
|
||||||
|
"""Tests for get_api_key_org_id_from_request helper function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_org_id_when_user_auth_has_api_key_org_id(self):
|
||||||
|
"""
|
||||||
|
GIVEN: Request with user_auth that has api_key_org_id
|
||||||
|
WHEN: get_api_key_org_id_from_request is called
|
||||||
|
THEN: Returns the api_key_org_id
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
org_id = uuid4()
|
||||||
|
mock_request = _create_mock_request(api_key_org_id=org_id)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await get_api_key_org_id_from_request(mock_request)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == org_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_user_auth_has_no_api_key_org_id(self):
|
||||||
|
"""
|
||||||
|
GIVEN: Request with user_auth that has no api_key_org_id (cookie auth)
|
||||||
|
WHEN: get_api_key_org_id_from_request is called
|
||||||
|
THEN: Returns None
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_request = _create_mock_request(api_key_org_id=None)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await get_api_key_org_id_from_request(mock_request)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_no_user_auth_in_request(self):
|
||||||
|
"""
|
||||||
|
GIVEN: Request without user_auth in state
|
||||||
|
WHEN: get_api_key_org_id_from_request is called
|
||||||
|
THEN: Returns None
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.user_auth = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await get_api_key_org_id_from_request(mock_request)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is None
|
||||||
|
|||||||
@@ -459,7 +459,8 @@ async def test_get_instance_no_auth(mock_request):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_saas_user_auth_from_bearer_success():
|
async def test_saas_user_auth_from_bearer_success():
|
||||||
"""Test successful authentication from bearer token."""
|
"""Test successful authentication from bearer token sets user_id and api_key_org_id."""
|
||||||
|
# Arrange
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user