mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Feat: Allow checking multiple conversations running at the same time (#5843)
This commit is contained in:
parent
69a9080480
commit
500598666e
@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
|
||||
import socketio
|
||||
|
||||
@ -27,6 +28,14 @@ class ConversationDoesNotExistError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionIsRunningCheck:
|
||||
request_id: str
|
||||
request_sids: list[str]
|
||||
running_sids: set[str] = field(default_factory=set)
|
||||
flag: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionManager:
|
||||
sio: socketio.AsyncServer
|
||||
@ -36,7 +45,9 @@ class SessionManager:
|
||||
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
|
||||
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
|
||||
_redis_listen_task: asyncio.Task | None = None
|
||||
_session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
|
||||
_session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
_active_conversations: dict[str, tuple[Conversation, int]] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
@ -97,27 +108,41 @@ class SessionManager:
|
||||
async def _process_message(self, message: dict):
|
||||
data = json.loads(message['data'])
|
||||
logger.debug(f'got_published_message:{message}')
|
||||
sid = data['sid']
|
||||
message_type = data['message_type']
|
||||
if message_type == 'event':
|
||||
sid = data['sid']
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
await session.dispatch(data['data'])
|
||||
elif message_type == 'is_session_running':
|
||||
# Another node in the cluster is asking if the current node is running the session given.
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
request_id = data['request_id']
|
||||
sids = [
|
||||
sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
|
||||
]
|
||||
if sids:
|
||||
await self._get_redis_client().publish(
|
||||
'oh_event',
|
||||
json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
|
||||
json.dumps(
|
||||
{
|
||||
'request_id': request_id,
|
||||
'sids': sids,
|
||||
'message_type': 'session_is_running',
|
||||
}
|
||||
),
|
||||
)
|
||||
elif message_type == 'session_is_running':
|
||||
self._last_alive_timestamps[sid] = time.time()
|
||||
flag = self._session_is_running_flags.get(sid)
|
||||
if flag:
|
||||
flag.set()
|
||||
request_id = data['request_id']
|
||||
for sid in data['sids']:
|
||||
self._last_alive_timestamps[sid] = time.time()
|
||||
check = self._session_is_running_checks.get(request_id)
|
||||
if check:
|
||||
check.running_sids.update(data['sids'])
|
||||
if len(check.request_sids) == len(check.running_sids):
|
||||
check.flag.set()
|
||||
elif message_type == 'has_remote_connections_query':
|
||||
# Another node in the cluster is asking if the current node is connected to a session
|
||||
sid = data['sid']
|
||||
required = sid in self.local_connection_id_to_session_id.values()
|
||||
if required:
|
||||
await self._get_redis_client().publish(
|
||||
@ -127,12 +152,14 @@ class SessionManager:
|
||||
),
|
||||
)
|
||||
elif message_type == 'has_remote_connections_response':
|
||||
sid = data['sid']
|
||||
flag = self._has_remote_connections_flags.get(sid)
|
||||
if flag:
|
||||
flag.set()
|
||||
elif message_type == 'session_closing':
|
||||
# Session closing event - We only get this in the event of graceful shutdown,
|
||||
# which can't be guaranteed - nodes can simply vanish unexpectedly!
|
||||
sid = data['sid']
|
||||
logger.debug(f'session_closing:{sid}')
|
||||
for (
|
||||
connection_id,
|
||||
@ -234,33 +261,47 @@ class SessionManager:
|
||||
logger.warning('error_cleaning_detached_conversations', exc_info=True)
|
||||
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
|
||||
|
||||
async def _is_agent_loop_running(self, sid: str) -> bool:
|
||||
if await self._is_agent_loop_running_locally(sid):
|
||||
async def get_agent_loop_running(self, sids: set[str]) -> set[str]:
|
||||
running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
|
||||
check_cluster_sids = [sid for sid in sids if sid not in running_sids]
|
||||
running_cluster_sids = await self.get_agent_loop_running_in_cluster(
|
||||
check_cluster_sids
|
||||
)
|
||||
running_sids.union(running_cluster_sids)
|
||||
return running_sids
|
||||
|
||||
async def is_agent_loop_running(self, sid: str) -> bool:
|
||||
if await self.is_agent_loop_running_locally(sid):
|
||||
return True
|
||||
if await self._is_agent_loop_running_in_cluster(sid):
|
||||
if await self.is_agent_loop_running_in_cluster(sid):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _is_agent_loop_running_locally(self, sid: str) -> bool:
|
||||
if self._local_agent_loops_by_sid.get(sid, None):
|
||||
return True
|
||||
return False
|
||||
async def is_agent_loop_running_locally(self, sid: str) -> bool:
|
||||
return sid in self._local_agent_loops_by_sid
|
||||
|
||||
async def _is_agent_loop_running_in_cluster(self, sid: str) -> bool:
|
||||
async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
|
||||
running_sids = await self.get_agent_loop_running_in_cluster([sid])
|
||||
return bool(running_sids)
|
||||
|
||||
async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
|
||||
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
|
||||
redis_client = self._get_redis_client()
|
||||
if not redis_client:
|
||||
return False
|
||||
return set()
|
||||
|
||||
flag = asyncio.Event()
|
||||
self._session_is_running_flags[sid] = flag
|
||||
request_id = str(uuid4())
|
||||
check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
|
||||
self._session_is_running_checks[request_id] = check
|
||||
try:
|
||||
logger.debug(f'publish:is_session_running:{sid}')
|
||||
logger.debug(f'publish:is_session_running:{sids}')
|
||||
await redis_client.publish(
|
||||
'oh_event',
|
||||
json.dumps(
|
||||
{
|
||||
'sid': sid,
|
||||
'request_id': request_id,
|
||||
'sids': sids,
|
||||
'message_type': 'is_session_running',
|
||||
}
|
||||
),
|
||||
@ -268,13 +309,12 @@ class SessionManager:
|
||||
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
|
||||
await flag.wait()
|
||||
|
||||
result = flag.is_set()
|
||||
return result
|
||||
return check.running_sids
|
||||
except TimeoutError:
|
||||
# Nobody replied in time
|
||||
return False
|
||||
return check.running_sids
|
||||
finally:
|
||||
self._session_is_running_flags.pop(sid, None)
|
||||
self._session_is_running_checks.pop(request_id, None)
|
||||
|
||||
async def _has_remote_connections(self, sid: str) -> bool:
|
||||
"""As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
|
||||
@ -307,7 +347,7 @@ class SessionManager:
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}')
|
||||
session: Session | None = None
|
||||
if not await self._is_agent_loop_running(sid):
|
||||
if not await self.is_agent_loop_running(sid):
|
||||
logger.info(f'start_agent_loop:{sid}')
|
||||
session = Session(
|
||||
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
|
||||
@ -328,7 +368,7 @@ class SessionManager:
|
||||
logger.info(f'found_local_agent_loop:{sid}')
|
||||
return session.agent_session.event_stream
|
||||
|
||||
if await self._is_agent_loop_running_in_cluster(sid):
|
||||
if await self.is_agent_loop_running_in_cluster(sid):
|
||||
logger.info(f'found_remote_agent_loop:{sid}')
|
||||
return EventStream(sid, self.file_store)
|
||||
|
||||
@ -352,7 +392,7 @@ class SessionManager:
|
||||
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
|
||||
if (
|
||||
next_alive_check > time.time()
|
||||
or await self._is_agent_loop_running_in_cluster(sid)
|
||||
or await self.is_agent_loop_running_in_cluster(sid)
|
||||
):
|
||||
# Send the event to the other pod
|
||||
await redis_client.publish(
|
||||
|
||||
@ -9,8 +9,8 @@ IN_MEMORY_FILES: dict = {}
|
||||
class InMemoryFileStore(FileStore):
|
||||
files: dict[str, str]
|
||||
|
||||
def __init__(self):
|
||||
self.files = IN_MEMORY_FILES
|
||||
def __init__(self, files: dict[str, str] = IN_MEMORY_FILES):
|
||||
self.files = files
|
||||
|
||||
def write(self, path: str, contents: str) -> None:
|
||||
self.files[path] = contents
|
||||
|
||||
@ -20,6 +20,7 @@ from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -168,7 +169,7 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck():
|
||||
config = AppConfig()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
@ -35,44 +36,56 @@ def get_mock_sio(get_message: GetMessageMock | None = None):
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_not_running_in_cluster():
|
||||
sio = get_mock_sio()
|
||||
id = uuid4()
|
||||
with (
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
result = await session_manager._is_agent_loop_running_in_cluster(
|
||||
result = await session_manager.is_agent_loop_running_in_cluster(
|
||||
'non-existant-session'
|
||||
)
|
||||
assert result is False
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'oh_event',
|
||||
'{"sid": "non-existant-session", "message_type": "is_session_running"}',
|
||||
'{"request_id": "'
|
||||
+ str(id)
|
||||
+ '", "sids": ["non-existant-session"], "message_type": "is_session_running"}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_is_running_in_cluster():
|
||||
id = uuid4()
|
||||
sio = get_mock_sio(
|
||||
GetMessageMock(
|
||||
{'sid': 'existing-session', 'message_type': 'session_is_running'}
|
||||
{
|
||||
'request_id': str(id),
|
||||
'sids': ['existing-session'],
|
||||
'message_type': 'session_is_running',
|
||||
}
|
||||
)
|
||||
)
|
||||
with (
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
|
||||
patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
result = await session_manager._is_agent_loop_running_in_cluster(
|
||||
result = await session_manager.is_agent_loop_running_in_cluster(
|
||||
'existing-session'
|
||||
)
|
||||
assert result is True
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'oh_event',
|
||||
'{"sid": "existing-session", "message_type": "is_session_running"}',
|
||||
'{"request_id": "'
|
||||
+ str(id)
|
||||
+ '", "sids": ["existing-session"], "message_type": "is_session_running"}',
|
||||
)
|
||||
|
||||
|
||||
@ -93,7 +106,7 @@ async def test_init_new_local_session():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
),
|
||||
):
|
||||
@ -125,7 +138,7 @@ async def test_join_local_session():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
),
|
||||
):
|
||||
@ -158,7 +171,7 @@ async def test_join_cluster_session():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
),
|
||||
):
|
||||
@ -187,7 +200,7 @@ async def test_add_to_local_event_stream():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
),
|
||||
):
|
||||
@ -221,7 +234,7 @@ async def test_add_to_cluster_event_stream():
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
|
||||
'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster',
|
||||
is_agent_loop_running_in_cluster_mock,
|
||||
),
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user