Feature: User id propagation (#6233)

This commit is contained in:
tofarr
2025-01-13 11:10:45 -07:00
committed by GitHub
parent 0b74fd71d9
commit 5a809c9b53
9 changed files with 39 additions and 27 deletions

View File

@@ -30,7 +30,7 @@ async def connect(connection_id: str, environ, auth):
logger.error('No conversation_id in query params')
raise ConnectionRefusedError('No conversation_id in query params')
user_id = -1
user_id = None
if openhands_config.app_mode != AppMode.OSS:
cookies_str = environ.get('HTTP_COOKIE', '')
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
@@ -63,7 +63,7 @@ async def connect(connection_id: str, environ, auth):
try:
event_stream = await session_manager.join_conversation(
conversation_id, connection_id, settings
conversation_id, connection_id, settings, user_id
)
except ConversationDoesNotExistError:
logger.error(f'Conversation {conversation_id} does not exist')

View File

@@ -42,7 +42,8 @@ async def new_conversation(request: Request, data: InitSessionRequest):
logger.info('Initializing new conversation')
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
user_id = get_user_id(request)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
logger.info('Settings loaded')
@@ -55,9 +56,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
session_init_args['selected_repository'] = data.selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
@@ -76,19 +75,19 @@ async def new_conversation(request: Request, data: InitSessionRequest):
ConversationMetadata(
conversation_id=conversation_id,
title=conversation_title,
github_user_id=get_user_id(request),
github_user_id=user_id,
selected_repository=data.selected_repository,
)
)
logger.info(f'Starting agent loop for conversation {conversation_id}')
event_stream = await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data
conversation_id, conversation_init_data, user_id
)
try:
event_stream.subscribe(
EventStreamSubscriber.SERVER,
_create_conversation_update_callback(get_user_id(request), conversation_id),
_create_conversation_update_callback(user_id, conversation_id),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
@@ -113,7 +112,7 @@ async def search_conversations(
if hasattr(conversation, 'created_at')
)
running_conversations = await session_manager.get_agent_loop_running(
set(conversation_ids)
get_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(

View File

@@ -209,13 +209,15 @@ class SessionManager:
self._active_conversations[sid] = (c, 1)
return c
async def join_conversation(self, sid: str, connection_id: str, settings: Settings):
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: int | None
):
logger.info(f'join_conversation:{sid}:{connection_id}')
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
if not event_stream:
return await self.maybe_start_agent_loop(sid, settings)
return await self.maybe_start_agent_loop(sid, settings, user_id)
return event_stream
async def detach_from_conversation(self, conversation: Conversation):
@@ -265,7 +267,7 @@ class SessionManager:
logger.warning('error_cleaning_detached_conversations', exc_info=True)
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
async def get_agent_loop_running(self, sids: set[str]) -> set[str]:
async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]:
running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
check_cluster_sids = [sid for sid in sids if sid not in running_sids]
running_cluster_sids = await self.get_agent_loop_running_in_cluster(
@@ -346,7 +348,9 @@ class SessionManager:
finally:
self._has_remote_connections_flags.pop(sid, None)
async def maybe_start_agent_loop(self, sid: str, settings: Settings) -> EventStream:
async def maybe_start_agent_loop(
self, sid: str, settings: Settings, user_id: int | None
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None
if not await self.is_agent_loop_running(sid):

View File

@@ -37,6 +37,7 @@ class Session:
loop: asyncio.AbstractEventLoop
config: AppConfig
file_store: FileStore
user_id: int | None
def __init__(
self,
@@ -44,6 +45,7 @@ class Session:
config: AppConfig,
file_store: FileStore,
sio: socketio.AsyncServer | None,
user_id: int | None = None,
):
self.sid = sid
self.sio = sio
@@ -58,6 +60,7 @@ class Session:
# Copying this means that when we update variables they are not applied to the shared global configuration!
self.config = deepcopy(config)
self.loop = asyncio.get_event_loop()
self.user_id = user_id
def close(self):
self.is_alive = False