SAAS: dedup fetching user settings from keycloak id (#11480)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-10-23 09:56:55 -04:00 committed by GitHub
parent 134c122026
commit f3d9faef34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 89 additions and 61 deletions

View File

@ -24,7 +24,7 @@ from server.config import get_config
from storage.database import session_maker
from storage.proactive_conversation_store import ProactiveConversationStore
from storage.saas_secrets_store import SaasSecretsStore
from storage.user_settings import UserSettings
from storage.saas_settings_store import SaasSettingsStore
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GithubServiceImpl
@ -61,20 +61,19 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
if not user_id:
return False
def _get_setting():
with session_maker() as session:
settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == user_id)
.first()
)
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
if not settings or settings.enable_proactive_conversation_starters is None:
return False
settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
return settings.enable_proactive_conversation_starters
if not settings or settings.enable_proactive_conversation_starters is None:
return False
return await call_sync_from_async(_get_setting)
return settings.enable_proactive_conversation_starters
# =================================================

View File

@ -3,10 +3,11 @@ from datetime import UTC, datetime
import httpx
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, field_validator
from server.config import get_config
from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.user_settings import UserSettings
from storage.saas_settings_store import SaasSettingsStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
@ -16,30 +17,30 @@ from openhands.utils.async_utils import call_sync_from_async
# Helper functions for BYOR API key management
async def get_byor_key_from_db(user_id: str) -> str | None:
"""Get the BYOR key from the database for a user."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
def _get_byor_key():
with session_maker() as session:
user_db_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == user_id)
.first()
)
if user_db_settings and user_db_settings.llm_api_key_for_byor:
return user_db_settings.llm_api_key_for_byor
return None
return await call_sync_from_async(_get_byor_key)
user_db_settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
if user_db_settings and user_db_settings.llm_api_key_for_byor:
return user_db_settings.llm_api_key_for_byor
return None
async def store_byor_key_in_db(user_id: str, key: str) -> None:
"""Store the BYOR key in the database for a user."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
def _update_user_settings():
with session_maker() as session:
user_db_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == user_id)
.first()
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
user_id, session
)
if user_db_settings:
user_db_settings.llm_api_key_for_byor = key

View File

@ -16,10 +16,11 @@ from server.auth.constants import (
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from server.config import sign_token
from server.config import get_config, sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
from openhands.core.logger import openhands_logger as logger
@ -212,16 +213,14 @@ async def keycloak_callback(
f'&state={state}'
)
has_accepted_tos = False
with session_maker() as session:
user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == user_id)
.first()
)
has_accepted_tos = (
user_settings is not None and user_settings.accepted_tos is not None
)
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
has_accepted_tos = (
user_settings is not None and user_settings.accepted_tos is not None
)
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:

View File

@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse, RedirectResponse
from integrations import stripe_service
from pydantic import BaseModel
from server.config import get_config
from server.constants import (
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
@ -22,8 +23,8 @@ from server.constants import (
from server.logger import logger
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from storage.subscription_access import SubscriptionAccess
from storage.user_settings import UserSettings
from openhands.server.user_auth import get_user_id
@ -617,11 +618,14 @@ async def stripe_webhook(request: Request) -> JSONResponse:
def reset_user_to_free_tier_settings(user_id: str) -> None:
"""Reset user settings to free tier defaults when subscription ends."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
with session_maker() as session:
user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == user_id)
.first()
user_settings = settings_store.get_user_settings_by_keycloak_id(
user_id, session
)
if user_settings:

View File

@ -39,15 +39,46 @@ class SaasSettingsStore(SettingsStore):
session_maker: sessionmaker
config: OpenHandsConfig
def get_user_settings_by_keycloak_id(
self, keycloak_user_id: str, session=None
) -> UserSettings | None:
"""
Get UserSettings by keycloak_user_id.
Args:
keycloak_user_id: The keycloak user ID to search for
session: Optional existing database session. If not provided, creates a new one.
Returns:
UserSettings object if found, None otherwise
"""
if not keycloak_user_id:
return None
def _get_settings():
if session:
# Use provided session
return (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
.first()
)
else:
# Create new session
with self.session_maker() as new_session:
return (
new_session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
.first()
)
return _get_settings()
async def load(self) -> Settings | None:
if not self.user_id:
return None
with self.session_maker() as session:
settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == self.user_id)
.first()
)
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
logger.info(
@ -71,12 +102,8 @@ class SaasSettingsStore(SettingsStore):
if item:
kwargs = item.model_dump(context={'expose_secrets': True})
self._encrypt_kwargs(kwargs)
query = session.query(UserSettings).filter(
UserSettings.keycloak_user_id == self.user_id
)
# First check if we have an existing entry in the new table
existing = query.first()
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
kwargs = {
key: value
@ -207,10 +234,8 @@ class SaasSettingsStore(SettingsStore):
spend = user_info.get('spend') or 0
with session_maker() as session:
user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == self.user_id)
.first()
user_settings = self.get_user_settings_by_keycloak_id(
self.user_id, session
)
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)

View File

@ -8,8 +8,8 @@ pytestmark = pytest.mark.asyncio
# Mock the call_sync_from_async function to return the result of the function directly
def mock_call_sync_from_async(func):
return func()
def mock_call_sync_from_async(func, *args, **kwargs):
return func(*args, **kwargs)
@pytest.fixture