From cdd05a98dba5ec5f29ff508f591ecc4e0dc11abd Mon Sep 17 00:00:00 2001 From: tofarr Date: Tue, 8 Oct 2024 07:17:37 -0600 Subject: [PATCH] Lockup Resiliency and Asyncio Improvements (#4221) --- Makefile | 2 +- openhands/events/stream.py | 15 ++++++++- openhands/llm/async_llm.py | 2 +- openhands/runtime/remote/runtime.py | 6 +++- openhands/runtime/runtime.py | 11 +++++-- openhands/security/invariant/analyzer.py | 7 ++-- openhands/server/listen.py | 9 ++++-- openhands/server/session/agent_session.py | 39 ++++++++--------------- openhands/server/session/session.py | 7 ++-- tests/unit/test_security.py | 13 ++++---- 10 files changed, 62 insertions(+), 49 deletions(-) diff --git a/Makefile b/Makefile index d87ea72296..6c89b04586 100644 --- a/Makefile +++ b/Makefile @@ -195,7 +195,7 @@ start-backend: # Start frontend start-frontend: @echo "$(YELLOW)Starting frontend...$(RESET)" - @cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start + @cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start -- --port $(FRONTEND_PORT) # Common setup for running the app (non-callable) _run_setup: diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 59f7febb87..b667202278 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -129,6 +129,13 @@ class EventStream: del self._subscribers[id] def add_event(self, event: Event, source: EventSource): + try: + asyncio.get_running_loop().create_task(self.async_add_event(event, source)) + except RuntimeError: + # No event loop running... + asyncio.run(self.async_add_event(event, source)) + + async def async_add_event(self, event: Event, source: EventSource): with self._lock: event._id = self._cur_id # type: ignore [attr-defined] self._cur_id += 1 @@ -138,10 +145,16 @@ class EventStream: data = event_to_dict(event) if event.id is not None: self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data)) + tasks = [] for key in sorted(self._subscribers.keys()): stack = self._subscribers[key] callback = stack[-1] - asyncio.create_task(callback(event)) + tasks.append(asyncio.create_task(callback(event))) + if tasks: + await asyncio.wait(tasks) + + def _callback(self, callback: Callable, event: Event): + asyncio.run(callback(event)) def filtered_events_by_source(self, source: EventSource): for event in self.get_events(): diff --git a/openhands/llm/async_llm.py b/openhands/llm/async_llm.py index a467e97b40..fec3de70c2 100644 --- a/openhands/llm/async_llm.py +++ b/openhands/llm/async_llm.py @@ -73,7 +73,7 @@ class AsyncLLM(LLM): and self.config.on_cancel_requested_fn is not None and await self.config.on_cancel_requested_fn() ): - raise UserCancelledError('LLM request cancelled by user') + return await asyncio.sleep(0.1) stop_check_task = asyncio.create_task(check_stopped()) diff --git a/openhands/runtime/remote/runtime.py b/openhands/runtime/remote/runtime.py index a4551e4885..1104270b39 100644 --- a/openhands/runtime/remote/runtime.py +++ b/openhands/runtime/remote/runtime.py @@ -200,6 +200,9 @@ class RemoteRuntime(Runtime): assert ( self.runtime_url is not None ), 'Runtime URL is not set. This should never happen.' + + self._wait_until_alive() + self.send_status_message(' ') self._wait_until_alive() @@ -229,7 +232,7 @@ class RemoteRuntime(Runtime): logger.warning(msg) raise RuntimeError(msg) - def close(self): + def close(self, timeout: int = 10): if self.runtime_id: try: response = send_request( @@ -237,6 +240,7 @@ class RemoteRuntime(Runtime): 'POST', f'{self.config.sandbox.remote_runtime_api_url}/stop', json={'runtime_id': self.runtime_id}, + timeout=timeout, ) if response.status_code != 200: logger.error(f'Failed to stop sandbox: {response.text}') diff --git a/openhands/runtime/runtime.py b/openhands/runtime/runtime.py index 8b293a3da0..efa7373ee5 100644 --- a/openhands/runtime/runtime.py +++ b/openhands/runtime/runtime.py @@ -1,3 +1,4 @@ +import asyncio import atexit import copy import json @@ -117,10 +118,10 @@ class Runtime: if event.timeout is None: event.timeout = self.config.sandbox.timeout assert event.timeout is not None - observation = self.run_action(event) + observation = await self.async_run_action(event) observation._cause = event.id # type: ignore[attr-defined] source = event.source if event.source else EventSource.AGENT - self.event_stream.add_event(observation, source) # type: ignore[arg-type] + await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type] def run_action(self, action: Action) -> Observation: """Run an action and return the resulting observation. @@ -151,6 +152,12 @@ class Runtime: observation = getattr(self, action_type)(action) return observation + async def async_run_action(self, action: Action) -> Observation: + observation = await asyncio.get_event_loop().run_in_executor( + None, self.run_action, action + ) + return observation + # ==================================================================== # Context manager # ==================================================================== diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index 0d92f1b327..ed32325d7c 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -1,3 +1,4 @@ +import asyncio import re import uuid from typing import Any @@ -144,10 +145,8 @@ class InvariantAnalyzer(SecurityAnalyzer): new_event = action_from_dict( {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}} ) - if event.source: - self.event_stream.add_event(new_event, event.source) - else: - self.event_stream.add_event(new_event, EventSource.AGENT) + event_source = event.source if event.source else EventSource.AGENT + await asyncio.get_event_loop().run_in_executor(None, self.event_stream.add_event, new_event, event_source) async def security_risk(self, event: Action) -> ActionSecurityRisk: logger.info('Calling security_risk on InvariantAnalyzer') diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 2b1144128a..e574d94326 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -430,7 +430,9 @@ async def list_files(request: Request, path: str | None = None): content={'error': 'Runtime not yet initialized'}, ) runtime: Runtime = request.state.session.agent_session.runtime - file_list = runtime.list_files(path) + file_list = await asyncio.get_event_loop().run_in_executor( + None, runtime.list_files, path + ) if path: file_list = [os.path.join(path, f) for f in file_list] @@ -451,6 +453,7 @@ async def list_files(request: Request, path: str | None = None): return file_list file_list = filter_for_gitignore(file_list, '') + return file_list @@ -478,7 +481,7 @@ async def select_file(file: str, request: Request): file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file) read_action = FileReadAction(file) - observation = runtime.run_action(read_action) + observation = await runtime.async_run_action(read_action) if isinstance(observation, FileReadObservation): content = observation.content @@ -720,7 +723,7 @@ async def save_file(request: Request): runtime.config.workspace_mount_path_in_sandbox, file_path ) write_action = FileWriteAction(file_path, content) - observation = runtime.run_action(write_action) + observation = await runtime.async_run_action(write_action) if isinstance(observation, FileWriteObservation): return JSONResponse( diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index d2db2a6c53..6eb2faa854 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -1,6 +1,4 @@ import asyncio -import concurrent.futures -from threading import Thread from typing import Callable, Optional from openhands.controller import AgentController @@ -32,7 +30,7 @@ class AgentSession: runtime: Runtime | None = None security_analyzer: SecurityAnalyzer | None = None _closed: bool = False - loop: asyncio.AbstractEventLoop + loop: asyncio.AbstractEventLoop | None = None def __init__(self, sid: str, file_store: FileStore): """Initializes a new instance of the Session class @@ -45,7 +43,6 @@ class AgentSession: self.sid = sid self.event_stream = EventStream(sid, file_store) self.file_store = file_store - self.loop = asyncio.new_event_loop() async def start( self, @@ -73,17 +70,9 @@ class AgentSession: 'Session already started. You need to close this session and start a new one.' ) - self.thread = Thread(target=self._run, daemon=True) - self.thread.start() - - def coro_callback(task): - fut: concurrent.futures.Future = concurrent.futures.Future() - try: - fut.set_result(task.result()) - except Exception as e: - logger.error(f'Error starting session: {e}') - - coro = self._start( + asyncio.get_event_loop().run_in_executor( + None, + self._start_thread, runtime_name, config, agent, @@ -93,9 +82,12 @@ class AgentSession: agent_configs, status_message_callback, ) - asyncio.run_coroutine_threadsafe(coro, self.loop).add_done_callback( - coro_callback - ) # type: ignore + + def _start_thread(self, *args): + try: + asyncio.run(self._start(*args), debug=True) + except RuntimeError: + logger.info('Session Finished') async def _start( self, @@ -108,6 +100,7 @@ class AgentSession: agent_configs: dict[str, AgentConfig] | None = None, status_message_callback: Optional[Callable] = None, ): + self.loop = asyncio.get_running_loop() self._create_security_analyzer(config.security.security_analyzer) self._create_runtime(runtime_name, config, agent, status_message_callback) self._create_controller( @@ -125,10 +118,6 @@ class AgentSession: self.controller.agent_task = self.controller.start_step_loop() await self.controller.agent_task # type: ignore - def _run(self): - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - async def close(self): """Closes the Agent session""" @@ -143,10 +132,8 @@ class AgentSession: if self.security_analyzer is not None: await self.security_analyzer.close() - self.loop.call_soon_threadsafe(self.loop.stop) - if self.thread: - # We may be closing an agent_session that was never actually started - self.thread.join() + if self.loop: + self.loop.call_soon_threadsafe(self.loop.stop) self._closed = True diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index f8cc2b581e..94606d085c 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -162,9 +162,10 @@ class Session: 'Model does not support image upload, change to a different model or try without an image.' ) return - asyncio.run_coroutine_threadsafe( - self._add_event(event, EventSource.USER), self.agent_session.loop - ) # type: ignore + if self.agent_session.loop: + asyncio.run_coroutine_threadsafe( + self._add_event(event, EventSource.USER), self.agent_session.loop + ) # type: ignore async def _add_event(self, event, event_source): self.agent_session.event_stream.add_event(event, EventSource.USER) diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py index f4c0503f58..a56e116dd1 100644 --- a/tests/unit/test_security.py +++ b/tests/unit/test_security.py @@ -1,4 +1,3 @@ -import asyncio import pathlib import tempfile @@ -42,7 +41,7 @@ def temp_dir(monkeypatch): yield temp_dir -async def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]): +def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]): for event, source in data: event_stream.add_event(event, source) @@ -62,7 +61,7 @@ def test_msg(temp_dir: str): (MessageAction('Hello world!'), EventSource.USER), (MessageAction('ABC!'), EventSource.AGENT), ] - asyncio.run(add_events(event_stream, data)) + add_events(event_stream, data) for i in range(3): assert data[i][0].security_risk == ActionSecurityRisk.LOW assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM @@ -86,7 +85,7 @@ def test_cmd(cmd, expected_risk, temp_dir: str): (MessageAction('Hello world!'), EventSource.USER), (CmdRunAction(cmd), EventSource.USER), ] - asyncio.run(add_events(event_stream, data)) + add_events(event_stream, data) assert data[0][0].security_risk == ActionSecurityRisk.LOW assert data[1][0].security_risk == expected_risk @@ -115,7 +114,7 @@ def test_leak_secrets(code, expected_risk, temp_dir: str): (IPythonRunCellAction(code), EventSource.AGENT), (IPythonRunCellAction('hello'), EventSource.AGENT), ] - asyncio.run(add_events(event_stream, data)) + add_events(event_stream, data) assert data[0][0].security_risk == ActionSecurityRisk.LOW assert data[1][0].security_risk == expected_risk assert data[2][0].security_risk == ActionSecurityRisk.LOW @@ -133,7 +132,7 @@ def test_unsafe_python_code(temp_dir: str): (MessageAction('Hello world!'), EventSource.USER), (IPythonRunCellAction(code), EventSource.AGENT), ] - asyncio.run(add_events(event_stream, data)) + add_events(event_stream, data) assert data[0][0].security_risk == ActionSecurityRisk.LOW # TODO: this failed but idk why and seems not deterministic to me # assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM @@ -148,7 +147,7 @@ def test_unsafe_bash_command(temp_dir: str): (MessageAction('Hello world!'), EventSource.USER), (CmdRunAction(code), EventSource.AGENT), ] - asyncio.run(add_events(event_stream, data)) + add_events(event_stream, data) assert data[0][0].security_risk == ActionSecurityRisk.LOW assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM