diff --git a/enterprise/integrations/jira/jira_payload.py b/enterprise/integrations/jira/jira_payload.py index 86032ecd6b..9b0d75ecd0 100644 --- a/enterprise/integrations/jira/jira_payload.py +++ b/enterprise/integrations/jira/jira_payload.py @@ -212,8 +212,6 @@ class JiraPayloadParser: missing.append('issue.id') if not issue_key: missing.append('issue.key') - if not user_email: - missing.append('user.emailAddress') if not display_name: missing.append('user.displayName') if not account_id: diff --git a/enterprise/server/routes/integration/jira.py b/enterprise/server/routes/integration/jira.py index 060baa7224..3096734f5d 100644 --- a/enterprise/server/routes/integration/jira.py +++ b/enterprise/server/routes/integration/jira.py @@ -308,10 +308,11 @@ async def jira_events( logger.info(f'Processing new Jira webhook event: {signature}') redis_client.setex(key, 300, '1') - # Process the webhook + # Process the webhook in background after returning response. + # Note: For async functions, BackgroundTasks runs them in the same event loop + # (not a thread pool), so asyncpg connections work correctly. message_payload = {'payload': payload} message = Message(source=SourceType.JIRA, message=message_payload) - background_tasks.add_task(jira_manager.receive_message, message) return JSONResponse({'success': True}) diff --git a/enterprise/server/utils/conversation_callback_utils.py b/enterprise/server/utils/conversation_callback_utils.py index 9224e686bf..5593448616 100644 --- a/enterprise/server/utils/conversation_callback_utils.py +++ b/enterprise/server/utils/conversation_callback_utils.py @@ -4,13 +4,14 @@ import pickle from datetime import datetime from server.logger import logger +from sqlalchemy import and_, select from storage.conversation_callback import ( CallbackStatus, ConversationCallback, ConversationCallbackProcessor, ) from storage.conversation_work import ConversationWork -from storage.database import session_maker +from storage.database import a_session_maker, session_maker from storage.stored_conversation_metadata import StoredConversationMetadata from openhands.core.config import load_openhands_config @@ -79,15 +80,16 @@ async def invoke_conversation_callbacks( conversation_id: The conversation ID to process callbacks for observation: The AgentStateChangedObservation that triggered the callback """ - with session_maker() as session: - callbacks = ( - session.query(ConversationCallback) - .filter( - ConversationCallback.conversation_id == conversation_id, - ConversationCallback.status == CallbackStatus.ACTIVE, + async with a_session_maker() as session: + result = await session.execute( + select(ConversationCallback).filter( + and_( + ConversationCallback.conversation_id == conversation_id, + ConversationCallback.status == CallbackStatus.ACTIVE, + ) ) - .all() ) + callbacks = result.scalars().all() for callback in callbacks: try: @@ -115,7 +117,7 @@ async def invoke_conversation_callbacks( callback.status = CallbackStatus.ERROR callback.updated_at = datetime.now() - session.commit() + await session.commit() def update_conversation_metadata(conversation_id: str, content: dict): diff --git a/enterprise/storage/jira_integration_store.py b/enterprise/storage/jira_integration_store.py index db353732bb..e1680c94d1 100644 --- a/enterprise/storage/jira_integration_store.py +++ b/enterprise/storage/jira_integration_store.py @@ -3,7 +3,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Optional -from storage.database import session_maker +from sqlalchemy import and_, select +from storage.database import a_session_maker from storage.jira_conversation import JiraConversation from storage.jira_user import JiraUser from storage.jira_workspace import JiraWorkspace @@ -35,10 +36,10 @@ class JiraIntegrationStore: status=status, ) - with session_maker() as session: + async with a_session_maker() as session: session.add(workspace) - session.commit() - session.refresh(workspace) + await session.commit() + await session.refresh(workspace) logger.info(f'[Jira] Created workspace {workspace.name}') return workspace @@ -53,11 +54,12 @@ class JiraIntegrationStore: status: Optional[str] = None, ) -> JiraWorkspace: """Update an existing Jira workspace with encrypted sensitive data.""" - with session_maker() as session: + async with a_session_maker() as session: # Find existing workspace by ID - workspace = ( - session.query(JiraWorkspace).filter(JiraWorkspace.id == id).first() + result = await session.execute( + select(JiraWorkspace).filter(JiraWorkspace.id == id) ) + workspace = result.scalars().first() if not workspace: raise ValueError(f'Workspace with ID "{id}" not found') @@ -77,11 +79,11 @@ class JiraIntegrationStore: if status is not None: workspace.status = status - session.commit() - session.refresh(workspace) + await session.commit() + await session.refresh(workspace) - logger.info(f'[Jira] Updated workspace {workspace.name}') - return workspace + logger.info(f'[Jira] Updated workspace {workspace.name}') + return workspace async def create_workspace_link( self, @@ -99,10 +101,10 @@ class JiraIntegrationStore: status=status, ) - with session_maker() as session: + async with a_session_maker() as session: session.add(jira_user) - session.commit() - session.refresh(jira_user) + await session.commit() + await session.refresh(jira_user) logger.info( f'[Jira] Created user {jira_user.id} for workspace {jira_workspace_id}' @@ -111,75 +113,77 @@ class JiraIntegrationStore: async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraWorkspace]: """Retrieve workspace by ID.""" - with session_maker() as session: - return ( - session.query(JiraWorkspace) - .filter(JiraWorkspace.id == workspace_id) - .first() + async with a_session_maker() as session: + result = await session.execute( + select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id) ) + return result.scalars().first() async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None: """Retrieve workspace by name.""" - with session_maker() as session: - return ( - session.query(JiraWorkspace) - .filter(JiraWorkspace.name == workspace_name.lower()) - .first() + async with a_session_maker() as session: + result = await session.execute( + select(JiraWorkspace).filter( + JiraWorkspace.name == workspace_name.lower() + ) ) + return result.scalars().first() async def get_user_by_active_workspace( self, keycloak_user_id: str ) -> Optional[JiraUser]: """Get Jira user by Keycloak user ID.""" - with session_maker() as session: - return ( - session.query(JiraUser) - .filter( - JiraUser.keycloak_user_id == keycloak_user_id, - JiraUser.status == 'active', + async with a_session_maker() as session: + result = await session.execute( + select(JiraUser).filter( + and_( + JiraUser.keycloak_user_id == keycloak_user_id, + JiraUser.status == 'active', + ) ) - .first() ) + return result.scalars().first() async def get_user_by_keycloak_id_and_workspace( self, keycloak_user_id: str, jira_workspace_id: int ) -> Optional[JiraUser]: """Get Jira user by Keycloak user ID and workspace ID.""" - with session_maker() as session: - return ( - session.query(JiraUser) - .filter( - JiraUser.keycloak_user_id == keycloak_user_id, - JiraUser.jira_workspace_id == jira_workspace_id, + async with a_session_maker() as session: + result = await session.execute( + select(JiraUser).filter( + and_( + JiraUser.keycloak_user_id == keycloak_user_id, + JiraUser.jira_workspace_id == jira_workspace_id, + ) ) - .first() ) + return result.scalars().first() async def get_active_user( self, jira_user_id: str, jira_workspace_id: int ) -> Optional[JiraUser]: """Get Jira user by Keycloak user ID and workspace ID.""" - with session_maker() as session: - return ( - session.query(JiraUser) - .filter( - JiraUser.jira_user_id == jira_user_id, - JiraUser.jira_workspace_id == jira_workspace_id, - JiraUser.status == 'active', + async with a_session_maker() as session: + result = await session.execute( + select(JiraUser).filter( + and_( + JiraUser.jira_user_id == jira_user_id, + JiraUser.jira_workspace_id == jira_workspace_id, + JiraUser.status == 'active', + ) ) - .first() ) + return result.scalars().first() async def update_user_integration_status( self, keycloak_user_id: str, status: str ) -> JiraUser: """Update Jira user integration status.""" - with session_maker() as session: - jira_user = ( - session.query(JiraUser) - .filter(JiraUser.keycloak_user_id == keycloak_user_id) - .first() + async with a_session_maker() as session: + result = await session.execute( + select(JiraUser).filter(JiraUser.keycloak_user_id == keycloak_user_id) ) + jira_user = result.scalars().first() if not jira_user: raise ValueError( @@ -187,60 +191,61 @@ class JiraIntegrationStore: ) jira_user.status = status - session.commit() - session.refresh(jira_user) + await session.commit() + await session.refresh(jira_user) logger.info(f'[Jira] Updated user {keycloak_user_id} status to {status}') return jira_user async def deactivate_workspace(self, workspace_id: int): """Deactivate the workspace and all user links for a given workspace.""" - with session_maker() as session: - users = ( - session.query(JiraUser) - .filter( - JiraUser.jira_workspace_id == workspace_id, - JiraUser.status == 'active', + async with a_session_maker() as session: + result = await session.execute( + select(JiraUser).filter( + and_( + JiraUser.jira_workspace_id == workspace_id, + JiraUser.status == 'active', + ) ) - .all() ) + users = result.scalars().all() for user in users: user.status = 'inactive' session.add(user) - workspace = ( - session.query(JiraWorkspace) - .filter(JiraWorkspace.id == workspace_id) - .first() + result = await session.execute( + select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id) ) + workspace = result.scalars().first() if workspace: workspace.status = 'inactive' session.add(workspace) - session.commit() + await session.commit() logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}') async def create_conversation(self, jira_conversation: JiraConversation) -> None: """Create a new Jira conversation record.""" - with session_maker() as session: + async with a_session_maker() as session: session.add(jira_conversation) - session.commit() + await session.commit() async def get_user_conversations_by_issue_id( self, issue_id: str, jira_user_id: int ) -> JiraConversation | None: """Get a Jira conversation by issue ID and jira user ID.""" - with session_maker() as session: - return ( - session.query(JiraConversation) - .filter( - JiraConversation.issue_id == issue_id, - JiraConversation.jira_user_id == jira_user_id, + async with a_session_maker() as session: + result = await session.execute( + select(JiraConversation).filter( + and_( + JiraConversation.issue_id == issue_id, + JiraConversation.jira_user_id == jira_user_id, + ) ) - .first() ) + return result.scalars().first() @classmethod def get_instance(cls) -> JiraIntegrationStore: diff --git a/enterprise/storage/saas_conversation_store.py b/enterprise/storage/saas_conversation_store.py index 7e43dfe471..7eb4919ad9 100644 --- a/enterprise/storage/saas_conversation_store.py +++ b/enterprise/storage/saas_conversation_store.py @@ -234,6 +234,10 @@ class SaasConversationStore(ConversationStore): cls, config: OpenHandsConfig, user_id: str | None ) -> ConversationStore: # user_id should not be None in SaaS, should we raise? - user = await UserStore.get_user_by_id_async(user_id) + # Use sync version because this method can be called from call_async_from_sync + # (e.g., from _create_conversation_update_callback in standalone_conversation_manager.py) + # which creates a new event loop. Using async DB operations in that context would + # cause asyncpg connection errors since connections are tied to the original event loop. + user = UserStore.get_user_by_id(user_id) org_id = user.current_org_id if user else None return SaasConversationStore(str(user_id), org_id, session_maker) diff --git a/enterprise/tests/unit/integrations/jira/test_jira_payload.py b/enterprise/tests/unit/integrations/jira/test_jira_payload.py new file mode 100644 index 0000000000..2f1a8deb7b --- /dev/null +++ b/enterprise/tests/unit/integrations/jira/test_jira_payload.py @@ -0,0 +1,268 @@ +""" +Tests for JiraPayloadParser. + +These tests verify the parsing behavior of Jira webhook payloads, +including the handling of optional fields like user_email which +may not be present in webhook payloads from Jira. +""" + +import pytest +from integrations.jira.jira_payload import ( + JiraEventType, + JiraPayloadError, + JiraPayloadParser, + JiraPayloadSkipped, + JiraPayloadSuccess, +) + + +@pytest.fixture +def parser(): + """Create a JiraPayloadParser with standard OpenHands labels.""" + return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands') + + +@pytest.fixture +def valid_label_payload(): + """Create a valid jira:issue_updated payload with OpenHands label.""" + return { + 'webhookEvent': 'jira:issue_updated', + 'issue': { + 'id': '12345', + 'key': 'TEST-123', + 'self': 'https://test.atlassian.net/rest/api/2/issue/12345', + }, + 'user': { + 'displayName': 'Test User', + 'accountId': 'account-123', + 'emailAddress': 'test@example.com', + }, + 'changelog': { + 'items': [ + { + 'field': 'labels', + 'toString': 'openhands', + } + ] + }, + } + + +@pytest.fixture +def valid_comment_payload(): + """Create a valid comment_created payload with OpenHands mention.""" + return { + 'webhookEvent': 'comment_created', + 'issue': { + 'id': '12345', + 'key': 'TEST-123', + 'self': 'https://test.atlassian.net/rest/api/2/issue/12345', + }, + 'comment': { + 'body': '@openhands please fix this bug', + 'author': { + 'displayName': 'Test User', + 'accountId': 'account-123', + 'emailAddress': 'test@example.com', + }, + }, + } + + +class TestUserEmailOptional: + """Tests verifying user_email is optional in webhook payloads. + + Jira webhooks may not include emailAddress in the user data. + The parser should accept payloads without this field. + """ + + def test_label_event_succeeds_without_email_address( + self, parser, valid_label_payload + ): + """Verify label event parsing succeeds when emailAddress is missing.""" + # Arrange - remove emailAddress from user data + del valid_label_payload['user']['emailAddress'] + + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadSuccess) + assert result.payload.user_email == '' + assert result.payload.display_name == 'Test User' + assert result.payload.account_id == 'account-123' + + def test_comment_event_succeeds_without_email_address( + self, parser, valid_comment_payload + ): + """Verify comment event parsing succeeds when emailAddress is missing.""" + # Arrange - remove emailAddress from author data + del valid_comment_payload['comment']['author']['emailAddress'] + + # Act + result = parser.parse(valid_comment_payload) + + # Assert + assert isinstance(result, JiraPayloadSuccess) + assert result.payload.user_email == '' + assert result.payload.display_name == 'Test User' + assert result.payload.account_id == 'account-123' + + def test_user_email_preserved_when_present(self, parser, valid_label_payload): + """Verify user_email is captured when emailAddress is present.""" + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadSuccess) + assert result.payload.user_email == 'test@example.com' + + +class TestRequiredFieldValidation: + """Tests verifying required fields are still validated.""" + + def test_missing_issue_id_returns_error(self, parser, valid_label_payload): + """Verify parsing fails when issue.id is missing.""" + # Arrange + del valid_label_payload['issue']['id'] + + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadError) + assert 'issue.id' in result.error + + def test_missing_issue_key_returns_error(self, parser, valid_label_payload): + """Verify parsing fails when issue.key is missing.""" + # Arrange + del valid_label_payload['issue']['key'] + + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadError) + assert 'issue.key' in result.error + + def test_missing_display_name_returns_error(self, parser, valid_label_payload): + """Verify parsing fails when user.displayName is missing.""" + # Arrange + del valid_label_payload['user']['displayName'] + + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadError) + assert 'displayName' in result.error + + def test_missing_account_id_returns_error(self, parser, valid_label_payload): + """Verify parsing fails when user.accountId is missing.""" + # Arrange + del valid_label_payload['user']['accountId'] + + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadError) + assert 'accountId' in result.error + + def test_missing_issue_self_url_returns_error(self, parser, valid_label_payload): + """Verify parsing fails when issue.self URL is missing.""" + # Arrange + del valid_label_payload['issue']['self'] + + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadError) + assert 'workspace_name' in result.error or 'base_api_url' in result.error + + +class TestEventTypeDetection: + """Tests for webhook event type detection.""" + + def test_issue_updated_with_label_returns_labeled_ticket( + self, parser, valid_label_payload + ): + """Verify jira:issue_updated with label is detected as LABELED_TICKET.""" + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadSuccess) + assert result.payload.event_type == JiraEventType.LABELED_TICKET + + def test_comment_created_with_mention_returns_comment_mention( + self, parser, valid_comment_payload + ): + """Verify comment_created with mention is detected as COMMENT_MENTION.""" + # Act + result = parser.parse(valid_comment_payload) + + # Assert + assert isinstance(result, JiraPayloadSuccess) + assert result.payload.event_type == JiraEventType.COMMENT_MENTION + + def test_unhandled_event_type_returns_skipped(self, parser): + """Verify unknown event types are skipped.""" + # Arrange + payload = {'webhookEvent': 'jira:issue_deleted'} + + # Act + result = parser.parse(payload) + + # Assert + assert isinstance(result, JiraPayloadSkipped) + assert 'Unhandled' in result.skip_reason + + +class TestLabelFiltering: + """Tests for OpenHands label filtering.""" + + def test_label_event_without_openhands_label_skipped( + self, parser, valid_label_payload + ): + """Verify label events without OpenHands label are skipped.""" + # Arrange - change label to something else + valid_label_payload['changelog']['items'][0]['toString'] = 'other-label' + + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadSkipped) + assert 'openhands' in result.skip_reason + + +class TestCommentFiltering: + """Tests for OpenHands comment mention filtering.""" + + def test_comment_without_mention_skipped(self, parser, valid_comment_payload): + """Verify comments without OpenHands mention are skipped.""" + # Arrange - remove mention from comment body + valid_comment_payload['comment']['body'] = 'Please fix this bug' + + # Act + result = parser.parse(valid_comment_payload) + + # Assert + assert isinstance(result, JiraPayloadSkipped) + assert '@openhands' in result.skip_reason + + +class TestWorkspaceExtraction: + """Tests for workspace name extraction from issue URL.""" + + def test_workspace_name_extracted_from_self_url(self, parser, valid_label_payload): + """Verify workspace name is extracted from issue self URL.""" + # Act + result = parser.parse(valid_label_payload) + + # Assert + assert isinstance(result, JiraPayloadSuccess) + assert result.payload.workspace_name == 'test.atlassian.net' + assert result.payload.base_api_url == 'https://test.atlassian.net' diff --git a/enterprise/tests/unit/server/test_conversation_callback_utils.py b/enterprise/tests/unit/server/test_conversation_callback_utils.py index 128f2d82d2..2445b6a7da 100644 --- a/enterprise/tests/unit/server/test_conversation_callback_utils.py +++ b/enterprise/tests/unit/server/test_conversation_callback_utils.py @@ -399,3 +399,135 @@ class TestUpdateActiveWorkingSeconds: assert conversation_work.seconds == 23.0 assert conversation_work.conversation_id == conversation_id assert conversation_work.user_id == user_id + + +class TestInvokeConversationCallbacks: + """Tests for invoke_conversation_callbacks function. + + This function uses async database sessions (a_session_maker) to query + and invoke callbacks for a conversation. + """ + + @pytest.fixture + def mock_observation(self): + """Create a mock AgentStateChangedObservation.""" + + observation = Mock(spec=AgentStateChangedObservation) + observation.agent_state = AgentState.FINISHED + return observation + + @pytest.fixture + def create_mock_async_session(self): + """Factory to create properly mocked async session context manager.""" + from contextlib import asynccontextmanager + from unittest.mock import AsyncMock + + def _create(callbacks_list): + mock_session = Mock() + mock_result = Mock() + mock_result.scalars.return_value.all.return_value = callbacks_list + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.commit = AsyncMock(return_value=None) + + @asynccontextmanager + async def mock_context_manager(): + yield mock_session + + return mock_context_manager, mock_session + + return _create + + @pytest.mark.asyncio + async def test_invoke_callbacks_with_active_callbacks( + self, mock_observation, create_mock_async_session + ): + """Test that active callbacks are invoked successfully.""" + from unittest.mock import AsyncMock + + # Arrange + conversation_id = 'test_conversation_callbacks' + mock_processor = AsyncMock(return_value=None) + + # Create a mock callback + mock_callback = Mock() + mock_callback.id = 1 + mock_callback.processor_type = 'test_processor' + mock_callback.get_processor.return_value = mock_processor + + mock_context_manager, mock_session = create_mock_async_session([mock_callback]) + + # Act + with patch( + 'server.utils.conversation_callback_utils.a_session_maker', + mock_context_manager, + ): + from server.utils.conversation_callback_utils import ( + invoke_conversation_callbacks, + ) + + await invoke_conversation_callbacks(conversation_id, mock_observation) + + # Assert + mock_callback.get_processor.assert_called_once() + mock_processor.assert_called_once_with(mock_callback, mock_observation) + + @pytest.mark.asyncio + async def test_invoke_callbacks_with_no_active_callbacks( + self, mock_observation, create_mock_async_session + ): + """Test behavior when no active callbacks exist.""" + # Arrange + conversation_id = 'test_no_callbacks' + + mock_context_manager, mock_session = create_mock_async_session([]) + + # Act + with patch( + 'server.utils.conversation_callback_utils.a_session_maker', + mock_context_manager, + ): + from server.utils.conversation_callback_utils import ( + invoke_conversation_callbacks, + ) + + await invoke_conversation_callbacks(conversation_id, mock_observation) + + # Assert - should complete without errors + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_callbacks_handles_processor_exception( + self, mock_observation, create_mock_async_session + ): + """Test that processor exceptions are caught and callback status is updated.""" + from unittest.mock import AsyncMock + + # Arrange + conversation_id = 'test_callback_error' + mock_processor = AsyncMock(side_effect=Exception('Processor error')) + + mock_callback = Mock() + mock_callback.id = 1 + mock_callback.processor_type = 'failing_processor' + mock_callback.get_processor.return_value = mock_processor + mock_callback.status = 'active' + + mock_context_manager, mock_session = create_mock_async_session([mock_callback]) + + # Act + with patch( + 'server.utils.conversation_callback_utils.a_session_maker', + mock_context_manager, + ), patch('server.utils.conversation_callback_utils.logger') as mock_logger: + from server.utils.conversation_callback_utils import ( + invoke_conversation_callbacks, + ) + from storage.conversation_callback import CallbackStatus + + await invoke_conversation_callbacks(conversation_id, mock_observation) + + # Assert - callback status should be set to ERROR + assert mock_callback.status == CallbackStatus.ERROR + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args + assert error_call[0][0] == 'callback_invocation_failed' diff --git a/enterprise/tests/unit/storage/test_jira_integration_store.py b/enterprise/tests/unit/storage/test_jira_integration_store.py new file mode 100644 index 0000000000..a3420c3f05 --- /dev/null +++ b/enterprise/tests/unit/storage/test_jira_integration_store.py @@ -0,0 +1,232 @@ +""" +Tests for JiraIntegrationStore async methods. + +The store uses async database sessions (a_session_maker) for all operations, +which is critical for avoiding asyncpg event loop issues when called from +FastAPI async endpoints. +""" + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from storage.jira_integration_store import JiraIntegrationStore +from storage.jira_user import JiraUser +from storage.jira_workspace import JiraWorkspace + + +@pytest.fixture +def store(): + """Create a JiraIntegrationStore instance.""" + return JiraIntegrationStore() + + +@pytest.fixture +def create_mock_async_session(): + """Factory to create properly mocked async session context manager.""" + + def _create(query_result=None, all_results=None): + mock_session = Mock() + mock_result = Mock() + + if all_results is not None: + mock_result.scalars.return_value.all.return_value = all_results + else: + mock_result.scalars.return_value.first.return_value = query_result + + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.add = Mock() + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + + @asynccontextmanager + async def mock_context_manager(): + yield mock_session + + return mock_context_manager, mock_session + + return _create + + +class TestJiraIntegrationStoreAsyncMethods: + """Tests verifying JiraIntegrationStore methods use async sessions correctly.""" + + @pytest.mark.asyncio + async def test_get_workspace_by_id_returns_workspace( + self, store, create_mock_async_session + ): + """Test get_workspace_by_id returns workspace when found.""" + # Arrange + mock_workspace = Mock(spec=JiraWorkspace) + mock_workspace.id = 1 + mock_workspace.name = 'test-workspace' + + mock_context_manager, mock_session = create_mock_async_session(mock_workspace) + + # Act + with patch( + 'storage.jira_integration_store.a_session_maker', mock_context_manager + ): + result = await store.get_workspace_by_id(1) + + # Assert + assert result == mock_workspace + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_get_workspace_by_id_returns_none_when_not_found( + self, store, create_mock_async_session + ): + """Test get_workspace_by_id returns None when workspace not found.""" + # Arrange + mock_context_manager, mock_session = create_mock_async_session(None) + + # Act + with patch( + 'storage.jira_integration_store.a_session_maker', mock_context_manager + ): + result = await store.get_workspace_by_id(999) + + # Assert + assert result is None + + @pytest.mark.asyncio + async def test_get_workspace_by_name_normalizes_to_lowercase( + self, store, create_mock_async_session + ): + """Test get_workspace_by_name converts name to lowercase for query.""" + # Arrange + mock_workspace = Mock(spec=JiraWorkspace) + mock_workspace.name = 'test-workspace' + + mock_context_manager, mock_session = create_mock_async_session(mock_workspace) + + # Act + with patch( + 'storage.jira_integration_store.a_session_maker', mock_context_manager + ): + result = await store.get_workspace_by_name('TEST-WORKSPACE') + + # Assert + assert result == mock_workspace + # Verify the query was executed (filter includes lowercase conversion) + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_get_active_user_filters_by_status( + self, store, create_mock_async_session + ): + """Test get_active_user only returns users with active status.""" + # Arrange + mock_user = Mock(spec=JiraUser) + mock_user.jira_user_id = 'jira-123' + mock_user.jira_workspace_id = 1 + mock_user.status = 'active' + + mock_context_manager, mock_session = create_mock_async_session(mock_user) + + # Act + with patch( + 'storage.jira_integration_store.a_session_maker', mock_context_manager + ): + result = await store.get_active_user('jira-123', 1) + + # Assert + assert result == mock_user + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_create_workspace_adds_and_commits( + self, store, create_mock_async_session + ): + """Test create_workspace properly adds, commits, and refreshes.""" + # Arrange + mock_context_manager, mock_session = create_mock_async_session(None) + + # Act + with patch( + 'storage.jira_integration_store.a_session_maker', mock_context_manager + ): + await store.create_workspace( + name='TEST-WORKSPACE', + jira_cloud_id='cloud-123', + admin_user_id='admin-user', + encrypted_webhook_secret='encrypted-secret', + svc_acc_email='svc@test.com', + encrypted_svc_acc_api_key='encrypted-key', + status='active', + ) + + # Assert + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + # Verify workspace was created with lowercase name + added_workspace = mock_session.add.call_args[0][0] + assert added_workspace.name == 'test-workspace' + + @pytest.mark.asyncio + async def test_update_user_integration_status_raises_if_not_found( + self, store, create_mock_async_session + ): + """Test update_user_integration_status raises ValueError if user not found.""" + # Arrange + mock_context_manager, mock_session = create_mock_async_session(None) + + # Act & Assert + with patch( + 'storage.jira_integration_store.a_session_maker', mock_context_manager + ): + with pytest.raises(ValueError) as exc_info: + await store.update_user_integration_status('unknown-user', 'inactive') + + assert 'Jira user not found' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_deactivate_workspace_deactivates_all_users( + self, store, create_mock_async_session + ): + """Test deactivate_workspace sets all users and workspace to inactive.""" + # Arrange + mock_user1 = Mock(spec=JiraUser) + mock_user1.status = 'active' + mock_user2 = Mock(spec=JiraUser) + mock_user2.status = 'active' + + mock_workspace = Mock(spec=JiraWorkspace) + mock_workspace.status = 'active' + + mock_session = Mock() + + # First execute returns users, second returns workspace + call_count = [0] + + def execute_side_effect(*args, **kwargs): + result = Mock() + if call_count[0] == 0: + result.scalars.return_value.all.return_value = [mock_user1, mock_user2] + else: + result.scalars.return_value.first.return_value = mock_workspace + call_count[0] += 1 + return result + + mock_session.execute = AsyncMock(side_effect=execute_side_effect) + mock_session.add = Mock() + mock_session.commit = AsyncMock() + + @asynccontextmanager + async def mock_context_manager(): + yield mock_session + + # Act + with patch( + 'storage.jira_integration_store.a_session_maker', mock_context_manager + ): + await store.deactivate_workspace(1) + + # Assert + assert mock_user1.status == 'inactive' + assert mock_user2.status == 'inactive' + assert mock_workspace.status == 'inactive' + mock_session.commit.assert_called_once() diff --git a/enterprise/tests/unit/test_saas_conversation_store.py b/enterprise/tests/unit/test_saas_conversation_store.py index f4f9a7afe6..a7d000d0ec 100644 --- a/enterprise/tests/unit/test_saas_conversation_store.py +++ b/enterprise/tests/unit/test_saas_conversation_store.py @@ -3,14 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import UUID import pytest +from storage.saas_conversation_store import SaasConversationStore +from storage.user import User +from openhands.core.config.openhands_config import OpenHandsConfig from openhands.storage.data_models.conversation_metadata import ConversationMetadata -# Mock the database module before importing -with patch('storage.database.engine'), patch('storage.database.a_engine'): - from storage.saas_conversation_store import SaasConversationStore - from storage.user import User - @pytest.fixture(autouse=True) def mock_call_sync_from_async(): @@ -166,3 +164,52 @@ async def test_exists(session_maker): assert not await store.exists('exists-test') await store.save_metadata(metadata) assert await store.exists('exists-test') + + +class TestGetInstance: + """Tests for SaasConversationStore.get_instance method. + + The get_instance method uses sync UserStore.get_user_by_id (not async) + because it can be called from call_async_from_sync contexts which create + a new event loop. Using async DB operations in that context would cause + asyncpg connection errors. + """ + + @pytest.mark.asyncio + async def test_get_instance_uses_sync_get_user_by_id(self): + """Verify get_instance calls the sync get_user_by_id, not async version.""" + # Arrange + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + mock_user = MagicMock(spec=User) + mock_user.current_org_id = UUID(user_id) + mock_config = MagicMock(spec=OpenHandsConfig) + + with patch( + 'storage.saas_conversation_store.UserStore.get_user_by_id', + return_value=mock_user, + ) as mock_sync_get_user, patch('storage.saas_conversation_store.session_maker'): + # Act + store = await SaasConversationStore.get_instance(mock_config, user_id) + + # Assert + mock_sync_get_user.assert_called_once_with(user_id) + assert store.user_id == user_id + assert store.org_id == mock_user.current_org_id + + @pytest.mark.asyncio + async def test_get_instance_handles_none_user(self): + """Verify get_instance handles case when user is not found.""" + # Arrange + user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' + mock_config = MagicMock(spec=OpenHandsConfig) + + with patch( + 'storage.saas_conversation_store.UserStore.get_user_by_id', + return_value=None, + ), patch('storage.saas_conversation_store.session_maker'): + # Act + store = await SaasConversationStore.get_instance(mock_config, user_id) + + # Assert + assert store.user_id == user_id + assert store.org_id is None