Fix server lock up on session init (#4007)

This commit is contained in:
tofarr 2024-09-24 15:49:30 -06:00 committed by GitHub
parent 1b1d8f0b02
commit ee284bae8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 13 deletions

View File

@ -54,7 +54,7 @@ class AgentController:
confirmation_mode: bool
agent_to_llm_config: dict[str, LLMConfig]
agent_configs: dict[str, AgentConfig]
agent_task: asyncio.Task | None = None
agent_task: asyncio.Future | None = None
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action: Action | None = None
@ -115,9 +115,6 @@ class AgentController:
# stuck helper
self._stuck_detector = StuckDetector(self.state)
if not is_delegate:
self.agent_task = asyncio.create_task(self._start_step_loop())
async def close(self):
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
if self.agent_task is not None:
@ -149,7 +146,7 @@ class AgentController:
self.state.last_error += f': {exception}'
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
async def _start_step_loop(self):
async def start_step_loop(self):
"""The main loop for the agent's step-by-step execution."""
logger.info(f'[Agent Controller {self.id}] Starting step loop...')

View File

@ -121,6 +121,9 @@ async def main():
event_stream=event_stream,
)
if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())
async def prompt_for_next_task():
next_message = input('How can I help? >> ')
if next_message == 'exit':

View File

@ -143,6 +143,9 @@ async def run_controller(
headless_mode=headless_mode,
)
if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())
assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
# Logging
logger.info(

View File

@ -1,4 +1,6 @@
import asyncio
from threading import Thread
from typing import Callable, Optional
from openhands.controller import AgentController
@ -65,9 +67,14 @@ class AgentSession:
raise RuntimeError(
'Session already started. You need to close this session and start a new one.'
)
await self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(runtime_name, config, agent, status_message_callback)
await self._create_controller(
self.loop = asyncio.new_event_loop()
self.thread = Thread(target=self._run, daemon=True)
self.thread.start()
self._create_security_analyzer(config.security.security_analyzer)
self._create_runtime(runtime_name, config, agent, status_message_callback)
self._create_controller(
agent,
config.security.confirmation_mode,
max_iterations,
@ -75,6 +82,13 @@ class AgentSession:
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
)
if self.controller is not None:
self.controller.agent_task = asyncio.run_coroutine_threadsafe(self.controller.start_step_loop(), self.loop) # type: ignore
def _run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
async def close(self):
"""Closes the Agent session"""
@ -89,9 +103,13 @@ class AgentSession:
self.runtime.close()
if self.security_analyzer is not None:
await self.security_analyzer.close()
self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join()
self._closed = True
async def _create_security_analyzer(self, security_analyzer: str | None):
def _create_security_analyzer(self, security_analyzer: str | None):
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
Parameters:
@ -104,7 +122,7 @@ class AgentSession:
security_analyzer, SecurityAnalyzer
)(self.event_stream)
async def _create_runtime(
def _create_runtime(
self,
runtime_name: str,
config: AppConfig,
@ -125,8 +143,7 @@ class AgentSession:
logger.info(f'Initializing runtime `{runtime_name}` now...')
runtime_cls = get_runtime_cls(runtime_name)
self.runtime = await asyncio.to_thread(
runtime_cls,
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
sid=self.sid,
@ -141,7 +158,7 @@ class AgentSession:
else:
logger.warning('Runtime initialization failed')
async def _create_controller(
def _create_controller(
self,
agent: Agent,
confirmation_mode: bool,