From 366fd7ab8abbcf3177c0c561f5cc66661bc06bfe Mon Sep 17 00:00:00 2001 From: Robert Brennan Date: Fri, 7 Mar 2025 11:19:50 -0500 Subject: [PATCH] Improve agent loop tracking and make concurrent limit configurable (#6945) Co-authored-by: openhands Co-authored-by: Tim O'Farrell --- openhands/core/config/app_config.py | 3 + openhands/llm/llm.py | 6 +- .../conversation_manager.py | 8 +- .../standalone_conversation_manager.py | 100 +++++++++++++++--- .../server/routes/manage_conversations.py | 40 +------ openhands/server/session/agent_session.py | 25 +++-- openhands/server/session/session.py | 3 +- openhands/server/shared.py | 13 +-- .../conversation/conversation_store.py | 8 ++ tests/unit/test_file_conversation_store.py | 35 ++++++ 10 files changed, 164 insertions(+), 77 deletions(-) diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index 3f12bb4c30..b0cf1a8037 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -83,6 +83,9 @@ class AppConfig(BaseModel): cli_multiline_input: bool = Field(default=False) conversation_max_age_seconds: int = Field(default=864000) # 10 days in seconds enable_default_condenser: bool = Field(default=True) + max_concurrent_conversations: int = Field( + default=3 + ) # Maximum number of concurrent agent loops allowed per user defaults_dict: ClassVar[dict] = {} diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index c392060c25..bed6424278 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -8,7 +8,6 @@ from typing import Any, Callable import requests from openhands.core.config import LLMConfig -from openhands.utils.ensure_httpx_close import ensure_httpx_close with warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -238,9 +237,8 @@ class LLM(RetryMixin, DebugMixin): # Record start time for latency measurement start_time = time.time() - with ensure_httpx_close(): - # we don't support streaming here, thus we get a ModelResponse - resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) + # we don't support streaming here, thus we get a ModelResponse + resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) # Calculate and record latency latency = time.time() - start_time diff --git a/openhands/server/conversation_manager/conversation_manager.py b/openhands/server/conversation_manager/conversation_manager.py index e96458c5eb..35c207de9f 100644 --- a/openhands/server/conversation_manager/conversation_manager.py +++ b/openhands/server/conversation_manager/conversation_manager.py @@ -7,8 +7,11 @@ import socketio from openhands.core.config import AppConfig from openhands.events.action import MessageAction from openhands.events.stream import EventStream +from openhands.server.config.server_config import ServerConfig +from openhands.server.monitoring import MonitoringListener from openhands.server.session.conversation import Conversation from openhands.server.settings import Settings +from openhands.storage.conversation.conversation_store import ConversationStore from openhands.storage.files import FileStore @@ -23,6 +26,7 @@ class ConversationManager(ABC): sio: socketio.AsyncServer config: AppConfig file_store: FileStore + conversation_store: ConversationStore @abstractmethod async def __aenter__(self): @@ -92,5 +96,7 @@ class ConversationManager(ABC): sio: socketio.AsyncServer, config: AppConfig, file_store: FileStore, + server_config: ServerConfig, + monitoring_listener: MonitoringListener, ) -> ConversationManager: - """Get a store for the user represented by the token given""" + """Get a conversation manager instance""" diff --git a/openhands/server/conversation_manager/standalone_conversation_manager.py b/openhands/server/conversation_manager/standalone_conversation_manager.py index b5efd0f812..5fae0d7d9a 100644 --- a/openhands/server/conversation_manager/standalone_conversation_manager.py +++ b/openhands/server/conversation_manager/standalone_conversation_manager.py @@ -1,7 +1,8 @@ import asyncio import time from dataclasses import dataclass, field -from typing import Iterable +from datetime import datetime, timezone +from typing import Callable, Iterable, Type import socketio @@ -11,20 +12,24 @@ from openhands.core.logger import openhands_logger as logger from openhands.core.schema.agent import AgentState from openhands.events.action import MessageAction from openhands.events.observation.agent import AgentStateChangedObservation -from openhands.events.stream import EventStream, session_exists +from openhands.events.stream import EventStream, EventStreamSubscriber, session_exists +from openhands.server.config.server_config import ServerConfig from openhands.server.monitoring import MonitoringListener from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE from openhands.server.session.conversation import Conversation from openhands.server.session.session import ROOM_KEY, Session from openhands.server.settings import Settings +from openhands.storage.conversation.conversation_store import ConversationStore +from openhands.storage.data_models.conversation_metadata import ConversationMetadata from openhands.storage.files import FileStore -from openhands.utils.async_utils import wait_all +from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all +from openhands.utils.import_utils import get_impl from openhands.utils.shutdown_listener import should_continue from .conversation_manager import ConversationManager _CLEANUP_INTERVAL = 15 -MAX_RUNNING_CONVERSATIONS = 3 +UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id' @dataclass @@ -34,6 +39,7 @@ class StandaloneConversationManager(ConversationManager): sio: socketio.AsyncServer config: AppConfig file_store: FileStore + server_config: ServerConfig # Defaulting monitoring_listener for temp backward compatibility. monitoring_listener: MonitoringListener = MonitoringListener() _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict) @@ -46,6 +52,7 @@ class StandaloneConversationManager(ConversationManager): ) _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _cleanup_task: asyncio.Task | None = None + _conversation_store_class: Type | None = None async def __aenter__(self): self._cleanup_task = asyncio.create_task(self._cleanup_stale()) @@ -146,7 +153,7 @@ class StandaloneConversationManager(ConversationManager): sid_to_close.append(sid) connections = await self.get_connections( - filter_to_sids=set(sid_to_close) + filter_to_sids=set(sid_to_close) # get_connections expects a set ) connected_sids = {sid for _, sid in connections.items()} sid_to_close = [ @@ -170,15 +177,36 @@ class StandaloneConversationManager(ConversationManager): logger.error('error_cleaning_stale') await asyncio.sleep(_CLEANUP_INTERVAL) + async def _get_conversation_store(self, user_id: str | None) -> ConversationStore: + conversation_store_class = self._conversation_store_class + if not conversation_store_class: + self._conversation_store_class = conversation_store_class = get_impl( + ConversationStore, # type: ignore + self.server_config.conversation_store_class, + ) + store = await conversation_store_class.get_instance(self.config, user_id) + return store + async def get_running_agent_loops( self, user_id: str | None = None, filter_to_sids: set[str] | None = None ) -> set[str]: - """Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest.""" + """Get the running session ids in chronological order (oldest first). + + If a user is supplied, then the results are limited to session ids for that user. + If a set of filter_to_sids is supplied, then results are limited to these ids of interest. + + Returns: + A set of session IDs + """ + # Get all items and convert to list for sorting items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items() + + # Filter items if needed if filter_to_sids is not None: items = (item for item in items if item[0] in filter_to_sids) if user_id: items = (item for item in items if item[1].user_id == user_id) + sids = {sid for sid, _ in items} return sids @@ -212,12 +240,16 @@ class StandaloneConversationManager(ConversationManager): logger.info(f'start_agent_loop:{sid}') response_ids = await self.get_running_agent_loops(user_id) - if len(response_ids) >= MAX_RUNNING_CONVERSATIONS: + if len(response_ids) >= self.config.max_concurrent_conversations: logger.info('too_many_sessions_for:{user_id}') - # Order is not guaranteed, but response_ids tend to be in descending chronological order - # By reversing, we are likely to pick the oldest (or at least an older) conversation - session_id = next(iter(reversed(list(response_ids)))) - await self.close_session(session_id) + # Get the conversations sorted (oldest first) + conversation_store = await self._get_conversation_store(user_id) + conversations = await conversation_store.get_all_metadata(response_ids) + conversations.sort(key=_last_updated_at_key, reverse=True) + + while len(conversations) >= self.config.max_concurrent_conversations: + oldest_conversation_id = conversations.pop().conversation_id + await self.close_session(oldest_conversation_id) session = Session( sid=sid, @@ -229,6 +261,15 @@ class StandaloneConversationManager(ConversationManager): ) self._local_agent_loops_by_sid[sid] = session asyncio.create_task(session.initialize_agent(settings, initial_user_msg)) + # This does not get added when resuming an existing conversation + try: + session.agent_session.event_stream.subscribe( + EventStreamSubscriber.SERVER, + self._create_conversation_update_callback(user_id, sid), + UPDATED_AT_CALLBACK_ID, + ) + except ValueError: + pass # Already subscribed - take no action event_stream = await self._get_event_stream(sid) if not event_stream: @@ -298,8 +339,41 @@ class StandaloneConversationManager(ConversationManager): sio: socketio.AsyncServer, config: AppConfig, file_store: FileStore, - monitoring_listener: MonitoringListener | None = None, + server_config: ServerConfig, + monitoring_listener: MonitoringListener | None, ) -> ConversationManager: return StandaloneConversationManager( - sio, config, file_store, monitoring_listener or MonitoringListener() + sio, + config, + file_store, + server_config, + monitoring_listener or MonitoringListener(), ) + + def _create_conversation_update_callback( + self, user_id: str | None, conversation_id: str + ) -> Callable: + def callback(*args, **kwargs): + call_async_from_sync( + self._update_timestamp_for_conversation, + GENERAL_TIMEOUT, + user_id, + conversation_id, + ) + + return callback + + async def _update_timestamp_for_conversation( + self, user_id: str, conversation_id: str + ): + conversation_store = await self._get_conversation_store(user_id) + conversation = await conversation_store.get_metadata(conversation_id) + conversation.last_updated_at = datetime.now(timezone.utc) + await conversation_store.save_metadata(conversation) + + +def _last_updated_at_key(conversation: ConversationMetadata) -> float: + last_updated_at = conversation.last_updated_at + if last_updated_at is None: + return 0.0 + return last_updated_at.timestamp() diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index e74c6019b2..d218ae5213 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -1,6 +1,5 @@ import uuid from datetime import datetime, timezone -from typing import Callable from fastapi import APIRouter, Body, Request, status from fastapi.responses import JSONResponse @@ -8,7 +7,6 @@ from pydantic import BaseModel, SecretStr from openhands.core.logger import openhands_logger as logger from openhands.events.action.message import MessageAction -from openhands.events.stream import EventStreamSubscriber from openhands.integrations.github.github_service import GithubServiceImpl from openhands.runtime import get_runtime_cls from openhands.server.auth import get_github_token, get_idp_token, get_user_id @@ -27,14 +25,9 @@ from openhands.server.shared import ( from openhands.server.types import LLMAuthenticationError, MissingSettingsError from openhands.storage.data_models.conversation_metadata import ConversationMetadata from openhands.storage.data_models.conversation_status import ConversationStatus -from openhands.utils.async_utils import ( - GENERAL_TIMEOUT, - call_async_from_sync, - wait_all, -) +from openhands.utils.async_utils import wait_all app = APIRouter(prefix='/api') -UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id' class InitSessionRequest(BaseModel): @@ -119,17 +112,9 @@ async def _create_new_conversation( content=user_msg or '', image_urls=image_urls or [], ) - event_stream = await conversation_manager.maybe_start_agent_loop( + await conversation_manager.maybe_start_agent_loop( conversation_id, conversation_init_data, user_id, initial_message_action ) - try: - event_stream.subscribe( - EventStreamSubscriber.SERVER, - _create_conversation_update_callback(user_id, conversation_id), - UPDATED_AT_CALLBACK_ID, - ) - except ValueError: - pass # Already subscribed - take no action logger.info(f'Finished initializing conversation {conversation_id}') return conversation_id @@ -307,24 +292,3 @@ async def _get_conversation_info( f'Error loading conversation {conversation.conversation_id}: {str(e)}', ) return None - - -def _create_conversation_update_callback( - user_id: str | None, conversation_id: str -) -> Callable: - def callback(*args, **kwargs): - call_async_from_sync( - _update_timestamp_for_conversation, - GENERAL_TIMEOUT, - user_id, - conversation_id, - ) - - return callback - - -async def _update_timestamp_for_conversation(user_id: str, conversation_id: str): - conversation_store = await ConversationStoreImpl.get_instance(config, user_id) - conversation = await conversation_store.get_metadata(conversation_id) - conversation.last_updated_at = datetime.now(timezone.utc) - await conversation_store.save_metadata(conversation) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 8f382f7dc5..cb100e2814 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -131,16 +131,18 @@ class AgentSession: 'github_token': github_token.get_secret_value(), } ) - if initial_message: - self.event_stream.add_event(initial_message, EventSource.USER) - self.event_stream.add_event( - ChangeAgentStateAction(AgentState.RUNNING), EventSource.ENVIRONMENT - ) - else: - self.event_stream.add_event( - ChangeAgentStateAction(AgentState.AWAITING_USER_INPUT), - EventSource.ENVIRONMENT, - ) + if not self._closed: + if initial_message: + self.event_stream.add_event(initial_message, EventSource.USER) + self.event_stream.add_event( + ChangeAgentStateAction(AgentState.RUNNING), + EventSource.ENVIRONMENT, + ) + else: + self.event_stream.add_event( + ChangeAgentStateAction(AgentState.AWAITING_USER_INPUT), + EventSource.ENVIRONMENT, + ) finished = True finally: self._starting = False @@ -360,3 +362,6 @@ class AgentSession: # If 5 minutes have elapsed and we still don't have a controller, something has gone wrong return AgentState.ERROR return None + + def is_closed(self) -> bool: + return self._closed diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 72b123b41d..720a60f0ee 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -247,8 +247,9 @@ class Session: async def _send_status_message(self, msg_type: str, id: str, message: str): """Sends a status message to the client.""" if msg_type == 'error': + agent_session = self.agent_session controller = self.agent_session.controller - if controller is not None: + if controller is not None and not agent_session.is_closed(): await controller.set_agent_state_to(AgentState.ERROR) await self.send( {'status_update': True, 'type': msg_type, 'id': id, 'message': message} diff --git a/openhands/server/shared.py b/openhands/server/shared.py index 2a472e4121..c53e73fb45 100644 --- a/openhands/server/shared.py +++ b/openhands/server/shared.py @@ -1,4 +1,3 @@ -import inspect import os import socketio @@ -46,15 +45,9 @@ ConversationManagerImpl = get_impl( server_config.conversation_manager_class, ) -if len(inspect.signature(ConversationManagerImpl.get_instance).parameters) == 3: - # This conditional prevents a breaking change in February 2025. - # It should be safe to remove by April. - conversation_manager = ConversationManagerImpl.get_instance(sio, config, file_store) -else: - # This is the new signature. - conversation_manager = ConversationManagerImpl.get_instance( # type: ignore - sio, config, file_store, monitoring_listener - ) +conversation_manager = ConversationManagerImpl.get_instance( # type: ignore + sio, config, file_store, server_config, monitoring_listener +) SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore diff --git a/openhands/storage/conversation/conversation_store.py b/openhands/storage/conversation/conversation_store.py index efe95b1b69..e74de6fcf7 100644 --- a/openhands/storage/conversation/conversation_store.py +++ b/openhands/storage/conversation/conversation_store.py @@ -1,12 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Iterable from openhands.core.config.app_config import AppConfig from openhands.storage.data_models.conversation_metadata import ConversationMetadata from openhands.storage.data_models.conversation_metadata_result_set import ( ConversationMetadataResultSet, ) +from openhands.utils.async_utils import wait_all class ConversationStore(ABC): @@ -38,6 +40,12 @@ class ConversationStore(ABC): ) -> ConversationMetadataResultSet: """Search conversations""" + async def get_all_metadata( + self, conversation_ids: Iterable[str] + ) -> list[ConversationMetadata]: + """Get metadata for multiple conversations in parallel""" + return await wait_all([self.get_metadata(cid) for cid in conversation_ids]) + @classmethod @abstractmethod async def get_instance( diff --git a/tests/unit/test_file_conversation_store.py b/tests/unit/test_file_conversation_store.py index 80c391daca..fe49cd5fef 100644 --- a/tests/unit/test_file_conversation_store.py +++ b/tests/unit/test_file_conversation_store.py @@ -162,3 +162,38 @@ async def test_search_with_invalid_conversation(): assert len(result.results) == 1 assert result.results[0].conversation_id == 'conv1' assert result.next_page_id is None + + +@pytest.mark.asyncio +async def test_get_all_metadata(): + store = FileConversationStore( + InMemoryFileStore( + { + 'sessions/conv1/metadata.json': json.dumps( + { + 'conversation_id': 'conv1', + 'github_user_id': '123', + 'selected_repository': 'repo1', + 'title': 'First conversation', + 'created_at': '2025-01-16T19:51:04Z', + } + ), + 'sessions/conv2/metadata.json': json.dumps( + { + 'conversation_id': 'conv2', + 'github_user_id': '123', + 'selected_repository': 'repo1', + 'title': 'Second conversation', + 'created_at': '2025-01-17T19:51:04Z', + } + ), + } + ) + ) + + results = await store.get_all_metadata(['conv1', 'conv2']) + assert len(results) == 2 + assert results[0].conversation_id == 'conv1' + assert results[0].title == 'First conversation' + assert results[1].conversation_id == 'conv2' + assert results[1].title == 'Second conversation'