From 154eed148f0d3fa31789ee0af91c970bc942bf53 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Tue, 13 May 2025 17:27:59 -0400 Subject: [PATCH] Fix typing in server directory (#8375) Co-authored-by: openhands Co-authored-by: Rohit Malhotra --- openhands/server/app.py | 5 ++-- .../standalone_conversation_manager.py | 2 +- openhands/server/data_models/feedback.py | 2 +- openhands/server/file_config.py | 4 +-- openhands/server/listen_socket.py | 11 +++---- openhands/server/middleware.py | 30 ++++++++++++------- openhands/server/mock/listen.py | 14 ++++----- openhands/server/routes/git.py | 5 +--- openhands/server/routes/secrets.py | 2 +- openhands/server/shared.py | 23 +++++++++----- openhands/server/static.py | 4 ++- openhands/server/types.py | 4 +-- .../server/user_auth/default_user_auth.py | 15 ++++++---- openhands/server/user_auth/user_auth.py | 10 +++---- openhands/server/utils.py | 4 ++- 15 files changed, 79 insertions(+), 56 deletions(-) diff --git a/openhands/server/app.py b/openhands/server/app.py index c9eea3afb8..1823888611 100644 --- a/openhands/server/app.py +++ b/openhands/server/app.py @@ -1,5 +1,6 @@ import warnings from contextlib import asynccontextmanager +from typing import AsyncIterator with warnings.catch_warnings(): warnings.simplefilter('ignore') @@ -26,7 +27,7 @@ from openhands.server.shared import conversation_manager @asynccontextmanager -async def _lifespan(app: FastAPI): +async def _lifespan(app: FastAPI) -> AsyncIterator[None]: async with conversation_manager: yield @@ -40,7 +41,7 @@ app = FastAPI( @app.get('/health') -async def health(): +async def health() -> str: return 'OK' diff --git a/openhands/server/conversation_manager/standalone_conversation_manager.py b/openhands/server/conversation_manager/standalone_conversation_manager.py index 5db9826f8c..32485b650d 100644 --- a/openhands/server/conversation_manager/standalone_conversation_manager.py +++ b/openhands/server/conversation_manager/standalone_conversation_manager.py @@ -203,7 +203,7 @@ class StandaloneConversationManager(ConversationManager): conversation_store_class = self._conversation_store_class if not conversation_store_class: self._conversation_store_class = conversation_store_class = get_impl( - ConversationStore, # type: ignore + ConversationStore, self.server_config.conversation_store_class, ) store = await conversation_store_class.get_instance(self.config, user_id) diff --git a/openhands/server/data_models/feedback.py b/openhands/server/data_models/feedback.py index 512b2b76fb..59f8a1dcc3 100644 --- a/openhands/server/data_models/feedback.py +++ b/openhands/server/data_models/feedback.py @@ -40,6 +40,6 @@ def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]: ) if response.status_code != 200: raise ValueError(f'Failed to store feedback: {response.text}') - response_data = json.loads(response.text) + response_data: dict[str, str] = json.loads(response.text) logger.debug(f'Stored feedback: {response.text}') return response_data diff --git a/openhands/server/file_config.py b/openhands/server/file_config.py index 1e4d6c5314..4c59e7f366 100644 --- a/openhands/server/file_config.py +++ b/openhands/server/file_config.py @@ -15,7 +15,7 @@ FILES_TO_IGNORE = [ ] -def sanitize_filename(filename): +def sanitize_filename(filename: str) -> str: """Sanitize the filename to prevent directory traversal""" # Remove any directory components filename = os.path.basename(filename) @@ -90,7 +90,7 @@ def load_file_upload_config( MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config() -def is_extension_allowed(filename): +def is_extension_allowed(filename: str) -> bool: """Check if the file extension is allowed based on the current configuration. This function supports wildcards and files without extensions. diff --git a/openhands/server/listen_socket.py b/openhands/server/listen_socket.py index 4d4dcb6129..fe0414bfcc 100644 --- a/openhands/server/listen_socket.py +++ b/openhands/server/listen_socket.py @@ -1,5 +1,6 @@ import asyncio from types import MappingProxyType +from typing import Any from urllib.parse import parse_qs from socketio.exceptions import ConnectionRefusedError @@ -38,7 +39,7 @@ from openhands.storage.data_models.user_secrets import UserSecrets def create_provider_tokens_object( providers_set: list[ProviderType], ) -> PROVIDER_TOKEN_TYPE: - provider_information = {} + provider_information: dict[ProviderType, ProviderToken] = {} for provider in providers_set: provider_information[provider] = ProviderToken(token=None, user_id=None) @@ -47,7 +48,7 @@ def create_provider_tokens_object( @sio.event -async def connect(connection_id: str, environ): +async def connect(connection_id: str, environ: dict) -> None: try: logger.info(f'sio:connect: {connection_id}') query_params = parse_qs(environ.get('QUERY_STRING', '')) @@ -141,18 +142,18 @@ async def connect(connection_id: str, environ): @sio.event -async def oh_user_action(connection_id: str, data: dict): +async def oh_user_action(connection_id: str, data: dict[str, Any]) -> None: await conversation_manager.send_to_event_stream(connection_id, data) @sio.event -async def oh_action(connection_id: str, data: dict): +async def oh_action(connection_id: str, data: dict[str, Any]) -> None: # TODO: Remove this handler once all clients are updated to use oh_user_action # Keeping for backward compatibility with in-progress sessions await conversation_manager.send_to_event_stream(connection_id, data) @sio.event -async def disconnect(connection_id: str): +async def disconnect(connection_id: str) -> None: logger.info(f'sio:disconnect:{connection_id}') await conversation_manager.disconnect_from_session(connection_id) diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 3bf8b3eac9..4ce91593ac 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -1,14 +1,15 @@ import asyncio from collections import defaultdict from datetime import datetime, timedelta -from typing import Callable +from typing import Any from urllib.parse import urlparse from fastapi import Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request as StarletteRequest +from starlette.responses import Response from starlette.types import ASGIApp from openhands.server import shared @@ -22,7 +23,7 @@ class LocalhostCORSMiddleware(CORSMiddleware): while using standard CORS rules for other origins. """ - def __init__(self, app: ASGIApp, **kwargs) -> None: + def __init__(self, app: ASGIApp, **kwargs: Any) -> None: super().__init__(app, **kwargs) def is_allowed_origin(self, origin: str) -> bool: @@ -35,7 +36,8 @@ class LocalhostCORSMiddleware(CORSMiddleware): return True # For missing origin or other origins, use the parent class's logic - return super().is_allowed_origin(origin) + result: bool = super().is_allowed_origin(origin) + return result class CacheControlMiddleware(BaseHTTPMiddleware): @@ -43,7 +45,9 @@ class CacheControlMiddleware(BaseHTTPMiddleware): Middleware to disable caching for all routes by adding appropriate headers """ - async def dispatch(self, request, call_next): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: response = await call_next(request) if request.url.path.startswith('/assets'): # The content of the assets directory has fingerprinted file names so we cache aggressively @@ -58,7 +62,7 @@ class CacheControlMiddleware(BaseHTTPMiddleware): class InMemoryRateLimiter: - history: dict + history: dict[str, list[datetime]] requests: int seconds: int sleep_seconds: int @@ -100,7 +104,9 @@ class RateLimitMiddleware(BaseHTTPMiddleware): super().__init__(app) self.rate_limiter = rate_limiter - async def dispatch(self, request: StarletteRequest, call_next): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: if not self.is_rate_limited_request(request): return await call_next(request) ok = await self.rate_limiter(request) @@ -112,7 +118,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware): ) return await call_next(request) - def is_rate_limited_request(self, request: StarletteRequest): + def is_rate_limited_request(self, request: StarletteRequest) -> bool: if request.url.path.startswith('/assets'): return False # Put Other non rate limited checks here @@ -120,10 +126,10 @@ class RateLimitMiddleware(BaseHTTPMiddleware): class AttachConversationMiddleware(SessionMiddlewareInterface): - def __init__(self, app): + def __init__(self, app: ASGIApp) -> None: self.app = app - def _should_attach(self, request) -> bool: + def _should_attach(self, request: Request) -> bool: """ Determine if the middleware should attach a session for the given request. """ @@ -168,7 +174,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface): request.state.conversation ) - async def __call__(self, request: Request, call_next: Callable): + async def __call__( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: if not self._should_attach(request): return await call_next(request) diff --git a/openhands/server/mock/listen.py b/openhands/server/mock/listen.py index ba5cb3e7fb..f85ee6f96a 100644 --- a/openhands/server/mock/listen.py +++ b/openhands/server/mock/listen.py @@ -8,7 +8,7 @@ app = FastAPI() @app.websocket('/ws') -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() try: @@ -26,12 +26,12 @@ async def websocket_endpoint(websocket: WebSocket): @app.get('/') -def read_root(): +def read_root() -> dict[str, str]: return {'message': 'This is a mock server'} @app.get('/api/options/models') -def read_llm_models(): +def read_llm_models() -> list[str]: return [ 'gpt-4', 'gpt-4-turbo-preview', @@ -41,24 +41,24 @@ def read_llm_models(): @app.get('/api/options/agents') -def read_llm_agents(): +def read_llm_agents() -> list[str]: return [ 'CodeActAgent', ] @app.get('/api/list-files') -def refresh_files(): +def refresh_files() -> list[str]: return ['hello_world.py'] @app.get('/api/options/config') -def get_config(): +def get_config() -> dict[str, str]: return {'APP_MODE': 'oss'} @app.get('/api/options/security-analyzers') -def get_analyzers(): +def get_analyzers() -> list[str]: return [] diff --git a/openhands/server/routes/git.py b/openhands/server/routes/git.py index ba53ac822f..6878b49fef 100644 --- a/openhands/server/routes/git.py +++ b/openhands/server/routes/git.py @@ -39,10 +39,7 @@ async def get_user_repositories( ) try: - repos: list[Repository] = await client.get_repositories( - sort, server_config.app_mode - ) - return repos + return await client.get_repositories(sort, server_config.app_mode) except AuthenticationError as e: return JSONResponse( diff --git a/openhands/server/routes/secrets.py b/openhands/server/routes/secrets.py index 7896b8aba6..86719295c0 100644 --- a/openhands/server/routes/secrets.py +++ b/openhands/server/routes/secrets.py @@ -56,7 +56,7 @@ async def invalidate_legacy_secrets_store( def process_token_validation_result( confirmed_token_type: ProviderType | None, token_type: ProviderType -): +) -> str: if not confirmed_token_type or confirmed_token_type != token_type: return ( f'Invalid token. Please make sure it is a valid {token_type.value} token.' diff --git a/openhands/server/shared.py b/openhands/server/shared.py index 454bd95de1..b91170c27f 100644 --- a/openhands/server/shared.py +++ b/openhands/server/shared.py @@ -4,22 +4,29 @@ import socketio from dotenv import load_dotenv from openhands.core.config import load_app_config -from openhands.server.config.server_config import load_server_config +from openhands.core.config.app_config import AppConfig +from openhands.server.config.server_config import ServerConfig, load_server_config from openhands.server.conversation_manager.conversation_manager import ( ConversationManager, ) from openhands.server.monitoring import MonitoringListener +from openhands.server.types import ServerConfigInterface from openhands.storage import get_file_store from openhands.storage.conversation.conversation_store import ConversationStore +from openhands.storage.files import FileStore from openhands.storage.secrets.secrets_store import SecretsStore from openhands.storage.settings.settings_store import SettingsStore from openhands.utils.import_utils import get_impl load_dotenv() -config = load_app_config() -server_config = load_server_config() -file_store = get_file_store(config.file_store, config.file_store_path) +config: AppConfig = load_app_config() +server_config_interface: ServerConfigInterface = load_server_config() +assert isinstance(server_config_interface, ServerConfig), ( + 'Loaded server config interface is not a ServerConfig, despite this being assumed' +) +server_config: ServerConfig = server_config_interface +file_store: FileStore = get_file_store(config.file_store, config.file_store_path) client_manager = None redis_host = os.environ.get('REDIS_HOST') @@ -42,19 +49,19 @@ MonitoringListenerImpl = get_impl( monitoring_listener = MonitoringListenerImpl.get_instance(config) ConversationManagerImpl = get_impl( - ConversationManager, # type: ignore + ConversationManager, server_config.conversation_manager_class, ) -conversation_manager = ConversationManagerImpl.get_instance( # type: ignore +conversation_manager = ConversationManagerImpl.get_instance( sio, config, file_store, server_config, monitoring_listener ) -SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore +SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) SecretsStoreImpl = get_impl(SecretsStore, server_config.secret_store_class) ConversationStoreImpl = get_impl( - ConversationStore, # type: ignore + ConversationStore, server_config.conversation_store_class, ) diff --git a/openhands/server/static.py b/openhands/server/static.py index ca7eb36c9b..15557f9d76 100644 --- a/openhands/server/static.py +++ b/openhands/server/static.py @@ -1,8 +1,10 @@ from fastapi.staticfiles import StaticFiles +from starlette.responses import Response +from starlette.types import Scope class SPAStaticFiles(StaticFiles): - async def get_response(self, path: str, scope): + async def get_response(self, path: str, scope: Scope) -> Response: try: return await super().get_response(path, scope) except Exception: diff --git a/openhands/server/types.py b/openhands/server/types.py index 4c8c1dc96a..717a0a41ac 100644 --- a/openhands/server/types.py +++ b/openhands/server/types.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import ClassVar, Protocol +from typing import Any, ClassVar, Protocol class AppMode(Enum): @@ -27,7 +27,7 @@ class ServerConfigInterface(ABC): raise NotImplementedError @abstractmethod - async def get_config(self) -> dict[str, str]: + def get_config(self) -> dict[str, Any]: """Configure attributes for frontend""" raise NotImplementedError diff --git a/openhands/server/user_auth/default_user_auth.py b/openhands/server/user_auth/default_user_auth.py index d2472e110d..87a0d5fb11 100644 --- a/openhands/server/user_auth/default_user_auth.py +++ b/openhands/server/user_auth/default_user_auth.py @@ -29,7 +29,7 @@ class DefaultUserAuth(UserAuth): """The default implementation does not support multi tenancy, so access_token is always None""" return None - async def get_user_settings_store(self): + async def get_user_settings_store(self) -> SettingsStore: settings_store = self._settings_store if settings_store: return settings_store @@ -37,6 +37,8 @@ class DefaultUserAuth(UserAuth): settings_store = await shared.SettingsStoreImpl.get_instance( shared.config, user_id ) + if settings_store is None: + raise ValueError('Failed to get settings store instance') self._settings_store = settings_store return settings_store @@ -49,7 +51,7 @@ class DefaultUserAuth(UserAuth): self._settings = settings return settings - async def get_secrets_store(self): + async def get_secrets_store(self) -> SecretsStore: secrets_store = self._secrets_store if secrets_store: return secrets_store @@ -57,6 +59,8 @@ class DefaultUserAuth(UserAuth): secret_store = await shared.SecretsStoreImpl.get_instance( shared.config, user_id ) + if secret_store is None: + raise ValueError('Failed to get secrets store instance') self._secrets_store = secret_store return secret_store @@ -70,9 +74,10 @@ class DefaultUserAuth(UserAuth): return user_secrets async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None: - secrets_store = await self.get_user_secrets() - provider_tokens = getattr(secrets_store, 'provider_tokens', None) - return provider_tokens + user_secrets = await self.get_user_secrets() + if user_secrets is None: + return None + return user_secrets.provider_tokens @classmethod async def get_instance(cls, request: Request) -> UserAuth: diff --git a/openhands/server/user_auth/user_auth.py b/openhands/server/user_auth/user_auth.py index 6ba1e7adff..6c40600b89 100644 --- a/openhands/server/user_auth/user_auth.py +++ b/openhands/server/user_auth/user_auth.py @@ -7,8 +7,8 @@ from fastapi import Request from pydantic import SecretStr from openhands.integrations.provider import PROVIDER_TOKEN_TYPE -from openhands.server.settings import Settings from openhands.server.shared import server_config +from openhands.server.settings import Settings from openhands.storage.data_models.user_secrets import UserSecrets from openhands.storage.secrets.secrets_store import SecretsStore from openhands.storage.settings.settings_store import SettingsStore @@ -38,7 +38,7 @@ class UserAuth(ABC): """Get the provider tokens for the current user.""" @abstractmethod - async def get_user_settings_store(self) -> SettingsStore | None: + async def get_user_settings_store(self) -> SettingsStore: """Get the settings store for the current user.""" async def get_user_settings(self) -> Settings | None: @@ -47,8 +47,6 @@ class UserAuth(ABC): if settings: return settings settings_store = await self.get_user_settings_store() - if settings_store is None: - return None settings = await settings_store.load() self._settings = settings return settings @@ -71,11 +69,13 @@ class UserAuth(ABC): async def get_user_auth(request: Request) -> UserAuth: - user_auth = getattr(request.state, 'user_auth', None) + user_auth: UserAuth | None = getattr(request.state, 'user_auth', None) if user_auth: return user_auth impl_name = server_config.user_auth_class impl = get_impl(UserAuth, impl_name) user_auth = await impl.get_instance(request) + if user_auth is None: + raise ValueError('Failed to get user auth instance') request.state.user_auth = user_auth return user_auth diff --git a/openhands/server/utils.py b/openhands/server/utils.py index 429958d003..f24d250e4e 100644 --- a/openhands/server/utils.py +++ b/openhands/server/utils.py @@ -6,7 +6,9 @@ from openhands.storage.conversation.conversation_store import ConversationStore async def get_conversation_store(request: Request) -> ConversationStore | None: - conversation_store = getattr(request.state, 'conversation_store', None) + conversation_store: ConversationStore | None = getattr( + request.state, 'conversation_store', None + ) if conversation_store: return conversation_store user_auth = await get_user_auth(request)