Fix typing in server directory (#8375)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
Graham Neubig 2025-05-13 17:27:59 -04:00 committed by GitHub
parent f9b0fcd76e
commit 154eed148f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 79 additions and 56 deletions

View File

@ -1,5 +1,6 @@
import warnings
from contextlib import asynccontextmanager
from typing import AsyncIterator
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@ -26,7 +27,7 @@ from openhands.server.shared import conversation_manager
@asynccontextmanager
async def _lifespan(app: FastAPI):
async def _lifespan(app: FastAPI) -> AsyncIterator[None]:
async with conversation_manager:
yield
@ -40,7 +41,7 @@ app = FastAPI(
@app.get('/health')
async def health():
async def health() -> str:
return 'OK'

View File

@ -203,7 +203,7 @@ class StandaloneConversationManager(ConversationManager):
conversation_store_class = self._conversation_store_class
if not conversation_store_class:
self._conversation_store_class = conversation_store_class = get_impl(
ConversationStore, # type: ignore
ConversationStore,
self.server_config.conversation_store_class,
)
store = await conversation_store_class.get_instance(self.config, user_id)

View File

@ -40,6 +40,6 @@ def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]:
)
if response.status_code != 200:
raise ValueError(f'Failed to store feedback: {response.text}')
response_data = json.loads(response.text)
response_data: dict[str, str] = json.loads(response.text)
logger.debug(f'Stored feedback: {response.text}')
return response_data

View File

