diff --git a/openhands/runtime/runtime.py b/openhands/runtime/runtime.py index 7e420643c3..44614ee0a3 100644 --- a/openhands/runtime/runtime.py +++ b/openhands/runtime/runtime.py @@ -28,7 +28,7 @@ from openhands.events.observation import ( ) from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS from openhands.runtime.plugins import JupyterRequirement, PluginRequirement -from openhands.utils.async_utils import sync_from_async +from openhands.utils.async_utils import call_sync_from_async def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]: @@ -123,7 +123,7 @@ class Runtime: if event.timeout is None: event.timeout = self.config.sandbox.timeout assert event.timeout is not None - observation = await sync_from_async(self.run_action, event) + observation = await call_sync_from_async(self.run_action, event) observation._cause = event.id # type: ignore[attr-defined] source = event.source if event.source else EventSource.AGENT await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type] diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index 275888bb41..9d8b280716 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -19,7 +19,7 @@ from openhands.runtime.utils import find_available_tcp_port from openhands.security.analyzer import SecurityAnalyzer from openhands.security.invariant.client import InvariantClient from openhands.security.invariant.parser import TraceElement, parse_element -from openhands.utils.async_utils import sync_from_async +from openhands.utils.async_utils import call_sync_from_async class InvariantAnalyzer(SecurityAnalyzer): @@ -146,7 +146,7 @@ class InvariantAnalyzer(SecurityAnalyzer): {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}} ) event_source = event.source if event.source else EventSource.AGENT - await sync_from_async(self.event_stream.add_event, new_event, event_source) + await call_sync_from_async(self.event_stream.add_event, new_event, event_source) async def security_risk(self, event: Action) -> ActionSecurityRisk: logger.info('Calling security_risk on InvariantAnalyzer') diff --git a/openhands/server/listen.py b/openhands/server/listen.py index d7f1777349..32c93a117e 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -14,7 +14,7 @@ from pathspec.patterns import GitWildMatchPattern from openhands.security.options import SecurityAnalyzers from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback from openhands.storage import get_file_store -from openhands.utils.async_utils import sync_from_async +from openhands.utils.async_utils import call_sync_from_async with warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -211,8 +211,8 @@ async def attach_session(request: Request, call_next): content={'error': 'Invalid token'}, ) - request.state.conversation = session_manager.attach_to_conversation( - request.state.sid + request.state.conversation = await call_sync_from_async( + session_manager.attach_to_conversation, request.state.sid ) if request.state.conversation is None: return JSONResponse( @@ -441,7 +441,9 @@ async def list_files(request: Request, path: str | None = None): ) runtime: Runtime = request.state.conversation.runtime - file_list = await sync_from_async(runtime.list_files, path) + file_list = await asyncio.create_task( + call_sync_from_async(runtime.list_files, path) + ) if path: file_list = [os.path.join(path, f) for f in file_list] @@ -490,7 +492,7 @@ async def select_file(file: str, request: Request): file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file) read_action = FileReadAction(file) - observation = await sync_from_async(runtime.run_action, read_action) + observation = await call_sync_from_async(runtime.run_action, read_action) if isinstance(observation, FileReadObservation): content = observation.content @@ -687,7 +689,7 @@ async def save_file(request: Request): runtime.config.workspace_mount_path_in_sandbox, file_path ) write_action = FileWriteAction(file_path, content) - observation = await sync_from_async(runtime.run_action, write_action) + observation = await call_sync_from_async(runtime.run_action, write_action) if isinstance(observation, FileWriteObservation): return JSONResponse( diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index f172021a37..6bc442ac73 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -14,6 +14,7 @@ from openhands.runtime import get_runtime_cls from openhands.runtime.runtime import Runtime from openhands.security import SecurityAnalyzer, options from openhands.storage.files import FileStore +from openhands.utils.async_utils import call_sync_from_async class AgentSession: @@ -102,7 +103,13 @@ class AgentSession: ): self.loop = asyncio.get_running_loop() self._create_security_analyzer(config.security.security_analyzer) - self._create_runtime(runtime_name, config, agent, status_message_callback) + await call_sync_from_async( + self._create_runtime, + runtime_name=runtime_name, + config=config, + agent=agent, + status_message_callback=status_message_callback, + ) self._create_controller( agent, config.security.confirmation_mode, diff --git a/openhands/utils/async_utils.py b/openhands/utils/async_utils.py index 7da8d05ff5..2a3b73f5da 100644 --- a/openhands/utils/async_utils.py +++ b/openhands/utils/async_utils.py @@ -7,7 +7,7 @@ GENERAL_TIMEOUT: int = 15 EXECUTOR = ThreadPoolExecutor() -async def sync_from_async(fn: Callable, *args, **kwargs): +async def call_sync_from_async(fn: Callable, *args, **kwargs): """ Shorthand for running a function in the default background thread pool executor and awaiting the result. The nature of synchronous code is that the future @@ -19,7 +19,7 @@ async def sync_from_async(fn: Callable, *args, **kwargs): return result -def async_from_sync( +def call_async_from_sync( corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs ): """ @@ -27,6 +27,11 @@ def async_from_sync( and awaiting the result """ + if corofn is None: + raise ValueError('corofn is None') + if not asyncio.iscoroutinefunction(corofn): + raise ValueError('corofn is not a coroutine function') + async def arun(): coro = corofn(*args, **kwargs) result = await coro @@ -46,6 +51,13 @@ def async_from_sync( return result +async def call_coro_in_bg_thread( + corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs +): + """Function for running a coroutine in a background thread.""" + await call_sync_from_async(call_async_from_sync, corofn, timeout, *args, **kwargs) + + async def wait_all( iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT ) -> List: diff --git a/tests/unit/test_async_utils.py b/tests/unit/test_async_utils.py index 89dd1e0f69..3dc9943896 100644 --- a/tests/unit/test_async_utils.py +++ b/tests/unit/test_async_utils.py @@ -1,11 +1,13 @@ import asyncio +import time import pytest from openhands.utils.async_utils import ( AsyncException, - async_from_sync, - sync_from_async, + call_async_from_sync, + call_coro_in_bg_thread, + call_sync_from_async, wait_all, ) @@ -80,44 +82,44 @@ async def test_await_all_timeout(): @pytest.mark.asyncio -async def test_sync_from_async(): +async def test_call_sync_from_async(): def dummy(value: int = 2): return value * 2 - result = await sync_from_async(dummy) + result = await call_sync_from_async(dummy) assert result == 4 - result = await sync_from_async(dummy, 3) + result = await call_sync_from_async(dummy, 3) assert result == 6 - result = await sync_from_async(dummy, value=5) + result = await call_sync_from_async(dummy, value=5) assert result == 10 @pytest.mark.asyncio -async def test_sync_from_async_error(): +async def test_call_sync_from_async_error(): def dummy(): raise ValueError() with pytest.raises(ValueError): - await sync_from_async(dummy) + await call_sync_from_async(dummy) -def test_async_from_sync(): +def test_call_async_from_sync(): async def dummy(value: int): return value * 2 - result = async_from_sync(dummy, 0, 3) + result = call_async_from_sync(dummy, 0, 3) assert result == 6 -def test_async_from_sync_error(): +def test_call_async_from_sync_error(): async def dummy(value: int): raise ValueError() with pytest.raises(ValueError): - async_from_sync(dummy, 0, 3) + call_async_from_sync(dummy, 0, 3) -def test_async_from_sync_background_tasks(): +def test_call_async_from_sync_background_tasks(): events = [] async def bg_task(): @@ -132,9 +134,33 @@ def test_async_from_sync_background_tasks(): asyncio.create_task(bg_task()) events.append('dummy_started') - async_from_sync(dummy, 0, 3) + call_async_from_sync(dummy, 0, 3) # We check that the function did not return until all coroutines completed # (Even though some of these were started as background tasks) expected = ['dummy_started', 'dummy_started', 'bg_started', 'bg_finished'] assert expected == events + + +@pytest.mark.asyncio +async def test_call_coro_in_bg_thread(): + times = {} + + async def bad_async(id_): + # Dummy demonstrating some bad async function that does not cede control + time.sleep(0.1) + times[id_] = time.time() + + async def curve_ball(): + # A curve ball - an async function that wants to run while the bad async functions are in progress + await asyncio.sleep(0.05) + times['curve_ball'] = time.time() + + start = time.time() + asyncio.create_task(curve_ball()) + await wait_all( + call_coro_in_bg_thread(bad_async, id_=f'bad_async_{id_}') for id_ in range(5) + ) + assert (times['curve_ball'] - start) == pytest.approx(0.05, abs=0.1) + for id_ in range(5): + assert (times[f'bad_async_{id_}'] - start) == pytest.approx(0.1, abs=0.1)