Clean up conversation joining (#7379)

This commit is contained in:
Robert Brennan 2025-03-21 09:18:37 -04:00 committed by GitHub
parent d9926d2491
commit 37188c7606
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 32 deletions

View File

@ -11,7 +11,6 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import MessageAction
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.stream import EventStream, EventStreamSubscriber, session_exists
from openhands.server.config.server_config import ServerConfig
from openhands.server.monitoring import MonitoringListener
@ -116,27 +115,22 @@ class StandaloneConversationManager(ConversationManager):
settings: Settings,
user_id: str | None,
github_user_id: str | None,
):
) -> EventStream:
logger.info(
f'join_conversation:{sid}:{connection_id}',
extra={'session_id': sid, 'user_id': user_id},
)
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid, user_id)
event_stream = await self.maybe_start_agent_loop(
sid, settings, user_id, github_user_id=github_user_id
)
if not event_stream:
return await self.maybe_start_agent_loop(
sid, settings, user_id, github_user_id=github_user_id
logger.error(
f'No event stream after joining conversation: {sid}',
extra={'session_id': sid},
)
for event in event_stream.get_events(reverse=True):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in (
AgentState.STOPPED.value,
AgentState.ERROR.value,
):
await self.close_session(sid)
return await self.maybe_start_agent_loop(sid, settings, user_id)
break
raise RuntimeError(f'no_event_stream:{sid}')
return event_stream
async def detach_from_conversation(self, conversation: Conversation):

View File

@ -54,10 +54,13 @@ async def connect(connection_id: str, environ):
event_stream = await conversation_manager.join_conversation(
conversation_id, connection_id, settings, user_id, github_user_id
)
logger.info(
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
)
agent_state_changed = None
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
async for event in async_stream:
logger.info(f'oh_event: {event.__class__.__name__}')
if isinstance(
event,
(NullAction, NullObservation, RecallAction, RecallObservation),
@ -69,6 +72,7 @@ async def connect(connection_id: str, environ):
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
if agent_state_changed:
await sio.emit('oh_event', event_to_dict(agent_state_changed), to=connection_id)
logger.info(f'Finished replaying event stream for conversation {conversation_id}')
@sio.event

View File

@ -44,6 +44,8 @@ async def test_init_new_local_session():
sio = get_mock_sio()
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = set()
is_agent_loop_running_mock = AsyncMock()
is_agent_loop_running_mock.return_value = True
with (
patch(
'openhands.server.conversation_manager.standalone_conversation_manager.Session',
@ -60,9 +62,19 @@ async def test_init_new_local_session():
await conversation_manager.maybe_start_agent_loop(
'new-session-id', ConversationInitData(), 1
)
await conversation_manager.join_conversation(
'new-session-id', 'new-session-id', ConversationInitData(), 1, '12345'
)
with (
patch(
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager.is_agent_loop_running',
is_agent_loop_running_mock,
),
):
await conversation_manager.join_conversation(
'new-session-id',
'new-session-id',
ConversationInitData(),
1,
'12345',
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 1
@ -76,6 +88,8 @@ async def test_join_local_session():
sio = get_mock_sio()
get_running_agent_loops_mock = AsyncMock()
get_running_agent_loops_mock.return_value = set()
is_agent_loop_running_mock = AsyncMock()
is_agent_loop_running_mock.return_value = True
with (
patch(
'openhands.server.conversation_manager.standalone_conversation_manager.Session',
@ -92,20 +106,26 @@ async def test_join_local_session():
await conversation_manager.maybe_start_agent_loop(
'new-session-id', ConversationInitData(), None
)
await conversation_manager.join_conversation(
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
await conversation_manager.join_conversation(
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
with (
patch(
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager.is_agent_loop_running',
is_agent_loop_running_mock,
),
):
await conversation_manager.join_conversation(
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
await conversation_manager.join_conversation(
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 2