@ -15,7 +15,7 @@ FILES_TO_IGNORE = [
]
def sanitize_filename(filename):
def sanitize_filename(filename: str) -> str:
"""Sanitize the filename to prevent directory traversal"""
# Remove any directory components
filename = os.path.basename(filename)
@ -90,7 +90,7 @@ def load_file_upload_config(
MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config()
def is_extension_allowed(filename):
def is_extension_allowed(filename: str) -> bool:
"""Check if the file extension is allowed based on the current configuration.
This function supports wildcards and files without extensions.

View File

@ -1,5 +1,6 @@
import asyncio
from types import MappingProxyType
from typing import Any
from urllib.parse import parse_qs
from socketio.exceptions import ConnectionRefusedError
@ -38,7 +39,7 @@ from openhands.storage.data_models.user_secrets import UserSecrets
def create_provider_tokens_object(
providers_set: list[ProviderType],
) -> PROVIDER_TOKEN_TYPE:
provider_information = {}
provider_information: dict[ProviderType, ProviderToken] = {}
for provider in providers_set:
provider_information[provider] = ProviderToken(token=None, user_id=None)
@ -47,7 +48,7 @@ def create_provider_tokens_object(
@sio.event
async def connect(connection_id: str, environ):
async def connect(connection_id: str, environ: dict) -> None:
try:
logger.info(f'sio:connect: {connection_id}')
query_params = parse_qs(environ.get('QUERY_STRING', ''))
@ -141,18 +142,18 @@ async def connect(connection_id: str, environ):
@sio.event
async def oh_user_action(connection_id: str, data: dict):
async def oh_user_action(connection_id: str, data: dict[str, Any]) -> None:
await conversation_manager.send_to_event_stream(connection_id, data)
@sio.event
async def oh_action(connection_id: str, data: dict):
async def oh_action(connection_id: str, data: dict[str, Any]) -> None:
# TODO: Remove this handler once all clients are updated to use oh_user_action
# Keeping for backward compatibility with in-progress sessions
await conversation_manager.send_to_event_stream(connection_id, data)
@sio.event
async def disconnect(connection_id: str):
async def disconnect(connection_id: str) -> None:
logger.info(f'sio:disconnect:{connection_id}')
await conversation_manager.disconnect_from_session(connection_id)

View File

@ -1,14 +1,15 @@
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Callable
from typing import Any
from urllib.parse import urlparse
from fastapi import Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response
from starlette.types import ASGIApp
from openhands.server import shared
@ -22,7 +23,7 @@ class LocalhostCORSMiddleware(CORSMiddleware):
while using standard CORS rules for other origins.
"""
def __init__(self, app: ASGIApp, **kwargs) -> None:
def __init__(self, app: ASGIApp, **kwargs: Any) -> None:
super().__init__(app, **kwargs)
def is_allowed_origin(self, origin: str) -> bool:
@ -35,7 +36,8 @@ class LocalhostCORSMiddleware(CORSMiddleware):
return True
# For missing origin or other origins, use the parent class's logic
return super().is_allowed_origin(origin)
result: bool = super().is_allowed_origin(origin)
return result
class CacheControlMiddleware(BaseHTTPMiddleware):
@ -43,7 +45,9 @@ class CacheControlMiddleware(BaseHTTPMiddleware):
Middleware to disable caching for all routes by adding appropriate headers
"""
async def dispatch(self, request, call_next):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
if request.url.path.startswith('/assets'):
# The content of the assets directory has fingerprinted file names so we cache aggressively
@ -58,7 +62,7 @@ class CacheControlMiddleware(BaseHTTPMiddleware):
class InMemoryRateLimiter:
history: dict
history: dict[str, list[datetime]]
requests: int
seconds: int
sleep_seconds: int
@ -100,7 +104,9 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
super().__init__(app)
self.rate_limiter = rate_limiter
async def dispatch(self, request: StarletteRequest, call_next):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if not self.is_rate_limited_request(request):
return await call_next(request)
ok = await self.rate_limiter(request)
@ -112,7 +118,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
)
return await call_next(request)
def is_rate_limited_request(self, request: StarletteRequest):
def is_rate_limited_request(self, request: StarletteRequest) -> bool:
if request.url.path.startswith('/assets'):
return False
# Put Other non rate limited checks here
@ -120,10 +126,10 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
class AttachConversationMiddleware(SessionMiddlewareInterface):
def __init__(self, app):
def __init__(self, app: ASGIApp) -> None:
self.app = app
def _should_attach(self, request) -> bool:
def _should_attach(self, request: Request) -> bool:
"""
Determine if the middleware should attach a session for the given request.
"""
@ -168,7 +174,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
request.state.conversation
)
async def __call__(self, request: Request, call_next: Callable):
async def __call__(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if not self._should_attach(request):
return await call_next(request)

View File

@ -8,7 +8,7 @@ app = FastAPI()
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
try:
@ -26,12 +26,12 @@ async def websocket_endpoint(websocket: WebSocket):
@app.get('/')
def read_root():
def read_root() -> dict[str, str]:
return {'message': 'This is a mock server'}
@app.get('/api/options/models')
def read_llm_models():
def read_llm_models() -> list[str]:
return [
'gpt-4',
'gpt-4-turbo-preview',
@ -41,24 +41,24 @@ def read_llm_models():
@app.get('/api/options/agents')
def read_llm_agents():
def read_llm_agents() -> list[str]:
return [
'CodeActAgent',
]
@app.get('/api/list-files')
def refresh_files():
def refresh_files() -> list[str]:
return ['hello_world.py']
@app.get('/api/options/config')
def get_config():
def get_config() -> dict[str, str]:
return {'APP_MODE': 'oss'}
@app.get('/api/options/security-analyzers')
def get_analyzers():
def get_analyzers() -> list[str]:
return []

View File

@ -39,10 +39,7 @@ async def get_user_repositories(
)
try:
repos: list[Repository] = await client.get_repositories(
sort, server_config.app_mode
)
return repos
return await client.get_repositories(sort, server_config.app_mode)
except AuthenticationError as e:
return JSONResponse(

View File

@ -56,7 +56,7 @@ async def invalidate_legacy_secrets_store(
def process_token_validation_result(
confirmed_token_type: ProviderType | None, token_type: ProviderType
):
) -> str:
if not confirmed_token_type or confirmed_token_type != token_type:
return (
f'Invalid token. Please make sure it is a valid {token_type.value} token.'

View File

@ -4,22 +4,29 @@ import socketio
from dotenv import load_dotenv
from openhands.core.config import load_app_config
from openhands.server.config.server_config import load_server_config
from openhands.core.config.app_config import AppConfig
from openhands.server.config.server_config import ServerConfig, load_server_config
from openhands.server.conversation_manager.conversation_manager import (
ConversationManager,
)
from openhands.server.monitoring import MonitoringListener
from openhands.server.types import ServerConfigInterface
from openhands.storage import get_file_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.files import FileStore
from openhands.storage.secrets.secrets_store import SecretsStore
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.import_utils import get_impl
load_dotenv()
config = load_app_config()
server_config = load_server_config()
file_store = get_file_store(config.file_store, config.file_store_path)
config: AppConfig = load_app_config()
server_config_interface: ServerConfigInterface = load_server_config()
assert isinstance(server_config_interface, ServerConfig), (
'Loaded server config interface is not a ServerConfig, despite this being assumed'
)
server_config: ServerConfig = server_config_interface
file_store: FileStore = get_file_store(config.file_store, config.file_store_path)
client_manager = None
redis_host = os.environ.get('REDIS_HOST')
@ -42,19 +49,19 @@ MonitoringListenerImpl = get_impl(
monitoring_listener = MonitoringListenerImpl.get_instance(config)
ConversationManagerImpl = get_impl(
ConversationManager, # type: ignore
ConversationManager,
server_config.conversation_manager_class,
)
conversation_manager = ConversationManagerImpl.get_instance( # type: ignore
conversation_manager = ConversationManagerImpl.get_instance(
sio, config, file_store, server_config, monitoring_listener
)
SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore
SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class)
SecretsStoreImpl = get_impl(SecretsStore, server_config.secret_store_class)
ConversationStoreImpl = get_impl(
ConversationStore, # type: ignore
ConversationStore,
server_config.conversation_store_class,
)

View File

@ -1,8 +1,10 @@
from fastapi.staticfiles import StaticFiles
from starlette.responses import Response
from starlette.types import Scope
class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
async def get_response(self, path: str, scope: Scope) -> Response:
try:
return await super().get_response(path, scope)
except Exception:

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import ClassVar, Protocol
from typing import Any, ClassVar, Protocol
class AppMode(Enum):
@ -27,7 +27,7 @@ class ServerConfigInterface(ABC):
raise NotImplementedError
@abstractmethod
async def get_config(self) -> dict[str, str]:
def get_config(self) -> dict[str, Any]:
"""Configure attributes for frontend"""
raise NotImplementedError

View File

@ -29,7 +29,7 @@ class DefaultUserAuth(UserAuth):
"""The default implementation does not support multi tenancy, so access_token is always None"""
return None
async def get_user_settings_store(self):
async def get_user_settings_store(self) -> SettingsStore:
settings_store = self._settings_store
if settings_store:
return settings_store
@ -37,6 +37,8 @@ class DefaultUserAuth(UserAuth):
settings_store = await shared.SettingsStoreImpl.get_instance(
shared.config, user_id
)
if settings_store is None:
raise ValueError('Failed to get settings store instance')
self._settings_store = settings_store
return settings_store
@ -49,7 +51,7 @@ class DefaultUserAuth(UserAuth):
self._settings = settings
return settings
async def get_secrets_store(self):
async def get_secrets_store(self) -> SecretsStore:
secrets_store = self._secrets_store
if secrets_store:
return secrets_store
@ -57,6 +59,8 @@ class DefaultUserAuth(UserAuth):
secret_store = await shared.SecretsStoreImpl.get_instance(
shared.config, user_id
)
if secret_store is None:
raise ValueError('Failed to get secrets store instance')
self._secrets_store = secret_store
return secret_store
@ -70,9 +74,10 @@ class DefaultUserAuth(UserAuth):
return user_secrets
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
secrets_store = await self.get_user_secrets()
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
return provider_tokens
user_secrets = await self.get_user_secrets()
if user_secrets is None:
return None
return user_secrets.provider_tokens
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:

View File

@ -7,8 +7,8 @@ from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.server.settings import Settings
from openhands.server.shared import server_config
from openhands.server.settings import Settings
from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.storage.secrets.secrets_store import SecretsStore
from openhands.storage.settings.settings_store import SettingsStore
@ -38,7 +38,7 @@ class UserAuth(ABC):
"""Get the provider tokens for the current user."""
@abstractmethod
async def get_user_settings_store(self) -> SettingsStore | None:
async def get_user_settings_store(self) -> SettingsStore:
"""Get the settings store for the current user."""
async def get_user_settings(self) -> Settings | None:
@ -47,8 +47,6 @@ class UserAuth(ABC):
if settings:
return settings
settings_store = await self.get_user_settings_store()
if settings_store is None:
return None
settings = await settings_store.load()
self._settings = settings
return settings
@ -71,11 +69,13 @@ class UserAuth(ABC):
async def get_user_auth(request: Request) -> UserAuth:
user_auth = getattr(request.state, 'user_auth', None)
user_auth: UserAuth | None = getattr(request.state, 'user_auth', None)
if user_auth:
return user_auth
impl_name = server_config.user_auth_class
impl = get_impl(UserAuth, impl_name)
user_auth = await impl.get_instance(request)
if user_auth is None:
raise ValueError('Failed to get user auth instance')
request.state.user_auth = user_auth
return user_auth

View File

@ -6,7 +6,9 @@ from openhands.storage.conversation.conversation_store import ConversationStore
async def get_conversation_store(request: Request) -> ConversationStore | None:
conversation_store = getattr(request.state, 'conversation_store', None)
conversation_store: ConversationStore | None = getattr(
request.state, 'conversation_store', None
)
if conversation_store:
return conversation_store
user_auth = await get_user_auth(request)