diff --git a/enterprise/saas_server.py b/enterprise/saas_server.py index 8bb576a55b..434652befd 100644 --- a/enterprise/saas_server.py +++ b/enterprise/saas_server.py @@ -46,6 +46,7 @@ from server.routes.org_invitations import ( # noqa: E402 ) from server.routes.orgs import org_router # noqa: E402 from server.routes.readiness import readiness_router # noqa: E402 +from server.routes.service import service_router # noqa: E402 from server.routes.user import saas_user_router # noqa: E402 from server.routes.user_app_settings import user_app_settings_router # noqa: E402 from server.sharing.shared_conversation_router import ( # noqa: E402 @@ -112,6 +113,7 @@ if GITLAB_APP_CLIENT_ID: base_app.include_router(gitlab_integration_router) base_app.include_router(api_keys_router) # Add routes for API key management +base_app.include_router(service_router) # Add routes for internal service API base_app.include_router(org_router) # Add routes for organization management base_app.include_router( verified_models_router diff --git a/enterprise/server/middleware.py b/enterprise/server/middleware.py index 659a66046a..c014864b0b 100644 --- a/enterprise/server/middleware.py +++ b/enterprise/server/middleware.py @@ -182,6 +182,10 @@ class SetAuthCookieMiddleware: if path.startswith('/api/v1/webhooks/'): return False + # Service API uses its own authentication (X-Service-API-Key header) + if path.startswith('/api/service/'): + return False + is_mcp = path.startswith('/mcp') is_api_route = path.startswith('/api') return is_api_route or is_mcp diff --git a/enterprise/server/routes/service.py b/enterprise/server/routes/service.py new file mode 100644 index 0000000000..87e470dd7c --- /dev/null +++ b/enterprise/server/routes/service.py @@ -0,0 +1,270 @@ +""" +Service API routes for internal service-to-service communication. + +This module provides endpoints for trusted internal services (e.g., automations service) +to perform privileged operations like creating API keys on behalf of users. + +Authentication is via a shared secret (X-Service-API-Key header) configured +through the AUTOMATIONS_SERVICE_API_KEY environment variable. +""" + +import os +from uuid import UUID + +from fastapi import APIRouter, Header, HTTPException, status +from pydantic import BaseModel, field_validator +from storage.api_key_store import ApiKeyStore +from storage.org_member_store import OrgMemberStore +from storage.user_store import UserStore + +from openhands.core.logger import openhands_logger as logger + +# Environment variable for the service API key +AUTOMATIONS_SERVICE_API_KEY = os.getenv('AUTOMATIONS_SERVICE_API_KEY', '').strip() + +service_router = APIRouter(prefix='/api/service', tags=['Service']) + + +class CreateUserApiKeyRequest(BaseModel): + """Request model for creating an API key on behalf of a user.""" + + name: str # Required - used to identify the key + + @field_validator('name') + @classmethod + def validate_name(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError('name is required and cannot be empty') + return v.strip() + + +class CreateUserApiKeyResponse(BaseModel): + """Response model for created API key.""" + + key: str + user_id: str + org_id: str + name: str + + +class ServiceInfoResponse(BaseModel): + """Response model for service info endpoint.""" + + service: str + authenticated: bool + + +async def validate_service_api_key( + x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'), +) -> str: + """ + Validate the service API key from the request header. + + Args: + x_service_api_key: The service API key from the X-Service-API-Key header + + Returns: + str: Service identifier for audit logging + + Raises: + HTTPException: 401 if key is missing or invalid + HTTPException: 503 if service auth is not configured + """ + if not AUTOMATIONS_SERVICE_API_KEY: + logger.warning( + 'Service authentication not configured (AUTOMATIONS_SERVICE_API_KEY not set)' + ) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail='Service authentication not configured', + ) + + if not x_service_api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='X-Service-API-Key header is required', + ) + + if x_service_api_key != AUTOMATIONS_SERVICE_API_KEY: + logger.warning('Invalid service API key attempted') + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Invalid service API key', + ) + + return 'automations-service' + + +@service_router.get('/health') +async def service_health() -> dict: + """Health check endpoint for the service API. + + This endpoint does not require authentication and can be used + to verify the service routes are accessible. + """ + return { + 'status': 'ok', + 'service_auth_configured': bool(AUTOMATIONS_SERVICE_API_KEY), + } + + +@service_router.post('/users/{user_id}/orgs/{org_id}/api-keys') +async def get_or_create_api_key_for_user( + user_id: str, + org_id: UUID, + request: CreateUserApiKeyRequest, + x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'), +) -> CreateUserApiKeyResponse: + """ + Get or create an API key for a user on behalf of the automations service. + + If a key with the given name already exists for the user/org and is not expired, + returns the existing key. Otherwise, creates a new key. + + The created/returned keys are system keys and are: + - Not visible to the user in their API keys list + - Not deletable by the user + - Never expire + + Args: + user_id: The user ID + org_id: The organization ID + request: Request body containing name (required) + x_service_api_key: Service API key header for authentication + + Returns: + CreateUserApiKeyResponse: The API key and metadata + + Raises: + HTTPException: 401 if service key is invalid + HTTPException: 404 if user not found + HTTPException: 403 if user is not a member of the specified org + """ + # Validate service API key + service_id = await validate_service_api_key(x_service_api_key) + + # Verify user exists + user = await UserStore.get_user_by_id(user_id) + if not user: + logger.warning( + 'Service attempted to create key for non-existent user', + extra={'user_id': user_id}, + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f'User {user_id} not found', + ) + + # Verify user is a member of the specified org + org_member = await OrgMemberStore.get_org_member(org_id, UUID(user_id)) + if not org_member: + logger.warning( + 'Service attempted to create key for user not in org', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + }, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f'User {user_id} is not a member of org {org_id}', + ) + + # Get or create the system API key + api_key_store = ApiKeyStore.get_instance() + + try: + api_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=request.name, + ) + except Exception as e: + logger.exception( + 'Failed to get or create system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'error': str(e), + }, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to get or create API key', + ) + + logger.info( + 'Service created API key for user', + extra={ + 'service_id': service_id, + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': request.name, + }, + ) + + return CreateUserApiKeyResponse( + key=api_key, + user_id=user_id, + org_id=str(org_id), + name=request.name, + ) + + +@service_router.delete('/users/{user_id}/orgs/{org_id}/api-keys/{key_name}') +async def delete_user_api_key( + user_id: str, + org_id: UUID, + key_name: str, + x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'), +) -> dict: + """ + Delete a system API key created by the service. + + This endpoint allows the automations service to clean up API keys + it previously created for users. + + Args: + user_id: The user ID + org_id: The organization ID + key_name: The name of the key to delete (without __SYSTEM__: prefix) + x_service_api_key: Service API key header for authentication + + Returns: + dict: Success message + + Raises: + HTTPException: 401 if service key is invalid + HTTPException: 404 if key not found + """ + # Validate service API key + service_id = await validate_service_api_key(x_service_api_key) + + api_key_store = ApiKeyStore.get_instance() + + # Delete the key by name (wrap with system key prefix since service creates system keys) + system_key_name = api_key_store.make_system_key_name(key_name) + success = await api_key_store.delete_api_key_by_name( + user_id=user_id, + org_id=org_id, + name=system_key_name, + allow_system=True, + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f'API key with name "{key_name}" not found for user {user_id} in org {org_id}', + ) + + logger.info( + 'Service deleted API key for user', + extra={ + 'service_id': service_id, + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': key_name, + }, + ) + + return {'message': 'API key deleted successfully'} diff --git a/enterprise/storage/api_key_store.py b/enterprise/storage/api_key_store.py index 3090b8da07..ecbb375592 100644 --- a/enterprise/storage/api_key_store.py +++ b/enterprise/storage/api_key_store.py @@ -27,6 +27,9 @@ class ApiKeyValidationResult: @dataclass class ApiKeyStore: API_KEY_PREFIX = 'sk-oh-' + # Prefix for system keys created by internal services (e.g., automations) + # Keys with this prefix are hidden from users and cannot be deleted by users + SYSTEM_KEY_NAME_PREFIX = '__SYSTEM__:' def generate_api_key(self, length: int = 32) -> str: """Generate a random API key with the sk-oh- prefix.""" @@ -34,6 +37,19 @@ class ApiKeyStore: random_part = ''.join(secrets.choice(alphabet) for _ in range(length)) return f'{self.API_KEY_PREFIX}{random_part}' + @classmethod + def is_system_key_name(cls, name: str | None) -> bool: + """Check if a key name indicates a system key.""" + return name is not None and name.startswith(cls.SYSTEM_KEY_NAME_PREFIX) + + @classmethod + def make_system_key_name(cls, name: str) -> str: + """Create a system key name with the appropriate prefix. + + Format: __SYSTEM__: + """ + return f'{cls.SYSTEM_KEY_NAME_PREFIX}{name}' + async def create_api_key( self, user_id: str, name: str | None = None, expires_at: datetime | None = None ) -> str: @@ -71,6 +87,113 @@ class ApiKeyStore: return api_key + async def get_or_create_system_api_key( + self, + user_id: str, + org_id: UUID, + name: str, + ) -> str: + """Get or create a system API key for a user on behalf of an internal service. + + If a key with the given name already exists for this user/org and is not expired, + returns the existing key. Otherwise, creates a new key (and deletes any expired one). + + System keys are: + - Not visible to users in their API keys list (filtered by name prefix) + - Not deletable by users (protected by name prefix check) + - Associated with a specific org (not the user's current org) + - Never expire (no expiration date) + + Args: + user_id: The ID of the user to create the key for + org_id: The organization ID to associate the key with + name: Required name for the key (will be prefixed with __SYSTEM__:) + + Returns: + The API key (existing or newly created) + """ + # Create system key name with prefix + system_key_name = self.make_system_key_name(name) + + async with a_session_maker() as session: + # Check if key already exists for this user/org/name + result = await session.execute( + select(ApiKey).filter( + ApiKey.user_id == user_id, + ApiKey.org_id == org_id, + ApiKey.name == system_key_name, + ) + ) + existing_key = result.scalars().first() + + if existing_key: + # Check if expired + if existing_key.expires_at: + now = datetime.now(UTC) + expires_at = existing_key.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + + if expires_at < now: + # Key is expired, delete it and create new one + logger.info( + 'System API key expired, re-issuing', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + await session.delete(existing_key) + await session.commit() + else: + # Key exists and is not expired, return it + logger.debug( + 'Returning existing system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + return existing_key.key + else: + # Key exists and has no expiration, return it + logger.debug( + 'Returning existing system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + return existing_key.key + + # Create new key (no expiration) + api_key = self.generate_api_key() + + async with a_session_maker() as session: + key_record = ApiKey( + key=api_key, + user_id=user_id, + org_id=org_id, + name=system_key_name, + expires_at=None, # System keys never expire + ) + session.add(key_record) + await session.commit() + + logger.info( + 'Created system API key', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'key_name': system_key_name, + }, + ) + + return api_key + async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None: """Validate an API key and return the associated user_id and org_id if valid.""" now = datetime.now(UTC) @@ -121,8 +244,18 @@ class ApiKeyStore: return True - async def delete_api_key_by_id(self, key_id: int) -> bool: - """Delete an API key by its ID.""" + async def delete_api_key_by_id( + self, key_id: int, allow_system: bool = False + ) -> bool: + """Delete an API key by its ID. + + Args: + key_id: The ID of the key to delete + allow_system: If False (default), system keys cannot be deleted + + Returns: + True if the key was deleted, False if not found or is a protected system key + """ async with a_session_maker() as session: result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) key_record = result.scalars().first() @@ -130,13 +263,26 @@ class ApiKeyStore: if not key_record: return False + # Protect system keys from deletion unless explicitly allowed + if self.is_system_key_name(key_record.name) and not allow_system: + logger.warning( + 'Attempted to delete system API key', + extra={'key_id': key_id, 'user_id': key_record.user_id}, + ) + return False + await session.delete(key_record) await session.commit() return True async def list_api_keys(self, user_id: str) -> list[ApiKey]: - """List all API keys for a user.""" + """List all user-visible API keys for a user. + + This excludes: + - System keys (name starts with __SYSTEM__:) - created by internal services + - MCP_API_KEY - internal MCP key + """ user = await UserStore.get_user_by_id(user_id) if user is None: raise ValueError(f'User not found: {user_id}') @@ -145,11 +291,17 @@ class ApiKeyStore: async with a_session_maker() as session: result = await session.execute( select(ApiKey).filter( - ApiKey.user_id == user_id, ApiKey.org_id == org_id + ApiKey.user_id == user_id, + ApiKey.org_id == org_id, ) ) keys = result.scalars().all() - return [key for key in keys if key.name != 'MCP_API_KEY'] + # Filter out system keys and MCP_API_KEY + return [ + key + for key in keys + if key.name != 'MCP_API_KEY' and not self.is_system_key_name(key.name) + ] async def retrieve_mcp_api_key(self, user_id: str) -> str | None: user = await UserStore.get_user_by_id(user_id) @@ -179,17 +331,44 @@ class ApiKeyStore: key_record = result.scalars().first() return key_record.key if key_record else None - async def delete_api_key_by_name(self, user_id: str, name: str) -> bool: - """Delete an API key by name for a specific user.""" + async def delete_api_key_by_name( + self, + user_id: str, + name: str, + org_id: UUID | None = None, + allow_system: bool = False, + ) -> bool: + """Delete an API key by name for a specific user. + + Args: + user_id: The ID of the user whose key to delete + name: The name of the key to delete + org_id: Optional organization ID to filter by (required for system keys) + allow_system: If False (default), system keys cannot be deleted + + Returns: + True if the key was deleted, False if not found or is a protected system key + """ async with a_session_maker() as session: - result = await session.execute( - select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name) - ) + # Build the query filters + filters = [ApiKey.user_id == user_id, ApiKey.name == name] + if org_id is not None: + filters.append(ApiKey.org_id == org_id) + + result = await session.execute(select(ApiKey).filter(*filters)) key_record = result.scalars().first() if not key_record: return False + # Protect system keys from deletion unless explicitly allowed + if self.is_system_key_name(key_record.name) and not allow_system: + logger.warning( + 'Attempted to delete system API key', + extra={'user_id': user_id, 'key_name': name}, + ) + return False + await session.delete(key_record) await session.commit() diff --git a/enterprise/tests/unit/routes/__init__.py b/enterprise/tests/unit/routes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/enterprise/tests/unit/routes/test_service.py b/enterprise/tests/unit/routes/test_service.py new file mode 100644 index 0000000000..a7156ec117 --- /dev/null +++ b/enterprise/tests/unit/routes/test_service.py @@ -0,0 +1,331 @@ +"""Unit tests for service API routes.""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException +from server.routes.service import ( + CreateUserApiKeyRequest, + delete_user_api_key, + get_or_create_api_key_for_user, + validate_service_api_key, +) + + +class TestValidateServiceApiKey: + """Test cases for validate_service_api_key.""" + + @pytest.mark.asyncio + async def test_valid_service_key(self): + """Test validation with valid service API key.""" + with patch( + 'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key' + ): + result = await validate_service_api_key('test-service-key') + assert result == 'automations-service' + + @pytest.mark.asyncio + async def test_missing_service_key(self): + """Test validation with missing service API key header.""" + with patch( + 'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key' + ): + with pytest.raises(HTTPException) as exc_info: + await validate_service_api_key(None) + assert exc_info.value.status_code == 401 + assert 'X-Service-API-Key header is required' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_invalid_service_key(self): + """Test validation with invalid service API key.""" + with patch( + 'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key' + ): + with pytest.raises(HTTPException) as exc_info: + await validate_service_api_key('wrong-key') + assert exc_info.value.status_code == 401 + assert 'Invalid service API key' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_service_auth_not_configured(self): + """Test validation when service auth is not configured.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', ''): + with pytest.raises(HTTPException) as exc_info: + await validate_service_api_key('any-key') + assert exc_info.value.status_code == 503 + assert 'Service authentication not configured' in exc_info.value.detail + + +class TestCreateUserApiKeyRequest: + """Test cases for CreateUserApiKeyRequest validation.""" + + def test_valid_request(self): + """Test valid request with all fields.""" + request = CreateUserApiKeyRequest( + name='automation', + ) + assert request.name == 'automation' + + def test_name_is_required(self): + """Test that name field is required.""" + with pytest.raises(ValueError): + CreateUserApiKeyRequest( + name='', # Empty name should fail + ) + + def test_name_is_stripped(self): + """Test that name field is stripped of whitespace.""" + request = CreateUserApiKeyRequest( + name=' automation ', + ) + assert request.name == 'automation' + + def test_whitespace_only_name_fails(self): + """Test that whitespace-only name fails validation.""" + with pytest.raises(ValueError): + CreateUserApiKeyRequest( + name=' ', + ) + + +class TestGetOrCreateApiKeyForUser: + """Test cases for get_or_create_api_key_for_user endpoint.""" + + @pytest.fixture + def valid_user_id(self): + """Return a valid user ID.""" + return '5594c7b6-f959-4b81-92e9-b09c206f5081' + + @pytest.fixture + def valid_org_id(self): + """Return a valid org ID.""" + return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + @pytest.fixture + def valid_request(self): + """Create a valid request object.""" + return CreateUserApiKeyRequest( + name='automation', + ) + + @pytest.mark.asyncio + async def test_user_not_found(self, valid_user_id, valid_org_id, valid_request): + """Test error when user doesn't exist.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + mock_get_user.return_value = None + with pytest.raises(HTTPException) as exc_info: + await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + assert exc_info.value.status_code == 404 + assert 'not found' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_user_not_in_org(self, valid_user_id, valid_org_id, valid_request): + """Test error when user is not a member of the org.""" + mock_user = MagicMock() + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + with patch( + 'server.routes.service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + ) as mock_get_member: + mock_get_user.return_value = mock_user + mock_get_member.return_value = None + with pytest.raises(HTTPException) as exc_info: + await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + assert exc_info.value.status_code == 403 + assert 'not a member of org' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_successful_key_creation( + self, valid_user_id, valid_org_id, valid_request + ): + """Test successful API key creation.""" + mock_user = MagicMock() + mock_org_member = MagicMock() + mock_api_key_store = MagicMock() + mock_api_key_store.get_or_create_system_api_key = AsyncMock( + return_value='sk-oh-test-key-12345678901234567890' + ) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + with patch( + 'server.routes.service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + ) as mock_get_member: + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_user.return_value = mock_user + mock_get_member.return_value = mock_org_member + mock_get_store.return_value = mock_api_key_store + + response = await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + + assert response.key == 'sk-oh-test-key-12345678901234567890' + assert response.user_id == valid_user_id + assert response.org_id == str(valid_org_id) + assert response.name == 'automation' + + # Verify the store was called with correct arguments + mock_api_key_store.get_or_create_system_api_key.assert_called_once_with( + user_id=valid_user_id, + org_id=valid_org_id, + name='automation', + ) + + @pytest.mark.asyncio + async def test_store_exception_handling( + self, valid_user_id, valid_org_id, valid_request + ): + """Test error handling when store raises exception.""" + mock_user = MagicMock() + mock_org_member = MagicMock() + mock_api_key_store = MagicMock() + mock_api_key_store.get_or_create_system_api_key = AsyncMock( + side_effect=Exception('Database error') + ) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + with patch( + 'server.routes.service.OrgMemberStore.get_org_member', + new_callable=AsyncMock, + ) as mock_get_member: + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_user.return_value = mock_user + mock_get_member.return_value = mock_org_member + mock_get_store.return_value = mock_api_key_store + + with pytest.raises(HTTPException) as exc_info: + await get_or_create_api_key_for_user( + user_id=valid_user_id, + org_id=valid_org_id, + request=valid_request, + x_service_api_key='test-key', + ) + + assert exc_info.value.status_code == 500 + assert 'Failed to get or create API key' in exc_info.value.detail + + +class TestDeleteUserApiKey: + """Test cases for delete_user_api_key endpoint.""" + + @pytest.fixture + def valid_org_id(self): + """Return a valid org ID.""" + return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + @pytest.mark.asyncio + async def test_successful_delete(self, valid_org_id): + """Test successful deletion of a system API key.""" + mock_api_key_store = MagicMock() + mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:automation' + mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=True) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_store.return_value = mock_api_key_store + + response = await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='automation', + x_service_api_key='test-key', + ) + + assert response == {'message': 'API key deleted successfully'} + + # Verify the store was called with correct arguments + mock_api_key_store.make_system_key_name.assert_called_once_with('automation') + mock_api_key_store.delete_api_key_by_name.assert_called_once_with( + user_id='user-123', + org_id=valid_org_id, + name='__SYSTEM__:automation', + allow_system=True, + ) + + @pytest.mark.asyncio + async def test_delete_key_not_found(self, valid_org_id): + """Test error when key to delete is not found.""" + mock_api_key_store = MagicMock() + mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:nonexistent' + mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=False) + + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with patch( + 'server.routes.service.ApiKeyStore.get_instance' + ) as mock_get_store: + mock_get_store.return_value = mock_api_key_store + + with pytest.raises(HTTPException) as exc_info: + await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='nonexistent', + x_service_api_key='test-key', + ) + + assert exc_info.value.status_code == 404 + assert 'not found' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_delete_invalid_service_key(self, valid_org_id): + """Test error when service API key is invalid.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with pytest.raises(HTTPException) as exc_info: + await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='automation', + x_service_api_key='wrong-key', + ) + + assert exc_info.value.status_code == 401 + assert 'Invalid service API key' in exc_info.value.detail + + @pytest.mark.asyncio + async def test_delete_missing_service_key(self, valid_org_id): + """Test error when service API key header is missing.""" + with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'): + with pytest.raises(HTTPException) as exc_info: + await delete_user_api_key( + user_id='user-123', + org_id=valid_org_id, + key_name='automation', + x_service_api_key=None, + ) + + assert exc_info.value.status_code == 401 + assert 'X-Service-API-Key header is required' in exc_info.value.detail diff --git a/enterprise/tests/unit/storage/test_api_key_store.py b/enterprise/tests/unit/storage/test_api_key_store.py new file mode 100644 index 0000000000..0db2d8bb96 --- /dev/null +++ b/enterprise/tests/unit/storage/test_api_key_store.py @@ -0,0 +1,314 @@ +"""Unit tests for ApiKeyStore system key functionality.""" + +import uuid +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy import select +from storage.api_key import ApiKey +from storage.api_key_store import ApiKeyStore + + +@pytest.fixture +def api_key_store(): + """Create ApiKeyStore instance.""" + return ApiKeyStore() + + +class TestApiKeyStoreSystemKeys: + """Test cases for system API key functionality.""" + + def test_is_system_key_name_with_prefix(self, api_key_store): + """Test that names with __SYSTEM__: prefix are identified as system keys.""" + assert api_key_store.is_system_key_name('__SYSTEM__:automation') is True + assert api_key_store.is_system_key_name('__SYSTEM__:test-key') is True + assert api_key_store.is_system_key_name('__SYSTEM__:') is True + + def test_is_system_key_name_without_prefix(self, api_key_store): + """Test that names without __SYSTEM__: prefix are not system keys.""" + assert api_key_store.is_system_key_name('my-key') is False + assert api_key_store.is_system_key_name('automation') is False + assert api_key_store.is_system_key_name('MCP_API_KEY') is False + assert api_key_store.is_system_key_name('') is False + + def test_is_system_key_name_none(self, api_key_store): + """Test that None is not a system key.""" + assert api_key_store.is_system_key_name(None) is False + + def test_make_system_key_name(self, api_key_store): + """Test system key name generation.""" + assert ( + api_key_store.make_system_key_name('automation') == '__SYSTEM__:automation' + ) + assert api_key_store.make_system_key_name('test-key') == '__SYSTEM__:test-key' + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_creates_new( + self, api_key_store, async_session_maker + ): + """Test creating a new system API key when none exists.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + key_name = 'automation' + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + api_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + assert api_key.startswith('sk-oh-') + assert len(api_key) == len('sk-oh-') + 32 + + # Verify the key was created in the database + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key)) + key_record = result.scalars().first() + assert key_record is not None + assert key_record.user_id == user_id + assert key_record.org_id == org_id + assert key_record.name == '__SYSTEM__:automation' + assert key_record.expires_at is None # System keys never expire + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_returns_existing( + self, api_key_store, async_session_maker + ): + """Test that existing valid system key is returned.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + key_name = 'automation' + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Create the first key + first_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + # Request again - should return the same key + second_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + assert first_key == second_key + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_different_names( + self, api_key_store, async_session_maker + ): + """Test that different names create different keys.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + key1 = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name='automation-1', + ) + + key2 = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name='automation-2', + ) + + assert key1 != key2 + + @pytest.mark.asyncio + async def test_get_or_create_system_api_key_reissues_expired( + self, api_key_store, async_session_maker + ): + """Test that expired system key is replaced with a new one.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + key_name = 'automation' + system_key_name = '__SYSTEM__:automation' + + # First, manually create an expired key + expired_time = datetime.now(UTC) - timedelta(hours=1) + async with async_session_maker() as session: + expired_key = ApiKey( + key='sk-oh-expired-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name=system_key_name, + expires_at=expired_time.replace(tzinfo=None), + ) + session.add(expired_key) + await session.commit() + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Request the key - should create a new one + new_key = await api_key_store.get_or_create_system_api_key( + user_id=user_id, + org_id=org_id, + name=key_name, + ) + + assert new_key != 'sk-oh-expired-key-12345678901234567890' + assert new_key.startswith('sk-oh-') + + # Verify old key was deleted and new key exists + async with async_session_maker() as session: + result = await session.execute( + select(ApiKey).filter(ApiKey.name == system_key_name) + ) + keys = result.scalars().all() + assert len(keys) == 1 + assert keys[0].key == new_key + assert keys[0].expires_at is None + + @pytest.mark.asyncio + async def test_list_api_keys_excludes_system_keys( + self, api_key_store, async_session_maker + ): + """Test that list_api_keys excludes system keys.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a user key and a system key + async with async_session_maker() as session: + user_key = ApiKey( + key='sk-oh-user-key-123456789012345678901', + user_id=user_id, + org_id=org_id, + name='my-user-key', + ) + system_key = ApiKey( + key='sk-oh-system-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name='__SYSTEM__:automation', + ) + mcp_key = ApiKey( + key='sk-oh-mcp-key-1234567890123456789012', + user_id=user_id, + org_id=org_id, + name='MCP_API_KEY', + ) + session.add(user_key) + session.add(system_key) + session.add(mcp_key) + await session.commit() + + # Mock UserStore.get_user_by_id to return a user with the correct org + mock_user = MagicMock() + mock_user.current_org_id = org_id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + with patch( + 'storage.api_key_store.UserStore.get_user_by_id', new_callable=AsyncMock + ) as mock_get_user: + mock_get_user.return_value = mock_user + keys = await api_key_store.list_api_keys(user_id) + + # Should only return the user key + assert len(keys) == 1 + assert keys[0].name == 'my-user-key' + + @pytest.mark.asyncio + async def test_delete_api_key_by_id_protects_system_keys( + self, api_key_store, async_session_maker + ): + """Test that system keys cannot be deleted by users.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a system key + async with async_session_maker() as session: + system_key = ApiKey( + key='sk-oh-system-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name='__SYSTEM__:automation', + ) + session.add(system_key) + await session.commit() + key_id = system_key.id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Attempt to delete without allow_system flag + result = await api_key_store.delete_api_key_by_id( + key_id, allow_system=False + ) + + assert result is False + + # Verify the key still exists + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) + key_record = result.scalars().first() + assert key_record is not None + + @pytest.mark.asyncio + async def test_delete_api_key_by_id_allows_system_with_flag( + self, api_key_store, async_session_maker + ): + """Test that system keys can be deleted with allow_system=True.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a system key + async with async_session_maker() as session: + system_key = ApiKey( + key='sk-oh-system-key-12345678901234567890', + user_id=user_id, + org_id=org_id, + name='__SYSTEM__:automation', + ) + session.add(system_key) + await session.commit() + key_id = system_key.id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Delete with allow_system=True + result = await api_key_store.delete_api_key_by_id(key_id, allow_system=True) + + assert result is True + + # Verify the key was deleted + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) + key_record = result.scalars().first() + assert key_record is None + + @pytest.mark.asyncio + async def test_delete_api_key_by_id_allows_regular_keys( + self, api_key_store, async_session_maker + ): + """Test that regular keys can be deleted normally.""" + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081') + + # Create a regular key + async with async_session_maker() as session: + regular_key = ApiKey( + key='sk-oh-regular-key-1234567890123456789', + user_id=user_id, + org_id=org_id, + name='my-regular-key', + ) + session.add(regular_key) + await session.commit() + key_id = regular_key.id + + with patch('storage.api_key_store.a_session_maker', async_session_maker): + # Delete without allow_system flag - should work for regular keys + result = await api_key_store.delete_api_key_by_id( + key_id, allow_system=False + ) + + assert result is True + + # Verify the key was deleted + async with async_session_maker() as session: + result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) + key_record = result.scalars().first() + assert key_record is None