mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
API Updates to facilitate nested runtimes. (#8525)
This commit is contained in:
parent
21d0990be4
commit
033788c2d0
@ -8,6 +8,8 @@ from openhands.core.config import AppConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
@ -53,7 +55,7 @@ class ConversationManager(ABC):
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
) -> EventStore | None:
|
||||
) -> AgentLoopInfo | None:
|
||||
"""Join a conversation and return its event stream."""
|
||||
|
||||
async def is_agent_loop_running(self, sid: str) -> bool:
|
||||
@ -81,7 +83,7 @@ class ConversationManager(ABC):
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
) -> EventStore:
|
||||
) -> AgentLoopInfo:
|
||||
"""Start an event loop if one is not already running"""
|
||||
|
||||
@abstractmethod
|
||||
@ -96,6 +98,12 @@ class ConversationManager(ABC):
|
||||
async def close_session(self, sid: str):
|
||||
"""Close a session."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_agent_loop_info(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> list[AgentLoopInfo]:
|
||||
"""Get the AgentLoopInfo for conversations."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_instance(
|
||||
|
||||
@ -11,9 +11,9 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
|
||||
from openhands.server.session.conversation import Conversation
|
||||
@ -119,21 +119,15 @@ class StandaloneConversationManager(ConversationManager):
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
) -> EventStore:
|
||||
) -> AgentLoopInfo:
|
||||
logger.info(
|
||||
f'join_conversation:{sid}:{connection_id}',
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self._local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
if not event_stream:
|
||||
logger.error(
|
||||
f'No event stream after joining conversation: {sid}',
|
||||
extra={'session_id': sid},
|
||||
)
|
||||
raise RuntimeError(f'no_event_stream:{sid}')
|
||||
return event_stream
|
||||
agent_loop_info = await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
return agent_loop_info
|
||||
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
sid = conversation.sid
|
||||
@ -251,21 +245,14 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
) -> EventStore:
|
||||
) -> AgentLoopInfo:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
|
||||
if not await self.is_agent_loop_running(sid):
|
||||
await self._start_agent_loop(
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if not session:
|
||||
session = await self._start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
|
||||
event_store = await self._get_event_store(sid, user_id)
|
||||
if not event_store:
|
||||
logger.error(
|
||||
f'No event stream after starting agent loop: {sid}',
|
||||
extra={'session_id': sid},
|
||||
)
|
||||
raise RuntimeError(f'no_event_stream:{sid}')
|
||||
return event_store
|
||||
return self._agent_loop_info_from_session(session)
|
||||
|
||||
async def _start_agent_loop(
|
||||
self,
|
||||
@ -330,22 +317,6 @@ class StandaloneConversationManager(ConversationManager):
|
||||
pass # Already subscribed - take no action
|
||||
return session
|
||||
|
||||
async def _get_event_store(
|
||||
self, sid: str, user_id: str | None
|
||||
) -> EventStore | None:
|
||||
logger.info(f'_get_event_store:{sid}', extra={'session_id': sid})
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
logger.info(f'found_local_agent_loop:{sid}', extra={'session_id': sid})
|
||||
event_stream = session.agent_session.event_stream
|
||||
return EventStore(
|
||||
event_stream.sid,
|
||||
event_stream.file_store,
|
||||
event_stream.user_id,
|
||||
event_stream.cur_id,
|
||||
)
|
||||
return None
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
# If there is a local session running, send to that
|
||||
sid = self._local_connection_id_to_session_id.get(connection_id)
|
||||
@ -493,6 +464,29 @@ class StandaloneConversationManager(ConversationManager):
|
||||
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
async def get_agent_loop_info(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
):
|
||||
results = []
|
||||
for session in self._local_agent_loops_by_sid.values():
|
||||
if user_id and session.user_id != user_id:
|
||||
continue
|
||||
if filter_to_sids and session.sid not in filter_to_sids:
|
||||
continue
|
||||
results.append(self._agent_loop_info_from_session(session))
|
||||
return results
|
||||
|
||||
def _agent_loop_info_from_session(self, session: Session):
|
||||
return AgentLoopInfo(
|
||||
conversation_id=session.sid,
|
||||
url=self._get_conversation_url(session.sid),
|
||||
api_key=None,
|
||||
event_store=session.agent_session.event_stream,
|
||||
)
|
||||
|
||||
def _get_conversation_url(self, conversation_id: str):
|
||||
return f"/conversations/{conversation_id}"
|
||||
|
||||
|
||||
def _last_updated_at_key(conversation: ConversationMetadata) -> float:
|
||||
last_updated_at = conversation.last_updated_at
|
||||
|
||||
14
openhands/server/data_models/agent_loop_info.py
Normal file
14
openhands/server/data_models/agent_loop_info.py
Normal file
@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentLoopInfo:
|
||||
"""
|
||||
Information about an agent loop - the URL on which to locate it and the event store
|
||||
"""
|
||||
conversation_id: str
|
||||
url: str | None
|
||||
api_key: str | None
|
||||
event_store: EventStoreABC
|
||||
@ -19,4 +19,6 @@ class ConversationInfo:
|
||||
selected_repository: str | None = None
|
||||
trigger: ConversationTrigger | None = None
|
||||
num_connections: int = 0
|
||||
url: str | None = None
|
||||
api_key: str | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@ -103,7 +103,7 @@ async def connect(connection_id: str, environ: dict) -> None:
|
||||
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
agent_loop_info = await conversation_manager.join_conversation(
|
||||
conversation_id,
|
||||
connection_id,
|
||||
conversation_init_data,
|
||||
@ -113,9 +113,11 @@ async def connect(connection_id: str, environ: dict) -> None:
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
if agent_loop_info is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_store = AsyncEventStoreWrapper(event_stream, latest_event_id + 1)
|
||||
async_store = AsyncEventStoreWrapper(
|
||||
agent_loop_info.event_store, latest_event_id + 1
|
||||
)
|
||||
async for event in async_store:
|
||||
logger.debug(f'oh_event: {event.__class__.__name__}')
|
||||
if isinstance(
|
||||
|
||||
@ -18,6 +18,7 @@ from openhands.integrations.service_types import (
|
||||
SuggestedTask,
|
||||
)
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@ -61,6 +62,14 @@ class InitSessionRequest(BaseModel):
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
|
||||
class InitSessionResponse(BaseModel):
|
||||
status: str
|
||||
conversation_id: str
|
||||
conversation_url: str
|
||||
api_key: str | None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
@ -71,7 +80,7 @@ async def _create_new_conversation(
|
||||
replay_json: str | None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_convo_id: bool = False,
|
||||
) -> str:
|
||||
) -> AgentLoopInfo:
|
||||
logger.info(
|
||||
'Creating conversation',
|
||||
extra={
|
||||
@ -149,15 +158,15 @@ async def _create_new_conversation(
|
||||
content=user_msg or '',
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id,
|
||||
conversation_init_data,
|
||||
user_id,
|
||||
initial_user_msg=initial_message_action,
|
||||
replay_json=replay_json,
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
return conversation_id
|
||||
logger.info(f'Finished initializing conversation {agent_loop_info.conversation_id}')
|
||||
return agent_loop_info
|
||||
|
||||
|
||||
@app.post('/conversations')
|
||||
@ -166,7 +175,7 @@ async def new_conversation(
|
||||
user_id: str = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
auth_type: AuthType | None = Depends(get_auth_type),
|
||||
) -> JSONResponse:
|
||||
) -> InitSessionResponse:
|
||||
"""Initialize a new session or join an existing one.
|
||||
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
@ -197,7 +206,7 @@ async def new_conversation(
|
||||
await provider_handler.verify_repo_provider(repository, git_provider)
|
||||
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
agent_loop_info = await _create_new_conversation(
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=repository,
|
||||
@ -208,8 +217,11 @@ async def new_conversation(
|
||||
conversation_trigger=conversation_trigger,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={'status': 'ok', 'conversation_id': conversation_id}
|
||||
return InitSessionResponse(
|
||||
status='ok',
|
||||
conversation_id=agent_loop_info.conversation_id,
|
||||
conversation_url=agent_loop_info.url,
|
||||
api_key=agent_loop_info.api_key,
|
||||
)
|
||||
except MissingSettingsError as e:
|
||||
return JSONResponse(
|
||||
@ -269,6 +281,8 @@ async def search_conversations(
|
||||
user_id, conversation_ids
|
||||
)
|
||||
connection_ids_to_conversation_ids = await conversation_manager.get_connections(filter_to_sids=conversation_ids)
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids=conversation_ids)
|
||||
urls_by_conversation_id = {info.conversation_id: info.url for info in agent_loop_info}
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
_get_conversation_info(
|
||||
@ -277,7 +291,8 @@ async def search_conversations(
|
||||
num_connections=sum(
|
||||
1 for conversation_id in connection_ids_to_conversation_ids.values()
|
||||
if conversation_id == conversation.conversation_id
|
||||
)
|
||||
),
|
||||
url=urls_by_conversation_id.get(conversation.conversation_id),
|
||||
)
|
||||
for conversation in filtered_results
|
||||
),
|
||||
@ -295,7 +310,9 @@ async def get_conversation(
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
num_connections = len(await conversation_manager.get_connections(filter_to_sids={conversation_id}))
|
||||
conversation_info = await _get_conversation_info(metadata, is_running, num_connections)
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(filter_to_sids={conversation_id})
|
||||
url = agent_loop_info[0].url if agent_loop_info else None
|
||||
conversation_info = await _get_conversation_info(metadata, is_running, num_connections, url)
|
||||
return conversation_info
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
@ -323,7 +340,8 @@ async def delete_conversation(
|
||||
async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
is_running: bool,
|
||||
num_connections: int
|
||||
num_connections: int,
|
||||
url: str | None,
|
||||
) -> ConversationInfo | None:
|
||||
try:
|
||||
title = conversation.title
|
||||
@ -339,7 +357,8 @@ async def _get_conversation_info(
|
||||
status=(
|
||||
ConversationStatus.RUNNING if is_running else ConversationStatus.STOPPED
|
||||
),
|
||||
num_connections=num_connections
|
||||
num_connections=num_connections,
|
||||
url=url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@ -21,6 +21,7 @@ from openhands.server.data_models.conversation_info_result_set import (
|
||||
)
|
||||
from openhands.server.routes.manage_conversations import (
|
||||
InitSessionRequest,
|
||||
InitSessionResponse,
|
||||
delete_conversation,
|
||||
get_conversation,
|
||||
new_conversation,
|
||||
@ -112,8 +113,12 @@ async def test_search_conversations():
|
||||
async def mock_get_connections(*args, **kwargs):
|
||||
return {}
|
||||
|
||||
async def get_agent_loop_info(*args, **kwargs):
|
||||
return []
|
||||
|
||||
mock_manager.get_running_agent_loops = mock_get_running_agent_loops
|
||||
mock_manager.get_connections = mock_get_connections
|
||||
mock_manager.get_agent_loop_info = get_agent_loop_info
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.datetime'
|
||||
) as mock_datetime:
|
||||
@ -165,6 +170,7 @@ async def test_search_conversations():
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
num_connections=0,
|
||||
url=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -193,6 +199,7 @@ async def test_get_conversation():
|
||||
) as mock_manager:
|
||||
mock_manager.is_agent_loop_running = AsyncMock(return_value=False)
|
||||
mock_manager.get_connections = AsyncMock(return_value={})
|
||||
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
|
||||
|
||||
conversation = await get_conversation(
|
||||
'some_conversation_id', conversation_store=mock_store
|
||||
@ -206,6 +213,7 @@ async def test_get_conversation():
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
num_connections=0,
|
||||
url=None,
|
||||
)
|
||||
assert conversation == expected
|
||||
|
||||
@ -234,7 +242,11 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
repository='test/repo',
|
||||
@ -247,12 +259,10 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.body.decode('utf-8')
|
||||
== '{"status":"ok","conversation_id":"test_conversation_id"}'
|
||||
)
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
assert response.conversation_id == 'test_conversation_id'
|
||||
assert response.conversation_url == 'https://my-conversation.com'
|
||||
|
||||
# Verify that _create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
@ -274,7 +284,11 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
# Mock SuggestedTask.get_prompt_for_task
|
||||
with patch(
|
||||
@ -302,12 +316,10 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.body.decode('utf-8')
|
||||
== '{"status":"ok","conversation_id":"test_conversation_id"}'
|
||||
)
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
assert response.conversation_id == 'test_conversation_id'
|
||||
assert response.conversation_url == 'https://my-conversation.com'
|
||||
|
||||
# Verify that _create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
@ -457,7 +469,11 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
# Create the request object
|
||||
test_request = InitSessionRequest(
|
||||
@ -470,8 +486,8 @@ async def test_new_conversation_with_bearer_auth(provider_handler_mock):
|
||||
response = await create_new_test_conversation(test_request, AuthType.BEARER)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
|
||||
# Verify that _create_new_conversation was called with REMOTE_API_KEY trigger
|
||||
mock_create_conversation.assert_called_once()
|
||||
@ -490,7 +506,11 @@ async def test_new_conversation_with_null_repository():
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
mock_create_conversation.return_value = MagicMock(
|
||||
conversation_id='test_conversation_id',
|
||||
url='https://my-conversation.com',
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
# Create the request object with null repository
|
||||
test_request = InitSessionRequest(
|
||||
@ -503,8 +523,8 @@ async def test_new_conversation_with_null_repository():
|
||||
response = await create_new_test_conversation(test_request)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
|
||||
# Verify that _create_new_conversation was called with None repository
|
||||
mock_create_conversation.assert_called_once()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user