mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix typing in server directory (#8375)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
parent
f9b0fcd76e
commit
154eed148f
@ -1,5 +1,6 @@
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@ -26,7 +27,7 @@ from openhands.server.shared import conversation_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: FastAPI):
|
||||
async def _lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
async with conversation_manager:
|
||||
yield
|
||||
|
||||
@ -40,7 +41,7 @@ app = FastAPI(
|
||||
|
||||
|
||||
@app.get('/health')
|
||||
async def health():
|
||||
async def health() -> str:
|
||||
return 'OK'
|
||||
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation_store_class = self._conversation_store_class
|
||||
if not conversation_store_class:
|
||||
self._conversation_store_class = conversation_store_class = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
ConversationStore,
|
||||
self.server_config.conversation_store_class,
|
||||
)
|
||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
||||
|
||||
@ -40,6 +40,6 @@ def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]:
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Failed to store feedback: {response.text}')
|
||||
response_data = json.loads(response.text)
|
||||
response_data: dict[str, str] = json.loads(response.text)
|
||||
logger.debug(f'Stored feedback: {response.text}')
|
||||
return response_data
|
||||
|
||||
@ -15,7 +15,7 @@ FILES_TO_IGNORE = [
|
||||
]
|
||||
|
||||
|
||||
def sanitize_filename(filename):
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""Sanitize the filename to prevent directory traversal"""
|
||||
# Remove any directory components
|
||||
filename = os.path.basename(filename)
|
||||
@ -90,7 +90,7 @@ def load_file_upload_config(
|
||||
MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config()
|
||||
|
||||
|
||||
def is_extension_allowed(filename):
|
||||
def is_extension_allowed(filename: str) -> bool:
|
||||
"""Check if the file extension is allowed based on the current configuration.
|
||||
|
||||
This function supports wildcards and files without extensions.
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from socketio.exceptions import ConnectionRefusedError
|
||||
@ -38,7 +39,7 @@ from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
def create_provider_tokens_object(
|
||||
providers_set: list[ProviderType],
|
||||
) -> PROVIDER_TOKEN_TYPE:
|
||||
provider_information = {}
|
||||
provider_information: dict[ProviderType, ProviderToken] = {}
|
||||
|
||||
for provider in providers_set:
|
||||
provider_information[provider] = ProviderToken(token=None, user_id=None)
|
||||
@ -47,7 +48,7 @@ def create_provider_tokens_object(
|
||||
|
||||
|
||||
@sio.event
|
||||
async def connect(connection_id: str, environ):
|
||||
async def connect(connection_id: str, environ: dict) -> None:
|
||||
try:
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
query_params = parse_qs(environ.get('QUERY_STRING', ''))
|
||||
@ -141,18 +142,18 @@ async def connect(connection_id: str, environ):
|
||||
|
||||
|
||||
@sio.event
|
||||
async def oh_user_action(connection_id: str, data: dict):
|
||||
async def oh_user_action(connection_id: str, data: dict[str, Any]) -> None:
|
||||
await conversation_manager.send_to_event_stream(connection_id, data)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def oh_action(connection_id: str, data: dict):
|
||||
async def oh_action(connection_id: str, data: dict[str, Any]) -> None:
|
||||
# TODO: Remove this handler once all clients are updated to use oh_user_action
|
||||
# Keeping for backward compatibility with in-progress sessions
|
||||
await conversation_manager.send_to_event_stream(connection_id, data)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def disconnect(connection_id: str):
|
||||
async def disconnect(connection_id: str) -> None:
|
||||
logger.info(f'sio:disconnect:{connection_id}')
|
||||
await conversation_manager.disconnect_from_session(connection_id)
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from openhands.server import shared
|
||||
@ -22,7 +23,7 @@ class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
while using standard CORS rules for other origins.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, **kwargs) -> None:
|
||||
def __init__(self, app: ASGIApp, **kwargs: Any) -> None:
|
||||
super().__init__(app, **kwargs)
|
||||
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
@ -35,7 +36,8 @@ class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
return True
|
||||
|
||||
# For missing origin or other origins, use the parent class's logic
|
||||
return super().is_allowed_origin(origin)
|
||||
result: bool = super().is_allowed_origin(origin)
|
||||
return result
|
||||
|
||||
|
||||
class CacheControlMiddleware(BaseHTTPMiddleware):
|
||||
@ -43,7 +45,9 @@ class CacheControlMiddleware(BaseHTTPMiddleware):
|
||||
Middleware to disable caching for all routes by adding appropriate headers
|
||||
"""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
response = await call_next(request)
|
||||
if request.url.path.startswith('/assets'):
|
||||
# The content of the assets directory has fingerprinted file names so we cache aggressively
|
||||
@ -58,7 +62,7 @@ class CacheControlMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
|
||||
class InMemoryRateLimiter:
|
||||
history: dict
|
||||
history: dict[str, list[datetime]]
|
||||
requests: int
|
||||
seconds: int
|
||||
sleep_seconds: int
|
||||
@ -100,7 +104,9 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
super().__init__(app)
|
||||
self.rate_limiter = rate_limiter
|
||||
|
||||
async def dispatch(self, request: StarletteRequest, call_next):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if not self.is_rate_limited_request(request):
|
||||
return await call_next(request)
|
||||
ok = await self.rate_limiter(request)
|
||||
@ -112,7 +118,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
def is_rate_limited_request(self, request: StarletteRequest):
|
||||
def is_rate_limited_request(self, request: StarletteRequest) -> bool:
|
||||
if request.url.path.startswith('/assets'):
|
||||
return False
|
||||
# Put Other non rate limited checks here
|
||||
@ -120,10 +126,10 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
|
||||
class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
def __init__(self, app):
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def _should_attach(self, request) -> bool:
|
||||
def _should_attach(self, request: Request) -> bool:
|
||||
"""
|
||||
Determine if the middleware should attach a session for the given request.
|
||||
"""
|
||||
@ -168,7 +174,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
request.state.conversation
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
async def __call__(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
if not self._should_attach(request):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ app = FastAPI()
|
||||
|
||||
|
||||
@app.websocket('/ws')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
@ -26,12 +26,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
def read_root() -> dict[str, str]:
|
||||
return {'message': 'This is a mock server'}
|
||||
|
||||
|
||||
@app.get('/api/options/models')
|
||||
def read_llm_models():
|
||||
def read_llm_models() -> list[str]:
|
||||
return [
|
||||
'gpt-4',
|
||||
'gpt-4-turbo-preview',
|
||||
@ -41,24 +41,24 @@ def read_llm_models():
|
||||
|
||||
|
||||
@app.get('/api/options/agents')
|
||||
def read_llm_agents():
|
||||
def read_llm_agents() -> list[str]:
|
||||
return [
|
||||
'CodeActAgent',
|
||||
]
|
||||
|
||||
|
||||
@app.get('/api/list-files')
|
||||
def refresh_files():
|
||||
def refresh_files() -> list[str]:
|
||||
return ['hello_world.py']
|
||||
|
||||
|
||||
@app.get('/api/options/config')
|
||||
def get_config():
|
||||
def get_config() -> dict[str, str]:
|
||||
return {'APP_MODE': 'oss'}
|
||||
|
||||
|
||||
@app.get('/api/options/security-analyzers')
|
||||
def get_analyzers():
|
||||
def get_analyzers() -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@ -39,10 +39,7 @@ async def get_user_repositories(
|
||||
)
|
||||
|
||||
try:
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
sort, server_config.app_mode
|
||||
)
|
||||
return repos
|
||||
return await client.get_repositories(sort, server_config.app_mode)
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
|
||||
@ -56,7 +56,7 @@ async def invalidate_legacy_secrets_store(
|
||||
|
||||
def process_token_validation_result(
|
||||
confirmed_token_type: ProviderType | None, token_type: ProviderType
|
||||
):
|
||||
) -> str:
|
||||
if not confirmed_token_type or confirmed_token_type != token_type:
|
||||
return (
|
||||
f'Invalid token. Please make sure it is a valid {token_type.value} token.'
|
||||
|
||||
@ -4,22 +4,29 @@ import socketio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from openhands.core.config import load_app_config
|
||||
from openhands.server.config.server_config import load_server_config
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.server.config.server_config import ServerConfig, load_server_config
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.types import ServerConfigInterface
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
load_dotenv()
|
||||
|
||||
config = load_app_config()
|
||||
server_config = load_server_config()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
config: AppConfig = load_app_config()
|
||||
server_config_interface: ServerConfigInterface = load_server_config()
|
||||
assert isinstance(server_config_interface, ServerConfig), (
|
||||
'Loaded server config interface is not a ServerConfig, despite this being assumed'
|
||||
)
|
||||
server_config: ServerConfig = server_config_interface
|
||||
file_store: FileStore = get_file_store(config.file_store, config.file_store_path)
|
||||
|
||||
client_manager = None
|
||||
redis_host = os.environ.get('REDIS_HOST')
|
||||
@ -42,19 +49,19 @@ MonitoringListenerImpl = get_impl(
|
||||
monitoring_listener = MonitoringListenerImpl.get_instance(config)
|
||||
|
||||
ConversationManagerImpl = get_impl(
|
||||
ConversationManager, # type: ignore
|
||||
ConversationManager,
|
||||
server_config.conversation_manager_class,
|
||||
)
|
||||
|
||||
conversation_manager = ConversationManagerImpl.get_instance( # type: ignore
|
||||
conversation_manager = ConversationManagerImpl.get_instance(
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
)
|
||||
|
||||
SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class) # type: ignore
|
||||
SettingsStoreImpl = get_impl(SettingsStore, server_config.settings_store_class)
|
||||
|
||||
SecretsStoreImpl = get_impl(SecretsStore, server_config.secret_store_class)
|
||||
|
||||
ConversationStoreImpl = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
ConversationStore,
|
||||
server_config.conversation_store_class,
|
||||
)
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Scope
|
||||
|
||||
|
||||
class SPAStaticFiles(StaticFiles):
|
||||
async def get_response(self, path: str, scope):
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
try:
|
||||
return await super().get_response(path, scope)
|
||||
except Exception:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Protocol
|
||||
from typing import Any, ClassVar, Protocol
|
||||
|
||||
|
||||
class AppMode(Enum):
|
||||
@ -27,7 +27,7 @@ class ServerConfigInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_config(self) -> dict[str, str]:
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Configure attributes for frontend"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ class DefaultUserAuth(UserAuth):
|
||||
"""The default implementation does not support multi tenancy, so access_token is always None"""
|
||||
return None
|
||||
|
||||
async def get_user_settings_store(self):
|
||||
async def get_user_settings_store(self) -> SettingsStore:
|
||||
settings_store = self._settings_store
|
||||
if settings_store:
|
||||
return settings_store
|
||||
@ -37,6 +37,8 @@ class DefaultUserAuth(UserAuth):
|
||||
settings_store = await shared.SettingsStoreImpl.get_instance(
|
||||
shared.config, user_id
|
||||
)
|
||||
if settings_store is None:
|
||||
raise ValueError('Failed to get settings store instance')
|
||||
self._settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
@ -49,7 +51,7 @@ class DefaultUserAuth(UserAuth):
|
||||
self._settings = settings
|
||||
return settings
|
||||
|
||||
async def get_secrets_store(self):
|
||||
async def get_secrets_store(self) -> SecretsStore:
|
||||
secrets_store = self._secrets_store
|
||||
if secrets_store:
|
||||
return secrets_store
|
||||
@ -57,6 +59,8 @@ class DefaultUserAuth(UserAuth):
|
||||
secret_store = await shared.SecretsStoreImpl.get_instance(
|
||||
shared.config, user_id
|
||||
)
|
||||
if secret_store is None:
|
||||
raise ValueError('Failed to get secrets store instance')
|
||||
self._secrets_store = secret_store
|
||||
return secret_store
|
||||
|
||||
@ -70,9 +74,10 @@ class DefaultUserAuth(UserAuth):
|
||||
return user_secrets
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
secrets_store = await self.get_user_secrets()
|
||||
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
|
||||
return provider_tokens
|
||||
user_secrets = await self.get_user_secrets()
|
||||
if user_secrets is None:
|
||||
return None
|
||||
return user_secrets.provider_tokens
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
|
||||
@ -7,8 +7,8 @@ from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
@ -38,7 +38,7 @@ class UserAuth(ABC):
|
||||
"""Get the provider tokens for the current user."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_settings_store(self) -> SettingsStore | None:
|
||||
async def get_user_settings_store(self) -> SettingsStore:
|
||||
"""Get the settings store for the current user."""
|
||||
|
||||
async def get_user_settings(self) -> Settings | None:
|
||||
@ -47,8 +47,6 @@ class UserAuth(ABC):
|
||||
if settings:
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
if settings_store is None:
|
||||
return None
|
||||
settings = await settings_store.load()
|
||||
self._settings = settings
|
||||
return settings
|
||||
@ -71,11 +69,13 @@ class UserAuth(ABC):
|
||||
|
||||
|
||||
async def get_user_auth(request: Request) -> UserAuth:
|
||||
user_auth = getattr(request.state, 'user_auth', None)
|
||||
user_auth: UserAuth | None = getattr(request.state, 'user_auth', None)
|
||||
if user_auth:
|
||||
return user_auth
|
||||
impl_name = server_config.user_auth_class
|
||||
impl = get_impl(UserAuth, impl_name)
|
||||
user_auth = await impl.get_instance(request)
|
||||
if user_auth is None:
|
||||
raise ValueError('Failed to get user auth instance')
|
||||
request.state.user_auth = user_auth
|
||||
return user_auth
|
||||
|
||||
@ -6,7 +6,9 @@ from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
|
||||
|
||||
async def get_conversation_store(request: Request) -> ConversationStore | None:
|
||||
conversation_store = getattr(request.state, 'conversation_store', None)
|
||||
conversation_store: ConversationStore | None = getattr(
|
||||
request.state, 'conversation_store', None
|
||||
)
|
||||
if conversation_store:
|
||||
return conversation_store
|
||||
user_auth = await get_user_auth(request)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user