mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
fix(backend): add organization filtering to V1 conversation queries (#12923)
This commit is contained in:
@@ -10,8 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
@@ -20,10 +24,15 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
# Test UUIDs
|
||||
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
|
||||
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
|
||||
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
@@ -55,6 +64,41 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session with pre-populated Org and User rows for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
# Insert Orgs first (required for User foreign key)
|
||||
org1 = Org(
|
||||
id=ORG1_ID,
|
||||
name='test-org-1',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
org2 = Org(
|
||||
id=ORG2_ID,
|
||||
name='test-org-2',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
db_session.add(org1)
|
||||
db_session.add(org2)
|
||||
await db_session.flush()
|
||||
|
||||
# Insert Users
|
||||
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
|
||||
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
|
||||
db_session.add(user1)
|
||||
db_session.add(user2)
|
||||
await db_session.commit()
|
||||
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
@@ -178,15 +222,26 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
async def test_secure_select_includes_user_and_org_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
"""Test that _secure_select method includes both user_id and org_id filtering."""
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
query = await service._secure_select()
|
||||
|
||||
# Convert query to string to verify filters are present
|
||||
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
|
||||
|
||||
# Verify user_id filter is present
|
||||
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
|
||||
|
||||
# Verify org_id filter is present (user1 is in org1)
|
||||
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
@@ -241,100 +296,32 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
async def test_user_isolation_different_users(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
"""Test that different users cannot see each other's conversations."""
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
created_by_user_id=str(USER2_ID),
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
@@ -346,18 +333,12 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
@@ -366,3 +347,142 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_user_org_switching_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that the same user switching orgs cannot see conversations from other orgs.
|
||||
|
||||
This tests the actual bug scenario: a user creates a conversation in org1,
|
||||
then switches to org2, and should NOT see org1's conversations.
|
||||
"""
|
||||
# Create service for user1 in org1
|
||||
user1_service_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create a conversation while user is in org1
|
||||
conv_in_org1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org1',
|
||||
title='Conversation in Org 1',
|
||||
)
|
||||
await user1_service_org1.save_app_conversation_info(conv_in_org1)
|
||||
|
||||
# Verify user can see the conversation in org1
|
||||
page_in_org1 = await user1_service_org1.search_app_conversation_info()
|
||||
assert len(page_in_org1.items) == 1
|
||||
assert page_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
# Simulate user switching to org2 by updating current_org_id using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
# Clear SQLAlchemy's identity map cache to simulate a new request
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
# Create new service instance (simulating a new request after org switch)
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should NOT see org1's conversations after switching to org2
|
||||
page_in_org2 = await user1_service_org2.search_app_conversation_info()
|
||||
assert (
|
||||
len(page_in_org2.items) == 0
|
||||
), 'User should not see conversations from org1 after switching to org2'
|
||||
|
||||
# User should not be able to get the specific conversation from org1
|
||||
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
|
||||
conv_in_org1.id
|
||||
)
|
||||
assert (
|
||||
conv_from_org2 is None
|
||||
), 'User should not be able to access org1 conversation from org2'
|
||||
|
||||
# Now create a conversation in org2
|
||||
conv_in_org2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org2',
|
||||
title='Conversation in Org 2',
|
||||
)
|
||||
await user1_service_org2.save_app_conversation_info(conv_in_org2)
|
||||
|
||||
# User should only see org2's conversation
|
||||
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
|
||||
assert len(page_in_org2_after.items) == 1
|
||||
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
|
||||
|
||||
# Switch back to org1 and verify isolation works both ways
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG1_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should only see org1's conversation now
|
||||
page_back_in_org1 = (
|
||||
await user1_service_back_to_org1.search_app_conversation_info()
|
||||
)
|
||||
assert len(page_back_in_org1.items) == 1
|
||||
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_respects_org_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that count_app_conversation_info respects org isolation."""
|
||||
# Create service for user1 in org1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations in org1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_org1_{i}',
|
||||
title=f'Org1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Count should be 3
|
||||
count_org1 = await user1_service.count_app_conversation_info()
|
||||
assert count_org1 == 3
|
||||
|
||||
# Switch to org2 using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Count should be 0 in org2
|
||||
count_org2 = await user1_service_org2.count_app_conversation_info()
|
||||
assert count_org2 == 0
|
||||
|
||||
Reference in New Issue
Block a user