mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-25 21:36:52 +08:00
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
100 lines
3.7 KiB
Python
100 lines
3.7 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any, AsyncGenerator
|
|
|
|
from fastapi import Request
|
|
from pydantic import PrivateAttr
|
|
|
|
from openhands.app_server.errors import AuthError
|
|
from openhands.app_server.services.injector import InjectorState
|
|
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
|
|
from openhands.app_server.user.user_context import UserContext, UserContextInjector
|
|
from openhands.app_server.user.user_models import UserInfo
|
|
from openhands.integrations.provider import ProviderHandler, ProviderType
|
|
from openhands.sdk.conversation.secret_source import SecretSource, StaticSecret
|
|
from openhands.server.user_auth.user_auth import UserAuth, get_user_auth
|
|
|
|
USER_AUTH_ATTR = 'user_auth'
|
|
|
|
|
|
@dataclass
|
|
class AuthUserContext(UserContext):
|
|
"""Interface to old user settings service. Eventually we want to migrate
|
|
this to use true database asyncio."""
|
|
|
|
user_auth: UserAuth
|
|
_user_info: UserInfo | None = None
|
|
_provider_handler: ProviderHandler | None = None
|
|
|
|
async def get_user_id(self) -> str | None:
|
|
# If you have an auth object here you are logged in. If user_id is None
|
|
# it means we are in OSS mode.
|
|
user_id = await self.user_auth.get_user_id()
|
|
return user_id
|
|
|
|
async def get_user_info(self) -> UserInfo:
|
|
user_info = self._user_info
|
|
if user_info is None:
|
|
user_id = await self.get_user_id()
|
|
settings = await self.user_auth.get_user_settings()
|
|
assert settings is not None
|
|
user_info = UserInfo(
|
|
id=user_id,
|
|
**settings.model_dump(context={'expose_secrets': True}),
|
|
)
|
|
self._user_info = user_info
|
|
return user_info
|
|
|
|
async def get_provider_handler(self):
|
|
provider_handler = self._provider_handler
|
|
if not provider_handler:
|
|
provider_tokens = await self.user_auth.get_provider_tokens()
|
|
assert provider_tokens is not None
|
|
user_id = await self.get_user_id()
|
|
provider_handler = ProviderHandler(
|
|
provider_tokens=provider_tokens, external_auth_id=user_id
|
|
)
|
|
self._provider_handler = provider_handler
|
|
return provider_handler
|
|
|
|
async def get_authenticated_git_url(self, repository: str) -> str:
|
|
provider_handler = await self.get_provider_handler()
|
|
url = await provider_handler.get_authenticated_git_url(repository)
|
|
return url
|
|
|
|
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
|
provider_handler = await self.get_provider_handler()
|
|
service = provider_handler.get_service(provider_type)
|
|
token = await service.get_latest_token()
|
|
return token
|
|
|
|
async def get_secrets(self) -> dict[str, SecretSource]:
|
|
results = {}
|
|
|
|
# Include custom secrets...
|
|
secrets = await self.user_auth.get_user_secrets()
|
|
if secrets:
|
|
for name, custom_secret in secrets.custom_secrets.items():
|
|
results[name] = StaticSecret(value=custom_secret.secret)
|
|
|
|
return results
|
|
|
|
|
|
USER_ID_ATTR = 'user_id'
|
|
|
|
|
|
class AuthUserContextInjector(UserContextInjector):
|
|
_user_auth_class: Any = PrivateAttr(default=None)
|
|
|
|
async def inject(
|
|
self, state: InjectorState, request: Request | None = None
|
|
) -> AsyncGenerator[UserContext, None]:
|
|
user_context = getattr(state, USER_CONTEXT_ATTR, None)
|
|
if user_context is None:
|
|
if request is None:
|
|
raise AuthError()
|
|
user_auth = await get_user_auth(request)
|
|
user_context = AuthUserContext(user_auth=user_auth)
|
|
setattr(state, USER_CONTEXT_ATTR, user_context)
|
|
|
|
yield user_context
|