mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
fix(backend): use run_coroutine_threadsafe for conversation update callbacks (#13134)
This commit is contained in:
@@ -234,10 +234,8 @@ class SaasConversationStore(ConversationStore):
|
||||
cls, config: OpenHandsConfig, user_id: str | None
|
||||
) -> ConversationStore:
|
||||
# user_id should not be None in SaaS, should we raise?
|
||||
# Use sync version because this method can be called from call_async_from_sync
|
||||
# (e.g., from _create_conversation_update_callback in standalone_conversation_manager.py)
|
||||
# which creates a new event loop. Using async DB operations in that context would
|
||||
# cause asyncpg connection errors since connections are tied to the original event loop.
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
|
||||
# to dispatch to the main event loop where asyncpg connections work properly.
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(str(user_id), org_id, session_maker)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@@ -169,15 +169,14 @@ async def test_exists(session_maker):
|
||||
class TestGetInstance:
|
||||
"""Tests for SaasConversationStore.get_instance method.
|
||||
|
||||
The get_instance method uses sync UserStore.get_user_by_id (not async)
|
||||
because it can be called from call_async_from_sync contexts which create
|
||||
a new event loop. Using async DB operations in that context would cause
|
||||
asyncpg connection errors.
|
||||
The get_instance method uses async UserStore.get_user_by_id_async because
|
||||
callers now use asyncio.run_coroutine_threadsafe() to dispatch to the main
|
||||
event loop where asyncpg connections work properly.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_uses_sync_get_user_by_id(self):
|
||||
"""Verify get_instance calls the sync get_user_by_id, not async version."""
|
||||
async def test_get_instance_uses_async_get_user_by_id(self):
|
||||
"""Verify get_instance calls the async get_user_by_id_async for proper event loop handling."""
|
||||
# Arrange
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
mock_user = MagicMock(spec=User)
|
||||
@@ -185,14 +184,16 @@ class TestGetInstance:
|
||||
mock_config = MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id',
|
||||
return_value=mock_user,
|
||||
) as mock_sync_get_user, patch('storage.saas_conversation_store.session_maker'):
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
) as mock_async_get_user, patch(
|
||||
'storage.saas_conversation_store.session_maker'
|
||||
):
|
||||
# Act
|
||||
store = await SaasConversationStore.get_instance(mock_config, user_id)
|
||||
|
||||
# Assert
|
||||
mock_sync_get_user.assert_called_once_with(user_id)
|
||||
mock_async_get_user.assert_called_once_with(user_id)
|
||||
assert store.user_id == user_id
|
||||
assert store.org_id == mock_user.current_org_id
|
||||
|
||||
@@ -204,8 +205,8 @@ class TestGetInstance:
|
||||
mock_config = MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id',
|
||||
return_value=None,
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=None),
|
||||
), patch('storage.saas_conversation_store.session_maker'):
|
||||
# Act
|
||||
store = await SaasConversationStore.get_instance(mock_config, user_id)
|
||||
|
||||
@@ -41,8 +41,6 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import (
|
||||
GENERAL_TIMEOUT,
|
||||
call_async_from_sync,
|
||||
call_sync_from_async,
|
||||
run_in_loop,
|
||||
wait_all,
|
||||
@@ -370,7 +368,11 @@ class StandaloneConversationManager(ConversationManager):
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(
|
||||
user_id, sid, settings, session.llm_registry
|
||||
user_id,
|
||||
sid,
|
||||
settings,
|
||||
session.llm_registry,
|
||||
asyncio.get_running_loop(),
|
||||
),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
@@ -618,23 +620,28 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
llm_registry: LLMRegistry,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> Callable:
|
||||
def callback(event, *args, **kwargs):
|
||||
call_async_from_sync(
|
||||
self._update_conversation_for_event,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
conversation_id,
|
||||
settings,
|
||||
llm_registry,
|
||||
event,
|
||||
)
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._update_conversation_for_event(
|
||||
user_id,
|
||||
conversation_id,
|
||||
settings,
|
||||
llm_registry,
|
||||
event,
|
||||
),
|
||||
loop,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error in conversation update callback: {e}')
|
||||
|
||||
return callback
|
||||
|
||||
async def _update_conversation_for_event(
|
||||
self,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
llm_registry: LLMRegistry,
|
||||
|
||||
Reference in New Issue
Block a user