mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 13:52:43 +08:00
169 lines
5.3 KiB
Python
169 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import secrets
|
|
import string
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
|
|
from sqlalchemy import update
|
|
from sqlalchemy.orm import sessionmaker
|
|
from storage.api_key import ApiKey
|
|
from storage.database import session_maker
|
|
|
|
from openhands.core.logger import openhands_logger as logger
|
|
|
|
|
|
@dataclass
|
|
class ApiKeyStore:
|
|
session_maker: sessionmaker
|
|
|
|
API_KEY_PREFIX = 'sk-oh-'
|
|
|
|
def generate_api_key(self, length: int = 32) -> str:
|
|
"""Generate a random API key with the sk-oh- prefix."""
|
|
alphabet = string.ascii_letters + string.digits
|
|
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
|
return f'{self.API_KEY_PREFIX}{random_part}'
|
|
|
|
def create_api_key(
|
|
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
|
) -> str:
|
|
"""Create a new API key for a user.
|
|
|
|
Args:
|
|
user_id: The ID of the user to create the key for
|
|
name: Optional name for the key
|
|
expires_at: Optional expiration date for the key
|
|
|
|
Returns:
|
|
The generated API key
|
|
"""
|
|
api_key = self.generate_api_key()
|
|
|
|
with self.session_maker() as session:
|
|
key_record = ApiKey(
|
|
key=api_key, user_id=user_id, name=name, expires_at=expires_at
|
|
)
|
|
session.add(key_record)
|
|
session.commit()
|
|
|
|
return api_key
|
|
|
|
def validate_api_key(self, api_key: str) -> str | None:
|
|
"""Validate an API key and return the associated user_id if valid."""
|
|
now = datetime.now(UTC)
|
|
|
|
with self.session_maker() as session:
|
|
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
|
|
|
|
if not key_record:
|
|
return None
|
|
|
|
# Check if the key has expired
|
|
if key_record.expires_at:
|
|
# Handle timezone-naive datetime from database by assuming it's UTC
|
|
expires_at = key_record.expires_at
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=UTC)
|
|
|
|
if expires_at < now:
|
|
logger.info(f'API key has expired: {key_record.id}')
|
|
return None
|
|
|
|
# Update last_used_at timestamp
|
|
session.execute(
|
|
update(ApiKey)
|
|
.where(ApiKey.id == key_record.id)
|
|
.values(last_used_at=now)
|
|
)
|
|
session.commit()
|
|
|
|
return key_record.user_id
|
|
|
|
def delete_api_key(self, api_key: str) -> bool:
|
|
"""Delete an API key by the key value."""
|
|
with self.session_maker() as session:
|
|
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
|
|
|
|
if not key_record:
|
|
return False
|
|
|
|
session.delete(key_record)
|
|
session.commit()
|
|
|
|
return True
|
|
|
|
def delete_api_key_by_id(self, key_id: int) -> bool:
|
|
"""Delete an API key by its ID."""
|
|
with self.session_maker() as session:
|
|
key_record = session.query(ApiKey).filter(ApiKey.id == key_id).first()
|
|
|
|
if not key_record:
|
|
return False
|
|
|
|
session.delete(key_record)
|
|
session.commit()
|
|
|
|
return True
|
|
|
|
def list_api_keys(self, user_id: str) -> list[dict]:
|
|
"""List all API keys for a user."""
|
|
with self.session_maker() as session:
|
|
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
|
|
|
return [
|
|
{
|
|
'id': key.id,
|
|
'name': key.name,
|
|
'created_at': key.created_at,
|
|
'last_used_at': key.last_used_at,
|
|
'expires_at': key.expires_at,
|
|
}
|
|
for key in keys
|
|
if 'MCP_API_KEY' != key.name
|
|
]
|
|
|
|
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
|
with self.session_maker() as session:
|
|
keys: list[ApiKey] = (
|
|
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
|
)
|
|
for key in keys:
|
|
if key.name == 'MCP_API_KEY':
|
|
return key.key
|
|
|
|
return None
|
|
|
|
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
|
"""Retrieve an API key by name for a specific user."""
|
|
with self.session_maker() as session:
|
|
key_record = (
|
|
session.query(ApiKey)
|
|
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
|
.first()
|
|
)
|
|
return key_record.key if key_record else None
|
|
|
|
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
|
"""Delete an API key by name for a specific user."""
|
|
with self.session_maker() as session:
|
|
key_record = (
|
|
session.query(ApiKey)
|
|
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
|
.first()
|
|
)
|
|
|
|
if not key_record:
|
|
return False
|
|
|
|
session.delete(key_record)
|
|
session.commit()
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def get_instance(cls) -> ApiKeyStore:
|
|
"""Get an instance of the ApiKeyStore."""
|
|
logger.debug('api_key_store.get_instance')
|
|
return ApiKeyStore(session_maker)
|