mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Keycloak changes (#6986)
This commit is contained in:
@@ -174,7 +174,7 @@ class OpenHands {
|
||||
code: string,
|
||||
): Promise<GitHubAccessTokenResponse> {
|
||||
const { data } = await openHands.post<GitHubAccessTokenResponse>(
|
||||
"/api/github/callback",
|
||||
"/api/keycloak/callback",
|
||||
{
|
||||
code,
|
||||
},
|
||||
|
||||
@@ -5,7 +5,13 @@
|
||||
* @returns The URL to redirect to for GitHub OAuth
|
||||
*/
|
||||
export const generateGitHubAuthUrl = (clientId: string, requestUrl: URL) => {
|
||||
const redirectUri = `${requestUrl.origin}/oauth/github/callback`;
|
||||
const scope = "repo,user,workflow,offline_access";
|
||||
return `https://github.com/login/oauth/authorize?client_id=${clientId}&redirect_uri=${encodeURIComponent(redirectUri)}&scope=${encodeURIComponent(scope)}`;
|
||||
const redirectUri = `${requestUrl.origin}/oauth/keycloak/callback`;
|
||||
const baseUrl = `${requestUrl.origin}`
|
||||
.replace("https://", "")
|
||||
.replace("http://", "");
|
||||
const authUrl = baseUrl
|
||||
.replace(/(^|\.)staging\.all-hands\.dev$/, ".auth.staging.all-hands.dev")
|
||||
.replace(/(^|\.)app\.all-hands\.dev$/, "auth.app.all-hands.dev");
|
||||
const scope = "openid email profile";
|
||||
return `https://${authUrl}/realms/allhands/protocol/openid-connect/auth?client_id=github&response_type=code&redirect_uri=${encodeURIComponent(redirectUri)}&scope=${encodeURIComponent(scope)}`;
|
||||
};
|
||||
|
||||
@@ -21,7 +21,12 @@ class GitHubService:
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(self, user_id: str | None = None, token: SecretStr | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
idp_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
|
||||
if token:
|
||||
|
||||
@@ -8,3 +8,7 @@ def get_github_token(request: Request) -> SecretStr | None:
|
||||
|
||||
def get_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'github_user_id', None)
|
||||
|
||||
|
||||
def get_idp_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'idp_token', None)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import jwt
|
||||
from pydantic import SecretStr
|
||||
from socketio.exceptions import ConnectionRefusedError
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -15,18 +13,18 @@ from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
server_config,
|
||||
sio,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.conversation.conversation_validator import (
|
||||
ConversationValidatorImpl,
|
||||
)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def connect(connection_id: str, environ, auth):
|
||||
async def connect(connection_id: str, environ):
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
query_params = parse_qs(environ.get('QUERY_STRING', ''))
|
||||
latest_event_id = int(query_params.get('latest_event_id', [-1])[0])
|
||||
@@ -35,37 +33,9 @@ async def connect(connection_id: str, environ, auth):
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
|
||||
user_id = None
|
||||
if server_config.app_mode != AppMode.OSS:
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
|
||||
signed_token = cookies.get('openhands_auth', '')
|
||||
if not signed_token:
|
||||
logger.error('No openhands_auth cookie')
|
||||
raise ConnectionRefusedError('No openhands_auth cookie')
|
||||
if not config.jwt_secret:
|
||||
raise RuntimeError('JWT secret not found')
|
||||
|
||||
jwt_secret = (
|
||||
config.jwt_secret.get_secret_value()
|
||||
if isinstance(config.jwt_secret, SecretStr)
|
||||
else config.jwt_secret
|
||||
)
|
||||
decoded = jwt.decode(signed_token, jwt_secret, algorithms=['HS256'])
|
||||
user_id = decoded['github_user_id']
|
||||
|
||||
logger.info(f'User {user_id} is connecting to conversation {conversation_id}')
|
||||
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
|
||||
if metadata.github_user_id != str(user_id):
|
||||
logger.error(
|
||||
f'User {user_id} is not allowed to join conversation {conversation_id}'
|
||||
)
|
||||
raise ConnectionRefusedError(
|
||||
f'User {user_id} is not allowed to join conversation {conversation_id}'
|
||||
)
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
conversation_validator = ConversationValidatorImpl()
|
||||
user_id = await conversation_validator.validate(conversation_id, cookies_str)
|
||||
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
|
||||
@@ -10,7 +10,7 @@ from openhands.integrations.github.github_types import (
|
||||
GitHubUser,
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
|
||||
|
||||
app = APIRouter(prefix='/api/github')
|
||||
|
||||
@@ -23,8 +23,11 @@ async def get_github_repositories(
|
||||
installation_id: int | None = None,
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
):
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
@@ -48,8 +51,11 @@ async def get_github_repositories(
|
||||
async def get_github_user(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
):
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
)
|
||||
try:
|
||||
user: GitHubUser = await client.get_user()
|
||||
return user
|
||||
@@ -71,8 +77,11 @@ async def get_github_user(
|
||||
async def get_github_installation_ids(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
):
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
)
|
||||
try:
|
||||
installations_ids: list[int] = await client.get_installation_ids()
|
||||
return installations_ids
|
||||
@@ -98,8 +107,11 @@ async def search_github_repositories(
|
||||
order: str = 'desc',
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
):
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.search_repositories(
|
||||
query, per_page, sort, order
|
||||
@@ -123,14 +135,17 @@ async def search_github_repositories(
|
||||
async def get_suggested_tasks(
|
||||
github_user_id: str | None = Depends(get_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
idp_token: SecretStr | None = Depends(get_idp_token),
|
||||
):
|
||||
"""
|
||||
Get suggested tasks for the authenticated user across their most recently pushed repositories.
|
||||
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
|
||||
|
||||
Returns:
|
||||
- PRs owned by the user
|
||||
- Issues assigned to the user
|
||||
- Issues assigned to the user.
|
||||
"""
|
||||
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id, idp_token=idp_token, token=github_user_token
|
||||
)
|
||||
try:
|
||||
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
|
||||
return tasks
|
||||
|
||||
@@ -11,7 +11,7 @@ from openhands.events.action.message import MessageAction
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -138,7 +138,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
user_id = get_user_id(request)
|
||||
gh_client = GithubServiceImpl(user_id=user_id, token=get_github_token(request))
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=user_id,
|
||||
idp_token=get_idp_token(request),
|
||||
token=get_github_token(request),
|
||||
)
|
||||
github_token = await gh_client.get_latest_token()
|
||||
|
||||
selected_repository = data.selected_repository
|
||||
|
||||
@@ -52,7 +52,7 @@ async def store_settings(
|
||||
# We check if the token is valid by getting the user
|
||||
# If the token is invalid, this will raise an exception
|
||||
github = GithubServiceImpl(
|
||||
user_id=None, token=SecretStr(settings.github_token)
|
||||
user_id=None, idp_token=None, token=SecretStr(settings.github_token)
|
||||
)
|
||||
await github.get_user()
|
||||
|
||||
|
||||
17
openhands/storage/conversation/conversation_validator.py
Normal file
17
openhands/storage/conversation/conversation_validator.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class ConversationValidator:
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
async def validate(self, conversation_id: str, cookies_str: str):
|
||||
return None
|
||||
|
||||
|
||||
conversation_validator_cls = os.environ.get(
|
||||
'OPENHANDS_CONVERSATION_VALIDATOR_CLS',
|
||||
'openhands.storage.conversation.conversation_validator.ConversationValidator',
|
||||
)
|
||||
ConversationValidatorImpl = get_impl(ConversationValidator, conversation_validator_cls)
|
||||
Reference in New Issue
Block a user