Fix Docker runtimes not stopping (#6470)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
tofarr 2025-01-27 11:09:09 -07:00 committed by GitHub
parent 12dd23ba1c
commit ffdab28abc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 149 additions and 16 deletions

View File

@ -1,6 +1,6 @@
import atexit
from functools import lru_cache
from typing import Callable
from uuid import UUID
import docker
import requests
@ -26,6 +26,7 @@ from openhands.runtime.utils.command import get_action_execution_server_startup_
from openhands.runtime.utils.log_streamer import LogStreamer
from openhands.runtime.utils.runtime_build import build_runtime_image
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.shutdown_listener import add_shutdown_listener
from openhands.utils.tenacity_stop import stop_if_should_exit
CONTAINER_NAME_PREFIX = 'openhands-runtime-'
@ -36,13 +37,6 @@ APP_PORT_RANGE_1 = (50000, 54999)
APP_PORT_RANGE_2 = (55000, 59999)
def stop_all_runtime_containers():
stop_all_containers(CONTAINER_NAME_PREFIX)
_atexit_registered = False
class DockerRuntime(ActionExecutionClient):
"""This runtime will subscribe the event stream.
When receive an event, it will send the event to runtime-client which run inside the docker environment.
@ -55,6 +49,8 @@ class DockerRuntime(ActionExecutionClient):
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
"""
_shutdown_listener_id: UUID | None = None
def __init__(
self,
config: AppConfig,
@ -66,10 +62,10 @@ class DockerRuntime(ActionExecutionClient):
attach_to_existing: bool = False,
headless_mode: bool = True,
):
global _atexit_registered
if not _atexit_registered:
_atexit_registered = True
atexit.register(stop_all_runtime_containers)
if not DockerRuntime._shutdown_listener_id:
DockerRuntime._shutdown_listener_id = add_shutdown_listener(
lambda: stop_all_containers(CONTAINER_NAME_PREFIX)
)
self.config = config
self._runtime_initialized: bool = False

View File

@ -1,5 +1,6 @@
"""
This module monitors the app for shutdown signals
This module monitors the app for shutdown signals. This exists because the atexit module
does not play nocely with stareltte / uvicorn shutdown signals.
"""
import asyncio
@ -7,12 +8,15 @@ import signal
import threading
import time
from types import FrameType
from typing import Callable
from uuid import UUID, uuid4
from uvicorn.server import HANDLED_SIGNALS
from openhands.core.logger import openhands_logger as logger
_should_exit = None
_shutdown_listeners: dict[UUID, Callable] = {}
def _register_signal_handler(sig: signal.Signals):
@ -21,9 +25,16 @@ def _register_signal_handler(sig: signal.Signals):
def handler(sig_: int, frame: FrameType | None):
logger.debug(f'shutdown_signal:{sig_}')
global _should_exit
_should_exit = True
if original_handler:
original_handler(sig_, frame) # type: ignore[unreachable]
if not _should_exit:
_should_exit = True
listeners = list(_shutdown_listeners.values())
for callable in listeners:
try:
callable()
except Exception:
logger.exception('Error calling shutdown listener')
if original_handler:
original_handler(sig_, frame) # type: ignore[unreachable]
original_handler = signal.signal(sig, handler)
@ -71,3 +82,13 @@ async def async_sleep_if_should_continue(timeout: float):
start_time = time.time()
while time.time() - start_time < timeout and should_continue():
await asyncio.sleep(1)
def add_shutdown_listener(callable: Callable) -> UUID:
id_ = uuid4()
_shutdown_listeners[id_] = callable
return id_
def remove_shutdown_listener(id_: UUID) -> bool:
return _shutdown_listeners.pop(id_, None) is not None

View File

@ -0,0 +1,116 @@
import signal
from dataclasses import dataclass, field
from signal import Signals
from typing import Callable
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from openhands.utils import shutdown_listener
from openhands.utils.shutdown_listener import (
add_shutdown_listener,
remove_shutdown_listener,
should_continue,
)
@pytest.fixture(autouse=True)
def cleanup_listeners():
shutdown_listener._shutdown_listeners.clear()
shutdown_listener._should_exit = False
@dataclass
class MockSignal:
handlers: dict[Signals, Callable] = field(default_factory=dict)
def signal(self, signalnum: Signals, handler: Callable):
result = self.handlers.get(signalnum)
self.handlers[signalnum] = handler
return result
def trigger(self, signalnum: Signals):
handler = self.handlers.get(signalnum)
if handler:
handler(signalnum.value, None)
def test_add_shutdown_listener():
mock_callable = MagicMock()
listener_id = add_shutdown_listener(mock_callable)
assert isinstance(listener_id, UUID)
assert listener_id in shutdown_listener._shutdown_listeners
assert shutdown_listener._shutdown_listeners[listener_id] == mock_callable
def test_remove_shutdown_listener():
mock_callable = MagicMock()
listener_id = add_shutdown_listener(mock_callable)
# Test successful removal
assert remove_shutdown_listener(listener_id) is True
assert listener_id not in shutdown_listener._shutdown_listeners
# Test removing non-existent listener
assert remove_shutdown_listener(listener_id) is False
def test_signal_handler_calls_listeners():
mock_signal = MockSignal()
with patch('openhands.utils.shutdown_listener.signal', mock_signal):
mock_callable1 = MagicMock()
mock_callable2 = MagicMock()
add_shutdown_listener(mock_callable1)
add_shutdown_listener(mock_callable2)
# Register and trigger signal handler
shutdown_listener._register_signal_handler(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)
# Verify both listeners were called
mock_callable1.assert_called_once()
mock_callable2.assert_called_once()
# Verify should_continue returns False after shutdown
assert should_continue() is False
def test_listeners_called_only_once():
mock_signal = MockSignal()
with patch('openhands.utils.shutdown_listener.signal', mock_signal):
mock_callable = MagicMock()
add_shutdown_listener(mock_callable)
# Register and trigger signal handler multiple times
shutdown_listener._register_signal_handler(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)
# Verify listener was called only once
assert mock_callable.call_count == 1
def test_remove_listener_during_shutdown():
mock_signal = MockSignal()
with patch('openhands.utils.shutdown_listener.signal', mock_signal):
mock_callable1 = MagicMock()
mock_callable2 = MagicMock()
# Second listener removes the first listener when called
listener1_id = add_shutdown_listener(mock_callable1)
def remove_other_listener():
remove_shutdown_listener(listener1_id)
mock_callable2()
add_shutdown_listener(remove_other_listener)
# Register and trigger signal handler
shutdown_listener._register_signal_handler(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)
# Both listeners should still be called
assert mock_callable1.call_count == 1
assert mock_callable2.call_count == 1