mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
fix(backend): add organization filtering to V1 conversation queries (#12923)
This commit is contained in:
@@ -22,11 +22,63 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
|||||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||||
SQLAppConversationInfoService,
|
SQLAppConversationInfoService,
|
||||||
)
|
)
|
||||||
|
from openhands.app_server.errors import AuthError
|
||||||
from openhands.app_server.services.injector import InjectorState
|
from openhands.app_server.services.injector import InjectorState
|
||||||
|
|
||||||
|
|
||||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
|
||||||
|
|
||||||
|
async def _get_current_user(self) -> User | None:
|
||||||
|
"""Get the current user using the existing db_session.
|
||||||
|
|
||||||
|
Uses self.db_session to avoid opening a separate database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object or None if no user_id is available
|
||||||
|
"""
|
||||||
|
user_id_str = await self.user_context.get_user_id()
|
||||||
|
if not user_id_str:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id_uuid = UUID(user_id_str)
|
||||||
|
result = await self.db_session.execute(
|
||||||
|
select(User).where(User.id == user_id_uuid)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def _apply_user_and_org_filter(self, query):
|
||||||
|
"""Apply user_id and org_id filters to ensure conversation isolation.
|
||||||
|
|
||||||
|
Filters conversations by:
|
||||||
|
- user_id: Only show conversations belonging to the current user
|
||||||
|
- org_id: Only show conversations belonging to the user's current organization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQLAlchemy query to apply filters to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Query with user and organization filters applied
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError: If no user_id is available (secure default: deny access)
|
||||||
|
"""
|
||||||
|
user_id_str = await self.user_context.get_user_id()
|
||||||
|
if not user_id_str:
|
||||||
|
# Secure default: no user means no access, not "show everything"
|
||||||
|
raise AuthError('User authentication required')
|
||||||
|
|
||||||
|
user_id_uuid = UUID(user_id_str)
|
||||||
|
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||||
|
|
||||||
|
# Filter by organization ID to ensure conversations are isolated per organization
|
||||||
|
user = await self._get_current_user()
|
||||||
|
if user and user.current_org_id is not None:
|
||||||
|
query = query.where(
|
||||||
|
StoredConversationMetadataSaas.org_id == user.current_org_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
async def _secure_select(self):
|
async def _secure_select(self):
|
||||||
query = (
|
query = (
|
||||||
@@ -38,13 +90,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
|||||||
)
|
)
|
||||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||||
)
|
)
|
||||||
|
return await self._apply_user_and_org_filter(query)
|
||||||
user_id_str = await self.user_context.get_user_id()
|
|
||||||
if user_id_str:
|
|
||||||
user_id_uuid = UUID(user_id_str)
|
|
||||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
|
||||||
|
|
||||||
return query
|
|
||||||
|
|
||||||
async def _secure_select_with_saas_metadata(self):
|
async def _secure_select_with_saas_metadata(self):
|
||||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||||
@@ -57,13 +103,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
|||||||
)
|
)
|
||||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||||
)
|
)
|
||||||
|
return await self._apply_user_and_org_filter(query)
|
||||||
user_id_str = await self.user_context.get_user_id()
|
|
||||||
if user_id_str:
|
|
||||||
user_id_uuid = UUID(user_id_str)
|
|
||||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
|
||||||
|
|
||||||
return query
|
|
||||||
|
|
||||||
async def search_app_conversation_info(
|
async def search_app_conversation_info(
|
||||||
self,
|
self,
|
||||||
@@ -155,21 +195,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
|||||||
"""Count conversations matching the given filters with SAAS metadata."""
|
"""Count conversations matching the given filters with SAAS metadata."""
|
||||||
query = (
|
query = (
|
||||||
select(func.count(StoredConversationMetadata.conversation_id))
|
select(func.count(StoredConversationMetadata.conversation_id))
|
||||||
.select_from(
|
.join(
|
||||||
StoredConversationMetadata.join(
|
StoredConversationMetadataSaas,
|
||||||
StoredConversationMetadataSaas,
|
StoredConversationMetadata.conversation_id
|
||||||
StoredConversationMetadata.conversation_id
|
== StoredConversationMetadataSaas.conversation_id,
|
||||||
== StoredConversationMetadataSaas.conversation_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply user filtering
|
# Apply user and organization filtering
|
||||||
user_id_str = await self.user_context.get_user_id()
|
query = await self._apply_user_and_org_filter(query)
|
||||||
if user_id_str:
|
|
||||||
user_id_uuid = UUID(user_id_str)
|
|
||||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
|
||||||
|
|
||||||
query = self._apply_filters_with_saas_metadata(
|
query = self._apply_filters_with_saas_metadata(
|
||||||
query=query,
|
query=query,
|
||||||
|
|||||||
@@ -10,8 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
|
from storage.base import Base
|
||||||
|
from storage.org import Org
|
||||||
|
from storage.user import User
|
||||||
|
|
||||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||||
SaasSQLAppConversationInfoService,
|
SaasSQLAppConversationInfoService,
|
||||||
@@ -20,10 +24,15 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
|||||||
AppConversationInfo,
|
AppConversationInfo,
|
||||||
)
|
)
|
||||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||||
from openhands.app_server.utils.sql_utils import Base
|
|
||||||
from openhands.integrations.service_types import ProviderType
|
from openhands.integrations.service_types import ProviderType
|
||||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
|
||||||
|
# Test UUIDs
|
||||||
|
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
|
||||||
|
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
|
||||||
|
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
|
||||||
|
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def async_engine():
|
async def async_engine():
|
||||||
@@ -55,6 +64,41 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
|||||||
yield db_session
|
yield db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create an async session with pre-populated Org and User rows for testing."""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
|
# Insert Orgs first (required for User foreign key)
|
||||||
|
org1 = Org(
|
||||||
|
id=ORG1_ID,
|
||||||
|
name='test-org-1',
|
||||||
|
enable_default_condenser=True,
|
||||||
|
enable_proactive_conversation_starters=True,
|
||||||
|
)
|
||||||
|
org2 = Org(
|
||||||
|
id=ORG2_ID,
|
||||||
|
name='test-org-2',
|
||||||
|
enable_default_condenser=True,
|
||||||
|
enable_proactive_conversation_starters=True,
|
||||||
|
)
|
||||||
|
db_session.add(org1)
|
||||||
|
db_session.add(org2)
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
# Insert Users
|
||||||
|
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
|
||||||
|
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
|
||||||
|
db_session.add(user1)
|
||||||
|
db_session.add(user2)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||||
@@ -178,15 +222,26 @@ class TestSaasSQLAppConversationInfoService:
|
|||||||
assert user1_id != user2_id
|
assert user1_id != user2_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_secure_select_includes_user_filtering(
|
async def test_secure_select_includes_user_and_org_filtering(
|
||||||
self,
|
self,
|
||||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
async_session_with_users: AsyncSession,
|
||||||
):
|
):
|
||||||
"""Test that _secure_select method includes user filtering."""
|
"""Test that _secure_select method includes both user_id and org_id filtering."""
|
||||||
# This test verifies that the _secure_select method exists and can be called
|
service = SaasSQLAppConversationInfoService(
|
||||||
# The actual SQL generation is tested implicitly through integration
|
db_session=async_session_with_users,
|
||||||
query = await saas_service_user1._secure_select()
|
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||||
assert query is not None
|
)
|
||||||
|
|
||||||
|
query = await service._secure_select()
|
||||||
|
|
||||||
|
# Convert query to string to verify filters are present
|
||||||
|
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
|
||||||
|
|
||||||
|
# Verify user_id filter is present
|
||||||
|
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
|
||||||
|
|
||||||
|
# Verify org_id filter is present (user1 is in org1)
|
||||||
|
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_to_info_with_user_id_functionality(
|
async def test_to_info_with_user_id_functionality(
|
||||||
@@ -241,100 +296,32 @@ class TestSaasSQLAppConversationInfoService:
|
|||||||
assert result.sandbox_id == 'test-sandbox'
|
assert result.sandbox_id == 'test-sandbox'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_isolation(
|
async def test_user_isolation_different_users(
|
||||||
self,
|
self,
|
||||||
async_session: AsyncSession,
|
async_session_with_users: AsyncSession,
|
||||||
multiple_conversation_infos: list[AppConversationInfo],
|
|
||||||
):
|
):
|
||||||
"""Test that user isolation works correctly."""
|
"""Test that different users cannot see each other's conversations."""
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
from storage.user import User
|
|
||||||
|
|
||||||
# Mock the database session execute method to return mock users
|
|
||||||
# This mock intercepts User queries and returns a mock user object
|
|
||||||
# with user_id and org_id the same as the user_id_uuid from the query
|
|
||||||
original_execute = async_session.execute
|
|
||||||
|
|
||||||
async def mock_execute(query):
|
|
||||||
query_str = str(query)
|
|
||||||
|
|
||||||
# Check if this is a User query
|
|
||||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
|
||||||
# Extract the UUID from the query parameters
|
|
||||||
# The query will have bound parameters, we need to get the UUID value
|
|
||||||
if hasattr(query, 'compile'):
|
|
||||||
try:
|
|
||||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
|
||||||
query_with_params = str(compiled)
|
|
||||||
|
|
||||||
# Extract UUID from the query string
|
|
||||||
import re
|
|
||||||
|
|
||||||
# Try both formats: with dashes and without dashes
|
|
||||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
|
||||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
|
||||||
|
|
||||||
uuid_match = re.search(
|
|
||||||
uuid_pattern_with_dashes, query_with_params
|
|
||||||
)
|
|
||||||
if not uuid_match:
|
|
||||||
uuid_match = re.search(
|
|
||||||
uuid_pattern_without_dashes, query_with_params
|
|
||||||
)
|
|
||||||
|
|
||||||
if uuid_match:
|
|
||||||
user_id_str = uuid_match.group(0)
|
|
||||||
# If the UUID doesn't have dashes, add them
|
|
||||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
|
||||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
|
||||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
|
||||||
user_id_uuid = UUID(user_id_str)
|
|
||||||
|
|
||||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
|
||||||
mock_user = MagicMock(spec=User)
|
|
||||||
mock_user.id = user_id_uuid
|
|
||||||
mock_user.current_org_id = user_id_uuid
|
|
||||||
|
|
||||||
# Create a mock result
|
|
||||||
mock_result = MagicMock()
|
|
||||||
mock_result.scalar_one_or_none.return_value = mock_user
|
|
||||||
return mock_result
|
|
||||||
except Exception:
|
|
||||||
# If there's any error in parsing, fall back to original execute
|
|
||||||
pass
|
|
||||||
|
|
||||||
# For all other queries, use the original execute method
|
|
||||||
return await original_execute(query)
|
|
||||||
|
|
||||||
# Apply the mock
|
|
||||||
async_session.execute = mock_execute
|
|
||||||
|
|
||||||
# Create services for different users
|
# Create services for different users
|
||||||
user1_service = SaasSQLAppConversationInfoService(
|
user1_service = SaasSQLAppConversationInfoService(
|
||||||
db_session=async_session,
|
db_session=async_session_with_users,
|
||||||
user_context=SpecifyUserContext(
|
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||||
user_id='a1111111-1111-1111-1111-111111111111'
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
user2_service = SaasSQLAppConversationInfoService(
|
user2_service = SaasSQLAppConversationInfoService(
|
||||||
db_session=async_session,
|
db_session=async_session_with_users,
|
||||||
user_context=SpecifyUserContext(
|
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||||
user_id='b2222222-2222-2222-2222-222222222222'
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create conversations for different users
|
# Create conversations for different users
|
||||||
user1_info = AppConversationInfo(
|
user1_info = AppConversationInfo(
|
||||||
id=uuid4(),
|
id=uuid4(),
|
||||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
created_by_user_id=str(USER1_ID),
|
||||||
sandbox_id='sandbox_user1',
|
sandbox_id='sandbox_user1',
|
||||||
title='User 1 Conversation',
|
title='User 1 Conversation',
|
||||||
)
|
)
|
||||||
|
|
||||||
user2_info = AppConversationInfo(
|
user2_info = AppConversationInfo(
|
||||||
id=uuid4(),
|
id=uuid4(),
|
||||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
created_by_user_id=str(USER2_ID),
|
||||||
sandbox_id='sandbox_user2',
|
sandbox_id='sandbox_user2',
|
||||||
title='User 2 Conversation',
|
title='User 2 Conversation',
|
||||||
)
|
)
|
||||||
@@ -346,18 +333,12 @@ class TestSaasSQLAppConversationInfoService:
|
|||||||
# User 1 should only see their conversation
|
# User 1 should only see their conversation
|
||||||
user1_page = await user1_service.search_app_conversation_info()
|
user1_page = await user1_service.search_app_conversation_info()
|
||||||
assert len(user1_page.items) == 1
|
assert len(user1_page.items) == 1
|
||||||
assert (
|
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
|
||||||
user1_page.items[0].created_by_user_id
|
|
||||||
== 'a1111111-1111-1111-1111-111111111111'
|
|
||||||
)
|
|
||||||
|
|
||||||
# User 2 should only see their conversation
|
# User 2 should only see their conversation
|
||||||
user2_page = await user2_service.search_app_conversation_info()
|
user2_page = await user2_service.search_app_conversation_info()
|
||||||
assert len(user2_page.items) == 1
|
assert len(user2_page.items) == 1
|
||||||
assert (
|
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
|
||||||
user2_page.items[0].created_by_user_id
|
|
||||||
== 'b2222222-2222-2222-2222-222222222222'
|
|
||||||
)
|
|
||||||
|
|
||||||
# User 1 should not be able to get user 2's conversation
|
# User 1 should not be able to get user 2's conversation
|
||||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||||
@@ -366,3 +347,142 @@ class TestSaasSQLAppConversationInfoService:
|
|||||||
# User 2 should not be able to get user 1's conversation
|
# User 2 should not be able to get user 1's conversation
|
||||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||||
assert user1_from_user2 is None
|
assert user1_from_user2 is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_same_user_org_switching_isolation(
|
||||||
|
self,
|
||||||
|
async_session_with_users: AsyncSession,
|
||||||
|
):
|
||||||
|
"""Test that the same user switching orgs cannot see conversations from other orgs.
|
||||||
|
|
||||||
|
This tests the actual bug scenario: a user creates a conversation in org1,
|
||||||
|
then switches to org2, and should NOT see org1's conversations.
|
||||||
|
"""
|
||||||
|
# Create service for user1 in org1
|
||||||
|
user1_service_org1 = SaasSQLAppConversationInfoService(
|
||||||
|
db_session=async_session_with_users,
|
||||||
|
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a conversation while user is in org1
|
||||||
|
conv_in_org1 = AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id=str(USER1_ID),
|
||||||
|
sandbox_id='sandbox_org1',
|
||||||
|
title='Conversation in Org 1',
|
||||||
|
)
|
||||||
|
await user1_service_org1.save_app_conversation_info(conv_in_org1)
|
||||||
|
|
||||||
|
# Verify user can see the conversation in org1
|
||||||
|
page_in_org1 = await user1_service_org1.search_app_conversation_info()
|
||||||
|
assert len(page_in_org1.items) == 1
|
||||||
|
assert page_in_org1.items[0].title == 'Conversation in Org 1'
|
||||||
|
|
||||||
|
# Simulate user switching to org2 by updating current_org_id using ORM
|
||||||
|
result = await async_session_with_users.execute(
|
||||||
|
select(User).where(User.id == USER1_ID)
|
||||||
|
)
|
||||||
|
user_to_update = result.scalars().first()
|
||||||
|
user_to_update.current_org_id = ORG2_ID
|
||||||
|
await async_session_with_users.commit()
|
||||||
|
# Clear SQLAlchemy's identity map cache to simulate a new request
|
||||||
|
async_session_with_users.expire_all()
|
||||||
|
|
||||||
|
# Create new service instance (simulating a new request after org switch)
|
||||||
|
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||||
|
db_session=async_session_with_users,
|
||||||
|
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User should NOT see org1's conversations after switching to org2
|
||||||
|
page_in_org2 = await user1_service_org2.search_app_conversation_info()
|
||||||
|
assert (
|
||||||
|
len(page_in_org2.items) == 0
|
||||||
|
), 'User should not see conversations from org1 after switching to org2'
|
||||||
|
|
||||||
|
# User should not be able to get the specific conversation from org1
|
||||||
|
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
|
||||||
|
conv_in_org1.id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
conv_from_org2 is None
|
||||||
|
), 'User should not be able to access org1 conversation from org2'
|
||||||
|
|
||||||
|
# Now create a conversation in org2
|
||||||
|
conv_in_org2 = AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id=str(USER1_ID),
|
||||||
|
sandbox_id='sandbox_org2',
|
||||||
|
title='Conversation in Org 2',
|
||||||
|
)
|
||||||
|
await user1_service_org2.save_app_conversation_info(conv_in_org2)
|
||||||
|
|
||||||
|
# User should only see org2's conversation
|
||||||
|
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
|
||||||
|
assert len(page_in_org2_after.items) == 1
|
||||||
|
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
|
||||||
|
|
||||||
|
# Switch back to org1 and verify isolation works both ways
|
||||||
|
result = await async_session_with_users.execute(
|
||||||
|
select(User).where(User.id == USER1_ID)
|
||||||
|
)
|
||||||
|
user_to_update = result.scalars().first()
|
||||||
|
user_to_update.current_org_id = ORG1_ID
|
||||||
|
await async_session_with_users.commit()
|
||||||
|
async_session_with_users.expire_all()
|
||||||
|
|
||||||
|
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
|
||||||
|
db_session=async_session_with_users,
|
||||||
|
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User should only see org1's conversation now
|
||||||
|
page_back_in_org1 = (
|
||||||
|
await user1_service_back_to_org1.search_app_conversation_info()
|
||||||
|
)
|
||||||
|
assert len(page_back_in_org1.items) == 1
|
||||||
|
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_respects_org_isolation(
|
||||||
|
self,
|
||||||
|
async_session_with_users: AsyncSession,
|
||||||
|
):
|
||||||
|
"""Test that count_app_conversation_info respects org isolation."""
|
||||||
|
# Create service for user1 in org1
|
||||||
|
user1_service = SaasSQLAppConversationInfoService(
|
||||||
|
db_session=async_session_with_users,
|
||||||
|
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create conversations in org1
|
||||||
|
for i in range(3):
|
||||||
|
conv = AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id=str(USER1_ID),
|
||||||
|
sandbox_id=f'sandbox_org1_{i}',
|
||||||
|
title=f'Org1 Conversation {i}',
|
||||||
|
)
|
||||||
|
await user1_service.save_app_conversation_info(conv)
|
||||||
|
|
||||||
|
# Count should be 3
|
||||||
|
count_org1 = await user1_service.count_app_conversation_info()
|
||||||
|
assert count_org1 == 3
|
||||||
|
|
||||||
|
# Switch to org2 using ORM
|
||||||
|
result = await async_session_with_users.execute(
|
||||||
|
select(User).where(User.id == USER1_ID)
|
||||||
|
)
|
||||||
|
user_to_update = result.scalars().first()
|
||||||
|
user_to_update.current_org_id = ORG2_ID
|
||||||
|
await async_session_with_users.commit()
|
||||||
|
async_session_with_users.expire_all()
|
||||||
|
|
||||||
|
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||||
|
db_session=async_session_with_users,
|
||||||
|
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count should be 0 in org2
|
||||||
|
count_org2 = await user1_service_org2.count_app_conversation_info()
|
||||||
|
assert count_org2 == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user