Lockup Resiliency and Asyncio Improvements (#4221)

This commit is contained in:
tofarr 2024-10-08 07:17:37 -06:00 committed by GitHub
parent 568c8ce993
commit cdd05a98db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 62 additions and 49 deletions

View File

@ -195,7 +195,7 @@ start-backend:
# Start frontend
start-frontend:
@echo "$(YELLOW)Starting frontend...$(RESET)"
@cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start
@cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start -- --port $(FRONTEND_PORT)
# Common setup for running the app (non-callable)
_run_setup:

View File

@ -129,6 +129,13 @@ class EventStream:
del self._subscribers[id]
def add_event(self, event: Event, source: EventSource):
try:
asyncio.get_running_loop().create_task(self.async_add_event(event, source))
except RuntimeError:
# No event loop running...
asyncio.run(self.async_add_event(event, source))
async def async_add_event(self, event: Event, source: EventSource):
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
self._cur_id += 1
@ -138,10 +145,16 @@ class EventStream:
data = event_to_dict(event)
if event.id is not None:
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
tasks = []
for key in sorted(self._subscribers.keys()):
stack = self._subscribers[key]
callback = stack[-1]
asyncio.create_task(callback(event))
tasks.append(asyncio.create_task(callback(event)))
if tasks:
await asyncio.wait(tasks)
def _callback(self, callback: Callable, event: Event):
asyncio.run(callback(event))
def filtered_events_by_source(self, source: EventSource):
for event in self.get_events():

View File

@ -73,7 +73,7 @@ class AsyncLLM(LLM):
and self.config.on_cancel_requested_fn is not None
and await self.config.on_cancel_requested_fn()
):
raise UserCancelledError('LLM request cancelled by user')
return
await asyncio.sleep(0.1)
stop_check_task = asyncio.create_task(check_stopped())

View File

@ -200,6 +200,9 @@ class RemoteRuntime(Runtime):
assert (
self.runtime_url is not None
), 'Runtime URL is not set. This should never happen.'
self._wait_until_alive()
self.send_status_message(' ')
self._wait_until_alive()
@ -229,7 +232,7 @@ class RemoteRuntime(Runtime):
logger.warning(msg)
raise RuntimeError(msg)
def close(self):
def close(self, timeout: int = 10):
if self.runtime_id:
try:
response = send_request(
@ -237,6 +240,7 @@ class RemoteRuntime(Runtime):
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/stop',
json={'runtime_id': self.runtime_id},
timeout=timeout,
)
if response.status_code != 200:
logger.error(f'Failed to stop sandbox: {response.text}')

View File

