Enforce modern Python typing annotations with Ruff (#8296)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-05-06 07:58:33 -04:00 committed by GitHub
parent 4c1ae6fd8d
commit adfa510b5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 79 additions and 72 deletions

View File

@ -26,7 +26,7 @@ repos:
- id: ruff
entry: ruff check --config dev_config/python/ruff.toml
types_or: [python, pyi, jupyter]
args: [--fix]
args: [--fix, --unsafe-fixes]
# Run the formatter.
- id: ruff-format
entry: ruff format --config dev_config/python/ruff.toml

View File

@ -7,6 +7,9 @@ select = [
"Q",
"B",
"ASYNC",
"UP006", # Use `list` instead of `List` for annotations
"UP007", # Use `X | Y` instead of `Union[X, Y]`
"UP008", # Use `X | None` instead of `Optional[X]`
]
ignore = [

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from litellm import ChatCompletionToolParam
from openhands.events.action import Action
from openhands.llm.llm import ModelResponse

View File

@ -3,11 +3,11 @@ ReadOnlyAgent - A specialized version of CodeActAgent that only uses read-only t
"""
import os
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from litellm import ChatCompletionToolParam
from openhands.events.action import Action
from openhands.llm.llm import ModelResponse

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List
import toml
@ -10,7 +9,7 @@ from openhands.events.event import Event
from openhands.llm.metrics import Metrics
_LOCAL_CONFIG_FILE_PATH = Path.home() / '.openhands' / 'config.toml'
_DEFAULT_CONFIG: Dict[str, Dict[str, List[str]]] = {'sandbox': {'trusted_dirs': []}}
_DEFAULT_CONFIG: dict[str, dict[str, list[str]]] = {'sandbox': {'trusted_dirs': []}}
def get_local_config_trusted_dirs() -> list[str]:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Type
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from openhands.controller.state.state import State
@ -30,7 +30,7 @@ class Agent(ABC):
It tracks the execution status and maintains a history of interactions.
"""
_registry: dict[str, Type['Agent']] = {}
_registry: dict[str, type['Agent']] = {}
sandbox_plugins: list[PluginRequirement] = []
def __init__(
@ -118,7 +118,7 @@ class Agent(ABC):
return self.__class__.__name__
@classmethod
def register(cls, name: str, agent_cls: Type['Agent']) -> None:
def register(cls, name: str, agent_cls: type['Agent']) -> None:
"""Registers an agent class in the registry.
Parameters:
@ -133,7 +133,7 @@ class Agent(ABC):
cls._registry[name] = agent_cls
@classmethod
def get_cls(cls, name: str) -> Type['Agent']:
def get_cls(cls, name: str) -> type['Agent']:
"""Retrieves an agent class from the registry.
Parameters:

View File

@ -5,7 +5,7 @@ import copy
import os
import time
import traceback
from typing import Callable, ClassVar, Tuple, Type
from typing import Callable, ClassVar
import litellm # noqa
from litellm.exceptions import ( # noqa
@ -91,7 +91,7 @@ class AgentController:
agent_configs: dict[str, AgentConfig]
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action_info: Tuple[Action, float] | None = None # (action, timestamp)
_pending_action_info: tuple[Action, float] | None = None # (action, timestamp)
_closed: bool = False
filter_out: ClassVar[tuple[type[Event], ...]] = (
NullAction,
@ -675,7 +675,7 @@ class AgentController:
Args:
action (AgentDelegateAction): The action containing information about the delegate agent to start.
"""
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
agent_cls: type[Agent] = Agent.get_cls(action.agent)
agent_config = self.agent_configs.get(action.agent, self.agent.config)
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
llm = LLM(config=llm_config, retry_listener=self._notify_on_llm_retry)

View File

@ -1,7 +1,7 @@
import hashlib
import os
import uuid
from typing import Callable, Tuple, Type
from typing import Callable
from pydantic import SecretStr
@ -173,7 +173,7 @@ def create_memory(
def create_agent(config: AppConfig) -> Agent:
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
@ -191,7 +191,7 @@ def create_controller(
config: AppConfig,
headless_mode: bool = True,
replay_events: list[Event] | None = None,
) -> Tuple[AgentController, State | None]:
) -> tuple[AgentController, State | None]:
event_stream = runtime.event_stream
initial_state = None
try:

View File

@ -6,6 +6,11 @@ from typing import Any
import httpx
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.queries import (
suggested_task_issue_graphql_query,
suggested_task_pr_graphql_query,
)
from openhands.integrations.service_types import (
BaseGitService,
Branch,
@ -20,9 +25,6 @@ from openhands.integrations.service_types import (
)
from openhands.server.types import AppMode
from openhands.utils.import_utils import get_impl
from openhands.integrations.github.queries import suggested_task_pr_graphql_query, suggested_task_issue_graphql_query
from datetime import datetime
from openhands.core.logger import openhands_logger as logger
class GitHubService(BaseGitService, GitService):
@ -301,7 +303,9 @@ class GitHubService(BaseGitService, GitService):
variables = {'login': login}
try:
pr_response = await self.execute_graphql_query(suggested_task_pr_graphql_query, variables)
pr_response = await self.execute_graphql_query(
suggested_task_pr_graphql_query, variables
)
pr_data = pr_response['data']['user']
# Process pull requests
@ -341,14 +345,20 @@ class GitHubService(BaseGitService, GitService):
)
)
except Exception as e:
logger.info(f"Error fetching suggested task for PRs: {e}",
extra={'signal': 'github_suggested_tasks', 'user_id': self.external_auth_id})
logger.info(
f'Error fetching suggested task for PRs: {e}',
extra={
'signal': 'github_suggested_tasks',
'user_id': self.external_auth_id,
},
)
try:
# Execute issue query
issue_response = await self.execute_graphql_query(suggested_task_issue_graphql_query, variables)
issue_response = await self.execute_graphql_query(
suggested_task_issue_graphql_query, variables
)
issue_data = issue_response['data']['user']
# Process issues
@ -367,8 +377,13 @@ class GitHubService(BaseGitService, GitService):
return tasks
except Exception as e:
logger.info(f"Error fetching suggested task for issues: {e}",
extra={'signal': 'github_suggested_tasks', 'user_id': self.external_auth_id})
logger.info(
f'Error fetching suggested task for issues: {e}',
extra={
'signal': 'github_suggested_tasks',
'user_id': self.external_auth_id,
},
)
return tasks

View File

@ -1,6 +1,6 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Dict, List, Optional
from typing import Optional
from mcp import ClientSession
from mcp.client.sse import sse_client
@ -18,8 +18,8 @@ class MCPClient(BaseModel):
session: Optional[ClientSession] = None
exit_stack: AsyncExitStack = AsyncExitStack()
description: str = 'MCP client tools for server interaction'
tools: List[MCPClientTool] = Field(default_factory=list)
tool_map: Dict[str, MCPClientTool] = Field(default_factory=dict)
tools: list[MCPClientTool] = Field(default_factory=list)
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
@ -91,7 +91,7 @@ class MCPClient(BaseModel):
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
)
async def call_tool(self, tool_name: str, args: Dict):
async def call_tool(self, tool_name: str, args: dict):
"""Call a tool on the MCP server."""
if tool_name not in self.tool_map:
raise ValueError(f'Tool {tool_name} not found.')

View File

@ -1,5 +1,3 @@
from typing import Dict
from mcp.types import Tool
@ -14,7 +12,7 @@ class MCPClientTool(Tool):
class Config:
arbitrary_types_allowed = True
def to_param(self) -> Dict:
def to_param(self) -> dict:
"""Convert tool to function call format."""
return {
'type': 'function',

View File

@ -6,11 +6,9 @@ class HunkException(PatchingException):
def __init__(self, msg: str, hunk: int | None = None) -> None:
self.hunk = hunk
if hunk is not None:
super(HunkException, self).__init__(
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
)
super().__init__('{msg}, in hunk #{n}'.format(msg=msg, n=hunk))
else:
super(HunkException, self).__init__(msg)
super().__init__(msg)
class ApplyException(PatchingException):
@ -19,7 +17,7 @@ class ApplyException(PatchingException):
class SubprocessException(ApplyException):
def __init__(self, msg: str, code: int) -> None:
super(SubprocessException, self).__init__(msg)
super().__init__(msg)
self.code = code

View File

@ -1,5 +1,3 @@
from typing import Type
from openhands.runtime.base import Runtime
from openhands.runtime.impl.daytona.daytona_runtime import DaytonaRuntime
from openhands.runtime.impl.docker.docker_runtime import (
@ -13,7 +11,7 @@ from openhands.runtime.impl.runloop.runloop_runtime import RunloopRuntime
from openhands.utils.import_utils import get_impl
# mypy: disable-error-code="type-abstract"
_DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
_DEFAULT_RUNTIME_CLASSES: dict[str, type[Runtime]] = {
'eventstream': DockerRuntime,
'docker': DockerRuntime,
'e2b': E2BRuntime,
@ -25,7 +23,7 @@ _DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
}
def get_runtime_cls(name: str) -> Type[Runtime]:
def get_runtime_cls(name: str) -> type[Runtime]:
"""
If name is one of the predefined runtime names (e.g. 'docker'), return its class.
Otherwise attempt to resolve name as subclass of Runtime and return it.

View File

@ -5,7 +5,6 @@ This server has no authentication and only listens to localhost traffic.
import os
import threading
from typing import Tuple
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
@ -75,7 +74,7 @@ def create_app() -> FastAPI:
return app
def start_file_viewer_server(port: int) -> Tuple[str, threading.Thread]:
def start_file_viewer_server(port: int) -> tuple[str, threading.Thread]:
"""Start the file viewer server on the specified port or find an available one.
Args:

View File

@ -6,7 +6,6 @@ import string
import tempfile
from enum import Enum
from pathlib import Path
from typing import List
import docker
from dirhash import dirhash # type: ignore
@ -111,7 +110,7 @@ def build_runtime_image(
build_folder: str | None = None,
dry_run: bool = False,
force_rebuild: bool = False,
extra_build_args: List[str] | None = None,
extra_build_args: list[str] | None = None,
) -> str:
"""Prepares the final docker build folder.
@ -167,7 +166,7 @@ def build_runtime_image_in_folder(
dry_run: bool,
force_rebuild: bool,
platform: str | None = None,
extra_build_args: List[str] | None = None,
extra_build_args: list[str] | None = None,
) -> str:
runtime_image_repo, _ = get_runtime_image_repo_and_tag(base_image)
lock_tag = f'oh_v{oh_version}_{get_hash_for_lock_files(base_image)}'
@ -294,7 +293,7 @@ _ALPHABET = string.digits + string.ascii_lowercase
def truncate_hash(hash: str) -> str:
"""Convert the base16 hash to base36 and truncate at 16 characters."""
value = int(hash, 16)
result: List[str] = []
result: list[str] = []
while value > 0 and len(result) < 16:
value, remainder = divmod(value, len(_ALPHABET))
result.append(_ALPHABET[remainder])
@ -347,7 +346,7 @@ def _build_sandbox_image(
lock_tag: str,
versioned_tag: str | None,
platform: str | None = None,
extra_build_args: List[str] | None = None,
extra_build_args: list[str] | None = None,
) -> str:
"""Build and tag the sandbox image. The image will be tagged with all tags that do not yet exist."""
names = [

View File

@ -1,8 +1,6 @@
from typing import Type
from openhands.security.analyzer import SecurityAnalyzer
from openhands.security.invariant.analyzer import InvariantAnalyzer
SecurityAnalyzers: dict[str, Type[SecurityAnalyzer]] = {
SecurityAnalyzers: dict[str, type[SecurityAnalyzer]] = {
'invariant': InvariantAnalyzer,
}

View File

@ -2,7 +2,7 @@ import asyncio
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Callable, Iterable, Type
from typing import Callable, Iterable
import socketio
@ -52,7 +52,7 @@ class StandaloneConversationManager(ConversationManager):
)
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_cleanup_task: asyncio.Task | None = None
_conversation_store_class: Type | None = None
_conversation_store_class: type[ConversationStore] | None = None
async def __aenter__(self):
self._cleanup_task = asyncio.create_task(self._cleanup_stale())

View File

@ -1,5 +1,5 @@
import os
from typing import Any, List, TypedDict
from typing import Any, TypedDict
import boto3
import botocore
@ -16,7 +16,7 @@ class GetObjectOutputDict(TypedDict):
class ListObjectsV2OutputDict(TypedDict):
Contents: List[S3ObjectDict] | None
Contents: list[S3ObjectDict] | None
class S3FileStore(FileStore):

View File

@ -1,7 +1,7 @@
import asyncio
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Coroutine, Iterable, List
from typing import Callable, Coroutine, Iterable
GENERAL_TIMEOUT: int = 15
EXECUTOR = ThreadPoolExecutor()
@ -64,7 +64,7 @@ async def call_coro_in_bg_thread(
async def wait_all(
iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
) -> List:
) -> list:
"""
Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
a task for each coroutine.

View File

@ -1,6 +1,6 @@
import importlib
from functools import lru_cache
from typing import Type, TypeVar
from typing import TypeVar
T = TypeVar('T')
@ -15,7 +15,7 @@ def import_from(qual_name: str):
@lru_cache()
def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]:
def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
"""Import a named implementation of the specified class"""
if impl_name is None:
return cls

View File

@ -5,7 +5,6 @@ import shutil
from abc import ABC
from dataclasses import dataclass, field
from io import BytesIO, StringIO
from typing import Dict, List
from unittest import TestCase
from unittest.mock import patch
@ -143,12 +142,12 @@ class _MockGoogleCloudClient:
@dataclass
class _MockGoogleCloudBucket:
blobs_by_path: Dict[str, _MockGoogleCloudBlob] = field(default_factory=dict)
blobs_by_path: dict[str, _MockGoogleCloudBlob] = field(default_factory=dict)
def blob(self, path: str | None = None) -> _MockGoogleCloudBlob:
return self.blobs_by_path.get(path) or _MockGoogleCloudBlob(self, path)
def list_blobs(self, prefix: str | None = None) -> List[_MockGoogleCloudBlob]:
def list_blobs(self, prefix: str | None = None) -> list[_MockGoogleCloudBlob]:
blobs = list(self.blobs_by_path.values())
if prefix and prefix != '/':
blobs = [blob for blob in blobs if blob.name.startswith(prefix)]
@ -197,14 +196,14 @@ class _MockGoogleCloudBlobWriter:
class _MockS3Client:
def __init__(self):
self.objects_by_bucket: Dict[str, Dict[str, _MockS3Object]] = {}
self.objects_by_bucket: dict[str, dict[str, _MockS3Object]] = {}
def put_object(self, Bucket: str, Key: str, Body: str | bytes) -> None:
if Bucket not in self.objects_by_bucket:
self.objects_by_bucket[Bucket] = {}
self.objects_by_bucket[Bucket][Key] = _MockS3Object(Key, Body)
def get_object(self, Bucket: str, Key: str) -> Dict:
def get_object(self, Bucket: str, Key: str) -> dict:
if Bucket not in self.objects_by_bucket:
raise botocore.exceptions.ClientError(
{
@ -230,7 +229,7 @@ class _MockS3Client:
return {'Body': BytesIO(content)}
return {'Body': StringIO(content)}
def list_objects_v2(self, Bucket: str, Prefix: str = '') -> Dict:
def list_objects_v2(self, Bucket: str, Prefix: str = '') -> dict:
if Bucket not in self.objects_by_bucket:
raise botocore.exceptions.ClientError(
{