diff --git a/enterprise/storage/saas_conversation_store.py b/enterprise/storage/saas_conversation_store.py index 7eb4919ad9..eec6961d02 100644 --- a/enterprise/storage/saas_conversation_store.py +++ b/enterprise/storage/saas_conversation_store.py @@ -234,10 +234,8 @@ class SaasConversationStore(ConversationStore): cls, config: OpenHandsConfig, user_id: str | None ) -> ConversationStore: # user_id should not be None in SaaS, should we raise? - # Use sync version because this method can be called from call_async_from_sync - # (e.g., from _create_conversation_update_callback in standalone_conversation_manager.py) - # which creates a new event loop. Using async DB operations in that context would - # cause asyncpg connection errors since connections are tied to the original event loop. - user = UserStore.get_user_by_id(user_id) + # Use async version since callers now use asyncio.run_coroutine_threadsafe() + # to dispatch to the main event loop where asyncpg connections work properly. + user = await UserStore.get_user_by_id_async(user_id) org_id = user.current_org_id if user else None return SaasConversationStore(str(user_id), org_id, session_maker) diff --git a/enterprise/tests/unit/test_saas_conversation_store.py b/enterprise/tests/unit/test_saas_conversation_store.py index a7d000d0ec..4d59c1227f 100644 --- a/enterprise/tests/unit/test_saas_conversation_store.py +++ b/enterprise/tests/unit/test_saas_conversation_store.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from uuid import UUID import pytest @@ -169,15 +169,14 @@ async def test_exists(session_maker): class TestGetInstance: """Tests for SaasConversationStore.get_instance method. - The get_instance method uses sync UserStore.get_user_by_id (not async) - because it can be called from call_async_from_sync contexts which create - a new event loop. Using async DB operations in that context would cause - asyncpg connection errors. + The get_instance method uses async UserStore.get_user_by_id_async because + callers now use asyncio.run_coroutine_threadsafe() to dispatch to the main + event loop where asyncpg connections work properly. """ @pytest.mark.asyncio - async def test_get_instance_uses_sync_get_user_by_id(self): - """Verify get_instance calls the sync get_user_by_id, not async version.""" + async def test_get_instance_uses_async_get_user_by_id(self): + """Verify get_instance calls the async get_user_by_id_async for proper event loop handling.""" # Arrange user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' mock_user = MagicMock(spec=User) @@ -185,14 +184,16 @@ class TestGetInstance: mock_config = MagicMock(spec=OpenHandsConfig) with patch( - 'storage.saas_conversation_store.UserStore.get_user_by_id', - return_value=mock_user, - ) as mock_sync_get_user, patch('storage.saas_conversation_store.session_maker'): + 'storage.saas_conversation_store.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ) as mock_async_get_user, patch( + 'storage.saas_conversation_store.session_maker' + ): # Act store = await SaasConversationStore.get_instance(mock_config, user_id) # Assert - mock_sync_get_user.assert_called_once_with(user_id) + mock_async_get_user.assert_called_once_with(user_id) assert store.user_id == user_id assert store.org_id == mock_user.current_org_id @@ -204,8 +205,8 @@ class TestGetInstance: mock_config = MagicMock(spec=OpenHandsConfig) with patch( - 'storage.saas_conversation_store.UserStore.get_user_by_id', - return_value=None, + 'storage.saas_conversation_store.UserStore.get_user_by_id_async', + AsyncMock(return_value=None), ), patch('storage.saas_conversation_store.session_maker'): # Act store = await SaasConversationStore.get_instance(mock_config, user_id) diff --git a/openhands/server/conversation_manager/standalone_conversation_manager.py b/openhands/server/conversation_manager/standalone_conversation_manager.py index 5e804cdea3..a4b2efad0a 100644 --- a/openhands/server/conversation_manager/standalone_conversation_manager.py +++ b/openhands/server/conversation_manager/standalone_conversation_manager.py @@ -41,8 +41,6 @@ from openhands.storage.data_models.conversation_status import ConversationStatus from openhands.storage.data_models.settings import Settings from openhands.storage.files import FileStore from openhands.utils.async_utils import ( - GENERAL_TIMEOUT, - call_async_from_sync, call_sync_from_async, run_in_loop, wait_all, @@ -370,7 +368,11 @@ class StandaloneConversationManager(ConversationManager): session.agent_session.event_stream.subscribe( EventStreamSubscriber.SERVER, self._create_conversation_update_callback( - user_id, sid, settings, session.llm_registry + user_id, + sid, + settings, + session.llm_registry, + asyncio.get_running_loop(), ), UPDATED_AT_CALLBACK_ID, ) @@ -618,23 +620,28 @@ class StandaloneConversationManager(ConversationManager): conversation_id: str, settings: Settings, llm_registry: LLMRegistry, + loop: asyncio.AbstractEventLoop, ) -> Callable: def callback(event, *args, **kwargs): - call_async_from_sync( - self._update_conversation_for_event, - GENERAL_TIMEOUT, - user_id, - conversation_id, - settings, - llm_registry, - event, - ) + try: + asyncio.run_coroutine_threadsafe( + self._update_conversation_for_event( + user_id, + conversation_id, + settings, + llm_registry, + event, + ), + loop, + ) + except Exception as e: + logger.error(f'Error in conversation update callback: {e}') return callback async def _update_conversation_for_event( self, - user_id: str, + user_id: str | None, conversation_id: str, settings: Settings, llm_registry: LLMRegistry,