Files
OpenHands/enterprise/storage/saas_settings_store.py
2026-03-16 05:19:31 -06:00

325 lines
12 KiB
Python

from __future__ import annotations
import binascii
import hashlib
import uuid
from base64 import b64decode, b64encode
from dataclasses import dataclass
from cryptography.fernet import Fernet
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import LITE_LLM_API_URL
from server.logger import logger
from sqlalchemy import select, update
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from storage.encrypt_utils import encrypt_value
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_store import OrgStore
from storage.user import User
from storage.user_settings import UserSettings
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.server.settings import Settings
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.llm import is_openhands_model
@dataclass
class SaasSettingsStore(SettingsStore):
user_id: str
config: OpenHandsConfig
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
async def _get_user_settings_by_keycloak_id_async(
self, keycloak_user_id: str, session=None
) -> UserSettings | None:
"""
Get UserSettings by keycloak_user_id (async version).
Args:
keycloak_user_id: The keycloak user ID to search for
session: Optional existing async database session. If not provided, creates a new one.
Returns:
UserSettings object if found, None otherwise
"""
if not keycloak_user_id:
return None
if session:
# Use provided session
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == keycloak_user_id
)
)
return result.scalars().first()
else:
# Create new session
async with a_session_maker() as new_session:
result = await new_session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == keycloak_user_id
)
)
return result.scalars().first()
async def load(self) -> Settings | None:
user = await UserStore.get_user_by_id(self.user_id)
if not user:
logger.error(f'User not found for ID {self.user_id}')
return None
org_id = user.current_org_id
org_member: OrgMember | None = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return None
org = await OrgStore.get_org_by_id_async(org_id)
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
)
return None
kwargs = {
**{
normalized: getattr(org, c.name)
for c in Org.__table__.columns
if (
normalized := c.name.removeprefix('_default_')
.removeprefix('default_')
.lstrip('_')
)
in Settings.model_fields
},
**{
normalized: getattr(user, c.name)
for c in User.__table__.columns
if (normalized := c.name.lstrip('_')) in Settings.model_fields
},
}
kwargs['llm_api_key'] = org_member.llm_api_key
if org_member.max_iterations:
kwargs['max_iterations'] = org_member.max_iterations
if org_member.llm_model:
kwargs['llm_model'] = org_member.llm_model
if org_member.llm_api_key_for_byor:
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
if org_member.llm_base_url:
kwargs['llm_base_url'] = org_member.llm_base_url
if org.v1_enabled is None:
kwargs['v1_enabled'] = True
# Apply default if sandbox_grouping_strategy is None in the database
if kwargs.get('sandbox_grouping_strategy') is None:
kwargs.pop('sandbox_grouping_strategy', None)
settings = Settings(**kwargs)
return settings
async def store(self, item: Settings):
async with a_session_maker() as session:
if not item:
return None
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(self.user_id))
)
user = result.scalars().first()
if not user:
# Check if we need to migrate from user_settings
user_settings = None
async with a_session_maker() as new_session:
user_settings = await self._get_user_settings_by_keycloak_id_async(
self.user_id, new_session
)
if user_settings:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(
self.user_id
)
if not user_info:
logger.error(f'User info not found for ID {self.user_id}')
return None
user = await UserStore.migrate_user(
self.user_id, user_settings, user_info
)
if not user:
logger.error(f'Failed to migrate user {self.user_id}')
return None
else:
logger.error(f'User not found for ID {self.user_id}')
return None
org_id = user.current_org_id
org_member: OrgMember | None = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return None
result = await session.execute(select(Org).filter(Org.id == org_id))
org = result.scalars().first()
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
)
return None
# Check if we need to generate an LLM key.
if item.llm_base_url == LITE_LLM_API_URL:
await self._ensure_api_key(
item, str(org_id), openhands_type=is_openhands_model(item.llm_model)
)
kwargs = item.model_dump(context={'expose_secrets': True})
for model in (user, org, org_member):
for key, value in kwargs.items():
if hasattr(model, key):
setattr(model, key, value)
# Map Settings fields to Org fields with 'default_' prefix
# The generic loop above doesn't update these because Org uses
# 'default_llm_model' not 'llm_model', etc.
# Use exclude_unset to only update explicitly-set fields (allows clearing with null)
settings_data = item.model_dump(exclude_unset=True)
if 'llm_model' in settings_data:
org.default_llm_model = settings_data['llm_model']
if 'llm_base_url' in settings_data:
org.default_llm_base_url = settings_data['llm_base_url']
if 'max_iterations' in settings_data:
org.default_max_iterations = settings_data['max_iterations']
# Propagate LLM settings to all org members
# This ensures all members see the same LLM configuration when an admin saves
# Note: Concurrent saves by multiple admins will result in last-write-wins.
# Consider adding optimistic locking if this becomes a problem.
member_update_values: dict = {}
if item.llm_model is not None:
member_update_values['llm_model'] = item.llm_model
if item.llm_base_url is not None:
member_update_values['llm_base_url'] = item.llm_base_url
if item.max_iterations is not None:
member_update_values['max_iterations'] = item.max_iterations
if item.llm_api_key is not None:
member_update_values['_llm_api_key'] = encrypt_value(
item.llm_api_key.get_secret_value()
)
if member_update_values:
stmt = (
update(OrgMember)
.where(OrgMember.org_id == org_id)
.values(**member_update_values)
)
await session.execute(stmt)
await session.commit()
@classmethod
async def get_instance(
cls,
config: OpenHandsConfig,
user_id: str, # type: ignore[override]
) -> SaasSettingsStore:
logger.debug(f'saas_settings_store.get_instance::{user_id}')
return SaasSettingsStore(user_id, config)
def _should_encrypt(self, key):
return key in self.ENCRYPT_VALUES
def _decrypt_kwargs(self, kwargs: dict):
fernet = self._fernet()
for key, value in kwargs.items():
try:
if value is None:
continue
if self._should_encrypt(key):
if isinstance(value, SecretStr):
value = fernet.decrypt(
b64decode(value.get_secret_value().encode())
).decode()
else:
value = fernet.decrypt(b64decode(value.encode())).decode()
kwargs[key] = value
except binascii.Error:
pass # Key is in legacy format...
def _encrypt_kwargs(self, kwargs: dict):
fernet = self._fernet()
for key, value in kwargs.items():
if value is None:
continue
if isinstance(value, dict):
self._encrypt_kwargs(value)
continue
if self._should_encrypt(key):
if isinstance(value, SecretStr):
value = b64encode(
fernet.encrypt(value.get_secret_value().encode())
).decode()
else:
value = b64encode(fernet.encrypt(value.encode())).decode()
kwargs[key] = value
def _fernet(self):
if not self.config.jwt_secret:
raise ValueError('jwt_secret must be defined on config')
jwt_secret = self.config.jwt_secret.get_secret_value()
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
return Fernet(fernet_key)
async def _ensure_api_key(
self, item: Settings, org_id: str, openhands_type: bool = False
) -> None:
"""Generate and set the OpenHands API key for the given settings.
First checks if an existing key exists for the user and verifies it
is valid in LiteLLM. If valid, reuses it. Otherwise, generates a new key.
"""
# First, check if our current key is valid
if item.llm_api_key and not await LiteLlmManager.verify_existing_key(
item.llm_api_key.get_secret_value(),
self.user_id,
org_id,
openhands_type=openhands_type,
):
if openhands_type:
generated_key = await LiteLlmManager.generate_key(
self.user_id,
org_id,
None,
{'type': 'openhands'},
)
else:
# Must delete any existing key with the same alias first
key_alias = get_openhands_cloud_key_alias(self.user_id, org_id)
await LiteLlmManager.delete_key_by_alias(key_alias=key_alias)
generated_key = await LiteLlmManager.generate_key(
self.user_id,
org_id,
key_alias,
None,
)
item.llm_api_key = SecretStr(generated_key)
logger.info(
'saas_settings_store:store:generated_openhands_key',
extra={'user_id': self.user_id},
)