fix: Add missing type annotations in utils/ directory (#6687)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-02-21 08:27:57 -05:00 committed by GitHub
parent 35bab5070d
commit 9d3a0a02b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 11 deletions

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from typing import Any, cast
import requests
from requests.structures import CaseInsensitiveDict
from openhands.core.logger import openhands_logger as logger
@ -15,13 +17,25 @@ class HttpSession:
session: requests.Session | None = field(default_factory=requests.Session)
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if self.session is None:
logger.error(
'Session is being used after close!', stack_info=True, exc_info=True
)
return object.__getattribute__(self.session, name)
raise RuntimeError('Session is being used after close!')
return getattr(self.session, name)
def close(self):
@property
def headers(self) -> CaseInsensitiveDict[str]:
if self.session is None:
logger.error(
'Session is being used after close!', stack_info=True, exc_info=True
)
raise RuntimeError('Session is being used after close!')
# Cast to CaseInsensitiveDict[str] since mypy doesn't know the exact type
return cast(CaseInsensitiveDict[str], self.session.headers)
def close(self) -> None:
if self.session is not None:
self.session.close()
self.session = None

View File

@ -109,7 +109,7 @@ class PromptManager:
if name not in self.disabled_microagents:
self.repo_microagents[name] = microagent
def load_microagents(self, microagents: list[BaseMicroAgent]):
def load_microagents(self, microagents: list[BaseMicroAgent]) -> None:
"""Load microagents from a list of BaseMicroAgents.
This is typically used when loading microagents from inside a repo.
@ -135,7 +135,7 @@ class PromptManager:
def get_system_message(self) -> str:
return self.system_template.render().strip()
def set_runtime_info(self, runtime: Runtime):
def set_runtime_info(self, runtime: Runtime) -> None:
self.runtime_info.available_hosts = runtime.web_hosts
def set_repository_info(

View File

@ -19,10 +19,10 @@ _should_exit = None
_shutdown_listeners: dict[UUID, Callable] = {}
def _register_signal_handler(sig: signal.Signals):
def _register_signal_handler(sig: signal.Signals) -> None:
original_handler = None
def handler(sig_: int, frame: FrameType | None):
def handler(sig_: int, frame: FrameType | None) -> None:
logger.debug(f'shutdown_signal:{sig_}')
global _should_exit
if not _should_exit:
@ -39,7 +39,7 @@ def _register_signal_handler(sig: signal.Signals):
original_handler = signal.signal(sig, handler)
def _register_signal_handlers():
def _register_signal_handlers() -> None:
global _should_exit
if _should_exit is not None:
return
@ -66,7 +66,7 @@ def should_continue() -> bool:
return not _should_exit
def sleep_if_should_continue(timeout: float):
def sleep_if_should_continue(timeout: float) -> None:
if timeout <= 1:
time.sleep(timeout)
return
@ -75,7 +75,7 @@ def sleep_if_should_continue(timeout: float):
time.sleep(1)
async def async_sleep_if_should_continue(timeout: float):
async def async_sleep_if_should_continue(timeout: float) -> None:
if timeout <= 1:
await asyncio.sleep(timeout)
return

View File

@ -8,4 +8,4 @@ class stop_if_should_exit(stop_base):
"""Stop if the should_exit flag is set."""
def __call__(self, retry_state: 'RetryCallState') -> bool:
return should_exit()
return bool(should_exit())