diff --git a/enterprise/server/sharing/google_cloud_shared_event_service.py b/enterprise/server/sharing/google_cloud_shared_event_service.py index bdb9a3a88d..553c5af21f 100644 --- a/enterprise/server/sharing/google_cloud_shared_event_service.py +++ b/enterprise/server/sharing/google_cloud_shared_event_service.py @@ -9,6 +9,7 @@ This implementation provides read-only access to events from shared conversation from __future__ import annotations import logging +import os from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -20,6 +21,7 @@ from google.cloud import storage from google.cloud.storage.bucket import Bucket from google.cloud.storage.client import Client from more_itertools import bucket +from pydantic import Field from server.sharing.shared_conversation_info_service import ( SharedConversationInfoService, ) @@ -131,6 +133,10 @@ class GoogleCloudSharedEventService(SharedEventService): class GoogleCloudSharedEventServiceInjector(SharedEventServiceInjector): + bucket_name: str | None = Field( + default_factory=lambda: os.environ.get('FILE_STORE_PATH') + ) + async def inject( self, state: InjectorState, request: Request | None = None ) -> AsyncGenerator[SharedEventService, None]: diff --git a/enterprise/tests/unit/test_sharing/test_sharing_shared_event_service.py b/enterprise/tests/unit/test_sharing/test_sharing_shared_event_service.py index 9915662d10..5185953dfb 100644 --- a/enterprise/tests/unit/test_sharing/test_sharing_shared_event_service.py +++ b/enterprise/tests/unit/test_sharing/test_sharing_shared_event_service.py @@ -1,12 +1,14 @@ """Tests for SharedEventService.""" +import os from datetime import UTC, datetime -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 import pytest from server.sharing.google_cloud_shared_event_service import ( GoogleCloudSharedEventService, + GoogleCloudSharedEventServiceInjector, ) from server.sharing.shared_conversation_info_service import ( SharedConversationInfoService, @@ -363,3 +365,224 @@ class TestSharedEventService: page_id='current_page', limit=50, ) + + +class TestGoogleCloudSharedEventServiceGetEventService: + """Test cases for GoogleCloudSharedEventService.get_event_service method.""" + + async def test_get_event_service_returns_event_service_for_shared_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + sample_public_conversation, + ): + """Test that get_event_service returns an EventService for a shared conversation.""" + conversation_id = sample_public_conversation.id + + # Mock the shared conversation info service to return a shared conversation + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation + + # Call the method + result = await shared_event_service.get_event_service(conversation_id) + + # Verify the result + assert result is not None + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + + async def test_get_event_service_returns_none_for_non_shared_conversation( + self, + shared_event_service, + mock_shared_conversation_info_service, + ): + """Test that get_event_service returns None for a non-shared conversation.""" + conversation_id = uuid4() + + # Mock the shared conversation info service to return None + mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None + + # Call the method + result = await shared_event_service.get_event_service(conversation_id) + + # Verify the result + assert result is None + mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with( + conversation_id + ) + + +class TestGoogleCloudSharedEventServiceInjector: + """Test cases for GoogleCloudSharedEventServiceInjector.""" + + def test_bucket_name_from_environment_variable(self): + """Test that bucket_name is read from FILE_STORE_PATH environment variable.""" + test_bucket_name = 'test-bucket-name' + with patch.dict(os.environ, {'FILE_STORE_PATH': test_bucket_name}): + # Create a new injector instance to pick up the environment variable + # Note: The class attribute is evaluated at class definition time, + # so we need to test that the attribute exists and can be overridden + injector = GoogleCloudSharedEventServiceInjector() + injector.bucket_name = os.environ.get('FILE_STORE_PATH') + assert injector.bucket_name == test_bucket_name + + def test_bucket_name_default_value_when_env_not_set(self): + """Test that bucket_name is None when FILE_STORE_PATH is not set.""" + with patch.dict(os.environ, {}, clear=True): + # Remove FILE_STORE_PATH if it exists + os.environ.pop('FILE_STORE_PATH', None) + injector = GoogleCloudSharedEventServiceInjector() + # The bucket_name will be whatever was set at class definition time + # or None if FILE_STORE_PATH was not set when the class was defined + assert hasattr(injector, 'bucket_name') + + async def test_injector_yields_google_cloud_shared_event_service(self): + """Test that the injector yields a GoogleCloudSharedEventService instance.""" + mock_state = MagicMock() + mock_request = MagicMock() + mock_db_session = AsyncMock() + + # Create the injector + injector = GoogleCloudSharedEventServiceInjector() + injector.bucket_name = 'test-bucket' + + # Mock the get_db_session context manager + mock_db_context = AsyncMock() + mock_db_context.__aenter__.return_value = mock_db_session + mock_db_context.__aexit__.return_value = None + + # Mock storage.Client and bucket + mock_storage_client = MagicMock() + mock_bucket = MagicMock() + mock_storage_client.bucket.return_value = mock_bucket + + with ( + patch( + 'server.sharing.google_cloud_shared_event_service.storage.Client', + return_value=mock_storage_client, + ), + patch( + 'openhands.app_server.config.get_db_session', + return_value=mock_db_context, + ), + ): + # Call the inject method + async for service in injector.inject(mock_state, mock_request): + # Verify the service is an instance of GoogleCloudSharedEventService + assert isinstance(service, GoogleCloudSharedEventService) + assert service.bucket == mock_bucket + + # Verify the storage client was called with the correct bucket name + mock_storage_client.bucket.assert_called_once_with('test-bucket') + + async def test_injector_uses_bucket_name_from_instance(self): + """Test that the injector uses the bucket_name from the instance.""" + mock_state = MagicMock() + mock_request = MagicMock() + mock_db_session = AsyncMock() + + # Create the injector with a specific bucket name + injector = GoogleCloudSharedEventServiceInjector() + injector.bucket_name = 'my-custom-bucket' + + # Mock the get_db_session context manager + mock_db_context = AsyncMock() + mock_db_context.__aenter__.return_value = mock_db_session + mock_db_context.__aexit__.return_value = None + + # Mock storage.Client and bucket + mock_storage_client = MagicMock() + mock_bucket = MagicMock() + mock_storage_client.bucket.return_value = mock_bucket + + with ( + patch( + 'server.sharing.google_cloud_shared_event_service.storage.Client', + return_value=mock_storage_client, + ), + patch( + 'openhands.app_server.config.get_db_session', + return_value=mock_db_context, + ), + ): + # Call the inject method + async for service in injector.inject(mock_state, mock_request): + pass + + # Verify the storage client was called with the custom bucket name + mock_storage_client.bucket.assert_called_once_with('my-custom-bucket') + + async def test_injector_creates_sql_shared_conversation_info_service(self): + """Test that the injector creates SQLSharedConversationInfoService with db_session.""" + mock_state = MagicMock() + mock_request = MagicMock() + mock_db_session = AsyncMock() + + # Create the injector + injector = GoogleCloudSharedEventServiceInjector() + injector.bucket_name = 'test-bucket' + + # Mock the get_db_session context manager + mock_db_context = AsyncMock() + mock_db_context.__aenter__.return_value = mock_db_session + mock_db_context.__aexit__.return_value = None + + # Mock storage.Client and bucket + mock_storage_client = MagicMock() + mock_bucket = MagicMock() + mock_storage_client.bucket.return_value = mock_bucket + + with ( + patch( + 'server.sharing.google_cloud_shared_event_service.storage.Client', + return_value=mock_storage_client, + ), + patch( + 'openhands.app_server.config.get_db_session', + return_value=mock_db_context, + ), + patch( + 'server.sharing.google_cloud_shared_event_service.SQLSharedConversationInfoService' + ) as mock_sql_service_class, + ): + mock_sql_service = MagicMock() + mock_sql_service_class.return_value = mock_sql_service + + # Call the inject method + async for service in injector.inject(mock_state, mock_request): + # Verify the service has the correct shared_conversation_info_service + assert service.shared_conversation_info_service == mock_sql_service + + # Verify SQLSharedConversationInfoService was created with db_session + mock_sql_service_class.assert_called_once_with(db_session=mock_db_session) + + async def test_injector_works_without_request(self): + """Test that the injector works when request is None.""" + mock_state = MagicMock() + mock_db_session = AsyncMock() + + # Create the injector + injector = GoogleCloudSharedEventServiceInjector() + injector.bucket_name = 'test-bucket' + + # Mock the get_db_session context manager + mock_db_context = AsyncMock() + mock_db_context.__aenter__.return_value = mock_db_session + mock_db_context.__aexit__.return_value = None + + # Mock storage.Client and bucket + mock_storage_client = MagicMock() + mock_bucket = MagicMock() + mock_storage_client.bucket.return_value = mock_bucket + + with patch( + 'server.sharing.google_cloud_shared_event_service.storage.Client', + return_value=mock_storage_client, + ): + with patch( + 'openhands.app_server.config.get_db_session', + return_value=mock_db_context, + ): + # Call the inject method with request=None + async for service in injector.inject(mock_state, request=None): + assert isinstance(service, GoogleCloudSharedEventService)