Add timeout to Keycloak operations and convert OfflineTokenStore to async (#13096)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-03-02 03:48:45 -05:00
committed by GitHub
parent d6b8d80026
commit e1408f7b15
45 changed files with 1577 additions and 1341 deletions

View File

@@ -1,5 +1,4 @@
from storage.blocked_email_domain_store import BlockedEmailDomainStore
from storage.database import session_maker
from openhands.core.logger import openhands_logger as logger
@@ -23,7 +22,7 @@ class DomainBlocker:
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
return None
def is_domain_blocked(self, email: str) -> bool:
async def is_domain_blocked(self, email: str) -> bool:
"""Check if email domain is blocked by querying the database directly via SQL.
Supports blocking:
@@ -45,7 +44,7 @@ class DomainBlocker:
try:
# Query database directly via SQL to check if domain is blocked
is_blocked = self.store.is_domain_blocked(domain)
is_blocked = await self.store.is_domain_blocked(domain)
if is_blocked:
logger.warning(f'Email domain {domain} is blocked for email: {email}')
@@ -63,5 +62,5 @@ class DomainBlocker:
# Initialize store and domain blocker
_store = BlockedEmailDomainStore(session_maker=session_maker)
_store = BlockedEmailDomainStore()
domain_blocker = DomainBlocker(store=_store)

View File

@@ -1,7 +1,7 @@
from integrations.github.github_service import SaaSGitHubService
from pydantic import SecretStr
from server.auth.auth_utils import user_verifier
from enterprise.server.auth.auth_utils import user_verifier
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_types import GitHubUser

View File

@@ -18,9 +18,10 @@ from server.auth.token_manager import TokenManager
from server.config import get_config
from server.logger import logger
from server.rate_limit import RateLimiter, create_redis_rate_limiter
from sqlalchemy import delete, select
from storage.api_key_store import ApiKeyStore
from storage.auth_tokens import AuthTokens
from storage.database import session_maker
from storage.database import a_session_maker
from storage.saas_secrets_store import SaasSecretsStore
from storage.saas_settings_store import SaasSettingsStore
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
@@ -124,7 +125,7 @@ class SaasUserAuth(UserAuth):
if secrets_store:
return secrets_store
user_id = await self.get_user_id()
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
secrets_store = SaasSecretsStore(user_id, get_config())
self.secrets_store = secrets_store
return secrets_store
@@ -161,12 +162,13 @@ class SaasUserAuth(UserAuth):
try:
# TODO: I think we can do this in a single request if we refactor
with session_maker() as session:
tokens = (
session.query(AuthTokens)
.where(AuthTokens.keycloak_user_id == self.user_id)
.all()
async with a_session_maker() as session:
result = await session.execute(
select(AuthTokens).where(
AuthTokens.keycloak_user_id == self.user_id
)
)
tokens = result.scalars().all()
for token in tokens:
idp_type = ProviderType(token.identity_provider)
@@ -192,11 +194,11 @@ class SaasUserAuth(UserAuth):
'idp_type': token.identity_provider,
},
)
with session_maker() as session:
session.query(AuthTokens).filter(
AuthTokens.id == token.id
).delete()
session.commit()
async with a_session_maker() as session:
await session.execute(
delete(AuthTokens).where(AuthTokens.id == token.id)
)
await session.commit()
raise
self.provider_tokens = MappingProxyType(provider_tokens)
@@ -210,7 +212,7 @@ class SaasUserAuth(UserAuth):
if settings_store:
return settings_store
user_id = await self.get_user_id()
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
settings_store = SaasSettingsStore(user_id, get_config())
self.settings_store = settings_store
return settings_store
@@ -278,7 +280,7 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
return None
api_key_store = ApiKeyStore.get_instance()
user_id = api_key_store.validate_api_key(api_key)
user_id = await api_key_store.validate_api_key(api_key)
if not user_id:
return None
offline_token = await token_manager.load_offline_token(user_id)
@@ -327,7 +329,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and domain_blocker.is_domain_blocked(email):
if email and await domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)

View File

@@ -251,7 +251,7 @@ async def delete_api_key(
)
# Delete the key
success = api_key_store.delete_api_key_by_id(key_id)
success = await api_key_store.delete_api_key_by_id(key_id)
if not success:
raise HTTPException(

View File

@@ -270,7 +270,7 @@ async def keycloak_callback(
# Fail open - continue with login if reCAPTCHA service unavailable
# Check if email domain is blocked
if email and domain_blocker.is_domain_blocked(email):
if email and await domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)

View File

@@ -181,7 +181,7 @@ async def device_token(device_code: str = Form(...)):
# Retrieve the specific API key for this device using the user_code
api_key_store = ApiKeyStore.get_instance()
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
device_api_key = api_key_store.retrieve_api_key_by_name(
device_api_key = await api_key_store.retrieve_api_key_by_name(
device_code_entry.keycloak_user_id, device_key_name
)

View File

@@ -388,5 +388,4 @@ async def _check_idp(
access_token.get_secret_value(), ProviderType(idp)
):
return default_value
return None

View File

@@ -2,6 +2,10 @@
from dataclasses import dataclass
from server.verified_models.verified_model_models import (
VerifiedModel,
VerifiedModelPage,
)
from sqlalchemy import (
Boolean,
Column,
@@ -18,10 +22,6 @@ from sqlalchemy import (
from sqlalchemy.ext.asyncio import AsyncSession
from storage.base import Base
from enterprise.server.verified_models.verified_model_models import (
VerifiedModel,
VerifiedModelPage,
)
from openhands.app_server.config import depends_db_session
from openhands.core.logger import openhands_logger as logger