Add type annotations to local runtime implementation (#8376)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-05-13 09:42:07 -04:00 committed by GitHub
parent dea3ddfcc6
commit f3d0ae3fbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 16 deletions

View File

@ -0,0 +1,5 @@
"""Local runtime implementation."""
from openhands.runtime.impl.local.local_runtime import LocalRuntime
__all__ = ['LocalRuntime']

View File

@ -1,6 +1,4 @@
"""
This runtime runs the action_execution_server directly on the local machine without Docker.
"""
"""This runtime runs the action_execution_server directly on the local machine without Docker."""
import os
import shutil
@ -41,7 +39,7 @@ from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.tenacity_stop import stop_if_should_exit
def get_user_info():
def get_user_info() -> tuple[int, str | None]:
"""Get user ID and username in a cross-platform way."""
username = os.getenv('USER')
if sys.platform == 'win32':
@ -53,7 +51,7 @@ def get_user_info():
return os.getuid(), username
def check_dependencies(code_repo_path: str, poetry_venvs_path: str):
def check_dependencies(code_repo_path: str, poetry_venvs_path: str) -> None:
ERROR_MESSAGE = 'Please follow the instructions in https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md to install OpenHands.'
if not os.path.exists(code_repo_path):
raise ValueError(
@ -112,7 +110,7 @@ class LocalRuntime(ActionExecutionClient):
config (AppConfig): The application configuration.
event_stream (EventStream): The event stream to subscribe to.
sid (str, optional): The session ID. Defaults to 'default'.
plugins (list[PluginRequirement] | None, optional): List of plugin requirements. Defaults to None.
plugins (list[PluginRequirement] | None, optional): list of plugin requirements. Defaults to None.
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
"""
@ -123,10 +121,10 @@ class LocalRuntime(ActionExecutionClient):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_callback: Callable | None = None,
status_callback: Callable[[str, str, str], None] | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
):
) -> None:
self.is_windows = sys.platform == 'win32'
if self.is_windows:
logger.warning(
@ -203,10 +201,10 @@ class LocalRuntime(ActionExecutionClient):
)
@property
def action_execution_server_url(self):
def action_execution_server_url(self) -> str:
return self.api_url
async def connect(self):
async def connect(self) -> None:
"""Start the action_execution_server on the local machine."""
self.send_status_message('STATUS$STARTING_RUNTIME')
@ -247,7 +245,7 @@ class LocalRuntime(ActionExecutionClient):
# Check dependencies using the derived env_root_path
check_dependencies(code_repo_path, env_root_path)
self.server_process = subprocess.Popen( # noqa: ASYNC101
self.server_process = subprocess.Popen( # noqa: S603
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
@ -258,7 +256,7 @@ class LocalRuntime(ActionExecutionClient):
)
# Start a thread to read and log server output
def log_output():
def log_output() -> None:
if not self.server_process or not self.server_process.stdout:
self.log('error', 'Server process or stdout not available for logging.')
return
@ -318,7 +316,9 @@ class LocalRuntime(ActionExecutionClient):
self.send_status_message(' ')
self._runtime_initialized = True
def _find_available_port(self, port_range, max_attempts=5):
def _find_available_port(
self, port_range: tuple[int, int], max_attempts: int = 5
) -> int:
port = port_range[1]
for _ in range(max_attempts):
port = find_available_tcp_port(port_range[0], port_range[1])
@ -332,7 +332,7 @@ class LocalRuntime(ActionExecutionClient):
f'Waiting for server to be ready... (attempt {retry_state.attempt_number})'
),
)
def _wait_until_alive(self):
def _wait_until_alive(self) -> bool:
"""Wait until the server is ready to accept requests."""
if self.server_process and self.server_process.poll() is not None:
raise RuntimeError('Server process died')
@ -365,7 +365,7 @@ class LocalRuntime(ActionExecutionClient):
except httpx.NetworkError:
raise AgentRuntimeDisconnectedError('Server connection lost')
def close(self):
def close(self) -> None:
"""Stop the server process."""
self._log_thread_exit_event.set() # Signal the log thread to exit
@ -392,7 +392,7 @@ class LocalRuntime(ActionExecutionClient):
return vscode_url
@property
def web_hosts(self):
def web_hosts(self) -> dict[str, int]:
hosts: dict[str, int] = {}
for port in self._app_ports:
hosts[f'http://localhost:{port}'] = port