Fix closing sessions (again) (#6322)

Co-authored-by: Robert Brennan <accounts@rbren.io>
This commit is contained in:
tofarr
2025-01-16 15:03:38 -07:00
committed by GitHub
parent eff9e07272
commit 313c8eca20
11 changed files with 336 additions and 221 deletions

View File

@@ -158,7 +158,7 @@ async def search_conversations(
for conversation in conversation_metadata_result_set.results
if hasattr(conversation, 'created_at')
)
running_conversations = await session_manager.get_agent_loop_running(
running_conversations = await session_manager.get_running_agent_loops(
get_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(

View File

@@ -1,4 +1,5 @@
import asyncio
import time
from typing import Callable, Optional
from openhands.controller import AgentController
@@ -16,10 +17,10 @@ from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.shutdown_listener import should_continue
WAIT_TIME_BEFORE_CLOSE = 300
WAIT_TIME_BEFORE_CLOSE = 90
WAIT_TIME_BEFORE_CLOSE_INTERVAL = 5
@@ -36,7 +37,8 @@ class AgentSession:
controller: AgentController | None = None
runtime: Runtime | None = None
security_analyzer: SecurityAnalyzer | None = None
_initializing: bool = False
_starting: bool = False
_started_at: float = 0
_closed: bool = False
loop: asyncio.AbstractEventLoop | None = None
@@ -88,7 +90,8 @@ class AgentSession:
if self._closed:
logger.warning('Session closed before starting')
return
self._initializing = True
self._starting = True
self._started_at = time.time()
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
runtime_name=runtime_name,
@@ -109,24 +112,19 @@ class AgentSession:
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)
self._initializing = False
self._starting = False
def close(self):
async def close(self):
"""Closes the Agent session"""
if self._closed:
return
self._closed = True
call_async_from_sync(self._close)
async def _close(self):
seconds_waited = 0
while self._initializing and should_continue():
while self._starting and should_continue():
logger.debug(
f'Waiting for initialization to finish before closing session {self.sid}'
)
await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL)
seconds_waited += WAIT_TIME_BEFORE_CLOSE_INTERVAL
if seconds_waited > WAIT_TIME_BEFORE_CLOSE:
if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE:
logger.error(
f'Waited too long for initialization to finish before closing session {self.sid}'
)
@@ -311,3 +309,12 @@ class AgentSession:
else:
logger.debug('No events found, no state to restore')
return restored_state
def get_state(self) -> AgentState | None:
controller = self.controller
if controller:
return controller.state.agent_state
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
return AgentState.ERROR
return None

View File

@@ -2,6 +2,7 @@ import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Generic, Iterable, TypeVar
from uuid import uuid4
import socketio
@@ -9,26 +10,29 @@ import socketio
from openhands.core.config import AppConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.stream import EventStream, session_exists
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.server.settings import Settings
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.async_utils import wait_all
from openhands.utils.shutdown_listener import should_continue
_REDIS_POLL_TIMEOUT = 1.5
_CHECK_ALIVE_INTERVAL = 15
_CLEANUP_INTERVAL = 15
_CLEANUP_EXCEPTION_WAIT_TIME = 15
MAX_RUNNING_CONVERSATIONS = 3
T = TypeVar('T')
@dataclass
class _SessionIsRunningCheck:
request_id: str
request_sids: list[str]
running_sids: set[str] = field(default_factory=set)
class _ClusterQuery(Generic[T]):
query_id: str
request_ids: set[str] | None
result: T
flag: asyncio.Event = field(default_factory=asyncio.Event)
@@ -38,10 +42,10 @@ class SessionManager:
config: AppConfig
file_store: FileStore
_local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
_redis_listen_task: asyncio.Task | None = None
_session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
_running_sid_queries: dict[str, _ClusterQuery[set[str]]] = field(
default_factory=dict
)
_active_conversations: dict[str, tuple[Conversation, int]] = field(
@@ -52,7 +56,7 @@ class SessionManager:
)
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_cleanup_task: asyncio.Task | None = None
_has_remote_connections_flags: dict[str, asyncio.Event] = field(
_connection_queries: dict[str, _ClusterQuery[dict[str, str]]] = field(
default_factory=dict
)
@@ -60,7 +64,7 @@ class SessionManager:
redis_client = self._get_redis_client()
if redis_client:
self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
return self
async def __aexit__(self, exc_type, exc_value, traceback):
@@ -82,7 +86,7 @@ class SessionManager:
logger.debug('_redis_subscribe')
redis_client = self._get_redis_client()
pubsub = redis_client.pubsub()
await pubsub.subscribe('oh_event')
await pubsub.subscribe('session_msg')
while should_continue():
try:
message = await pubsub.get_message(
@@ -108,59 +112,71 @@ class SessionManager:
session = self._local_agent_loops_by_sid.get(sid)
if session:
await session.dispatch(data['data'])
elif message_type == 'is_session_running':
elif message_type == 'running_agent_loops_query':
# Another node in the cluster is asking if the current node is running the session given.
request_id = data['request_id']
sids = [
sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
]
query_id = data['query_id']
sids = self._get_running_agent_loops_locally(
data.get('user_id'), data.get('filter_to_sids')
)
if sids:
await self._get_redis_client().publish(
'oh_event',
'session_msg',
json.dumps(
{
'request_id': request_id,
'sids': sids,
'message_type': 'session_is_running',
'query_id': query_id,
'sids': list(sids),
'message_type': 'running_agent_loops_response',
}
),
)
elif message_type == 'session_is_running':
request_id = data['request_id']
elif message_type == 'running_agent_loops_response':
query_id = data['query_id']
for sid in data['sids']:
self._last_alive_timestamps[sid] = time.time()
check = self._session_is_running_checks.get(request_id)
if check:
check.running_sids.update(data['sids'])
if len(check.request_sids) == len(check.running_sids):
check.flag.set()
elif message_type == 'has_remote_connections_query':
running_query = self._running_sid_queries.get(query_id)
if running_query:
running_query.result.update(data['sids'])
if running_query.request_ids is not None and len(
running_query.request_ids
) == len(running_query.result):
running_query.flag.set()
elif message_type == 'connections_query':
# Another node in the cluster is asking if the current node is connected to a session
sid = data['sid']
required = sid in self.local_connection_id_to_session_id.values()
if required:
query_id = data['query_id']
connections = self._get_connections_locally(
data.get('user_id'), data.get('filter_to_sids')
)
if connections:
await self._get_redis_client().publish(
'oh_event',
'session_msg',
json.dumps(
{'sid': sid, 'message_type': 'has_remote_connections_response'}
{
'query_id': query_id,
'connections': connections,
'message_type': 'connections_response',
}
),
)
elif message_type == 'has_remote_connections_response':
sid = data['sid']
flag = self._has_remote_connections_flags.get(sid)
if flag:
flag.set()
elif message_type == 'connections_response':
query_id = data['query_id']
connection_query = self._connection_queries.get(query_id)
if connection_query:
connection_query.result.update(**data['connections'])
if connection_query.request_ids is not None and len(
connection_query.request_ids
) == len(connection_query.result):
connection_query.flag.set()
elif message_type == 'close_session':
sid = data['sid']
if sid in self._local_agent_loops_by_sid:
await self._on_close_session(sid)
await self._close_session(sid)
elif message_type == 'session_closing':
# Session closing event - We only get this in the event of graceful shutdown,
# which can't be guaranteed - nodes can simply vanish unexpectedly!
sid = data['sid']
logger.debug(f'session_closing:{sid}')
# Create a list of items to process to avoid modifying dict during iteration
items = list(self.local_connection_id_to_session_id.items())
items = list(self._local_connection_id_to_session_id.items())
for connection_id, local_sid in items:
if sid == local_sid:
logger.warning(
@@ -208,7 +224,7 @@ class SessionManager:
):
logger.info(f'join_conversation:{sid}:{connection_id}')
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
if not event_stream:
return await self.maybe_start_agent_loop(sid, settings, user_id)
@@ -226,7 +242,7 @@ class SessionManager:
self._active_conversations.pop(sid)
self._detached_conversations[sid] = (conversation, time.time())
async def _cleanup_detached_conversations(self):
async def _cleanup_stale(self):
while should_continue():
if self._get_redis_client():
# Debug info for HA envs
@@ -240,7 +256,7 @@ class SessionManager:
f'Running agent loops: {len(self._local_agent_loops_by_sid)}'
)
logger.info(
f'Local connections: {len(self.local_connection_id_to_session_id)}'
f'Local connections: {len(self._local_connection_id_to_session_id)}'
)
try:
async with self._conversations_lock:
@@ -250,97 +266,180 @@ class SessionManager:
await conversation.disconnect()
self._detached_conversations.pop(sid, None)
close_threshold = time.time() - self.config.sandbox.close_delay
running_loops = list(self._local_agent_loops_by_sid.items())
running_loops.sort(key=lambda item: item[1].last_active_ts)
sid_to_close: list[str] = []
for sid, session in running_loops:
state = session.agent_session.get_state()
if session.last_active_ts < close_threshold and state not in [
AgentState.RUNNING,
None,
]:
sid_to_close.append(sid)
connections = self._get_connections_locally(
filter_to_sids=set(sid_to_close)
)
connected_sids = {sid for _, sid in connections.items()}
sid_to_close = [
sid for sid in sid_to_close if sid not in connected_sids
]
if sid_to_close:
connections = await self._get_connections_remotely(
filter_to_sids=set(sid_to_close)
)
connected_sids = {sid for _, sid in connections.items()}
sid_to_close = [
sid for sid in sid_to_close if sid not in connected_sids
]
await wait_all(self._close_session(sid) for sid in sid_to_close)
await asyncio.sleep(_CLEANUP_INTERVAL)
except asyncio.CancelledError:
async with self._conversations_lock:
for conversation, _ in self._detached_conversations.values():
await conversation.disconnect()
self._detached_conversations.clear()
await wait_all(
(
self._close_session(sid)
for sid in self._local_agent_loops_by_sid
),
timeout=WAIT_TIME_BEFORE_CLOSE,
)
return
except Exception as e:
logger.warning(f'error_cleaning_detached_conversations: {str(e)}')
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]:
running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
check_cluster_sids = [sid for sid in sids if sid not in running_sids]
running_cluster_sids = await self.get_agent_loop_running_in_cluster(
check_cluster_sids
)
running_sids.union(running_cluster_sids)
return running_sids
except Exception:
logger.warning('error_cleaning_stale', exc_info=True, stack_info=True)
await asyncio.sleep(_CLEANUP_INTERVAL)
async def is_agent_loop_running(self, sid: str) -> bool:
if await self.is_agent_loop_running_locally(sid):
return True
if await self.is_agent_loop_running_in_cluster(sid):
return True
return False
sids = await self.get_running_agent_loops(filter_to_sids={sid})
return bool(sids)
async def is_agent_loop_running_locally(self, sid: str) -> bool:
return sid in self._local_agent_loops_by_sid
async def get_running_agent_loops(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> set[str]:
"""Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest."""
sids = self._get_running_agent_loops_locally(user_id, filter_to_sids)
remote_sids = await self._get_running_agent_loops_remotely(
user_id, filter_to_sids
)
return sids.union(remote_sids)
async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
running_sids = await self.get_agent_loop_running_in_cluster([sid])
return bool(running_sids)
def _get_running_agent_loops_locally(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> set[str]:
items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items()
if filter_to_sids is not None:
items = (item for item in items if item[0] in filter_to_sids)
if user_id:
items = (item for item in items if item[1].user_id == user_id)
sids = {sid for sid, _ in items}
return sids
async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
async def _get_running_agent_loops_remotely(
self,
user_id: str | None = None,
filter_to_sids: set[str] | None = None,
) -> set[str]:
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
redis_client = self._get_redis_client()
if not redis_client:
return set()
flag = asyncio.Event()
request_id = str(uuid4())
check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
self._session_is_running_checks[request_id] = check
query_id = str(uuid4())
query = _ClusterQuery[set[str]](
query_id=query_id, request_ids=filter_to_sids, result=set()
)
self._running_sid_queries[query_id] = query
try:
logger.debug(f'publish:is_session_running:{sids}')
await redis_client.publish(
'oh_event',
json.dumps(
{
'request_id': request_id,
'sids': sids,
'message_type': 'is_session_running',
}
),
logger.debug(
f'publish:_get_running_agent_loops_remotely_query:{user_id}:{filter_to_sids}'
)
data: dict = {
'query_id': query_id,
'message_type': 'running_agent_loops_query',
}
if user_id:
data['user_id'] = user_id
if filter_to_sids:
data['filter_to_sids'] = list(filter_to_sids)
await redis_client.publish('session_msg', json.dumps(data))
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
await flag.wait()
return check.running_sids
return query.result
except TimeoutError:
# Nobody replied in time
return check.running_sids
return query.result
finally:
self._session_is_running_checks.pop(request_id, None)
self._running_sid_queries.pop(query_id, None)
async def get_connections(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> dict[str, str]:
connection_ids = self._get_connections_locally(user_id, filter_to_sids)
remote_connection_ids = await self._get_connections_remotely(
user_id, filter_to_sids
)
connection_ids.update(**remote_connection_ids)
return connection_ids
def _get_connections_locally(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> dict[str, str]:
connections = dict(**self._local_connection_id_to_session_id)
if filter_to_sids is not None:
connections = {
connection_id: sid
for connection_id, sid in connections.items()
if sid in filter_to_sids
}
if user_id:
for connection_id, sid in list(connections.items()):
session = self._local_agent_loops_by_sid.get(sid)
if not session or session.user_id != user_id:
connections.pop(connection_id)
return connections
async def _get_connections_remotely(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> dict[str, str]:
redis_client = self._get_redis_client()
if not redis_client:
return {}
async def _has_remote_connections(self, sid: str) -> bool:
"""As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
# Create a flag for the callback
flag = asyncio.Event()
self._has_remote_connections_flags[sid] = flag
query_id = str(uuid4())
query = _ClusterQuery[dict[str, str]](
query_id=query_id, request_ids=filter_to_sids, result={}
)
self._connection_queries[query_id] = query
try:
await self._get_redis_client().publish(
'oh_event',
json.dumps(
{
'sid': sid,
'message_type': 'has_remote_connections_query',
}
),
logger.debug(
f'publish:get_connections_remotely_query:{user_id}:{filter_to_sids}'
)
data: dict = {
'query_id': query_id,
'message_type': 'connections_query',
}
if user_id:
data['user_id'] = user_id
if filter_to_sids:
data['filter_to_sids'] = list(filter_to_sids)
await redis_client.publish('session_msg', json.dumps(data))
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
await flag.wait()
result = flag.is_set()
return result
return query.result
except TimeoutError:
# Nobody replied in time
return False
return query.result
finally:
self._has_remote_connections_flags.pop(sid, None)
self._connection_queries.pop(query_id, None)
async def maybe_start_agent_loop(
self, sid: str, settings: Settings, user_id: str | None
@@ -349,8 +448,18 @@ class SessionManager:
session: Session | None = None
if not await self.is_agent_loop_running(sid):
logger.info(f'start_agent_loop:{sid}')
response_ids = await self.get_running_agent_loops(user_id)
if len(response_ids) >= MAX_RUNNING_CONVERSATIONS:
logger.info('too_many_sessions_for:{user_id}')
await self.close_session(next(iter(response_ids)))
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
sid=sid,
file_store=self.file_store,
config=self.config,
sio=self.sio,
user_id=user_id,
)
self._local_agent_loops_by_sid[sid] = session
asyncio.create_task(session.initialize_agent(settings))
@@ -359,7 +468,6 @@ class SessionManager:
if not event_stream:
logger.error(f'No event stream after starting agent loop: {sid}')
raise RuntimeError(f'no_event_stream:{sid}')
asyncio.create_task(self._cleanup_session_later(sid))
return event_stream
async def _get_event_stream(self, sid: str) -> EventStream | None:
@@ -369,7 +477,7 @@ class SessionManager:
logger.info(f'found_local_agent_loop:{sid}')
return session.agent_session.event_stream
if await self.is_agent_loop_running_in_cluster(sid):
if await self._get_running_agent_loops_remotely(filter_to_sids={sid}):
logger.info(f'found_remote_agent_loop:{sid}')
return EventStream(sid, self.file_store)
@@ -377,7 +485,7 @@ class SessionManager:
async def send_to_event_stream(self, connection_id: str, data: dict):
# If there is a local session running, send to that
sid = self.local_connection_id_to_session_id.get(connection_id)
sid = self._local_connection_id_to_session_id.get(connection_id)
if not sid:
raise RuntimeError(f'no_connected_session:{connection_id}')
@@ -393,11 +501,11 @@ class SessionManager:
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
if (
next_alive_check > time.time()
or await self.is_agent_loop_running_in_cluster(sid)
or await self._get_running_agent_loops_remotely(filter_to_sids={sid})
):
# Send the event to the other pod
await redis_client.publish(
'oh_event',
'session_msg',
json.dumps(
{
'sid': sid,
@@ -411,75 +519,37 @@ class SessionManager:
raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
async def disconnect_from_session(self, connection_id: str):
sid = self.local_connection_id_to_session_id.pop(connection_id, None)
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
logger.info(f'disconnect_from_session:{connection_id}:{sid}')
if not sid:
# This can occur if the init action was never run.
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
return
if should_continue():
asyncio.create_task(self._cleanup_session_later(sid))
else:
await self._on_close_session(sid)
async def _cleanup_session_later(self, sid: str):
# Once there have been no connections to a session for a reasonable period, we close it
try:
await asyncio.sleep(self.config.sandbox.close_delay)
finally:
# If the sleep was cancelled, we still want to close these
await self._cleanup_session(sid)
async def _cleanup_session(self, sid: str) -> bool:
# Get local connections
logger.info(f'_cleanup_session:{sid}')
has_local_connections = next(
(True for v in self.local_connection_id_to_session_id.values() if v == sid),
False,
)
if has_local_connections:
return False
# If no local connections, get connections through redis
redis_client = self._get_redis_client()
if redis_client and await self._has_remote_connections(sid):
return False
# We alert the cluster in case they are interested
if redis_client:
await redis_client.publish(
'oh_event',
json.dumps({'sid': sid, 'message_type': 'session_closing'}),
)
await self._on_close_session(sid)
return True
async def close_session(self, sid: str):
session = self._local_agent_loops_by_sid.get(sid)
if session:
await self._on_close_session(sid)
await self._close_session(sid)
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish(
'oh_event',
'session_msg',
json.dumps({'sid': sid, 'message_type': 'close_session'}),
)
async def _on_close_session(self, sid: str):
async def _close_session(self, sid: str):
logger.info(f'_close_session:{sid}')
# Clear up local variables
connection_ids_to_remove = list(
connection_id
for connection_id, conn_sid in self.local_connection_id_to_session_id.items()
for connection_id, conn_sid in self._local_connection_id_to_session_id.items()
if sid == conn_sid
)
logger.info(f'removing connections: {connection_ids_to_remove}')
for connnnection_id in connection_ids_to_remove:
self.local_connection_id_to_session_id.pop(connnnection_id, None)
self._local_connection_id_to_session_id.pop(connnnection_id, None)
session = self._local_agent_loops_by_sid.pop(sid, None)
if not session:
@@ -488,12 +558,17 @@ class SessionManager:
logger.info(f'closing_session:{session.sid}')
# We alert the cluster in case they are interested
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish(
'oh_event',
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
try:
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish(
'session_msg',
json.dumps({'sid': session.sid, 'message_type': 'session_closing'}),
)
except Exception:
logger.info(
'error_publishing_close_session_event', exc_info=True, stack_info=True
)
await call_sync_from_async(session.close)
await session.close()
logger.info(f'closed_session:{session.sid}')

View File

@@ -62,9 +62,17 @@ class Session:
self.loop = asyncio.get_event_loop()
self.user_id = user_id
def close(self):
async def close(self):
if self.sio:
await self.sio.emit(
'oh_event',
event_to_dict(
AgentStateChangedObservation('', AgentState.STOPPED.value)
),
to=ROOM_KEY.format(sid=self.sid),
)
self.is_alive = False
self.agent_session.close()
await self.agent_session.close()
async def initialize_agent(
self,