mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 05:37:20 +08:00
SAAS: Introducing orgs (phase 1) (#11265)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com> Co-authored-by: Hiep Le <69354317+hieptl@users.noreply.github.com> Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS_PATH ?= ".."
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
@@ -78,19 +78,17 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
# Check global setting first - if disabled globally, return False
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
return settings.enable_proactive_conversation_starters
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
|
||||
# =================================================
|
||||
@@ -151,6 +149,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
@@ -189,6 +188,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
@@ -327,7 +327,6 @@ class GithubIssueComment(GithubIssue):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
@@ -397,7 +396,6 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
|
||||
@@ -189,6 +189,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
'v1_enabled': v1_enabled,
|
||||
},
|
||||
@@ -197,6 +198,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
v1_enabled=v1_enabled,
|
||||
@@ -399,10 +401,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
)
|
||||
|
||||
async def send_message_to_v1_conversation(self, jinja: Environment):
|
||||
|
||||
53
enterprise/integrations/store_repo_utils.py
Normal file
53
enterprise/integrations/store_repo_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import Repository
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
@@ -1,19 +1,24 @@
|
||||
from uuid import UUID
|
||||
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
@@ -21,46 +26,76 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
query=f"metadata['org_id']:'{str(org_id)}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
logger.info(
|
||||
'no_customer_for_org_id',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
return customer_id
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
if customer_id:
|
||||
return {'customer_id': customer_id, 'org_id': str(org.id)}
|
||||
logger.info(
|
||||
'creating_customer',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
org_id=org.id,
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
return customer.id
|
||||
return {'customer_id': customer.id, 'org_id': str(org.id)}
|
||||
|
||||
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
@@ -71,3 +106,28 @@ async def has_payment_method(user_id: str) -> bool:
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -6,15 +6,9 @@ import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.config import get_config
|
||||
from server.constants import WEB_HOST
|
||||
from storage.database import session_maker
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events import Event, EventSource
|
||||
@@ -125,26 +119,17 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.v1_enabled is None:
|
||||
if not org or org.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return settings.v1_enabled
|
||||
return org.v1_enabled
|
||||
|
||||
|
||||
def has_exact_mention(text: str, mention: str) -> bool:
|
||||
@@ -409,51 +394,6 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
return message + footer
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
|
||||
@@ -20,6 +20,8 @@ down_revision = '059'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# TODO: decide whether to modify this for orgs or users
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
@@ -28,8 +30,10 @@ def upgrade():
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
|
||||
# Hardcoded value to prevent migration failures when constant is removed from codebase
|
||||
# This migration has already run in production, so we use the value that was current at the time
|
||||
CURRENT_USER_SETTINGS_VERSION = 4
|
||||
|
||||
# Create a connection and bind it to a session
|
||||
connection = op.get_bind()
|
||||
|
||||
272
enterprise/migrations/versions/089_create_org_tables.py
Normal file
272
enterprise/migrations/versions/089_create_org_tables.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""create org tables from pgerd schema
|
||||
|
||||
Revision ID: 089
|
||||
Revises: 088
|
||||
Create Date: 2025-01-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '089'
|
||||
down_revision: Union[str, None] = '088'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove current settings table
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
# Add already_migrated column to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'already_migrated',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Create role table
|
||||
op.create_table(
|
||||
'role',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('rank', sa.Integer, nullable=False),
|
||||
sa.UniqueConstraint('name', name='role_name_unique'),
|
||||
)
|
||||
|
||||
# 1. Create default roles
|
||||
op.execute(
|
||||
sa.text("""
|
||||
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""")
|
||||
)
|
||||
|
||||
# Create org table with settings fields
|
||||
op.create_table(
|
||||
'org',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('contact_name', sa.String, nullable=True),
|
||||
sa.Column('contact_email', sa.String, nullable=True),
|
||||
sa.Column('conversation_expiration', sa.Integer, nullable=True),
|
||||
# Settings fields moved to org table
|
||||
sa.Column('agent', sa.String, nullable=True),
|
||||
sa.Column('default_max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('security_analyzer', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'confirmation_mode',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('default_llm_model', sa.String, nullable=True),
|
||||
sa.Column('default_llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('billing_margin', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
|
||||
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
|
||||
),
|
||||
sa.Column('mcp_config', sa.JSON, nullable=True),
|
||||
sa.Column('_search_api_key', sa.String, nullable=True),
|
||||
sa.Column('_sandbox_api_key', sa.String, nullable=True),
|
||||
sa.Column('max_budget_per_task', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_solvability_analysis',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('v1_enabled', sa.Boolean, nullable=True),
|
||||
sa.Column('condenser_max_size', sa.Integer, nullable=True),
|
||||
sa.UniqueConstraint('name', name='org_name_unique'),
|
||||
)
|
||||
|
||||
# Create user table with user-specific settings fields
|
||||
op.create_table(
|
||||
'user',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=True),
|
||||
sa.Column('accepted_tos', sa.DateTime, nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('language', sa.String, nullable=True),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
|
||||
sa.Column('email', sa.String, nullable=True),
|
||||
sa.Column('email_verified', sa.Boolean, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
['current_org_id'], ['org.id'], name='current_org_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
|
||||
)
|
||||
|
||||
# Create org_member table (junction table for many-to-many relationship)
|
||||
op.create_table(
|
||||
'org_member',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('_llm_api_key', sa.String, nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('llm_model', sa.String, nullable=True),
|
||||
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('status', sa.String, nullable=True),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
|
||||
sa.PrimaryKeyConstraint('org_id', 'user_id'),
|
||||
)
|
||||
|
||||
# Add org_id column to existing tables
|
||||
# billing_sessions
|
||||
op.add_column(
|
||||
'billing_sessions',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
|
||||
),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# custom_secrets
|
||||
op.add_column(
|
||||
'custom_secrets',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# api_keys
|
||||
op.add_column(
|
||||
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
|
||||
|
||||
# slack_conversation
|
||||
op.add_column(
|
||||
'slack_conversation',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# slack_users
|
||||
op.add_column(
|
||||
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# stripe_customers
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
op.add_column(
|
||||
'stripe_customers',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop already_migrated column from user_settings table
|
||||
op.drop_column('user_settings', 'already_migrated')
|
||||
|
||||
# Drop foreign keys and columns added to existing tables
|
||||
op.drop_constraint(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('stripe_customers', 'org_id')
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
|
||||
op.drop_column('slack_users', 'org_id')
|
||||
|
||||
op.drop_constraint(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('slack_conversation', 'org_id')
|
||||
|
||||
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
|
||||
op.drop_column('api_keys', 'org_id')
|
||||
|
||||
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
|
||||
op.drop_column('custom_secrets', 'org_id')
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
|
||||
op.drop_constraint(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('billing_sessions', 'org_id')
|
||||
|
||||
# Drop tables in reverse order due to foreign key constraints
|
||||
op.drop_table('org_member')
|
||||
op.drop_table('user')
|
||||
op.drop_table('org')
|
||||
op.drop_table('role')
|
||||
11589
enterprise/poetry.lock
generated
11589
enterprise/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -4,6 +4,10 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure SAAS configuration is used
|
||||
if not os.getenv('OPENHANDS_CONFIG_CLS'):
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
|
||||
|
||||
import socketio # noqa: E402
|
||||
from fastapi import Request, status # noqa: E402
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
|
||||
@@ -103,7 +103,6 @@ class SaasUserAuth(UserAuth):
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
# If load() returned None, should settings be created?
|
||||
if settings:
|
||||
settings.email = self.email
|
||||
settings.email_verified = self.email_verified
|
||||
|
||||
@@ -8,7 +8,7 @@ import socketio
|
||||
from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -524,16 +524,18 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
# Map of user settings versions to their corresponding default LLM models
|
||||
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
|
||||
USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
1: 'claude-3-5-sonnet-20241022',
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
@@ -31,7 +31,8 @@ USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
|
||||
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
|
||||
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
|
||||
LITE_LLM_API_URL = os.environ.get(
|
||||
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
|
||||
@@ -55,7 +56,6 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
|
||||
@@ -91,5 +91,5 @@ def get_default_litellm_model():
|
||||
"""Construct proxy for litellm model based on user settings if not set explicitly."""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
return LITELLM_DEFAULT_MODEL
|
||||
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
|
||||
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
|
||||
return build_litellm_proxy_model_path(model)
|
||||
|
||||
@@ -44,11 +44,13 @@ class MyProcessor(MaintenanceTaskProcessor):
|
||||
### UserVersionUpgradeProcessor
|
||||
|
||||
Located in `user_version_upgrade_processor.py`, this processor:
|
||||
|
||||
- Handles up to 100 user IDs per task
|
||||
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
|
||||
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
|
||||
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
|
||||
|
||||
@@ -144,22 +146,26 @@ task = create_maintenance_task(
|
||||
## Best Practices
|
||||
|
||||
### Processor Design
|
||||
|
||||
- Keep tasks short-running (under 1 minute)
|
||||
- Handle errors gracefully and return meaningful error information
|
||||
- Use batch processing for large datasets
|
||||
- Include progress information in the return dict
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Always wrap your processor logic in try-catch blocks
|
||||
- Return structured error information
|
||||
- Log important events for debugging
|
||||
|
||||
### Performance
|
||||
|
||||
- Limit batch sizes to avoid long-running tasks
|
||||
- Use database sessions efficiently
|
||||
- Consider memory usage for large datasets
|
||||
|
||||
### Testing
|
||||
|
||||
- Create unit tests for your processors
|
||||
- Test error conditions
|
||||
- Verify the processor serialization/deserialization works correctly
|
||||
@@ -167,6 +173,7 @@ task = create_maintenance_task(
|
||||
## Database Patterns
|
||||
|
||||
The maintenance task system follows the repository's established patterns:
|
||||
|
||||
- Uses `session_maker()` for database operations
|
||||
- Wraps sync database operations in `call_sync_from_async` for async routes
|
||||
- Follows proper SQLAlchemy query patterns
|
||||
@@ -174,15 +181,18 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Integration with Existing Systems
|
||||
|
||||
### User Management
|
||||
|
||||
- Integrates with the existing `UserSettings` model
|
||||
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
|
||||
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
|
||||
- Maintains compatibility with existing user management workflows
|
||||
|
||||
### Authentication
|
||||
|
||||
- Admin endpoints use the existing SaaS authentication system
|
||||
- Requires users to have `admin = True` in their UserSettings
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Tasks are logged with structured information
|
||||
- Status updates are tracked in the database
|
||||
- Error information is preserved for debugging
|
||||
@@ -206,6 +216,7 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements that could be added:
|
||||
|
||||
- Task dependencies and scheduling
|
||||
- Retry mechanisms for failed tasks
|
||||
- Real-time progress updates
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
"""
|
||||
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
Returns:
|
||||
dict: Results containing successful and failed user IDs
|
||||
"""
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:start',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_count': len(self.user_ids),
|
||||
'current_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if len(self.user_ids) > 100:
|
||||
raise ValueError(
|
||||
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
|
||||
)
|
||||
|
||||
config = load_openhands_config()
|
||||
|
||||
# Track results
|
||||
successful_upgrades = []
|
||||
failed_upgrades = []
|
||||
users_already_current = []
|
||||
|
||||
# Find users that need upgrading
|
||||
with session_maker() as session:
|
||||
users_to_upgrade = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id.in_(self.user_ids),
|
||||
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Track users that are already current
|
||||
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
|
||||
users_already_current = [
|
||||
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:found_users',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'users_to_upgrade': len(users_to_upgrade),
|
||||
'users_already_current': len(users_already_current),
|
||||
'total_requested': len(self.user_ids),
|
||||
},
|
||||
)
|
||||
|
||||
# Process each user that needs upgrading
|
||||
for user_settings in users_to_upgrade:
|
||||
user_id = user_settings.keycloak_user_id
|
||||
old_version = user_settings.user_version
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:upgrading_user',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
# Create SaasSettingsStore instance and upgrade
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
await settings_store.create_default_settings(user_settings)
|
||||
|
||||
successful_upgrades.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:user_upgraded',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_upgrades.append(
|
||||
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
|
||||
)
|
||||
|
||||
logger.error(
|
||||
'user_version_upgrade_processor:user_upgrade_failed',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Create result summary
|
||||
result = {
|
||||
'total_users': len(self.user_ids),
|
||||
'users_already_current': users_already_current,
|
||||
'successful_upgrades': successful_upgrades,
|
||||
'failed_upgrades': failed_upgrades,
|
||||
'summary': (
|
||||
f'Processed {len(self.user_ids)} users: '
|
||||
f'{len(successful_upgrades)} upgraded, '
|
||||
f'{len(users_already_current)} already current, '
|
||||
f'{len(failed_upgrades)} errors'
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:completed',
|
||||
extra={'task_id': task.id, 'result': result},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@@ -36,6 +34,7 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
Returns:
|
||||
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
|
||||
"""
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
|
||||
@@ -1,113 +1,95 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
)
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
user_db_settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
if user_db_settings and user_db_settings.llm_api_key_for_byor:
|
||||
return user_db_settings.llm_api_key_for_byor
|
||||
return None
|
||||
def _get_byor_key():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
return None
|
||||
|
||||
return await call_sync_from_async(_get_byor_key)
|
||||
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
def _update_user_settings():
|
||||
with session_maker() as session:
|
||||
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
if user_db_settings:
|
||||
user_db_settings.llm_api_key_for_byor = key
|
||||
session.commit()
|
||||
logger.info(
|
||||
'Successfully stored BYOR key in user settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'User settings not found when trying to store BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
current_org_id = str(user.current_org_id)
|
||||
if not user:
|
||||
return None
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id,
|
||||
current_org_id,
|
||||
f'BYOR Key - user {user_id}, org {current_org_id}',
|
||||
{'type': 'byor'},
|
||||
)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'metadata': {'type': 'byor'},
|
||||
'key_alias': f'BYOR Key - user {user_id}',
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...'
|
||||
if key and len(key) > 10
|
||||
else key,
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error generating BYOR key',
|
||||
@@ -116,96 +98,16 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
async def verify_byor_key_in_litellm(byor_key: str, user_id: str) -> bool:
|
||||
"""Verify that a BYOR key is valid in LiteLLM by making a lightweight API call.
|
||||
|
||||
Args:
|
||||
byor_key: The BYOR key to verify
|
||||
user_id: The user ID for logging purposes
|
||||
|
||||
Returns:
|
||||
True if the key is verified as valid, False if verification fails or key is invalid.
|
||||
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
|
||||
"""
|
||||
if not (LITE_LLM_API_URL and byor_key):
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/v1/models',
|
||||
headers={
|
||||
'Authorization': f'Bearer {byor_key}',
|
||||
},
|
||||
)
|
||||
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
'key_prefix': byor_key[:10] + '...'
|
||||
if len(byor_key) > 10
|
||||
else byor_key,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
except (httpx.TimeoutException, Exception) as e:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Delete the key directly using the key value
|
||||
delete_url = f'{LITE_LLM_API_URL}/key/delete'
|
||||
delete_payload = {'keys': [byor_key]}
|
||||
|
||||
delete_response = await client.post(delete_url, json=delete_payload)
|
||||
delete_response.raise_for_status()
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
await LiteLlmManager.delete_key(byor_key)
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error deleting BYOR key from LiteLLM',
|
||||
@@ -357,7 +259,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
byor_key = await get_byor_key_from_db(user_id)
|
||||
if byor_key:
|
||||
# Validate that the key is actually registered in LiteLLM
|
||||
is_valid = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
if is_valid:
|
||||
return {'key': byor_key}
|
||||
else:
|
||||
@@ -412,15 +314,6 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='LiteLLM API configuration not found',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
@@ -22,12 +23,12 @@ from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.recaptcha_service import recaptcha_service
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config, sign_token
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
@@ -36,6 +37,7 @@ from openhands.server.services.conversation_service import create_provider_token
|
||||
from openhands.server.shared import config
|
||||
from openhands.server.user_auth import get_access_token
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@@ -87,7 +89,8 @@ def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if (request.url.hostname or '').endswith('staging.all-hand.dev')
|
||||
if not request.url.hostname
|
||||
or request.url.hostname.endswith('staging.all-hands.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
@@ -174,6 +177,20 @@ async def keycloak_callback(
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
|
||||
# reCAPTCHA verification with Account Defender
|
||||
if RECAPTCHA_SITE_KEY:
|
||||
@@ -360,15 +377,7 @@ async def keycloak_callback(
|
||||
f'&state={state}'
|
||||
)
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
|
||||
has_accepted_tos = (
|
||||
user_settings is not None and user_settings.accepted_tos is not None
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
@@ -486,28 +495,20 @@ async def accept_tos(request: Request):
|
||||
redirect_url = body.get('redirect_url', str(request.base_url))
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.accepted_tos = datetime.now(timezone.utc)
|
||||
session.merge(user_settings)
|
||||
else:
|
||||
# Create user settings if they don't exist
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
accepted_tos=datetime.now(timezone.utc),
|
||||
user_version=0, # This will trigger a migration to the latest version on next load
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
session.add(user_settings)
|
||||
|
||||
user.accepted_tos = accepted_tos
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
|
||||
|
||||
@@ -4,30 +4,23 @@ from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
STRIPE_API_KEY,
|
||||
STRIPE_WEBHOOK_SECRET,
|
||||
SUBSCRIPTION_PRICE_DATA,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
@@ -106,31 +99,20 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
return max(max_budget - spend, 0.0)
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current credit balance
|
||||
# Endpoint to retrieve the current organization's credit balance
|
||||
@billing_router.get('/credits')
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=15.0
|
||||
) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f'litellm_get_user_failed: {type(e).__name__}: {e}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': e.response.status_code,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve credit balance from billing service',
|
||||
)
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
@@ -165,79 +147,7 @@ async def get_subscription_access(
|
||||
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await stripe_service.has_payment_method(user_id)
|
||||
|
||||
|
||||
# Endpoint to cancel user's subscription
|
||||
@billing_router.post('/cancel-subscription')
|
||||
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
|
||||
"""Cancel user's active subscription at the end of the current billing period."""
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
with session_maker() as session:
|
||||
# Find the user's active subscription
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No active subscription found',
|
||||
)
|
||||
|
||||
if not subscription_access.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot cancel subscription: missing Stripe subscription ID',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cancel the subscription in Stripe at period end
|
||||
await stripe.Subscription.modify_async(
|
||||
subscription_access.stripe_subscription_id, cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Update local database
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
'end_at': subscription_access.end_at,
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{'status': 'success', 'message': 'Subscription cancelled successfully'}
|
||||
)
|
||||
|
||||
except stripe.StripeError as e:
|
||||
logger.error(
|
||||
'stripe_cancellation_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Failed to cancel subscription: {str(e)}',
|
||||
)
|
||||
return await stripe_service.has_payment_method_by_user_id(user_id)
|
||||
|
||||
|
||||
# Endpoint to create a new setup intent in stripe
|
||||
@@ -246,16 +156,15 @@ async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Endpoint to create a new Stripe checkout session for credit purchase
|
||||
@@ -267,9 +176,9 @@ async def create_checkout_session(
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
line_items=[
|
||||
{
|
||||
'price_data': {
|
||||
@@ -282,7 +191,7 @@ async def create_checkout_session(
|
||||
'tax_behavior': 'exclusive',
|
||||
},
|
||||
'quantity': 1,
|
||||
}
|
||||
},
|
||||
],
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
@@ -295,8 +204,9 @@ async def create_checkout_session(
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'stripe_customer_id': customer_info['customer_id'],
|
||||
'user_id': user_id,
|
||||
'org_id': customer_info['org_id'],
|
||||
'amount': body.amount,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
@@ -305,105 +215,14 @@ async def create_checkout_session(
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
org_id=customer_info['org_id'],
|
||||
price=body.amount,
|
||||
price_code='NA',
|
||||
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@billing_router.post('/subscription-checkout-session')
|
||||
async def create_subscription_checkout_session(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
existing_active_subscription = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_active_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
|
||||
)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': subscription_price_data,
|
||||
'quantity': 1,
|
||||
}
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
subscription_data={
|
||||
'metadata': {
|
||||
'user_id': user_id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_subscription_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
price=subscription_price_data['unit_amount'],
|
||||
price_code='NA',
|
||||
billing_session_type=billing_session_type.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(
|
||||
redirect_url=typing.cast(str, checkout_session.url)
|
||||
)
|
||||
|
||||
|
||||
@billing_router.get('/create-subscription-checkout-session')
|
||||
async def create_subscription_checkout_session_via_get(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
return RedirectResponse(response.redirect_url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
|
||||
@@ -425,15 +244,6 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
!= BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
stripe_session = stripe.checkout.Session.retrieve(session_id)
|
||||
if stripe_session.status != 'complete':
|
||||
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
|
||||
@@ -447,31 +257,39 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
# Update max budget in litellm
|
||||
user_json = await _get_litellm_user(client, billing_session.user_id)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
new_max_budget = (
|
||||
(user_json.get('user_info') or {}).get('max_budget') or 0
|
||||
) + add_credits
|
||||
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
|
||||
user = await call_sync_from_async(
|
||||
UserStore.get_user_by_id, billing_session.user_id
|
||||
)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = amount_subtotal
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'org_id': str(user.current_org_id),
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
@@ -501,206 +319,6 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
|
||||
# Redirect credit purchases to billing screen, subscriptions to LLM settings
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
== BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel',
|
||||
status_code=302,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
# If no billing session found, default to LLM settings (subscription flow)
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@billing_router.post('/stripe-webhook')
|
||||
async def stripe_webhook(request: Request) -> JSONResponse:
|
||||
"""Endpoint for stripe webhooks."""
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
|
||||
|
||||
# Handle the event
|
||||
logger.info('stripe_webhook_event', extra={'event': event})
|
||||
event_type = event['type']
|
||||
if event_type == 'invoice.paid':
|
||||
invoice = event['data']['object']
|
||||
amount_paid = invoice.amount_paid
|
||||
metadata = invoice.parent.subscription_details.metadata # type: ignore
|
||||
billing_session_type = metadata.billing_session_type
|
||||
assert (
|
||||
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
|
||||
)
|
||||
user_id = metadata.user_id
|
||||
|
||||
start_at = datetime.now(UTC)
|
||||
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
|
||||
end_at = start_at + relativedelta(months=1)
|
||||
else:
|
||||
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = SubscriptionAccess(
|
||||
status='ACTIVE',
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
amount_paid=amount_paid,
|
||||
stripe_invoice_payment_id=invoice.payment_intent,
|
||||
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
|
||||
)
|
||||
session.add(subscription_access)
|
||||
session.commit()
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
# Handle subscription cancellation
|
||||
if subscription.get('cancel_at_period_end') is True:
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(
|
||||
SubscriptionAccess.stripe_subscription_id == subscription_id
|
||||
)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access and not subscription_access.cancelled_at:
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled_via_webhook',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
subscription_access.status = 'DISABLED'
|
||||
subscription_access.updated_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
# Reset user settings to free tier defaults
|
||||
reset_user_to_free_tier_settings(subscription_access.user_id)
|
||||
|
||||
logger.info(
|
||||
'subscription_expired_reset_to_free_tier',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
|
||||
|
||||
return JSONResponse({'status': 'success'})
|
||||
|
||||
|
||||
def reset_user_to_free_tier_settings(user_id: str) -> None:
|
||||
"""Reset user settings to free tier defaults when subscription ends."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = None
|
||||
user_settings.llm_api_key_for_byor = None
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
user_settings.max_budget_per_task = None
|
||||
user_settings.confirmation_mode = False
|
||||
user_settings.enable_solvability_analysis = False
|
||||
user_settings.security_analyzer = 'llm'
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
user_settings.language = 'en'
|
||||
user_settings.enable_default_condenser = True
|
||||
user_settings.enable_sound_notifications = False
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
user_settings.user_consents_to_analytics = False
|
||||
|
||||
session.merge(user_settings)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_settings_reset_to_free_tier',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'reset_timestamp': datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
|
||||
"""Get a user from litellm with the id matching that given.
|
||||
|
||||
If no such user exists, returns a dummy user in the format:
|
||||
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
|
||||
"""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _upsert_litellm_user(
|
||||
client: httpx.AsyncClient, user_id: str, max_budget: float
|
||||
):
|
||||
"""Insert / Update a user in litellm."""
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -6,7 +6,7 @@ from threading import Thread
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import wait_all
|
||||
@@ -127,7 +127,7 @@ def _db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
with session_maker() as session:
|
||||
num_users = session.query(UserSettings).count()
|
||||
num_users = session.query(User).count()
|
||||
time.sleep(delay)
|
||||
logger.info(
|
||||
'check',
|
||||
@@ -155,7 +155,7 @@ async def _a_db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
async with a_session_maker() as a_session:
|
||||
stmt = select(func.count(UserSettings.id))
|
||||
stmt = select(func.count(User.id))
|
||||
num_users = await a_session.execute(stmt)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(f'a_num_users:{num_users.scalar_one()}')
|
||||
|
||||
@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.shared import file_store
|
||||
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ from integrations.slack.slack_manager import SlackManager
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -35,9 +34,11 @@ from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
signature_verifier = SignatureVerifier(signing_secret=SLACK_SIGNING_SECRET)
|
||||
slack_router = APIRouter(prefix='/slack')
|
||||
@@ -79,6 +80,14 @@ async def install_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('slack_install_callback_error JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncWebClient() # no prepared token needed for this
|
||||
# Complete the installation by calling oauth.v2.access API method
|
||||
@@ -94,16 +103,17 @@ async def install_callback(
|
||||
|
||||
# Create a state variable for keycloak oauth
|
||||
payload = {}
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if state:
|
||||
payload = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
payload['slack_user_id'] = authed_user.get('id')
|
||||
payload['bot_access_token'] = bot_access_token
|
||||
payload['team_id'] = team_id
|
||||
|
||||
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
|
||||
state = jwt.encode(
|
||||
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
|
||||
# Redirect into keycloak
|
||||
scope = quote('openid email profile offline_access')
|
||||
@@ -149,9 +159,16 @@ async def keycloak_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if not config.jwt_secret:
|
||||
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
payload: dict[str, str] = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
slack_user_id = payload['slack_user_id']
|
||||
bot_access_token = payload['bot_access_token']
|
||||
@@ -180,6 +197,13 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# These tokens are offline access tokens - store them!
|
||||
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
|
||||
@@ -211,6 +235,7 @@ async def keycloak_callback(
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
org_id=user.current_org_id,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_display_name=slack_display_name,
|
||||
)
|
||||
@@ -305,7 +330,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
|
||||
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
|
||||
payload = json.loads(form.get('payload'))
|
||||
|
||||
logger.info('slack_on_form_interaction', extra={'payload': payload})
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy import orm
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
@@ -646,16 +647,18 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
"""
|
||||
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation_metadata:
|
||||
if not conversation_metadata_saas:
|
||||
raise ValueError(f'No conversation found {conversation_id}')
|
||||
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
async def _get_runtime_status_from_nested_runtime(
|
||||
self, session_api_key: Any | None, nested_url: str, conversation_id: str
|
||||
@@ -994,9 +997,17 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
with session_maker() as session:
|
||||
# Only include conversations updated in the past week
|
||||
one_week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
query = session.query(StoredConversationMetadata.conversation_id).filter(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
query = (
|
||||
session.query(StoredConversationMetadata.conversation_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
)
|
||||
user_conversation_ids = set(query)
|
||||
return user_conversation_ids
|
||||
@@ -1070,11 +1081,16 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None or conversation_metadata_saas is None:
|
||||
# Conversation is running in different server
|
||||
return
|
||||
|
||||
user_id = conversation_metadata.user_id
|
||||
user_id = conversation_metadata_saas.user_id
|
||||
|
||||
# Get the id of the next event which is not present
|
||||
events_dir = get_conversation_events_dir(
|
||||
|
||||
@@ -241,7 +241,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
created_by_user_id=None, # user_id is no longer stored in conversation metadata
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
|
||||
350
enterprise/server/utils/saas_app_conversation_info_injector.py
Normal file
350
enterprise/server/utils/saas_app_conversation_info_injector.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import func, select
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
AppConversationInfoServiceInjector,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
AppConversationInfoPage,
|
||||
AppConversationSortOrder,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
select(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> AppConversationInfoPage:
|
||||
"""Search for conversations with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == AppConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == AppConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
for stored_metadata, saas_metadata in rows
|
||||
]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
count = result.scalar()
|
||||
return count or 0
|
||||
|
||||
def _apply_filters_with_saas_metadata(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply filters to query that includes SAAS metadata."""
|
||||
# Apply the same filters as the base class
|
||||
conditions = []
|
||||
if title__contains is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.title.like(f'%{title__contains}%')
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
|
||||
|
||||
if created_at__lt is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
return query
|
||||
|
||||
async def get_app_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> AppConversationInfo | None:
|
||||
"""Get conversation info with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
result_set = await self.db_session.execute(query)
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[AppConversationInfo | None]:
|
||||
"""Batch get conversation info with user_id from SAAS metadata."""
|
||||
conversation_id_strs = [
|
||||
str(conversation_id) for conversation_id in conversation_ids
|
||||
]
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Create a mapping of conversation_id to (metadata, saas_metadata)
|
||||
info_by_id = {}
|
||||
for stored_metadata, saas_metadata in rows:
|
||||
info_by_id[stored_metadata.conversation_id] = (
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
)
|
||||
|
||||
results: list[AppConversationInfo | None] = []
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
|
||||
# Save the base conversation metadata
|
||||
await super().save_app_conversation_info(info)
|
||||
|
||||
# Get current user_id for SAAS metadata
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
user_query = select(User).where(User.id == user_id_uuid)
|
||||
result = await self.db_session.execute(user_query)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
if not existing_saas_metadata:
|
||||
# Create new SAAS metadata
|
||||
# Set org_id to user_id as specified in requirements
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=user.current_org_id,
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
return info
|
||||
|
||||
def _to_info_with_user_id(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
str(saas_metadata.user_id) if saas_metadata.user_id else None
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[AppConversationInfoService, None]:
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_user_context,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=db_session, user_context=user_context
|
||||
)
|
||||
yield service
|
||||
@@ -0,0 +1,85 @@
|
||||
from storage.api_key import ApiKey
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.experiment_assignment import ExperimentAssignment
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.subscription_access_status import SubscriptionAccessStatus
|
||||
from storage.user import User
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
__all__ = [
|
||||
'ApiKey',
|
||||
'AuthTokens',
|
||||
'BillingSession',
|
||||
'BillingSessionType',
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'StoredConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
'ExperimentAssignment',
|
||||
'Feedback',
|
||||
'GithubAppInstallation',
|
||||
'GitlabWebhook',
|
||||
'JiraConversation',
|
||||
'JiraDcConversation',
|
||||
'JiraDcUser',
|
||||
'JiraDcWorkspace',
|
||||
'JiraUser',
|
||||
'JiraWorkspace',
|
||||
'LinearConversation',
|
||||
'LinearUser',
|
||||
'LinearWorkspace',
|
||||
'MaintenanceTask',
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
'SlackUser',
|
||||
'StoredConversationMetadata',
|
||||
'StoredOfflineToken',
|
||||
'StoredRepository',
|
||||
'StoredCustomSecrets',
|
||||
'StripeCustomer',
|
||||
'SubscriptionAccess',
|
||||
'SubscriptionAccessStatus',
|
||||
'User',
|
||||
'UserRepositoryMap',
|
||||
'UserSettings',
|
||||
'WebhookStatus',
|
||||
]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,13 @@ class ApiKey(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
key = Column(String(255), nullable=False, unique=True, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
name = Column(String(255), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='api_keys')
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -39,10 +40,15 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key, user_id=user_id, name=name, expires_at=expires_at
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
@@ -108,8 +114,15 @@ class ApiKeyStore:
|
||||
|
||||
def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
"""List all API keys for a user."""
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
keys = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -124,9 +137,14 @@ class ApiKeyStore:
|
||||
]
|
||||
|
||||
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,9 @@ class BillingSession(Base): # type: ignore
|
||||
"""
|
||||
|
||||
__tablename__ = 'billing_sessions'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
status = Column(
|
||||
Enum(
|
||||
'in_progress',
|
||||
@@ -24,15 +26,6 @@ class BillingSession(Base): # type: ignore
|
||||
),
|
||||
default='in_progress',
|
||||
)
|
||||
billing_session_type = Column(
|
||||
Enum(
|
||||
'DIRECT_PAYMENT',
|
||||
'MONTHLY_SUBSCRIPTION',
|
||||
name='billing_session_type_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='DIRECT_PAYMENT',
|
||||
)
|
||||
price = Column(DECIMAL(19, 4), nullable=False)
|
||||
price_code = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -43,3 +36,6 @@ class BillingSession(Base): # type: ignore
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='billing_sessions')
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -10,7 +15,6 @@ from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text, text
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from storage.base import Base
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
@@ -33,7 +37,7 @@ class ConversationCallbackProcessor(BaseModel, ABC):
|
||||
async def __call__(
|
||||
self,
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
observation: 'AgentStateChangedObservation',
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
@@ -7,6 +8,9 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.util import await_only
|
||||
|
||||
# Check if we're running in a test environment
|
||||
IS_TESTING = 'pytest' in sys.modules
|
||||
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
|
||||
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
|
||||
114
enterprise/storage/encrypt_utils.py
Normal file
114
enterprise/storage/encrypt_utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from pydantic import SecretStr
|
||||
from server.config import get_config
|
||||
|
||||
_jwt_service = None
|
||||
_fernet = None
|
||||
|
||||
|
||||
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
encrypt_kwargs(encrypt_keys, value)
|
||||
continue
|
||||
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_value(value: str | SecretStr) -> str:
|
||||
return get_jwt_service().create_jwe_token(
|
||||
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
|
||||
)
|
||||
|
||||
|
||||
def decrypt_value(value: str | SecretStr) -> str:
|
||||
token = get_jwt_service().decrypt_jwe_token(
|
||||
value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
)
|
||||
return token['v']
|
||||
|
||||
|
||||
def get_jwt_service():
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
global _jwt_service
|
||||
if _jwt_service is None:
|
||||
jwt_service_injector = get_global_config().jwt
|
||||
assert jwt_service_injector is not None
|
||||
_jwt_service = jwt_service_injector.get_jwt_service()
|
||||
return _jwt_service
|
||||
|
||||
|
||||
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
except InvalidToken:
|
||||
pass # Key not encrypted...
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return (
|
||||
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
|
||||
)
|
||||
else:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
jwt_secret = get_config().jwt_secret.get_secret_value()
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
_fernet = Fernet(fernet_key)
|
||||
return _fernet
|
||||
|
||||
|
||||
def model_to_kwargs(model_instance):
|
||||
return {
|
||||
column.name: getattr(model_instance, column.name)
|
||||
for column in model_instance.__table__.columns
|
||||
}
|
||||
@@ -1,7 +1,16 @@
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
|
||||
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
|
||||
760
enterprise/storage/lite_llm_manager.py
Normal file
760
enterprise/storage/lite_llm_manager.py
Normal file
@@ -0,0 +1,760 @@
|
||||
"""
|
||||
Store class for managing organizational settings.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# Timeout in seconds for BYOR key verification requests to LiteLLM
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
oss_settings: Settings,
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'SettingsStore:update_settings_with_litellm_default:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id}',
|
||||
None,
|
||||
)
|
||||
|
||||
oss_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
oss_settings.llm_model = get_default_litellm_model()
|
||||
oss_settings.llm_api_key = SecretStr(key)
|
||||
oss_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return oss_settings
|
||||
|
||||
@staticmethod
|
||||
async def migrate_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
logger.info(
|
||||
'SettingsStore:umigrate_lite_llm_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
user_json = await LiteLlmManager._get_user(client, keycloak_user_id)
|
||||
if not user_json:
|
||||
return None
|
||||
user_info = user_json['user_info']
|
||||
max_budget = user_info.get('max_budget', 0.0)
|
||||
spend = user_info.get('spend', 0.0)
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
|
||||
if not max_budget:
|
||||
# if max_budget is None, then we've already migrated the User
|
||||
return None
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def update_team_and_users_budget(
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._update_team(client, team_id, None, max_budget)
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
# TODO: change to use bulk update endpoint
|
||||
for membership in team_info.get('team_memberships', []):
|
||||
user_id = membership.get('user_id')
|
||||
if not user_id:
|
||||
continue
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
client, user_id, team_id, max_budget
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _create_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_alias: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
response.status_code == 400
|
||||
and 'already exists. Please use a different team id' in response.text
|
||||
):
|
||||
# team already exists, so update, then return
|
||||
await LiteLlmManager._update_team(
|
||||
client, team_id, team_alias, max_budget
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': team_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_team(client: httpx.AsyncClient, team_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a team from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/team/info?team_id={team_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_id: str,
|
||||
team_alias: str | None,
|
||||
max_budget: float | None,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
if team_alias is not None:
|
||||
json_data['team_alias'] = team_alias
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/update',
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Team failed to update in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': None,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if response.status_code == 400 and 'already exists' in response.text:
|
||||
# user already exists, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a user from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'user_id': keycloak_user_id,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _update_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
key: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'key': key,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _delete_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [keycloak_user_id]}
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_deleting_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _add_user_to_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_adding_litellm_user_to_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_team_info(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
async def _update_user_in_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user_in_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str | None,
|
||||
key_alias: str | None,
|
||||
metadata: dict | None,
|
||||
) -> str | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
json_data: dict[str, Any] = {
|
||||
'user_id': keycloak_user_id,
|
||||
'models': [],
|
||||
}
|
||||
|
||||
if team_id is not None:
|
||||
json_data['team_id'] = team_id
|
||||
|
||||
if key_alias is not None:
|
||||
json_data['key_alias'] = key_alias
|
||||
|
||||
if metadata is not None:
|
||||
json_data['metadata'] = metadata
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json=json_data,
|
||||
)
|
||||
# Failed to generate user key for team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_generate_user_team_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
logger.info(
|
||||
'LiteLlmManager:_lite_llm_generate_user_team_key:key_created',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
async def verify_key(key: str, user_id: str) -> bool:
|
||||
"""Verify that a key is valid in LiteLLM by making a lightweight API call.
|
||||
|
||||
Args:
|
||||
key: The key to verify
|
||||
user_id: The user ID for logging purposes
|
||||
|
||||
Returns:
|
||||
True if the key is verified as valid, False if verification fails or key is invalid.
|
||||
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
|
||||
"""
|
||||
if not (LITE_LLM_API_URL and key):
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/v1/models',
|
||||
headers={
|
||||
'Authorization': f'Bearer {key}',
|
||||
},
|
||||
)
|
||||
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
'key_prefix': key[:10] + '...' if len(key) > 10 else key,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
except (httpx.TimeoutException, Exception) as e:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _get_key_info(
|
||||
client: httpx.AsyncClient,
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
) -> dict | None:
|
||||
from storage.user_store import UserStore
|
||||
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return {}
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/key/info?key={org_member.llm_api_key}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key_info = response_json.get('info')
|
||||
if not key_info:
|
||||
return {}
|
||||
return {
|
||||
'key_max_budget': key_info.get('max_budget'),
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/delete',
|
||||
json={
|
||||
'keys': [key_id],
|
||||
},
|
||||
)
|
||||
# Failed to key...
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# key doesn't exist, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_deleting_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_delete_key:key_deleted',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def with_http_client(
|
||||
internal_fn: Callable[..., Awaitable[Any]],
|
||||
) -> Callable[..., Awaitable[Any]]:
|
||||
@functools.wraps(internal_fn)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with httpx.AsyncClient(
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY}
|
||||
) as client:
|
||||
return await internal_fn(client, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Public methods with injected client
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
100
enterprise/storage/org.py
Normal file
100
enterprise/storage/org.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import SecretStr
|
||||
from server.constants import DEFAULT_BILLING_MARGIN
|
||||
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class Org(Base): # type: ignore
|
||||
"""Organization model."""
|
||||
|
||||
__tablename__ = 'org'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
contact_name = Column(String, nullable=True)
|
||||
contact_email = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
default_max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
default_llm_model = Column(String, nullable=True)
|
||||
default_llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
|
||||
enable_proactive_conversation_starters = Column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
sandbox_base_container_image = Column(String, nullable=True)
|
||||
sandbox_runtime_container_image = Column(String, nullable=True)
|
||||
org_version = Column(Integer, nullable=False, default=0)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_search_api_key = Column(String, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_sandbox_api_key = Column(String, nullable=True)
|
||||
max_budget_per_task = Column(Float, nullable=True)
|
||||
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
condenser_max_size = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
)
|
||||
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
|
||||
api_keys = relationship('ApiKey', back_populates='org')
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
slack_users = relationship('SlackUser', back_populates='org')
|
||||
stripe_customers = relationship('StripeCustomer', back_populates='org')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'search_api_key' in kwargs:
|
||||
self.search_api_key = kwargs.pop('search_api_key')
|
||||
if 'sandbox_api_key' in kwargs:
|
||||
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def search_api_key(self) -> SecretStr | None:
|
||||
if self._search_api_key:
|
||||
decrypted = decrypt_value(self._search_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@search_api_key.setter
|
||||
def search_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._search_api_key = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def sandbox_api_key(self) -> SecretStr | None:
|
||||
if self._sandbox_api_key:
|
||||
decrypted = decrypt_value(self._sandbox_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@sandbox_api_key.setter
|
||||
def sandbox_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._sandbox_api_key = encrypt_value(raw) if raw else None
|
||||
67
enterprise/storage/org_member.py
Normal file
67
enterprise/storage/org_member.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class OrgMember(Base): # type: ignore
|
||||
"""Junction table for organization-member relationships with roles."""
|
||||
|
||||
__tablename__ = 'org_member'
|
||||
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
_llm_api_key = Column(String, nullable=False)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
llm_model = Column(String, nullable=True)
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
status = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='org_members')
|
||||
user = relationship('User', back_populates='org_members')
|
||||
role = relationship('Role', back_populates='org_members')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'llm_api_key' in kwargs:
|
||||
self.llm_api_key = kwargs.pop('llm_api_key')
|
||||
if 'llm_api_key_for_byor' in kwargs:
|
||||
self.llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def llm_api_key(self) -> SecretStr:
|
||||
decrypted = decrypt_value(self._llm_api_key)
|
||||
return SecretStr(decrypted)
|
||||
|
||||
@llm_api_key.setter
|
||||
def llm_api_key(self, value: str | SecretStr):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key = encrypt_value(raw)
|
||||
|
||||
@property
|
||||
def llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@llm_api_key_for_byor.setter
|
||||
def llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
125
enterprise/storage/org_member_store.py
Normal file
125
enterprise/storage/org_member_store.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Store class for managing organization-member relationships.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgMemberStore:
|
||||
"""Store for managing organization-member relationships."""
|
||||
|
||||
@staticmethod
|
||||
def add_user_to_org(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
role_id: int,
|
||||
llm_api_key: str,
|
||||
status: Optional[str] = None,
|
||||
) -> OrgMember:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
with session_maker() as session:
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status=status,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_org_members(org_id: UUID) -> list[OrgMember]:
|
||||
"""Get all users in an organization."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
|
||||
|
||||
@staticmethod
|
||||
def update_org_member(org_member: OrgMember) -> None:
|
||||
"""Update an organization-member relationship."""
|
||||
with session_maker() as session:
|
||||
session.merge(org_member)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
org_member.role_id = role_id
|
||||
if status is not None:
|
||||
org_member.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
session.delete(org_member)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
162
enterprise/storage/org_store.py
Normal file
162
enterprise/storage/org_store.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Store class for managing organizations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgStore:
|
||||
"""Store for managing organizations."""
|
||||
|
||||
@staticmethod
|
||||
def create_org(
|
||||
kwargs: dict,
|
||||
) -> Org:
|
||||
"""Create a new organization."""
|
||||
with session_maker() as session:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(f'User not found for ID {keycloak_user_id}')
|
||||
return None
|
||||
org_id = user.current_org_id
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
|
||||
)
|
||||
return None
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_name(name: str) -> Org | None:
|
||||
"""Get organization by name."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.name == name).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def _validate_org_version(org: Org) -> Org | None:
|
||||
"""Check if we need to update org version."""
|
||||
if org and org.org_version < ORG_SETTINGS_VERSION:
|
||||
org = OrgStore.update_org(
|
||||
org.id,
|
||||
{
|
||||
'org_version': ORG_SETTINGS_VERSION,
|
||||
'default_llm_model': get_default_litellm_model(),
|
||||
'llm_base_url': LITE_LLM_API_URL,
|
||||
},
|
||||
)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def list_orgs() -> list[Org]:
|
||||
"""List all organizations."""
|
||||
with session_maker() as session:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
kwargs: dict,
|
||||
) -> Optional[Org]:
|
||||
"""Update organization details."""
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
if 'id' in kwargs:
|
||||
kwargs.pop('id')
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(settings, normalized)
|
||||
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(user_settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(user_settings, normalized)
|
||||
|
||||
kwargs['org_version'] = user_settings.user_version
|
||||
return kwargs
|
||||
21
enterprise/storage/role.py
Normal file
21
enterprise/storage/role.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
SQLAlchemy model for Role.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Role(Base): # type: ignore
|
||||
"""Role model for user permissions."""
|
||||
|
||||
__tablename__ = 'role'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
rank = Column(Integer, nullable=False)
|
||||
|
||||
# Relationships
|
||||
users = relationship('User', back_populates='role')
|
||||
org_members = relationship('OrgMember', back_populates='role')
|
||||
40
enterprise/storage/role_store.py
Normal file
40
enterprise/storage/role_store.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Store class for managing roles.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class RoleStore:
|
||||
"""Store for managing roles."""
|
||||
|
||||
@staticmethod
|
||||
def create_role(name: str, rank: int) -> Role:
|
||||
"""Create a new role."""
|
||||
with session_maker() as session:
|
||||
role = Role(name=name, rank=rank)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
session.refresh(role)
|
||||
return role
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_id(role_id: int) -> Optional[Role]:
|
||||
"""Get role by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).order_by(Role.rank).all()
|
||||
@@ -4,10 +4,13 @@ import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
@@ -29,20 +32,37 @@ logger = logging.getLogger(__name__)
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(self, user_id: str, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.session_maker = session_maker
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
self.org_id = user.current_org_id if user else None
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
return (
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
query = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
)
|
||||
|
||||
if self.org_id is not None:
|
||||
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
|
||||
return query
|
||||
|
||||
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
if c.name != 'github_user_id' # Skip github_user_id field
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
@@ -53,6 +73,8 @@ class SaasConversationStore(ConversationStore):
|
||||
# Convert string to ProviderType enum
|
||||
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
|
||||
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove V1 attributes
|
||||
kwargs.pop('max_budget_per_task', None)
|
||||
kwargs.pop('cache_read_tokens', None)
|
||||
@@ -67,7 +89,10 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
@@ -81,7 +106,41 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== stored_metadata.conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=self.org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Validate
|
||||
expected_user_id = UUID(self.user_id)
|
||||
expected_org_id = self.org_id
|
||||
|
||||
if saas_metadata.user_id != expected_user_id:
|
||||
raise ValueError(
|
||||
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
|
||||
)
|
||||
|
||||
if expected_org_id and saas_metadata.org_id != expected_org_id:
|
||||
raise ValueError(
|
||||
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
@@ -101,8 +160,29 @@ class SaasConversationStore(ConversationStore):
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
self._select_by_id(session, conversation_id).delete()
|
||||
session.commit()
|
||||
saas_record = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
|
||||
StoredConversationMetadataSaas.org_id == self.org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saas_record:
|
||||
# Delete both records, but only if the SaaS one exists
|
||||
session.query(StoredConversationMetadata).filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
).delete()
|
||||
|
||||
session.delete(saas_record)
|
||||
|
||||
session.commit()
|
||||
else:
|
||||
# No SaaS record found → skip deleting main metadata
|
||||
session.rollback()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
|
||||
@@ -125,7 +205,15 @@ class SaasConversationStore(ConversationStore):
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
|
||||
@@ -8,11 +8,13 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,14 +26,17 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
settings = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
query = session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
if org_id is not None:
|
||||
query = query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
settings = query.all()
|
||||
|
||||
if not settings:
|
||||
return Secrets()
|
||||
@@ -48,6 +53,8 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
@@ -76,6 +83,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
for secret_name, secret_value, description in secret_tuples:
|
||||
new_secret = StoredCustomSecrets(
|
||||
keycloak_user_id=self.user_id,
|
||||
org_id=org_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
description=description,
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from integrations import stripe_service
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
USER_SETTINGS_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
_RETRY_LOAD_DELAY_SECONDS = 2
|
||||
# Redis key prefix for user creation locks
|
||||
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,8 +30,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
def get_user_settings_by_keycloak_id(
|
||||
def _get_user_settings_by_keycloak_id(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
@@ -85,354 +68,105 @@ class SaasSettingsStore(SettingsStore):
|
||||
return _get_settings()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
if not self.user_id:
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
with self.session_maker() as session:
|
||||
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
|
||||
logger.info(
|
||||
'saas_settings_store:load:triggering_migration',
|
||||
extra={'user_id': self.user_id},
|
||||
org_id = user.current_org_id
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
kwargs = {
|
||||
**{
|
||||
normalized: getattr(org, c.name)
|
||||
for c in Org.__table__.columns
|
||||
if (
|
||||
normalized := c.name.removeprefix('_default_')
|
||||
.removeprefix('default_')
|
||||
.lstrip('_')
|
||||
)
|
||||
return await self.create_default_settings(settings)
|
||||
kwargs = {
|
||||
c.name: getattr(settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
in Settings.model_fields
|
||||
},
|
||||
**{
|
||||
normalized: getattr(user, c.name)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) in Settings.model_fields
|
||||
},
|
||||
}
|
||||
kwargs['llm_api_key'] = org_member.llm_api_key
|
||||
if org_member.max_iterations:
|
||||
kwargs['max_iterations'] = org_member.max_iterations
|
||||
if org_member.llm_model:
|
||||
kwargs['llm_model'] = org_member.llm_model
|
||||
if org_member.llm_api_key_for_byor:
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
if org.v1_enabled is None:
|
||||
kwargs['v1_enabled'] = False
|
||||
|
||||
return settings
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if item and self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item)
|
||||
|
||||
with self.session_maker() as session:
|
||||
existing = None
|
||||
kwargs = {}
|
||||
if item:
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
self._encrypt_kwargs(kwargs)
|
||||
# First check if we have an existing entry in the new table
|
||||
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in UserSettings.__table__.columns
|
||||
}
|
||||
if existing:
|
||||
# Update existing entry
|
||||
for key, value in kwargs.items():
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
kwargs['keycloak_user_id'] = self.user_id
|
||||
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
|
||||
settings = UserSettings(**kwargs)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
|
||||
def _get_redis_client(self):
|
||||
"""Get the Redis client from the Socket.IO manager."""
|
||||
from openhands.server.shared import sio
|
||||
|
||||
return getattr(sio.manager, 'redis', None)
|
||||
|
||||
async def _acquire_user_creation_lock(self) -> bool:
|
||||
"""Attempt to acquire a distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
|
||||
Returns False if another process holds the lock.
|
||||
"""
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{self.user_id}'
|
||||
lock_acquired = await redis_client.set(
|
||||
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
async def create_default_settings(self, user_settings: UserSettings | None):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
# Prevent duplicate settings creation using distributed lock
|
||||
if not await self._acquire_user_creation_lock():
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
return await self.load()
|
||||
|
||||
# Only users that have specified a payment method get default settings
|
||||
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
|
||||
self.user_id
|
||||
):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:no_payment',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
settings: Settings | None = None
|
||||
if user_settings is None:
|
||||
settings = Settings(
|
||||
language='en',
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
elif isinstance(user_settings, UserSettings):
|
||||
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
|
||||
kwargs = {
|
||||
c.name: getattr(user_settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
if settings:
|
||||
settings = await self.update_settings_with_litellm_default(settings)
|
||||
if settings is None:
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:litellm_update_failed',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
await self.store(settings)
|
||||
return settings
|
||||
|
||||
async def load_legacy_file_store_settings(self, github_user_id: str):
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
|
||||
path = f'users/github/{github_user_id}/settings.json'
|
||||
|
||||
try:
|
||||
json_str = await call_sync_from_async(file_store.read, path)
|
||||
logger.info(
|
||||
'saas_settings_store:load_legacy_file_store_settings:found',
|
||||
extra={'github_user_id': github_user_id},
|
||||
)
|
||||
kwargs = json.loads(json_str)
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'saas_settings_store:load_legacy_file_store_settings:error',
|
||||
extra={'github_user_id': github_user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
def _has_custom_settings(
|
||||
self, settings: Settings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
settings.llm_model.strip()
|
||||
if settings.llm_model and settings.llm_model.strip()
|
||||
else None
|
||||
)
|
||||
user_base_url = (
|
||||
settings.llm_base_url.strip()
|
||||
if settings.llm_base_url and settings.llm_base_url.strip()
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version < CURRENT_USER_SETTINGS_VERSION
|
||||
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
|
||||
# Check if user has custom settings
|
||||
has_custom = self._has_custom_settings(settings, settings.user_version)
|
||||
|
||||
# Determine model to use (needed before LiteLLM user creation)
|
||||
llm_model_to_use = (
|
||||
settings.llm_model
|
||||
if has_custom and settings.llm_model
|
||||
else get_default_litellm_model()
|
||||
)
|
||||
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(self.user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Get the previous max budget to prevent accidental loss.
|
||||
#
|
||||
# LiteLLM v1.80+ returns 404 for non-existent users (previously returned empty user_info)
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
|
||||
)
|
||||
user_info: dict
|
||||
if response.status_code == 404:
|
||||
# New user - doesn't exist in LiteLLM yet (v1.80+ behavior)
|
||||
user_info = {}
|
||||
else:
|
||||
# For any other status, use standard error handling
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
user_info = response_json.get('user_info') or {}
|
||||
logger.info(
|
||||
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
|
||||
)
|
||||
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
|
||||
spend = user_info.get('spend') or 0
|
||||
if not item:
|
||||
return None
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = self.get_user_settings_by_keycloak_id(
|
||||
user_settings = self._get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
)
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
user_settings.billing_margin = 1.0
|
||||
session.commit()
|
||||
|
||||
email = keycloak_user_info.get('email')
|
||||
|
||||
# We explicitly delete here to guard against odd inherited settings on upgrade.
|
||||
# We don't care if this fails with a 404
|
||||
await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
|
||||
)
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={'user_id': self.user_id, 'email': email},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [self.user_id],
|
||||
'email': email,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
else:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:user_created',
|
||||
extra={'user_id': self.user_id},
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
|
||||
if has_custom:
|
||||
settings.llm_model = settings.llm_model or get_default_litellm_model()
|
||||
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
|
||||
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
|
||||
else:
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
settings.agent = 'CodeActAgent'
|
||||
|
||||
return settings
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -443,6 +177,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
for key, value in kwargs.items():
|
||||
@@ -486,21 +223,24 @@ class SaasSettingsStore(SettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _should_encrypt(self, key: str) -> bool:
|
||||
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings) -> None:
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key with the OpenHands alias exists,
|
||||
and reuses it if found. Otherwise, generates a new key.
|
||||
"""
|
||||
# Generate new key if none exists
|
||||
generated_key = await self._generate_openhands_key()
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
@@ -512,83 +252,3 @@ class SaasSettingsStore(SettingsStore):
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
max_budget: int,
|
||||
spend: int,
|
||||
llm_model: str,
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'user_id': str(self.user_id),
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': True,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': llm_model,
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _generate_openhands_key(self) -> str | None:
|
||||
"""Generate a new OpenHands provider key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': self.user_id,
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'saas_settings_store:_generate_openhands_key:success',
|
||||
extra={
|
||||
'user_id': self.user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': (
|
||||
key[:10] + '...' if key and len(key) > 10 else key
|
||||
),
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'saas_settings_store:_generate_openhands_key:no_key_in_response',
|
||||
extra={'user_id': self.user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'saas_settings_store:_generate_openhands_key:error',
|
||||
extra={'user_id': self.user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Boolean, Column, Identity, Integer, String
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,5 +10,9 @@ class SlackConversation(Base): # type: ignore
|
||||
conversation_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=False)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
parent_id = Column(String, nullable=True, index=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_conversations')
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -6,6 +8,7 @@ class SlackUser(Base): # type: ignore
|
||||
__tablename__ = 'slack_users'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
slack_user_id = Column(String, nullable=False, index=True)
|
||||
slack_display_name = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -13,3 +16,6 @@ class SlackUser(Base): # type: ignore
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_users')
|
||||
|
||||
@@ -4,5 +4,4 @@ from openhands.app_server.app_conversation.sql_app_conversation_info_service imp
|
||||
|
||||
StoredConversationMetadata = _StoredConversationMetadata
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadata']
|
||||
|
||||
28
enterprise/storage/stored_conversation_metadata_saas.py
Normal file
28
enterprise/storage/stored_conversation_metadata_saas.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
SQLAlchemy model for ConversationMetadataSaas.
|
||||
|
||||
This model stores the SaaS-specific metadata for conversations,
|
||||
containing only the conversation_id, user_id, and org_id.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID as SQL_UUID
|
||||
from sqlalchemy import Column, ForeignKey, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredConversationMetadataSaas(Base): # type: ignore
|
||||
"""SaaS conversation metadata model containing user and org associations."""
|
||||
|
||||
__tablename__ = 'conversation_metadata_saas'
|
||||
|
||||
conversation_id = Column(String, primary_key=True)
|
||||
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', back_populates='stored_conversation_metadata_saas')
|
||||
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadataSaas']
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -6,6 +8,10 @@ class StoredCustomSecrets(Base): # type: ignore
|
||||
__tablename__ = 'custom_secrets'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
secret_name = Column(String, nullable=False)
|
||||
secret_value = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='user_secrets')
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,6 +15,7 @@ class StripeCustomer(Base): # type: ignore
|
||||
__tablename__ = 'stripe_customers'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
stripe_customer_id = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
@@ -23,3 +26,6 @@ class StripeCustomer(Base): # type: ignore
|
||||
onupdate=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='stripe_customers')
|
||||
|
||||
41
enterprise/storage/user.py
Normal file
41
enterprise/storage/user.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
SQLAlchemy model for User.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class User(Base): # type: ignore
|
||||
"""User model with organizational relationships."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
|
||||
accepted_tos = Column(DateTime, nullable=True)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True)
|
||||
language = Column(String, nullable=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
org_members = relationship('OrgMember', back_populates='user')
|
||||
current_org = relationship('Org', back_populates='current_users')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='user'
|
||||
)
|
||||
@@ -39,3 +39,6 @@ class UserSettings(Base): # type: ignore
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
already_migrated = Column(
|
||||
Boolean, nullable=True, default=False
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
466
enterprise/storage/user_store.py
Normal file
466
enterprise/storage/user_store.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Store class for managing users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
_RETRY_LOAD_DELAY_SECONDS = 2
|
||||
# Redis key prefix for user creation locks
|
||||
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
|
||||
|
||||
|
||||
class UserStore:
|
||||
"""Store for managing users."""
|
||||
|
||||
@staticmethod
|
||||
async def create_user(
|
||||
user_id: str,
|
||||
user_info: dict,
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
"""Create a new user."""
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
v1_enabled=True,
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org.id), user_id=user_id
|
||||
)
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
|
||||
# avoid setting org member llm fields to use org defaults on user creation
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def _get_redis_client():
|
||||
"""Get the Redis client from the Socket.IO manager."""
|
||||
from openhands.server.shared import sio
|
||||
|
||||
return getattr(sio.manager, 'redis', None)
|
||||
|
||||
@staticmethod
|
||||
async def _acquire_user_creation_lock(user_id: str) -> bool:
|
||||
"""Attempt to acquire a distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
|
||||
Returns False if another process holds the lock.
|
||||
"""
|
||||
redis_client = UserStore._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
|
||||
lock_acquired = await redis_client.set(
|
||||
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
user_id: str,
|
||||
user_settings: UserSettings,
|
||||
user_info: dict,
|
||||
) -> User:
|
||||
if not user_id or not user_settings:
|
||||
return None
|
||||
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
org_version=user_settings.user_version,
|
||||
contact_name=user_info['username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
custom_settings = UserStore._has_custom_settings(
|
||||
decrypted_user_settings, user_settings.user_version
|
||||
)
|
||||
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
await migrate_customer(session, user_id, org)
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_user_settings(decrypted_user_settings)
|
||||
org_kwargs.pop('id', None)
|
||||
|
||||
# if user has custom settings, set org defaults to current version
|
||||
if custom_settings:
|
||||
org_kwargs['default_llm_model'] = get_default_litellm_model()
|
||||
org_kwargs['llm_base_url'] = LITE_LLM_API_URL
|
||||
org_kwargs['org_version'] = ORG_SETTINGS_VERSION
|
||||
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
user_kwargs.pop('id', None)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=None,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
|
||||
# if the user did not have custom settings in the old model,
|
||||
# then use the org defaults by not setting org_member fields
|
||||
if not custom_settings:
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
del org_member_kwargs['llm_api_key_for_byor']
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
|
||||
# Mark the old user_settings as migrated instead of deleting
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
text("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
conversation_id,
|
||||
:user_id,
|
||||
:user_id
|
||||
FROM conversation_metadata
|
||||
WHERE user_id = :user_id
|
||||
"""),
|
||||
{'user_id': user_id},
|
||||
)
|
||||
|
||||
# Update org_id for tables that had org_id added
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Update stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_users
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update custom_secrets
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
while not call_async_from_sync(
|
||||
UserStore._acquire_user_creation_lock, GENERAL_TIMEOUT, user_id
|
||||
):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
call_async_from_sync(
|
||||
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
|
||||
)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
org_id: str, user_id: str
|
||||
) -> Optional['Settings']:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:start',
|
||||
extra={'org_id': org_id, 'user_id': user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
settings = Settings(language='en', enable_proactive_conversation_starters=True)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
settings = await LiteLlmManager.create_entries(org_id, user_id, settings)
|
||||
if not settings:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:litellm_create_failed',
|
||||
extra={'org_id': org_id},
|
||||
)
|
||||
return None
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: 'Settings'):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _has_custom_settings(
|
||||
user_settings: UserSettings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
user_settings.llm_model.strip() or None if user_settings.llm_model else None
|
||||
)
|
||||
user_base_url = (
|
||||
user_settings.llm_base_url.strip() or None
|
||||
if user_settings.llm_base_url
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version <= ORG_SETTINGS_VERSION
|
||||
and old_user_version in PERSONAL_WORKSPACE_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = PERSONAL_WORKSPACE_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import (
|
||||
UserVersionUpgradeProcessor,
|
||||
)
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -15,11 +14,16 @@ from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -68,7 +72,6 @@ def add_minimal_fixtures(session_maker):
|
||||
session.add(
|
||||
StoredConversationMetadata(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id='mock-user-id',
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
last_updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
accumulated_cost=5.25,
|
||||
@@ -77,6 +80,13 @@ def add_minimal_fixtures(session_maker):
|
||||
total_tokens=750,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredConversationMetadataSaas(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredOfflineToken(
|
||||
user_id='mock-user-id',
|
||||
@@ -85,7 +95,38 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
|
||||
session.add(
|
||||
Org(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
name='mock-org',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Role(
|
||||
id=1,
|
||||
name='admin',
|
||||
rank=1,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
User(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
OrgMember(
|
||||
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
role_id=1,
|
||||
llm_api_key='mock-api-key',
|
||||
status='active',
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='mock-user-id',
|
||||
@@ -94,13 +135,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-10'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UserSettings(
|
||||
keycloak_user_id='mock-user-id',
|
||||
user_consents_to_analytics=True,
|
||||
user_version=CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
ConversationWork(
|
||||
conversation_id='mock-conversation-id',
|
||||
@@ -109,17 +143,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
maintenance_task = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.PENDING,
|
||||
)
|
||||
maintenance_task.set_processor(
|
||||
UserVersionUpgradeProcessor(
|
||||
user_ids=['mock-user-id'],
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(maintenance_task)
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -7,16 +7,16 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.api_keys import (
|
||||
get_llm_api_key_for_byor,
|
||||
verify_byor_key_in_litellm,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
|
||||
class TestVerifyByorKeyInLitellm:
|
||||
"""Test the verify_byor_key_in_litellm function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_valid_key_returns_true(self, mock_client_class):
|
||||
"""Test that a valid key (200 response) returns True."""
|
||||
# Arrange
|
||||
@@ -32,7 +32,7 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
@@ -42,8 +42,8 @@ class TestVerifyByorKeyInLitellm:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_invalid_key_401_returns_false(self, mock_client_class):
|
||||
"""Test that an invalid key (401 response) returns False."""
|
||||
# Arrange
|
||||
@@ -58,14 +58,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_invalid_key_403_returns_false(self, mock_client_class):
|
||||
"""Test that an invalid key (403 response) returns False."""
|
||||
# Arrange
|
||||
@@ -80,14 +80,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_server_error_returns_false(self, mock_client_class):
|
||||
"""Test that a server error (500) returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -103,14 +103,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_timeout_returns_false(self, mock_client_class):
|
||||
"""Test that a timeout returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -123,14 +123,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_network_error_returns_false(self, mock_client_class):
|
||||
"""Test that a network error returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -143,13 +143,13 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', None)
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', None)
|
||||
async def test_verify_missing_api_url_returns_false(self):
|
||||
"""Test that missing LITE_LLM_API_URL returns False."""
|
||||
# Arrange
|
||||
@@ -157,13 +157,13 @@ class TestVerifyByorKeyInLitellm:
|
||||
user_id = 'user-123'
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
async def test_verify_empty_key_returns_false(self):
|
||||
"""Test that empty key returns False."""
|
||||
# Arrange
|
||||
@@ -171,7 +171,7 @@ class TestVerifyByorKeyInLitellm:
|
||||
user_id = 'user-123'
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
@@ -205,7 +205,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_valid_key_in_database_returns_key(
|
||||
self, mock_get_key, mock_verify_key
|
||||
@@ -229,7 +229,7 @@ class TestGetLlmApiKeyForByor:
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_invalid_key_in_database_regenerates(
|
||||
self,
|
||||
@@ -265,7 +265,7 @@ class TestGetLlmApiKeyForByor:
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_invalid_key_deletion_failure_still_regenerates(
|
||||
self,
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestGetUserId:
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
user_id = _get_user_id('mock-conversation-id')
|
||||
assert user_id == 'mock-user-id'
|
||||
assert user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
|
||||
def test_get_user_id_conversation_not_found(self, session_maker):
|
||||
"""Test getting user ID when conversation doesn't exist."""
|
||||
@@ -105,10 +105,12 @@ class TestGetSessionApiKey:
|
||||
return_value=[mock_agent_loop_info]
|
||||
)
|
||||
|
||||
api_key = await _get_session_api_key('user-123', 'conv-456')
|
||||
api_key = await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
assert api_key == 'test-api-key'
|
||||
mock_manager.get_agent_loop_info.assert_called_once_with(
|
||||
'user-123', filter_to_sids={'conv-456'}
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', filter_to_sids={'conv-456'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -118,7 +120,9 @@ class TestGetSessionApiKey:
|
||||
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
await _get_session_api_key('user-123', 'conv-456')
|
||||
await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
|
||||
|
||||
class TestProcessEvent:
|
||||
@@ -142,10 +146,15 @@ class TestProcessEvent:
|
||||
mock_event = MagicMock()
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once_with(
|
||||
'users/user-123/conversations/conv-456/events/event-1.json',
|
||||
'users/5594c7b6-f959-4b81-92e9-b09c206f5081/conversations/conv-456/events/event-1.json',
|
||||
json.dumps(content),
|
||||
)
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
@@ -177,14 +186,19 @@ class TestProcessEvent:
|
||||
)
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
mock_invoke_callbacks.assert_called_once_with('conv-456', mock_event)
|
||||
mock_update_working_seconds.assert_called_once()
|
||||
mock_event_store_class.assert_called_once_with(
|
||||
'conv-456', mock_file_store, 'user-123'
|
||||
'conv-456', mock_file_store, '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -212,7 +226,12 @@ class TestProcessEvent:
|
||||
mock_event.agent_state = 'running' # Set RUNNING state to skip the update
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
"""Tests for SaasSQLAppConversationInfoService.
|
||||
|
||||
This module tests the SAAS implementation of SQLAppConversationInfoService,
|
||||
focusing on user isolation, SAAS metadata handling, and multi-tenant functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_user(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance with a user_id for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111'),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info() -> AppConversationInfo:
|
||||
"""Create a sample AppConversationInfo for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_123',
|
||||
selected_repository='https://github.com/test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123, 456],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_conversation_infos() -> list[AppConversationInfo]:
|
||||
"""Create multiple AppConversationInfo instances for testing."""
|
||||
base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
return [
|
||||
AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=None,
|
||||
sandbox_id=f'sandbox_{i}',
|
||||
selected_repository=f'https://github.com/test/repo{i}',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'Test Conversation {i}',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[i * 100],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=base_time.replace(hour=12 + i),
|
||||
updated_at=base_time.replace(hour=12 + i, minute=30),
|
||||
)
|
||||
for i in range(1, 6) # Create 5 conversations
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user1_context():
|
||||
"""Create user context for user1."""
|
||||
return SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user2_context():
|
||||
"""Create user context for user2."""
|
||||
return SpecifyUserContext(user_id='b2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user1(mock_db_session, user1_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user1."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user1_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user2(mock_db_session, user2_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user2."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user2_context
|
||||
)
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoService:
|
||||
"""Test suite for SaasSQLAppConversationInfoService."""
|
||||
|
||||
def test_service_initialization(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
user1_context: SpecifyUserContext,
|
||||
):
|
||||
"""Test that the SAAS service is properly initialized."""
|
||||
assert saas_service_user1.user_context == user1_context
|
||||
assert saas_service_user1.db_session is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_isolation(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
saas_service_user2: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that different service instances have different user contexts."""
|
||||
user1_id = await saas_service_user1.user_context.get_user_id()
|
||||
user2_id = await saas_service_user2.user_context.get_user_id()
|
||||
|
||||
assert user1_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert user2_id == 'b2222222-2222-2222-2222-222222222222'
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _to_info_with_user_id properly sets user_id from SAAS metadata."""
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Create mock metadata objects
|
||||
stored_metadata = MagicMock(spec=StoredConversationMetadata)
|
||||
stored_metadata.conversation_id = '12345678-1234-5678-1234-567812345678'
|
||||
stored_metadata.parent_conversation_id = None
|
||||
stored_metadata.title = 'Test Conversation'
|
||||
stored_metadata.sandbox_id = 'test-sandbox'
|
||||
stored_metadata.selected_repository = None
|
||||
stored_metadata.selected_branch = None
|
||||
stored_metadata.git_provider = None
|
||||
stored_metadata.trigger = None
|
||||
stored_metadata.pr_number = []
|
||||
stored_metadata.llm_model = None
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stored_metadata.created_at = datetime.now(timezone.utc)
|
||||
stored_metadata.last_updated_at = datetime.now(timezone.utc)
|
||||
stored_metadata.accumulated_cost = 0.0
|
||||
stored_metadata.prompt_tokens = 0
|
||||
stored_metadata.completion_tokens = 0
|
||||
stored_metadata.total_tokens = 0
|
||||
stored_metadata.max_budget_per_task = None
|
||||
stored_metadata.cache_read_tokens = 0
|
||||
stored_metadata.cache_write_tokens = 0
|
||||
stored_metadata.reasoning_tokens = 0
|
||||
stored_metadata.context_window = 0
|
||||
stored_metadata.per_turn_token = 0
|
||||
|
||||
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
|
||||
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
# Test the _to_info_with_user_id method
|
||||
result = saas_service_user1._to_info_with_user_id(
|
||||
stored_metadata, saas_metadata
|
||||
)
|
||||
|
||||
# Verify that the user_id from SAAS metadata is used
|
||||
assert result.created_by_user_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert result.title == 'Test Conversation'
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
|
||||
# Save conversations
|
||||
await user1_service.save_app_conversation_info(user1_info)
|
||||
await user2_service.save_app_conversation_info(user2_info)
|
||||
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
assert user2_from_user1 is None
|
||||
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
@@ -19,6 +19,14 @@ def mock_session_maker(mock_session):
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = 'test-org-123'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store(mock_session_maker):
|
||||
return ApiKeyStore(mock_session_maker)
|
||||
@@ -33,11 +41,13 @@ def test_generate_api_key(api_key_store):
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
|
||||
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test creating an API key."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
name = 'Test Key'
|
||||
mock_get_user.return_value = mock_user
|
||||
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
|
||||
|
||||
# Execute
|
||||
@@ -45,10 +55,15 @@ def test_create_api_key(api_key_store, mock_session):
|
||||
|
||||
# Verify
|
||||
assert result == 'test-api-key'
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
api_key_store.generate_api_key.assert_called_once()
|
||||
|
||||
# Verify the ApiKey was created with the correct org_id
|
||||
added_api_key = mock_session.add.call_args[0][0]
|
||||
assert added_api_key.org_id == mock_user.current_org_id
|
||||
|
||||
|
||||
def test_validate_api_key_valid(api_key_store, mock_session):
|
||||
"""Test validating a valid API key."""
|
||||
@@ -204,10 +219,12 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_list_api_keys(api_key_store, mock_session):
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
mock_key1 = MagicMock()
|
||||
mock_key1.id = 1
|
||||
@@ -223,15 +240,17 @@ def test_list_api_keys(api_key_store, mock_session):
|
||||
mock_key2.last_used_at = None
|
||||
mock_key2.expires_at = None
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_key1,
|
||||
mock_key2,
|
||||
]
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_key1, mock_key2]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0]['id'] == 1
|
||||
assert result[0]['name'] == 'Key 1'
|
||||
@@ -244,3 +263,59 @@ def test_list_api_keys(api_key_store, mock_session):
|
||||
assert result[1]['created_at'] == now
|
||||
assert result[1]['last_used_at'] is None
|
||||
assert result[1]['expires_at'] is None
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_mcp_key = MagicMock()
|
||||
mock_mcp_key.name = 'MCP_API_KEY'
|
||||
mock_mcp_key.key = 'mcp-test-key'
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result == 'mcp-test-key'
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result is None
|
||||
|
||||
@@ -130,6 +130,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -144,6 +145,15 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = None
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = False
|
||||
|
||||
@@ -165,20 +175,19 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -228,6 +237,7 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -243,6 +253,13 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -267,6 +284,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -282,6 +300,13 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -309,20 +334,20 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
),
|
||||
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
|
||||
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
@@ -535,6 +560,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -549,6 +575,14 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
@@ -576,6 +610,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -602,6 +637,15 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -629,6 +673,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -655,6 +700,15 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -680,6 +734,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -706,6 +761,16 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -725,6 +790,7 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
@@ -741,6 +807,13 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -761,6 +834,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
@@ -777,6 +851,13 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -796,6 +877,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -825,6 +907,14 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -846,6 +936,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -873,6 +964,14 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -898,6 +997,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -924,6 +1024,14 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1043,6 +1151,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1071,6 +1180,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1112,6 +1229,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -1127,6 +1245,13 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
@@ -1172,6 +1297,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1200,6 +1326,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1250,6 +1384,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1278,6 +1413,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1325,6 +1468,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1353,6 +1497,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1399,6 +1551,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1427,6 +1580,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1470,6 +1631,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1498,6 +1660,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1527,6 +1697,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1555,6 +1726,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1590,6 +1769,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.auth.logger') as mock_logger,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1618,6 +1798,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1665,6 +1853,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.logger') as mock_logger,
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -1680,6 +1869,13 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -5,22 +6,21 @@ import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import Response
|
||||
from integrations.stripe_service import has_payment_method
|
||||
from server.routes import billing
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
CreateCheckoutSessionRequest,
|
||||
GetCreditsResponse,
|
||||
cancel_callback,
|
||||
cancel_subscription,
|
||||
create_checkout_session,
|
||||
create_subscription_checkout_session,
|
||||
create_customer_setup_session,
|
||||
get_credits,
|
||||
has_payment_method,
|
||||
success_callback,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
|
||||
|
||||
@@ -78,28 +78,31 @@ def mock_subscription_request():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
mock_response = Response(
|
||||
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_credits('mock_user')
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert (
|
||||
exc_info.value.detail
|
||||
== 'Failed to retrieve credit balance from billing service'
|
||||
)
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await get_credits('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_success():
|
||||
mock_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
json={
|
||||
'user_info': {
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
}
|
||||
},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
@@ -108,24 +111,22 @@ async def test_get_credits_success():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch('httpx.AsyncClient', return_value=mock_client),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
):
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = MagicMock(
|
||||
billing_margin=4
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal(
|
||||
'74.50'
|
||||
) # 100.00 - 25.50 = 74.50 (no billing margin applied)
|
||||
mock_client.__aenter__.return_value.get.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/info?user_id=mock_user',
|
||||
headers={'x-goog-api-key': None},
|
||||
)
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -138,6 +139,9 @@ async def test_create_checkout_session_stripe_error(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = uuid.uuid4()
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
pytest.raises(Exception, match='Stripe API Error'),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
@@ -149,6 +153,10 @@ async def test_create_checkout_session_stripe_error(
|
||||
AsyncMock(side_effect=Exception('Stripe API Error')),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -174,6 +182,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_org.id = mock_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
patch(
|
||||
@@ -182,6 +194,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -253,7 +269,6 @@ async def test_success_callback_stripe_incomplete():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
@@ -281,44 +296,33 @@ async def test_success_callback_success():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
mock_lite_llm_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_lite_llm_update_response = Response(
|
||||
status_code=200, json={}, request=MagicMock()
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_user_settings = MagicMock(billing_margin=None)
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_user_settings
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=2500,
|
||||
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
||||
) # $25.00 in cents
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.return_value = (
|
||||
mock_lite_llm_response
|
||||
)
|
||||
mock_client_instance.__aenter__.return_value.post.return_value = (
|
||||
mock_lite_llm_update_response
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
response = await success_callback('test_session_id', mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
@@ -328,18 +332,14 @@ async def test_success_callback_success():
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
mock_client_instance.__aenter__.return_value.get.assert_called_once()
|
||||
mock_client_instance.__aenter__.return_value.post.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/update',
|
||||
headers={'x-goog-api-key': None},
|
||||
json={
|
||||
'user_id': 'mock_user',
|
||||
'max_budget': 125,
|
||||
}, # 100 + (25.00 from Stripe)
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@@ -353,27 +353,27 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_total=2500
|
||||
status='complete', amount_subtotal=2500
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.side_effect = Exception(
|
||||
'LiteLLM API Error'
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
@@ -397,7 +397,8 @@ async def test_cancel_callback_session_not_found():
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
@@ -423,7 +424,8 @@ async def test_cancel_callback_success():
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
@@ -435,314 +437,67 @@ async def test_cancel_callback_success():
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_with_payment_method():
|
||||
"""Test has_payment_method returns True when user has a payment method."""
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[MagicMock()])),
|
||||
) as mock_list_payment_methods,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
mock_has_payment_method = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
):
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is True
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_without_payment_method():
|
||||
"""Test has_payment_method returns False when user has no payment method."""
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
) as mock_list_payment_methods,
|
||||
mock_has_payment_method = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
mock_has_payment_method.return_value = False
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is False
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_success():
|
||||
"""Test successful subscription cancellation."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
async def test_create_customer_setup_session_success():
|
||||
"""Test successful creation of customer setup session."""
|
||||
mock_request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/create-customer-setup-session',
|
||||
'server': ('test.com', 80),
|
||||
'headers': [],
|
||||
}
|
||||
)
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
# Mock Stripe subscription response
|
||||
mock_stripe_subscription = MagicMock()
|
||||
mock_stripe_subscription.cancel_at_period_end = True
|
||||
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_create = AsyncMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(return_value=mock_stripe_subscription),
|
||||
) as mock_stripe_modify,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function
|
||||
result = await cancel_subscription('test_user')
|
||||
|
||||
# Verify Stripe API was called
|
||||
mock_stripe_modify.assert_called_once_with(
|
||||
'sub_test123', cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Verify database was updated
|
||||
assert mock_subscription_access.cancelled_at is not None
|
||||
mock_session.merge.assert_called_once_with(mock_subscription_access)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify response
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_no_active_subscription():
|
||||
"""Test cancellation when no active subscription exists."""
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session with no subscription found
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'No active subscription found' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_missing_stripe_id():
|
||||
"""Test cancellation when subscription has no Stripe ID."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock subscription without Stripe ID
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id=None, # Missing Stripe ID
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'missing Stripe subscription ID' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_stripe_error():
|
||||
"""Test cancellation when Stripe API fails."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(side_effect=stripe.StripeError('API Error')),
|
||||
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
),
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to cancel subscription' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription when user already has active subscription raises error."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return existing active subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
result = await create_customer_setup_session(mock_request, 'mock_user')
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
'user already has an active subscription'
|
||||
in str(exc_info.value.detail).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert isinstance(result, billing.CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_success_no_existing(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test successful subscription creation when no existing subscription."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return no existing subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
# Verify Stripe session creation parameters
|
||||
mock_create.assert_called_once_with(
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='http://test.com/?free_credits=success',
|
||||
cancel_url='http://test.com/',
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
@@ -3,6 +3,7 @@ Tests for ConversationCallbackProcessor and ConversationCallback models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.conversation_callback import (
|
||||
@@ -11,6 +12,9 @@ from storage.conversation_callback import (
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
@@ -80,15 +84,22 @@ class TestConversationCallback:
|
||||
"""Create a test conversation metadata record."""
|
||||
with session_maker() as session:
|
||||
metadata = StoredConversationMetadata(
|
||||
conversation_id='test_conversation_123', user_id='test_user_456'
|
||||
conversation_id='test_conversation_123'
|
||||
)
|
||||
metadata_saas = StoredConversationMetadataSaas(
|
||||
conversation_id='test_conversation_123',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
session.add(metadata)
|
||||
session.add(metadata_saas)
|
||||
session.commit()
|
||||
session.refresh(metadata)
|
||||
yield metadata
|
||||
|
||||
# Cleanup
|
||||
session.delete(metadata)
|
||||
session.delete(metadata_saas)
|
||||
session.commit()
|
||||
|
||||
def test_callback_creation(self, conversation_metadata, session_maker):
|
||||
|
||||
@@ -1,114 +1,100 @@
|
||||
"""Unit tests for get_user_v1_enabled_setting function."""
|
||||
"""Unit tests for get_user_v1_enabled_setting and is_v1_enabled_for_github_resolver functions."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.utils import get_user_v1_enabled_setting
|
||||
from integrations.github.github_view import (
|
||||
get_user_v1_enabled_setting,
|
||||
is_v1_enabled_for_github_resolver,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_settings():
|
||||
"""Create a mock user settings object."""
|
||||
settings = MagicMock()
|
||||
settings.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return settings
|
||||
def mock_org():
|
||||
"""Create a mock org object."""
|
||||
org = MagicMock()
|
||||
org.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_store():
|
||||
"""Create a mock settings store."""
|
||||
store = MagicMock()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock config object."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker():
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(
|
||||
mock_settings_store, mock_config, mock_session_maker, mock_user_settings
|
||||
):
|
||||
def mock_dependencies(mock_org):
|
||||
"""Fixture that patches all the common dependencies."""
|
||||
# Patch at the source module since SaasSettingsStore is imported inside the function
|
||||
with patch(
|
||||
'storage.saas_settings_store.SaasSettingsStore',
|
||||
return_value=mock_settings_store,
|
||||
) as mock_store_class, patch(
|
||||
'integrations.utils.get_config', return_value=mock_config
|
||||
) as mock_get_config, patch(
|
||||
'integrations.utils.session_maker', mock_session_maker
|
||||
), patch(
|
||||
'integrations.utils.call_sync_from_async',
|
||||
return_value=mock_user_settings,
|
||||
) as mock_call_sync:
|
||||
return_value=mock_org,
|
||||
) as mock_call_sync, patch('integrations.utils.OrgStore') as mock_org_store:
|
||||
yield {
|
||||
'store_class': mock_store_class,
|
||||
'get_config': mock_get_config,
|
||||
'session_maker': mock_session_maker,
|
||||
'call_sync': mock_call_sync,
|
||||
'settings_store': mock_settings_store,
|
||||
'user_settings': mock_user_settings,
|
||||
'org_store': mock_org_store,
|
||||
'org': mock_org,
|
||||
}
|
||||
|
||||
|
||||
class TestGetUserV1EnabledSetting:
|
||||
"""Test cases for get_user_v1_enabled_setting function."""
|
||||
class TestIsV1EnabledForGithubResolver:
|
||||
"""Test cases for is_v1_enabled_for_github_resolver function.
|
||||
|
||||
This function returns True only if BOTH the environment variable
|
||||
ENABLE_V1_GITHUB_RESOLVER is true AND the user's org has v1_enabled=True.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'user_setting_enabled,expected_result',
|
||||
'env_var_enabled,user_setting_enabled,expected_result',
|
||||
[
|
||||
(True, True), # User enabled -> True
|
||||
(False, False), # User disabled -> False
|
||||
(False, True, False), # Env var disabled, user enabled -> False
|
||||
(True, False, False), # Env var enabled, user disabled -> False
|
||||
(True, True, True), # Both enabled -> True
|
||||
(False, False, False), # Both disabled -> False
|
||||
],
|
||||
)
|
||||
async def test_v1_enabled_user_setting(
|
||||
self, mock_dependencies, user_setting_enabled, expected_result
|
||||
async def test_v1_enabled_combinations(
|
||||
self, mock_dependencies, env_var_enabled, user_setting_enabled, expected_result
|
||||
):
|
||||
"""Test that the function returns the user's v1_enabled setting."""
|
||||
mock_dependencies['user_settings'].v1_enabled = user_setting_enabled
|
||||
"""Test all combinations of environment variable and user setting values."""
|
||||
mock_dependencies['org'].v1_enabled = user_setting_enabled
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is expected_result
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_enabled
|
||||
):
|
||||
result = await is_v1_enabled_for_github_resolver('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_no_user_id(self):
|
||||
"""Test that the function returns False when no user_id is provided."""
|
||||
result = await get_user_v1_enabled_setting(None)
|
||||
assert result is False
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_value,env_var_bool,expected_result',
|
||||
[
|
||||
('false', False, False), # Environment variable 'false' -> False
|
||||
('true', True, True), # Environment variable 'true' -> True
|
||||
],
|
||||
)
|
||||
async def test_environment_variable_integration(
|
||||
self, mock_dependencies, env_var_value, env_var_bool, expected_result
|
||||
):
|
||||
"""Test that the function properly reads the ENABLE_V1_GITHUB_RESOLVER environment variable."""
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
|
||||
result = await get_user_v1_enabled_setting('')
|
||||
assert result is False
|
||||
with patch.dict(
|
||||
os.environ, {'ENABLE_V1_GITHUB_RESOLVER': env_var_value}
|
||||
), patch('integrations.utils.os.getenv', return_value=env_var_value), patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_bool
|
||||
):
|
||||
result = await is_v1_enabled_for_github_resolver('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_settings_is_none(self, mock_dependencies):
|
||||
"""Test that the function returns False when settings is None."""
|
||||
mock_dependencies['call_sync'].return_value = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is False
|
||||
class TestGetUserV1EnabledSetting:
|
||||
"""Test cases for get_user_v1_enabled_setting function.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_v1_enabled_is_none(self, mock_dependencies):
|
||||
"""Test that the function returns False when v1_enabled is None."""
|
||||
mock_dependencies['user_settings'].v1_enabled = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is False
|
||||
This function only returns the user's org v1_enabled setting.
|
||||
It does NOT check the ENABLE_V1_GITHUB_RESOLVER environment variable.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calls_correct_methods(self, mock_dependencies):
|
||||
"""Test that the function calls the correct methods with correct parameters."""
|
||||
mock_dependencies['user_settings'].v1_enabled = True
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
|
||||
@@ -116,13 +102,38 @@ class TestGetUserV1EnabledSetting:
|
||||
assert result is True
|
||||
|
||||
# Verify correct methods were called with correct parameters
|
||||
mock_dependencies['get_config'].assert_called_once()
|
||||
mock_dependencies['store_class'].assert_called_once_with(
|
||||
user_id='test_user_123',
|
||||
session_maker=mock_dependencies['session_maker'],
|
||||
config=mock_dependencies['get_config'].return_value,
|
||||
)
|
||||
mock_dependencies['call_sync'].assert_called_once_with(
|
||||
mock_dependencies['settings_store'].get_user_settings_by_keycloak_id,
|
||||
mock_dependencies['org_store'].get_current_org_from_keycloak_user_id,
|
||||
'test_user_123',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_setting_true(self, mock_dependencies):
|
||||
"""Test that the function returns True when org.v1_enabled is True."""
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_setting_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when org.v1_enabled is False."""
|
||||
mock_dependencies['org'].v1_enabled = False
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_org_returns_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when no org is found."""
|
||||
# Mock call_sync_from_async to return None (no org found)
|
||||
mock_dependencies['call_sync'].return_value = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_v1_enabled_none_returns_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when org.v1_enabled is None."""
|
||||
mock_dependencies['org'].v1_enabled = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
653
enterprise/tests/unit/test_lite_llm_manager.py
Normal file
653
enterprise/tests/unit/test_lite_llm_manager.py
Normal file
@@ -0,0 +1,653 @@
|
||||
"""
|
||||
Unit tests for LiteLlmManager class.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
class TestLiteLlmManager:
|
||||
"""Test cases for LiteLlmManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create a mock Settings object."""
|
||||
settings = Settings()
|
||||
settings.agent = 'TestAgent'
|
||||
settings.llm_model = 'test-model'
|
||||
settings.llm_api_key = SecretStr('test-key')
|
||||
settings.llm_base_url = 'http://test.com'
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_settings(self):
|
||||
"""Create a mock UserSettings object."""
|
||||
user_settings = UserSettings()
|
||||
user_settings.agent = 'TestAgent'
|
||||
user_settings.llm_model = 'test-model'
|
||||
user_settings.llm_api_key = SecretStr('test-key')
|
||||
user_settings.llm_base_url = 'http://test.com'
|
||||
user_settings.user_version = 4 # Set version to avoid None comparison
|
||||
return user_settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client(self):
|
||||
"""Create a mock HTTP client."""
|
||||
client = AsyncMock(spec=httpx.AsyncClient)
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.text = 'Success'
|
||||
response.json.return_value = {'key': 'test-api-key'}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_team_response(self):
|
||||
"""Create a mock team response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'team_memberships': [
|
||||
{
|
||||
'user_id': 'test-user-id',
|
||||
'team_id': 'test-org-id',
|
||||
'max_budget': 100.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_response(self):
|
||||
"""Create a mock user response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'user_info': {
|
||||
'max_budget': 50.0,
|
||||
'spend': 10.0,
|
||||
}
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_key_info_response(self):
|
||||
"""Create a mock key info response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'info': {
|
||||
'max_budget': 100.0,
|
||||
'spend': 25.0,
|
||||
}
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_missing_config(self, mock_settings):
|
||||
"""Test create_entries when LiteLLM config is missing."""
|
||||
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_local_deployment(self, mock_settings):
|
||||
"""Test create_entries in local deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
|
||||
"""Test create_entries in cloud deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert (
|
||||
result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
)
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify API calls were made
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, create_user, add_user_to_team, generate_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_missing_config(self, mock_user_settings):
|
||||
"""Test migrate_entries when LiteLLM config is missing."""
|
||||
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_local_deployment(self, mock_user_settings):
|
||||
"""Test migrate_entries in local deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# migrate_entries returns the user_settings unchanged
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
assert result.llm_model == 'test-model'
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_no_user_found(self, mock_user_settings):
|
||||
"""Test migrate_entries when user is not found."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
# Mock the _get_user method directly to return None
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_user', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = None
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_already_migrated(
|
||||
self, mock_user_settings, mock_user_response
|
||||
):
|
||||
"""Test migrate_entries when user is already migrated (no max_budget)."""
|
||||
mock_user_response.json.return_value = {
|
||||
'user_info': {
|
||||
'max_budget': None, # Already migrated
|
||||
'spend': 10.0,
|
||||
}
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_user_response
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_successful_migration(
|
||||
self, mock_user_settings, mock_user_response, mock_response
|
||||
):
|
||||
"""Test successful migrate_entries operation."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_user_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# migrate_entries returns the user_settings unchanged
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
assert result.llm_model == 'test-model'
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify migration steps were called
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, update_user, add_user_to_team, update_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_and_users_budget_missing_config(self):
|
||||
"""Test update_team_and_users_budget when LiteLLM config is missing."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
# Should not raise an exception, just return early
|
||||
await LiteLlmManager.update_team_and_users_budget('test-team-id', 100.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_and_users_budget_successful(
|
||||
self, mock_team_response, mock_response
|
||||
):
|
||||
"""Test successful update_team_and_users_budget operation."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.get.return_value = mock_team_response
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
'test-team-id', 100.0
|
||||
)
|
||||
|
||||
# Verify update_team and update_user_in_team were called
|
||||
assert (
|
||||
mock_client.post.call_count == 2
|
||||
) # update_team, update_user_in_team
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_team operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/team/new' in call_args[0]
|
||||
assert call_args[1]['json']['team_id'] == 'test-team-id'
|
||||
assert call_args[1]['json']['team_alias'] == 'test-alias'
|
||||
assert call_args[1]['json']['max_budget'] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_already_exists(self, mock_http_client):
|
||||
"""Test _create_team when team already exists."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 400
|
||||
error_response.text = 'Team already exists. Please use a different team id'
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch.object(
|
||||
LiteLlmManager, '_update_team', new_callable=AsyncMock
|
||||
) as mock_update:
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
mock_update.assert_called_once_with(
|
||||
mock_http_client, 'test-team-id', 'test-alias', 100.0
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_error(self, mock_http_client):
|
||||
"""Test _create_team with unexpected error."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 500
|
||||
error_response.text = 'Internal server error'
|
||||
error_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
'Server error', request=MagicMock(), response=error_response
|
||||
)
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_team_success(self, mock_http_client, mock_team_response):
|
||||
"""Test successful _get_team operation."""
|
||||
mock_http_client.get.return_value = mock_team_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_team(
|
||||
mock_http_client, 'test-team-id'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert 'team_memberships' in result
|
||||
mock_http_client.get.assert_called_once_with(
|
||||
'http://test.com/team/info?team_id=test-team-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_user operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/user/new' in call_args[0]
|
||||
assert call_args[1]['json']['user_email'] == 'test@example.com'
|
||||
assert call_args[1]['json']['user_id'] == 'test-user-id'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_duplicate_email(self, mock_http_client, mock_response):
|
||||
"""Test _create_user with duplicate email handling."""
|
||||
# First call fails with duplicate email
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 400
|
||||
error_response.text = 'duplicate email'
|
||||
|
||||
# Second call succeeds
|
||||
mock_http_client.post.side_effect = [error_response, mock_response]
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
assert mock_http_client.post.call_count == 2
|
||||
# Second call should have None email
|
||||
second_call_args = mock_http_client.post.call_args_list[1]
|
||||
assert second_call_args[1]['json']['user_email'] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_key_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _generate_key operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._generate_key(
|
||||
mock_http_client,
|
||||
'test-user-id',
|
||||
'test-team-id',
|
||||
'test-alias',
|
||||
{'test': 'metadata'},
|
||||
)
|
||||
|
||||
assert result == 'test-api-key'
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/key/generate' in call_args[0]
|
||||
assert call_args[1]['json']['user_id'] == 'test-user-id'
|
||||
assert call_args[1]['json']['team_id'] == 'test-team-id'
|
||||
assert call_args[1]['json']['key_alias'] == 'test-alias'
|
||||
assert call_args[1]['json']['metadata'] == {'test': 'metadata'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_info_success(self, mock_http_client, mock_key_info_response):
|
||||
"""Test successful _get_key_info operation."""
|
||||
mock_http_client.get.return_value = mock_key_info_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
# Mock user with org member
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.org_id = 'test-ord-id'
|
||||
mock_org_member.llm_api_key = 'test-api-key'
|
||||
mock_user.org_members = [mock_org_member]
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result['key_max_budget'] == 100.0
|
||||
assert result['key_spend'] == 25.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_info_no_user(self, mock_http_client):
|
||||
"""Test _get_key_info when user is not found."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
mock_user_store.get_user_by_id.return_value = None
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _delete_key operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/key/delete' in call_args[0]
|
||||
assert call_args[1]['json']['keys'] == ['test-key-id']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_not_found(self, mock_http_client):
|
||||
"""Test _delete_key when key is not found (404 error)."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 404
|
||||
error_response.text = 'Key not found'
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
# Should not raise an exception for 404
|
||||
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_http_client_decorator(self):
|
||||
"""Test the with_http_client decorator functionality."""
|
||||
|
||||
# Create a mock internal function
|
||||
async def mock_internal_fn(client, arg1, arg2, kwarg1=None):
|
||||
return f'client={type(client).__name__}, arg1={arg1}, arg2={arg2}, kwarg1={kwarg1}'
|
||||
|
||||
# Apply the decorator
|
||||
decorated_fn = LiteLlmManager.with_http_client(mock_internal_fn)
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
result = await decorated_fn('test1', 'test2', kwarg1='test3')
|
||||
|
||||
# Verify the client was injected as the first argument
|
||||
assert 'client=AsyncMock' in result
|
||||
assert 'arg1=test1' in result
|
||||
assert 'arg2=test2' in result
|
||||
assert 'kwarg1=test3' in result
|
||||
|
||||
def test_public_methods_exist(self):
|
||||
"""Test that all public wrapper methods exist and are properly decorated."""
|
||||
public_methods = [
|
||||
'create_team',
|
||||
'get_team',
|
||||
'update_team',
|
||||
'create_user',
|
||||
'get_user',
|
||||
'update_user',
|
||||
'delete_user',
|
||||
'add_user_to_team',
|
||||
'get_user_team_info',
|
||||
'update_user_in_team',
|
||||
'generate_key',
|
||||
'get_key_info',
|
||||
'delete_key',
|
||||
]
|
||||
|
||||
for method_name in public_methods:
|
||||
assert hasattr(LiteLlmManager, method_name)
|
||||
method = getattr(LiteLlmManager, method_name)
|
||||
assert callable(method)
|
||||
# The methods are created by the with_http_client decorator, so they're functions
|
||||
# We can verify they exist and are callable, which is the important part
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_missing_config_all_methods(self):
|
||||
"""Test that all methods handle missing configuration gracefully."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Test all private methods that check for config
|
||||
await LiteLlmManager._create_team(
|
||||
mock_client, 'alias', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._update_team(
|
||||
mock_client, 'team_id', 'alias', 100.0
|
||||
)
|
||||
await LiteLlmManager._create_user(mock_client, 'email', 'user_id')
|
||||
await LiteLlmManager._update_user(mock_client, 'user_id')
|
||||
await LiteLlmManager._delete_user(mock_client, 'user_id')
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
mock_client, 'user_id', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
mock_client, 'user_id', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._delete_key(mock_client, 'key_id')
|
||||
|
||||
result1 = await LiteLlmManager._get_team(mock_client, 'team_id')
|
||||
result2 = await LiteLlmManager._get_user(mock_client, 'user_id')
|
||||
result3 = await LiteLlmManager._generate_key(
|
||||
mock_client, 'user_id', 'team_id', 'alias', {}
|
||||
)
|
||||
result4 = await LiteLlmManager._get_user_team_info(
|
||||
mock_client, 'user_id', 'team_id'
|
||||
)
|
||||
result5 = await LiteLlmManager._get_key_info(
|
||||
mock_client, 'test-ord-id', 'user_id'
|
||||
)
|
||||
|
||||
# Methods that return None when config is missing
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
assert result3 is None
|
||||
assert result4 is None
|
||||
assert result5 is None
|
||||
|
||||
# Verify no HTTP calls were made
|
||||
mock_client.get.assert_not_called()
|
||||
mock_client.post.assert_not_called()
|
||||
70
enterprise/tests/unit/test_models.py
Normal file
70
enterprise/tests/unit/test_models.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Test that the models are correctly defined.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def test_user_model(session_maker):
|
||||
"""Test that the User model works correctly."""
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test_org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create a test user
|
||||
test_user_id = uuid4()
|
||||
user = User(id=test_user_id, current_org_id=org.id, language='en')
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org_member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=1,
|
||||
llm_api_key='test-api-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
# Query the user
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user is not None
|
||||
assert queried_user.language == 'en'
|
||||
|
||||
# Query the org
|
||||
queried_org = session.query(Org).filter(Org.id == org.id).first()
|
||||
assert queried_org is not None
|
||||
assert queried_org.name == 'test_org'
|
||||
|
||||
# Query the org_member relationship
|
||||
queried_org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org.id, OrgMember.user_id == user.id)
|
||||
.first()
|
||||
)
|
||||
assert queried_org_member is not None
|
||||
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
253
enterprise/tests/unit/test_org_member_store.py
Normal file
253
enterprise/tests/unit/test_org_member_store.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing OrgMemberStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
def test_get_org_members(session_maker):
|
||||
# Test getting org_members by org ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user1 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
user2 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user1, user2, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user1.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user2.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_org_members(org_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [om.llm_api_key.get_secret_value() for om in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_user_orgs(session_maker):
|
||||
# Test getting org_members by user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org1.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org1.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org2.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_user_orgs(user_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [ou.llm_api_key.get_secret_value() for ou in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_org_member(session_maker):
|
||||
# Test getting org_member by org and user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is not None
|
||||
assert retrieved_org_member.org_id == org_id
|
||||
assert retrieved_org_member.user_id == user_id
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
def test_add_user_to_org(session_maker):
|
||||
# Test adding a user to an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role_id = role.id
|
||||
|
||||
# Test creation
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_member = OrgMemberStore.add_user_to_org(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key='new-test-key',
|
||||
status='active',
|
||||
)
|
||||
|
||||
assert org_member is not None
|
||||
assert org_member.org_id == org_id
|
||||
assert org_member.user_id == user_id
|
||||
assert org_member.role_id == role_id
|
||||
assert org_member.llm_api_key.get_secret_value() == 'new-test-key'
|
||||
assert org_member.status == 'active'
|
||||
|
||||
|
||||
def test_update_user_role_in_org(session_maker):
|
||||
# Test updating user role in org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([user, role1, role2])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role1.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role2_id = role2.id
|
||||
|
||||
# Test update
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=org_id, user_id=user_id, role_id=role2_id, status='inactive'
|
||||
)
|
||||
|
||||
assert updated_org_member is not None
|
||||
assert updated_org_member.role_id == role2_id
|
||||
assert updated_org_member.status == 'inactive'
|
||||
|
||||
|
||||
def test_update_user_role_in_org_not_found(session_maker):
|
||||
# Test updating org_member that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=uuid4(), user_id=99999, role_id=1
|
||||
)
|
||||
assert updated_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org(session_maker):
|
||||
# Test removing a user from an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test removal
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(org_id, user_id)
|
||||
assert result is True
|
||||
|
||||
# Verify it's removed
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org_not_found(session_maker):
|
||||
# Test removing user from org that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
|
||||
assert result is False
|
||||
197
enterprise/tests/unit/test_org_store.py
Normal file
197
enterprise/tests/unit/test_org_store.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
# Mock the database module before importing OrgStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.patch.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_get_org_by_id(session_maker, mock_litellm_api):
|
||||
# Test getting org by ID
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_id(org_id)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.id == org_id
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_org_by_id_not_found(session_maker):
|
||||
# Test getting org by ID when it doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
non_existent_id = uuid.uuid4()
|
||||
retrieved_org = OrgStore.get_org_by_id(non_existent_id)
|
||||
assert retrieved_org is None
|
||||
|
||||
|
||||
def test_list_orgs(session_maker, mock_litellm_api):
|
||||
# Test listing all orgs
|
||||
with session_maker() as session:
|
||||
# Create test orgs
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
orgs = OrgStore.list_orgs()
|
||||
assert len(orgs) >= 2
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'test-org-1' in org_names
|
||||
assert 'test-org-2' in org_names
|
||||
|
||||
|
||||
def test_update_org(session_maker, mock_litellm_api):
|
||||
# Test updating org details
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org', agent='CodeActAgent')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test update
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=org_id, kwargs={'name': 'updated-org', 'agent': 'PlannerAgent'}
|
||||
)
|
||||
|
||||
assert updated_org is not None
|
||||
assert updated_org.name == 'updated-org'
|
||||
assert updated_org.agent == 'PlannerAgent'
|
||||
|
||||
|
||||
def test_update_org_not_found(session_maker):
|
||||
# Test updating org that doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
from uuid import uuid4
|
||||
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=uuid4(), kwargs={'name': 'updated-org'}
|
||||
)
|
||||
assert updated_org is None
|
||||
|
||||
|
||||
def test_create_org(session_maker, mock_litellm_api):
|
||||
# Test creating a new org
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
org = OrgStore.create_org(kwargs={'name': 'new-org', 'agent': 'CodeActAgent'})
|
||||
|
||||
assert org is not None
|
||||
assert org.name == 'new-org'
|
||||
assert org.agent == 'CodeActAgent'
|
||||
assert org.id is not None
|
||||
|
||||
|
||||
def test_get_org_by_name(session_maker, mock_litellm_api):
|
||||
# Test getting org by name
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org-by-name')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_name('test-org-by-name')
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org-by-name'
|
||||
|
||||
|
||||
def test_get_current_org_from_keycloak_user_id(session_maker, mock_litellm_api):
|
||||
# Test getting current org from user ID
|
||||
test_user_id = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
from storage.user import User
|
||||
|
||||
user = User(id=test_user_id, current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_current_org_from_keycloak_user_id(
|
||||
str(test_user_id)
|
||||
)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_kwargs_from_settings():
|
||||
# Test extracting org kwargs from settings
|
||||
settings = Settings(
|
||||
language='es',
|
||||
agent='CodeActAgent',
|
||||
llm_model='gpt-4',
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
enable_sound_notifications=True,
|
||||
)
|
||||
|
||||
kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
|
||||
# Should only include fields that exist in Org model
|
||||
assert 'agent' in kwargs
|
||||
assert 'default_llm_model' in kwargs
|
||||
assert kwargs['agent'] == 'CodeActAgent'
|
||||
assert kwargs['default_llm_model'] == 'gpt-4'
|
||||
# Should not include fields that don't exist in Org model
|
||||
assert 'language' not in kwargs # language is not in Org model
|
||||
assert 'llm_api_key' not in kwargs
|
||||
assert 'llm_model' not in kwargs
|
||||
assert 'enable_sound_notifications' not in kwargs
|
||||
@@ -1,32 +1,15 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.org import Org
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# Mock the call_sync_from_async function to return the result of the function directly
|
||||
def mock_call_sync_from_async(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
filter = MagicMock()
|
||||
|
||||
# Mock the context manager behavior
|
||||
session.__enter__.return_value = session
|
||||
|
||||
session.query.return_value = query
|
||||
query.filter.return_value = filter
|
||||
|
||||
return session, query, filter
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
"""Test that the function returns False when no user ID is provided."""
|
||||
with patch(
|
||||
@@ -42,75 +25,82 @@ async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
assert await get_user_proactive_conversation_setting(None) is False
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found():
|
||||
"""Test that False is returned when the user is not found."""
|
||||
session, query, filter = mock_session
|
||||
filter.first.return_value = None
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=None,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none():
|
||||
"""Test that False is returned when the user setting is None."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = None
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = None
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true():
|
||||
"""Test that True is returned when the user setting is True and the global setting is True."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = True
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is True
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false():
|
||||
"""Test that False is returned when the user setting is False, regardless of global setting."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = False
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = False
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
83
enterprise/tests/unit/test_role_store.py
Normal file
83
enterprise/tests/unit/test_role_store.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing RoleStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
def test_get_role_by_id(session_maker):
|
||||
# Test getting role by ID
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(role_id)
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_id_not_found(session_maker):
|
||||
# Test getting role by ID when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(99999)
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_get_role_by_name(session_maker):
|
||||
# Test getting role by name
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('admin')
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_name_not_found(session_maker):
|
||||
# Test getting role by name when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('nonexistent')
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_list_roles(session_maker):
|
||||
# Test listing all roles
|
||||
with session_maker() as session:
|
||||
# Create test roles
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([role1, role2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
roles = RoleStore.list_roles()
|
||||
assert len(roles) >= 2
|
||||
role_names = [role.name for role in roles]
|
||||
assert 'admin' in role_names
|
||||
assert 'user' in role_names
|
||||
|
||||
|
||||
def test_create_role(session_maker):
|
||||
# Test creating a new role
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
role = RoleStore.create_role(name='moderator', rank=2)
|
||||
|
||||
assert role is not None
|
||||
assert role.name == 'moderator'
|
||||
assert role.rank == 2
|
||||
assert role.id is not None
|
||||
@@ -1,11 +1,16 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_call_sync_from_async():
|
||||
@@ -20,12 +25,22 @@ def mock_call_sync_from_async():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_user_store():
|
||||
"""Mock UserStore.get_user_by_id to return a mock user"""
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.current_org_id = UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
with patch('storage.user_store.UserStore.get_user_by_id', return_value=mock_user):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_and_get(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='my-conversation-id',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='my-repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -47,13 +62,13 @@ async def test_save_and_get(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
|
||||
# Create test conversations with different timestamps
|
||||
conversations = [
|
||||
ConversationMetadata(
|
||||
conversation_id=f'conv-{i}',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime(2024, 1, i + 1, tzinfo=UTC),
|
||||
@@ -92,10 +107,10 @@ async def test_search(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_metadata(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='to-delete',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -112,17 +127,17 @@ async def test_delete_metadata(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_metadata(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await store.get_metadata('nonexistent-id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='exists-test',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch='test-branch',
|
||||
created_at=datetime.now(UTC),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
@@ -19,6 +20,14 @@ def mock_config():
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secrets_store(session_maker, mock_config):
|
||||
return SaasSecretsStore('user-id', session_maker, mock_config)
|
||||
@@ -26,7 +35,11 @@ def secrets_store(session_maker, mock_config):
|
||||
|
||||
class TestSaasSecretsStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_load(self, secrets_store):
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Create a Secrets object with some test data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -59,7 +72,10 @@ class TestSaasSecretsStore:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encryption_decryption(self, secrets_store):
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
# Create a Secrets object with sensitive data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -89,6 +105,7 @@ class TestSaasSecretsStore:
|
||||
stored = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == 'user-id')
|
||||
.filter(StoredCustomSecrets.org_id == mock_user.current_org_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -152,7 +169,12 @@ class TestSaasSecretsStore:
|
||||
assert await secrets_store.load() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_existing_secrets(self, secrets_store):
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_update_existing_secrets(
|
||||
self, mock_get_user, secrets_store, mock_user
|
||||
):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
# Create and store initial secrets
|
||||
initial_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -143,7 +143,6 @@ def sample_private_conversation_info():
|
||||
class TestSharedConversationInfoService:
|
||||
"""Test cases for SharedConversationInfoService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_public_conversation(
|
||||
self,
|
||||
@@ -165,7 +164,8 @@ class TestSharedConversationInfoService:
|
||||
assert result is not None
|
||||
assert result.id == sample_conversation_info.id
|
||||
assert result.title == sample_conversation_info.title
|
||||
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
|
||||
# Note: created_by_user_id is no longer stored in shared conversation metadata
|
||||
assert result.created_by_user_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
|
||||
|
||||
@@ -3,27 +3,30 @@ This test file verifies that the stripe_service functions properly use the datab
|
||||
to store and retrieve customer IDs.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from integrations.stripe_service import (
|
||||
find_customer_id_by_user_id,
|
||||
find_or_create_customer,
|
||||
find_or_create_customer_by_user_id,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user_settings import Base as UserBase
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
|
||||
UserBase.metadata.create_all(engine)
|
||||
StripeCustomerBase.metadata.create_all(engine)
|
||||
# Create all tables using the unified Base
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@@ -32,79 +35,158 @@ def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_org_and_user(session_maker):
|
||||
"""Create a test org and user for use in tests."""
|
||||
test_user_id = uuid.uuid4()
|
||||
test_org_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create role first
|
||||
role = Role(name='test-role', rank=1)
|
||||
session.add(role)
|
||||
session.flush()
|
||||
|
||||
# Create org
|
||||
org = Org(id=test_org_id, name='test-org', contact_email='testy@tester.com')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create user with current_org_id
|
||||
user = User(id=test_user_id, current_org_id=test_org_id, role_id=role.id)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=test_org_id,
|
||||
user_id=test_user_id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
return test_user_id, test_org_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_checks_db_first(session_maker):
|
||||
async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
session_maker, test_org_and_user
|
||||
):
|
||||
"""Test that find_customer_id_by_user_id checks the database first"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for the database query result
|
||||
with session_maker() as session:
|
||||
# Create stripe customer
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='test-user-id',
|
||||
keycloak_user_id=str(test_user_id),
|
||||
org_id=test_org_id,
|
||||
stripe_customer_id='cus_test123',
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
with patch('integrations.stripe_service.session_maker', session_maker):
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id('test-user-id')
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that call_sync_from_async was called with the correct function
|
||||
mock_call_sync.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_falls_back_to_stripe(session_maker):
|
||||
async def test_find_customer_id_by_user_id_falls_back_to_stripe(
|
||||
session_maker, test_org_and_user
|
||||
):
|
||||
"""Test that find_customer_id_by_user_id falls back to Stripe if not found in the database"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async
|
||||
mock_customer = stripe.Customer(id='cus_test123')
|
||||
mock_search = AsyncMock(return_value=MagicMock(data=[mock_customer]))
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id('test-user-id')
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that Stripe was searched
|
||||
# Verify that Stripe was searched with the org_id
|
||||
mock_search.assert_called_once()
|
||||
assert "metadata['user_id']:'test-user-id'" in mock_search.call_args[1]['query']
|
||||
assert (
|
||||
f"metadata['org_id']:'{str(test_org_id)}'" in mock_search.call_args[1]['query']
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_customer_stores_id_in_db(session_maker):
|
||||
async def test_create_customer_stores_id_in_db(session_maker, test_org_and_user):
|
||||
"""Test that create_customer stores the customer ID in the database"""
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async and create_async
|
||||
mock_search = AsyncMock(return_value=MagicMock(data=[]))
|
||||
mock_create_async = AsyncMock(return_value=stripe.Customer(id='cus_test123'))
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('stripe.Customer.create_async', mock_create_async),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_or_create_customer('test-user-id')
|
||||
result = await find_or_create_customer_by_user_id(str(test_user_id))
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
assert result == {'customer_id': 'cus_test123', 'org_id': str(test_org_id)}
|
||||
|
||||
# Verify that the stripe customer was stored in the db
|
||||
with session_maker() as session:
|
||||
customer = session.query(StripeCustomer).first()
|
||||
assert customer.id > 0
|
||||
assert customer.keycloak_user_id == 'test-user-id'
|
||||
assert customer.keycloak_user_id == str(test_user_id)
|
||||
assert customer.org_id == test_org_id
|
||||
assert customer.stripe_customer_id == 'cus_test123'
|
||||
assert customer.created_at is not None
|
||||
assert customer.updated_at is not None
|
||||
|
||||
164
enterprise/tests/unit/test_user_store.py
Normal file
164
enterprise/tests/unit/test_user_store.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
# Mock the database module before importing UserStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from sqlalchemy.orm import configure_mappers
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope='session')
|
||||
def load_all_models():
|
||||
configure_mappers() # fail fast if anything’s missing
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stripe():
|
||||
search_patch = patch(
|
||||
'stripe.Customer.search_async',
|
||||
AsyncMock(return_value=MagicMock(id='mock-customer-id')),
|
||||
)
|
||||
payment_patch = patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[{}])),
|
||||
)
|
||||
with search_patch, payment_patch:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_no_org_id():
|
||||
# Test UserStore.create_default_settings with empty org_id
|
||||
settings = await UserStore.create_default_settings('', 'test-user-id')
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_org(session_maker, mock_stripe):
|
||||
# Mock stripe_service.has_payment_method to return False
|
||||
with (
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
):
|
||||
settings = await UserStore.create_default_settings(
|
||||
'test-org-id', 'test-user-id'
|
||||
)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_with_litellm(session_maker, mock_litellm_api):
|
||||
# Test that UserStore.create_default_settings works with LiteLLM
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.user_store.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'attributes': {'github_id': ['12345']}}),
|
||||
),
|
||||
):
|
||||
settings = await UserStore.create_default_settings(
|
||||
'test-org-id', 'test-user-id'
|
||||
)
|
||||
assert settings is not None
|
||||
assert settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
assert settings.llm_base_url == 'http://test.url'
|
||||
assert settings.agent == 'CodeActAgent'
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Complex integration test with session isolation issues')
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(session_maker, mock_litellm_api):
|
||||
# Test creating a new user - skipped due to complex session isolation issues
|
||||
pass
|
||||
|
||||
|
||||
def test_get_user_by_id(session_maker):
|
||||
# Test getting user by ID
|
||||
test_org_id = uuid.uuid4()
|
||||
test_user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
with session_maker() as session:
|
||||
# Create a test user
|
||||
user = User(id=uuid.UUID(test_user_id), current_org_id=test_org_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
retrieved_user = UserStore.get_user_by_id(test_user_id)
|
||||
assert retrieved_user is not None
|
||||
assert retrieved_user.id == user_id
|
||||
|
||||
|
||||
def test_list_users(session_maker):
|
||||
# Test listing all users
|
||||
test_org_id1 = uuid.uuid4()
|
||||
test_org_id2 = uuid.uuid4()
|
||||
test_user_id1 = uuid.uuid4()
|
||||
test_user_id2 = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
# Create test users
|
||||
user1 = User(id=test_user_id1, current_org_id=test_org_id1)
|
||||
user2 = User(id=test_user_id2, current_org_id=test_org_id2)
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
users = UserStore.list_users()
|
||||
assert len(users) >= 2
|
||||
user_ids = [user.id for user in users]
|
||||
assert test_user_id1 in user_ids
|
||||
assert test_user_id2 in user_ids
|
||||
|
||||
|
||||
def test_get_kwargs_from_settings():
|
||||
# Test extracting user kwargs from settings
|
||||
settings = Settings(
|
||||
language='es',
|
||||
enable_sound_notifications=True,
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
)
|
||||
|
||||
kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
|
||||
# Should only include fields that exist in User model
|
||||
assert 'language' in kwargs
|
||||
assert 'enable_sound_notifications' in kwargs
|
||||
# Should not include fields that don't exist in User model
|
||||
assert 'llm_api_key' not in kwargs
|
||||
@@ -69,8 +69,6 @@ class StoredConversationMetadata(Base): # type: ignore
|
||||
conversation_id = Column(
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
github_user_id = Column(String, nullable=True) # The GitHub user ID
|
||||
user_id = Column(String, nullable=False) # The Keycloak User ID
|
||||
selected_repository = Column(String, nullable=True)
|
||||
selected_branch = Column(String, nullable=True)
|
||||
git_provider = Column(
|
||||
@@ -199,10 +197,9 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count sandboxed conversations matching the given filters."""
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = query.where(StoredConversationMetadata.user_id == user_id)
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id)).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
@@ -319,22 +316,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
existing = result.scalar_one_or_none()
|
||||
assert existing is None or existing.user_id == user_id
|
||||
|
||||
metrics = info.metrics or MetricsSnapshot()
|
||||
usage = metrics.accumulated_token_usage or TokenUsage()
|
||||
|
||||
stored = StoredConversationMetadata(
|
||||
conversation_id=str(info.id),
|
||||
github_user_id=None, # TODO: Should we add this to the conversation info?
|
||||
user_id=info.created_by_user_id or '',
|
||||
selected_repository=info.selected_repository,
|
||||
selected_branch=info.selected_branch,
|
||||
git_provider=info.git_provider.value if info.git_provider else None,
|
||||
@@ -342,7 +328,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
last_updated_at=info.updated_at,
|
||||
created_at=info.created_at,
|
||||
trigger=info.trigger.value if info.trigger else None,
|
||||
pr_number=info.pr_number,
|
||||
pr_number=info.pr_number or [],
|
||||
# Cost and token metrics
|
||||
accumulated_cost=metrics.accumulated_cost,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
@@ -496,11 +482,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
)
|
||||
return query
|
||||
|
||||
def _to_info(
|
||||
@@ -535,7 +516,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
|
||||
return AppConversationInfo(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
@@ -580,13 +561,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
# Apply user security filtering - only allow deletion of conversations owned by the current user
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if user_id:
|
||||
delete_query = delete_query.where(
|
||||
StoredConversationMetadata.user_id == user_id
|
||||
)
|
||||
|
||||
# Execute the secure delete query
|
||||
result = await self.db_session.execute(delete_query)
|
||||
|
||||
|
||||
46
openhands/app_server/app_lifespan/alembic/versions/005.py
Normal file
46
openhands/app_server/app_lifespan/alembic/versions/005.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Update conversation_metadata table to match StoredConversationMetadata dataclass
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 004
|
||||
Create Date: 2025-11-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '005'
|
||||
down_revision: Union[str, Sequence[str], None] = '004'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
with op.batch_alter_table('conversation_metadata') as batch_op:
|
||||
# Drop columns not in StoredConversationMetadata dataclass
|
||||
batch_op.drop_column('github_user_id')
|
||||
|
||||
# Alter user_id to become nullable
|
||||
batch_op.alter_column(
|
||||
'user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
with op.batch_alter_table('conversation_metadata') as batch_op:
|
||||
# Add back removed column
|
||||
batch_op.add_column(sa.Column('github_user_id', sa.String(), nullable=True))
|
||||
|
||||
# Restore NOT NULL constraint
|
||||
batch_op.alter_column(
|
||||
'user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -233,7 +233,20 @@ def config_from_env() -> AppServerConfig:
|
||||
config.sandbox_spec = DockerSandboxSpecServiceInjector()
|
||||
|
||||
if config.app_conversation_info is None:
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
# Use enterprise injector if running in SAAS mode
|
||||
if 'saas' in (os.getenv('OPENHANDS_CONFIG_CLS') or '').lower():
|
||||
try:
|
||||
# Import enterprise injector dynamically
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasAppConversationInfoServiceInjector,
|
||||
)
|
||||
|
||||
config.app_conversation_info = SaasAppConversationInfoServiceInjector()
|
||||
except ImportError:
|
||||
# Fallback to OSS injector if enterprise module is not available
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
else:
|
||||
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
|
||||
|
||||
if config.app_conversation_start_task is None:
|
||||
config.app_conversation_start_task = (
|
||||
|
||||
@@ -153,15 +153,15 @@ class JwtService:
|
||||
|
||||
# Add standard JWT claims
|
||||
now = utc_now()
|
||||
if expires_in is None:
|
||||
expires_in = timedelta(hours=1)
|
||||
|
||||
jwt_payload = {
|
||||
**payload,
|
||||
'iat': int(now.timestamp()),
|
||||
'exp': int((now + expires_in).timestamp()),
|
||||
}
|
||||
|
||||
# Only add exp if expires_in is provided
|
||||
if expires_in is not None:
|
||||
jwt_payload['exp'] = int((now + expires_in).timestamp())
|
||||
|
||||
# Get the raw key for JWE encryption and derive a 256-bit key
|
||||
secret_key = self._keys[key_id].key.get_secret_value()
|
||||
key_bytes = secret_key.encode() if isinstance(secret_key, str) else secret_key
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -22,16 +22,6 @@ class FileReadSource(str, Enum):
|
||||
DEFAULT = 'default'
|
||||
|
||||
|
||||
class RecallType(str, Enum):
|
||||
"""The type of information that can be retrieved from microagents."""
|
||||
|
||||
WORKSPACE_CONTEXT = 'workspace_context'
|
||||
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||
|
||||
KNOWLEDGE = 'knowledge'
|
||||
"""A knowledge microagent."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
INVALID_ID = -1
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
@@ -28,6 +27,7 @@ from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
from openhands.events.observation.task_tracking import TaskTrackingObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
__all__ = [
|
||||
'Observation',
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
11
openhands/events/recall_type.py
Normal file
11
openhands/events/recall_type.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RecallType(str, Enum):
|
||||
"""The type of information that can be retrieved from microagents."""
|
||||
|
||||
WORKSPACE_CONTEXT = 'workspace_context'
|
||||
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||
|
||||
KNOWLEDGE = 'knowledge'
|
||||
"""A knowledge microagent."""
|
||||
@@ -1,7 +1,6 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
@@ -32,6 +31,7 @@ from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
from openhands.events.observation.task_tracking import TaskTrackingObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
observations = (
|
||||
NullObservation,
|
||||
|
||||
@@ -28,7 +28,7 @@ from openhands.integrations.service_types import (
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class AzureDevOpsServiceImpl(
|
||||
class AzureDevOpsService(
|
||||
AzureDevOpsResolverMixin,
|
||||
AzureDevOpsReposMixin,
|
||||
AzureDevOpsBranchesMixin,
|
||||
@@ -242,8 +242,34 @@ class AzureDevOpsServiceImpl(
|
||||
# Dynamic class loading to support custom implementations (e.g., SaaS)
|
||||
azure_devops_service_cls = os.environ.get(
|
||||
'OPENHANDS_AZURE_DEVOPS_SERVICE_CLS',
|
||||
'openhands.integrations.azure_devops.azure_devops_service.AzureDevOpsServiceImpl',
|
||||
)
|
||||
AzureDevOpsServiceImpl = get_impl( # type: ignore[misc]
|
||||
AzureDevOpsServiceImpl, azure_devops_service_cls
|
||||
'openhands.integrations.azure_devops.azure_devops_service.AzureDevOpsService',
|
||||
)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_azure_devops_service_impl = None
|
||||
|
||||
|
||||
def get_azure_devops_service_impl():
|
||||
"""Get the Azure DevOps service implementation with lazy loading."""
|
||||
global _azure_devops_service_impl
|
||||
if _azure_devops_service_impl is None:
|
||||
_azure_devops_service_impl = get_impl( # type: ignore[misc]
|
||||
AzureDevOpsService, azure_devops_service_cls
|
||||
)
|
||||
return _azure_devops_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _AzureDevOpsServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for AzureDevOpsServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_azure_devops_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_azure_devops_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
AzureDevOpsServiceImpl: type[AzureDevOpsService] = _AzureDevOpsServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -64,4 +64,30 @@ bitbucket_service_cls = os.environ.get(
|
||||
'OPENHANDS_BITBUCKET_SERVICE_CLS',
|
||||
'openhands.integrations.bitbucket.bitbucket_service.BitBucketService',
|
||||
)
|
||||
BitBucketServiceImpl = get_impl(BitBucketService, bitbucket_service_cls)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_bitbucket_service_impl = None
|
||||
|
||||
|
||||
def get_bitbucket_service_impl():
|
||||
"""Get the BitBucket service implementation with lazy loading."""
|
||||
global _bitbucket_service_impl
|
||||
if _bitbucket_service_impl is None:
|
||||
_bitbucket_service_impl = get_impl(BitBucketService, bitbucket_service_cls)
|
||||
return _bitbucket_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _BitBucketServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for BitBucketServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_bitbucket_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_bitbucket_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
BitBucketServiceImpl: type[BitBucketService] = _BitBucketServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -75,4 +75,30 @@ github_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITHUB_SERVICE_CLS',
|
||||
'openhands.integrations.github.github_service.GitHubService',
|
||||
)
|
||||
GithubServiceImpl = get_impl(GitHubService, github_service_cls)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_github_service_impl = None
|
||||
|
||||
|
||||
def get_github_service_impl():
|
||||
"""Get the GitHub service implementation with lazy loading."""
|
||||
global _github_service_impl
|
||||
if _github_service_impl is None:
|
||||
_github_service_impl = get_impl(GitHubService, github_service_cls)
|
||||
return _github_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _GitHubServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for GithubServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_github_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_github_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
GithubServiceImpl: type[GitHubService] = _GitHubServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -79,4 +79,30 @@ gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
)
|
||||
GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls)
|
||||
|
||||
# Lazy loading to avoid circular imports
|
||||
_gitlab_service_impl = None
|
||||
|
||||
|
||||
def get_gitlab_service_impl():
|
||||
"""Get the GitLab service implementation with lazy loading."""
|
||||
global _gitlab_service_impl
|
||||
if _gitlab_service_impl is None:
|
||||
_gitlab_service_impl = get_impl(GitLabService, gitlab_service_cls)
|
||||
return _gitlab_service_impl
|
||||
|
||||
|
||||
# For backward compatibility, provide the implementation as a property
|
||||
class _GitLabServiceImplProxy:
|
||||
"""Proxy class to provide lazy loading for GitLabServiceImpl."""
|
||||
|
||||
def __getattr__(self, name):
|
||||
impl = get_gitlab_service_impl()
|
||||
return getattr(impl, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
impl = get_gitlab_service_impl()
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
|
||||
GitLabServiceImpl: type[GitLabService] = _GitLabServiceImplProxy() # type: ignore[assignment]
|
||||
|
||||
@@ -22,7 +22,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
@@ -44,6 +44,7 @@ from openhands.events.observation.agent import (
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.utils.prompt import (
|
||||
ConversationInstructions,
|
||||
|
||||
@@ -9,12 +9,13 @@ import openhands
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
from openhands.microagent import (
|
||||
BaseMicroagent,
|
||||
|
||||
@@ -159,7 +159,7 @@ class AgentSession:
|
||||
await provider_handler.set_event_stream_secrets(self.event_stream)
|
||||
|
||||
if custom_secrets:
|
||||
custom_secrets_handler.set_event_stream_secrets(self.event_stream)
|
||||
self.event_stream.set_secrets(custom_secrets_handler.get_env_vars())
|
||||
|
||||
self.memory = await self._create_memory(
|
||||
selected_repository=selected_repository,
|
||||
|
||||
@@ -13,7 +13,6 @@ from pydantic import (
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
CUSTOM_SECRETS_TYPE,
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
@@ -144,14 +143,6 @@ class Secrets(BaseModel):
|
||||
|
||||
return new_data
|
||||
|
||||
def set_event_stream_secrets(self, event_stream: EventStream) -> None:
|
||||
"""This ensures that provider tokens and custom secrets masked from the event stream
|
||||
Args:
|
||||
event_stream: Agent session's event stream
|
||||
"""
|
||||
secrets = self.get_env_vars()
|
||||
event_stream.set_secrets(secrets)
|
||||
|
||||
def get_env_vars(self) -> dict[str, str]:
|
||||
secret_store = self.model_dump(context={'expose_secrets': True})
|
||||
custom_secrets = secret_store.get('custom_secrets', {})
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -51,7 +49,7 @@ class Settings(BaseModel):
|
||||
email_verified: bool | None = None
|
||||
git_user_name: str | None = None
|
||||
git_user_email: str | None = None
|
||||
v1_enabled: bool | None = Field(default=bool(os.getenv('V1_ENABLED') == '1'))
|
||||
v1_enabled: bool = True
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
|
||||
14025
poetry.lock
generated
14025
poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -233,7 +233,7 @@ libtmux = ">=0.46.2"
|
||||
pygithub = "^2.5.0"
|
||||
joblib = "*"
|
||||
openhands-aci = "0.3.2"
|
||||
python-socketio = "^5.11.4"
|
||||
python-socketio = "5.13.0"
|
||||
sse-starlette = "^3.0.2"
|
||||
psutil = "*"
|
||||
python-json-logger = "^3.2.1"
|
||||
|
||||
@@ -199,9 +199,8 @@ class TestJwtService:
|
||||
|
||||
# Check that standard JWT claims are added
|
||||
assert 'iat' in decrypted_payload
|
||||
assert 'exp' in decrypted_payload
|
||||
assert 'exp' not in decrypted_payload
|
||||
assert isinstance(decrypted_payload['iat'], int) # JWE uses timestamp integers
|
||||
assert isinstance(decrypted_payload['exp'], int)
|
||||
|
||||
def test_jwe_token_round_trip_specific_key(self, jwt_service):
|
||||
"""Test JWE token creation and decryption with specific key."""
|
||||
@@ -420,7 +419,7 @@ class TestJwtService:
|
||||
|
||||
# Should still have standard claims
|
||||
assert 'iat' in jwe_decrypted
|
||||
assert 'exp' in jwe_decrypted
|
||||
assert 'exp' not in jwe_decrypted
|
||||
|
||||
def test_unicode_and_special_characters(self, jwt_service):
|
||||
"""Test JWS and JWE with unicode and special characters."""
|
||||
|
||||
@@ -27,6 +27,8 @@ from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
# Note: org_id column exists but foreign key constraint is not enforced in tests
|
||||
|
||||
# Note: MetricsSnapshot from SDK is not available in test environment
|
||||
# We'll use None for metrics field in tests since it's optional
|
||||
|
||||
@@ -106,7 +108,7 @@ def multiple_conversation_infos() -> list[AppConversationInfo]:
|
||||
return [
|
||||
AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user_123',
|
||||
created_by_user_id=None,
|
||||
sandbox_id=f'sandbox_{i}',
|
||||
selected_repository=f'https://github.com/test/repo{i}',
|
||||
selected_branch='main',
|
||||
@@ -151,10 +153,6 @@ class TestSQLAppConversationInfoService:
|
||||
# Verify the retrieved info matches the original
|
||||
assert retrieved_info is not None
|
||||
assert retrieved_info.id == sample_conversation_info.id
|
||||
assert (
|
||||
retrieved_info.created_by_user_id
|
||||
== sample_conversation_info.created_by_user_id
|
||||
)
|
||||
assert retrieved_info.sandbox_id == sample_conversation_info.sandbox_id
|
||||
assert (
|
||||
retrieved_info.selected_repository
|
||||
@@ -206,7 +204,6 @@ class TestSQLAppConversationInfoService:
|
||||
# Verify all fields
|
||||
assert retrieved_info is not None
|
||||
assert retrieved_info.id == original_info.id
|
||||
assert retrieved_info.created_by_user_id == original_info.created_by_user_id
|
||||
assert retrieved_info.sandbox_id == original_info.sandbox_id
|
||||
assert retrieved_info.selected_repository == original_info.selected_repository
|
||||
assert retrieved_info.selected_branch == original_info.selected_branch
|
||||
@@ -235,7 +232,6 @@ class TestSQLAppConversationInfoService:
|
||||
# Verify required fields
|
||||
assert retrieved_info is not None
|
||||
assert retrieved_info.id == minimal_info.id
|
||||
assert retrieved_info.created_by_user_id == minimal_info.created_by_user_id
|
||||
assert retrieved_info.sandbox_id == minimal_info.sandbox_id
|
||||
|
||||
# Verify optional fields are None or default values
|
||||
@@ -486,58 +482,6 @@ class TestSQLAppConversationInfoService:
|
||||
count = await service.count_app_conversation_info(title__contains='nonexistent')
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
# Create services for different users
|
||||
user1_service = SQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id='user1')
|
||||
)
|
||||
user2_service = SQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id='user2')
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='user1',
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='user2',
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
|
||||
# Save conversations
|
||||
await user1_service.save_app_conversation_info(user1_info)
|
||||
await user2_service.save_app_conversation_info(user2_info)
|
||||
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert user1_page.items[0].created_by_user_id == 'user1'
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert user2_page.items[0].created_by_user_id == 'user2'
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
assert user2_from_user1 is None
|
||||
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation_info(
|
||||
self,
|
||||
@@ -567,10 +511,6 @@ class TestSQLAppConversationInfoService:
|
||||
assert retrieved_info.pr_number == [789]
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert (
|
||||
retrieved_info.created_by_user_id
|
||||
== sample_conversation_info.created_by_user_id
|
||||
)
|
||||
assert retrieved_info.sandbox_id == sample_conversation_info.sandbox_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user