@ -1,3 +1,4 @@
import asyncio
import atexit
import copy
import json
@ -117,10 +118,10 @@ class Runtime:
if event.timeout is None:
event.timeout = self.config.sandbox.timeout
assert event.timeout is not None
observation = self.run_action(event)
observation = await self.async_run_action(event)
observation._cause = event.id # type: ignore[attr-defined]
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
def run_action(self, action: Action) -> Observation:
"""Run an action and return the resulting observation.
@ -151,6 +152,12 @@ class Runtime:
observation = getattr(self, action_type)(action)
return observation
async def async_run_action(self, action: Action) -> Observation:
observation = await asyncio.get_event_loop().run_in_executor(
None, self.run_action, action
)
return observation
# ====================================================================
# Context manager
# ====================================================================

View File

@ -1,3 +1,4 @@
import asyncio
import re
import uuid
from typing import Any
@ -144,10 +145,8 @@ class InvariantAnalyzer(SecurityAnalyzer):
new_event = action_from_dict(
{'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
)
if event.source:
self.event_stream.add_event(new_event, event.source)
else:
self.event_stream.add_event(new_event, EventSource.AGENT)
event_source = event.source if event.source else EventSource.AGENT
await asyncio.get_event_loop().run_in_executor(None, self.event_stream.add_event, new_event, event_source)
async def security_risk(self, event: Action) -> ActionSecurityRisk:
logger.info('Calling security_risk on InvariantAnalyzer')

View File

@ -430,7 +430,9 @@ async def list_files(request: Request, path: str | None = None):
content={'error': 'Runtime not yet initialized'},
)
runtime: Runtime = request.state.session.agent_session.runtime
file_list = runtime.list_files(path)
file_list = await asyncio.get_event_loop().run_in_executor(
None, runtime.list_files, path
)
if path:
file_list = [os.path.join(path, f) for f in file_list]
@ -451,6 +453,7 @@ async def list_files(request: Request, path: str | None = None):
return file_list
file_list = filter_for_gitignore(file_list, '')
return file_list
@ -478,7 +481,7 @@ async def select_file(file: str, request: Request):
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
read_action = FileReadAction(file)
observation = runtime.run_action(read_action)
observation = await runtime.async_run_action(read_action)
if isinstance(observation, FileReadObservation):
content = observation.content
@ -720,7 +723,7 @@ async def save_file(request: Request):
runtime.config.workspace_mount_path_in_sandbox, file_path
)
write_action = FileWriteAction(file_path, content)
observation = runtime.run_action(write_action)
observation = await runtime.async_run_action(write_action)
if isinstance(observation, FileWriteObservation):
return JSONResponse(

View File

@ -1,6 +1,4 @@
import asyncio
import concurrent.futures
from threading import Thread
from typing import Callable, Optional
from openhands.controller import AgentController
@ -32,7 +30,7 @@ class AgentSession:
runtime: Runtime | None = None
security_analyzer: SecurityAnalyzer | None = None
_closed: bool = False
loop: asyncio.AbstractEventLoop
loop: asyncio.AbstractEventLoop | None = None
def __init__(self, sid: str, file_store: FileStore):
"""Initializes a new instance of the Session class
@ -45,7 +43,6 @@ class AgentSession:
self.sid = sid
self.event_stream = EventStream(sid, file_store)
self.file_store = file_store
self.loop = asyncio.new_event_loop()
async def start(
self,
@ -73,17 +70,9 @@ class AgentSession:
'Session already started. You need to close this session and start a new one.'
)
self.thread = Thread(target=self._run, daemon=True)
self.thread.start()
def coro_callback(task):
fut: concurrent.futures.Future = concurrent.futures.Future()
try:
fut.set_result(task.result())
except Exception as e:
logger.error(f'Error starting session: {e}')
coro = self._start(
asyncio.get_event_loop().run_in_executor(
None,
self._start_thread,
runtime_name,
config,
agent,
@ -93,9 +82,12 @@ class AgentSession:
agent_configs,
status_message_callback,
)
asyncio.run_coroutine_threadsafe(coro, self.loop).add_done_callback(
coro_callback
) # type: ignore
def _start_thread(self, *args):
try:
asyncio.run(self._start(*args), debug=True)
except RuntimeError:
logger.info('Session Finished')
async def _start(
self,
@ -108,6 +100,7 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
self.loop = asyncio.get_running_loop()
self._create_security_analyzer(config.security.security_analyzer)
self._create_runtime(runtime_name, config, agent, status_message_callback)
self._create_controller(
@ -125,10 +118,6 @@ class AgentSession:
self.controller.agent_task = self.controller.start_step_loop()
await self.controller.agent_task # type: ignore
def _run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
async def close(self):
"""Closes the Agent session"""
@ -143,10 +132,8 @@ class AgentSession:
if self.security_analyzer is not None:
await self.security_analyzer.close()
self.loop.call_soon_threadsafe(self.loop.stop)
if self.thread:
# We may be closing an agent_session that was never actually started
self.thread.join()
if self.loop:
self.loop.call_soon_threadsafe(self.loop.stop)
self._closed = True

View File

@ -162,9 +162,10 @@ class Session:
'Model does not support image upload, change to a different model or try without an image.'
)
return
asyncio.run_coroutine_threadsafe(
self._add_event(event, EventSource.USER), self.agent_session.loop
) # type: ignore
if self.agent_session.loop:
asyncio.run_coroutine_threadsafe(
self._add_event(event, EventSource.USER), self.agent_session.loop
) # type: ignore
async def _add_event(self, event, event_source):
self.agent_session.event_stream.add_event(event, EventSource.USER)

View File

@ -1,4 +1,3 @@
import asyncio
import pathlib
import tempfile
@ -42,7 +41,7 @@ def temp_dir(monkeypatch):
yield temp_dir
async def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]):
def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]):
for event, source in data:
event_stream.add_event(event, source)
@ -62,7 +61,7 @@ def test_msg(temp_dir: str):
(MessageAction('Hello world!'), EventSource.USER),
(MessageAction('ABC!'), EventSource.AGENT),
]
asyncio.run(add_events(event_stream, data))
add_events(event_stream, data)
for i in range(3):
assert data[i][0].security_risk == ActionSecurityRisk.LOW
assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
@ -86,7 +85,7 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
(MessageAction('Hello world!'), EventSource.USER),
(CmdRunAction(cmd), EventSource.USER),
]
asyncio.run(add_events(event_stream, data))
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
assert data[1][0].security_risk == expected_risk
@ -115,7 +114,7 @@ def test_leak_secrets(code, expected_risk, temp_dir: str):
(IPythonRunCellAction(code), EventSource.AGENT),
(IPythonRunCellAction('hello'), EventSource.AGENT),
]
asyncio.run(add_events(event_stream, data))
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
assert data[1][0].security_risk == expected_risk
assert data[2][0].security_risk == ActionSecurityRisk.LOW
@ -133,7 +132,7 @@ def test_unsafe_python_code(temp_dir: str):
(MessageAction('Hello world!'), EventSource.USER),
(IPythonRunCellAction(code), EventSource.AGENT),
]
asyncio.run(add_events(event_stream, data))
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
# TODO: this failed but idk why and seems not deterministic to me
# assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM
@ -148,7 +147,7 @@ def test_unsafe_bash_command(temp_dir: str):
(MessageAction('Hello world!'), EventSource.USER),
(CmdRunAction(code), EventSource.AGENT),
]
asyncio.run(add_events(event_stream, data))
add_events(event_stream, data)
assert data[0][0].security_risk == ActionSecurityRisk.LOW
assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM