Retrieve GitHub IDs more efficiently (#6074)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Robert Brennan 2025-01-06 14:22:52 -05:00 committed by GitHub
parent 09734467c0
commit 343b86429e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 58 additions and 64 deletions

View File

@ -62,16 +62,13 @@ const WsClientContext = React.createContext<UseWsClient>({
interface WsClientProviderProps {
conversationId: string;
ghToken: string | null;
}
export function WsClientProvider({
ghToken,
conversationId,
children,
}: React.PropsWithChildren<WsClientProviderProps>) {
const sioRef = React.useRef<Socket | null>(null);
const ghTokenRef = React.useRef<string | null>(ghToken);
const [status, setStatus] = React.useState(
WsClientProviderStatus.DISCONNECTED,
);
@ -141,9 +138,6 @@ export function WsClientProvider({
sio = io(baseUrl, {
transports: ["websocket"],
auth: {
github_token: ghToken || undefined,
},
query,
});
sio.on("connect", handleConnect);
@ -153,7 +147,6 @@ export function WsClientProvider({
sio.on("disconnect", handleDisconnect);
sioRef.current = sio;
ghTokenRef.current = ghToken;
return () => {
sio.off("connect", handleConnect);
@ -162,7 +155,7 @@ export function WsClientProvider({
sio.off("connect_failed", handleError);
sio.off("disconnect", handleDisconnect);
};
}, [ghToken, conversationId]);
}, [conversationId]);
React.useEffect(
() => () => {

View File

@ -175,7 +175,7 @@ function AppContent() {
}
return (
<WsClientProvider ghToken={gitHubToken} conversationId={conversationId}>
<WsClientProvider conversationId={conversationId}>
<EventHandler>
<div data-testid="app-route" className="flex flex-col h-full gap-3">
<div className="flex h-full overflow-auto">{renderMain()}</div>

View File

@ -1,9 +1,14 @@
import jwt
from fastapi import Request
from jwt.exceptions import InvalidTokenError
from openhands.core.logger import openhands_logger as logger
def get_user_id(request: Request) -> int:
return getattr(request.state, 'github_user_id', 0)
def get_sid_from_token(token: str, jwt_secret: str) -> str:
"""Retrieves the session id from a JWT token.

View File

@ -1,6 +1,6 @@
from urllib.parse import parse_qs
from github import Github
import jwt
from socketio.exceptions import ConnectionRefusedError
from openhands.core.logger import openhands_logger as logger
@ -18,7 +18,6 @@ from openhands.server.routes.settings import ConversationStoreImpl, SettingsStor
from openhands.server.session.manager import ConversationDoesNotExistError
from openhands.server.shared import config, openhands_config, session_manager, sio
from openhands.server.types import AppMode
from openhands.utils.async_utils import call_sync_from_async
@sio.event
@ -31,20 +30,20 @@ async def connect(connection_id: str, environ, auth):
logger.error('No conversation_id in query params')
raise ConnectionRefusedError('No conversation_id in query params')
github_token = ''
user_id = -1
if openhands_config.app_mode != AppMode.OSS:
user_id = ''
if auth and 'github_token' in auth:
github_token = auth['github_token']
with Github(github_token) as g:
gh_user = await call_sync_from_async(g.get_user)
user_id = gh_user.id
cookies_str = environ.get('HTTP_COOKIE', '')
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
signed_token = cookies.get('github_auth', '')
if not signed_token:
logger.error('No github_auth cookie')
raise ConnectionRefusedError('No github_auth cookie')
decoded = jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
user_id = decoded['github_user_id']
logger.info(f'User {user_id} is connecting to conversation {conversation_id}')
conversation_store = await ConversationStoreImpl.get_instance(
config, github_token
)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
metadata = await conversation_store.get_metadata(conversation_id)
if metadata.github_user_id != user_id:
logger.error(
@ -54,7 +53,7 @@ async def connect(connection_id: str, environ, auth):
f'User {user_id} is not allowed to join conversation {conversation_id}'
)
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
if not settings:

View File

@ -4,11 +4,11 @@ from typing import Callable
from fastapi import APIRouter, Body, Request
from fastapi.responses import JSONResponse
from github import Github
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import EventStreamSubscriber
from openhands.server.auth import get_user_id
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import config, session_manager
@ -21,7 +21,6 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import (
GENERAL_TIMEOUT,
call_async_from_sync,
call_sync_from_async,
wait_all,
)
@ -43,10 +42,9 @@ async def new_conversation(request: Request, data: InitSessionRequest):
using the returned conversation ID
"""
logger.info('Initializing new conversation')
github_token = data.github_token or ''
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
settings = await settings_store.load()
logger.info('Settings loaded')
@ -54,11 +52,14 @@ async def new_conversation(request: Request, data: InitSessionRequest):
if settings:
session_init_args = {**settings.__dict__, **session_init_args}
github_token = getattr(request.state, 'github_token', '')
session_init_args['github_token'] = github_token
session_init_args['selected_repository'] = data.selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
@ -67,18 +68,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
conversation_id = uuid.uuid4().hex
logger.info(f'New conversation ID: {conversation_id}')
user_id = ''
if data.github_token:
logger.info('Fetching Github user ID')
with Github(data.github_token) as g:
gh_user = await call_sync_from_async(g.get_user)
user_id = gh_user.id
logger.info(f'Saving metadata for conversation {conversation_id}')
await conversation_store.save_metadata(
ConversationMetadata(
conversation_id=conversation_id,
github_user_id=user_id,
github_user_id=get_user_id(request),
selected_repository=data.selected_repository,
)
)
@ -90,9 +84,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
try:
event_stream.subscribe(
EventStreamSubscriber.SERVER,
_create_conversation_update_callback(
data.github_token or '', conversation_id
),
_create_conversation_update_callback(get_user_id(request), conversation_id),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
@ -107,8 +99,9 @@ async def search_conversations(
page_id: str | None = None,
limit: int = 20,
) -> ConversationInfoResultSet:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
conversation_ids = set(
conversation.conversation_id
@ -134,8 +127,9 @@ async def search_conversations(
async def get_conversation(
conversation_id: str, request: Request
) -> ConversationInfo | None:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await session_manager.is_agent_loop_running(conversation_id)
@ -149,8 +143,9 @@ async def get_conversation(
async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
return False
@ -164,8 +159,9 @@ async def delete_conversation(
conversation_id: str,
request: Request,
) -> bool:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
try:
await conversation_store.get_metadata(conversation_id)
except FileNotFoundError:
@ -205,21 +201,21 @@ async def _get_conversation_info(
def _create_conversation_update_callback(
github_token: str, conversation_id: str
user_id: int, conversation_id: str
) -> Callable:
def callback(*args, **kwargs):
call_async_from_sync(
_update_timestamp_for_conversation,
GENERAL_TIMEOUT,
github_token,
user_id,
conversation_id,
)
return callback
async def _update_timestamp_for_conversation(github_token: str, conversation_id: str):
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
async def _update_timestamp_for_conversation(user_id: int, conversation_id: str):
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now()
await conversation_store.save_metadata(conversation)

View File

@ -2,6 +2,7 @@ from fastapi import APIRouter, Request, status
from fastapi.responses import JSONResponse
from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import get_user_id
from openhands.server.settings import Settings
from openhands.server.shared import config, openhands_config
from openhands.storage.conversation.conversation_store import ConversationStore
@ -19,9 +20,10 @@ ConversationStoreImpl = get_impl(
@app.get('/settings')
async def load_settings(request: Request) -> Settings | None:
github_token = getattr(request.state, 'github_token', '') or ''
try:
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
settings = await settings_store.load()
if not settings:
return JSONResponse(
@ -45,11 +47,10 @@ async def store_settings(
request: Request,
settings: Settings,
) -> JSONResponse:
github_token = ''
if hasattr(request.state, 'github_token'):
github_token = request.state.github_token
try:
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings = await settings_store.load()
if existing_settings:

View File

@ -40,7 +40,5 @@ class ConversationStore(ABC):
@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, token: str | None
) -> ConversationStore:
async def get_instance(cls, config: AppConfig, user_id: int) -> ConversationStore:
"""Get a store for the user represented by the token given"""

View File

@ -90,7 +90,9 @@ class FileConversationStore(ConversationStore):
return get_conversation_metadata_filename(conversation_id)
@classmethod
async def get_instance(cls, config: AppConfig, token: str | None):
async def get_instance(
cls, config: AppConfig, user_id: int
) -> FileConversationStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileConversationStore(file_store)

View File

@ -5,7 +5,7 @@ from datetime import datetime
@dataclass
class ConversationMetadata:
conversation_id: str
github_user_id: int | str
github_user_id: int
selected_repository: str | None
title: str | None = None
last_updated_at: datetime | None = None

View File

@ -30,6 +30,6 @@ class FileSettingsStore(SettingsStore):
await call_sync_from_async(self.file_store.write, self.path, json_str)
@classmethod
async def get_instance(cls, config: AppConfig, token: str | None):
async def get_instance(cls, config: AppConfig, user_id: int) -> FileSettingsStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileSettingsStore(file_store)

View File

@ -21,5 +21,5 @@ class SettingsStore(ABC):
@classmethod
@abstractmethod
async def get_instance(cls, config: AppConfig, token: str | None) -> SettingsStore:
async def get_instance(cls, config: AppConfig, user_id: int) -> SettingsStore:
"""Get a store for the user represented by the token given"""

View File

@ -28,7 +28,7 @@ def _patch_store():
'title': 'Some Conversation',
'selected_repository': 'foobar',
'conversation_id': 'some_conversation_id',
'github_user_id': 'github_user',
'github_user_id': 12345,
'created_at': '2025-01-01T00:00:00',
'last_updated_at': '2025-01-01T00:01:00',
}