mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Improve agent loop tracking and make concurrent limit configurable (#6945)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
This commit is contained in:
parent
318fcbcfc7
commit
366fd7ab8a
@ -83,6 +83,9 @@ class AppConfig(BaseModel):
|
||||
cli_multiline_input: bool = Field(default=False)
|
||||
conversation_max_age_seconds: int = Field(default=864000) # 10 days in seconds
|
||||
enable_default_condenser: bool = Field(default=True)
|
||||
max_concurrent_conversations: int = Field(
|
||||
default=3
|
||||
) # Maximum number of concurrent agent loops allowed per user
|
||||
|
||||
defaults_dict: ClassVar[dict] = {}
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any, Callable
|
||||
import requests
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.utils.ensure_httpx_close import ensure_httpx_close
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@ -238,9 +237,8 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
# Record start time for latency measurement
|
||||
start_time = time.time()
|
||||
with ensure_httpx_close():
|
||||
# we don't support streaming here, thus we get a ModelResponse
|
||||
resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
|
||||
# we don't support streaming here, thus we get a ModelResponse
|
||||
resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
|
||||
|
||||
# Calculate and record latency
|
||||
latency = time.time() - start_time
|
||||
|
||||
@ -7,8 +7,11 @@ import socketio
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
@ -23,6 +26,7 @@ class ConversationManager(ABC):
|
||||
sio: socketio.AsyncServer
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
conversation_store: ConversationStore
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self):
|
||||
@ -92,5 +96,7 @@ class ConversationManager(ABC):
|
||||
sio: socketio.AsyncServer,
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
server_config: ServerConfig,
|
||||
monitoring_listener: MonitoringListener,
|
||||
) -> ConversationManager:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
"""Get a conversation manager instance"""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Iterable, Type
|
||||
|
||||
import socketio
|
||||
|
||||
@ -11,20 +12,24 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.stream import EventStream, session_exists
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber, session_exists
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
from .conversation_manager import ConversationManager
|
||||
|
||||
_CLEANUP_INTERVAL = 15
|
||||
MAX_RUNNING_CONVERSATIONS = 3
|
||||
UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id'
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -34,6 +39,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
sio: socketio.AsyncServer
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
server_config: ServerConfig
|
||||
# Defaulting monitoring_listener for temp backward compatibility.
|
||||
monitoring_listener: MonitoringListener = MonitoringListener()
|
||||
_local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
|
||||
@ -46,6 +52,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_cleanup_task: asyncio.Task | None = None
|
||||
_conversation_store_class: Type | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
|
||||
@ -146,7 +153,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
sid_to_close.append(sid)
|
||||
|
||||
connections = await self.get_connections(
|
||||
filter_to_sids=set(sid_to_close)
|
||||
filter_to_sids=set(sid_to_close) # get_connections expects a set
|
||||
)
|
||||
connected_sids = {sid for _, sid in connections.items()}
|
||||
sid_to_close = [
|
||||
@ -170,15 +177,36 @@ class StandaloneConversationManager(ConversationManager):
|
||||
logger.error('error_cleaning_stale')
|
||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||
|
||||
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
|
||||
conversation_store_class = self._conversation_store_class
|
||||
if not conversation_store_class:
|
||||
self._conversation_store_class = conversation_store_class = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
self.server_config.conversation_store_class,
|
||||
)
|
||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
||||
return store
|
||||
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest."""
|
||||
"""Get the running session ids in chronological order (oldest first).
|
||||
|
||||
If a user is supplied, then the results are limited to session ids for that user.
|
||||
If a set of filter_to_sids is supplied, then results are limited to these ids of interest.
|
||||
|
||||
Returns:
|
||||
A set of session IDs
|
||||
"""
|
||||
# Get all items and convert to list for sorting
|
||||
items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items()
|
||||
|
||||
# Filter items if needed
|
||||
if filter_to_sids is not None:
|
||||
items = (item for item in items if item[0] in filter_to_sids)
|
||||
if user_id:
|
||||
items = (item for item in items if item[1].user_id == user_id)
|
||||
|
||||
sids = {sid for sid, _ in items}
|
||||
return sids
|
||||
|
||||
@ -212,12 +240,16 @@ class StandaloneConversationManager(ConversationManager):
|
||||
logger.info(f'start_agent_loop:{sid}')
|
||||
|
||||
response_ids = await self.get_running_agent_loops(user_id)
|
||||
if len(response_ids) >= MAX_RUNNING_CONVERSATIONS:
|
||||
if len(response_ids) >= self.config.max_concurrent_conversations:
|
||||
logger.info('too_many_sessions_for:{user_id}')
|
||||
# Order is not guaranteed, but response_ids tend to be in descending chronological order
|
||||
# By reversing, we are likely to pick the oldest (or at least an older) conversation
|
||||
session_id = next(iter(reversed(list(response_ids))))
|
||||
await self.close_session(session_id)
|
||||
# Get the conversations sorted (oldest first)
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
conversations = await conversation_store.get_all_metadata(response_ids)
|
||||
conversations.sort(key=_last_updated_at_key, reverse=True)
|
||||
|
||||
while len(conversations) >= self.config.max_concurrent_conversations:
|
||||
oldest_conversation_id = conversations.pop().conversation_id
|
||||
await self.close_session(oldest_conversation_id)
|
||||
|
||||
session = Session(
|
||||
sid=sid,
|
||||
@ -229,6 +261,15 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
self._local_agent_loops_by_sid[sid] = session
|
||||
asyncio.create_task(session.initialize_agent(settings, initial_user_msg))
|
||||
# This does not get added when resuming an existing conversation
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(user_id, sid),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
pass # Already subscribed - take no action
|
||||
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
@ -298,8 +339,41 @@ class StandaloneConversationManager(ConversationManager):
|
||||
sio: socketio.AsyncServer,
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
monitoring_listener: MonitoringListener | None = None,
|
||||
server_config: ServerConfig,
|
||||
monitoring_listener: MonitoringListener | None,
|
||||
) -> ConversationManager:
|
||||
return StandaloneConversationManager(
|
||||
sio, config, file_store, monitoring_listener or MonitoringListener()
|
||||
sio,
|
||||
config,
|
||||
file_store,
|
||||
server_config,
|
||||
monitoring_listener or MonitoringListener(),
|
||||
)
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
self, user_id: str | None, conversation_id: str
|
||||
) -> Callable:
|
||||
def callback(*args, **kwargs):
|
||||
call_async_from_sync(
|
||||
self._update_timestamp_for_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def _update_timestamp_for_conversation(
|
||||
self, user_id: str, conversation_id: str
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
conversation.last_updated_at = datetime.now(timezone.utc)
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
|
||||
def _last_updated_at_key(conversation: ConversationMetadata) -> float:
|
||||
last_updated_at = conversation.last_updated_at
|
||||
if last_updated_at is None:
|
||||
return 0.0
|
||||
return last_updated_at.timestamp()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import APIRouter, Body, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
@ -8,7 +7,6 @@ from pydantic import BaseModel, SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_github_token, get_idp_token, get_user_id
|
||||
@ -27,14 +25,9 @@ from openhands.server.shared import (
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import (
|
||||
GENERAL_TIMEOUT,
|
||||
call_async_from_sync,
|
||||
wait_all,
|
||||
)
|
||||
from openhands.utils.async_utils import wait_all
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id'
|
||||
|
||||
|
||||
class InitSessionRequest(BaseModel):
|
||||
@ -119,17 +112,9 @@ async def _create_new_conversation(
|
||||
content=user_msg or '',
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
event_stream = await conversation_manager.maybe_start_agent_loop(
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data, user_id, initial_message_action
|
||||
)
|
||||
try:
|
||||
event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
_create_conversation_update_callback(user_id, conversation_id),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
pass # Already subscribed - take no action
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
|
||||
return conversation_id
|
||||
@ -307,24 +292,3 @@ async def _get_conversation_info(
|
||||
f'Error loading conversation {conversation.conversation_id}: {str(e)}',
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
user_id: str | None, conversation_id: str
|
||||
) -> Callable:
|
||||
def callback(*args, **kwargs):
|
||||
call_async_from_sync(
|
||||
_update_timestamp_for_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
async def _update_timestamp_for_conversation(user_id: str, conversation_id: str):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
conversation.last_updated_at = datetime.now(timezone.utc)
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
@ -131,16 +131,18 @@ class AgentSession:
|
||||
'github_token': github_token.get_secret_value(),
|
||||
}
|
||||
)
|
||||
if initial_message:
|
||||
self.event_stream.add_event(initial_message, EventSource.USER)
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.RUNNING), EventSource.ENVIRONMENT
|
||||
)
|
||||
else:
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.AWAITING_USER_INPUT),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
if not self._closed:
|
||||
if initial_message:
|
||||
self.event_stream.add_event(initial_message, EventSource.USER)
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.RUNNING),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
else:
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.AWAITING_USER_INPUT),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
finished = True
|
||||
finally:
|
||||
self._starting = False
|
||||
@ -360,3 +362,6 @@ class AgentSession:
|
||||
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
||||
return AgentState.ERROR
|
||||
return None
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
@ -247,8 +247,9 @@ class Session:
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str):
|
||||
"""Sends a status message to the client."""
|
||||
if msg_type == 'error':
|
||||
agent_session = self.agent_session
|
||||
controller = self.agent_session.controller
|
||||
if controller is not None:
|
||||
if controller is not None and not agent_session.is_closed():
|
||||
await controller.set_agent_state_to(AgentState.ERROR)
|
||||
await self.send(
|
||||
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import socketio
|
||||
@ -46,15 +45,9 @@ ConversationManagerImpl = get_impl(
|
||||
server_config.conversation_manager_class,
|
||||
)
|
||||
|
||||
if len(inspect.signature(ConversationManagerImpl.get_instance).parameters) == 3:
|
||||
# This conditional prevents a breaking change in February 2025.
|
||||
# It should be safe to remove by April.
|
||||
conversation_manager = ConversationManagerImpl.get_instance(sio, config, file_store)
|
||||
else:
|
||||
# This is the new signature.
|
||||
conversation_manager = ConversationManagerImpl.get_instance( # type: ignore
|
||||
sio, config, file_store, monitoring_listener
|
||||
)
|
||||
conversation_manager = ConversationManagerImpl.get_instance( # type: ignore
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
)
|
||||
|
||||
SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore
|
||||
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_metadata_result_set import (
|
||||
ConversationMetadataResultSet,
|
||||
)
|
||||
from openhands.utils.async_utils import wait_all
|
||||
|
||||
|
||||
class ConversationStore(ABC):
|
||||
@ -38,6 +40,12 @@ class ConversationStore(ABC):
|
||||
) -> ConversationMetadataResultSet:
|
||||
"""Search conversations"""
|
||||
|
||||
async def get_all_metadata(
|
||||
self, conversation_ids: Iterable[str]
|
||||
) -> list[ConversationMetadata]:
|
||||
"""Get metadata for multiple conversations in parallel"""
|
||||
return await wait_all([self.get_metadata(cid) for cid in conversation_ids])
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
|
||||
@ -162,3 +162,38 @@ async def test_search_with_invalid_conversation():
|
||||
assert len(result.results) == 1
|
||||
assert result.results[0].conversation_id == 'conv1'
|
||||
assert result.next_page_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_metadata():
|
||||
store = FileConversationStore(
|
||||
InMemoryFileStore(
|
||||
{
|
||||
'sessions/conv1/metadata.json': json.dumps(
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'First conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
}
|
||||
),
|
||||
'sessions/conv2/metadata.json': json.dumps(
|
||||
{
|
||||
'conversation_id': 'conv2',
|
||||
'github_user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Second conversation',
|
||||
'created_at': '2025-01-17T19:51:04Z',
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
results = await store.get_all_metadata(['conv1', 'conv2'])
|
||||
assert len(results) == 2
|
||||
assert results[0].conversation_id == 'conv1'
|
||||
assert results[0].title == 'First conversation'
|
||||
assert results[1].conversation_id == 'conv2'
|
||||
assert results[1].title == 'Second conversation'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user