diff --git a/openhands/integrations/github/github_service.py b/openhands/integrations/github/github_service.py index 1ccdc641d6..c28d2f56ac 100644 --- a/openhands/integrations/github/github_service.py +++ b/openhands/integrations/github/github_service.py @@ -5,6 +5,7 @@ from typing import Any import httpx from pydantic import SecretStr +from openhands.core.logger import openhands_logger as logger from openhands.integrations.service_types import ( AuthenticationError, GitService, @@ -15,7 +16,7 @@ from openhands.integrations.service_types import ( User, ) from openhands.utils.import_utils import get_impl -from openhands.core.logger import openhands_logger as logger + class GitHubService(GitService): BASE_URL = 'https://api.github.com' @@ -25,6 +26,7 @@ class GitHubService(GitService): def __init__( self, user_id: str | None = None, + external_auth_id: str | None = None, external_auth_token: SecretStr | None = None, token: SecretStr | None = None, external_token_manager: bool = False, diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index aee48fdba6..d4a078b8ef 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -97,7 +97,7 @@ class Runtime(FileEditRuntimeMixin): status_callback: Callable | None = None, attach_to_existing: bool = False, headless_mode: bool = False, - github_user_id: str | None = None, + user_id: str | None = None, ): self.sid = sid self.event_stream = event_stream @@ -130,7 +130,7 @@ class Runtime(FileEditRuntimeMixin): self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor ) - self.github_user_id = github_user_id + self.user_id = user_id def setup_initial_env(self) -> None: if self.attach_to_existing: @@ -220,9 +220,9 @@ class Runtime(FileEditRuntimeMixin): assert event.timeout is not None try: if isinstance(event, CmdRunAction): - if self.github_user_id and '$GITHUB_TOKEN' in event.command: + if self.user_id and '$GITHUB_TOKEN' in event.command: gh_client = GithubServiceImpl( - user_id=self.github_user_id, external_token_manager=True + external_auth_id=self.user_id, external_token_manager=True ) token = await gh_client.get_latest_token() if token: diff --git a/openhands/runtime/impl/action_execution/action_execution_client.py b/openhands/runtime/impl/action_execution/action_execution_client.py index 351fb1d03a..bcdbe2be15 100644 --- a/openhands/runtime/impl/action_execution/action_execution_client.py +++ b/openhands/runtime/impl/action_execution/action_execution_client.py @@ -59,7 +59,7 @@ class ActionExecutionClient(Runtime): status_callback: Any | None = None, attach_to_existing: bool = False, headless_mode: bool = True, - github_user_id: str | None = None, + user_id: str | None = None, ): self.session = HttpSession() self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time @@ -75,7 +75,7 @@ class ActionExecutionClient(Runtime): status_callback, attach_to_existing, headless_mode, - github_user_id, + user_id, ) @abstractmethod diff --git a/openhands/runtime/impl/remote/remote_runtime.py b/openhands/runtime/impl/remote/remote_runtime.py index 6e32e889e6..428760001e 100644 --- a/openhands/runtime/impl/remote/remote_runtime.py +++ b/openhands/runtime/impl/remote/remote_runtime.py @@ -45,7 +45,7 @@ class RemoteRuntime(ActionExecutionClient): status_callback: Callable | None = None, attach_to_existing: bool = False, headless_mode: bool = True, - github_user_id: str | None = None, + user_id: str | None = None, ): super().__init__( config, @@ -56,7 +56,7 @@ class RemoteRuntime(ActionExecutionClient): status_callback, attach_to_existing, headless_mode, - github_user_id, + user_id, ) if self.config.sandbox.api_key is None: raise ValueError( diff --git a/openhands/server/conversation_manager/conversation_manager.py b/openhands/server/conversation_manager/conversation_manager.py index 35c207de9f..d152a936f6 100644 --- a/openhands/server/conversation_manager/conversation_manager.py +++ b/openhands/server/conversation_manager/conversation_manager.py @@ -46,7 +46,12 @@ class ConversationManager(ABC): @abstractmethod async def join_conversation( - self, sid: str, connection_id: str, settings: Settings, user_id: str | None + self, + sid: str, + connection_id: str, + settings: Settings, + user_id: str | None, + github_user_id: str | None, ) -> EventStream | None: """Join a conversation and return its event stream.""" @@ -74,6 +79,7 @@ class ConversationManager(ABC): settings: Settings, user_id: str | None, initial_user_msg: MessageAction | None = None, + github_user_id: str | None = None, ) -> EventStream: """Start an event loop if one is not already running""" diff --git a/openhands/server/conversation_manager/standalone_conversation_manager.py b/openhands/server/conversation_manager/standalone_conversation_manager.py index 6f876d4095..b36f4025f9 100644 --- a/openhands/server/conversation_manager/standalone_conversation_manager.py +++ b/openhands/server/conversation_manager/standalone_conversation_manager.py @@ -106,7 +106,12 @@ class StandaloneConversationManager(ConversationManager): return c async def join_conversation( - self, sid: str, connection_id: str, settings: Settings, user_id: str | None + self, + sid: str, + connection_id: str, + settings: Settings, + user_id: str | None, + github_user_id: str | None, ): logger.info( f'join_conversation:{sid}:{connection_id}', @@ -116,7 +121,9 @@ class StandaloneConversationManager(ConversationManager): self._local_connection_id_to_session_id[connection_id] = sid event_stream = await self._get_event_stream(sid) if not event_stream: - return await self.maybe_start_agent_loop(sid, settings, user_id) + return await self.maybe_start_agent_loop( + sid, settings, user_id, github_user_id=github_user_id + ) for event in event_stream.get_events(reverse=True): if isinstance(event, AgentStateChangedObservation): if event.agent_state in ( @@ -187,14 +194,18 @@ class StandaloneConversationManager(ConversationManager): logger.error('error_cleaning_stale') await asyncio.sleep(_CLEANUP_INTERVAL) - async def _get_conversation_store(self, user_id: str | None) -> ConversationStore: + async def _get_conversation_store( + self, user_id: str | None, github_user_id: str | None + ) -> ConversationStore: conversation_store_class = self._conversation_store_class if not conversation_store_class: self._conversation_store_class = conversation_store_class = get_impl( ConversationStore, # type: ignore self.server_config.conversation_store_class, ) - store = await conversation_store_class.get_instance(self.config, user_id) + store = await conversation_store_class.get_instance( + self.config, user_id, github_user_id + ) return store async def get_running_agent_loops( @@ -243,6 +254,7 @@ class StandaloneConversationManager(ConversationManager): settings: Settings, user_id: str | None, initial_user_msg: MessageAction | None = None, + github_user_id: str | None = None, ) -> EventStream: logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid}) session: Session | None = None @@ -256,7 +268,9 @@ class StandaloneConversationManager(ConversationManager): extra={'session_id': sid, 'user_id': user_id}, ) # Get the conversations sorted (oldest first) - conversation_store = await self._get_conversation_store(user_id) + conversation_store = await self._get_conversation_store( + user_id, github_user_id + ) conversations = await conversation_store.get_all_metadata(response_ids) conversations.sort(key=_last_updated_at_key, reverse=True) @@ -277,7 +291,9 @@ class StandaloneConversationManager(ConversationManager): try: session.agent_session.event_stream.subscribe( EventStreamSubscriber.SERVER, - self._create_conversation_update_callback(user_id, sid), + self._create_conversation_update_callback( + user_id, github_user_id, sid + ), UPDATED_AT_CALLBACK_ID, ) except ValueError: @@ -374,22 +390,23 @@ class StandaloneConversationManager(ConversationManager): ) def _create_conversation_update_callback( - self, user_id: str | None, conversation_id: str + self, user_id: str | None, github_user_id: str | None, conversation_id: str ) -> Callable: def callback(*args, **kwargs): call_async_from_sync( self._update_timestamp_for_conversation, GENERAL_TIMEOUT, user_id, + github_user_id, conversation_id, ) return callback async def _update_timestamp_for_conversation( - self, user_id: str, conversation_id: str + self, user_id: str, github_user_id: str, conversation_id: str ): - conversation_store = await self._get_conversation_store(user_id) + conversation_store = await self._get_conversation_store(user_id, github_user_id) conversation = await conversation_store.get_metadata(conversation_id) conversation.last_updated_at = datetime.now(timezone.utc) await conversation_store.save_metadata(conversation) diff --git a/openhands/server/listen_socket.py b/openhands/server/listen_socket.py index f89157856b..0b07c53e94 100644 --- a/openhands/server/listen_socket.py +++ b/openhands/server/listen_socket.py @@ -35,7 +35,9 @@ async def connect(connection_id: str, environ): cookies_str = environ.get('HTTP_COOKIE', '') conversation_validator = ConversationValidatorImpl() - user_id = await conversation_validator.validate(conversation_id, cookies_str) + user_id, github_user_id = await conversation_validator.validate( + conversation_id, cookies_str + ) settings_store = await SettingsStoreImpl.get_instance(config, user_id) settings = await settings_store.load() @@ -46,7 +48,7 @@ async def connect(connection_id: str, environ): ) event_stream = await conversation_manager.join_conversation( - conversation_id, connection_id, settings, user_id + conversation_id, connection_id, settings, user_id, github_user_id ) agent_state_changed = None diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index ae5c714bed..91c872bfab 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -10,7 +10,12 @@ from openhands.events.action.message import MessageAction from openhands.integrations.github.github_service import GithubServiceImpl from openhands.integrations.provider import ProviderType from openhands.runtime import get_runtime_cls -from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id +from openhands.server.auth import ( + get_access_token, + get_github_user_id, + get_provider_tokens, + get_user_id, +) from openhands.server.data_models.conversation_info import ConversationInfo from openhands.server.data_models.conversation_info_result_set import ( ConversationInfoResultSet, @@ -73,12 +78,12 @@ async def _create_new_conversation( logger.warn('Settings not present, not starting conversation') raise MissingSettingsError('Settings not found') - session_init_args['github_token'] = token or SecretStr('') + session_init_args['provider_token'] = token session_init_args['selected_repository'] = selected_repository session_init_args['selected_branch'] = selected_branch conversation_init_data = ConversationInitData(**session_init_args) logger.info('Loading conversation store') - conversation_store = await ConversationStoreImpl.get_instance(config, user_id) + conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None) logger.info('Conversation store loaded') conversation_id = uuid.uuid4().hex @@ -100,7 +105,8 @@ async def _create_new_conversation( ConversationMetadata( conversation_id=conversation_id, title=conversation_title, - github_user_id=user_id, + user_id=user_id, + github_user_id=None, selected_repository=selected_repository, selected_branch=selected_branch, ) @@ -122,7 +128,10 @@ async def _create_new_conversation( image_urls=image_urls or [], ) await conversation_manager.maybe_start_agent_loop( - conversation_id, conversation_init_data, user_id, initial_message_action + conversation_id, + conversation_init_data, + user_id, + initial_user_msg=initial_message_action, ) logger.info(f'Finished initializing conversation {conversation_id}') @@ -158,7 +167,7 @@ async def new_conversation(request: Request, data: InitSessionRequest): try: # Create conversation with initial message conversation_id = await _create_new_conversation( - user_id, + get_user_id(request), github_token, selected_repository, selected_branch, @@ -197,7 +206,7 @@ async def search_conversations( limit: int = 20, ) -> ConversationInfoResultSet: conversation_store = await ConversationStoreImpl.get_instance( - config, get_github_user_id(request) + config, get_user_id(request), get_github_user_id(request) ) conversation_metadata_result_set = await conversation_store.search(page_id, limit) @@ -216,7 +225,7 @@ async def search_conversations( conversation.conversation_id for conversation in filtered_results ) running_conversations = await conversation_manager.get_running_agent_loops( - get_github_user_id(request), set(conversation_ids) + get_user_id(request), set(conversation_ids) ) result = ConversationInfoResultSet( results=await wait_all( @@ -236,7 +245,7 @@ async def get_conversation( conversation_id: str, request: Request ) -> ConversationInfo | None: conversation_store = await ConversationStoreImpl.get_instance( - config, get_github_user_id(request) + config, get_user_id(request), get_github_user_id(request) ) try: metadata = await conversation_store.get_metadata(conversation_id) @@ -252,7 +261,7 @@ async def update_conversation( request: Request, conversation_id: str, title: str = Body(embed=True) ) -> bool: conversation_store = await ConversationStoreImpl.get_instance( - config, get_github_user_id(request) + config, get_user_id(request), get_github_user_id(request) ) metadata = await conversation_store.get_metadata(conversation_id) if not metadata: @@ -268,7 +277,7 @@ async def delete_conversation( request: Request, ) -> bool: conversation_store = await ConversationStoreImpl.get_instance( - config, get_github_user_id(request) + config, get_user_id(request), get_github_user_id(request) ) try: await conversation_store.get_metadata(conversation_id) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 02466a7ad9..58518ab7a3 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -53,7 +53,7 @@ class AgentSession: sid: str, file_store: FileStore, status_callback: Callable | None = None, - github_user_id: str | None = None, + user_id: str | None = None, ): """Initializes a new instance of the Session class @@ -66,9 +66,9 @@ class AgentSession: self.event_stream = EventStream(sid, file_store) self.file_store = file_store self._status_callback = status_callback - self.github_user_id = github_user_id + self.user_id = user_id self.logger = OpenHandsLoggerAdapter( - extra={'session_id': sid, 'user_id': github_user_id} + extra={'session_id': sid, 'user_id': user_id} ) async def start( @@ -241,7 +241,7 @@ class AgentSession: kwargs = {} if runtime_cls == RemoteRuntime: - kwargs['github_user_id'] = self.github_user_id + kwargs['user_id'] = self.user_id self.runtime = runtime_cls( config=config, diff --git a/openhands/server/session/conversation_init_data.py b/openhands/server/session/conversation_init_data.py index 4cb6acd50f..ce7acbbe3e 100644 --- a/openhands/server/session/conversation_init_data.py +++ b/openhands/server/session/conversation_init_data.py @@ -8,6 +8,6 @@ class ConversationInitData(Settings): Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data. """ - github_token: SecretStr | None = Field(default=None) + provider_token: SecretStr | None = Field(default=None) selected_repository: str | None = Field(default=None) selected_branch: str | None = Field(default=None) diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 0004dcb176..1690e39d30 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -61,7 +61,7 @@ class Session: sid, file_store, status_callback=self.queue_status_message, - github_user_id=user_id, + user_id=user_id, ) self.agent_session.event_stream.subscribe( EventStreamSubscriber.SERVER, self.on_event, self.sid @@ -123,11 +123,11 @@ class Session: agent = Agent.get_cls(agent_cls)(llm, agent_config) - github_token = None + provider_token = None selected_repository = None selected_branch = None if isinstance(settings, ConversationInitData): - github_token = settings.github_token + provider_token = settings.provider_token selected_repository = settings.selected_repository selected_branch = settings.selected_branch @@ -140,7 +140,7 @@ class Session: max_budget_per_task=self.config.max_budget_per_task, agent_to_llm_config=self.config.get_agent_to_llm_config_map(), agent_configs=self.config.get_agent_configs(), - github_token=github_token, + github_token=provider_token, selected_repository=selected_repository, selected_branch=selected_branch, initial_message=initial_message, diff --git a/openhands/server/settings.py b/openhands/server/settings.py index 4059bf4886..a16ab3336c 100644 --- a/openhands/server/settings.py +++ b/openhands/server/settings.py @@ -43,7 +43,7 @@ class Settings(BaseModel): if context and context.get('expose_secrets', False): return llm_api_key.get_secret_value() - return pydantic_encoder(llm_api_key) + return pydantic_encoder(llm_api_key) if llm_api_key else None @staticmethod def _convert_token_value( diff --git a/openhands/storage/conversation/conversation_store.py b/openhands/storage/conversation/conversation_store.py index e74de6fcf7..8a7a73bdb6 100644 --- a/openhands/storage/conversation/conversation_store.py +++ b/openhands/storage/conversation/conversation_store.py @@ -12,25 +12,36 @@ from openhands.utils.async_utils import wait_all class ConversationStore(ABC): - """ - Storage for conversation metadata. May or may not support multiple users depending on the environment - """ + """Storage for conversation metadata. May or may not support multiple users depending on the environment.""" @abstractmethod async def save_metadata(self, metadata: ConversationMetadata) -> None: - """Store conversation metadata""" + """Store conversation metadata.""" @abstractmethod async def get_metadata(self, conversation_id: str) -> ConversationMetadata: - """Load conversation metadata""" + """Load conversation metadata.""" + + async def validate_metadata( + self, conversation_id: str, user_id: str, github_user_id: str + ) -> bool: + """Validate that conversation belongs to the current user.""" + # TODO: remove github_user_id after transition to Keycloak is complete. + metadata = await self.get_metadata(conversation_id) + if (not metadata.user_id and not metadata.github_user_id) or ( + metadata.user_id != user_id and metadata.github_user_id != github_user_id + ): + return False + else: + return True @abstractmethod async def delete_metadata(self, conversation_id: str) -> None: - """delete conversation metadata""" + """Delete conversation metadata.""" @abstractmethod async def exists(self, conversation_id: str) -> bool: - """Check if conversation exists""" + """Check if conversation exists.""" @abstractmethod async def search( @@ -49,6 +60,6 @@ class ConversationStore(ABC): @classmethod @abstractmethod async def get_instance( - cls, config: AppConfig, user_id: str | None + cls, config: AppConfig, user_id: str | None, github_user_id: str | None ) -> ConversationStore: """Get a store for the user represented by the token given""" diff --git a/openhands/storage/conversation/conversation_validator.py b/openhands/storage/conversation/conversation_validator.py index 51f293b395..63a177b0c2 100644 --- a/openhands/storage/conversation/conversation_validator.py +++ b/openhands/storage/conversation/conversation_validator.py @@ -7,7 +7,7 @@ class ConversationValidator: """Storage for conversation metadata. May or may not support multiple users depending on the environment.""" async def validate(self, conversation_id: str, cookies_str: str): - return None + return None, None conversation_validator_cls = os.environ.get( diff --git a/openhands/storage/conversation/file_conversation_store.py b/openhands/storage/conversation/file_conversation_store.py index ed18b7cd01..4ff56d8939 100644 --- a/openhands/storage/conversation/file_conversation_store.py +++ b/openhands/storage/conversation/file_conversation_store.py @@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore): @classmethod async def get_instance( - cls, config: AppConfig, user_id: str | None + cls, config: AppConfig, user_id: str | None, github_user_id: str | None ) -> FileConversationStore: file_store = get_file_store(config.file_store, config.file_store_path) return FileConversationStore(file_store) diff --git a/openhands/storage/data_models/conversation_metadata.py b/openhands/storage/data_models/conversation_metadata.py index 15909e9b51..85f5070050 100644 --- a/openhands/storage/data_models/conversation_metadata.py +++ b/openhands/storage/data_models/conversation_metadata.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone @dataclass class ConversationMetadata: conversation_id: str + user_id: str | None github_user_id: str | None selected_repository: str | None selected_branch: str | None = None diff --git a/tests/unit/test_conversation.py b/tests/unit/test_conversation.py index c38f09b2ad..0b47551048 100644 --- a/tests/unit/test_conversation.py +++ b/tests/unit/test_conversation.py @@ -32,6 +32,7 @@ def _patch_store(): 'selected_repository': 'foobar', 'conversation_id': 'some_conversation_id', 'github_user_id': '12345', + 'user_id': '12345', 'created_at': '2025-01-01T00:00:00+00:00', 'last_updated_at': '2025-01-01T00:01:00+00:00', } diff --git a/tests/unit/test_file_conversation_store.py b/tests/unit/test_file_conversation_store.py index 4e4172f452..cb1ecf1eca 100644 --- a/tests/unit/test_file_conversation_store.py +++ b/tests/unit/test_file_conversation_store.py @@ -13,7 +13,8 @@ async def test_load_store(): store = FileConversationStore(InMemoryFileStore({})) expected = ConversationMetadata( conversation_id='some-conversation-id', - github_user_id='some-user-id', + user_id='some-user-id', + github_user_id='12345', selected_repository='some-repo', title="Let's talk about trains", ) @@ -31,6 +32,7 @@ async def test_load_int_user_id(): { 'conversation_id': 'some-conversation-id', 'github_user_id': 12345, + 'user_id': '67890', 'selected_repository': 'some-repo', 'title': "Let's talk about trains", 'created_at': '2025-01-16T19:51:04.886331Z', @@ -41,6 +43,7 @@ async def test_load_int_user_id(): ) found = await store.get_metadata('some-conversation-id') assert found.github_user_id == '12345' + assert found.user_id == '67890' @pytest.mark.asyncio @@ -61,6 +64,7 @@ async def test_search_basic(): { 'conversation_id': 'conv1', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'First conversation', 'created_at': '2025-01-16T19:51:04Z', @@ -70,6 +74,7 @@ async def test_search_basic(): { 'conversation_id': 'conv2', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'Second conversation', 'created_at': '2025-01-17T19:51:04Z', @@ -79,6 +84,7 @@ async def test_search_basic(): { 'conversation_id': 'conv3', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'Third conversation', 'created_at': '2025-01-15T19:51:04Z', @@ -107,6 +113,7 @@ async def test_search_pagination(): { 'conversation_id': f'conv{i}', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': f'Conversation {i}', 'created_at': f'2025-01-{15+i}T19:51:04Z', @@ -148,6 +155,7 @@ async def test_search_with_invalid_conversation(): { 'conversation_id': 'conv1', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'Valid conversation', 'created_at': '2025-01-16T19:51:04Z', @@ -176,6 +184,7 @@ async def test_get_all_metadata(): { 'conversation_id': 'conv1', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'First conversation', 'created_at': '2025-01-16T19:51:04Z', @@ -185,6 +194,7 @@ async def test_get_all_metadata(): { 'conversation_id': 'conv2', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'Second conversation', 'created_at': '2025-01-17T19:51:04Z', diff --git a/tests/unit/test_search_utils.py b/tests/unit/test_search_utils.py index cdf80d8040..8d1e634cad 100644 --- a/tests/unit/test_search_utils.py +++ b/tests/unit/test_search_utils.py @@ -49,6 +49,7 @@ async def test_iterate_single_page(): { 'conversation_id': 'conv1', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'First conversation', 'created_at': '2025-01-16T19:51:04Z', @@ -58,6 +59,7 @@ async def test_iterate_single_page(): { 'conversation_id': 'conv2', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'Second conversation', 'created_at': '2025-01-17T19:51:04Z', @@ -86,6 +88,7 @@ async def test_iterate_multiple_pages(): { 'conversation_id': f'conv{i}', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': f'Conversation {i}', 'created_at': f'2025-01-{15+i}T19:51:04Z', @@ -120,6 +123,7 @@ async def test_iterate_with_invalid_conversation(): { 'conversation_id': 'conv1', 'github_user_id': '123', + 'user_id': '123', 'selected_repository': 'repo1', 'title': 'Valid conversation', 'created_at': '2025-01-16T19:51:04Z', diff --git a/tests/unit/test_standalone_conversation_manager.py b/tests/unit/test_standalone_conversation_manager.py index 2e49768aed..becc59b944 100644 --- a/tests/unit/test_standalone_conversation_manager.py +++ b/tests/unit/test_standalone_conversation_manager.py @@ -61,7 +61,7 @@ async def test_init_new_local_session(): 'new-session-id', ConversationInitData(), 1 ) await conversation_manager.join_conversation( - 'new-session-id', 'new-session-id', ConversationInitData(), 1 + 'new-session-id', 'new-session-id', ConversationInitData(), 1, '12345' ) assert session_instance.initialize_agent.call_count == 1 assert sio.enter_room.await_count == 1 @@ -93,10 +93,18 @@ async def test_join_local_session(): 'new-session-id', ConversationInitData(), None ) await conversation_manager.join_conversation( - 'new-session-id', 'new-session-id', ConversationInitData(), None + 'new-session-id', + 'new-session-id', + ConversationInitData(), + None, + '12345', ) await conversation_manager.join_conversation( - 'new-session-id', 'new-session-id', ConversationInitData(), None + 'new-session-id', + 'new-session-id', + ConversationInitData(), + None, + '12345', ) assert session_instance.initialize_agent.call_count == 1 assert sio.enter_room.await_count == 2 @@ -128,7 +136,7 @@ async def test_add_to_local_event_stream(): 'new-session-id', ConversationInitData(), 1 ) await conversation_manager.join_conversation( - 'new-session-id', 'connection-id', ConversationInitData(), 1 + 'new-session-id', 'connection-id', ConversationInitData(), 1, '12345' ) await conversation_manager.send_to_event_stream( 'connection-id', {'event_type': 'some_event'}