Keycloak changes (#6986)

This commit is contained in:
chuckbutkus
2025-02-28 15:29:15 -05:00
committed by GitHub
parent de4cf07d4d
commit 17644fedd7
9 changed files with 75 additions and 54 deletions

View File

@@ -174,7 +174,7 @@ class OpenHands {
code: string,
): Promise<GitHubAccessTokenResponse> {
const { data } = await openHands.post<GitHubAccessTokenResponse>(
"/api/github/callback",
"/api/keycloak/callback",
{
code,
},

View File

@@ -5,7 +5,13 @@
* @returns The URL to redirect to for GitHub OAuth
*/
export const generateGitHubAuthUrl = (clientId: string, requestUrl: URL) => {
const redirectUri = `${requestUrl.origin}/oauth/github/callback`;
const scope = "repo,user,workflow,offline_access";
return `https://github.com/login/oauth/authorize?client_id=${clientId}&redirect_uri=${encodeURIComponent(redirectUri)}&scope=${encodeURIComponent(scope)}`;
const redirectUri = `${requestUrl.origin}/oauth/keycloak/callback`;
const baseUrl = `${requestUrl.origin}`
.replace("https://", "")
.replace("http://", "");
const authUrl = baseUrl
.replace(/(^|\.)staging\.all-hands\.dev$/, ".auth.staging.all-hands.dev")
.replace(/(^|\.)app\.all-hands\.dev$/, "auth.app.all-hands.dev");
const scope = "openid email profile";
return `https://${authUrl}/realms/allhands/protocol/openid-connect/auth?client_id=github&response_type=code&redirect_uri=${encodeURIComponent(redirectUri)}&scope=${encodeURIComponent(scope)}`;
};

View File

@@ -21,7 +21,12 @@ class GitHubService:
token: SecretStr = SecretStr('')
refresh = False
def __init__(self, user_id: str | None = None, token: SecretStr | None = None):
def __init__(
self,
user_id: str | None = None,
idp_token: SecretStr | None = None,
token: SecretStr | None = None,
):
self.user_id = user_id
if token:

View File

@@ -8,3 +8,7 @@ def get_github_token(request: Request) -> SecretStr | None:
def get_user_id(request: Request) -> str | None:
return getattr(request.state, 'github_user_id', None)
def get_idp_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'idp_token', None)

View File

@@ -1,7 +1,5 @@
from urllib.parse import parse_qs
import jwt
from pydantic import SecretStr
from socketio.exceptions import ConnectionRefusedError
from openhands.core.logger import openhands_logger as logger
@@ -15,18 +13,18 @@ from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.shared import (
ConversationStoreImpl,
SettingsStoreImpl,
config,
conversation_manager,
server_config,
sio,
)
from openhands.server.types import AppMode
from openhands.storage.conversation.conversation_validator import (
ConversationValidatorImpl,
)
@sio.event
async def connect(connection_id: str, environ, auth):
async def connect(connection_id: str, environ):
logger.info(f'sio:connect: {connection_id}')
query_params = parse_qs(environ.get('QUERY_STRING', ''))
latest_event_id = int(query_params.get('latest_event_id', [-1])[0])
@@ -35,37 +33,9 @@ async def connect(connection_id: str, environ, auth):
logger.error('No conversation_id in query params')
raise ConnectionRefusedError('No conversation_id in query params')
user_id = None
if server_config.app_mode != AppMode.OSS:
cookies_str = environ.get('HTTP_COOKIE', '')
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
signed_token = cookies.get('openhands_auth', '')
if not signed_token:
logger.error('No openhands_auth cookie')
raise ConnectionRefusedError('No openhands_auth cookie')
if not config.jwt_secret:
raise RuntimeError('JWT secret not found')
jwt_secret = (
config.jwt_secret.get_secret_value()
if isinstance(config.jwt_secret, SecretStr)
else config.jwt_secret
)
decoded = jwt.decode(signed_token, jwt_secret, algorithms=['HS256'])
user_id = decoded['github_user_id']
logger.info(f'User {user_id} is connecting to conversation {conversation_id}')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
metadata = await conversation_store.get_metadata(conversation_id)
if metadata.github_user_id != str(user_id):
logger.error(
f'User {user_id} is not allowed to join conversation {conversation_id}'
)
raise ConnectionRefusedError(
f'User {user_id} is not allowed to join conversation {conversation_id}'
)
cookies_str = environ.get('HTTP_COOKIE', '')
conversation_validator = ConversationValidatorImpl()
user_id = await conversation_validator.validate(conversation_id, cookies_str)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()

View File

@@ -10,7 +10,7 @@ from openhands.integrations.github.github_types import (
GitHubUser,
SuggestedTask,
)
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
app = APIRouter(prefix='/api/github')
@@ -23,8 +23,11 @@ async def get_github_repositories(
installation_id: int | None = None,
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
):
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
)
try:
repos: list[GitHubRepository] = await client.get_repositories(
page, per_page, sort, installation_id
@@ -48,8 +51,11 @@ async def get_github_repositories(
async def get_github_user(
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
):
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
)
try:
user: GitHubUser = await client.get_user()
return user
@@ -71,8 +77,11 @@ async def get_github_user(
async def get_github_installation_ids(
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
):
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
)
try:
installations_ids: list[int] = await client.get_installation_ids()
return installations_ids
@@ -98,8 +107,11 @@ async def search_github_repositories(
order: str = 'desc',
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
):
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
)
try:
repos: list[GitHubRepository] = await client.search_repositories(
query, per_page, sort, order
@@ -123,14 +135,17 @@ async def search_github_repositories(
async def get_suggested_tasks(
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
idp_token: SecretStr | None = Depends(get_idp_token),
):
"""
Get suggested tasks for the authenticated user across their most recently pushed repositories.
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
Returns:
- PRs owned by the user
- Issues assigned to the user
- Issues assigned to the user.
"""
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
client = GithubServiceImpl(
user_id=github_user_id, idp_token=idp_token, token=github_user_token
)
try:
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
return tasks

View File

@@ -11,7 +11,7 @@ from openhands.events.action.message import MessageAction
from openhands.events.stream import EventStreamSubscriber
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@@ -138,7 +138,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
"""
logger.info('Initializing new conversation')
user_id = get_user_id(request)
gh_client = GithubServiceImpl(user_id=user_id, token=get_github_token(request))
gh_client = GithubServiceImpl(
user_id=user_id,
idp_token=get_idp_token(request),
token=get_github_token(request),
)
github_token = await gh_client.get_latest_token()
selected_repository = data.selected_repository

View File

@@ -52,7 +52,7 @@ async def store_settings(
# We check if the token is valid by getting the user
# If the token is invalid, this will raise an exception
github = GithubServiceImpl(
user_id=None, token=SecretStr(settings.github_token)
user_id=None, idp_token=None, token=SecretStr(settings.github_token)
)
await github.get_user()

View File

@@ -0,0 +1,17 @@
import os
from openhands.utils.import_utils import get_impl
class ConversationValidator:
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
async def validate(self, conversation_id: str, cookies_str: str):
return None
conversation_validator_cls = os.environ.get(
'OPENHANDS_CONVERSATION_VALIDATOR_CLS',
'openhands.storage.conversation.conversation_validator.ConversationValidator',
)
ConversationValidatorImpl = get_impl(ConversationValidator, conversation_validator_cls)