mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 13:52:43 +08:00
803 lines
32 KiB
Python
803 lines
32 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from uuid import uuid4
|
|
|
|
import socketio
|
|
from server.logger import logger
|
|
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
|
from storage.database import session_maker
|
|
from storage.saas_settings_store import SaasSettingsStore
|
|
from storage.stored_conversation_metadata import StoredConversationMetadata
|
|
|
|
from openhands.core.config import LLMConfig
|
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
|
from openhands.core.config.utils import load_openhands_config
|
|
from openhands.core.schema.agent import AgentState
|
|
from openhands.events.action import MessageAction
|
|
from openhands.events.event_store import EventStore
|
|
from openhands.events.event_store_abc import EventStoreABC
|
|
from openhands.events.observation import AgentStateChangedObservation
|
|
from openhands.events.stream import EventStreamSubscriber
|
|
from openhands.llm.llm_registry import LLMRegistry
|
|
from openhands.runtime.runtime_status import RuntimeStatus
|
|
from openhands.server.config.server_config import ServerConfig
|
|
from openhands.server.conversation_manager.conversation_manager import (
|
|
ConversationManager,
|
|
)
|
|
from openhands.server.conversation_manager.standalone_conversation_manager import (
|
|
StandaloneConversationManager,
|
|
)
|
|
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
|
from openhands.server.monitoring import MonitoringListener
|
|
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
|
|
from openhands.server.session.session import Session
|
|
from openhands.server.settings import Settings
|
|
from openhands.storage.files import FileStore
|
|
from openhands.utils.async_utils import call_sync_from_async, wait_all
|
|
from openhands.utils.shutdown_listener import should_continue
|
|
|
|
# Time in seconds between cleanup operations for stale conversations
|
|
_CLEANUP_INTERVAL_SECONDS = 15
|
|
|
|
# Time in seconds before a Redis entry is considered expired if not refreshed
|
|
_REDIS_ENTRY_TIMEOUT_SECONDS = 15
|
|
|
|
# Time in seconds between updates to Redis entries
|
|
_REDIS_UPDATE_INTERVAL_SECONDS = 5
|
|
|
|
_REDIS_POLL_TIMEOUT = 0.15
|
|
|
|
|
|
@dataclass
|
|
class _LLMResponseRequest:
|
|
query_id: str
|
|
response: str | None
|
|
flag: asyncio.Event
|
|
|
|
|
|
@dataclass
|
|
class ClusteredConversationManager(StandaloneConversationManager):
|
|
"""Manages conversations in clustered mode (multiple server instances with Redis).
|
|
|
|
This class extends StandaloneConversationManager to provide distributed conversation
|
|
management across multiple server instances using Redis as a communication channel
|
|
and state store. It handles:
|
|
|
|
- Cross-server message passing via Redis pub/sub
|
|
- Tracking of conversations and connections across the cluster
|
|
- Graceful recovery from server failures
|
|
- Enforcement of conversation limits across the cluster
|
|
- Cleanup of stale conversations and connections
|
|
|
|
The Redis communication uses several key patterns:
|
|
- ohcnv:{user_id}:{conversation_id} - Marks a conversation as active
|
|
- ohcnct:{user_id}:{conversation_id}:{connection_id} - Tracks connections to conversations
|
|
"""
|
|
|
|
_redis_listen_task: asyncio.Task | None = field(default=None)
|
|
_redis_update_task: asyncio.Task | None = field(default=None)
|
|
|
|
_llm_responses: dict[str, _LLMResponseRequest] = field(default_factory=dict)
|
|
|
|
def __post_init__(self):
|
|
# We increment the max_concurrent_conversations by 1 because this class
|
|
# marks the conversation as started in Redis before checking the number
|
|
# of running conversations. This prevents race conditions where multiple
|
|
# servers might simultaneously start new conversations.
|
|
self.config.max_concurrent_conversations += 1
|
|
|
|
async def __aenter__(self):
|
|
await super().__aenter__()
|
|
self._redis_update_task = asyncio.create_task(
|
|
self._update_state_in_redis_task()
|
|
)
|
|
self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
if self._redis_update_task:
|
|
self._redis_update_task.cancel()
|
|
self._redis_update_task = None
|
|
if self._redis_listen_task:
|
|
self._redis_listen_task.cancel()
|
|
self._redis_listen_task = None
|
|
await super().__aexit__(exc_type, exc_value, traceback)
|
|
|
|
async def _redis_subscribe(self):
|
|
"""Subscribe to Redis messages for cross-server communication.
|
|
|
|
This method creates a Redis pub/sub subscription to receive messages from
|
|
other server instances. It runs in a continuous loop until cancelled.
|
|
"""
|
|
logger.debug('_redis_subscribe')
|
|
redis_client = self._get_redis_client()
|
|
pubsub = redis_client.pubsub()
|
|
await pubsub.subscribe('session_msg')
|
|
while should_continue():
|
|
try:
|
|
message = await pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=5
|
|
)
|
|
if message:
|
|
await self._process_message(message)
|
|
except asyncio.CancelledError:
|
|
logger.debug('redis_subscribe_cancelled')
|
|
return
|
|
except Exception as e:
|
|
try:
|
|
asyncio.get_running_loop()
|
|
logger.exception(f'error_reading_from_redis:{str(e)}')
|
|
except RuntimeError:
|
|
# Loop has been shut down, exit gracefully
|
|
return
|
|
|
|
async def _process_message(self, message: dict):
|
|
"""Process messages received from Redis pub/sub.
|
|
|
|
Handles three types of messages:
|
|
- 'event': Forward an event to a local session
|
|
- 'close_session': Close a local session
|
|
- 'session_closing': Handle remote session closure
|
|
|
|
Args:
|
|
message: The Redis pub/sub message containing the action to perform
|
|
"""
|
|
data = json.loads(message['data'])
|
|
logger.debug(f'got_published_message:{message}')
|
|
message_type = data['message_type']
|
|
|
|
if message_type == 'event':
|
|
# Forward an event to a local session if it exists
|
|
sid = data['sid']
|
|
session = self._local_agent_loops_by_sid.get(sid)
|
|
if session:
|
|
await session.dispatch(data['data'])
|
|
elif message_type == 'close_session':
|
|
# Close a local session if it exists
|
|
sid = data['sid']
|
|
if sid in self._local_agent_loops_by_sid:
|
|
await self._close_session(sid)
|
|
elif message_type == 'session_closing':
|
|
# Handle connections to a session that is closing on another node
|
|
# We only get this in the event of graceful shutdown,
|
|
# which can't be guaranteed - nodes can simply vanish unexpectedly!
|
|
sid = data['sid']
|
|
user_id = data['user_id']
|
|
logger.debug(f'session_closing:{sid}')
|
|
|
|
# Create a list of items to process to avoid modifying dict during iteration
|
|
items = list(self._local_connection_id_to_session_id.items())
|
|
for connection_id, local_sid in items:
|
|
if sid == local_sid:
|
|
logger.warning(
|
|
f'local_connection_to_closing_session:{connection_id}:{sid}'
|
|
)
|
|
await self._handle_remote_conversation_stopped(
|
|
user_id, connection_id
|
|
)
|
|
elif message_type == 'llm_completion':
|
|
# Request extraneous llm completion from session's LLM Registry
|
|
sid = data['sid']
|
|
service_id = data['service_id']
|
|
messages = data['messages']
|
|
llm_config = data['llm_config']
|
|
query_id = data['query_id']
|
|
|
|
session = self._local_agent_loops_by_sid.get(sid)
|
|
if session:
|
|
llm_registry: LLMRegistry = session.llm_registry
|
|
response = await call_sync_from_async(
|
|
llm_registry.request_extraneous_completion,
|
|
service_id,
|
|
llm_config,
|
|
messages,
|
|
)
|
|
await self._get_redis_client().publish(
|
|
'session_msg',
|
|
json.dumps(
|
|
{
|
|
'query_id': query_id,
|
|
'response': response,
|
|
'message_type': 'llm_completion_response',
|
|
}
|
|
),
|
|
)
|
|
elif message_type == 'llm_completion_response':
|
|
query_id = data['query_id']
|
|
llm_response = self._llm_responses.get(query_id)
|
|
if llm_response:
|
|
llm_response.response = data['response']
|
|
llm_response.flag.set()
|
|
|
|
def _get_redis_client(self):
|
|
return getattr(self.sio.manager, 'redis', None)
|
|
|
|
def _get_redis_conversation_key(self, user_id: str | None, conversation_id: str):
|
|
return f'ohcnv:{user_id}:{conversation_id}'
|
|
|
|
def _get_redis_connection_key(
|
|
self, user_id: str, conversation_id: str, connection_id: str
|
|
):
|
|
return f'ohcnct:{user_id}:{conversation_id}:{connection_id}'
|
|
|
|
async def _get_event_store(self, sid, user_id) -> EventStoreABC | None:
|
|
session = self._local_agent_loops_by_sid.get(sid)
|
|
if session:
|
|
logger.debug('found_local_agent_loop', extra={'sid': sid})
|
|
return session.agent_session.event_stream
|
|
|
|
redis = self._get_redis_client()
|
|
key = self._get_redis_conversation_key(user_id, sid)
|
|
value = await redis.get(key)
|
|
if value:
|
|
logger.debug('found_remote_agent_loop', extra={'sid': sid})
|
|
return EventStore(sid, self.file_store, user_id)
|
|
|
|
return None
|
|
|
|
async def get_running_agent_loops(
|
|
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
|
) -> set[str]:
|
|
sids = await self.get_running_agent_loops_locally(user_id, filter_to_sids)
|
|
if not filter_to_sids or len(sids) != len(filter_to_sids):
|
|
remote_sids = await self._get_running_agent_loops_remotely(
|
|
user_id, filter_to_sids
|
|
)
|
|
sids = sids.union(remote_sids)
|
|
return sids
|
|
|
|
async def get_running_agent_loops_locally(
|
|
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
|
) -> set[str]:
|
|
sids = await super().get_running_agent_loops(user_id, filter_to_sids)
|
|
return sids
|
|
|
|
async def _get_running_agent_loops_remotely(
|
|
self,
|
|
user_id: str | None = None,
|
|
filter_to_sids: set[str] | None = None,
|
|
) -> set[str]:
|
|
"""Get the set of conversation IDs running on remote servers.
|
|
|
|
Args:
|
|
user_id: Optional user ID to filter conversations by
|
|
filter_to_sids: Optional set of conversation IDs to filter by
|
|
|
|
Returns:
|
|
A set of conversation IDs running on remote servers
|
|
"""
|
|
if filter_to_sids is not None and not filter_to_sids:
|
|
return set()
|
|
if user_id:
|
|
pattern = self._get_redis_conversation_key(user_id, '*')
|
|
else:
|
|
pattern = self._get_redis_conversation_key('*', '*')
|
|
redis = self._get_redis_client()
|
|
result = set()
|
|
async for key in redis.scan_iter(pattern):
|
|
conversation_id = key.decode().split(':')[2]
|
|
if filter_to_sids is None or conversation_id in filter_to_sids:
|
|
result.add(conversation_id)
|
|
return result
|
|
|
|
async def get_connections(
|
|
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
|
) -> dict[str, str]:
|
|
connections = await super().get_connections(user_id, filter_to_sids)
|
|
if not filter_to_sids or len(connections) != len(filter_to_sids):
|
|
remote_connections = await self._get_connections_remotely(
|
|
user_id, filter_to_sids
|
|
)
|
|
connections.update(remote_connections)
|
|
return connections
|
|
|
|
async def _get_connections_remotely(
|
|
self,
|
|
user_id: str | None = None,
|
|
filter_to_sids: set[str] | None = None,
|
|
) -> dict[str, str]:
|
|
if filter_to_sids is not None and not filter_to_sids:
|
|
return {}
|
|
if user_id:
|
|
pattern = self._get_redis_connection_key(user_id, '*', '*')
|
|
else:
|
|
pattern = self._get_redis_connection_key('*', '*', '*')
|
|
redis = self._get_redis_client()
|
|
result = {}
|
|
async for key in redis.scan_iter(pattern):
|
|
parts = key.decode().split(':')
|
|
conversation_id = parts[2]
|
|
connection_id = parts[3]
|
|
if filter_to_sids is None or conversation_id in filter_to_sids:
|
|
result[connection_id] = conversation_id
|
|
return result
|
|
|
|
async def send_to_event_stream(self, connection_id: str, data: dict) -> None:
|
|
sid = self._local_connection_id_to_session_id.get(connection_id)
|
|
if sid:
|
|
await self.send_event_to_conversation(sid, data)
|
|
|
|
async def request_llm_completion(
|
|
self,
|
|
sid: str,
|
|
service_id: str,
|
|
llm_config: LLMConfig,
|
|
messages: list[dict[str, str]],
|
|
) -> str:
|
|
session = self._local_agent_loops_by_sid.get(sid)
|
|
if session:
|
|
llm_registry = session.llm_registry
|
|
return llm_registry.request_extraneous_completion(
|
|
service_id, llm_config, messages
|
|
)
|
|
|
|
flag = asyncio.Event()
|
|
query_id = str(uuid4())
|
|
query = _LLMResponseRequest(query_id=query_id, response=None, flag=flag)
|
|
self._llm_responses[query_id] = query
|
|
|
|
try:
|
|
redis_client = self._get_redis_client()
|
|
await redis_client.publish(
|
|
'session_msg',
|
|
json.dumps(
|
|
{
|
|
'message_type': 'llm_completion',
|
|
'query_id': query_id,
|
|
'sid': sid,
|
|
'service_id': service_id,
|
|
'llm_config': llm_config,
|
|
'message': messages,
|
|
}
|
|
),
|
|
)
|
|
|
|
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
|
|
await flag.wait()
|
|
|
|
if query.response:
|
|
return query.response
|
|
|
|
raise Exception('Failed to perform LLM completion')
|
|
except TimeoutError:
|
|
raise Exception('Timeout occured')
|
|
|
|
async def send_event_to_conversation(self, sid: str, data: dict):
|
|
if not sid:
|
|
return
|
|
session = self._local_agent_loops_by_sid.get(sid)
|
|
if session:
|
|
await session.dispatch(data)
|
|
else:
|
|
# The session is running on another node
|
|
redis_client = self._get_redis_client()
|
|
await redis_client.publish(
|
|
'session_msg',
|
|
json.dumps({'message_type': 'event', 'sid': sid, 'data': data}),
|
|
)
|
|
|
|
async def close_session(self, sid: str):
|
|
# Send a message to other nodes telling them to close this session if they have the agent loop, and close any connections.
|
|
redis_client = self._get_redis_client()
|
|
await redis_client.publish(
|
|
'session_msg',
|
|
json.dumps({'message_type': 'close_session', 'sid': sid}),
|
|
)
|
|
await self._close_session(sid)
|
|
|
|
async def maybe_start_agent_loop(
|
|
self,
|
|
sid: str,
|
|
settings: Settings,
|
|
user_id: str | None,
|
|
initial_user_msg: MessageAction | None = None,
|
|
replay_json: str | None = None,
|
|
) -> AgentLoopInfo:
|
|
# If we can set the key in redis then no other worker is running this conversation
|
|
redis = self._get_redis_client()
|
|
key = self._get_redis_conversation_key(user_id, sid) # type: ignore
|
|
created = await redis.set(key, 1, nx=True, ex=_REDIS_ENTRY_TIMEOUT_SECONDS)
|
|
if created:
|
|
await self._start_agent_loop(
|
|
sid, settings, user_id, initial_user_msg, replay_json
|
|
)
|
|
|
|
event_store = await self._get_event_store(sid, user_id)
|
|
if not event_store:
|
|
logger.error(
|
|
f'No event stream after starting agent loop: {sid}',
|
|
extra={'sid': sid},
|
|
)
|
|
raise RuntimeError(f'no_event_stream:{sid}')
|
|
|
|
return AgentLoopInfo(
|
|
conversation_id=sid,
|
|
url=self._get_conversation_url(sid),
|
|
session_api_key=None,
|
|
event_store=event_store,
|
|
)
|
|
|
|
async def _update_state_in_redis_task(self):
|
|
while should_continue():
|
|
try:
|
|
await self._update_state_in_redis()
|
|
await asyncio.sleep(_REDIS_UPDATE_INTERVAL_SECONDS)
|
|
except asyncio.CancelledError:
|
|
return
|
|
except Exception:
|
|
try:
|
|
asyncio.get_running_loop()
|
|
logger.exception('error_reading_from_redis')
|
|
except RuntimeError:
|
|
return # Loop has been shut down
|
|
|
|
async def _update_state_in_redis(self):
|
|
"""Refresh all entries in Redis to maintain conversation state across the cluster.
|
|
|
|
This method:
|
|
1. Scans Redis for all conversation keys to build a mapping of conversation IDs to user IDs
|
|
2. Updates Redis entries for all local conversations to prevent them from expiring
|
|
3. Updates Redis entries for all local connections to prevent them from expiring
|
|
|
|
This is critical for maintaining the distributed state and allowing other servers
|
|
to detect when a server has gone down unexpectedly.
|
|
"""
|
|
redis = self._get_redis_client()
|
|
|
|
# Build a mapping of conversation_id -> user_id from existing Redis keys
|
|
pattern = self._get_redis_conversation_key('*', '*')
|
|
conversation_user_ids = {}
|
|
async for key in redis.scan_iter(pattern):
|
|
parts = key.decode().split(':')
|
|
conversation_user_ids[parts[2]] = parts[1]
|
|
|
|
pipe = redis.pipeline()
|
|
|
|
# Add multiple commands to the pipeline
|
|
# First, update all local agent loops
|
|
for sid, session in self._local_agent_loops_by_sid.items():
|
|
if sid:
|
|
await pipe.set(
|
|
self._get_redis_conversation_key(session.user_id, sid),
|
|
1,
|
|
ex=_REDIS_ENTRY_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
# Then, update all local connections
|
|
for (
|
|
connection_id,
|
|
conversation_id,
|
|
) in self._local_connection_id_to_session_id.items():
|
|
user_id = conversation_user_ids.get(conversation_id)
|
|
if user_id:
|
|
await pipe.set(
|
|
self._get_redis_connection_key(
|
|
user_id, conversation_id, connection_id
|
|
),
|
|
1,
|
|
ex=_REDIS_ENTRY_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
# Execute all commands in the pipeline
|
|
await pipe.execute()
|
|
|
|
async def _disconnect_from_stopped(self):
|
|
"""
|
|
Handle connections to conversations that have stopped unexpectedly.
|
|
|
|
This method detects when a local connection is pointing to a conversation
|
|
that was running on another server that has crashed or been terminated
|
|
without proper cleanup. It:
|
|
|
|
1. Identifies local connections to remote conversations
|
|
2. Checks which remote conversations are still running in Redis
|
|
3. Disconnects from conversations that are no longer running
|
|
4. Attempts to restart the conversation locally if possible
|
|
"""
|
|
# Get the remote sessions with local connections
|
|
connected_to_remote_sids = set(
|
|
self._local_connection_id_to_session_id.values()
|
|
) - set(self._local_agent_loops_by_sid.keys())
|
|
if not connected_to_remote_sids:
|
|
return
|
|
|
|
# Get the list of sessions which are actually running
|
|
redis = self._get_redis_client()
|
|
pattern = self._get_redis_conversation_key('*', '*')
|
|
running_remote = set()
|
|
async for key in redis.scan_iter(pattern):
|
|
parts = key.decode().split(':')
|
|
running_remote.add(parts[2])
|
|
|
|
# Get the list of connections locally where the remote agentloop has died.
|
|
stopped_conversation_ids = connected_to_remote_sids - running_remote
|
|
if not stopped_conversation_ids:
|
|
return
|
|
|
|
# Process each connection to a stopped conversation
|
|
items = list(self._local_connection_id_to_session_id.items())
|
|
for connection_id, conversation_id in items:
|
|
if conversation_id in stopped_conversation_ids:
|
|
logger.warning(
|
|
f'local_connection_to_stopped_conversation:{connection_id}:{conversation_id}'
|
|
)
|
|
# Look up the user_id from the database
|
|
with session_maker() as session:
|
|
conversation_metadata = (
|
|
session.query(StoredConversationMetadata)
|
|
.filter(
|
|
StoredConversationMetadata.conversation_id
|
|
== conversation_id
|
|
)
|
|
.first()
|
|
)
|
|
user_id = (
|
|
conversation_metadata.user_id if conversation_metadata else None
|
|
)
|
|
# Handle the stopped conversation asynchronously
|
|
asyncio.create_task(
|
|
self._handle_remote_conversation_stopped(user_id, connection_id) # type: ignore
|
|
)
|
|
|
|
async def _close_disconnected(self):
|
|
async with self._conversations_lock:
|
|
# Create a list of items to process to avoid modifying dict during iteration
|
|
items = list(self._detached_conversations.items())
|
|
for sid, (conversation, detach_time) in items:
|
|
await conversation.disconnect()
|
|
self._detached_conversations.pop(sid, None)
|
|
|
|
close_threshold = time.time() - self.config.sandbox.close_delay
|
|
running_loops = list(self._local_agent_loops_by_sid.items())
|
|
running_loops.sort(key=lambda item: item[1].last_active_ts)
|
|
sid_to_close: list[str] = []
|
|
for sid, session in running_loops:
|
|
state = session.agent_session.get_state()
|
|
if session.last_active_ts < close_threshold and state not in [
|
|
AgentState.RUNNING,
|
|
None,
|
|
]:
|
|
sid_to_close.append(sid)
|
|
|
|
# First we filter out any conversation that has local connections
|
|
connections = await super().get_connections(filter_to_sids=set(sid_to_close))
|
|
connected_sids = set(connections.values())
|
|
sid_to_close = [sid for sid in sid_to_close if sid not in connected_sids]
|
|
|
|
# Next we filter out any conversation that has remote connections
|
|
if sid_to_close:
|
|
connections = await self._get_connections_remotely(
|
|
filter_to_sids=set(sid_to_close)
|
|
)
|
|
connected_sids = {sid for _, sid in connections.items()}
|
|
sid_to_close = [sid for sid in sid_to_close if sid not in connected_sids]
|
|
|
|
await wait_all(
|
|
(self._close_session(sid) for sid in sid_to_close),
|
|
timeout=WAIT_TIME_BEFORE_CLOSE,
|
|
)
|
|
|
|
async def _cleanup_stale(self):
|
|
while should_continue():
|
|
try:
|
|
logger.info(
|
|
'conversation_manager',
|
|
extra={
|
|
'attached': len(self._active_conversations),
|
|
'detached': len(self._detached_conversations),
|
|
'running': len(self._local_agent_loops_by_sid),
|
|
'local_conn': len(self._local_connection_id_to_session_id),
|
|
},
|
|
)
|
|
await self._disconnect_from_stopped()
|
|
await self._close_disconnected()
|
|
await asyncio.sleep(_CLEANUP_INTERVAL_SECONDS)
|
|
except asyncio.CancelledError:
|
|
async with self._conversations_lock:
|
|
for conversation, _ in self._detached_conversations.values():
|
|
await conversation.disconnect()
|
|
self._detached_conversations.clear()
|
|
await wait_all(
|
|
(
|
|
self._close_session(sid)
|
|
for sid in self._local_agent_loops_by_sid
|
|
),
|
|
timeout=WAIT_TIME_BEFORE_CLOSE,
|
|
)
|
|
return
|
|
except Exception:
|
|
logger.warning('error_cleaning_stale', exc_info=True, stack_info=True)
|
|
await asyncio.sleep(_CLEANUP_INTERVAL_SECONDS)
|
|
|
|
async def _close_session(self, sid: str):
|
|
logger.info(f'_close_session:{sid}')
|
|
redis = self._get_redis_client()
|
|
|
|
# Keys to delete from redis
|
|
to_delete = []
|
|
|
|
# Remove connections
|
|
connection_ids_to_remove = list(
|
|
connection_id
|
|
for connection_id, conn_sid in self._local_connection_id_to_session_id.items()
|
|
if sid == conn_sid
|
|
)
|
|
|
|
if connection_ids_to_remove:
|
|
pattern = self._get_redis_connection_key('*', sid, '*')
|
|
async for key in redis.scan_iter(pattern):
|
|
parts = key.decode().split(':')
|
|
connection_id = parts[3]
|
|
if connection_id in connection_ids_to_remove:
|
|
to_delete.append(key)
|
|
|
|
logger.info(f'removing connections: {connection_ids_to_remove}')
|
|
for connection_id in connection_ids_to_remove:
|
|
await self.sio.disconnect(connection_id)
|
|
self._local_connection_id_to_session_id.pop(connection_id, None)
|
|
|
|
# Delete the conversation key if running locally
|
|
session = self._local_agent_loops_by_sid.pop(sid, None)
|
|
if not session:
|
|
logger.info(f'no_session_to_close:{sid}')
|
|
if to_delete:
|
|
redis.delete(*to_delete)
|
|
return
|
|
|
|
to_delete.append(self._get_redis_conversation_key(session.user_id, sid))
|
|
await redis.delete(*to_delete)
|
|
try:
|
|
redis_client = self._get_redis_client()
|
|
if redis_client:
|
|
await redis_client.publish(
|
|
'session_msg',
|
|
json.dumps(
|
|
{
|
|
'sid': session.sid,
|
|
'message_type': 'session_closing',
|
|
'user_id': session.user_id,
|
|
}
|
|
),
|
|
)
|
|
except Exception:
|
|
logger.info(
|
|
'error_publishing_close_session_event', exc_info=True, stack_info=True
|
|
)
|
|
|
|
await session.close()
|
|
logger.info(f'closed_session:{session.sid}')
|
|
|
|
async def get_agent_loop_info(self, user_id=None, filter_to_sids=None):
|
|
# conversation_ids = await self.get_running_agent_loops(user_id=user_id, filter_to_sids=filter_to_sids)
|
|
redis = self._get_redis_client()
|
|
results = []
|
|
if user_id:
|
|
pattern = self._get_redis_conversation_key(user_id, '*')
|
|
else:
|
|
pattern = self._get_redis_conversation_key('*', '*')
|
|
|
|
async for key in redis.scan_iter(pattern):
|
|
uid, conversation_id = key.decode().split(':')[1:]
|
|
if filter_to_sids is None or conversation_id in filter_to_sids:
|
|
results.append(
|
|
AgentLoopInfo(
|
|
conversation_id,
|
|
url=self._get_conversation_url(conversation_id),
|
|
session_api_key=None,
|
|
event_store=EventStore(conversation_id, self.file_store, uid),
|
|
runtime_status=RuntimeStatus.READY,
|
|
)
|
|
)
|
|
return results
|
|
|
|
@classmethod
|
|
def get_instance(
|
|
cls,
|
|
sio: socketio.AsyncServer,
|
|
config: OpenHandsConfig,
|
|
file_store: FileStore,
|
|
server_config: ServerConfig,
|
|
monitoring_listener: MonitoringListener | None,
|
|
) -> ConversationManager:
|
|
return ClusteredConversationManager(
|
|
sio,
|
|
config,
|
|
file_store,
|
|
server_config,
|
|
monitoring_listener, # type: ignore[arg-type]
|
|
)
|
|
|
|
async def _handle_remote_conversation_stopped(
|
|
self, user_id: str, connection_id: str
|
|
):
|
|
"""Handle a situation where a remote conversation has stopped unexpectedly.
|
|
|
|
When a server hosting a conversation crashes or is terminated without proper
|
|
cleanup, this method attempts to recover by:
|
|
1. Verifying the connection and conversation still exist
|
|
2. Checking if we can start a new conversation (within limits)
|
|
3. Restarting the conversation locally if possible
|
|
4. Disconnecting the client if recovery isn't possible
|
|
|
|
Args:
|
|
user_id: The user ID associated with the conversation
|
|
connection_id: The connection ID to handle
|
|
"""
|
|
conversation_id = self._local_connection_id_to_session_id.get(connection_id)
|
|
|
|
# Not finding a user_id or a conversation_id indicates we are in some unknown state
|
|
# so we disconnect
|
|
if not user_id or not conversation_id:
|
|
await self.sio.disconnect(connection_id)
|
|
return
|
|
|
|
# Wait a second for connections to stabilize
|
|
await asyncio.sleep(1)
|
|
|
|
# Check if there are too many loops running - if so disconnect
|
|
response_ids = await self.get_running_agent_loops(user_id)
|
|
if len(response_ids) > self.config.max_concurrent_conversations:
|
|
await self.sio.disconnect(connection_id)
|
|
return
|
|
|
|
# Restart the agent loop
|
|
config = load_openhands_config()
|
|
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
|
settings = await settings_store.load()
|
|
await self.maybe_start_agent_loop(conversation_id, settings, user_id)
|
|
|
|
async def _start_agent_loop(
|
|
self,
|
|
sid: str,
|
|
settings: Settings,
|
|
user_id: str | None,
|
|
initial_user_msg: MessageAction | None = None,
|
|
replay_json: str | None = None,
|
|
) -> Session:
|
|
"""Start an agent loop and add conversation callback subscriber.
|
|
|
|
This method calls the parent implementation and then adds a subscriber
|
|
to the event stream that will invoke conversation callbacks when events occur.
|
|
"""
|
|
# Call the parent method to start the agent loop
|
|
session = await super()._start_agent_loop(
|
|
sid, settings, user_id, initial_user_msg, replay_json
|
|
)
|
|
|
|
# Subscribers run in a different thread - if we are going to access socketio, redis or anything else
|
|
# bound to the main event loop, we need to pass callbacks back to the main event loop.
|
|
loop = asyncio.get_running_loop()
|
|
|
|
# Add a subscriber for conversation callbacks
|
|
def conversation_callback_handler(event):
|
|
"""Handle events by invoking conversation callbacks."""
|
|
try:
|
|
if isinstance(event, AgentStateChangedObservation):
|
|
asyncio.run_coroutine_threadsafe(
|
|
invoke_conversation_callbacks(sid, event), loop
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f'Error invoking conversation callbacks for {sid}: {str(e)}',
|
|
extra={'session_id': sid, 'error': str(e)},
|
|
exc_info=True,
|
|
)
|
|
|
|
# Subscribe to the event stream with our callback handler
|
|
try:
|
|
session.agent_session.event_stream.subscribe(
|
|
EventStreamSubscriber.SERVER,
|
|
conversation_callback_handler,
|
|
'conversation_callbacks',
|
|
)
|
|
except ValueError:
|
|
# Already subscribed - this can happen if the method is called multiple times
|
|
pass
|
|
|
|
return session
|
|
|
|
def get_local_session(self, sid: str) -> Session:
|
|
return self._local_agent_loops_by_sid[sid]
|