Files
OpenHands/enterprise/storage/api_key_store.py
2026-03-03 17:51:53 -07:00

183 lines
6.2 KiB
Python

from __future__ import annotations
import secrets
import string
from dataclasses import dataclass
from datetime import UTC, datetime
from sqlalchemy import select, update
from storage.api_key import ApiKey
from storage.database import a_session_maker
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
@dataclass
class ApiKeyStore:
API_KEY_PREFIX = 'sk-oh-'
def generate_api_key(self, length: int = 32) -> str:
"""Generate a random API key with the sk-oh- prefix."""
alphabet = string.ascii_letters + string.digits
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
return f'{self.API_KEY_PREFIX}{random_part}'
async def create_api_key(
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
) -> str:
"""Create a new API key for a user.
Args:
user_id: The ID of the user to create the key for
name: Optional name for the key
expires_at: Optional expiration date for the key
Returns:
The generated API key
"""
api_key = self.generate_api_key()
user = await UserStore.get_user_by_id(user_id)
if user is None:
raise ValueError(f'User not found: {user_id}')
org_id = user.current_org_id
async with a_session_maker() as session:
key_record = ApiKey(
key=api_key,
user_id=user_id,
org_id=org_id,
name=name,
expires_at=expires_at,
)
session.add(key_record)
await session.commit()
return api_key
async def validate_api_key(self, api_key: str) -> str | None:
"""Validate an API key and return the associated user_id if valid."""
now = datetime.now(UTC)
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
key_record = result.scalars().first()
if not key_record:
return None
# Check if the key has expired
if key_record.expires_at:
# Handle timezone-naive datetime from database by assuming it's UTC
expires_at = key_record.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.info(f'API key has expired: {key_record.id}')
return None
# Update last_used_at timestamp
await session.execute(
update(ApiKey)
.where(ApiKey.id == key_record.id)
.values(last_used_at=now.replace(tzinfo=None))
)
await session.commit()
return key_record.user_id
async def delete_api_key(self, api_key: str) -> bool:
"""Delete an API key by the key value."""
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
key_record = result.scalars().first()
if not key_record:
return False
await session.delete(key_record)
await session.commit()
return True
async def delete_api_key_by_id(self, key_id: int) -> bool:
"""Delete an API key by its ID."""
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
key_record = result.scalars().first()
if not key_record:
return False
await session.delete(key_record)
await session.commit()
return True
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
"""List all API keys for a user."""
user = await UserStore.get_user_by_id(user_id)
if user is None:
raise ValueError(f'User not found: {user_id}')
org_id = user.current_org_id
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(
ApiKey.user_id == user_id, ApiKey.org_id == org_id
)
)
keys = result.scalars().all()
return [key for key in keys if key.name != 'MCP_API_KEY']
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id(user_id)
if user is None:
raise ValueError(f'User not found: {user_id}')
org_id = user.current_org_id
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(
ApiKey.user_id == user_id, ApiKey.org_id == org_id
)
)
keys = result.scalars().all()
for key in keys:
if key.name == 'MCP_API_KEY':
return key.key
return None
async def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
"""Retrieve an API key by name for a specific user."""
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
)
key_record = result.scalars().first()
return key_record.key if key_record else None
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
"""Delete an API key by name for a specific user."""
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
)
key_record = result.scalars().first()
if not key_record:
return False
await session.delete(key_record)
await session.commit()
return True
@classmethod
def get_instance(cls) -> ApiKeyStore:
"""Get an instance of the ApiKeyStore."""
logger.debug('api_key_store.get_instance')
return ApiKeyStore()