mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Retrieve GitHub IDs more efficiently (#6074)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
09734467c0
commit
343b86429e
@ -62,16 +62,13 @@ const WsClientContext = React.createContext<UseWsClient>({
|
||||
|
||||
interface WsClientProviderProps {
|
||||
conversationId: string;
|
||||
ghToken: string | null;
|
||||
}
|
||||
|
||||
export function WsClientProvider({
|
||||
ghToken,
|
||||
conversationId,
|
||||
children,
|
||||
}: React.PropsWithChildren<WsClientProviderProps>) {
|
||||
const sioRef = React.useRef<Socket | null>(null);
|
||||
const ghTokenRef = React.useRef<string | null>(ghToken);
|
||||
const [status, setStatus] = React.useState(
|
||||
WsClientProviderStatus.DISCONNECTED,
|
||||
);
|
||||
@ -141,9 +138,6 @@ export function WsClientProvider({
|
||||
|
||||
sio = io(baseUrl, {
|
||||
transports: ["websocket"],
|
||||
auth: {
|
||||
github_token: ghToken || undefined,
|
||||
},
|
||||
query,
|
||||
});
|
||||
sio.on("connect", handleConnect);
|
||||
@ -153,7 +147,6 @@ export function WsClientProvider({
|
||||
sio.on("disconnect", handleDisconnect);
|
||||
|
||||
sioRef.current = sio;
|
||||
ghTokenRef.current = ghToken;
|
||||
|
||||
return () => {
|
||||
sio.off("connect", handleConnect);
|
||||
@ -162,7 +155,7 @@ export function WsClientProvider({
|
||||
sio.off("connect_failed", handleError);
|
||||
sio.off("disconnect", handleDisconnect);
|
||||
};
|
||||
}, [ghToken, conversationId]);
|
||||
}, [conversationId]);
|
||||
|
||||
React.useEffect(
|
||||
() => () => {
|
||||
|
||||
@ -175,7 +175,7 @@ function AppContent() {
|
||||
}
|
||||
|
||||
return (
|
||||
<WsClientProvider ghToken={gitHubToken} conversationId={conversationId}>
|
||||
<WsClientProvider conversationId={conversationId}>
|
||||
<EventHandler>
|
||||
<div data-testid="app-route" className="flex flex-col h-full gap-3">
|
||||
<div className="flex h-full overflow-auto">{renderMain()}</div>
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def get_user_id(request: Request) -> int:
|
||||
return getattr(request.state, 'github_user_id', 0)
|
||||
|
||||
|
||||
def get_sid_from_token(token: str, jwt_secret: str) -> str:
|
||||
"""Retrieves the session id from a JWT token.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from github import Github
|
||||
import jwt
|
||||
from socketio.exceptions import ConnectionRefusedError
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@ -18,7 +18,6 @@ from openhands.server.routes.settings import ConversationStoreImpl, SettingsStor
|
||||
from openhands.server.session.manager import ConversationDoesNotExistError
|
||||
from openhands.server.shared import config, openhands_config, session_manager, sio
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@sio.event
|
||||
@ -31,20 +30,20 @@ async def connect(connection_id: str, environ, auth):
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
|
||||
github_token = ''
|
||||
user_id = -1
|
||||
if openhands_config.app_mode != AppMode.OSS:
|
||||
user_id = ''
|
||||
if auth and 'github_token' in auth:
|
||||
github_token = auth['github_token']
|
||||
with Github(github_token) as g:
|
||||
gh_user = await call_sync_from_async(g.get_user)
|
||||
user_id = gh_user.id
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
|
||||
signed_token = cookies.get('github_auth', '')
|
||||
if not signed_token:
|
||||
logger.error('No github_auth cookie')
|
||||
raise ConnectionRefusedError('No github_auth cookie')
|
||||
decoded = jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
|
||||
user_id = decoded['github_user_id']
|
||||
|
||||
logger.info(f'User {user_id} is connecting to conversation {conversation_id}')
|
||||
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, github_token
|
||||
)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if metadata.github_user_id != user_id:
|
||||
logger.error(
|
||||
@ -54,7 +53,7 @@ async def connect(connection_id: str, environ, auth):
|
||||
f'User {user_id} is not allowed to join conversation {conversation_id}'
|
||||
)
|
||||
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
|
||||
if not settings:
|
||||
|
||||
@ -4,11 +4,11 @@ from typing import Callable
|
||||
|
||||
from fastapi import APIRouter, Body, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from github import Github
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import config, session_manager
|
||||
@ -21,7 +21,6 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import (
|
||||
GENERAL_TIMEOUT,
|
||||
call_async_from_sync,
|
||||
call_sync_from_async,
|
||||
wait_all,
|
||||
)
|
||||
|
||||
@ -43,10 +42,9 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
using the returned conversation ID
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
github_token = data.github_token or ''
|
||||
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
@ -54,11 +52,14 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
github_token = getattr(request.state, 'github_token', '')
|
||||
session_init_args['github_token'] = github_token
|
||||
session_init_args['selected_repository'] = data.selected_repository
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
@ -67,18 +68,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
conversation_id = uuid.uuid4().hex
|
||||
logger.info(f'New conversation ID: {conversation_id}')
|
||||
|
||||
user_id = ''
|
||||
if data.github_token:
|
||||
logger.info('Fetching Github user ID')
|
||||
with Github(data.github_token) as g:
|
||||
gh_user = await call_sync_from_async(g.get_user)
|
||||
user_id = gh_user.id
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
conversation_id=conversation_id,
|
||||
github_user_id=user_id,
|
||||
github_user_id=get_user_id(request),
|
||||
selected_repository=data.selected_repository,
|
||||
)
|
||||
)
|
||||
@ -90,9 +84,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
try:
|
||||
event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
_create_conversation_update_callback(
|
||||
data.github_token or '', conversation_id
|
||||
),
|
||||
_create_conversation_update_callback(get_user_id(request), conversation_id),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@ -107,8 +99,9 @@ async def search_conversations(
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ConversationInfoResultSet:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
conversation_ids = set(
|
||||
conversation.conversation_id
|
||||
@ -134,8 +127,9 @@ async def search_conversations(
|
||||
async def get_conversation(
|
||||
conversation_id: str, request: Request
|
||||
) -> ConversationInfo | None:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await session_manager.is_agent_loop_running(conversation_id)
|
||||
@ -149,8 +143,9 @@ async def get_conversation(
|
||||
async def update_conversation(
|
||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||
) -> bool:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
return False
|
||||
@ -164,8 +159,9 @@ async def delete_conversation(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
) -> bool:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
except FileNotFoundError:
|
||||
@ -205,21 +201,21 @@ async def _get_conversation_info(
|
||||
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
github_token: str, conversation_id: str
|
||||
user_id: int, conversation_id: str
|
||||
) -> Callable:
|
||||
def callback(*args, **kwargs):
|
||||
call_async_from_sync(
|
||||
_update_timestamp_for_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
github_token,
|
||||
user_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
async def _update_timestamp_for_conversation(github_token: str, conversation_id: str):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
|
||||
async def _update_timestamp_for_conversation(user_id: int, conversation_id: str):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
conversation.last_updated_at = datetime.now()
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
@ -2,6 +2,7 @@ from fastapi import APIRouter, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.shared import config, openhands_config
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
@ -19,9 +20,10 @@ ConversationStoreImpl = get_impl(
|
||||
|
||||
@app.get('/settings')
|
||||
async def load_settings(request: Request) -> Settings | None:
|
||||
github_token = getattr(request.state, 'github_token', '') or ''
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
return JSONResponse(
|
||||
@ -45,11 +47,10 @@ async def store_settings(
|
||||
request: Request,
|
||||
settings: Settings,
|
||||
) -> JSONResponse:
|
||||
github_token = ''
|
||||
if hasattr(request.state, 'github_token'):
|
||||
github_token = request.state.github_token
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
if existing_settings:
|
||||
|
||||
@ -40,7 +40,5 @@ class ConversationStore(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, token: str | None
|
||||
) -> ConversationStore:
|
||||
async def get_instance(cls, config: AppConfig, user_id: int) -> ConversationStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
@ -90,7 +90,9 @@ class FileConversationStore(ConversationStore):
|
||||
return get_conversation_metadata_filename(conversation_id)
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, config: AppConfig, token: str | None):
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from datetime import datetime
|
||||
@dataclass
|
||||
class ConversationMetadata:
|
||||
conversation_id: str
|
||||
github_user_id: int | str
|
||||
github_user_id: int
|
||||
selected_repository: str | None
|
||||
title: str | None = None
|
||||
last_updated_at: datetime | None = None
|
||||
|
||||
@ -30,6 +30,6 @@ class FileSettingsStore(SettingsStore):
|
||||
await call_sync_from_async(self.file_store.write, self.path, json_str)
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, config: AppConfig, token: str | None):
|
||||
async def get_instance(cls, config: AppConfig, user_id: int) -> FileSettingsStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileSettingsStore(file_store)
|
||||
|
||||
@ -21,5 +21,5 @@ class SettingsStore(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(cls, config: AppConfig, token: str | None) -> SettingsStore:
|
||||
async def get_instance(cls, config: AppConfig, user_id: int) -> SettingsStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
@ -28,7 +28,7 @@ def _patch_store():
|
||||
'title': 'Some Conversation',
|
||||
'selected_repository': 'foobar',
|
||||
'conversation_id': 'some_conversation_id',
|
||||
'github_user_id': 'github_user',
|
||||
'github_user_id': 12345,
|
||||
'created_at': '2025-01-01T00:00:00',
|
||||
'last_updated_at': '2025-01-01T00:01:00',
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user