Fix for dying sessions/runtimes (#5755)

This commit is contained in:
Robert Brennan 2024-12-23 11:00:05 -05:00 committed by GitHub
parent d62cf7e731
commit faf8b5829c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 136 additions and 70 deletions

View File

@ -38,7 +38,7 @@ class SandboxConfig:
remote_runtime_api_url: str = 'http://localhost:8000'
local_runtime_url: str = 'http://localhost'
keep_runtime_alive: bool = False
keep_runtime_alive: bool = True
rm_all_containers: bool = False
api_key: str | None = None
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.server.routes.settings import SettingsStoreImpl
from openhands.server.session.session_init_data import SessionInitData
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import config, session_manager
from openhands.storage.conversation.conversation_store import (
ConversationMetadata,
@ -47,7 +47,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
session_init_args['github_token'] = github_token
session_init_args['selected_repository'] = data.selected_repository
session_init_data = SessionInitData(**session_init_args)
conversation_init_data = ConversationInitData(**session_init_args)
conversation_store = await ConversationStore.get_instance(config)
@ -70,5 +70,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
)
)
await session_manager.start_agent_loop(conversation_id, session_init_data)
await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data
)
return JSONResponse(content={'status': 'ok', 'conversation_id': conversation_id})

View File

@ -4,7 +4,7 @@ from openhands.server.settings import Settings
@dataclass
class SessionInitData(Settings):
class ConversationInitData(Settings):
"""
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""

View File

@ -10,8 +10,8 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import EventStream, session_exists
from openhands.server.session.conversation import Conversation
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.session.session import ROOM_KEY, Session
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.shutdown_listener import should_continue
@ -32,7 +32,7 @@ class SessionManager:
sio: socketio.AsyncServer
config: AppConfig
file_store: FileStore
local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
_local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
_redis_listen_task: asyncio.Task | None = None
@ -100,12 +100,12 @@ class SessionManager:
sid = data['sid']
message_type = data['message_type']
if message_type == 'event':
session = self.local_sessions_by_sid.get(sid)
session = self._local_agent_loops_by_sid.get(sid)
if session:
await session.dispatch(data['data'])
elif message_type == 'is_session_running':
# Another node in the cluster is asking if the current node is running the session given.
session = self.local_sessions_by_sid.get(sid)
session = self._local_agent_loops_by_sid.get(sid)
if session:
await self._get_redis_client().publish(
'oh_event',
@ -183,19 +183,15 @@ class SessionManager:
self.local_connection_id_to_session_id[connection_id] = sid
# If we have a local session running, use that
session = self.local_sessions_by_sid.get(sid)
session = self._local_agent_loops_by_sid.get(sid)
if session:
logger.info(f'found_local_session:{sid}')
return session.agent_session.event_stream
# If there is a remote session running, retrieve existing events for that
redis_client = self._get_redis_client()
if redis_client and await self._is_session_running_in_cluster(sid):
if await self._is_agent_loop_running_in_cluster(sid):
return EventStream(sid, self.file_store)
raise ConversationDoesNotExistError(
f'no_conversation_for_id:{connection_id}:{sid}'
)
return await self.maybe_start_agent_loop(sid)
async def detach_from_conversation(self, conversation: Conversation):
sid = conversation.sid
@ -232,14 +228,29 @@ class SessionManager:
logger.warning('error_cleaning_detached_conversations', exc_info=True)
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
async def _is_session_running_in_cluster(self, sid: str) -> bool:
async def _is_agent_loop_running(self, sid: str) -> bool:
if await self._is_agent_loop_running_locally(sid):
return True
if await self._is_agent_loop_running_in_cluster(sid):
return True
return False
async def _is_agent_loop_running_locally(self, sid: str) -> bool:
if self._local_agent_loops_by_sid.get(sid, None):
return True
return False
async def _is_agent_loop_running_in_cluster(self, sid: str) -> bool:
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
# Create a flag for the callback
redis_client = self._get_redis_client()
if not redis_client:
return False
flag = asyncio.Event()
self._session_is_running_flags[sid] = flag
try:
logger.debug(f'publish:is_session_running:{sid}')
await self._get_redis_client().publish(
await redis_client.publish(
'oh_event',
json.dumps(
{
@ -285,14 +296,24 @@ class SessionManager:
finally:
self._has_remote_connections_flags.pop(sid, None)
async def start_agent_loop(self, sid: str, session_init_data: SessionInitData):
logger.info(f'start_agent_loop:{sid}')
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
)
self.local_sessions_by_sid[sid] = session
await session.initialize_agent(session_init_data)
return session.agent_session.event_stream
async def maybe_start_agent_loop(
self, sid: str, conversation_init_data: ConversationInitData | None = None
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None
if not await self._is_agent_loop_running_locally(sid):
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
)
self._local_agent_loops_by_sid[sid] = session
if not await self._is_agent_loop_running_in_cluster(sid):
logger.info(f'start_agent_loop:{sid}')
await session.initialize_agent(conversation_init_data)
session = self._local_agent_loops_by_sid.get(sid)
if session is not None:
return session.agent_session.event_stream
raise RuntimeError(f'no_session:{sid}')
async def send_to_event_stream(self, connection_id: str, data: dict):
# If there is a local session running, send to that
@ -300,7 +321,7 @@ class SessionManager:
if not sid:
raise RuntimeError(f'no_connected_session:{connection_id}')
session = self.local_sessions_by_sid.get(sid)
session = self._local_agent_loops_by_sid.get(sid)
if session:
await session.dispatch(data)
return
@ -310,7 +331,7 @@ class SessionManager:
# If we have a recent report that the session is alive in another pod
last_alive_at = self._last_alive_timestamps.get(sid) or 0
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
if next_alive_check > time.time() or self._is_session_running_in_cluster(
if next_alive_check > time.time() or self._is_agent_loop_running_in_cluster(
sid
):
# Send the event to the other pod
@ -387,7 +408,7 @@ class SessionManager:
for connnnection_id in connection_ids_to_remove:
self.local_connection_id_to_session_id.pop(connnnection_id, None)
session = self.local_sessions_by_sid.pop(sid, None)
session = self._local_agent_loops_by_sid.pop(sid, None)
if not session:
logger.warning(f'no_session_to_close:{sid}')
return

View File

@ -1,4 +1,5 @@
import asyncio
import json
import time
from copy import deepcopy
@ -21,8 +22,10 @@ from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.server.session.agent_session import AgentSession
from openhands.server.session.session_init_data import SessionInitData
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_init_data_filename
from openhands.utils.async_utils import call_sync_from_async
ROOM_KEY = 'room:{sid}'
@ -35,6 +38,7 @@ class Session:
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
config: AppConfig
file_store: FileStore
def __init__(
self,
@ -46,6 +50,7 @@ class Session:
self.sid = sid
self.sio = sio
self.last_active_ts = int(time.time())
self.file_store = file_store
self.agent_session = AgentSession(
sid, file_store, status_callback=self.queue_status_message
)
@ -60,35 +65,63 @@ class Session:
self.is_alive = False
self.agent_session.close()
async def initialize_agent(self, session_init_data: SessionInitData):
async def _restore_init_data(self, sid: str) -> ConversationInitData:
# FIXME: we should not store/restore this data once we have server-side
# LLM configs. Should be done by 1/1/2025
json_str = await call_sync_from_async(
self.file_store.read, get_conversation_init_data_filename(sid)
)
data = json.loads(json_str)
return ConversationInitData(**data)
async def _save_init_data(self, sid: str, init_data: ConversationInitData):
# FIXME: we should not store/restore this data once we have server-side
# LLM configs. Should be done by 1/1/2025
json_str = json.dumps(init_data.__dict__)
await call_sync_from_async(
self.file_store.write, get_conversation_init_data_filename(sid), json_str
)
async def initialize_agent(
self, conversation_init_data: ConversationInitData | None = None
):
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
)
# Extract the agent-relevant arguments from the request
agent_cls = session_init_data.agent or self.config.default_agent
if conversation_init_data is None:
try:
conversation_init_data = await self._restore_init_data(self.sid)
except FileNotFoundError:
logger.error(f'User settings not found for session {self.sid}')
raise RuntimeError('User settings not found')
agent_cls = conversation_init_data.agent or self.config.default_agent
self.config.security.confirmation_mode = (
self.config.security.confirmation_mode
if session_init_data.confirmation_mode is None
else session_init_data.confirmation_mode
if conversation_init_data.confirmation_mode is None
else conversation_init_data.confirmation_mode
)
self.config.security.security_analyzer = (
session_init_data.security_analyzer
conversation_init_data.security_analyzer
or self.config.security.security_analyzer
)
max_iterations = session_init_data.max_iterations or self.config.max_iterations
max_iterations = (
conversation_init_data.max_iterations or self.config.max_iterations
)
# override default LLM config
default_llm_config = self.config.get_llm_config()
default_llm_config.model = (
session_init_data.llm_model or default_llm_config.model
conversation_init_data.llm_model or default_llm_config.model
)
default_llm_config.api_key = (
session_init_data.llm_api_key or default_llm_config.api_key
conversation_init_data.llm_api_key or default_llm_config.api_key
)
default_llm_config.base_url = (
session_init_data.llm_base_url or default_llm_config.base_url
conversation_init_data.llm_base_url or default_llm_config.base_url
)
await self._save_init_data(self.sid, conversation_init_data)
# TODO: override other LLM config & agent config groups (#2075)
@ -105,8 +138,8 @@ class Session:
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
github_token=session_init_data.github_token,
selected_repository=session_init_data.selected_repository,
github_token=conversation_init_data.github_token,
selected_repository=conversation_init_data.selected_repository,
)
except Exception as e:
logger.exception(f'Error creating controller: {e}')

View File

@ -15,3 +15,7 @@ def get_conversation_event_filename(sid: str, id: int) -> str:
def get_conversation_metadata_filename(sid: str) -> str:
return f'{get_conversation_dir(sid)}metadata.json'
def get_conversation_init_data_filename(sid: str) -> str:
return f'{get_conversation_dir(sid)}init.json'

View File

@ -8,7 +8,7 @@ from openhands.server.settings import Settings
class SettingsStore(ABC):
"""
Storage for SessionInitData. May or may not support multiple users depending on the environment
Storage for ConversationInitData. May or may not support multiple users depending on the environment
"""
@abstractmethod

View File

@ -6,8 +6,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.core.config.app_config import AppConfig
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.session.manager import SessionManager
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.memory import InMemoryFileStore
@ -41,7 +41,7 @@ async def test_session_not_running_in_cluster():
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
result = await session_manager._is_session_running_in_cluster(
result = await session_manager._is_agent_loop_running_in_cluster(
'non-existant-session'
)
assert result is False
@ -65,7 +65,7 @@ async def test_session_is_running_in_cluster():
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
result = await session_manager._is_session_running_in_cluster(
result = await session_manager._is_agent_loop_running_in_cluster(
'existing-session'
)
assert result is True
@ -83,8 +83,8 @@ async def test_init_new_local_session():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_session_running_in_cluster_mock = AsyncMock()
is_session_running_in_cluster_mock.return_value = False
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = False
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
@ -93,14 +93,16 @@ async def test_init_new_local_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
is_session_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
await session_manager.start_agent_loop('new-session-id', SessionInitData())
await session_manager.maybe_start_agent_loop(
'new-session-id', ConversationInitData()
)
await session_manager.join_conversation('new-session-id', 'new-session-id')
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 1
@ -113,8 +115,8 @@ async def test_join_local_session():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_session_running_in_cluster_mock = AsyncMock()
is_session_running_in_cluster_mock.return_value = False
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = False
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@ -123,14 +125,16 @@ async def test_join_local_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
is_session_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
await session_manager.start_agent_loop('new-session-id', SessionInitData())
await session_manager.maybe_start_agent_loop(
'new-session-id', ConversationInitData()
)
await session_manager.join_conversation('new-session-id', 'new-session-id')
await session_manager.join_conversation('new-session-id', 'new-session-id')
assert session_instance.initialize_agent.call_count == 1
@ -144,8 +148,8 @@ async def test_join_cluster_session():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_session_running_in_cluster_mock = AsyncMock()
is_session_running_in_cluster_mock.return_value = True
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = True
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@ -154,8 +158,8 @@ async def test_join_cluster_session():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
is_session_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
async with SessionManager(
@ -173,8 +177,8 @@ async def test_add_to_local_event_stream():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_session_running_in_cluster_mock = AsyncMock()
is_session_running_in_cluster_mock.return_value = False
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = False
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@ -183,14 +187,16 @@ async def test_add_to_local_event_stream():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
is_session_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
await session_manager.start_agent_loop('new-session-id', SessionInitData())
await session_manager.maybe_start_agent_loop(
'new-session-id', ConversationInitData()
)
await session_manager.join_conversation('new-session-id', 'connection-id')
await session_manager.send_to_event_stream(
'connection-id', {'event_type': 'some_event'}
@ -205,8 +211,8 @@ async def test_add_to_cluster_event_stream():
mock_session = MagicMock()
mock_session.return_value = session_instance
sio = get_mock_sio()
is_session_running_in_cluster_mock = AsyncMock()
is_session_running_in_cluster_mock.return_value = True
is_agent_loop_running_in_cluster_mock = AsyncMock()
is_agent_loop_running_in_cluster_mock.return_value = True
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
@ -215,8 +221,8 @@ async def test_add_to_cluster_event_stream():
AsyncMock(),
),
patch(
'openhands.server.session.manager.SessionManager._is_session_running_in_cluster',
is_session_running_in_cluster_mock,
'openhands.server.session.manager.SessionManager._is_agent_loop_running_in_cluster',
is_agent_loop_running_in_cluster_mock,
),
):
async with SessionManager(