diff --git a/openhands/utils/http_session.py b/openhands/utils/http_session.py index 9421d19dbb..cf95a35879 100644 --- a/openhands/utils/http_session.py +++ b/openhands/utils/http_session.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field +from typing import Any, cast import requests +from requests.structures import CaseInsensitiveDict from openhands.core.logger import openhands_logger as logger @@ -15,13 +17,25 @@ class HttpSession: session: requests.Session | None = field(default_factory=requests.Session) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if self.session is None: logger.error( 'Session is being used after close!', stack_info=True, exc_info=True ) - return object.__getattribute__(self.session, name) + raise RuntimeError('Session is being used after close!') + return getattr(self.session, name) - def close(self): + @property + def headers(self) -> CaseInsensitiveDict[str]: + if self.session is None: + logger.error( + 'Session is being used after close!', stack_info=True, exc_info=True + ) + raise RuntimeError('Session is being used after close!') + # Cast to CaseInsensitiveDict[str] since mypy doesn't know the exact type + return cast(CaseInsensitiveDict[str], self.session.headers) + + def close(self) -> None: if self.session is not None: self.session.close() + self.session = None diff --git a/openhands/utils/prompt.py b/openhands/utils/prompt.py index 7fc5d46382..47123345ad 100644 --- a/openhands/utils/prompt.py +++ b/openhands/utils/prompt.py @@ -109,7 +109,7 @@ class PromptManager: if name not in self.disabled_microagents: self.repo_microagents[name] = microagent - def load_microagents(self, microagents: list[BaseMicroAgent]): + def load_microagents(self, microagents: list[BaseMicroAgent]) -> None: """Load microagents from a list of BaseMicroAgents. This is typically used when loading microagents from inside a repo. @@ -135,7 +135,7 @@ class PromptManager: def get_system_message(self) -> str: return self.system_template.render().strip() - def set_runtime_info(self, runtime: Runtime): + def set_runtime_info(self, runtime: Runtime) -> None: self.runtime_info.available_hosts = runtime.web_hosts def set_repository_info( diff --git a/openhands/utils/shutdown_listener.py b/openhands/utils/shutdown_listener.py index eddaac54f0..ac99094ce6 100644 --- a/openhands/utils/shutdown_listener.py +++ b/openhands/utils/shutdown_listener.py @@ -19,10 +19,10 @@ _should_exit = None _shutdown_listeners: dict[UUID, Callable] = {} -def _register_signal_handler(sig: signal.Signals): +def _register_signal_handler(sig: signal.Signals) -> None: original_handler = None - def handler(sig_: int, frame: FrameType | None): + def handler(sig_: int, frame: FrameType | None) -> None: logger.debug(f'shutdown_signal:{sig_}') global _should_exit if not _should_exit: @@ -39,7 +39,7 @@ def _register_signal_handler(sig: signal.Signals): original_handler = signal.signal(sig, handler) -def _register_signal_handlers(): +def _register_signal_handlers() -> None: global _should_exit if _should_exit is not None: return @@ -66,7 +66,7 @@ def should_continue() -> bool: return not _should_exit -def sleep_if_should_continue(timeout: float): +def sleep_if_should_continue(timeout: float) -> None: if timeout <= 1: time.sleep(timeout) return @@ -75,7 +75,7 @@ def sleep_if_should_continue(timeout: float): time.sleep(1) -async def async_sleep_if_should_continue(timeout: float): +async def async_sleep_if_should_continue(timeout: float) -> None: if timeout <= 1: await asyncio.sleep(timeout) return diff --git a/openhands/utils/tenacity_stop.py b/openhands/utils/tenacity_stop.py index d9aa83a613..d877c1f27e 100644 --- a/openhands/utils/tenacity_stop.py +++ b/openhands/utils/tenacity_stop.py @@ -8,4 +8,4 @@ class stop_if_should_exit(stop_base): """Stop if the should_exit flag is set.""" def __call__(self, retry_state: 'RetryCallState') -> bool: - return should_exit() + return bool(should_exit())