Enforce modern Python typing annotations with Ruff (#8296)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-05-06 07:58:33 -04:00 committed by GitHub
parent 4c1ae6fd8d
commit adfa510b5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 79 additions and 72 deletions

View File

@ -26,7 +26,7 @@ repos:
- id: ruff
entry: ruff check --config dev_config/python/ruff.toml
types_or: [python, pyi, jupyter]
args: [--fix]
args: [--fix, --unsafe-fixes]
# Run the formatter.
- id: ruff-format
entry: ruff format --config dev_config/python/ruff.toml

View File

@ -7,6 +7,9 @@ select = [
"Q",
"B",
"ASYNC",
"UP006", # Use `list` instead of `List` for annotations
"UP007", # Use `X | Y` instead of `Union[X, Y]`
"UP008", # Use `X | None` instead of `Optional[X]`
]
ignore = [

View File

@ -57,7 +57,7 @@ describe("Browser", () => {
screenshotSrc: "",
};
});
it("renders a message if no screenshotSrc is provided", () => {
// Set the mock state for this test
mockBrowserState = {

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from litellm import ChatCompletionToolParam
from openhands.events.action import Action
from openhands.llm.llm import ModelResponse

View File

@ -3,11 +3,11 @@ ReadOnlyAgent - A specialized version of CodeActAgent that only uses read-only t
"""
import os
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from litellm import ChatCompletionToolParam
from openhands.events.action import Action
from openhands.llm.llm import ModelResponse

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List
import toml
@ -10,7 +9,7 @@ from openhands.events.event import Event
from openhands.llm.metrics import Metrics
_LOCAL_CONFIG_FILE_PATH = Path.home() / '.openhands' / 'config.toml'
_DEFAULT_CONFIG: Dict[str, Dict[str, List[str]]] = {'sandbox': {'trusted_dirs': []}}
_DEFAULT_CONFIG: dict[str, dict[str, list[str]]] = {'sandbox': {'trusted_dirs': []}}
def get_local_config_trusted_dirs() -> list[str]:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Type
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from openhands.controller.state.state import State
@ -30,7 +30,7 @@ class Agent(ABC):
It tracks the execution status and maintains a history of interactions.
"""
_registry: dict[str, Type['Agent']] = {}
_registry: dict[str, type['Agent']] = {}
sandbox_plugins: list[PluginRequirement] = []
def __init__(
@ -118,7 +118,7 @@ class Agent(ABC):
return self.__class__.__name__
@classmethod
def register(cls, name: str, agent_cls: Type['Agent']) -> None:
def register(cls, name: str, agent_cls: type['Agent']) -> None:
"""Registers an agent class in the registry.
Parameters:
@ -133,7 +133,7 @@ class Agent(ABC):
cls._registry[name] = agent_cls
@classmethod
def get_cls(cls, name: str) -> Type['Agent']:
def get_cls(cls, name: str) -> type['Agent']:
"""Retrieves an agent class from the registry.
Parameters:

View File

@ -5,7 +5,7 @@ import copy
import os
import time
import traceback
from typing import Callable, ClassVar, Tuple, Type
from typing import Callable, ClassVar
import litellm # noqa
from litellm.exceptions import ( # noqa
@ -91,7 +91,7 @@ class AgentController:
agent_configs: dict[str, AgentConfig]
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action_info: Tuple[Action, float] | None = None # (action, timestamp)
_pending_action_info: tuple[Action, float] | None = None # (action, timestamp)
_closed: bool = False
filter_out: ClassVar[tuple[type[Event], ...]] = (
NullAction,
@ -675,7 +675,7 @@ class AgentController:
Args:
action (AgentDelegateAction): The action containing information about the delegate agent to start.
"""
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
agent_cls: type[Agent] = Agent.get_cls(action.agent)
agent_config = self.agent_configs.get(action.agent, self.agent.config)
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
llm = LLM(config=llm_config, retry_listener=self._notify_on_llm_retry)

View File

@ -1,7 +1,7 @@
import hashlib
import os
import uuid
from typing import Callable, Tuple, Type
from typing import Callable
from pydantic import SecretStr
@ -173,7 +173,7 @@ def create_memory(
def create_agent(config: AppConfig) -> Agent:
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
@ -191,7 +191,7 @@ def create_controller(
config: AppConfig,
headless_mode: bool = True,
replay_events: list[Event] | None = None,
) -> Tuple[AgentController, State | None]:
) -> tuple[AgentController, State | None]:
event_stream = runtime.event_stream
initial_state = None
try:

View File

@ -6,6 +6,11 @@ from typing import Any
import httpx
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.queries import (
suggested_task_issue_graphql_query,
suggested_task_pr_graphql_query,
)
from openhands.integrations.service_types import (
BaseGitService,
Branch,
@ -20,9 +25,6 @@ from openhands.integrations.service_types import (
)
from openhands.server.types import AppMode
from openhands.utils.import_utils import get_impl
from openhands.integrations.github.queries import suggested_task_pr_graphql_query, suggested_task_issue_graphql_query
from datetime import datetime
from openhands.core.logger import openhands_logger as logger
class GitHubService(BaseGitService, GitService):
@ -291,7 +293,7 @@ class GitHubService(BaseGitService, GitService):
Returns:
- PRs authored by the user.
- Issues assigned to the user.
Note: Queries are split to avoid timeout issues.
"""
# Get user info to use in queries
@ -301,9 +303,11 @@ class GitHubService(BaseGitService, GitService):
variables = {'login': login}
try:
pr_response = await self.execute_graphql_query(suggested_task_pr_graphql_query, variables)
pr_response = await self.execute_graphql_query(
suggested_task_pr_graphql_query, variables
)
pr_data = pr_response['data']['user']
# Process pull requests
for pr in pr_data['pullRequests']['nodes']:
repo_name = pr['repository']['nameWithOwner']
@ -341,16 +345,22 @@ class GitHubService(BaseGitService, GitService):
)
)
except Exception as e:
logger.info(f"Error fetching suggested task for PRs: {e}",
extra={'signal': 'github_suggested_tasks', 'user_id': self.external_auth_id})
logger.info(
f'Error fetching suggested task for PRs: {e}',
extra={
'signal': 'github_suggested_tasks',
'user_id': self.external_auth_id,
},
)
try:
# Execute issue query
issue_response = await self.execute_graphql_query(suggested_task_issue_graphql_query, variables)
issue_response = await self.execute_graphql_query(
suggested_task_issue_graphql_query, variables
)
issue_data = issue_response['data']['user']
# Process issues
for issue in issue_data['issues']['nodes']:
repo_name = issue['repository']['nameWithOwner']
@ -365,10 +375,15 @@ class GitHubService(BaseGitService, GitService):
)
return tasks
except Exception as e:
logger.info(f"Error fetching suggested task for issues: {e}",
extra={'signal': 'github_suggested_tasks', 'user_id': self.external_auth_id})
logger.info(
f'Error fetching suggested task for issues: {e}',
extra={
'signal': 'github_suggested_tasks',
'user_id': self.external_auth_id,
},
)
return tasks

View File

@ -29,7 +29,7 @@ suggested_task_pr_graphql_query = """
}
"""
suggested_task_issue_graphql_query = """
query GetUserIssues($login: String!) {
user(login: $login) {

View File

@ -1,5 +1,5 @@
Please summarize your work.
If you answered a question, please re-state the answer to the question
If you made changes, please create a concise overview on whether the request has been addressed successfully or if there are were issues with the attempt.
If successful, make sure your changes are pushed to the remote branch.

View File

@ -1,6 +1,6 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Dict, List, Optional
from typing import Optional
from mcp import ClientSession
from mcp.client.sse import sse_client
@ -18,8 +18,8 @@ class MCPClient(BaseModel):
session: Optional[ClientSession] = None
exit_stack: AsyncExitStack = AsyncExitStack()
description: str = 'MCP client tools for server interaction'
tools: List[MCPClientTool] = Field(default_factory=list)
tool_map: Dict[str, MCPClientTool] = Field(default_factory=dict)
tools: list[MCPClientTool] = Field(default_factory=list)
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
@ -91,7 +91,7 @@ class MCPClient(BaseModel):
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
)
async def call_tool(self, tool_name: str, args: Dict):
async def call_tool(self, tool_name: str, args: dict):
"""Call a tool on the MCP server."""
if tool_name not in self.tool_map:
raise ValueError(f'Tool {tool_name} not found.')

View File

@ -1,5 +1,3 @@
from typing import Dict
from mcp.types import Tool
@ -14,7 +12,7 @@ class MCPClientTool(Tool):
class Config:
arbitrary_types_allowed = True
def to_param(self) -> Dict:
def to_param(self) -> dict:
"""Convert tool to function call format."""
return {
'type': 'function',

View File

@ -6,11 +6,9 @@ class HunkException(PatchingException):
def __init__(self, msg: str, hunk: int | None = None) -> None:
self.hunk = hunk
if hunk is not None:
super(HunkException, self).__init__(
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
)
super().__init__('{msg}, in hunk #{n}'.format(msg=msg, n=hunk))
else:
super(HunkException, self).__init__(msg)
super().__init__(msg)
class ApplyException(PatchingException):
@ -19,7 +17,7 @@ class ApplyException(PatchingException):
class SubprocessException(ApplyException):
def __init__(self, msg: str, code: int) -> None:
super(SubprocessException, self).__init__(msg)
super().__init__(msg)
self.code = code

View File

@ -1,5 +1,3 @@
from typing import Type
from openhands.runtime.base import Runtime
from openhands.runtime.impl.daytona.daytona_runtime import DaytonaRuntime
from openhands.runtime.impl.docker.docker_runtime import (
@ -13,7 +11,7 @@ from openhands.runtime.impl.runloop.runloop_runtime import RunloopRuntime
from openhands.utils.import_utils import get_impl
# mypy: disable-error-code="type-abstract"
_DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
_DEFAULT_RUNTIME_CLASSES: dict[str, type[Runtime]] = {
'eventstream': DockerRuntime,
'docker': DockerRuntime,
'e2b': E2BRuntime,
@ -25,7 +23,7 @@ _DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
}
def get_runtime_cls(name: str) -> Type[Runtime]:
def get_runtime_cls(name: str) -> type[Runtime]:
"""
If name is one of the predefined runtime names (e.g. 'docker'), return its class.
Otherwise attempt to resolve name as subclass of Runtime and return it.

View File

@ -5,7 +5,6 @@ This server has no authentication and only listens to localhost traffic.
import os
import threading
from typing import Tuple
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
@ -75,7 +74,7 @@ def create_app() -> FastAPI:
return app
def start_file_viewer_server(port: int) -> Tuple[str, threading.Thread]:
def start_file_viewer_server(port: int) -> tuple[str, threading.Thread]:
"""Start the file viewer server on the specified port or find an available one.
Args:

View File

@ -6,7 +6,6 @@ import string
import tempfile
from enum import Enum
from pathlib import Path
from typing import List
import docker
from dirhash import dirhash # type: ignore
@ -111,7 +110,7 @@ def build_runtime_image(
build_folder: str | None = None,
dry_run: bool = False,
force_rebuild: bool = False,
extra_build_args: List[str] | None = None,
extra_build_args: list[str] | None = None,
) -> str:
"""Prepares the final docker build folder.
@ -167,7 +166,7 @@ def build_runtime_image_in_folder(
dry_run: bool,
force_rebuild: bool,
platform: str | None = None,
extra_build_args: List[str] | None = None,
extra_build_args: list[str] | None = None,
) -> str:
runtime_image_repo, _ = get_runtime_image_repo_and_tag(base_image)
lock_tag = f'oh_v{oh_version}_{get_hash_for_lock_files(base_image)}'
@ -294,7 +293,7 @@ _ALPHABET = string.digits + string.ascii_lowercase
def truncate_hash(hash: str) -> str:
"""Convert the base16 hash to base36 and truncate at 16 characters."""
value = int(hash, 16)
result: List[str] = []
result: list[str] = []
while value > 0 and len(result) < 16:
value, remainder = divmod(value, len(_ALPHABET))
result.append(_ALPHABET[remainder])
@ -347,7 +346,7 @@ def _build_sandbox_image(
lock_tag: str,
versioned_tag: str | None,
platform: str | None = None,
extra_build_args: List[str] | None = None,
extra_build_args: list[str] | None = None,
) -> str:
"""Build and tag the sandbox image. The image will be tagged with all tags that do not yet exist."""
names = [

View File

@ -1,8 +1,6 @@
from typing import Type
from openhands.security.analyzer import SecurityAnalyzer
from openhands.security.invariant.analyzer import InvariantAnalyzer
SecurityAnalyzers: dict[str, Type[SecurityAnalyzer]] = {
SecurityAnalyzers: dict[str, type[SecurityAnalyzer]] = {
'invariant': InvariantAnalyzer,
}

View File

@ -2,7 +2,7 @@ import asyncio
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Callable, Iterable, Type
from typing import Callable, Iterable
import socketio
@ -52,7 +52,7 @@ class StandaloneConversationManager(ConversationManager):
)
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_cleanup_task: asyncio.Task | None = None
_conversation_store_class: Type | None = None
_conversation_store_class: type[ConversationStore] | None = None
async def __aenter__(self):
self._cleanup_task = asyncio.create_task(self._cleanup_stale())

View File

@ -1,5 +1,5 @@
import os
from typing import Any, List, TypedDict
from typing import Any, TypedDict
import boto3
import botocore
@ -16,7 +16,7 @@ class GetObjectOutputDict(TypedDict):
class ListObjectsV2OutputDict(TypedDict):
Contents: List[S3ObjectDict] | None
Contents: list[S3ObjectDict] | None
class S3FileStore(FileStore):

View File

@ -1,7 +1,7 @@
import asyncio
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Coroutine, Iterable, List
from typing import Callable, Coroutine, Iterable
GENERAL_TIMEOUT: int = 15
EXECUTOR = ThreadPoolExecutor()
@ -64,7 +64,7 @@ async def call_coro_in_bg_thread(
async def wait_all(
iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
) -> List:
) -> list:
"""
Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
a task for each coroutine.

View File

@ -1,6 +1,6 @@
import importlib
from functools import lru_cache
from typing import Type, TypeVar
from typing import TypeVar
T = TypeVar('T')
@ -15,7 +15,7 @@ def import_from(qual_name: str):
@lru_cache()
def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]:
def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
"""Import a named implementation of the specified class"""
if impl_name is None:
return cls

View File

@ -5,7 +5,6 @@ import shutil
from abc import ABC
from dataclasses import dataclass, field
from io import BytesIO, StringIO
from typing import Dict, List
from unittest import TestCase
from unittest.mock import patch
@ -143,12 +142,12 @@ class _MockGoogleCloudClient:
@dataclass
class _MockGoogleCloudBucket:
blobs_by_path: Dict[str, _MockGoogleCloudBlob] = field(default_factory=dict)
blobs_by_path: dict[str, _MockGoogleCloudBlob] = field(default_factory=dict)
def blob(self, path: str | None = None) -> _MockGoogleCloudBlob:
return self.blobs_by_path.get(path) or _MockGoogleCloudBlob(self, path)
def list_blobs(self, prefix: str | None = None) -> List[_MockGoogleCloudBlob]:
def list_blobs(self, prefix: str | None = None) -> list[_MockGoogleCloudBlob]:
blobs = list(self.blobs_by_path.values())
if prefix and prefix != '/':
blobs = [blob for blob in blobs if blob.name.startswith(prefix)]
@ -197,14 +196,14 @@ class _MockGoogleCloudBlobWriter:
class _MockS3Client:
def __init__(self):
self.objects_by_bucket: Dict[str, Dict[str, _MockS3Object]] = {}
self.objects_by_bucket: dict[str, dict[str, _MockS3Object]] = {}
def put_object(self, Bucket: str, Key: str, Body: str | bytes) -> None:
if Bucket not in self.objects_by_bucket:
self.objects_by_bucket[Bucket] = {}
self.objects_by_bucket[Bucket][Key] = _MockS3Object(Key, Body)
def get_object(self, Bucket: str, Key: str) -> Dict:
def get_object(self, Bucket: str, Key: str) -> dict:
if Bucket not in self.objects_by_bucket:
raise botocore.exceptions.ClientError(
{
@ -230,7 +229,7 @@ class _MockS3Client:
return {'Body': BytesIO(content)}
return {'Body': StringIO(content)}
def list_objects_v2(self, Bucket: str, Prefix: str = '') -> Dict:
def list_objects_v2(self, Bucket: str, Prefix: str = '') -> dict:
if Bucket not in self.objects_by_bucket:
raise botocore.exceptions.ClientError(
{