diff --git a/frontend/src/api/open-hands.ts b/frontend/src/api/open-hands.ts index bb70ffa4ef..afb939dd20 100644 --- a/frontend/src/api/open-hands.ts +++ b/frontend/src/api/open-hands.ts @@ -174,7 +174,7 @@ class OpenHands { code: string, ): Promise { const { data } = await openHands.post( - "/api/github/callback", + "/api/keycloak/callback", { code, }, diff --git a/frontend/src/utils/generate-github-auth-url.ts b/frontend/src/utils/generate-github-auth-url.ts index 87370b9784..ed2fa38f41 100644 --- a/frontend/src/utils/generate-github-auth-url.ts +++ b/frontend/src/utils/generate-github-auth-url.ts @@ -5,7 +5,13 @@ * @returns The URL to redirect to for GitHub OAuth */ export const generateGitHubAuthUrl = (clientId: string, requestUrl: URL) => { - const redirectUri = `${requestUrl.origin}/oauth/github/callback`; - const scope = "repo,user,workflow,offline_access"; - return `https://github.com/login/oauth/authorize?client_id=${clientId}&redirect_uri=${encodeURIComponent(redirectUri)}&scope=${encodeURIComponent(scope)}`; + const redirectUri = `${requestUrl.origin}/oauth/keycloak/callback`; + const baseUrl = `${requestUrl.origin}` + .replace("https://", "") + .replace("http://", ""); + const authUrl = baseUrl + .replace(/(^|\.)staging\.all-hands\.dev$/, ".auth.staging.all-hands.dev") + .replace(/(^|\.)app\.all-hands\.dev$/, "auth.app.all-hands.dev"); + const scope = "openid email profile"; + return `https://${authUrl}/realms/allhands/protocol/openid-connect/auth?client_id=github&response_type=code&redirect_uri=${encodeURIComponent(redirectUri)}&scope=${encodeURIComponent(scope)}`; }; diff --git a/openhands/integrations/github/github_service.py b/openhands/integrations/github/github_service.py index 1637d42e38..e9049460ba 100644 --- a/openhands/integrations/github/github_service.py +++ b/openhands/integrations/github/github_service.py @@ -21,7 +21,12 @@ class GitHubService: token: SecretStr = SecretStr('') refresh = False - def __init__(self, user_id: str | None = None, token: SecretStr | None = None): + def __init__( + self, + user_id: str | None = None, + idp_token: SecretStr | None = None, + token: SecretStr | None = None, + ): self.user_id = user_id if token: diff --git a/openhands/server/auth.py b/openhands/server/auth.py index fa28dafbf4..470834f8d0 100644 --- a/openhands/server/auth.py +++ b/openhands/server/auth.py @@ -8,3 +8,7 @@ def get_github_token(request: Request) -> SecretStr | None: def get_user_id(request: Request) -> str | None: return getattr(request.state, 'github_user_id', None) + + +def get_idp_token(request: Request) -> SecretStr | None: + return getattr(request.state, 'idp_token', None) diff --git a/openhands/server/listen_socket.py b/openhands/server/listen_socket.py index ad45a33d26..f89157856b 100644 --- a/openhands/server/listen_socket.py +++ b/openhands/server/listen_socket.py @@ -1,7 +1,5 @@ from urllib.parse import parse_qs -import jwt -from pydantic import SecretStr from socketio.exceptions import ConnectionRefusedError from openhands.core.logger import openhands_logger as logger @@ -15,18 +13,18 @@ from openhands.events.observation.agent import AgentStateChangedObservation from openhands.events.serialization import event_to_dict from openhands.events.stream import AsyncEventStreamWrapper from openhands.server.shared import ( - ConversationStoreImpl, SettingsStoreImpl, config, conversation_manager, - server_config, sio, ) -from openhands.server.types import AppMode +from openhands.storage.conversation.conversation_validator import ( + ConversationValidatorImpl, +) @sio.event -async def connect(connection_id: str, environ, auth): +async def connect(connection_id: str, environ): logger.info(f'sio:connect: {connection_id}') query_params = parse_qs(environ.get('QUERY_STRING', '')) latest_event_id = int(query_params.get('latest_event_id', [-1])[0]) @@ -35,37 +33,9 @@ async def connect(connection_id: str, environ, auth): logger.error('No conversation_id in query params') raise ConnectionRefusedError('No conversation_id in query params') - user_id = None - if server_config.app_mode != AppMode.OSS: - cookies_str = environ.get('HTTP_COOKIE', '') - cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; ')) - signed_token = cookies.get('openhands_auth', '') - if not signed_token: - logger.error('No openhands_auth cookie') - raise ConnectionRefusedError('No openhands_auth cookie') - if not config.jwt_secret: - raise RuntimeError('JWT secret not found') - - jwt_secret = ( - config.jwt_secret.get_secret_value() - if isinstance(config.jwt_secret, SecretStr) - else config.jwt_secret - ) - decoded = jwt.decode(signed_token, jwt_secret, algorithms=['HS256']) - user_id = decoded['github_user_id'] - - logger.info(f'User {user_id} is connecting to conversation {conversation_id}') - - conversation_store = await ConversationStoreImpl.get_instance(config, user_id) - metadata = await conversation_store.get_metadata(conversation_id) - - if metadata.github_user_id != str(user_id): - logger.error( - f'User {user_id} is not allowed to join conversation {conversation_id}' - ) - raise ConnectionRefusedError( - f'User {user_id} is not allowed to join conversation {conversation_id}' - ) + cookies_str = environ.get('HTTP_COOKIE', '') + conversation_validator = ConversationValidatorImpl() + user_id = await conversation_validator.validate(conversation_id, cookies_str) settings_store = await SettingsStoreImpl.get_instance(config, user_id) settings = await settings_store.load() diff --git a/openhands/server/routes/github.py b/openhands/server/routes/github.py index b3e4e4e1f3..0255f6327b 100644 --- a/openhands/server/routes/github.py +++ b/openhands/server/routes/github.py @@ -10,7 +10,7 @@ from openhands.integrations.github.github_types import ( GitHubUser, SuggestedTask, ) -from openhands.server.auth import get_github_token, get_user_id +from openhands.server.auth import get_github_token, get_idp_token, get_user_id app = APIRouter(prefix='/api/github') @@ -23,8 +23,11 @@ async def get_github_repositories( installation_id: int | None = None, github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), + idp_token: SecretStr | None = Depends(get_idp_token), ): - client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) + client = GithubServiceImpl( + user_id=github_user_id, idp_token=idp_token, token=github_user_token + ) try: repos: list[GitHubRepository] = await client.get_repositories( page, per_page, sort, installation_id @@ -48,8 +51,11 @@ async def get_github_repositories( async def get_github_user( github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), + idp_token: SecretStr | None = Depends(get_idp_token), ): - client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) + client = GithubServiceImpl( + user_id=github_user_id, idp_token=idp_token, token=github_user_token + ) try: user: GitHubUser = await client.get_user() return user @@ -71,8 +77,11 @@ async def get_github_user( async def get_github_installation_ids( github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), + idp_token: SecretStr | None = Depends(get_idp_token), ): - client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) + client = GithubServiceImpl( + user_id=github_user_id, idp_token=idp_token, token=github_user_token + ) try: installations_ids: list[int] = await client.get_installation_ids() return installations_ids @@ -98,8 +107,11 @@ async def search_github_repositories( order: str = 'desc', github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), + idp_token: SecretStr | None = Depends(get_idp_token), ): - client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) + client = GithubServiceImpl( + user_id=github_user_id, idp_token=idp_token, token=github_user_token + ) try: repos: list[GitHubRepository] = await client.search_repositories( query, per_page, sort, order @@ -123,14 +135,17 @@ async def search_github_repositories( async def get_suggested_tasks( github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), + idp_token: SecretStr | None = Depends(get_idp_token), ): - """ - Get suggested tasks for the authenticated user across their most recently pushed repositories. + """Get suggested tasks for the authenticated user across their most recently pushed repositories. + Returns: - PRs owned by the user - - Issues assigned to the user + - Issues assigned to the user. """ - client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) + client = GithubServiceImpl( + user_id=github_user_id, idp_token=idp_token, token=github_user_token + ) try: tasks: list[SuggestedTask] = await client.get_suggested_tasks() return tasks diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 0389bbd90c..33c55a32d1 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -11,7 +11,7 @@ from openhands.events.action.message import MessageAction from openhands.events.stream import EventStreamSubscriber from openhands.integrations.github.github_service import GithubServiceImpl from openhands.runtime import get_runtime_cls -from openhands.server.auth import get_github_token, get_user_id +from openhands.server.auth import get_github_token, get_idp_token, get_user_id from openhands.server.data_models.conversation_info import ConversationInfo from openhands.server.data_models.conversation_info_result_set import ( ConversationInfoResultSet, @@ -138,7 +138,11 @@ async def new_conversation(request: Request, data: InitSessionRequest): """ logger.info('Initializing new conversation') user_id = get_user_id(request) - gh_client = GithubServiceImpl(user_id=user_id, token=get_github_token(request)) + gh_client = GithubServiceImpl( + user_id=user_id, + idp_token=get_idp_token(request), + token=get_github_token(request), + ) github_token = await gh_client.get_latest_token() selected_repository = data.selected_repository diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 66ed76a23e..a63c84fa7f 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -52,7 +52,7 @@ async def store_settings( # We check if the token is valid by getting the user # If the token is invalid, this will raise an exception github = GithubServiceImpl( - user_id=None, token=SecretStr(settings.github_token) + user_id=None, idp_token=None, token=SecretStr(settings.github_token) ) await github.get_user() diff --git a/openhands/storage/conversation/conversation_validator.py b/openhands/storage/conversation/conversation_validator.py new file mode 100644 index 0000000000..51f293b395 --- /dev/null +++ b/openhands/storage/conversation/conversation_validator.py @@ -0,0 +1,17 @@ +import os + +from openhands.utils.import_utils import get_impl + + +class ConversationValidator: + """Storage for conversation metadata. May or may not support multiple users depending on the environment.""" + + async def validate(self, conversation_id: str, cookies_str: str): + return None + + +conversation_validator_cls = os.environ.get( + 'OPENHANDS_CONVERSATION_VALIDATOR_CLS', + 'openhands.storage.conversation.conversation_validator.ConversationValidator', +) +ConversationValidatorImpl = get_impl(ConversationValidator, conversation_validator_cls)