mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Feature: User id propagation (#6233)
This commit is contained in:
parent
0b74fd71d9
commit
5a809c9b53
@ -30,7 +30,7 @@ async def connect(connection_id: str, environ, auth):
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
|
||||
user_id = -1
|
||||
user_id = None
|
||||
if openhands_config.app_mode != AppMode.OSS:
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
|
||||
@ -63,7 +63,7 @@ async def connect(connection_id: str, environ, auth):
|
||||
|
||||
try:
|
||||
event_stream = await session_manager.join_conversation(
|
||||
conversation_id, connection_id, settings
|
||||
conversation_id, connection_id, settings, user_id
|
||||
)
|
||||
except ConversationDoesNotExistError:
|
||||
logger.error(f'Conversation {conversation_id} does not exist')
|
||||
|
||||
@ -42,7 +42,8 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
logger.info('Initializing new conversation')
|
||||
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
user_id = get_user_id(request)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
@ -55,9 +56,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
session_init_args['selected_repository'] = data.selected_repository
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
@ -76,19 +75,19 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
ConversationMetadata(
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
github_user_id=get_user_id(request),
|
||||
github_user_id=user_id,
|
||||
selected_repository=data.selected_repository,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f'Starting agent loop for conversation {conversation_id}')
|
||||
event_stream = await session_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data
|
||||
conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
try:
|
||||
event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
_create_conversation_update_callback(get_user_id(request), conversation_id),
|
||||
_create_conversation_update_callback(user_id, conversation_id),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@ -113,7 +112,7 @@ async def search_conversations(
|
||||
if hasattr(conversation, 'created_at')
|
||||
)
|
||||
running_conversations = await session_manager.get_agent_loop_running(
|
||||
set(conversation_ids)
|
||||
get_user_id(request), set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
|
||||
@ -209,13 +209,15 @@ class SessionManager:
|
||||
self._active_conversations[sid] = (c, 1)
|
||||
return c
|
||||
|
||||
async def join_conversation(self, sid: str, connection_id: str, settings: Settings):
|
||||
async def join_conversation(
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: int | None
|
||||
):
|
||||
logger.info(f'join_conversation:{sid}:{connection_id}')
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self.local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
return await self.maybe_start_agent_loop(sid, settings)
|
||||
return await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
return event_stream
|
||||
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
@ -265,7 +267,7 @@ class SessionManager:
|
||||
logger.warning('error_cleaning_detached_conversations', exc_info=True)
|
||||
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
|
||||
|
||||
async def get_agent_loop_running(self, sids: set[str]) -> set[str]:
|
||||
async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]:
|
||||
running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
|
||||
check_cluster_sids = [sid for sid in sids if sid not in running_sids]
|
||||
running_cluster_sids = await self.get_agent_loop_running_in_cluster(
|
||||
@ -346,7 +348,9 @@ class SessionManager:
|
||||
finally:
|
||||
self._has_remote_connections_flags.pop(sid, None)
|
||||
|
||||
async def maybe_start_agent_loop(self, sid: str, settings: Settings) -> EventStream:
|
||||
async def maybe_start_agent_loop(
|
||||
self, sid: str, settings: Settings, user_id: int | None
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}')
|
||||
session: Session | None = None
|
||||
if not await self.is_agent_loop_running(sid):
|
||||
|
||||
@ -37,6 +37,7 @@ class Session:
|
||||
loop: asyncio.AbstractEventLoop
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
user_id: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -44,6 +45,7 @@ class Session:
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
sio: socketio.AsyncServer | None,
|
||||
user_id: int | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.sio = sio
|
||||
@ -58,6 +60,7 @@ class Session:
|
||||
# Copying this means that when we update variables they are not applied to the shared global configuration!
|
||||
self.config = deepcopy(config)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.user_id = user_id
|
||||
|
||||
def close(self):
|
||||
self.is_alive = False
|
||||
|
||||
@ -40,5 +40,7 @@ class ConversationStore(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(cls, config: AppConfig, user_id: int) -> ConversationStore:
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
) -> ConversationStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
@ -94,7 +94,7 @@ class FileConversationStore(ConversationStore):
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
@ -30,6 +30,8 @@ class FileSettingsStore(SettingsStore):
|
||||
await call_sync_from_async(self.file_store.write, self.path, json_str)
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, config: AppConfig, user_id: int) -> FileSettingsStore:
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
) -> FileSettingsStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileSettingsStore(file_store)
|
||||
|
||||
@ -21,5 +21,7 @@ class SettingsStore(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(cls, config: AppConfig, user_id: int) -> SettingsStore:
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: int | None
|
||||
) -> SettingsStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
@ -114,10 +114,10 @@ async def test_init_new_local_session():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData()
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData()
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 1
|
||||
@ -148,13 +148,13 @@ async def test_join_local_session():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData()
|
||||
'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData()
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData()
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 2
|
||||
@ -185,7 +185,7 @@ async def test_join_cluster_session():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData()
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 0
|
||||
assert sio.enter_room.await_count == 1
|
||||
@ -216,10 +216,10 @@ async def test_add_to_local_event_stream():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData()
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'connection-id', ConversationInitData()
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
@ -252,7 +252,7 @@ async def test_add_to_cluster_event_stream():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'connection-id', ConversationInitData()
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user