335 lines
13 KiB
Python

import asyncio
import json
import time
from dataclasses import dataclass, field
import socketio
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import EventStream, session_exists
from openhands.runtime.base import RuntimeUnavailableError
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.files import FileStore
from openhands.utils.shutdown_listener import should_continue
_REDIS_POLL_TIMEOUT = 1.5
_CHECK_ALIVE_INTERVAL = 15
@dataclass
class SessionManager:
sio: socketio.AsyncServer
config: AppConfig
file_store: FileStore
local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
_redis_listen_task: asyncio.Task | None = None
_session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
_has_remote_connections_flags: dict[str, asyncio.Event] = field(
default_factory=dict
)
async def __aenter__(self):
redis_client = self._get_redis_client()
if redis_client:
self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._redis_listen_task:
self._redis_listen_task.cancel()
self._redis_listen_task = None
def _get_redis_client(self):
redis_client = getattr(self.sio.manager, 'redis', None)
return redis_client
async def _redis_subscribe(self):
"""
We use a redis backchannel to send actions between server nodes
"""
logger.debug('_redis_subscribe')
redis_client = self._get_redis_client()
pubsub = redis_client.pubsub()
await pubsub.subscribe('oh_event')
while should_continue():
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=5
)
if message:
await self._process_message(message)
except asyncio.CancelledError:
return
except Exception:
try:
asyncio.get_running_loop()
logger.warning(
'error_reading_from_redis', exc_info=True, stack_info=True
)
except RuntimeError:
return # Loop has been shut down
async def _process_message(self, message: dict):
data = json.loads(message['data'])
logger.debug(f'got_published_message:{message}')
sid = data['sid']
message_type = data['message_type']
if message_type == 'event':
session = self.local_sessions_by_sid.get(sid)
if session:
await session.dispatch(data['data'])
elif message_type == 'is_session_running':
# Another node in the cluster is asking if the current node is running the session given.
session = self.local_sessions_by_sid.get(sid)
if session:
await self._get_redis_client().publish(
'oh_event',
json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
)
elif message_type == 'session_is_running':
self._last_alive_timestamps[sid] = time.time()
flag = self._session_is_running_flags.get(sid)
if flag:
flag.set()
elif message_type == 'has_remote_connections_query':
# Another node in the cluster is asking if the current node is connected to a session
required = sid in self.local_connection_id_to_session_id.values()
if required:
await self._get_redis_client().publish(
'oh_event',
json.dumps(
{'sid': sid, 'message_type': 'has_remote_connections_response'}
),
)
elif message_type == 'has_remote_connections_response':
flag = self._has_remote_connections_flags.get(sid)
if flag:
flag.set()
elif message_type == 'session_closing':
# Session closing event - We only get this in the event of graceful shutdown,
# which can't be guaranteed - nodes can simply vanish unexpectedly!
logger.debug(f'session_closing:{sid}')
for (
connection_id,
local_sid,
) in self.local_connection_id_to_session_id.items():
if sid == local_sid:
logger.warning(
'local_connection_to_closing_session:{connection_id}:{sid}'
)
await self.sio.disconnect(connection_id)
async def attach_to_conversation(self, sid: str) -> Conversation | None:
start_time = time.time()
if not await session_exists(sid, self.file_store):
return None
c = Conversation(sid, file_store=self.file_store, config=self.config)
try:
await c.connect()
except RuntimeUnavailableError as e:
logger.error(f'Error connecting to conversation {c.sid}: {e}')
return None
end_time = time.time()
logger.info(
f'Conversation {c.sid} connected in {end_time - start_time} seconds'
)
return c
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()
async def init_or_join_session(
self, sid: str, connection_id: str, session_init_data: SessionInitData
):
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid
# If we have a local session running, use that
session = self.local_sessions_by_sid.get(sid)
if session:
logger.info(f'found_local_session:{sid}')
return session.agent_session.event_stream
# If there is a remote session running, retrieve existing events for that
redis_client = self._get_redis_client()
if redis_client and await self._is_session_running_in_cluster(sid):
return EventStream(sid, self.file_store)
return await self.start_local_session(sid, session_init_data)
async def _is_session_running_in_cluster(self, sid: str) -> bool:
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
# Create a flag for the callback
flag = asyncio.Event()
self._session_is_running_flags[sid] = flag
try:
logger.debug(f'publish:is_session_running:{sid}')
await self._get_redis_client().publish(
'oh_event',
json.dumps(
{
'sid': sid,
'message_type': 'is_session_running',
}
),
)
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
await flag.wait()
result = flag.is_set()
return result
except TimeoutError:
# Nobody replied in time
return False
finally:
self._session_is_running_flags.pop(sid)
async def _has_remote_connections(self, sid: str) -> bool:
"""As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
# Create a flag for the callback
flag = asyncio.Event()
self._has_remote_connections_flags[sid] = flag
try:
await self._get_redis_client().publish(
'oh_event',
json.dumps(
{
'sid': sid,
'message_type': 'has_remote_connections_query',
}
),
)
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
await flag.wait()
result = flag.is_set()
return result
except TimeoutError:
# Nobody replied in time
return False
finally:
self._has_remote_connections_flags.pop(sid)
async def start_local_session(self, sid: str, session_init_data: SessionInitData):
# Start a new local session
logger.info(f'start_new_local_session:{sid}')
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
)
self.local_sessions_by_sid[sid] = session
await session.initialize_agent(session_init_data)
return session.agent_session.event_stream
async def send_to_event_stream(self, connection_id: str, data: dict):
# If there is a local session running, send to that
sid = self.local_connection_id_to_session_id.get(connection_id)
if not sid:
raise RuntimeError(f'no_connected_session:{connection_id}')
session = self.local_sessions_by_sid.get(sid)
if session:
await session.dispatch(data)
return
redis_client = self._get_redis_client()
if redis_client:
# If we have a recent report that the session is alive in another pod
last_alive_at = self._last_alive_timestamps.get(sid) or 0
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
if next_alive_check > time.time() or self._is_session_running_in_cluster(
sid
):
# Send the event to the other pod
await redis_client.publish(
'oh_event',
json.dumps(
{
'sid': sid,
'message_type': 'event',
'data': data,
}
),
)
return
raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
async def disconnect_from_session(self, connection_id: str):
sid = self.local_connection_id_to_session_id.pop(connection_id, None)
if not sid:
# This can occur if the init action was never run.
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
return
session = self.local_sessions_by_sid.get(sid)
if session:
logger.info(f'close_session:{connection_id}:{sid}')
if should_continue():
asyncio.create_task(self._cleanup_session_later(session))
else:
await self._close_session(session)
async def _cleanup_session_later(self, session: Session):
# Once there have been no connections to a session for a reasonable period, we close it
try:
await asyncio.sleep(self.config.sandbox.close_delay)
finally:
# If the sleep was cancelled, we still want to close these
await self._cleanup_session(session)
async def _cleanup_session(self, session: Session):
# Get local connections
has_local_connections = next(
(
True
for v in self.local_connection_id_to_session_id.values()
if v == session.sid
),
False,
)
if has_local_connections:
return False
# If no local connections, get connections through redis
redis_client = self._get_redis_client()
if redis_client and await self._has_remote_connections(session.sid):
return False
# We alert the cluster in case they are interested
if redis_client:
await redis_client.publish(
'oh_event',
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
)
await self._close_session(session)
async def _close_session(self, session: Session):
logger.info(f'_close_session:{session.sid}')
# Clear up local variables
connection_ids_to_remove = list(
connection_id
for connection_id, sid in self.local_connection_id_to_session_id.items()
if sid == session.sid
)
for connnnection_id in connection_ids_to_remove:
self.local_connection_id_to_session_id.pop(connnnection_id, None)
self.local_sessions_by_sid.pop(session.sid, None)
# We alert the cluster in case they are interested
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish(
'oh_event',
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
)
session.close()