mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
265 lines
9.1 KiB
Python
265 lines
9.1 KiB
Python
"""
|
|
Store class for managing organization-member relationships.
|
|
"""
|
|
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from server.routes.org_models import OrgMemberLLMSettings
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import joinedload
|
|
from storage.database import a_session_maker
|
|
from storage.encrypt_utils import encrypt_value
|
|
from storage.org_member import OrgMember
|
|
from storage.user import User
|
|
from storage.user_settings import UserSettings
|
|
|
|
from openhands.storage.data_models.settings import Settings
|
|
|
|
|
|
class OrgMemberStore:
|
|
"""Store for managing organization-member relationships."""
|
|
|
|
@staticmethod
|
|
async def add_user_to_org(
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
role_id: int,
|
|
llm_api_key: str,
|
|
status: Optional[str] = None,
|
|
llm_model: Optional[str] = None,
|
|
llm_base_url: Optional[str] = None,
|
|
max_iterations: Optional[int] = None,
|
|
) -> OrgMember:
|
|
"""Add a user to an organization with a specific role."""
|
|
async with a_session_maker() as session:
|
|
org_member = OrgMember(
|
|
org_id=org_id,
|
|
user_id=user_id,
|
|
role_id=role_id,
|
|
llm_api_key=llm_api_key,
|
|
status=status,
|
|
llm_model=llm_model,
|
|
llm_base_url=llm_base_url,
|
|
max_iterations=max_iterations,
|
|
)
|
|
session.add(org_member)
|
|
await session.commit()
|
|
await session.refresh(org_member)
|
|
return org_member
|
|
|
|
@staticmethod
|
|
async def get_org_member(org_id: UUID, user_id: UUID) -> Optional[OrgMember]:
|
|
"""Get organization-user relationship."""
|
|
async with a_session_maker() as session:
|
|
result = await session.execute(
|
|
select(OrgMember).filter(
|
|
OrgMember.org_id == org_id, OrgMember.user_id == user_id
|
|
)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
@staticmethod
|
|
async def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
|
|
"""Get the org member for a user's current organization.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
The OrgMember for the user's current organization, or None if not found.
|
|
"""
|
|
async with a_session_maker() as session:
|
|
result = await session.execute(
|
|
select(OrgMember)
|
|
.join(User, User.id == OrgMember.user_id)
|
|
.filter(
|
|
User.id == user_id,
|
|
OrgMember.org_id == User.current_org_id,
|
|
)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
@staticmethod
|
|
async def get_user_orgs(user_id: UUID) -> list[OrgMember]:
|
|
"""Get all organizations for a user."""
|
|
async with a_session_maker() as session:
|
|
result = await session.execute(
|
|
select(OrgMember).filter(OrgMember.user_id == user_id)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
@staticmethod
|
|
async def get_org_members(org_id: UUID) -> list[OrgMember]:
|
|
"""Get all users in an organization."""
|
|
async with a_session_maker() as session:
|
|
result = await session.execute(
|
|
select(OrgMember).filter(OrgMember.org_id == org_id)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
@staticmethod
|
|
async def update_org_member(org_member: OrgMember) -> None:
|
|
"""Update an organization-member relationship."""
|
|
async with a_session_maker() as session:
|
|
await session.merge(org_member)
|
|
await session.commit()
|
|
|
|
@staticmethod
|
|
async def update_user_role_in_org(
|
|
org_id: UUID, user_id: UUID, role_id: int, status: Optional[str] = None
|
|
) -> Optional[OrgMember]:
|
|
"""Update user's role in an organization."""
|
|
async with a_session_maker() as session:
|
|
result = await session.execute(
|
|
select(OrgMember).filter(
|
|
OrgMember.org_id == org_id, OrgMember.user_id == user_id
|
|
)
|
|
)
|
|
org_member = result.scalars().first()
|
|
|
|
if not org_member:
|
|
return None
|
|
|
|
org_member.role_id = role_id
|
|
if status is not None:
|
|
org_member.status = status
|
|
|
|
await session.commit()
|
|
await session.refresh(org_member)
|
|
return org_member
|
|
|
|
@staticmethod
|
|
async def remove_user_from_org(org_id: UUID, user_id: UUID) -> bool:
|
|
"""Remove a user from an organization."""
|
|
async with a_session_maker() as session:
|
|
result = await session.execute(
|
|
select(OrgMember).filter(
|
|
OrgMember.org_id == org_id, OrgMember.user_id == user_id
|
|
)
|
|
)
|
|
org_member = result.scalars().first()
|
|
|
|
if not org_member:
|
|
return False
|
|
|
|
await session.delete(org_member)
|
|
await session.commit()
|
|
return True
|
|
|
|
@staticmethod
|
|
def get_kwargs_from_settings(settings: Settings):
|
|
kwargs = {
|
|
normalized: getattr(settings, normalized)
|
|
for c in OrgMember.__table__.columns
|
|
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
|
}
|
|
return kwargs
|
|
|
|
@staticmethod
|
|
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
|
kwargs = {
|
|
normalized: getattr(user_settings, normalized)
|
|
for c in OrgMember.__table__.columns
|
|
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
|
}
|
|
return kwargs
|
|
|
|
@staticmethod
|
|
async def get_org_members_count(
|
|
org_id: UUID,
|
|
email_filter: str | None = None,
|
|
) -> int:
|
|
"""Get total count of organization members, optionally filtered by email.
|
|
|
|
Args:
|
|
org_id: Organization UUID.
|
|
email_filter: Optional case-insensitive partial email match.
|
|
|
|
Returns:
|
|
Total count of matching members.
|
|
"""
|
|
async with a_session_maker() as session:
|
|
query = select(func.count(OrgMember.user_id)).filter(
|
|
OrgMember.org_id == org_id
|
|
)
|
|
|
|
if email_filter:
|
|
query = query.join(User, User.id == OrgMember.user_id).filter(
|
|
User.email.ilike(f'%{email_filter}%')
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
return result.scalar() or 0
|
|
|
|
@staticmethod
|
|
async def get_org_members_paginated(
|
|
org_id: UUID,
|
|
offset: int = 0,
|
|
limit: int = 100,
|
|
email_filter: str | None = None,
|
|
) -> tuple[list[OrgMember], bool]:
|
|
"""Get paginated list of organization members with user and role info.
|
|
|
|
Args:
|
|
org_id: Organization UUID.
|
|
offset: Number of records to skip.
|
|
limit: Maximum number of records to return.
|
|
email_filter: Optional case-insensitive partial email match.
|
|
|
|
Returns:
|
|
Tuple of (members_list, has_more) where has_more indicates if there are more results.
|
|
"""
|
|
async with a_session_maker() as session:
|
|
# Query for limit + 1 items to determine if there are more results
|
|
# Order by user_id for consistent pagination
|
|
query = (
|
|
select(OrgMember)
|
|
.options(joinedload(OrgMember.user), joinedload(OrgMember.role))
|
|
.join(User, User.id == OrgMember.user_id)
|
|
.filter(OrgMember.org_id == org_id)
|
|
)
|
|
|
|
# Apply email filter if provided
|
|
if email_filter:
|
|
query = query.filter(User.email.ilike(f'%{email_filter}%'))
|
|
|
|
query = query.order_by(OrgMember.user_id).offset(offset).limit(limit + 1)
|
|
|
|
result = await session.execute(query)
|
|
members = list(result.unique().scalars().all())
|
|
|
|
# Check if there are more results
|
|
has_more = len(members) > limit
|
|
if has_more:
|
|
# Remove the extra item
|
|
members = members[:limit]
|
|
|
|
return members, has_more
|
|
|
|
@staticmethod
|
|
async def update_all_members_llm_settings_async(
|
|
session: AsyncSession,
|
|
org_id: UUID,
|
|
member_settings: OrgMemberLLMSettings,
|
|
) -> None:
|
|
"""Update LLM settings for all members of an organization.
|
|
|
|
Args:
|
|
session: Database session (passed from caller for transaction)
|
|
org_id: Organization ID
|
|
member_settings: Typed LLM settings to apply to all members
|
|
"""
|
|
# Build update values from non-None fields
|
|
values = member_settings.model_dump(exclude_none=True)
|
|
|
|
# Handle encrypted llm_api_key field - map to _llm_api_key column with encryption
|
|
if 'llm_api_key' in values:
|
|
raw_key = values.pop('llm_api_key')
|
|
values['_llm_api_key'] = encrypt_value(raw_key)
|
|
|
|
if values:
|
|
stmt = update(OrgMember).where(OrgMember.org_id == org_id).values(**values)
|
|
await session.execute(stmt)
|