diff --git a/evaluation/swe_bench/run_infer.py b/evaluation/swe_bench/run_infer.py index 36eb70554e..189f234071 100644 --- a/evaluation/swe_bench/run_infer.py +++ b/evaluation/swe_bench/run_infer.py @@ -2,7 +2,6 @@ import asyncio import json import os import tempfile -import time from typing import Any import pandas as pd @@ -32,6 +31,7 @@ from openhands.core.main import create_runtime, run_controller from openhands.events.action import CmdRunAction from openhands.events.observation import CmdOutputObservation, ErrorObservation from openhands.runtime.runtime import Runtime +from openhands.runtime.utils.shutdown_listener import sleep_if_should_continue USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true' USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true' @@ -316,10 +316,10 @@ def complete_runtime( break else: logger.info('Failed to get git diff, retrying...') - time.sleep(10) + sleep_if_should_continue(10) elif isinstance(obs, ErrorObservation): logger.error(f'Error occurred: {obs.content}. Retrying...') - time.sleep(10) + sleep_if_should_continue(10) else: raise ValueError(f'Unexpected observation type: {type(obs)}') diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 09d1c02a46..ba75518828 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -37,6 +37,7 @@ from openhands.events.observation import ( Observation, ) from openhands.llm.llm import LLM +from openhands.runtime.utils.shutdown_listener import should_continue # note: RESUME is only available on web GUI TRAFFIC_CONTROL_REMINDER = ( @@ -148,7 +149,7 @@ class AgentController: """The main loop for the agent's step-by-step execution.""" logger.info(f'[Agent Controller {self.id}] Starting step loop...') - while True: + while should_continue(): try: await self._step() except asyncio.CancelledError: diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 6285520658..59f7febb87 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -8,6 +8,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.core.utils import json from openhands.events.event import Event, EventSource from openhands.events.serialization.event import event_from_dict, event_to_dict +from openhands.runtime.utils.shutdown_listener import should_continue from openhands.storage import FileStore @@ -85,7 +86,7 @@ class EventStream: event_id -= 1 else: event_id = start_id - while True: + while should_continue(): if end_id is not None and event_id > end_id: break try: diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 16e9eb8a13..db7a241988 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -5,6 +5,7 @@ from functools import partial from typing import Union from openhands.core.config import LLMConfig +from openhands.runtime.utils.shutdown_listener import should_continue with warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -296,7 +297,7 @@ class LLM: debug_message = self._get_debug_message(messages) async def check_stopped(): - while True: + while should_continue(): if ( hasattr(self.config, 'on_cancel_requested_fn') and self.config.on_cancel_requested_fn is not None diff --git a/openhands/runtime/browser/browser_env.py b/openhands/runtime/browser/browser_env.py index 41efb9bb29..2483714c9b 100644 --- a/openhands/runtime/browser/browser_env.py +++ b/openhands/runtime/browser/browser_env.py @@ -16,6 +16,7 @@ from PIL import Image from openhands.core.exceptions import BrowserInitException from openhands.core.logger import openhands_logger as logger +from openhands.runtime.utils.shutdown_listener import should_continue, should_exit BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL' BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS' @@ -99,7 +100,7 @@ class BrowserEnv: self.eval_goal = obs['goal'] logger.info('Browser env started.') - while True: + while should_continue(): try: if self.browser_side.poll(timeout=0.01): unique_request_id, action_data = self.browser_side.recv() @@ -157,7 +158,7 @@ class BrowserEnv: self.agent_side.send((unique_request_id, {'action': action_str})) start_time = time.time() while True: - if time.time() - start_time > timeout: + if should_exit() or time.time() - start_time > timeout: raise TimeoutError('Browser environment took too long to respond.') if self.agent_side.poll(timeout=0.01): response_id, obs = self.agent_side.recv() diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py index d78cce93fc..3ce34383e6 100644 --- a/openhands/runtime/builder/remote.py +++ b/openhands/runtime/builder/remote.py @@ -8,6 +8,7 @@ import requests from openhands.core.logger import openhands_logger as logger from openhands.runtime.builder import RuntimeBuilder from openhands.runtime.utils.request import send_request +from openhands.runtime.utils.shutdown_listener import should_exit, sleep_if_should_continue class RemoteRuntimeBuilder(RuntimeBuilder): @@ -57,7 +58,7 @@ class RemoteRuntimeBuilder(RuntimeBuilder): start_time = time.time() timeout = 30 * 60 # 20 minutes in seconds while True: - if time.time() - start_time > timeout: + if should_exit() or time.time() - start_time > timeout: logger.error('Build timed out after 30 minutes') raise RuntimeError('Build timed out after 30 minutes') @@ -95,7 +96,7 @@ class RemoteRuntimeBuilder(RuntimeBuilder): raise RuntimeError(error_message) # Wait before polling again - time.sleep(30) + sleep_if_should_continue(30) def image_exists(self, image_name: str) -> bool: """Checks if an image exists in the remote registry using the /image_exists endpoint.""" diff --git a/openhands/runtime/plugins/jupyter/__init__.py b/openhands/runtime/plugins/jupyter/__init__.py index 26eda35054..b46714c242 100644 --- a/openhands/runtime/plugins/jupyter/__init__.py +++ b/openhands/runtime/plugins/jupyter/__init__.py @@ -8,6 +8,7 @@ from openhands.events.observation import IPythonRunCellObservation from openhands.runtime.plugins.jupyter.execute_server import JupyterKernel from openhands.runtime.plugins.requirement import Plugin, PluginRequirement from openhands.runtime.utils import find_available_tcp_port +from openhands.runtime.utils.shutdown_listener import should_continue @dataclass @@ -38,7 +39,7 @@ class JupyterPlugin(Plugin): ) # read stdout until the kernel gateway is ready output = '' - while True and self.gateway_process.stdout is not None: + while should_continue() and self.gateway_process.stdout is not None: line = self.gateway_process.stdout.readline().decode('utf-8') output += line if 'at' in line: diff --git a/openhands/runtime/utils/request.py b/openhands/runtime/utils/request.py index 7177cda909..ee95d60120 100644 --- a/openhands/runtime/utils/request.py +++ b/openhands/runtime/utils/request.py @@ -47,6 +47,7 @@ def send_request( if retry_fns is not None: for fn in retry_fns: retry_condition |= retry_if_exception(fn) + kwargs["timeout"] = timeout @retry( stop=stop_after_delay(timeout), diff --git a/openhands/runtime/utils/shutdown_listener.py b/openhands/runtime/utils/shutdown_listener.py new file mode 100644 index 0000000000..882d532a40 --- /dev/null +++ b/openhands/runtime/utils/shutdown_listener.py @@ -0,0 +1,60 @@ +""" +This module monitors the app for shutdown signals +""" +import asyncio +import signal +import time +from types import FrameType + +from uvicorn.server import HANDLED_SIGNALS + +_should_exit = None + + +def _register_signal_handler(sig: signal.Signals): + original_handler = None + + def handler(sig_: int, frame: FrameType | None): + global _should_exit + _should_exit = True + if original_handler: + original_handler(sig_, frame) # type: ignore[unreachable] + + original_handler = signal.signal(sig, handler) + + +def _register_signal_handlers(): + global _should_exit + if _should_exit is not None: + return + _should_exit = False + for sig in HANDLED_SIGNALS: + _register_signal_handler(sig) + + +def should_exit() -> bool: + _register_signal_handlers() + return bool(_should_exit) + + +def should_continue() -> bool: + _register_signal_handlers() + return not _should_exit + + +def sleep_if_should_continue(timeout: float): + if(timeout <= 1): + time.sleep(timeout) + return + start_time = time.time() + while (time.time() - start_time) < timeout and should_continue(): + time.sleep(1) + + +async def async_sleep_if_should_continue(timeout: float): + if(timeout <= 1): + await asyncio.sleep(timeout) + return + start_time = time.time() + while time.time() - start_time < timeout and should_continue(): + await asyncio.sleep(1) diff --git a/openhands/server/mock/listen.py b/openhands/server/mock/listen.py index e170201de7..77b37be2f7 100644 --- a/openhands/server/mock/listen.py +++ b/openhands/server/mock/listen.py @@ -2,6 +2,7 @@ import uvicorn from fastapi import FastAPI, WebSocket from openhands.core.schema import ActionType +from openhands.runtime.utils.shutdown_listener import should_continue app = FastAPI() @@ -15,7 +16,7 @@ async def websocket_endpoint(websocket: WebSocket): ) try: - while True: + while should_continue(): # receive message data = await websocket.receive_json() print(f'Received message: {data}') diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 074ca9b6e7..a14fdc8be1 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -5,6 +5,7 @@ from fastapi import WebSocket from openhands.core.config import AppConfig from openhands.core.logger import openhands_logger as logger +from openhands.runtime.utils.shutdown_listener import should_continue from openhands.server.session.session import Session from openhands.storage.files import FileStore @@ -47,7 +48,7 @@ class SessionManager: return await self.send(sid, {'message': message}) async def _cleanup_sessions(self): - while True: + while should_continue(): current_time = time.time() session_ids_to_remove = [] for sid, session in list(self._sessions.items()): diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 6636552a32..10c51594db 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -20,6 +20,7 @@ from openhands.events.observation import ( from openhands.events.serialization import event_from_dict, event_to_dict from openhands.events.stream import EventStreamSubscriber from openhands.llm.llm import LLM +from openhands.runtime.utils.shutdown_listener import should_continue from openhands.server.session.agent import AgentSession from openhands.storage.files import FileStore @@ -53,7 +54,7 @@ class Session: try: if self.websocket is None: return - while True: + while should_continue(): try: data = await self.websocket.receive_json() except ValueError: