mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
refactor: use SQL filtering and pagination in VerifiedModelStore (#13068)
Co-authored-by: bittoby <brianwhitedev1996@gmail.com> Co-authored-by: statxc <statxc@user.noreply.github.com> Co-authored-by: bittoby <bittoby@users.noreply.github.com> Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -48,15 +48,18 @@ from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.routes.verified_models import ( # noqa: E402
|
||||
api_router as verified_models_router,
|
||||
)
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
api_router as verified_models_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
override_llm_models_dependency,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -113,6 +116,11 @@ base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
) # Add routes for verified models management
|
||||
|
||||
# Override the default LLM models implementation with SaaS version
|
||||
# This must happen after all routers are included
|
||||
override_llm_models_dependency(base_app)
|
||||
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
"""API routes for managing verified LLM models (admin only)."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.email_validation import get_admin_user_id
|
||||
from storage.verified_model_store import VerifiedModelStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
|
||||
|
||||
|
||||
class VerifiedModelCreate(BaseModel):
|
||||
model_name: str
|
||||
provider: str
|
||||
is_enabled: bool = True
|
||||
|
||||
@field_validator('model_name')
|
||||
@classmethod
|
||||
def validate_model_name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v or len(v) > 255:
|
||||
raise ValueError('model_name must be 1-255 characters')
|
||||
return v
|
||||
|
||||
@field_validator('provider')
|
||||
@classmethod
|
||||
def validate_provider(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v or len(v) > 100:
|
||||
raise ValueError('provider must be 1-100 characters')
|
||||
return v
|
||||
|
||||
|
||||
class VerifiedModelUpdate(BaseModel):
|
||||
is_enabled: bool | None = None
|
||||
|
||||
|
||||
class VerifiedModelResponse(BaseModel):
|
||||
id: int
|
||||
model_name: str
|
||||
provider: str
|
||||
is_enabled: bool
|
||||
|
||||
|
||||
class VerifiedModelPage(BaseModel):
|
||||
"""Paginated response model for verified model list."""
|
||||
|
||||
items: list[VerifiedModelResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
def _to_response(model) -> VerifiedModelResponse:
|
||||
return VerifiedModelResponse(
|
||||
id=model.id,
|
||||
model_name=model.model_name,
|
||||
provider=model.provider,
|
||||
is_enabled=model.is_enabled,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('', response_model=VerifiedModelPage)
|
||||
async def list_verified_models(
|
||||
provider: str | None = None,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int, Query(title='The max number of results in the page', gt=0, le=100)
|
||||
] = 100,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""List all verified models, optionally filtered by provider."""
|
||||
try:
|
||||
if provider:
|
||||
all_models = VerifiedModelStore.get_models_by_provider(provider)
|
||||
else:
|
||||
all_models = VerifiedModelStore.get_all_models()
|
||||
|
||||
try:
|
||||
offset = int(page_id) if page_id else 0
|
||||
except ValueError:
|
||||
offset = 0
|
||||
page = all_models[offset : offset + limit + 1]
|
||||
has_more = len(page) > limit
|
||||
if has_more:
|
||||
page = page[:limit]
|
||||
|
||||
return VerifiedModelPage(
|
||||
items=[_to_response(m) for m in page],
|
||||
next_page_id=str(offset + limit) if has_more else None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error listing verified models')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to list verified models',
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('', response_model=VerifiedModelResponse, status_code=201)
|
||||
async def create_verified_model(
|
||||
data: VerifiedModelCreate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""Create a new verified model."""
|
||||
try:
|
||||
model = VerifiedModelStore.create_model(
|
||||
model_name=data.model_name,
|
||||
provider=data.provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
return _to_response(model)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error creating verified model')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create verified model',
|
||||
)
|
||||
|
||||
|
||||
@api_router.put('/{provider}/{model_name:path}', response_model=VerifiedModelResponse)
|
||||
async def update_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
data: VerifiedModelUpdate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""Update a verified model by provider and model name."""
|
||||
try:
|
||||
model = VerifiedModelStore.update_model(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Model {provider}/{model_name} not found',
|
||||
)
|
||||
return _to_response(model)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f'Error updating verified model: {provider}/{model_name}')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update verified model',
|
||||
)
|
||||
|
||||
|
||||
@api_router.delete('/{provider}/{model_name:path}')
|
||||
async def delete_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""Delete a verified model by provider and model name."""
|
||||
try:
|
||||
success = VerifiedModelStore.delete_model(
|
||||
model_name=model_name, provider=provider
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Model {provider}/{model_name} not found',
|
||||
)
|
||||
return {'message': f'Model {provider}/{model_name} deleted'}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f'Error deleting verified model: {provider}/{model_name}')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete verified model',
|
||||
)
|
||||
33
enterprise/server/verified_models/verified_model_models.py
Normal file
33
enterprise/server/verified_models/verified_model_models.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, StringConstraints
|
||||
|
||||
|
||||
class VerifiedModelCreate(BaseModel):
|
||||
model_name: Annotated[
|
||||
str,
|
||||
StringConstraints(max_length=255),
|
||||
]
|
||||
provider: Annotated[
|
||||
str,
|
||||
StringConstraints(max_length=100),
|
||||
]
|
||||
is_enabled: bool = True
|
||||
|
||||
|
||||
class VerifiedModel(VerifiedModelCreate):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class VerifiedModelUpdate(BaseModel):
|
||||
is_enabled: bool | None = None
|
||||
|
||||
|
||||
class VerifiedModelPage(BaseModel):
|
||||
"""Paginated response model for verified model list."""
|
||||
|
||||
items: list[VerifiedModel]
|
||||
next_page_id: str | None = None
|
||||
143
enterprise/server/verified_models/verified_model_router.py
Normal file
143
enterprise/server/verified_models/verified_model_router.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""API routes for managing verified LLM models (admin only)."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelCreate,
|
||||
VerifiedModelPage,
|
||||
VerifiedModelUpdate,
|
||||
)
|
||||
from server.verified_models.verified_model_service import (
|
||||
VerifiedModelService,
|
||||
verified_model_store_dependency,
|
||||
)
|
||||
|
||||
from openhands.app_server.config import get_db_session
|
||||
from openhands.server.routes import public
|
||||
from openhands.utils.llm import get_supported_llm_models
|
||||
|
||||
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
|
||||
|
||||
|
||||
@api_router.get('')
|
||||
async def search_verified_models(
|
||||
provider: str | None = None,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int, Query(title='The max number of results in the page', gt=0, le=100)
|
||||
] = 100,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModelPage:
|
||||
"""List all verified models, optionally filtered by provider."""
|
||||
# Use SQL-level filtering and pagination
|
||||
result = await verified_model_service.search_verified_models(
|
||||
provider=provider,
|
||||
enabled_only=False, # Admin sees all models including disabled
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@api_router.post('', status_code=201)
|
||||
async def create_verified_model(
|
||||
data: VerifiedModelCreate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model."""
|
||||
try:
|
||||
model = await verified_model_service.create_verified_model(
|
||||
model_name=data.model_name,
|
||||
provider=data.provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
return model
|
||||
except ValueError as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(ex),
|
||||
)
|
||||
|
||||
|
||||
@api_router.put('/{provider}/{model_name:path}')
|
||||
async def update_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
data: VerifiedModelUpdate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModel:
|
||||
"""Update a verified model by provider and model name."""
|
||||
model = await verified_model_service.update_verified_model(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Model {provider}/{model_name} not found',
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@api_router.delete('/{provider}/{model_name:path}')
|
||||
async def delete_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> bool:
|
||||
"""Delete a verified model by provider and model name."""
|
||||
try:
|
||||
await verified_model_service.delete_verified_model(
|
||||
model_name=model_name, provider=provider
|
||||
)
|
||||
return True
|
||||
except ValueError as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(ex),
|
||||
)
|
||||
|
||||
|
||||
async def get_saas_llm_models_dependency(request: Request) -> list[str]:
|
||||
"""SaaS implementation for the LLM models endpoint."""
|
||||
async with get_db_session(request.state, request) as db_session:
|
||||
# Prevent circular import
|
||||
from openhands.server.shared import config
|
||||
|
||||
verified_model_service = VerifiedModelService(db_session)
|
||||
page = await verified_model_service.search_verified_models(enabled_only=True)
|
||||
if page.next_page_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Too many models defined in database',
|
||||
)
|
||||
verified_models = [f'{m.provider}/{m.model_name}' for m in page.items]
|
||||
return get_supported_llm_models(config, verified_models)
|
||||
|
||||
|
||||
# Override the default implementation with SaaS implementation
|
||||
# This must be called after the app is created in saas_server.py
|
||||
def override_llm_models_dependency(app):
|
||||
"""Override the default LLM models implementation with SaaS version."""
|
||||
app.dependency_overrides[public.get_llm_models_dependency] = (
|
||||
get_saas_llm_models_dependency
|
||||
)
|
||||
242
enterprise/server/verified_models/verified_model_service.py
Normal file
242
enterprise/server/verified_models/verified_model_service.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Store for managing verified LLM models in the database."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Identity,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
func,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.base import Base
|
||||
|
||||
from enterprise.server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelPage,
|
||||
)
|
||||
from openhands.app_server.config import depends_db_session
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class StoredVerifiedModel(Base): # type: ignore
|
||||
"""A verified LLM model available in the model selector.
|
||||
|
||||
The composite unique constraint on (model_name, provider) allows the same
|
||||
model name to exist under different providers (e.g. 'claude-sonnet' under
|
||||
both 'openhands' and 'anthropic').
|
||||
"""
|
||||
|
||||
__tablename__ = 'verified_models'
|
||||
__table_args__ = (
|
||||
UniqueConstraint('model_name', 'provider', name='uq_verified_model_provider'),
|
||||
)
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
model_name = Column(String(255), nullable=False)
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
is_enabled = Column(
|
||||
Boolean, nullable=False, default=True, server_default=text('true')
|
||||
)
|
||||
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
def verified_model(result: StoredVerifiedModel) -> VerifiedModel:
|
||||
return VerifiedModel(
|
||||
id=result.id,
|
||||
model_name=result.model_name,
|
||||
provider=result.provider,
|
||||
is_enabled=result.is_enabled,
|
||||
created_at=result.created_at,
|
||||
updated_at=result.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerifiedModelService:
|
||||
"""Store for CRUD operations on verified models.
|
||||
|
||||
Follows the async pattern with db_session as an attribute.
|
||||
"""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def search_verified_models(
|
||||
self,
|
||||
provider: str | None = None,
|
||||
enabled_only: bool = True,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> VerifiedModelPage:
|
||||
"""Search for verified models with optional filtering and pagination.
|
||||
|
||||
Args:
|
||||
provider: Optional provider name to filter by (e.g., 'openhands', 'anthropic')
|
||||
enabled_only: If True, only return enabled models (default: True)
|
||||
page_id: Page id for pagination
|
||||
limit: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
SearchModelsResult containing items list and has_more flag
|
||||
"""
|
||||
query = select(StoredVerifiedModel)
|
||||
|
||||
# Build filters
|
||||
filters = []
|
||||
if provider:
|
||||
filters.append(StoredVerifiedModel.provider == provider)
|
||||
if enabled_only:
|
||||
filters.append(StoredVerifiedModel.is_enabled.is_(True))
|
||||
|
||||
if filters:
|
||||
query = query.where(and_(*filters))
|
||||
|
||||
# Order by provider, then model_name
|
||||
query = query.order_by(
|
||||
StoredVerifiedModel.provider, StoredVerifiedModel.model_name
|
||||
)
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
offset = int(page_id or '0')
|
||||
query = query.offset(offset).limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
results = list(result.scalars().all())
|
||||
has_more = len(results) > limit
|
||||
next_page_id = None
|
||||
|
||||
# Return only the requested number of results
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
results.pop()
|
||||
|
||||
items = [verified_model(result) for result in results]
|
||||
return VerifiedModelPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def get_model(self, model_name: str, provider: str) -> VerifiedModel | None:
|
||||
"""Get a model by its composite key (model_name, provider).
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def create_verified_model(
|
||||
self,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool = True,
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
is_enabled: Whether the model is enabled (default True)
|
||||
|
||||
Raises:
|
||||
ValueError: If a model with the same (model_name, provider) already exists
|
||||
"""
|
||||
existing_query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(existing_query)
|
||||
existing = result.scalars().first()
|
||||
if existing:
|
||||
raise ValueError(f'Model {provider}/{model_name} already exists')
|
||||
|
||||
model = StoredVerifiedModel(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=is_enabled,
|
||||
)
|
||||
self.db_session.add(model)
|
||||
await self.db_session.commit()
|
||||
await self.db_session.refresh(model)
|
||||
logger.info(f'Created verified model: {provider}/{model_name}')
|
||||
return verified_model(model)
|
||||
|
||||
async def update_verified_model(
|
||||
self,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool | None = None,
|
||||
) -> VerifiedModel | None:
|
||||
"""Update an existing verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to update
|
||||
provider: The provider name
|
||||
is_enabled: New enabled state (optional)
|
||||
|
||||
Returns:
|
||||
The updated model if found, None otherwise
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
if is_enabled is not None:
|
||||
model.is_enabled = is_enabled
|
||||
|
||||
await self.db_session.commit()
|
||||
await self.db_session.refresh(model)
|
||||
logger.info(f'Updated verified model: {provider}/{model_name}')
|
||||
return verified_model(model)
|
||||
|
||||
async def delete_verified_model(self, model_name: str, provider: str):
|
||||
"""Delete a verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to delete
|
||||
provider: The provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
raise ValueError('Unknown model')
|
||||
|
||||
await self.db_session.delete(model)
|
||||
await self.db_session.commit()
|
||||
logger.info(f'Deleted verified model: {provider}/{model_name}')
|
||||
|
||||
|
||||
def verified_model_store_dependency(db_session: AsyncSession = depends_db_session()):
|
||||
return VerifiedModelService(db_session)
|
||||
@@ -1,39 +0,0 @@
|
||||
"""SQLAlchemy model for verified LLM models."""
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Identity,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class VerifiedModel(Base): # type: ignore
|
||||
"""A verified LLM model available in the model selector.
|
||||
|
||||
The composite unique constraint on (model_name, provider) allows the same
|
||||
model name to exist under different providers (e.g. 'claude-sonnet' under
|
||||
both 'openhands' and 'anthropic').
|
||||
"""
|
||||
|
||||
__tablename__ = 'verified_models'
|
||||
__table_args__ = (
|
||||
UniqueConstraint('model_name', 'provider', name='uq_verified_model_provider'),
|
||||
)
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
model_name = Column(String(255), nullable=False)
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
is_enabled = Column(
|
||||
Boolean, nullable=False, default=True, server_default=text('true')
|
||||
)
|
||||
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -1,187 +0,0 @@
|
||||
"""Store for managing verified LLM models in the database."""
|
||||
|
||||
from sqlalchemy import and_
|
||||
from storage.database import session_maker
|
||||
from storage.verified_model import VerifiedModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class VerifiedModelStore:
|
||||
"""Store for CRUD operations on verified models.
|
||||
|
||||
Follows the project convention of static methods with session_maker()
|
||||
(see UserStore, OrgMemberStore for reference).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_enabled_models() -> list[VerifiedModel]:
|
||||
"""Get all enabled models.
|
||||
|
||||
Returns:
|
||||
list[VerifiedModel]: All models where is_enabled is True
|
||||
"""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.filter(VerifiedModel.is_enabled.is_(True))
|
||||
.order_by(VerifiedModel.provider, VerifiedModel.model_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_models_by_provider(provider: str) -> list[VerifiedModel]:
|
||||
"""Get all enabled models for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: The provider name (e.g., 'openhands', 'anthropic')
|
||||
"""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.provider == provider,
|
||||
VerifiedModel.is_enabled.is_(True),
|
||||
)
|
||||
)
|
||||
.order_by(VerifiedModel.model_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_models() -> list[VerifiedModel]:
|
||||
"""Get all models (including disabled)."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.order_by(VerifiedModel.provider, VerifiedModel.model_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model(model_name: str, provider: str) -> VerifiedModel | None:
|
||||
"""Get a model by its composite key (model_name, provider).
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
"""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_model(
|
||||
model_name: str, provider: str, is_enabled: bool = True
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
is_enabled: Whether the model is enabled (default True)
|
||||
|
||||
Raises:
|
||||
ValueError: If a model with the same (model_name, provider) already exists
|
||||
"""
|
||||
with session_maker() as session:
|
||||
existing = (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f'Model {provider}/{model_name} already exists')
|
||||
|
||||
model = VerifiedModel(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=is_enabled,
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
logger.info(f'Created verified model: {provider}/{model_name}')
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool | None = None,
|
||||
) -> VerifiedModel | None:
|
||||
"""Update an existing verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to update
|
||||
provider: The provider name
|
||||
is_enabled: New enabled state (optional)
|
||||
|
||||
Returns:
|
||||
The updated model if found, None otherwise
|
||||
"""
|
||||
with session_maker() as session:
|
||||
model = (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
return None
|
||||
|
||||
if is_enabled is not None:
|
||||
model.is_enabled = is_enabled
|
||||
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
logger.info(f'Updated verified model: {provider}/{model_name}')
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def delete_model(model_name: str, provider: str) -> bool:
|
||||
"""Delete a verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to delete
|
||||
provider: The provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
with session_maker() as session:
|
||||
model = (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
session.delete(model)
|
||||
session.commit()
|
||||
logger.info(f'Deleted verified model: {provider}/{model_name}')
|
||||
return True
|
||||
@@ -4,6 +4,9 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from server.verified_models.verified_model_service import (
|
||||
StoredVerifiedModel, # noqa: F401
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -25,7 +28,6 @@ from storage.stored_conversation_metadata_saas import (
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user import User
|
||||
from storage.verified_model import VerifiedModel # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
"""Unit tests for VerifiedModelStore."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.verified_model_store import VerifiedModelStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_session_maker():
|
||||
"""Create an in-memory SQLite database and patch session_maker."""
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
Base.metadata.create_all(engine)
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
|
||||
with patch(
|
||||
'storage.verified_model_store.session_maker',
|
||||
side_effect=lambda **kwargs: session_factory(**kwargs),
|
||||
):
|
||||
yield
|
||||
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _seed_models(_mock_session_maker):
|
||||
"""Seed the database with test models."""
|
||||
VerifiedModelStore.create_model(model_name='claude-sonnet', provider='openhands')
|
||||
VerifiedModelStore.create_model(model_name='claude-sonnet', provider='anthropic')
|
||||
VerifiedModelStore.create_model(
|
||||
model_name='gpt-4o', provider='openhands', is_enabled=False
|
||||
)
|
||||
|
||||
|
||||
class TestCreateModel:
|
||||
def test_create_model(self, _mock_session_maker):
|
||||
model = VerifiedModelStore.create_model(
|
||||
model_name='test-model', provider='test-provider'
|
||||
)
|
||||
assert model.model_name == 'test-model'
|
||||
assert model.provider == 'test-provider'
|
||||
assert model.is_enabled is True
|
||||
assert model.id is not None
|
||||
|
||||
def test_create_duplicate_raises(self, _mock_session_maker):
|
||||
VerifiedModelStore.create_model(model_name='test-model', provider='test')
|
||||
with pytest.raises(ValueError, match='test/test-model already exists'):
|
||||
VerifiedModelStore.create_model(model_name='test-model', provider='test')
|
||||
|
||||
def test_same_name_different_provider_allowed(self, _mock_session_maker):
|
||||
VerifiedModelStore.create_model(model_name='claude', provider='openhands')
|
||||
model = VerifiedModelStore.create_model(
|
||||
model_name='claude', provider='anthropic'
|
||||
)
|
||||
assert model.provider == 'anthropic'
|
||||
|
||||
|
||||
class TestGetModel:
|
||||
def test_get_model(self, _seed_models):
|
||||
model = VerifiedModelStore.get_model('claude-sonnet', 'openhands')
|
||||
assert model is not None
|
||||
assert model.provider == 'openhands'
|
||||
|
||||
def test_get_model_not_found(self, _seed_models):
|
||||
assert VerifiedModelStore.get_model('nonexistent', 'openhands') is None
|
||||
|
||||
def test_get_model_wrong_provider(self, _seed_models):
|
||||
assert VerifiedModelStore.get_model('claude-sonnet', 'openai') is None
|
||||
|
||||
|
||||
class TestGetModels:
|
||||
def test_get_all_models(self, _seed_models):
|
||||
models = VerifiedModelStore.get_all_models()
|
||||
assert len(models) == 3
|
||||
|
||||
def test_get_enabled_models(self, _seed_models):
|
||||
models = VerifiedModelStore.get_enabled_models()
|
||||
assert len(models) == 2
|
||||
names = {m.model_name for m in models}
|
||||
assert 'gpt-4o' not in names
|
||||
|
||||
def test_get_models_by_provider(self, _seed_models):
|
||||
models = VerifiedModelStore.get_models_by_provider('openhands')
|
||||
assert len(models) == 1
|
||||
assert models[0].model_name == 'claude-sonnet'
|
||||
|
||||
|
||||
class TestUpdateModel:
|
||||
def test_update_model(self, _seed_models):
|
||||
updated = VerifiedModelStore.update_model(
|
||||
model_name='claude-sonnet', provider='openhands', is_enabled=False
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.is_enabled is False
|
||||
|
||||
def test_update_not_found(self, _seed_models):
|
||||
assert (
|
||||
VerifiedModelStore.update_model(
|
||||
model_name='nonexistent', provider='openhands', is_enabled=False
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_update_no_change(self, _seed_models):
|
||||
updated = VerifiedModelStore.update_model(
|
||||
model_name='claude-sonnet', provider='openhands'
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.is_enabled is True
|
||||
|
||||
|
||||
class TestDeleteModel:
|
||||
def test_delete_model(self, _seed_models):
|
||||
assert VerifiedModelStore.delete_model('claude-sonnet', 'openhands') is True
|
||||
assert VerifiedModelStore.get_model('claude-sonnet', 'openhands') is None
|
||||
# Other provider's version should still exist
|
||||
assert VerifiedModelStore.get_model('claude-sonnet', 'anthropic') is not None
|
||||
|
||||
def test_delete_not_found(self, _seed_models):
|
||||
assert VerifiedModelStore.delete_model('nonexistent', 'openhands') is False
|
||||
@@ -0,0 +1,225 @@
|
||||
"""Unit tests for VerifiedModelService."""
|
||||
|
||||
import pytest
|
||||
from server.verified_models.verified_model_service import (
|
||||
VerifiedModelService,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def _seed_models(async_session_maker):
|
||||
"""Seed the database with test models."""
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
await service.create_verified_model(
|
||||
model_name='claude-sonnet', provider='openhands'
|
||||
)
|
||||
await service.create_verified_model(
|
||||
model_name='claude-sonnet', provider='anthropic'
|
||||
)
|
||||
await service.create_verified_model(
|
||||
model_name='gpt-4o', provider='openhands', is_enabled=False
|
||||
)
|
||||
|
||||
|
||||
class TestCreateVerifiedModel:
|
||||
async def test_create_model(self, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
model = await service.create_verified_model(
|
||||
model_name='test-model', provider='test-provider'
|
||||
)
|
||||
assert model.model_name == 'test-model'
|
||||
assert model.provider == 'test-provider'
|
||||
assert model.is_enabled is True
|
||||
assert model.id is not None
|
||||
|
||||
async def test_create_duplicate_raises(self, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
await service.create_verified_model(
|
||||
model_name='test-model', provider='test'
|
||||
)
|
||||
with pytest.raises(ValueError, match='test/test-model already exists'):
|
||||
await service.create_verified_model(
|
||||
model_name='test-model', provider='test'
|
||||
)
|
||||
|
||||
async def test_same_name_different_provider_allowed(self, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
await service.create_verified_model(
|
||||
model_name='claude', provider='openhands'
|
||||
)
|
||||
model = await service.create_verified_model(
|
||||
model_name='claude', provider='anthropic'
|
||||
)
|
||||
assert model.provider == 'anthropic'
|
||||
|
||||
|
||||
class TestGetModel:
|
||||
async def test_get_model(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
model = await service.get_model('claude-sonnet', 'openhands')
|
||||
assert model is not None
|
||||
assert model.provider == 'openhands'
|
||||
|
||||
async def test_get_model_not_found(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
assert await service.get_model('nonexistent', 'openhands') is None
|
||||
|
||||
async def test_get_model_wrong_provider(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
assert await service.get_model('claude-sonnet', 'openai') is None
|
||||
|
||||
|
||||
class TestSearchVerifiedModels:
|
||||
async def test_search_models_no_filters(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
result = await service.search_verified_models()
|
||||
assert len(result.items) == 2 # Only enabled models
|
||||
assert result.next_page_id is None
|
||||
|
||||
async def test_search_models_enabled_only_true(
|
||||
self, _seed_models, async_session_maker
|
||||
):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
result = await service.search_verified_models(enabled_only=True)
|
||||
assert len(result.items) == 2
|
||||
names = {m.model_name for m in result.items}
|
||||
assert 'gpt-4o' not in names # Disabled model not included
|
||||
|
||||
async def test_search_models_enabled_only_false(
|
||||
self, _seed_models, async_session_maker
|
||||
):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
result = await service.search_verified_models(enabled_only=False)
|
||||
assert len(result.items) == 3 # All models including disabled
|
||||
|
||||
async def test_search_models_by_provider(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
result = await service.search_verified_models(provider='openhands')
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].model_name == 'claude-sonnet'
|
||||
|
||||
async def test_search_models_pagination(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
# Create more models for pagination testing
|
||||
await service.create_verified_model(model_name='model-1', provider='test')
|
||||
await service.create_verified_model(model_name='model-2', provider='test')
|
||||
await service.create_verified_model(model_name='model-3', provider='test')
|
||||
await service.create_verified_model(model_name='model-4', provider='test')
|
||||
|
||||
# Total: 7 models (3 initial + 4 new)
|
||||
# First page
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
result = await service.search_verified_models(
|
||||
enabled_only=False, page_id='0', limit=3
|
||||
)
|
||||
assert len(result.items) == 3
|
||||
assert result.next_page_id == '3' # 4 more items after position 2
|
||||
|
||||
# Second page (page_id 3)
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
result = await service.search_verified_models(
|
||||
enabled_only=False, page_id='3', limit=3
|
||||
)
|
||||
assert len(result.items) == 3
|
||||
# There are 4 items total starting at offset 3 (positions 3,4,5,6), so next_page_id exists
|
||||
assert result.next_page_id == '6'
|
||||
|
||||
# Third page (page_id 6) - last item
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
result = await service.search_verified_models(
|
||||
enabled_only=False, page_id='6', limit=3
|
||||
)
|
||||
assert len(result.items) == 1
|
||||
assert result.next_page_id is None # No more items after position 6
|
||||
|
||||
|
||||
class TestUpdateVerifiedModel:
|
||||
async def test_update_model(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
updated = await service.update_verified_model(
|
||||
model_name='claude-sonnet', provider='openhands', is_enabled=False
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.is_enabled is False
|
||||
|
||||
async def test_update_not_found(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
assert (
|
||||
await service.update_verified_model(
|
||||
model_name='nonexistent', provider='openhands', is_enabled=False
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
async def test_update_no_change(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
updated = await service.update_verified_model(
|
||||
model_name='claude-sonnet', provider='openhands'
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.is_enabled is True
|
||||
|
||||
|
||||
class TestDeleteVerifiedModel:
|
||||
async def test_delete_model(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
await service.delete_verified_model('claude-sonnet', 'openhands')
|
||||
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
assert await service.get_model('claude-sonnet', 'openhands') is None
|
||||
# Other provider's version should still exist
|
||||
assert await service.get_model('claude-sonnet', 'anthropic') is not None
|
||||
|
||||
async def test_delete_not_found(self, _seed_models, async_session_maker):
|
||||
async with async_session_maker() as session:
|
||||
service = VerifiedModelService(session)
|
||||
with pytest.raises(ValueError):
|
||||
assert await service.delete_verified_model('nonexistent', 'openhands')
|
||||
@@ -6,9 +6,10 @@
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
# This module belongs to the old V0 web server. The V1 application server lives under openhands/app_server/.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.security.options import SecurityAnalyzers
|
||||
@@ -19,45 +20,21 @@ from openhands.utils.llm import get_supported_llm_models
|
||||
app = APIRouter(prefix='/api/options', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.get('/models', response_model=list[str])
|
||||
async def get_litellm_models() -> list[str]:
|
||||
"""Get all models supported by LiteLLM.
|
||||
async def get_llm_models_dependency(request: Request) -> list[str]:
|
||||
"""Returns a callable that provides the LLM models implementation.
|
||||
|
||||
This function combines models from litellm and Bedrock, removing any
|
||||
error-prone Bedrock models. In SaaS mode, it uses database-backed
|
||||
verified models for dynamic updates without code deployments.
|
||||
|
||||
To get the models:
|
||||
```sh
|
||||
curl http://localhost:3000/api/litellm-models
|
||||
```
|
||||
|
||||
Returns:
|
||||
list[str]: A sorted list of unique model names.
|
||||
Returns a factory that produces the actual implementation function.
|
||||
Override this in enterprise/saas mode via app.dependency_overrides.
|
||||
"""
|
||||
verified_models = _load_verified_models_from_db()
|
||||
return get_supported_llm_models(config, verified_models)
|
||||
|
||||
return get_supported_llm_models(config, [])
|
||||
|
||||
|
||||
def _load_verified_models_from_db() -> list[str] | None:
|
||||
"""Try to load verified models from the database (SaaS mode only).
|
||||
|
||||
Returns:
|
||||
List of model strings like 'provider/model_name' if available, None otherwise.
|
||||
"""
|
||||
try:
|
||||
from storage.verified_model_store import VerifiedModelStore
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
try:
|
||||
db_models = VerifiedModelStore.get_enabled_models()
|
||||
return [f'{m.provider}/{m.model_name}' for m in db_models]
|
||||
except Exception:
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
logger.exception('Failed to load verified models from database')
|
||||
return None
|
||||
@app.get('/models')
|
||||
async def get_litellm_models(
|
||||
models: list[str] = Depends(get_llm_models_dependency),
|
||||
) -> list[str]:
|
||||
return models
|
||||
|
||||
|
||||
@app.get('/agents', response_model=list[str])
|
||||
|
||||
Reference in New Issue
Block a user