From ffdab28abc5c222ae40451a07ef9c11cf5920379 Mon Sep 17 00:00:00 2001 From: tofarr Date: Mon, 27 Jan 2025 11:09:09 -0700 Subject: [PATCH] Fix Docker runtimes not stopping (#6470) Co-authored-by: openhands --- .../runtime/impl/docker/docker_runtime.py | 20 ++- openhands/utils/shutdown_listener.py | 29 ++++- tests/unit/test_shutdown_listener.py | 116 ++++++++++++++++++ 3 files changed, 149 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_shutdown_listener.py diff --git a/openhands/runtime/impl/docker/docker_runtime.py b/openhands/runtime/impl/docker/docker_runtime.py index 91621c2490..02ad760a0f 100644 --- a/openhands/runtime/impl/docker/docker_runtime.py +++ b/openhands/runtime/impl/docker/docker_runtime.py @@ -1,6 +1,6 @@ -import atexit from functools import lru_cache from typing import Callable +from uuid import UUID import docker import requests @@ -26,6 +26,7 @@ from openhands.runtime.utils.command import get_action_execution_server_startup_ from openhands.runtime.utils.log_streamer import LogStreamer from openhands.runtime.utils.runtime_build import build_runtime_image from openhands.utils.async_utils import call_sync_from_async +from openhands.utils.shutdown_listener import add_shutdown_listener from openhands.utils.tenacity_stop import stop_if_should_exit CONTAINER_NAME_PREFIX = 'openhands-runtime-' @@ -36,13 +37,6 @@ APP_PORT_RANGE_1 = (50000, 54999) APP_PORT_RANGE_2 = (55000, 59999) -def stop_all_runtime_containers(): - stop_all_containers(CONTAINER_NAME_PREFIX) - - -_atexit_registered = False - - class DockerRuntime(ActionExecutionClient): """This runtime will subscribe the event stream. When receive an event, it will send the event to runtime-client which run inside the docker environment. @@ -55,6 +49,8 @@ class DockerRuntime(ActionExecutionClient): env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None. """ + _shutdown_listener_id: UUID | None = None + def __init__( self, config: AppConfig, @@ -66,10 +62,10 @@ class DockerRuntime(ActionExecutionClient): attach_to_existing: bool = False, headless_mode: bool = True, ): - global _atexit_registered - if not _atexit_registered: - _atexit_registered = True - atexit.register(stop_all_runtime_containers) + if not DockerRuntime._shutdown_listener_id: + DockerRuntime._shutdown_listener_id = add_shutdown_listener( + lambda: stop_all_containers(CONTAINER_NAME_PREFIX) + ) self.config = config self._runtime_initialized: bool = False diff --git a/openhands/utils/shutdown_listener.py b/openhands/utils/shutdown_listener.py index 5cf84a309c..eddaac54f0 100644 --- a/openhands/utils/shutdown_listener.py +++ b/openhands/utils/shutdown_listener.py @@ -1,5 +1,6 @@ """ -This module monitors the app for shutdown signals +This module monitors the app for shutdown signals. This exists because the atexit module +does not play nocely with stareltte / uvicorn shutdown signals. """ import asyncio @@ -7,12 +8,15 @@ import signal import threading import time from types import FrameType +from typing import Callable +from uuid import UUID, uuid4 from uvicorn.server import HANDLED_SIGNALS from openhands.core.logger import openhands_logger as logger _should_exit = None +_shutdown_listeners: dict[UUID, Callable] = {} def _register_signal_handler(sig: signal.Signals): @@ -21,9 +25,16 @@ def _register_signal_handler(sig: signal.Signals): def handler(sig_: int, frame: FrameType | None): logger.debug(f'shutdown_signal:{sig_}') global _should_exit - _should_exit = True - if original_handler: - original_handler(sig_, frame) # type: ignore[unreachable] + if not _should_exit: + _should_exit = True + listeners = list(_shutdown_listeners.values()) + for callable in listeners: + try: + callable() + except Exception: + logger.exception('Error calling shutdown listener') + if original_handler: + original_handler(sig_, frame) # type: ignore[unreachable] original_handler = signal.signal(sig, handler) @@ -71,3 +82,13 @@ async def async_sleep_if_should_continue(timeout: float): start_time = time.time() while time.time() - start_time < timeout and should_continue(): await asyncio.sleep(1) + + +def add_shutdown_listener(callable: Callable) -> UUID: + id_ = uuid4() + _shutdown_listeners[id_] = callable + return id_ + + +def remove_shutdown_listener(id_: UUID) -> bool: + return _shutdown_listeners.pop(id_, None) is not None diff --git a/tests/unit/test_shutdown_listener.py b/tests/unit/test_shutdown_listener.py new file mode 100644 index 0000000000..a4317c6b76 --- /dev/null +++ b/tests/unit/test_shutdown_listener.py @@ -0,0 +1,116 @@ +import signal +from dataclasses import dataclass, field +from signal import Signals +from typing import Callable +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest + +from openhands.utils import shutdown_listener +from openhands.utils.shutdown_listener import ( + add_shutdown_listener, + remove_shutdown_listener, + should_continue, +) + + +@pytest.fixture(autouse=True) +def cleanup_listeners(): + shutdown_listener._shutdown_listeners.clear() + shutdown_listener._should_exit = False + + +@dataclass +class MockSignal: + handlers: dict[Signals, Callable] = field(default_factory=dict) + + def signal(self, signalnum: Signals, handler: Callable): + result = self.handlers.get(signalnum) + self.handlers[signalnum] = handler + return result + + def trigger(self, signalnum: Signals): + handler = self.handlers.get(signalnum) + if handler: + handler(signalnum.value, None) + + +def test_add_shutdown_listener(): + mock_callable = MagicMock() + listener_id = add_shutdown_listener(mock_callable) + + assert isinstance(listener_id, UUID) + assert listener_id in shutdown_listener._shutdown_listeners + assert shutdown_listener._shutdown_listeners[listener_id] == mock_callable + + +def test_remove_shutdown_listener(): + mock_callable = MagicMock() + listener_id = add_shutdown_listener(mock_callable) + + # Test successful removal + assert remove_shutdown_listener(listener_id) is True + assert listener_id not in shutdown_listener._shutdown_listeners + + # Test removing non-existent listener + assert remove_shutdown_listener(listener_id) is False + + +def test_signal_handler_calls_listeners(): + mock_signal = MockSignal() + with patch('openhands.utils.shutdown_listener.signal', mock_signal): + mock_callable1 = MagicMock() + mock_callable2 = MagicMock() + add_shutdown_listener(mock_callable1) + add_shutdown_listener(mock_callable2) + + # Register and trigger signal handler + shutdown_listener._register_signal_handler(signal.SIGTERM) + mock_signal.trigger(signal.SIGTERM) + + # Verify both listeners were called + mock_callable1.assert_called_once() + mock_callable2.assert_called_once() + + # Verify should_continue returns False after shutdown + assert should_continue() is False + + +def test_listeners_called_only_once(): + mock_signal = MockSignal() + with patch('openhands.utils.shutdown_listener.signal', mock_signal): + mock_callable = MagicMock() + add_shutdown_listener(mock_callable) + + # Register and trigger signal handler multiple times + shutdown_listener._register_signal_handler(signal.SIGTERM) + mock_signal.trigger(signal.SIGTERM) + mock_signal.trigger(signal.SIGTERM) + + # Verify listener was called only once + assert mock_callable.call_count == 1 + + +def test_remove_listener_during_shutdown(): + mock_signal = MockSignal() + with patch('openhands.utils.shutdown_listener.signal', mock_signal): + mock_callable1 = MagicMock() + mock_callable2 = MagicMock() + + # Second listener removes the first listener when called + listener1_id = add_shutdown_listener(mock_callable1) + + def remove_other_listener(): + remove_shutdown_listener(listener1_id) + mock_callable2() + + add_shutdown_listener(remove_other_listener) + + # Register and trigger signal handler + shutdown_listener._register_signal_handler(signal.SIGTERM) + mock_signal.trigger(signal.SIGTERM) + + # Both listeners should still be called + assert mock_callable1.call_count == 1 + assert mock_callable2.call_count == 1