Feat Tightening up Timeouts and interrupt conditions. (#3926)

This commit is contained in:
tofarr
2024-09-18 14:50:42 -06:00
committed by GitHub
parent 47f60b8275
commit ad0b549d8b
12 changed files with 84 additions and 14 deletions

View File

@@ -2,7 +2,6 @@ import asyncio
import json
import os
import tempfile
import time
from typing import Any
import pandas as pd
@@ -32,6 +31,7 @@ from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.runtime.runtime import Runtime
from openhands.runtime.utils.shutdown_listener import sleep_if_should_continue
USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'false').lower() == 'true'
@@ -316,10 +316,10 @@ def complete_runtime(
break
else:
logger.info('Failed to get git diff, retrying...')
time.sleep(10)
sleep_if_should_continue(10)
elif isinstance(obs, ErrorObservation):
logger.error(f'Error occurred: {obs.content}. Retrying...')
time.sleep(10)
sleep_if_should_continue(10)
else:
raise ValueError(f'Unexpected observation type: {type(obs)}')

View File

@@ -37,6 +37,7 @@ from openhands.events.observation import (
Observation,
)
from openhands.llm.llm import LLM
from openhands.runtime.utils.shutdown_listener import should_continue
# note: RESUME is only available on web GUI
TRAFFIC_CONTROL_REMINDER = (
@@ -148,7 +149,7 @@ class AgentController:
"""The main loop for the agent's step-by-step execution."""
logger.info(f'[Agent Controller {self.id}] Starting step loop...')
while True:
while should_continue():
try:
await self._step()
except asyncio.CancelledError:

View File

@@ -8,6 +8,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
from openhands.events.event import Event, EventSource
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.storage import FileStore
@@ -85,7 +86,7 @@ class EventStream:
event_id -= 1
else:
event_id = start_id
while True:
while should_continue():
if end_id is not None and event_id > end_id:
break
try:

View File

@@ -5,6 +5,7 @@ from functools import partial
from typing import Union
from openhands.core.config import LLMConfig
from openhands.runtime.utils.shutdown_listener import should_continue
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@@ -296,7 +297,7 @@ class LLM:
debug_message = self._get_debug_message(messages)
async def check_stopped():
while True:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
and self.config.on_cancel_requested_fn is not None

View File

@@ -16,6 +16,7 @@ from PIL import Image
from openhands.core.exceptions import BrowserInitException
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.utils.shutdown_listener import should_continue, should_exit
BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
@@ -99,7 +100,7 @@ class BrowserEnv:
self.eval_goal = obs['goal']
logger.info('Browser env started.')
while True:
while should_continue():
try:
if self.browser_side.poll(timeout=0.01):
unique_request_id, action_data = self.browser_side.recv()
@@ -157,7 +158,7 @@ class BrowserEnv:
self.agent_side.send((unique_request_id, {'action': action_str}))
start_time = time.time()
while True:
if time.time() - start_time > timeout:
if should_exit() or time.time() - start_time > timeout:
raise TimeoutError('Browser environment took too long to respond.')
if self.agent_side.poll(timeout=0.01):
response_id, obs = self.agent_side.recv()

View File

@@ -8,6 +8,7 @@ import requests
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.builder import RuntimeBuilder
from openhands.runtime.utils.request import send_request
from openhands.runtime.utils.shutdown_listener import should_exit, sleep_if_should_continue
class RemoteRuntimeBuilder(RuntimeBuilder):
@@ -57,7 +58,7 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
start_time = time.time()
timeout = 30 * 60 # 20 minutes in seconds
while True:
if time.time() - start_time > timeout:
if should_exit() or time.time() - start_time > timeout:
logger.error('Build timed out after 30 minutes')
raise RuntimeError('Build timed out after 30 minutes')
@@ -95,7 +96,7 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
raise RuntimeError(error_message)
# Wait before polling again
time.sleep(30)
sleep_if_should_continue(30)
def image_exists(self, image_name: str) -> bool:
"""Checks if an image exists in the remote registry using the /image_exists endpoint."""

View File

@@ -8,6 +8,7 @@ from openhands.events.observation import IPythonRunCellObservation
from openhands.runtime.plugins.jupyter.execute_server import JupyterKernel
from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
from openhands.runtime.utils import find_available_tcp_port
from openhands.runtime.utils.shutdown_listener import should_continue
@dataclass
@@ -38,7 +39,7 @@ class JupyterPlugin(Plugin):
)
# read stdout until the kernel gateway is ready
output = ''
while True and self.gateway_process.stdout is not None:
while should_continue() and self.gateway_process.stdout is not None:
line = self.gateway_process.stdout.readline().decode('utf-8')
output += line
if 'at' in line:

View File

@@ -47,6 +47,7 @@ def send_request(
if retry_fns is not None:
for fn in retry_fns:
retry_condition |= retry_if_exception(fn)
kwargs["timeout"] = timeout
@retry(
stop=stop_after_delay(timeout),

View File

@@ -0,0 +1,60 @@
"""
This module monitors the app for shutdown signals
"""
import asyncio
import signal
import time
from types import FrameType
from uvicorn.server import HANDLED_SIGNALS
_should_exit = None
def _register_signal_handler(sig: signal.Signals):
original_handler = None
def handler(sig_: int, frame: FrameType | None):
global _should_exit
_should_exit = True
if original_handler:
original_handler(sig_, frame) # type: ignore[unreachable]
original_handler = signal.signal(sig, handler)
def _register_signal_handlers():
global _should_exit
if _should_exit is not None:
return
_should_exit = False
for sig in HANDLED_SIGNALS:
_register_signal_handler(sig)
def should_exit() -> bool:
_register_signal_handlers()
return bool(_should_exit)
def should_continue() -> bool:
_register_signal_handlers()
return not _should_exit
def sleep_if_should_continue(timeout: float):
if(timeout <= 1):
time.sleep(timeout)
return
start_time = time.time()
while (time.time() - start_time) < timeout and should_continue():
time.sleep(1)
async def async_sleep_if_should_continue(timeout: float):
if(timeout <= 1):
await asyncio.sleep(timeout)
return
start_time = time.time()
while time.time() - start_time < timeout and should_continue():
await asyncio.sleep(1)

View File

@@ -2,6 +2,7 @@ import uvicorn
from fastapi import FastAPI, WebSocket
from openhands.core.schema import ActionType
from openhands.runtime.utils.shutdown_listener import should_continue
app = FastAPI()
@@ -15,7 +16,7 @@ async def websocket_endpoint(websocket: WebSocket):
)
try:
while True:
while should_continue():
# receive message
data = await websocket.receive_json()
print(f'Received message: {data}')

View File

@@ -5,6 +5,7 @@ from fastapi import WebSocket
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.server.session.session import Session
from openhands.storage.files import FileStore
@@ -47,7 +48,7 @@ class SessionManager:
return await self.send(sid, {'message': message})
async def _cleanup_sessions(self):
while True:
while should_continue():
current_time = time.time()
session_ids_to_remove = []
for sid, session in list(self._sessions.items()):

View File

@@ -20,6 +20,7 @@ from openhands.events.observation import (
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.server.session.agent import AgentSession
from openhands.storage.files import FileStore
@@ -53,7 +54,7 @@ class Session:
try:
if self.websocket is None:
return
while True:
while should_continue():
try:
data = await self.websocket.receive_json()
except ValueError: