from __future__ import annotations import secrets import string from dataclasses import dataclass from datetime import UTC, datetime from sqlalchemy import select, update from storage.api_key import ApiKey from storage.database import a_session_maker from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger @dataclass class ApiKeyStore: API_KEY_PREFIX = 'sk-oh-' def generate_api_key(self, length: int = 32) -> str: """Generate a random API key with the sk-oh- prefix.""" alphabet = string.ascii_letters + string.digits random_part = ''.join(secrets.choice(alphabet) for _ in range(length)) return f'{self.API_KEY_PREFIX}{random_part}' async def create_api_key( self, user_id: str, name: str | None = None, expires_at: datetime | None = None ) -> str: """Create a new API key for a user. Args: user_id: The ID of the user to create the key for name: Optional name for the key expires_at: Optional expiration date for the key Returns: The generated API key """ api_key = self.generate_api_key() user = await UserStore.get_user_by_id_async(user_id) org_id = user.current_org_id async with a_session_maker() as session: key_record = ApiKey( key=api_key, user_id=user_id, org_id=org_id, name=name, expires_at=expires_at, ) session.add(key_record) await session.commit() return api_key async def validate_api_key(self, api_key: str) -> str | None: """Validate an API key and return the associated user_id if valid.""" now = datetime.now(UTC) async with a_session_maker() as session: result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key)) key_record = result.scalars().first() if not key_record: return None # Check if the key has expired if key_record.expires_at: # Handle timezone-naive datetime from database by assuming it's UTC expires_at = key_record.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=UTC) if expires_at < now: logger.info(f'API key has expired: {key_record.id}') return None # Update last_used_at timestamp await session.execute( update(ApiKey) .where(ApiKey.id == key_record.id) .values(last_used_at=now) ) await session.commit() return key_record.user_id async def delete_api_key(self, api_key: str) -> bool: """Delete an API key by the key value.""" async with a_session_maker() as session: result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key)) key_record = result.scalars().first() if not key_record: return False await session.delete(key_record) await session.commit() return True async def delete_api_key_by_id(self, key_id: int) -> bool: """Delete an API key by its ID.""" async with a_session_maker() as session: result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id)) key_record = result.scalars().first() if not key_record: return False await session.delete(key_record) await session.commit() return True async def list_api_keys(self, user_id: str) -> list[ApiKey]: """List all API keys for a user.""" user = await UserStore.get_user_by_id_async(user_id) org_id = user.current_org_id async with a_session_maker() as session: result = await session.execute( select(ApiKey).filter( ApiKey.user_id == user_id, ApiKey.org_id == org_id ) ) keys = result.scalars().all() return [key for key in keys if key.name != 'MCP_API_KEY'] async def retrieve_mcp_api_key(self, user_id: str) -> str | None: user = await UserStore.get_user_by_id_async(user_id) org_id = user.current_org_id async with a_session_maker() as session: result = await session.execute( select(ApiKey).filter( ApiKey.user_id == user_id, ApiKey.org_id == org_id ) ) keys = result.scalars().all() for key in keys: if key.name == 'MCP_API_KEY': return key.key return None async def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None: """Retrieve an API key by name for a specific user.""" async with a_session_maker() as session: result = await session.execute( select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name) ) key_record = result.scalars().first() return key_record.key if key_record else None async def delete_api_key_by_name(self, user_id: str, name: str) -> bool: """Delete an API key by name for a specific user.""" async with a_session_maker() as session: result = await session.execute( select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name) ) key_record = result.scalars().first() if not key_record: return False await session.delete(key_record) await session.commit() return True @classmethod def get_instance(cls) -> ApiKeyStore: """Get an instance of the ApiKeyStore.""" logger.debug('api_key_store.get_instance') return ApiKeyStore()