mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Nested Conversation Support (#8588)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Robert Brennan <contact@rbren.io>
This commit is contained in:
parent
70573dcbc0
commit
a3d1a92353
@ -2,11 +2,12 @@ import ColdIcon from "./state-indicators/cold.svg?react";
|
||||
import RunningIcon from "./state-indicators/running.svg?react";
|
||||
|
||||
type SVGIcon = React.FunctionComponent<React.SVGProps<SVGSVGElement>>;
|
||||
export type ProjectStatus = "RUNNING" | "STOPPED";
|
||||
export type ProjectStatus = "RUNNING" | "STOPPED" | "STARTING";
|
||||
|
||||
const INDICATORS: Record<ProjectStatus, SVGIcon> = {
|
||||
STOPPED: ColdIcon,
|
||||
RUNNING: RunningIcon,
|
||||
STARTING: ColdIcon,
|
||||
};
|
||||
|
||||
interface ConversationStateIndicatorProps {
|
||||
|
||||
@ -269,7 +269,7 @@ export function WsClientProvider({
|
||||
if (!conversationId) {
|
||||
throw new Error("No conversation ID provided");
|
||||
}
|
||||
if (!conversation) {
|
||||
if (!conversation || conversation.status === "STARTING") {
|
||||
return () => undefined; // conversation not yet loaded
|
||||
}
|
||||
|
||||
@ -309,7 +309,7 @@ export function WsClientProvider({
|
||||
sio.off("connect_failed", handleError);
|
||||
sio.off("disconnect", handleDisconnect);
|
||||
};
|
||||
}, [conversationId, conversation?.url]);
|
||||
}, [conversationId, conversation?.url, conversation?.status]);
|
||||
|
||||
React.useEffect(
|
||||
() => () => {
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
const FIVE_MINUTES = 1000 * 60 * 5;
|
||||
const FIFTEEN_MINUTES = 1000 * 60 * 15;
|
||||
|
||||
export const useUserConversation = (cid: string | null) =>
|
||||
useQuery({
|
||||
queryKey: ["user", "conversation", cid],
|
||||
@ -11,6 +14,12 @@ export const useUserConversation = (cid: string | null) =>
|
||||
},
|
||||
enabled: !!cid,
|
||||
retry: false,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
refetchInterval: (query) => {
|
||||
if (query.state.data?.status === "STARTING") {
|
||||
return 2000; // 2 seconds
|
||||
}
|
||||
return FIVE_MINUTES;
|
||||
},
|
||||
staleTime: FIVE_MINUTES,
|
||||
gcTime: FIFTEEN_MINUTES,
|
||||
});
|
||||
|
||||
73
openhands/events/nested_event_store.py
Normal file
73
openhands/events/nested_event_store.py
Normal file
@ -0,0 +1,73 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.serialization.event import event_from_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class NestedEventStore(EventStoreABC):
|
||||
"""
|
||||
A stored list of events backing a conversation
|
||||
"""
|
||||
|
||||
base_url: str
|
||||
sid: str
|
||||
user_id: str | None
|
||||
|
||||
def search_events(
|
||||
self,
|
||||
start_id: int = 0,
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter: EventFilter | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterable[Event]:
|
||||
while True:
|
||||
search_params = {
|
||||
'start_id': start_id,
|
||||
'reverse': reverse,
|
||||
}
|
||||
if limit is not None:
|
||||
search_params['limit'] = min(100, limit)
|
||||
search_str = urlencode(search_params)
|
||||
url = f'{self.base_url}/events{search_str}'
|
||||
response = httpx.get(url)
|
||||
result_set = response.json()
|
||||
for result in result_set['results']:
|
||||
event = event_from_dict(result)
|
||||
start_id = event.id
|
||||
if end_id == event.id:
|
||||
if not filter or filter.include(event):
|
||||
yield event
|
||||
return
|
||||
if filter and filter.exclude(event):
|
||||
continue
|
||||
yield event
|
||||
if limit is not None:
|
||||
limit -= 1
|
||||
if limit <= 0:
|
||||
return
|
||||
if not result_set['has_more']:
|
||||
return
|
||||
|
||||
def get_event(self, id: int) -> Event:
|
||||
events = list(self.search_events(start_id=id, limit=1))
|
||||
if not events:
|
||||
raise FileNotFoundError('no_event')
|
||||
return events[0]
|
||||
|
||||
def get_latest_event(self) -> Event:
|
||||
events = list(self.search_events(reverse=True, limit=1))
|
||||
if not events:
|
||||
raise FileNotFoundError('no_event')
|
||||
return events[0]
|
||||
|
||||
def get_latest_event_id(self) -> int:
|
||||
event = self.get_latest_event()
|
||||
return event.id
|
||||
@ -118,9 +118,10 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.RUNTIME, self.on_event, self.sid
|
||||
)
|
||||
if event_stream:
|
||||
event_stream.subscribe(
|
||||
EventStreamSubscriber.RUNTIME, self.on_event, self.sid
|
||||
)
|
||||
self.plugins = (
|
||||
copy.deepcopy(plugins) if plugins is not None and len(plugins) > 0 else []
|
||||
)
|
||||
@ -267,9 +268,10 @@ class Runtime(FileEditRuntimeMixin):
|
||||
return
|
||||
|
||||
try:
|
||||
await self.provider_handler.set_event_stream_secrets(
|
||||
self.event_stream, env_vars=env_vars
|
||||
)
|
||||
if self.event_stream:
|
||||
await self.provider_handler.set_event_stream_secrets(
|
||||
self.event_stream, env_vars=env_vars
|
||||
)
|
||||
self.add_env_vars(self.provider_handler.expose_env_vars(env_vars))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
||||
@ -23,7 +23,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
from openhands.runtime.impl.docker.containers import stop_all_containers
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
from openhands.runtime.utils.command import get_action_execution_server_startup_command
|
||||
from openhands.runtime.utils.command import DEFAULT_MAIN_MODULE, get_action_execution_server_startup_command
|
||||
from openhands.runtime.utils.log_streamer import LogStreamer
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
@ -38,10 +38,10 @@ APP_PORT_RANGE_1 = (50000, 54999)
|
||||
APP_PORT_RANGE_2 = (55000, 59999)
|
||||
|
||||
|
||||
def _is_retryable_wait_until_alive_error(exception):
|
||||
def _is_retryablewait_until_alive_error(exception):
|
||||
if isinstance(exception, tenacity.RetryError):
|
||||
cause = exception.last_attempt.exception()
|
||||
return _is_retryable_wait_until_alive_error(cause)
|
||||
return _is_retryablewait_until_alive_error(cause)
|
||||
|
||||
return isinstance(
|
||||
exception,
|
||||
@ -51,6 +51,7 @@ def _is_retryable_wait_until_alive_error(exception):
|
||||
httpx.NetworkError,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.HTTPStatusError,
|
||||
httpx.ReadTimeout,
|
||||
),
|
||||
)
|
||||
|
||||
@ -80,6 +81,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
main_module: str = DEFAULT_MAIN_MODULE,
|
||||
):
|
||||
if not DockerRuntime._shutdown_listener_id:
|
||||
DockerRuntime._shutdown_listener_id = add_shutdown_listener(
|
||||
@ -109,6 +111,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self.runtime_container_image = self.config.sandbox.runtime_container_image
|
||||
self.container_name = CONTAINER_NAME_PREFIX + sid
|
||||
self.container: Container | None = None
|
||||
self.main_module = main_module
|
||||
|
||||
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
|
||||
|
||||
@ -148,25 +151,11 @@ class DockerRuntime(ActionExecutionClient):
|
||||
f'Container {self.container_name} not found.',
|
||||
)
|
||||
raise AgentRuntimeDisconnectedError from e
|
||||
if self.runtime_container_image is None:
|
||||
if self.base_container_image is None:
|
||||
raise ValueError(
|
||||
'Neither runtime container image nor base container image is set'
|
||||
)
|
||||
self.send_status_message('STATUS$STARTING_CONTAINER')
|
||||
self.runtime_container_image = build_runtime_image(
|
||||
self.base_container_image,
|
||||
self.runtime_builder,
|
||||
platform=self.config.sandbox.platform,
|
||||
extra_deps=self.config.sandbox.runtime_extra_deps,
|
||||
force_rebuild=self.config.sandbox.force_rebuild_runtime,
|
||||
extra_build_args=self.config.sandbox.runtime_extra_build_args,
|
||||
)
|
||||
|
||||
self.maybe_build_runtime_container_image()
|
||||
self.log(
|
||||
'info', f'Starting runtime with image: {self.runtime_container_image}'
|
||||
)
|
||||
await call_sync_from_async(self._init_container)
|
||||
await call_sync_from_async(self.init_container)
|
||||
self.log(
|
||||
'info',
|
||||
f'Container started: {self.container_name}. VSCode URL: {self.vscode_url}',
|
||||
@ -181,7 +170,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
|
||||
await call_sync_from_async(self._wait_until_alive)
|
||||
await call_sync_from_async(self.wait_until_alive)
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', 'Runtime is ready.')
|
||||
@ -197,6 +186,22 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self.send_status_message(' ')
|
||||
self._runtime_initialized = True
|
||||
|
||||
def maybe_build_runtime_container_image(self):
|
||||
if self.runtime_container_image is None:
|
||||
if self.base_container_image is None:
|
||||
raise ValueError(
|
||||
'Neither runtime container image nor base container image is set'
|
||||
)
|
||||
self.send_status_message('STATUS$STARTING_CONTAINER')
|
||||
self.runtime_container_image = build_runtime_image(
|
||||
self.base_container_image,
|
||||
self.runtime_builder,
|
||||
platform=self.config.sandbox.platform,
|
||||
extra_deps=self.config.sandbox.runtime_extra_deps,
|
||||
force_rebuild=self.config.sandbox.force_rebuild_runtime,
|
||||
extra_build_args=self.config.sandbox.runtime_extra_build_args,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def _init_docker_client() -> docker.DockerClient:
|
||||
@ -256,7 +261,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
|
||||
return volumes
|
||||
|
||||
def _init_container(self):
|
||||
def init_container(self):
|
||||
self.log('debug', 'Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
self._host_port = self._find_available_port(EXECUTION_SERVER_PORT_RANGE)
|
||||
@ -336,11 +341,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
|
||||
)
|
||||
|
||||
command = get_action_execution_server_startup_command(
|
||||
server_port=self._container_port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
)
|
||||
command = self.get_action_execution_server_startup_command()
|
||||
|
||||
try:
|
||||
self.container = self.docker_client.containers.run(
|
||||
@ -371,7 +372,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
f'Container {self.container_name} already exists. Removing...',
|
||||
)
|
||||
stop_all_containers(self.container_name)
|
||||
return self._init_container()
|
||||
return self.init_container()
|
||||
|
||||
else:
|
||||
self.log(
|
||||
@ -421,11 +422,11 @@ class DockerRuntime(ActionExecutionClient):
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
|
||||
retry=tenacity.retry_if_exception(_is_retryable_wait_until_alive_error),
|
||||
retry=tenacity.retry_if_exception(_is_retryablewait_until_alive_error),
|
||||
reraise=True,
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
def _wait_until_alive(self):
|
||||
def wait_until_alive(self):
|
||||
try:
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
if container.status == 'exited':
|
||||
@ -519,7 +520,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self.log('debug', f'Container {self.container_name} resumed')
|
||||
|
||||
# Wait for the container to be ready
|
||||
self._wait_until_alive()
|
||||
self.wait_until_alive()
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, conversation_id: str):
|
||||
@ -534,3 +535,11 @@ class DockerRuntime(ActionExecutionClient):
|
||||
pass
|
||||
finally:
|
||||
docker_client.close()
|
||||
|
||||
def get_action_execution_server_startup_command(self):
|
||||
return get_action_execution_server_startup_command(
|
||||
server_port=self._container_port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
main_module=self.main_module,
|
||||
)
|
||||
|
||||
@ -24,7 +24,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.command import get_action_execution_server_startup_command
|
||||
from openhands.runtime.utils.command import DEFAULT_MAIN_MODULE, get_action_execution_server_startup_command
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
@ -41,6 +41,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
runtime_builder: RemoteRuntimeBuilder
|
||||
container_image: str
|
||||
available_hosts: dict[str, int]
|
||||
main_module: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -54,6 +55,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
main_module: str = DEFAULT_MAIN_MODULE,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config,
|
||||
@ -85,6 +87,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
)
|
||||
|
||||
assert self.config.sandbox.remote_runtime_class in (None, 'sysbox', 'gvisor')
|
||||
self.main_module = main_module
|
||||
|
||||
self.runtime_builder = RemoteRuntimeBuilder(
|
||||
self.config.sandbox.remote_runtime_api_url,
|
||||
@ -231,11 +234,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
|
||||
def _start_runtime(self) -> None:
|
||||
# Prepare the request body for the /start endpoint
|
||||
command = get_action_execution_server_startup_command(
|
||||
server_port=self.port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
)
|
||||
command = self.get_action_execution_server_startup_command()
|
||||
environment: dict[str, str] = {}
|
||||
if self.config.debug or os.environ.get('DEBUG', 'false').lower() == 'true':
|
||||
environment['DEBUG'] = 'true'
|
||||
@ -492,3 +491,11 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
|
||||
def _stop_if_closed(self, retry_state: RetryCallState) -> bool:
|
||||
return self._runtime_closed
|
||||
|
||||
def get_action_execution_server_startup_command(self):
|
||||
return get_action_execution_server_startup_command(
|
||||
server_port=self.port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
main_module=self.main_module,
|
||||
)
|
||||
|
||||
@ -9,6 +9,7 @@ DEFAULT_PYTHON_PREFIX = [
|
||||
'poetry',
|
||||
'run',
|
||||
]
|
||||
DEFAULT_MAIN_MODULE = 'openhands.runtime.action_execution_server'
|
||||
|
||||
|
||||
def get_action_execution_server_startup_command(
|
||||
@ -18,6 +19,7 @@ def get_action_execution_server_startup_command(
|
||||
python_prefix: list[str] = DEFAULT_PYTHON_PREFIX,
|
||||
override_user_id: int | None = None,
|
||||
override_username: str | None = None,
|
||||
main_module: str = DEFAULT_MAIN_MODULE,
|
||||
) -> list[str]:
|
||||
sandbox_config = app_config.sandbox
|
||||
|
||||
@ -45,7 +47,7 @@ def get_action_execution_server_startup_command(
|
||||
'python',
|
||||
'-u',
|
||||
'-m',
|
||||
'openhands.runtime.action_execution_server',
|
||||
main_module,
|
||||
str(server_port),
|
||||
'--working-dir',
|
||||
app_config.workspace_mount_path_in_sandbox,
|
||||
|
||||
16
openhands/server/__main__.py
Normal file
16
openhands/server/__main__.py
Normal file
@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
||||
def main():
|
||||
uvicorn.run(
|
||||
'openhands.server.listen:app',
|
||||
host='0.0.0.0',
|
||||
port=int(os.environ.get('port') or '3000'),
|
||||
log_level='debug' if os.environ.get('DEBUG') else 'info',
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -17,6 +17,7 @@ from openhands.server.routes.conversation import app as conversation_api_router
|
||||
from openhands.server.routes.feedback import app as feedback_api_router
|
||||
from openhands.server.routes.files import app as files_api_router
|
||||
from openhands.server.routes.git import app as git_api_router
|
||||
from openhands.server.routes.health import add_health_endpoints
|
||||
from openhands.server.routes.manage_conversations import (
|
||||
app as manage_conversation_api_router,
|
||||
)
|
||||
@ -44,11 +45,6 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
|
||||
@app.get('/health')
|
||||
async def health() -> str:
|
||||
return 'OK'
|
||||
|
||||
|
||||
app.include_router(public_api_router)
|
||||
app.include_router(files_api_router)
|
||||
app.include_router(security_api_router)
|
||||
@ -59,3 +55,4 @@ app.include_router(settings_router)
|
||||
app.include_router(secrets_router)
|
||||
app.include_router(git_api_router)
|
||||
app.include_router(trajectory_router)
|
||||
add_health_endpoints(app)
|
||||
|
||||
@ -21,7 +21,10 @@ class ServerConfig(ServerConfigInterface):
|
||||
conversation_store_class: str = (
|
||||
'openhands.storage.conversation.file_conversation_store.FileConversationStore'
|
||||
)
|
||||
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
conversation_manager_class: str = os.environ.get(
|
||||
"CONVERSATION_MANAGER_CLASS",
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager',
|
||||
)
|
||||
monitoring_listener_class: str = 'openhands.server.monitoring.MonitoringListener'
|
||||
user_auth_class: str = (
|
||||
'openhands.server.user_auth.default_user_auth.DefaultUserAuth'
|
||||
|
||||
@ -0,0 +1,464 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
from base64 import urlsafe_b64encode
|
||||
from dataclasses import dataclass, field
|
||||
from types import MappingProxyType
|
||||
from typing import Any, cast
|
||||
|
||||
import docker
|
||||
import httpx
|
||||
import socketio
|
||||
from docker.models.containers import Container
|
||||
from fastapi import status
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.nested_event_store import NestedEventStore
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.impl.docker.containers import stop_all_containers
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_dir
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
@dataclass
|
||||
class DockerNestedConversationManager(ConversationManager):
|
||||
"""Conversation manager where the agent loops exist inside the docker containers."""
|
||||
|
||||
sio: socketio.AsyncServer
|
||||
config: AppConfig
|
||||
server_config: ServerConfig
|
||||
file_store: FileStore
|
||||
docker_client: docker.DockerClient = field(default_factory=docker.from_env)
|
||||
_conversation_store_class: type[ConversationStore] | None = None
|
||||
_starting_conversation_ids: set[str] = field(default_factory=set)
|
||||
|
||||
async def __aenter__(self):
|
||||
# No action is required on startup for this implementation
|
||||
pass
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
# No action is required on shutdown for this implementation
|
||||
pass
|
||||
|
||||
async def attach_to_conversation(
|
||||
self, sid: str, user_id: str | None = None
|
||||
) -> Conversation | None:
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def join_conversation(
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
) -> AgentLoopInfo:
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get the running agent loops directly from docker.
|
||||
"""
|
||||
names = (container.name for container in self.docker_client.containers.list())
|
||||
conversation_ids = {
|
||||
name[len('openhands-runtime-') :]
|
||||
for name in names
|
||||
if name.startswith('openhands-runtime-')
|
||||
}
|
||||
if filter_to_sids is not None:
|
||||
conversation_ids = {
|
||||
conversation_id
|
||||
for conversation_id in conversation_ids
|
||||
if conversation_id in filter_to_sids
|
||||
}
|
||||
return conversation_ids
|
||||
|
||||
async def get_connections(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
# We don't monitor connections outside the nested server, though we could introduce an API for this.
|
||||
results: dict[str, str] = {}
|
||||
return results
|
||||
|
||||
async def maybe_start_agent_loop(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
if not await self.is_agent_loop_running(sid):
|
||||
await self._start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
|
||||
nested_url = self._get_nested_url(sid)
|
||||
return AgentLoopInfo(
|
||||
conversation_id=sid,
|
||||
url=nested_url,
|
||||
session_api_key=self._get_session_api_key_for_conversation(sid),
|
||||
event_store=NestedEventStore(
|
||||
base_url=nested_url,
|
||||
sid=sid,
|
||||
user_id=user_id,
|
||||
),
|
||||
status=ConversationStatus.STARTING
|
||||
if sid in self._starting_conversation_ids
|
||||
else ConversationStatus.RUNNING,
|
||||
)
|
||||
|
||||
async def _start_agent_loop(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None,
|
||||
replay_json: str | None,
|
||||
):
|
||||
logger.info(f'starting_agent_loop:{sid}', extra={'session_id': sid})
|
||||
await self.ensure_num_conversations_below_limit(sid, user_id)
|
||||
runtime = await self._create_runtime(sid, user_id, settings)
|
||||
self._starting_conversation_ids.add(sid)
|
||||
try:
|
||||
# Build the runtime container image if it is missing
|
||||
await call_sync_from_async(runtime.maybe_build_runtime_container_image)
|
||||
|
||||
# initialize the container but dont wait for it to start
|
||||
await call_sync_from_async(runtime.init_container)
|
||||
|
||||
# Start the conversation in a background task.
|
||||
asyncio.create_task(
|
||||
self._start_conversation(
|
||||
sid,
|
||||
settings,
|
||||
runtime,
|
||||
initial_user_msg,
|
||||
replay_json,
|
||||
runtime.api_url,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
self._starting_conversation_ids.remove(sid)
|
||||
raise
|
||||
|
||||
async def _start_conversation(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
runtime: DockerRuntime,
|
||||
initial_user_msg: MessageAction | None,
|
||||
replay_json: str | None,
|
||||
api_url: str,
|
||||
):
|
||||
try:
|
||||
await call_sync_from_async(runtime.wait_until_alive)
|
||||
await call_sync_from_async(runtime.setup_initial_env)
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'X-Session-API-Key': self._get_session_api_key_for_conversation(sid)
|
||||
}
|
||||
) as client:
|
||||
# setup the settings...
|
||||
settings_json = settings.model_dump(context={'expose_secrets': True})
|
||||
settings_json.pop('custom_secrets', None)
|
||||
settings_json.pop('git_provider_tokens', None)
|
||||
secrets_store = settings_json.pop('secrets_store', None) or {}
|
||||
response = await client.post(
|
||||
f'{api_url}/api/settings', json=settings_json
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Setup provider tokens
|
||||
provider_handler = self._get_provider_handler(settings)
|
||||
provider_tokens = provider_handler.provider_tokens
|
||||
if provider_tokens:
|
||||
provider_tokens_json = {
|
||||
k.value: {
|
||||
'token': v.token.get_secret_value(),
|
||||
'user_id': v.user_id,
|
||||
'host': v.host,
|
||||
}
|
||||
for k, v in provider_tokens.items()
|
||||
if v.token
|
||||
}
|
||||
response = await client.post(
|
||||
f'{api_url}/api/add-git-providers',
|
||||
json={
|
||||
'provider_tokens': provider_tokens_json,
|
||||
},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Setup custom secrets
|
||||
custom_secrets = secrets_store.get('custom_secrets') or {}
|
||||
if custom_secrets:
|
||||
for key, value in custom_secrets.items():
|
||||
response = await client.post(
|
||||
f'{api_url}/api/secrets',
|
||||
json={
|
||||
'name': key,
|
||||
'description': value.description,
|
||||
'value': value.value,
|
||||
},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
init_conversation: dict[str, Any] = {
|
||||
'initial_user_msg': initial_user_msg,
|
||||
'image_urls': [],
|
||||
'replay_json': replay_json,
|
||||
'conversation_id': sid,
|
||||
}
|
||||
|
||||
if isinstance(settings, ConversationInitData):
|
||||
init_conversation['repository'] = settings.selected_repository
|
||||
init_conversation['selected_branch'] = settings.selected_branch
|
||||
init_conversation['git_provider'] = (
|
||||
settings.git_provider.value if settings.git_provider else None
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
response = await client.post(
|
||||
f'{api_url}/api/conversations', json=init_conversation
|
||||
)
|
||||
logger.info(
|
||||
f'_start_agent_loop:{response.status_code}:{response.json()}'
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
finally:
|
||||
self._starting_conversation_ids.remove(sid)
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
# Not supported - clients should connect directly to the nested server!
|
||||
raise ValueError('unsupported_operation')
|
||||
|
||||
async def close_session(self, sid: str):
|
||||
stop_all_containers(f'openhands-runtime-{sid}')
|
||||
|
||||
async def get_agent_loop_info(self, user_id=None, filter_to_sids=None):
|
||||
results = []
|
||||
containers = self.docker_client.containers.list()
|
||||
for container in containers:
|
||||
if not container.name.startswith('openhands-runtime-'):
|
||||
continue
|
||||
conversation_id = container.name[len('openhands-runtime-') :]
|
||||
if filter_to_sids is not None and conversation_id not in filter_to_sids:
|
||||
continue
|
||||
nested_url = self.get_nested_url_for_container(container)
|
||||
if os.getenv('NESTED_RUNTIME_BROWSER_HOST', '') != '':
|
||||
# This should be set to http://localhost if you're running OH inside a docker container
|
||||
nested_url = nested_url.replace(
|
||||
self.config.sandbox.local_runtime_url,
|
||||
os.getenv('NESTED_RUNTIME_BROWSER_HOST', ''),
|
||||
)
|
||||
agent_loop_info = AgentLoopInfo(
|
||||
conversation_id=conversation_id,
|
||||
url=nested_url,
|
||||
session_api_key=self._get_session_api_key_for_conversation(
|
||||
conversation_id
|
||||
),
|
||||
event_store=NestedEventStore(
|
||||
base_url=nested_url,
|
||||
sid=conversation_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
status=ConversationStatus.STARTING
|
||||
if conversation_id in self._starting_conversation_ids
|
||||
else ConversationStatus.RUNNING,
|
||||
)
|
||||
results.append(agent_loop_info)
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
sio: socketio.AsyncServer,
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
server_config: ServerConfig,
|
||||
monitoring_listener: MonitoringListener,
|
||||
) -> ConversationManager:
|
||||
return DockerNestedConversationManager(
|
||||
sio=sio,
|
||||
config=config,
|
||||
server_config=server_config,
|
||||
file_store=file_store,
|
||||
)
|
||||
|
||||
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
|
||||
conversation_store_class = self._conversation_store_class
|
||||
if not conversation_store_class:
|
||||
self._conversation_store_class = conversation_store_class = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
self.server_config.conversation_store_class,
|
||||
)
|
||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
||||
return store
|
||||
|
||||
def _get_nested_url(self, sid: str) -> str:
|
||||
container = self.docker_client.containers.get(f'openhands-runtime-{sid}')
|
||||
return self.get_nested_url_for_container(container)
|
||||
|
||||
def get_nested_url_for_container(self, container: Container) -> str:
|
||||
env = container.attrs['Config']['Env']
|
||||
container_port = int(next(e[5:] for e in env if e.startswith('port=')))
|
||||
conversation_id = container.name[len('openhands-runtime-') :]
|
||||
nested_url = f'{self.config.sandbox.local_runtime_url}:{container_port}/api/conversations/{conversation_id}'
|
||||
return nested_url
|
||||
|
||||
def _get_session_api_key_for_conversation(self, conversation_id: str):
|
||||
jwt_secret = self.config.jwt_secret.get_secret_value() # type:ignore
|
||||
conversation_key = f'{jwt_secret}:{conversation_id}'.encode()
|
||||
session_api_key = (
|
||||
urlsafe_b64encode(hashlib.sha256(conversation_key).digest())
|
||||
.decode()
|
||||
.replace('=', '')
|
||||
)
|
||||
return session_api_key
|
||||
|
||||
async def ensure_num_conversations_below_limit(self, sid: str, user_id: str | None):
|
||||
response_ids = await self.get_running_agent_loops(user_id)
|
||||
if len(response_ids) >= self.config.max_concurrent_conversations:
|
||||
logger.info(
|
||||
f'too_many_sessions_for:{user_id or ""}',
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
# Get the conversations sorted (oldest first)
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
conversations = await conversation_store.get_all_metadata(response_ids)
|
||||
conversations.sort(key=_last_updated_at_key, reverse=True)
|
||||
|
||||
while len(conversations) >= self.config.max_concurrent_conversations:
|
||||
oldest_conversation_id = conversations.pop().conversation_id
|
||||
logger.debug(
|
||||
f'closing_from_too_many_sessions:{user_id or ""}:{oldest_conversation_id}',
|
||||
extra={'session_id': oldest_conversation_id, 'user_id': user_id},
|
||||
)
|
||||
# Send status message to client and close session.
|
||||
status_update_dict = {
|
||||
'status_update': True,
|
||||
'type': 'error',
|
||||
'id': 'AGENT_ERROR$TOO_MANY_CONVERSATIONS',
|
||||
'message': 'Too many conversations at once. If you are still using this one, try reactivating it by prompting the agent to continue',
|
||||
}
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
status_update_dict,
|
||||
to=ROOM_KEY.format(sid=oldest_conversation_id),
|
||||
)
|
||||
await self.close_session(oldest_conversation_id)
|
||||
|
||||
def _get_provider_handler(self, settings: Settings):
|
||||
provider_tokens = None
|
||||
if isinstance(settings, ConversationInitData):
|
||||
provider_tokens = settings.git_provider_tokens
|
||||
provider_handler = ProviderHandler(
|
||||
provider_tokens=provider_tokens
|
||||
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({}))
|
||||
)
|
||||
return provider_handler
|
||||
|
||||
async def _create_runtime(self, sid: str, user_id: str | None, settings: Settings):
|
||||
# This session is created here only because it is the easiest way to get a runtime, which
|
||||
# is the easiest way to create the needed docker container
|
||||
session = Session(
|
||||
sid=sid,
|
||||
file_store=self.file_store,
|
||||
config=self.config,
|
||||
sio=self.sio,
|
||||
user_id=user_id,
|
||||
)
|
||||
agent_cls = settings.agent or self.config.default_agent
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
llm = LLM(
|
||||
config=self.config.get_llm_config_from_agent(agent_name),
|
||||
retry_listener=session._notify_on_llm_retry,
|
||||
)
|
||||
llm = session._create_llm(agent_cls)
|
||||
agent_config = self.config.get_agent_config(agent_cls)
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
|
||||
config = self.config.model_copy(deep=True)
|
||||
env_vars = config.sandbox.runtime_startup_env_vars
|
||||
env_vars['CONVERSATION_MANAGER_CLASS'] = (
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
)
|
||||
env_vars['SERVE_FRONTEND'] = '0'
|
||||
env_vars['RUNTIME'] = 'local'
|
||||
env_vars['USER'] = 'CURRENT_USER'
|
||||
env_vars['SESSION_API_KEY'] = self._get_session_api_key_for_conversation(sid)
|
||||
|
||||
# Set up mounted volume for conversation directory within workspace
|
||||
# TODO: Check if we are using the standard event store and file store
|
||||
volumes = config.sandbox.volumes
|
||||
if not config.sandbox.volumes:
|
||||
volumes = []
|
||||
else:
|
||||
volumes = [v.strip() for v in config.sandbox.volumes.split(',')]
|
||||
conversation_dir = get_conversation_dir(sid, user_id)
|
||||
volumes.append(
|
||||
f'{config.file_store_path}/{conversation_dir}:{AppConfig.model_fields["file_store_path"].default}/{conversation_dir}:rw'
|
||||
)
|
||||
config.sandbox.volumes = ','.join(volumes)
|
||||
|
||||
# Currently this eventstream is never used and only exists because one is required in order to create a docker runtime
|
||||
event_stream = EventStream(sid, self.file_store, user_id)
|
||||
|
||||
runtime = DockerRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
env_vars=env_vars,
|
||||
main_module='openhands.server',
|
||||
)
|
||||
|
||||
# Hack - disable setting initial env.
|
||||
runtime.setup_initial_env = lambda: None # type:ignore
|
||||
|
||||
return runtime
|
||||
|
||||
|
||||
def _last_updated_at_key(conversation: ConversationMetadata) -> float:
|
||||
last_updated_at = conversation.last_updated_at
|
||||
if last_updated_at is None:
|
||||
return 0.0
|
||||
return last_updated_at.timestamp()
|
||||
@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -11,4 +12,5 @@ class AgentLoopInfo:
|
||||
conversation_id: str
|
||||
url: str | None
|
||||
session_api_key: str | None
|
||||
event_store: EventStoreABC
|
||||
event_store: EventStoreABC | None
|
||||
status: ConversationStatus = field(default=ConversationStatus.RUNNING)
|
||||
|
||||
@ -204,11 +204,7 @@ class SessionApiKeyMiddleware:
|
||||
async def __call__(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if (
|
||||
request.method != 'OPTIONS'
|
||||
and request.url.path != '/alive'
|
||||
and request.url.path != '/server_info'
|
||||
):
|
||||
if request.method != 'OPTIONS' and request.url.path.startswith('/api'):
|
||||
if self.session_api_key != request.headers.get('X-Session-API-Key'):
|
||||
return JSONResponse(
|
||||
{'code': 'invalid_session_api_key'},
|
||||
|
||||
39
openhands/server/routes/health.py
Normal file
39
openhands/server/routes/health.py
Normal file
@ -0,0 +1,39 @@
|
||||
import time
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from openhands.runtime.utils.system_stats import get_system_stats
|
||||
|
||||
start_time = time.time()
|
||||
last_execution_time = start_time
|
||||
|
||||
def add_health_endpoints(app: FastAPI):
|
||||
@app.get('/alive')
|
||||
async def alive():
|
||||
return {'status': 'ok'}
|
||||
|
||||
|
||||
@app.get('/health')
|
||||
async def health() -> str:
|
||||
return 'OK'
|
||||
|
||||
|
||||
@app.get('/server_info')
|
||||
async def get_server_info():
|
||||
current_time = time.time()
|
||||
uptime = current_time - start_time
|
||||
idle_time = current_time - last_execution_time
|
||||
|
||||
response = {
|
||||
'uptime': uptime,
|
||||
'idle_time': idle_time,
|
||||
'resources': get_system_stats(),
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def update_last_execution_time(request: Request, call_next):
|
||||
global last_execution_time
|
||||
response = await call_next(request)
|
||||
last_execution_time = time.time()
|
||||
return response
|
||||
@ -1,15 +1,13 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.provider import (
|
||||
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
)
|
||||
@ -24,10 +22,9 @@ from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.services.conversation import create_new_conversation
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
)
|
||||
@ -62,6 +59,7 @@ class InitSessionRequest(BaseModel):
|
||||
replay_json: str | None = None
|
||||
suggested_task: SuggestedTask | None = None
|
||||
conversation_instructions: str | None = None
|
||||
conversation_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
@ -69,112 +67,11 @@ class InitSessionRequest(BaseModel):
|
||||
class InitSessionResponse(BaseModel):
|
||||
status: str
|
||||
conversation_id: str
|
||||
conversation_url: str
|
||||
session_api_key: str | None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
replay_json: str | None,
|
||||
conversation_instructions: str | None = None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_convo_id: bool = False,
|
||||
) -> AgentLoopInfo:
|
||||
logger.info(
|
||||
'Creating conversation',
|
||||
extra={
|
||||
'signal': 'create_conversation',
|
||||
'user_id': user_id,
|
||||
'trigger': conversation_trigger.value,
|
||||
},
|
||||
)
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
session_init_args: dict[str, Any] = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
# We could use litellm.check_valid_key for a more accurate check,
|
||||
# but that would run a tiny inference.
|
||||
if (
|
||||
not settings.llm_api_key
|
||||
or settings.llm_api_key.get_secret_value().isspace()
|
||||
):
|
||||
logger.warning(f'Missing api key for model {settings.llm_model}')
|
||||
raise LLMAuthenticationError(
|
||||
'Error authenticating with the LLM provider. Please check your API key'
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
session_init_args['custom_secrets'] = custom_secrets
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
session_init_args['conversation_instructions'] = conversation_instructions
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
while await conversation_store.exists(conversation_id):
|
||||
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
|
||||
conversation_id = uuid.uuid4().hex
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Starting agent loop for conversation {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
initial_message_action = None
|
||||
if initial_user_msg or image_urls:
|
||||
user_msg = (
|
||||
initial_user_msg.format(conversation_id)
|
||||
if attach_convo_id and initial_user_msg
|
||||
else initial_user_msg
|
||||
)
|
||||
initial_message_action = MessageAction(
|
||||
content=user_msg or '',
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id,
|
||||
conversation_init_data,
|
||||
user_id,
|
||||
initial_user_msg=initial_message_action,
|
||||
replay_json=replay_json,
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {agent_loop_info.conversation_id}')
|
||||
return agent_loop_info
|
||||
# Temporary alias since the private variable was referenced publicly - delete once deploy project is updated.
|
||||
_create_new_conversation = create_new_conversation
|
||||
|
||||
|
||||
@app.post('/conversations')
|
||||
@ -190,7 +87,7 @@ async def new_conversation(
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
logger.info(f'initializing_new_conversation:{data}')
|
||||
repository = data.repository
|
||||
selected_branch = data.selected_branch
|
||||
initial_user_msg = data.initial_user_msg
|
||||
@ -215,8 +112,8 @@ async def new_conversation(
|
||||
# Check against git_provider, otherwise check all provider apis
|
||||
await provider_handler.verify_repo_provider(repository, git_provider)
|
||||
|
||||
# Create conversation with initial message
|
||||
agent_loop_info = await _create_new_conversation(
|
||||
conversation_id = data.conversation_id
|
||||
await create_new_conversation(
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
@ -226,14 +123,14 @@ async def new_conversation(
|
||||
image_urls=image_urls,
|
||||
replay_json=replay_json,
|
||||
conversation_trigger=conversation_trigger,
|
||||
conversation_instructions=conversation_instructions
|
||||
conversation_instructions=conversation_instructions,
|
||||
git_provider=git_provider,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
return InitSessionResponse(
|
||||
status='ok',
|
||||
conversation_id=agent_loop_info.conversation_id,
|
||||
conversation_url=agent_loop_info.url,
|
||||
session_api_key=agent_loop_info.session_api_key,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
except MissingSettingsError as e:
|
||||
return JSONResponse(
|
||||
@ -270,7 +167,6 @@ async def new_conversation(
|
||||
async def search_conversations(
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
) -> ConversationInfoResultSet:
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
@ -289,9 +185,6 @@ async def search_conversations(
|
||||
conversation_ids = set(
|
||||
conversation.conversation_id for conversation in filtered_results
|
||||
)
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
user_id, conversation_ids
|
||||
)
|
||||
connection_ids_to_conversation_ids = await conversation_manager.get_connections(filter_to_sids=conversation_ids)
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids=conversation_ids)
|
||||
agent_loop_info_by_conversation_id = {info.conversation_id: info for info in agent_loop_info}
|
||||
@ -299,7 +192,6 @@ async def search_conversations(
|
||||
results=await wait_all(
|
||||
_get_conversation_info(
|
||||
conversation=conversation,
|
||||
is_running=conversation.conversation_id in running_conversations,
|
||||
num_connections=sum(
|
||||
1 for conversation_id in connection_ids_to_conversation_ids.values()
|
||||
if conversation_id == conversation.conversation_id
|
||||
@ -321,11 +213,10 @@ async def get_conversation(
|
||||
) -> ConversationInfo | None:
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
num_connections = len(await conversation_manager.get_connections(filter_to_sids={conversation_id}))
|
||||
agent_loop_infos = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id})
|
||||
agent_loop_info = agent_loop_infos[0] if agent_loop_infos else None
|
||||
conversation_info = await _get_conversation_info(metadata, is_running, num_connections, agent_loop_info)
|
||||
conversation_info = await _get_conversation_info(metadata, num_connections, agent_loop_info)
|
||||
return conversation_info
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
@ -352,7 +243,6 @@ async def delete_conversation(
|
||||
|
||||
async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
is_running: bool,
|
||||
num_connections: int,
|
||||
agent_loop_info: AgentLoopInfo | None,
|
||||
) -> ConversationInfo | None:
|
||||
@ -368,7 +258,7 @@ async def _get_conversation_info(
|
||||
created_at=conversation.created_at,
|
||||
selected_repository=conversation.selected_repository,
|
||||
status=(
|
||||
ConversationStatus.RUNNING if is_running else ConversationStatus.STOPPED
|
||||
agent_loop_info.status if agent_loop_info else ConversationStatus.STOPPED
|
||||
),
|
||||
num_connections=num_connections,
|
||||
url=agent_loop_info.url if agent_loop_info else None,
|
||||
|
||||
122
openhands/server/services/conversation.py
Normal file
122
openhands/server/services/conversation.py
Normal file
@ -0,0 +1,122 @@
|
||||
|
||||
|
||||
from typing import Any
|
||||
import uuid
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.provider import CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA, PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import ConversationStoreImpl, SettingsStoreImpl, config, conversation_manager
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata, ConversationTrigger
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
|
||||
async def create_new_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
replay_json: str | None,
|
||||
conversation_instructions: str | None = None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_convo_id: bool = False,
|
||||
git_provider: ProviderType | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
logger.info(
|
||||
'Creating conversation',
|
||||
extra={
|
||||
'signal': 'create_conversation',
|
||||
'user_id': user_id,
|
||||
'trigger': conversation_trigger.value,
|
||||
},
|
||||
)
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
session_init_args: dict[str, Any] = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
# We could use litellm.check_valid_key for a more accurate check,
|
||||
# but that would run a tiny inference.
|
||||
if (
|
||||
not settings.llm_api_key
|
||||
or settings.llm_api_key.get_secret_value().isspace()
|
||||
):
|
||||
logger.warning(f'Missing api key for model {settings.llm_model}')
|
||||
raise LLMAuthenticationError(
|
||||
'Error authenticating with the LLM provider. Please check your API key'
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
session_init_args['custom_secrets'] = custom_secrets
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
session_init_args['git_provider'] = git_provider
|
||||
session_init_args['conversation_instructions'] = conversation_instructions
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
# For nested runtimes, we allow a single conversation id, passed in on container creation
|
||||
if conversation_id is None:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
while await conversation_store.exists(conversation_id):
|
||||
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
|
||||
conversation_id = uuid.uuid4().hex
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Starting agent loop for conversation {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
initial_message_action = None
|
||||
if initial_user_msg or image_urls:
|
||||
user_msg = (
|
||||
initial_user_msg.format(conversation_id)
|
||||
if attach_convo_id and initial_user_msg
|
||||
else initial_user_msg
|
||||
)
|
||||
initial_message_action = MessageAction(
|
||||
content=user_msg or '',
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id,
|
||||
conversation_init_data,
|
||||
user_id,
|
||||
initial_user_msg=initial_message_action,
|
||||
replay_json=replay_json,
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {agent_loop_info.conversation_id}')
|
||||
return agent_loop_info
|
||||
@ -1,6 +1,7 @@
|
||||
from pydantic import Field
|
||||
|
||||
from openhands.integrations.provider import CUSTOM_SECRETS_TYPE, PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@ -15,6 +16,7 @@ class ConversationInitData(Settings):
|
||||
replay_json: str | None = Field(default=None)
|
||||
selected_branch: str | None = Field(default=None)
|
||||
conversation_instructions: str | None = Field(default=None)
|
||||
git_provider: ProviderType | None = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
'arbitrary_types_allowed': True,
|
||||
|
||||
@ -2,5 +2,6 @@ from enum import Enum
|
||||
|
||||
|
||||
class ConversationStatus(Enum):
|
||||
STARTING = 'STARTING'
|
||||
RUNNING = 'RUNNING'
|
||||
STOPPED = 'STOPPED'
|
||||
|
||||
@ -157,7 +157,6 @@ async def test_search_conversations():
|
||||
result_set = await search_conversations(
|
||||
page_id=None,
|
||||
limit=20,
|
||||
user_id='12345',
|
||||
conversation_store=mock_store,
|
||||
)
|
||||
|
||||
@ -242,9 +241,9 @@ async def test_get_missing_conversation():
|
||||
async def test_new_conversation_success(provider_handler_mock):
|
||||
"""Test successful creation of a new conversation."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function directly
|
||||
# Mock the create_new_conversation function directly
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
'openhands.server.routes.manage_conversations.create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
@ -258,6 +257,7 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
image_urls=['https://example.com/image.jpg'],
|
||||
conversation_id='test_conversation_id',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
@ -267,9 +267,8 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
assert response.conversation_id == 'test_conversation_id'
|
||||
assert response.conversation_url == 'https://my-conversation.com'
|
||||
|
||||
# Verify that _create_new_conversation was called with the correct arguments
|
||||
# Verify that create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['user_id'] == 'test_user'
|
||||
@ -284,9 +283,9 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
"""Test creating a new conversation with a suggested task."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function directly
|
||||
# Mock the create_new_conversation function directly
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
'openhands.server.routes.manage_conversations.create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
@ -315,6 +314,7 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
suggested_task=test_task,
|
||||
conversation_id='test_conversation_id',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
@ -324,9 +324,8 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
assert response.conversation_id == 'test_conversation_id'
|
||||
assert response.conversation_url == 'https://my-conversation.com'
|
||||
|
||||
# Verify that _create_new_conversation was called with the correct arguments
|
||||
# Verify that create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['user_id'] == 'test_user'
|
||||
@ -349,9 +348,9 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
async def test_new_conversation_missing_settings(provider_handler_mock):
|
||||
"""Test creating a new conversation when settings are missing."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function to raise MissingSettingsError
|
||||
# Mock the create_new_conversation function to raise MissingSettingsError
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
'openhands.server.routes.manage_conversations.create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to raise MissingSettingsError
|
||||
mock_create_conversation.side_effect = MissingSettingsError(
|
||||
@ -378,9 +377,9 @@ async def test_new_conversation_missing_settings(provider_handler_mock):
|
||||
async def test_new_conversation_invalid_session_api_key(provider_handler_mock):
|
||||
"""Test creating a new conversation with an invalid API key."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function to raise LLMAuthenticationError
|
||||
# Mock the create_new_conversation function to raise LLMAuthenticationError
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
'openhands.server.routes.manage_conversations.create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to raise LLMAuthenticationError
|
||||
mock_create_conversation.side_effect = LLMAuthenticationError(
|
||||
@ -469,9 +468,9 @@ async def test_delete_conversation():
|
||||
async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
"""Test creating a new conversation with bearer authentication."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function
|
||||
# Mock the create_new_conversation function
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
'openhands.server.routes.manage_conversations.create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
@ -494,7 +493,7 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
|
||||
# Verify that _create_new_conversation was called with REMOTE_API_KEY trigger
|
||||
# Verify that create_new_conversation was called with REMOTE_API_KEY trigger
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert (
|
||||
@ -506,9 +505,9 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
async def test_new_conversation_with_null_repository():
|
||||
"""Test creating a new conversation with null repository."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function
|
||||
# Mock the create_new_conversation function
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
'openhands.server.routes.manage_conversations.create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
@ -531,7 +530,7 @@ async def test_new_conversation_with_null_repository():
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
|
||||
# Verify that _create_new_conversation was called with None repository
|
||||
# Verify that create_new_conversation was called with None repository
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['selected_repository'] is None
|
||||
@ -547,9 +546,9 @@ async def test_new_conversation_with_provider_authentication_error(
|
||||
|
||||
"""Test creating a new conversation when provider authentication fails."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function
|
||||
# Mock the create_new_conversation function
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
'openhands.server.routes.manage_conversations.create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
@ -578,7 +577,7 @@ async def test_new_conversation_with_provider_authentication_error(
|
||||
'test/repo', None
|
||||
)
|
||||
|
||||
# Verify that _create_new_conversation was not called
|
||||
# Verify that create_new_conversation was not called
|
||||
mock_create_conversation.assert_not_called()
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user