mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Lockup Resiliency and Asyncio Improvements (#4221)
This commit is contained in:
parent
568c8ce993
commit
cdd05a98db
2
Makefile
2
Makefile
@ -195,7 +195,7 @@ start-backend:
|
||||
# Start frontend
|
||||
start-frontend:
|
||||
@echo "$(YELLOW)Starting frontend...$(RESET)"
|
||||
@cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start
|
||||
@cd frontend && VITE_BACKEND_HOST=$(BACKEND_HOST_PORT) VITE_FRONTEND_PORT=$(FRONTEND_PORT) npm run start -- --port $(FRONTEND_PORT)
|
||||
|
||||
# Common setup for running the app (non-callable)
|
||||
_run_setup:
|
||||
|
||||
@ -129,6 +129,13 @@ class EventStream:
|
||||
del self._subscribers[id]
|
||||
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self.async_add_event(event, source))
|
||||
except RuntimeError:
|
||||
# No event loop running...
|
||||
asyncio.run(self.async_add_event(event, source))
|
||||
|
||||
async def async_add_event(self, event: Event, source: EventSource):
|
||||
with self._lock:
|
||||
event._id = self._cur_id # type: ignore [attr-defined]
|
||||
self._cur_id += 1
|
||||
@ -138,10 +145,16 @@ class EventStream:
|
||||
data = event_to_dict(event)
|
||||
if event.id is not None:
|
||||
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
|
||||
tasks = []
|
||||
for key in sorted(self._subscribers.keys()):
|
||||
stack = self._subscribers[key]
|
||||
callback = stack[-1]
|
||||
asyncio.create_task(callback(event))
|
||||
tasks.append(asyncio.create_task(callback(event)))
|
||||
if tasks:
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
def _callback(self, callback: Callable, event: Event):
|
||||
asyncio.run(callback(event))
|
||||
|
||||
def filtered_events_by_source(self, source: EventSource):
|
||||
for event in self.get_events():
|
||||
|
||||
@ -73,7 +73,7 @@ class AsyncLLM(LLM):
|
||||
and self.config.on_cancel_requested_fn is not None
|
||||
and await self.config.on_cancel_requested_fn()
|
||||
):
|
||||
raise UserCancelledError('LLM request cancelled by user')
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
stop_check_task = asyncio.create_task(check_stopped())
|
||||
|
||||
@ -200,6 +200,9 @@ class RemoteRuntime(Runtime):
|
||||
assert (
|
||||
self.runtime_url is not None
|
||||
), 'Runtime URL is not set. This should never happen.'
|
||||
|
||||
self._wait_until_alive()
|
||||
|
||||
self.send_status_message(' ')
|
||||
|
||||
self._wait_until_alive()
|
||||
@ -229,7 +232,7 @@ class RemoteRuntime(Runtime):
|
||||
logger.warning(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def close(self):
|
||||
def close(self, timeout: int = 10):
|
||||
if self.runtime_id:
|
||||
try:
|
||||
response = send_request(
|
||||
@ -237,6 +240,7 @@ class RemoteRuntime(Runtime):
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/stop',
|
||||
json={'runtime_id': self.runtime_id},
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(f'Failed to stop sandbox: {response.text}')
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import copy
|
||||
import json
|
||||
@ -117,10 +118,10 @@ class Runtime:
|
||||
if event.timeout is None:
|
||||
event.timeout = self.config.sandbox.timeout
|
||||
assert event.timeout is not None
|
||||
observation = self.run_action(event)
|
||||
observation = await self.async_run_action(event)
|
||||
observation._cause = event.id # type: ignore[attr-defined]
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
"""Run an action and return the resulting observation.
|
||||
@ -151,6 +152,12 @@ class Runtime:
|
||||
observation = getattr(self, action_type)(action)
|
||||
return observation
|
||||
|
||||
async def async_run_action(self, action: Action) -> Observation:
|
||||
observation = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.run_action, action
|
||||
)
|
||||
return observation
|
||||
|
||||
# ====================================================================
|
||||
# Context manager
|
||||
# ====================================================================
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
@ -144,10 +145,8 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
new_event = action_from_dict(
|
||||
{'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
|
||||
)
|
||||
if event.source:
|
||||
self.event_stream.add_event(new_event, event.source)
|
||||
else:
|
||||
self.event_stream.add_event(new_event, EventSource.AGENT)
|
||||
event_source = event.source if event.source else EventSource.AGENT
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.event_stream.add_event, new_event, event_source)
|
||||
|
||||
async def security_risk(self, event: Action) -> ActionSecurityRisk:
|
||||
logger.info('Calling security_risk on InvariantAnalyzer')
|
||||
|
||||
@ -430,7 +430,9 @@ async def list_files(request: Request, path: str | None = None):
|
||||
content={'error': 'Runtime not yet initialized'},
|
||||
)
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
file_list = runtime.list_files(path)
|
||||
file_list = await asyncio.get_event_loop().run_in_executor(
|
||||
None, runtime.list_files, path
|
||||
)
|
||||
if path:
|
||||
file_list = [os.path.join(path, f) for f in file_list]
|
||||
|
||||
@ -451,6 +453,7 @@ async def list_files(request: Request, path: str | None = None):
|
||||
return file_list
|
||||
|
||||
file_list = filter_for_gitignore(file_list, '')
|
||||
|
||||
return file_list
|
||||
|
||||
|
||||
@ -478,7 +481,7 @@ async def select_file(file: str, request: Request):
|
||||
|
||||
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
|
||||
read_action = FileReadAction(file)
|
||||
observation = runtime.run_action(read_action)
|
||||
observation = await runtime.async_run_action(read_action)
|
||||
|
||||
if isinstance(observation, FileReadObservation):
|
||||
content = observation.content
|
||||
@ -720,7 +723,7 @@ async def save_file(request: Request):
|
||||
runtime.config.workspace_mount_path_in_sandbox, file_path
|
||||
)
|
||||
write_action = FileWriteAction(file_path, content)
|
||||
observation = runtime.run_action(write_action)
|
||||
observation = await runtime.async_run_action(write_action)
|
||||
|
||||
if isinstance(observation, FileWriteObservation):
|
||||
return JSONResponse(
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from threading import Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
from openhands.controller import AgentController
|
||||
@ -32,7 +30,7 @@ class AgentSession:
|
||||
runtime: Runtime | None = None
|
||||
security_analyzer: SecurityAnalyzer | None = None
|
||||
_closed: bool = False
|
||||
loop: asyncio.AbstractEventLoop
|
||||
loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore):
|
||||
"""Initializes a new instance of the Session class
|
||||
@ -45,7 +43,6 @@ class AgentSession:
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
self.file_store = file_store
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
async def start(
|
||||
self,
|
||||
@ -73,17 +70,9 @@ class AgentSession:
|
||||
'Session already started. You need to close this session and start a new one.'
|
||||
)
|
||||
|
||||
self.thread = Thread(target=self._run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def coro_callback(task):
|
||||
fut: concurrent.futures.Future = concurrent.futures.Future()
|
||||
try:
|
||||
fut.set_result(task.result())
|
||||
except Exception as e:
|
||||
logger.error(f'Error starting session: {e}')
|
||||
|
||||
coro = self._start(
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._start_thread,
|
||||
runtime_name,
|
||||
config,
|
||||
agent,
|
||||
@ -93,9 +82,12 @@ class AgentSession:
|
||||
agent_configs,
|
||||
status_message_callback,
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(coro, self.loop).add_done_callback(
|
||||
coro_callback
|
||||
) # type: ignore
|
||||
|
||||
def _start_thread(self, *args):
|
||||
try:
|
||||
asyncio.run(self._start(*args), debug=True)
|
||||
except RuntimeError:
|
||||
logger.info('Session Finished')
|
||||
|
||||
async def _start(
|
||||
self,
|
||||
@ -108,6 +100,7 @@ class AgentSession:
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self._create_security_analyzer(config.security.security_analyzer)
|
||||
self._create_runtime(runtime_name, config, agent, status_message_callback)
|
||||
self._create_controller(
|
||||
@ -125,10 +118,6 @@ class AgentSession:
|
||||
self.controller.agent_task = self.controller.start_step_loop()
|
||||
await self.controller.agent_task # type: ignore
|
||||
|
||||
def _run(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.loop.run_forever()
|
||||
|
||||
async def close(self):
|
||||
"""Closes the Agent session"""
|
||||
|
||||
@ -143,10 +132,8 @@ class AgentSession:
|
||||
if self.security_analyzer is not None:
|
||||
await self.security_analyzer.close()
|
||||
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
if self.thread:
|
||||
# We may be closing an agent_session that was never actually started
|
||||
self.thread.join()
|
||||
if self.loop:
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
|
||||
self._closed = True
|
||||
|
||||
|
||||
@ -162,9 +162,10 @@ class Session:
|
||||
'Model does not support image upload, change to a different model or try without an image.'
|
||||
)
|
||||
return
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_event(event, EventSource.USER), self.agent_session.loop
|
||||
) # type: ignore
|
||||
if self.agent_session.loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_event(event, EventSource.USER), self.agent_session.loop
|
||||
) # type: ignore
|
||||
|
||||
async def _add_event(self, event, event_source):
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
@ -42,7 +41,7 @@ def temp_dir(monkeypatch):
|
||||
yield temp_dir
|
||||
|
||||
|
||||
async def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]):
|
||||
def add_events(event_stream: EventStream, data: list[tuple[Event, EventSource]]):
|
||||
for event, source in data:
|
||||
event_stream.add_event(event, source)
|
||||
|
||||
@ -62,7 +61,7 @@ def test_msg(temp_dir: str):
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(MessageAction('ABC!'), EventSource.AGENT),
|
||||
]
|
||||
asyncio.run(add_events(event_stream, data))
|
||||
add_events(event_stream, data)
|
||||
for i in range(3):
|
||||
assert data[i][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[3][0].security_risk == ActionSecurityRisk.MEDIUM
|
||||
@ -86,7 +85,7 @@ def test_cmd(cmd, expected_risk, temp_dir: str):
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(CmdRunAction(cmd), EventSource.USER),
|
||||
]
|
||||
asyncio.run(add_events(event_stream, data))
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[1][0].security_risk == expected_risk
|
||||
|
||||
@ -115,7 +114,7 @@ def test_leak_secrets(code, expected_risk, temp_dir: str):
|
||||
(IPythonRunCellAction(code), EventSource.AGENT),
|
||||
(IPythonRunCellAction('hello'), EventSource.AGENT),
|
||||
]
|
||||
asyncio.run(add_events(event_stream, data))
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[1][0].security_risk == expected_risk
|
||||
assert data[2][0].security_risk == ActionSecurityRisk.LOW
|
||||
@ -133,7 +132,7 @@ def test_unsafe_python_code(temp_dir: str):
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(IPythonRunCellAction(code), EventSource.AGENT),
|
||||
]
|
||||
asyncio.run(add_events(event_stream, data))
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
# TODO: this failed but idk why and seems not deterministic to me
|
||||
# assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM
|
||||
@ -148,7 +147,7 @@ def test_unsafe_bash_command(temp_dir: str):
|
||||
(MessageAction('Hello world!'), EventSource.USER),
|
||||
(CmdRunAction(code), EventSource.AGENT),
|
||||
]
|
||||
asyncio.run(add_events(event_stream, data))
|
||||
add_events(event_stream, data)
|
||||
assert data[0][0].security_risk == ActionSecurityRisk.LOW
|
||||
assert data[1][0].security_risk == ActionSecurityRisk.MEDIUM
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user