Fix typing (#7083)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra 2025-03-03 15:41:11 -05:00 committed by GitHub
parent 4e4f4d64f8
commit 5ffb1ef704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 31 additions and 25 deletions

View File

@ -7,7 +7,7 @@ import os
import re
from dataclasses import dataclass
from enum import Enum, auto
from typing import Dict, List, Optional, Union
from typing import Dict, List, Union
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import BrowseInteractiveAction
@ -133,7 +133,7 @@ def parse_content_to_elements(content: str) -> Dict[str, str]:
return elements
def find_matching_anchor(content: str, selector: str) -> Optional[str]:
def find_matching_anchor(content: str, selector: str) -> str | None:
"""Find the anchor ID that matches the given selector description"""
elements = parse_content_to_elements(content)

View File

@ -53,4 +53,3 @@ To verify Docker is working correctly, run the hello-world container:
```bash
sudo docker run hello-world
```

View File

@ -7,6 +7,7 @@ OpenHands uses its own `Message` class (`openhands/core/message.py`) which provi
## Class Structure
Our `Message` class (`openhands/core/message.py`):
```python
class Message(BaseModel):
role: Literal['user', 'system', 'assistant', 'tool']
@ -22,13 +23,14 @@ class Message(BaseModel):
```
litellm's `Message` class (`litellm/types/utils.py`):
```python
class Message(OpenAIObject):
content: Optional[str]
content: str | None
role: Literal["assistant", "user", "system", "tool", "function"]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
function_call: Optional[FunctionCall]
audio: Optional[ChatCompletionAudioResponse] = None
tool_calls: List[ChatCompletionMessageToolCall] | None
function_call: FunctionCall | None
audio: ChatCompletionAudioResponse | None = None
```
## How It Works
@ -36,6 +38,7 @@ class Message(OpenAIObject):
1. **Message Creation**: Our `Message` class is a Pydantic model that supports rich content (text and images) through its `content` field.
2. **Serialization**: The class uses Pydantic's `@model_serializer` to convert messages into dictionaries that litellm can understand. We have two serialization methods:
```python
def _string_serializer(self) -> dict:
# convert content to a single string
@ -55,6 +58,7 @@ class Message(OpenAIObject):
```
The appropriate serializer is chosen based on the message's capabilities:
```python
@model_serializer
def serialize_model(self) -> dict:
@ -64,11 +68,13 @@ class Message(OpenAIObject):
```
3. **Tool Call Handling**: Tool calls require special attention in serialization because:
- They need to work with litellm's API calls (which accept both dicts and objects)
- They need to be properly serialized for token counting
- They need to maintain compatibility with different LLM providers' formats
4. **litellm Integration**: When we pass our messages to `litellm.completion()`, litellm doesn't care about the message class type - it works with the dictionary representation. This works because:
- litellm's transformation code (e.g., `litellm/llms/anthropic/chat/transformation.py`) processes messages based on their structure, not their type
- our serialization produces dictionaries that match litellm's expected format
- litellm handles rich content by looking at the message structure, supporting both simple string content and lists of content items
@ -78,6 +84,7 @@ class Message(OpenAIObject):
### Token Counting
To use litellm's token counter, we need to make sure that all message components (including tool calls) are properly serialized to dictionaries. This is because:
- litellm's token counter expects dictionary structures
- Tool calls need to be included in the token count
- Different providers may count tokens differently for structured content

View File

@ -4,7 +4,7 @@ import multiprocessing as mp
import os
import re
from enum import Enum
from typing import Callable, Optional
from typing import Callable
import pandas as pd
import requests
@ -22,7 +22,7 @@ class Platform(Enum):
GITLAB = 2
def identify_token(token: str, repo: Optional[str] = None) -> Platform:
def identify_token(token: str, repo: str | None = None) -> Platform:
"""
Identifies whether a token belongs to GitHub or GitLab.

View File

@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable
from openhands.core.config import AppConfig
from openhands.events.action import (
@ -27,7 +27,7 @@ class E2BRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
sandbox: E2BSandbox | None = None,
status_callback: Optional[Callable] = None,
status_callback: Callable | None = None,
):
super().__init__(
config,

View File

@ -7,7 +7,7 @@ import shutil
import subprocess
import tempfile
import threading
from typing import Callable, Optional
from typing import Callable
import requests
import tenacity
@ -155,7 +155,7 @@ class LocalRuntime(ActionExecutionClient):
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._host_port}'
self.status_callback = status_callback
self.server_process: Optional[subprocess.Popen[str]] = None
self.server_process: subprocess.Popen[str] | None = None
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
# Update env vars

View File

@ -1,5 +1,5 @@
import os
from typing import Callable, Optional
from typing import Callable
from urllib.parse import urlparse
import requests
@ -42,7 +42,7 @@ class RemoteRuntime(ActionExecutionClient):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_callback: Optional[Callable] = None,
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
github_user_id: str | None = None,

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Literal, Optional
from typing import Any, Literal
import requests
from pydantic import BaseModel
@ -15,7 +15,7 @@ class FeedbackDataModel(BaseModel):
'positive', 'negative'
] # TODO: remove this, its here for backward compatibility
permissions: Literal['public', 'private']
trajectory: Optional[list[dict[str, Any]]]
trajectory: list[dict[str, Any]] | None
FEEDBACK_URL = 'https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_trajectory'

View File

@ -1,6 +1,6 @@
import asyncio
import time
from typing import Callable, Optional
from typing import Callable
from pydantic import SecretStr
@ -52,7 +52,7 @@ class AgentSession:
sid: str,
file_store: FileStore,
monitoring_listener: MonitoringListener,
status_callback: Optional[Callable] = None,
status_callback: Callable | None = None,
github_user_id: str | None = None,
):
"""Initializes a new instance of the Session class

View File

@ -1,5 +1,5 @@
import os
from typing import List, Optional
from typing import List
from google.api_core.exceptions import NotFound
from google.cloud import storage
@ -8,7 +8,7 @@ from openhands.storage.files import FileStore
class GoogleCloudFileStore(FileStore):
def __init__(self, bucket_name: Optional[str] = None) -> None:
def __init__(self, bucket_name: str | None = None) -> None:
"""
Create a new FileStore. If GOOGLE_APPLICATION_CREDENTIALS is defined in the
environment it will be used for authentication. Otherwise access will be

View File

@ -5,7 +5,7 @@ import shutil
from abc import ABC
from dataclasses import dataclass, field
from io import BytesIO, StringIO
from typing import Dict, List, Optional
from typing import Dict, List
from unittest import TestCase
from unittest.mock import patch
@ -145,10 +145,10 @@ class _MockGoogleCloudClient:
class _MockGoogleCloudBucket:
blobs_by_path: Dict[str, _MockGoogleCloudBlob] = field(default_factory=dict)
def blob(self, path: Optional[str] = None) -> _MockGoogleCloudBlob:
def blob(self, path: str | None = None) -> _MockGoogleCloudBlob:
return self.blobs_by_path.get(path) or _MockGoogleCloudBlob(self, path)
def list_blobs(self, prefix: Optional[str] = None) -> List[_MockGoogleCloudBlob]:
def list_blobs(self, prefix: str | None = None) -> List[_MockGoogleCloudBlob]:
blobs = list(self.blobs_by_path.values())
if prefix and prefix != '/':
blobs = [blob for blob in blobs if blob.name.startswith(prefix)]
@ -159,7 +159,7 @@ class _MockGoogleCloudBucket:
class _MockGoogleCloudBlob:
bucket: _MockGoogleCloudBucket
name: str
content: Optional[str | bytes] = None
content: str | bytes | None = None
def open(self, op: str):
if op == 'r':