From f3d9faef349a5ff0ba3c4c0404d7fd15da95fe67 Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Thu, 23 Oct 2025 09:56:55 -0400 Subject: [PATCH] SAAS: dedup fetching user settings from keycloak id (#11480) Co-authored-by: openhands --- enterprise/integrations/github/github_view.py | 23 ++++---- enterprise/server/routes/api_keys.py | 35 ++++++------ enterprise/server/routes/auth.py | 21 ++++---- enterprise/server/routes/billing.py | 14 +++-- enterprise/storage/saas_settings_store.py | 53 ++++++++++++++----- .../test_proactive_conversation_starters.py | 4 +- 6 files changed, 89 insertions(+), 61 deletions(-) diff --git a/enterprise/integrations/github/github_view.py b/enterprise/integrations/github/github_view.py index 208ad12365..435dec8b3f 100644 --- a/enterprise/integrations/github/github_view.py +++ b/enterprise/integrations/github/github_view.py @@ -24,7 +24,7 @@ from server.config import get_config from storage.database import session_maker from storage.proactive_conversation_store import ProactiveConversationStore from storage.saas_secrets_store import SaasSecretsStore -from storage.user_settings import UserSettings +from storage.saas_settings_store import SaasSettingsStore from openhands.core.logger import openhands_logger as logger from openhands.integrations.github.github_service import GithubServiceImpl @@ -61,20 +61,19 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool: if not user_id: return False - def _get_setting(): - with session_maker() as session: - settings = ( - session.query(UserSettings) - .filter(UserSettings.keycloak_user_id == user_id) - .first() - ) + config = get_config() + settings_store = SaasSettingsStore( + user_id=user_id, session_maker=session_maker, config=config + ) - if not settings or settings.enable_proactive_conversation_starters is None: - return False + settings = await call_sync_from_async( + settings_store.get_user_settings_by_keycloak_id, user_id + ) - return settings.enable_proactive_conversation_starters + if not settings or settings.enable_proactive_conversation_starters is None: + return False - return await call_sync_from_async(_get_setting) + return settings.enable_proactive_conversation_starters # ================================================= diff --git a/enterprise/server/routes/api_keys.py b/enterprise/server/routes/api_keys.py index 95ea8e4ec6..defa82c7d6 100644 --- a/enterprise/server/routes/api_keys.py +++ b/enterprise/server/routes/api_keys.py @@ -3,10 +3,11 @@ from datetime import UTC, datetime import httpx from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, field_validator +from server.config import get_config from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL from storage.api_key_store import ApiKeyStore from storage.database import session_maker -from storage.user_settings import UserSettings +from storage.saas_settings_store import SaasSettingsStore from openhands.core.logger import openhands_logger as logger from openhands.server.user_auth import get_user_id @@ -16,30 +17,30 @@ from openhands.utils.async_utils import call_sync_from_async # Helper functions for BYOR API key management async def get_byor_key_from_db(user_id: str) -> str | None: """Get the BYOR key from the database for a user.""" + config = get_config() + settings_store = SaasSettingsStore( + user_id=user_id, session_maker=session_maker, config=config + ) - def _get_byor_key(): - with session_maker() as session: - user_db_settings = ( - session.query(UserSettings) - .filter(UserSettings.keycloak_user_id == user_id) - .first() - ) - if user_db_settings and user_db_settings.llm_api_key_for_byor: - return user_db_settings.llm_api_key_for_byor - return None - - return await call_sync_from_async(_get_byor_key) + user_db_settings = await call_sync_from_async( + settings_store.get_user_settings_by_keycloak_id, user_id + ) + if user_db_settings and user_db_settings.llm_api_key_for_byor: + return user_db_settings.llm_api_key_for_byor + return None async def store_byor_key_in_db(user_id: str, key: str) -> None: """Store the BYOR key in the database for a user.""" + config = get_config() + settings_store = SaasSettingsStore( + user_id=user_id, session_maker=session_maker, config=config + ) def _update_user_settings(): with session_maker() as session: - user_db_settings = ( - session.query(UserSettings) - .filter(UserSettings.keycloak_user_id == user_id) - .first() + user_db_settings = settings_store.get_user_settings_by_keycloak_id( + user_id, session ) if user_db_settings: user_db_settings.llm_api_key_for_byor = key diff --git a/enterprise/server/routes/auth.py b/enterprise/server/routes/auth.py index e6fa3e7254..d5f5cbd1ed 100644 --- a/enterprise/server/routes/auth.py +++ b/enterprise/server/routes/auth.py @@ -16,10 +16,11 @@ from server.auth.constants import ( from server.auth.gitlab_sync import schedule_gitlab_repo_sync from server.auth.saas_user_auth import SaasUserAuth from server.auth.token_manager import TokenManager -from server.config import sign_token +from server.config import get_config, sign_token from server.constants import IS_FEATURE_ENV from server.routes.event_webhook import _get_session_api_key, _get_user_id from storage.database import session_maker +from storage.saas_settings_store import SaasSettingsStore from storage.user_settings import UserSettings from openhands.core.logger import openhands_logger as logger @@ -212,16 +213,14 @@ async def keycloak_callback( f'&state={state}' ) - has_accepted_tos = False - with session_maker() as session: - user_settings = ( - session.query(UserSettings) - .filter(UserSettings.keycloak_user_id == user_id) - .first() - ) - has_accepted_tos = ( - user_settings is not None and user_settings.accepted_tos is not None - ) + config = get_config() + settings_store = SaasSettingsStore( + user_id=user_id, session_maker=session_maker, config=config + ) + user_settings = settings_store.get_user_settings_by_keycloak_id(user_id) + has_accepted_tos = ( + user_settings is not None and user_settings.accepted_tos is not None + ) # If the user hasn't accepted the TOS, redirect to the TOS page if not has_accepted_tos: diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index b1a6fc96fb..2ab046eeb0 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import JSONResponse, RedirectResponse from integrations import stripe_service from pydantic import BaseModel +from server.config import get_config from server.constants import ( LITE_LLM_API_KEY, LITE_LLM_API_URL, @@ -22,8 +23,8 @@ from server.constants import ( from server.logger import logger from storage.billing_session import BillingSession from storage.database import session_maker +from storage.saas_settings_store import SaasSettingsStore from storage.subscription_access import SubscriptionAccess -from storage.user_settings import UserSettings from openhands.server.user_auth import get_user_id @@ -617,11 +618,14 @@ async def stripe_webhook(request: Request) -> JSONResponse: def reset_user_to_free_tier_settings(user_id: str) -> None: """Reset user settings to free tier defaults when subscription ends.""" + config = get_config() + settings_store = SaasSettingsStore( + user_id=user_id, session_maker=session_maker, config=config + ) + with session_maker() as session: - user_settings = ( - session.query(UserSettings) - .filter(UserSettings.keycloak_user_id == user_id) - .first() + user_settings = settings_store.get_user_settings_by_keycloak_id( + user_id, session ) if user_settings: diff --git a/enterprise/storage/saas_settings_store.py b/enterprise/storage/saas_settings_store.py index 0b5d40fe2c..719a45c49d 100644 --- a/enterprise/storage/saas_settings_store.py +++ b/enterprise/storage/saas_settings_store.py @@ -39,15 +39,46 @@ class SaasSettingsStore(SettingsStore): session_maker: sessionmaker config: OpenHandsConfig + def get_user_settings_by_keycloak_id( + self, keycloak_user_id: str, session=None + ) -> UserSettings | None: + """ + Get UserSettings by keycloak_user_id. + + Args: + keycloak_user_id: The keycloak user ID to search for + session: Optional existing database session. If not provided, creates a new one. + + Returns: + UserSettings object if found, None otherwise + """ + if not keycloak_user_id: + return None + + def _get_settings(): + if session: + # Use provided session + return ( + session.query(UserSettings) + .filter(UserSettings.keycloak_user_id == keycloak_user_id) + .first() + ) + else: + # Create new session + with self.session_maker() as new_session: + return ( + new_session.query(UserSettings) + .filter(UserSettings.keycloak_user_id == keycloak_user_id) + .first() + ) + + return _get_settings() + async def load(self) -> Settings | None: if not self.user_id: return None with self.session_maker() as session: - settings = ( - session.query(UserSettings) - .filter(UserSettings.keycloak_user_id == self.user_id) - .first() - ) + settings = self.get_user_settings_by_keycloak_id(self.user_id, session) if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION: logger.info( @@ -71,12 +102,8 @@ class SaasSettingsStore(SettingsStore): if item: kwargs = item.model_dump(context={'expose_secrets': True}) self._encrypt_kwargs(kwargs) - query = session.query(UserSettings).filter( - UserSettings.keycloak_user_id == self.user_id - ) - # First check if we have an existing entry in the new table - existing = query.first() + existing = self.get_user_settings_by_keycloak_id(self.user_id, session) kwargs = { key: value @@ -207,10 +234,8 @@ class SaasSettingsStore(SettingsStore): spend = user_info.get('spend') or 0 with session_maker() as session: - user_settings = ( - session.query(UserSettings) - .filter(UserSettings.keycloak_user_id == self.user_id) - .first() + user_settings = self.get_user_settings_by_keycloak_id( + self.user_id, session ) # In upgrade to V4, we no longer use billing margin, but instead apply this directly # in litellm. The default billing marign was 2 before this (hence the magic numbers below) diff --git a/enterprise/tests/unit/test_proactive_conversation_starters.py b/enterprise/tests/unit/test_proactive_conversation_starters.py index a6ffea764b..b9c7b6539d 100644 --- a/enterprise/tests/unit/test_proactive_conversation_starters.py +++ b/enterprise/tests/unit/test_proactive_conversation_starters.py @@ -8,8 +8,8 @@ pytestmark = pytest.mark.asyncio # Mock the call_sync_from_async function to return the result of the function directly -def mock_call_sync_from_async(func): - return func() +def mock_call_sync_from_async(func, *args, **kwargs): + return func(*args, **kwargs) @pytest.fixture