OpenHands/enterprise/server/auth/saas_user_auth.py

349 lines
13 KiB
Python

import time
from dataclasses import dataclass
from types import MappingProxyType
import jwt
from fastapi import Request
from keycloak.exceptions import KeycloakError
from pydantic import SecretStr
from server.auth.auth_error import (
AuthError,
BearerTokenError,
CookieError,
ExpiredError,
NoCredentialsError,
)
from server.auth.domain_blocker import domain_blocker
from server.auth.token_manager import TokenManager
from server.config import get_config
from server.logger import logger
from server.rate_limit import RateLimiter, create_redis_rate_limiter
from storage.api_key_store import ApiKeyStore
from storage.auth_tokens import AuthTokens
from storage.database import session_maker
from storage.saas_secrets_store import SaasSecretsStore
from storage.saas_settings_store import SaasSettingsStore
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderToken,
ProviderType,
)
from openhands.server.settings import Settings
from openhands.server.user_auth.user_auth import AuthType, UserAuth
from openhands.storage.data_models.secrets import Secrets
from openhands.storage.settings.settings_store import SettingsStore
token_manager = TokenManager()
rate_limiter: RateLimiter = create_redis_rate_limiter('10/second; 100/minute')
@dataclass
class SaasUserAuth(UserAuth):
refresh_token: SecretStr
user_id: str
email: str | None = None
email_verified: bool | None = None
access_token: SecretStr | None = None
provider_tokens: PROVIDER_TOKEN_TYPE | None = None
refreshed: bool = False
settings_store: SaasSettingsStore | None = None
secrets_store: SaasSecretsStore | None = None
_settings: Settings | None = None
_secrets: Secrets | None = None
accepted_tos: bool | None = None
auth_type: AuthType = AuthType.COOKIE
async def get_user_id(self) -> str | None:
return self.user_id
async def get_user_email(self) -> str | None:
return self.email
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(1),
retry=retry_if_exception_type(KeycloakError),
)
async def refresh(self):
if self._is_token_expired(self.refresh_token):
logger.debug('saas_user_auth_refresh:expired')
raise ExpiredError()
tokens = await token_manager.refresh(self.refresh_token.get_secret_value())
self.access_token = SecretStr(tokens['access_token'])
self.refresh_token = SecretStr(tokens['refresh_token'])
self.refreshed = True
def _is_token_expired(self, token: SecretStr):
logger.debug('saas_user_auth_is_token_expired')
# Decode token payload - works with both access and refresh tokens
payload = jwt.decode(
token.get_secret_value(), options={'verify_signature': False}
)
# Sanity check - make sure we refer to current user
assert payload['sub'] == self.user_id
# Check token expiration
expiration = payload.get('exp')
if expiration:
logger.debug('saas_user_auth_is_token_expired expiration is %d', expiration)
return expiration and expiration < time.time()
def get_auth_type(self) -> AuthType | None:
return self.auth_type
async def get_user_settings(self) -> Settings | None:
settings = self._settings
if settings:
return settings
settings_store = await self.get_user_settings_store()
settings = await settings_store.load()
# If load() returned None, should settings be created?
if settings:
settings.email = self.email
settings.email_verified = self.email_verified
self._settings = settings
return settings
async def get_secrets_store(self):
logger.debug('saas_user_auth_get_secrets_store')
secrets_store = self.secrets_store
if secrets_store:
return secrets_store
user_id = await self.get_user_id()
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
self.secrets_store = secrets_store
return secrets_store
async def get_secrets(self):
user_secrets = self._secrets
if user_secrets:
return user_secrets
secrets_store = await self.get_secrets_store()
user_secrets = await secrets_store.load()
self._secrets = user_secrets
return user_secrets
async def get_access_token(self) -> SecretStr | None:
logger.debug('saas_user_auth_get_access_token')
try:
if self.access_token is None or self._is_token_expired(self.access_token):
await self.refresh()
return self.access_token
except AuthError:
raise
except Exception as e:
raise AuthError() from e
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
logger.debug('saas_user_auth_get_provider_tokens')
if self.provider_tokens is not None:
return self.provider_tokens
provider_tokens = {}
access_token = await self.get_access_token()
if not access_token:
raise AuthError()
user_secrets = await self.get_secrets()
try:
# TODO: I think we can do this in a single request if we refactor
with session_maker() as session:
tokens = session.query(AuthTokens).where(
AuthTokens.keycloak_user_id == self.user_id
)
for token in tokens:
idp_type = ProviderType(token.identity_provider)
try:
host = None
if user_secrets and idp_type in user_secrets.provider_tokens:
host = user_secrets.provider_tokens[idp_type].host
provider_token = await token_manager.get_idp_token(
access_token.get_secret_value(),
idp=idp_type,
)
# TODO: Currently we don't store the IDP user id in our refresh table. We should.
provider_tokens[idp_type] = ProviderToken(
token=SecretStr(provider_token), user_id=None, host=host
)
except Exception as e:
# If there was a problem with a refresh token we log and delete it
logger.error(
f'Error refreshing provider_token token: {e}',
extra={
'user_id': self.user_id,
'idp_type': token.identity_provider,
},
)
with session_maker() as session:
session.query(AuthTokens).filter(
AuthTokens.id == token.id
).delete()
session.commit()
raise
self.provider_tokens = MappingProxyType(provider_tokens)
return self.provider_tokens
except Exception as e:
# Any error refreshing tokens means we need to log in again
raise AuthError() from e
async def get_user_settings_store(self) -> SettingsStore:
settings_store = self.settings_store
if settings_store:
return settings_store
user_id = await self.get_user_id()
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
self.settings_store = settings_store
return settings_store
async def get_mcp_api_key(self) -> str:
api_key_store = ApiKeyStore.get_instance()
mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id)
if not mcp_api_key:
mcp_api_key = api_key_store.create_api_key(
self.user_id, 'MCP_API_KEY', None
)
return mcp_api_key
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:
logger.debug('saas_user_auth_get_instance')
# First we check for for an API Key...
logger.debug('saas_user_auth_get_instance:check_bearer')
instance = await saas_user_auth_from_bearer(request)
if instance is None:
logger.debug('saas_user_auth_get_instance:check_cookie')
instance = await saas_user_auth_from_cookie(request)
if instance is None:
logger.debug('saas_user_auth_get_instance:no_credentials')
raise NoCredentialsError('failed to authenticate')
if not getattr(request.state, 'user_rate_limit_processed', False):
user_id = await instance.get_user_id()
if user_id:
# Ensure requests are only counted once
request.state.user_rate_limit_processed = True
# Will raise if rate limit is reached.
await rate_limiter.hit('auth_uid', user_id)
return instance
@classmethod
async def get_for_user(cls, user_id: str) -> UserAuth:
offline_token = await token_manager.load_offline_token(user_id)
assert offline_token is not None
return SaasUserAuth(
user_id=user_id,
refresh_token=SecretStr(offline_token),
auth_type=AuthType.BEARER,
)
def get_api_key_from_header(request: Request):
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Bearer '):
return auth_header.replace('Bearer ', '')
# This is a temp hack
# Streamable HTTP MCP Client works via redirect requests, but drops the Authorization header for reason
# We include `X-Session-API-Key` header by default due to nested runtimes, so it used as a drop in replacement here
session_api_key = request.headers.get('X-Session-API-Key')
if session_api_key:
return session_api_key
# Fallback to X-Access-Token header as an additional option
return request.headers.get('X-Access-Token')
async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
try:
api_key = get_api_key_from_header(request)
if not api_key:
return None
api_key_store = ApiKeyStore.get_instance()
user_id = api_key_store.validate_api_key(api_key)
if not user_id:
return None
offline_token = await token_manager.load_offline_token(user_id)
return SaasUserAuth(
user_id=user_id,
refresh_token=SecretStr(offline_token),
auth_type=AuthType.BEARER,
)
except Exception as exc:
raise BearerTokenError from exc
async def saas_user_auth_from_cookie(request: Request) -> SaasUserAuth | None:
try:
signed_token = request.cookies.get('keycloak_auth')
if not signed_token:
return None
return await saas_user_auth_from_signed_token(signed_token)
except Exception as exc:
raise CookieError from exc
async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
logger.debug('saas_user_auth_from_signed_token')
jwt_secret = get_config().jwt_secret.get_secret_value()
decoded = jwt.decode(signed_token, jwt_secret, algorithms=['HS256'])
logger.debug('saas_user_auth_from_signed_token:decoded')
access_token = decoded['access_token']
refresh_token = decoded['refresh_token']
logger.debug(
'saas_user_auth_from_signed_token',
extra={
'access_token': access_token,
'refresh_token': refresh_token,
},
)
accepted_tos = decoded.get('accepted_tos')
# The access token was encoded using HS256 on keycloak. Since we signed it, we can trust is was
# created by us. So we can grab the user_id and expiration from it without going back to keycloak.
access_token_payload = jwt.decode(access_token, options={'verify_signature': False})
user_id = access_token_payload['sub']
email = access_token_payload['email']
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)
raise AuthError(
'Access denied: Your email domain is not allowed to access this service'
)
logger.debug('saas_user_auth_from_signed_token:return')
return SaasUserAuth(
access_token=SecretStr(access_token),
refresh_token=SecretStr(refresh_token),
user_id=user_id,
email=email,
email_verified=email_verified,
accepted_tos=accepted_tos,
auth_type=AuthType.COOKIE,
)
async def get_user_auth_from_keycloak_id(keycloak_user_id: str) -> UserAuth:
offline_token = await token_manager.load_offline_token(keycloak_user_id)
if offline_token is None:
logger.info('no_offline_token_found')
user_auth = SaasUserAuth(
user_id=keycloak_user_id,
refresh_token=SecretStr(offline_token),
)
return user_auth