From 5fb431bcc549190065ecdc9ef205769d06cca08e Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Thu, 8 Jan 2026 12:08:11 -0800 Subject: [PATCH] feat: Implement Slack V1 integration following GitHub V1 pattern (#11825) Co-authored-by: openhands Co-authored-by: Tim O'Farrell --- enterprise/integrations/__init__.py | 0 enterprise/integrations/github/github_view.py | 42 +- .../integrations/slack/slack_manager.py | 37 +- enterprise/integrations/slack/slack_types.py | 14 +- .../slack/slack_v1_callback_processor.py | 273 +++++++++++ enterprise/integrations/slack/slack_view.py | 282 ++++++++++-- enterprise/integrations/types.py | 4 - enterprise/integrations/utils.py | 39 +- ...086_add_v1_column_to_slack_conversation.py | 30 ++ .../server/clustered_conversation_manager.py | 3 +- enterprise/server/routes/integration/slack.py | 4 +- enterprise/storage/slack_conversation.py | 3 +- .../storage/slack_conversation_store.py | 3 +- .../slack/test_slack_v1_callback_processor.py | 431 ++++++++++++++++++ .../integrations/slack/test_slack_view.py | 341 ++++++++++++++ .../unit/test_get_user_v1_enabled_setting.py | 108 +++-- .../app_conversation_models.py | 3 +- .../live_status_app_conversation_service.py | 50 +- .../app_server/sandbox/sandbox_service.py | 98 ++++ 19 files changed, 1567 insertions(+), 198 deletions(-) delete mode 100644 enterprise/integrations/__init__.py create mode 100644 enterprise/integrations/slack/slack_v1_callback_processor.py create mode 100644 enterprise/migrations/versions/086_add_v1_column_to_slack_conversation.py create mode 100644 enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py create mode 100644 enterprise/tests/unit/integrations/slack/test_slack_view.py diff --git a/enterprise/integrations/__init__.py b/enterprise/integrations/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/enterprise/integrations/github/github_view.py b/enterprise/integrations/github/github_view.py index 8a0e686f8d..54ce1f4be7 100644 --- a/enterprise/integrations/github/github_view.py +++ b/enterprise/integrations/github/github_view.py @@ -17,6 +17,7 @@ from integrations.utils import ( HOST, HOST_URL, get_oh_labels, + get_user_v1_enabled_setting, has_exact_mention, ) from jinja2 import Environment @@ -55,6 +56,10 @@ from openhands.utils.async_utils import call_sync_from_async OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST) +async def is_v1_enabled_for_github_resolver(user_id: str) -> bool: + return await get_user_v1_enabled_setting(user_id) and ENABLE_V1_GITHUB_RESOLVER + + async def get_user_proactive_conversation_setting(user_id: str | None) -> bool: """Get the user's proactive conversation setting. @@ -88,38 +93,6 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool: return settings.enable_proactive_conversation_starters -async def get_user_v1_enabled_setting(user_id: str) -> bool: - """Get the user's V1 conversation API setting. - - Args: - user_id: The keycloak user ID - - Returns: - True if V1 conversations are enabled for this user, False otherwise - - Note: - This function checks both the global environment variable kill switch AND - the user's individual setting. Both must be true for the function to return true. - """ - # Check the global environment variable first - if not ENABLE_V1_GITHUB_RESOLVER: - return False - - config = get_config() - settings_store = SaasSettingsStore( - user_id=user_id, session_maker=session_maker, config=config - ) - - settings = await call_sync_from_async( - settings_store.get_user_settings_by_keycloak_id, user_id - ) - - if not settings or settings.v1_enabled is None: - return False - - return settings.v1_enabled - - # ================================================= # SECTION: Github view types # ================================================= @@ -191,9 +164,10 @@ class GithubIssue(ResolverViewInterface): async def initialize_new_conversation(self) -> ConversationMetadata: # FIXME: Handle if initialize_conversation returns None - self.v1_enabled = await get_user_v1_enabled_setting( + self.v1_enabled = await is_v1_enabled_for_github_resolver( self.user_info.keycloak_user_id ) + logger.info( f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {self.v1_enabled}' ) @@ -438,7 +412,7 @@ class GithubInlinePRComment(GithubPRComment): def _create_github_v1_callback_processor(self): """Create a V1 callback processor for GitHub integration.""" - from openhands.app_server.event_callback.github_v1_callback_processor import ( + from integrations.github.github_v1_callback_processor import ( GithubV1CallbackProcessor, ) diff --git a/enterprise/integrations/slack/slack_manager.py b/enterprise/integrations/slack/slack_manager.py index 858c7c98f0..8d309c57ad 100644 --- a/enterprise/integrations/slack/slack_manager.py +++ b/enterprise/integrations/slack/slack_manager.py @@ -16,9 +16,8 @@ from integrations.utils import ( OPENHANDS_RESOLVER_TEMPLATES_DIR, get_session_expired_message, ) +from integrations.v1_utils import get_saas_user_auth from jinja2 import Environment, FileSystemLoader -from pydantic import SecretStr -from server.auth.saas_user_auth import SaasUserAuth from server.constants import SLACK_CLIENT_ID from server.utils.conversation_callback_utils import register_callback_processor from slack_sdk.oauth import AuthorizeUrlGenerator @@ -59,17 +58,6 @@ class SlackManager(Manager): if message.source != SourceType.SLACK: raise ValueError(f'Unexpected message source {message.source}') - async def _get_user_auth(self, keycloak_user_id: str) -> UserAuth: - offline_token = await self.token_manager.load_offline_token(keycloak_user_id) - if offline_token is None: - logger.info('no_offline_token_found') - - user_auth = SaasUserAuth( - user_id=keycloak_user_id, - refresh_token=SecretStr(offline_token), - ) - return user_auth - async def authenticate_user( self, slack_user_id: str ) -> tuple[SlackUser | None, UserAuth | None]: @@ -86,7 +74,9 @@ class SlackManager(Manager): saas_user_auth = None if slack_user: - saas_user_auth = await self._get_user_auth(slack_user.keycloak_user_id) + saas_user_auth = await get_saas_user_auth( + slack_user.keycloak_user_id, self.token_manager + ) # slack_view.saas_user_auth = await self._get_user_auth(slack_view.slack_to_openhands_user.keycloak_user_id) return slack_user, saas_user_auth @@ -249,13 +239,11 @@ class SlackManager(Manager): async def is_job_requested( self, message: Message, slack_view: SlackViewInterface ) -> bool: - """ - A job is always request we only receive webhooks for events associated with the slack bot + """A job is always request we only receive webhooks for events associated with the slack bot This method really just checks 1. Is the user is authenticated 2. Do we have the necessary information to start a job (either by inferring the selected repo, otherwise asking the user) """ - # Infer repo from user message is not needed; user selected repo from the form or is updating existing convo if isinstance(slack_view, SlackUpdateExistingConversationView): return True @@ -322,10 +310,15 @@ class SlackManager(Manager): f'[Slack] Created conversation {conversation_id} for user {user_info.slack_display_name}' ) - if not isinstance(slack_view, SlackUpdateExistingConversationView): + # Only add SlackCallbackProcessor for new conversations (not updates) and non-v1 conversations + if ( + not isinstance(slack_view, SlackUpdateExistingConversationView) + and not slack_view.v1_enabled + ): # We don't re-subscribe for follow up messages from slack. # Summaries are generated for every messages anyways, we only need to do # this subscription once for the event which kicked off the job. + processor = SlackCallbackProcessor( slack_user_id=slack_view.slack_user_id, channel_id=slack_view.channel_id, @@ -340,6 +333,14 @@ class SlackManager(Manager): logger.info( f'[Slack] Created callback processor for conversation {conversation_id}' ) + elif isinstance(slack_view, SlackUpdateExistingConversationView): + logger.info( + f'[Slack] Skipping callback processor for existing conversation update {conversation_id}' + ) + elif slack_view.v1_enabled: + logger.info( + f'[Slack] Skipping callback processor for v1 conversation {conversation_id}' + ) msg_info = slack_view.get_response_msg() diff --git a/enterprise/integrations/slack/slack_types.py b/enterprise/integrations/slack/slack_types.py index c07ca9e770..c6b11922df 100644 --- a/enterprise/integrations/slack/slack_types.py +++ b/enterprise/integrations/slack/slack_types.py @@ -21,20 +21,16 @@ class SlackViewInterface(SummaryExtractionTracker, ABC): send_summary_instruction: bool conversation_id: str team_id: str + v1_enabled: bool @abstractmethod def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]: - "Instructions passed when conversation is first initialized" + """Instructions passed when conversation is first initialized""" pass @abstractmethod async def create_or_update_conversation(self, jinja_env: Environment): - "Create a new conversation" - pass - - @abstractmethod - def get_callback_id(self) -> str: - "Unique callback id for subscribription made to EventStream for fetching agent summary" + """Create a new conversation""" pass @abstractmethod @@ -43,6 +39,4 @@ class SlackViewInterface(SummaryExtractionTracker, ABC): class StartingConvoException(Exception): - """ - Raised when trying to send message to a conversation that's is still starting up - """ + """Raised when trying to send message to a conversation that's is still starting up""" diff --git a/enterprise/integrations/slack/slack_v1_callback_processor.py b/enterprise/integrations/slack/slack_v1_callback_processor.py new file mode 100644 index 0000000000..a20ef5fd52 --- /dev/null +++ b/enterprise/integrations/slack/slack_v1_callback_processor.py @@ -0,0 +1,273 @@ +import logging +from uuid import UUID + +import httpx +from integrations.utils import CONVERSATION_URL, get_summary_instruction +from pydantic import Field +from slack_sdk import WebClient +from storage.slack_team_store import SlackTeamStore + +from openhands.agent_server.models import AskAgentRequest, AskAgentResponse +from openhands.app_server.event_callback.event_callback_models import ( + EventCallback, + EventCallbackProcessor, +) +from openhands.app_server.event_callback.event_callback_result_models import ( + EventCallbackResult, + EventCallbackResultStatus, +) +from openhands.app_server.event_callback.util import ( + ensure_conversation_found, + ensure_running_sandbox, + get_agent_server_url_from_sandbox, +) +from openhands.sdk import Event +from openhands.sdk.event import ConversationStateUpdateEvent + +_logger = logging.getLogger(__name__) + + +class SlackV1CallbackProcessor(EventCallbackProcessor): + """Callback processor for Slack V1 integrations.""" + + slack_view_data: dict[str, str | None] = Field(default_factory=dict) + + async def __call__( + self, + conversation_id: UUID, + callback: EventCallback, + event: Event, + ) -> EventCallbackResult | None: + """Process events for Slack V1 integration.""" + + # Only handle ConversationStateUpdateEvent + if not isinstance(event, ConversationStateUpdateEvent): + return None + + # Only act when execution has finished + if not (event.key == 'execution_status' and event.value == 'finished'): + return None + + _logger.info('[Slack V1] Callback agent state was %s', event) + + try: + summary = await self._request_summary(conversation_id) + await self._post_summary_to_slack(summary) + + return EventCallbackResult( + status=EventCallbackResultStatus.SUCCESS, + event_callback_id=callback.id, + event_id=event.id, + conversation_id=conversation_id, + detail=summary, + ) + except Exception as e: + _logger.exception('[Slack V1] Error processing callback: %s', e) + + # Only try to post error to Slack if we have basic requirements + try: + await self._post_summary_to_slack( + f'OpenHands encountered an error: **{str(e)}**.\n\n' + f'[See the conversation]({CONVERSATION_URL.format(conversation_id)})' + 'for more information.' + ) + except Exception as post_error: + _logger.warning( + '[Slack V1] Failed to post error message to Slack: %s', post_error + ) + + return EventCallbackResult( + status=EventCallbackResultStatus.ERROR, + event_callback_id=callback.id, + event_id=event.id, + conversation_id=conversation_id, + detail=str(e), + ) + + # ------------------------------------------------------------------------- + # Slack helpers + # ------------------------------------------------------------------------- + + def _get_bot_access_token(self): + slack_team_store = SlackTeamStore.get_instance() + bot_access_token = slack_team_store.get_team_bot_token( + self.slack_view_data['team_id'] + ) + + return bot_access_token + + async def _post_summary_to_slack(self, summary: str) -> None: + """Post a summary message to the configured Slack channel.""" + bot_access_token = self._get_bot_access_token() + if not bot_access_token: + raise RuntimeError('Missing Slack bot access token') + + channel_id = self.slack_view_data['channel_id'] + thread_ts = self.slack_view_data.get('thread_ts') or self.slack_view_data.get( + 'message_ts' + ) + + client = WebClient(token=bot_access_token) + + try: + # Post the summary as a threaded reply + response = client.chat_postMessage( + channel=channel_id, + text=summary, + thread_ts=thread_ts, + unfurl_links=False, + unfurl_media=False, + ) + + if not response['ok']: + raise RuntimeError( + f"Slack API error: {response.get('error', 'Unknown error')}" + ) + + _logger.info( + '[Slack V1] Successfully posted summary to channel %s', channel_id + ) + + except Exception as e: + _logger.error('[Slack V1] Failed to post message to Slack: %s', e) + raise + + # ------------------------------------------------------------------------- + # Agent / sandbox helpers + # ------------------------------------------------------------------------- + + async def _ask_question( + self, + httpx_client: httpx.AsyncClient, + agent_server_url: str, + conversation_id: UUID, + session_api_key: str, + message_content: str, + ) -> str: + """Send a message to the agent server via the V1 API and return response text.""" + send_message_request = AskAgentRequest(question=message_content) + + url = ( + f'{agent_server_url.rstrip("/")}' + f'/api/conversations/{conversation_id}/ask_agent' + ) + headers = {'X-Session-API-Key': session_api_key} + payload = send_message_request.model_dump() + + try: + response = await httpx_client.post( + url, + json=payload, + headers=headers, + timeout=30.0, + ) + response.raise_for_status() + + agent_response = AskAgentResponse.model_validate(response.json()) + return agent_response.response + + except httpx.HTTPStatusError as e: + error_detail = f'HTTP {e.response.status_code} error' + try: + error_body = e.response.text + if error_body: + error_detail += f': {error_body}' + except Exception: # noqa: BLE001 + pass + + _logger.error( + '[Slack V1] HTTP error sending message to %s: %s. ' + 'Request payload: %s. Response headers: %s', + url, + error_detail, + payload, + dict(e.response.headers), + exc_info=True, + ) + raise Exception(f'Failed to send message to agent server: {error_detail}') + + except httpx.TimeoutException: + error_detail = f'Request timeout after 30 seconds to {url}' + _logger.error( + '[Slack V1] %s. Request payload: %s', + error_detail, + payload, + exc_info=True, + ) + raise Exception(error_detail) + + except httpx.RequestError as e: + error_detail = f'Request error to {url}: {str(e)}' + _logger.error( + '[Slack V1] %s. Request payload: %s', + error_detail, + payload, + exc_info=True, + ) + raise Exception(error_detail) + + # ------------------------------------------------------------------------- + # Summary orchestration + # ------------------------------------------------------------------------- + + async def _request_summary(self, conversation_id: UUID) -> str: + """ + Ask the agent to produce a summary of its work and return the agent response. + + NOTE: This method now returns a string (the agent server's response text) + and raises exceptions on errors. The wrapping into EventCallbackResult + is handled by __call__. + """ + # Import services within the method to avoid circular imports + from openhands.app_server.config import ( + get_app_conversation_info_service, + get_httpx_client, + get_sandbox_service, + ) + from openhands.app_server.services.injector import InjectorState + from openhands.app_server.user.specifiy_user_context import ( + ADMIN, + USER_CONTEXT_ATTR, + ) + + # Create injector state for dependency injection + state = InjectorState() + setattr(state, USER_CONTEXT_ATTR, ADMIN) + + async with ( + get_app_conversation_info_service(state) as app_conversation_info_service, + get_sandbox_service(state) as sandbox_service, + get_httpx_client(state) as httpx_client, + ): + # 1. Conversation lookup + app_conversation_info = ensure_conversation_found( + await app_conversation_info_service.get_app_conversation_info( + conversation_id + ), + conversation_id, + ) + + # 2. Sandbox lookup + validation + sandbox = ensure_running_sandbox( + await sandbox_service.get_sandbox(app_conversation_info.sandbox_id), + app_conversation_info.sandbox_id, + ) + + assert ( + sandbox.session_api_key is not None + ), f'No session API key for sandbox: {sandbox.id}' + + # 3. URL + instruction + agent_server_url = get_agent_server_url_from_sandbox(sandbox) + + # Prepare message based on agent state + message_content = get_summary_instruction() + + # Ask the agent and return the response text + return await self._ask_question( + httpx_client=httpx_client, + agent_server_url=agent_server_url, + conversation_id=conversation_id, + session_api_key=sandbox.session_api_key, + message_content=message_content, + ) diff --git a/enterprise/integrations/slack/slack_view.py b/enterprise/integrations/slack/slack_view.py index b270dfb2ca..9fc0952130 100644 --- a/enterprise/integrations/slack/slack_view.py +++ b/enterprise/integrations/slack/slack_view.py @@ -1,8 +1,16 @@ from dataclasses import dataclass +from uuid import UUID, uuid4 from integrations.models import Message +from integrations.resolver_context import ResolverUserContext from integrations.slack.slack_types import SlackViewInterface, StartingConvoException -from integrations.utils import CONVERSATION_URL, get_final_agent_observation +from integrations.slack.slack_v1_callback_processor import SlackV1CallbackProcessor +from integrations.utils import ( + CONVERSATION_URL, + ENABLE_V1_SLACK_RESOLVER, + get_final_agent_observation, + get_user_v1_enabled_setting, +) from jinja2 import Environment from slack_sdk import WebClient from storage.slack_conversation import SlackConversation @@ -10,22 +18,34 @@ from storage.slack_conversation_store import SlackConversationStore from storage.slack_team_store import SlackTeamStore from storage.slack_user import SlackUser +from openhands.app_server.app_conversation.app_conversation_models import ( + AppConversationStartRequest, + AppConversationStartTaskStatus, + SendMessageRequest, +) +from openhands.app_server.config import get_app_conversation_service +from openhands.app_server.sandbox.sandbox_models import SandboxStatus +from openhands.app_server.services.injector import InjectorState +from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR from openhands.core.logger import openhands_logger as logger from openhands.core.schema.agent import AgentState from openhands.events.action import MessageAction from openhands.events.serialization.event import event_to_dict -from openhands.integrations.provider import ProviderHandler +from openhands.integrations.provider import ProviderHandler, ProviderType +from openhands.sdk import TextContent from openhands.server.services.conversation_service import ( create_new_conversation, setup_init_conversation_settings, ) from openhands.server.shared import ConversationStoreImpl, config, conversation_manager from openhands.server.user_auth.user_auth import UserAuth -from openhands.storage.data_models.conversation_metadata import ConversationTrigger +from openhands.storage.data_models.conversation_metadata import ( + ConversationTrigger, +) from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync # ================================================= -# SECTION: Github view types +# SECTION: Slack view types # ================================================= @@ -34,6 +54,10 @@ slack_conversation_store = SlackConversationStore.get_instance() slack_team_store = SlackTeamStore.get_instance() +async def is_v1_enabled_for_slack_resolver(user_id: str) -> bool: + return await get_user_v1_enabled_setting(user_id) and ENABLE_V1_SLACK_RESOLVER + + @dataclass class SlackUnkownUserView(SlackViewInterface): bot_access_token: str @@ -49,6 +73,7 @@ class SlackUnkownUserView(SlackViewInterface): send_summary_instruction: bool conversation_id: str team_id: str + v1_enabled: bool def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]: raise NotImplementedError @@ -56,9 +81,6 @@ class SlackUnkownUserView(SlackViewInterface): async def create_or_update_conversation(self, jinja_env: Environment): raise NotImplementedError - def get_callback_id(self) -> str: - raise NotImplementedError - def get_response_msg(self) -> str: raise NotImplementedError @@ -78,6 +100,7 @@ class SlackNewConversationView(SlackViewInterface): send_summary_instruction: bool conversation_id: str team_id: str + v1_enabled: bool def _get_initial_prompt(self, text: str, blocks: list[dict]): bot_id = self._get_bot_id(blocks) @@ -96,8 +119,7 @@ class SlackNewConversationView(SlackViewInterface): return '' def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]: - "Instructions passed when conversation is first initialized" - + """Instructions passed when conversation is first initialized""" user_info: SlackUser = self.slack_to_openhands_user messages = [] @@ -157,7 +179,7 @@ class SlackNewConversationView(SlackViewInterface): 'Attempting to start conversation without confirming selected repo from user' ) - async def save_slack_convo(self): + async def save_slack_convo(self, v1_enabled: bool = False): if self.slack_to_openhands_user: user_info: SlackUser = self.slack_to_openhands_user @@ -168,6 +190,7 @@ class SlackNewConversationView(SlackViewInterface): 'conversation_id': self.conversation_id, 'keycloak_user_id': user_info.keycloak_user_id, 'parent_id': self.thread_ts or self.message_ts, + 'v1_enabled': v1_enabled, }, ) slack_conversation = SlackConversation( @@ -176,17 +199,47 @@ class SlackNewConversationView(SlackViewInterface): keycloak_user_id=user_info.keycloak_user_id, parent_id=self.thread_ts or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID + v1_enabled=v1_enabled, ) await slack_conversation_store.create_slack_conversation(slack_conversation) + def _create_slack_v1_callback_processor(self) -> SlackV1CallbackProcessor: + """Create a SlackV1CallbackProcessor for V1 conversation handling.""" + return SlackV1CallbackProcessor( + slack_view_data={ + 'channel_id': self.channel_id, + 'message_ts': self.message_ts, + 'thread_ts': self.thread_ts, + 'team_id': self.team_id, + 'slack_user_id': self.slack_user_id, + } + ) + async def create_or_update_conversation(self, jinja: Environment) -> str: - """ - Only creates a new conversation - """ + """Only creates a new conversation""" self._verify_necessary_values_are_set() provider_tokens = await self.saas_user_auth.get_provider_tokens() user_secrets = await self.saas_user_auth.get_secrets() + + # Check if V1 conversations are enabled for this user + self.v1_enabled = await is_v1_enabled_for_slack_resolver( + self.slack_to_openhands_user.keycloak_user_id + ) + + if self.v1_enabled: + # Use V1 app conversation service + await self._create_v1_conversation(jinja) + return self.conversation_id + else: + # Use existing V0 conversation service + await self._create_v0_conversation(jinja, provider_tokens, user_secrets) + return self.conversation_id + + async def _create_v0_conversation( + self, jinja: Environment, provider_tokens, user_secrets + ) -> None: + """Create conversation using the legacy V0 system.""" user_instructions, conversation_instructions = self._get_instructions(jinja) # Determine git provider from repository @@ -213,11 +266,65 @@ class SlackNewConversationView(SlackViewInterface): ) self.conversation_id = agent_loop_info.conversation_id - await self.save_slack_convo() - return self.conversation_id + logger.info(f'[Slack]: Created V0 conversation: {self.conversation_id}') + await self.save_slack_convo(v1_enabled=False) - def get_callback_id(self) -> str: - return f'slack_{self.channel_id}_{self.message_ts}' + async def _create_v1_conversation(self, jinja: Environment) -> None: + """Create conversation using the new V1 app conversation system.""" + user_instructions, conversation_instructions = self._get_instructions(jinja) + + # Create the initial message request + initial_message = SendMessageRequest( + role='user', content=[TextContent(text=user_instructions)] + ) + + # Create the Slack V1 callback processor + slack_callback_processor = self._create_slack_v1_callback_processor() + + # Determine git provider from repository + git_provider = None + provider_tokens = await self.saas_user_auth.get_provider_tokens() + if self.selected_repo and provider_tokens: + provider_handler = ProviderHandler(provider_tokens) + repository = await provider_handler.verify_repo_provider(self.selected_repo) + git_provider = ProviderType(repository.git_provider.value) + + # Get the app conversation service and start the conversation + injector_state = InjectorState() + + # Create the V1 conversation start request with the callback processor + self.conversation_id = uuid4().hex + start_request = AppConversationStartRequest( + conversation_id=UUID(self.conversation_id), + system_message_suffix=conversation_instructions, + initial_message=initial_message, + selected_repository=self.selected_repo, + git_provider=git_provider, + title=f'Slack conversation from {self.slack_to_openhands_user.slack_display_name}', + trigger=ConversationTrigger.SLACK, + processors=[ + slack_callback_processor + ], # Pass the callback processor directly + ) + + # Set up the Slack user context for the V1 system + slack_user_context = ResolverUserContext(saas_user_auth=self.saas_user_auth) + setattr(injector_state, USER_CONTEXT_ATTR, slack_user_context) + + async with get_app_conversation_service( + injector_state + ) as app_conversation_service: + async for task in app_conversation_service.start_app_conversation( + start_request + ): + if task.status == AppConversationStartTaskStatus.ERROR: + logger.error(f'Failed to start V1 conversation: {task.detail}') + raise RuntimeError( + f'Failed to start V1 conversation: {task.detail}' + ) + + logger.info(f'[Slack V1]: Created new conversation: {self.conversation_id}') + await self.save_slack_convo(v1_enabled=True) def get_response_msg(self) -> str: user_info: SlackUser = self.slack_to_openhands_user @@ -254,32 +361,20 @@ class SlackUpdateExistingConversationView(SlackNewConversationView): return user_message, '' - async def create_or_update_conversation(self, jinja: Environment) -> str: - """ - Send new user message to converation - """ + async def send_message_to_v0_conversation(self, jinja: Environment): user_info: SlackUser = self.slack_to_openhands_user - saas_user_auth: UserAuth = self.saas_user_auth user_id = user_info.keycloak_user_id - - # Org management in the future will get rid of this - # For now, only user that created the conversation can send follow up messages to it - if user_id != self.slack_conversation.keycloak_user_id: - raise StartingConvoException( - f'{user_info.slack_display_name} is not authorized to send messages to this conversation.' - ) - - # Check if conversation has been deleted - # Update logic when soft delete is implemented - conversation_store = await ConversationStoreImpl.get_instance(config, user_id) + saas_user_auth: UserAuth = self.saas_user_auth + provider_tokens = await saas_user_auth.get_provider_tokens() try: + conversation_store = await ConversationStoreImpl.get_instance( + config, user_id + ) await conversation_store.get_metadata(self.conversation_id) except FileNotFoundError: raise StartingConvoException('Conversation no longer exists.') - provider_tokens = await saas_user_auth.get_provider_tokens() - # Should we raise here if there are no provider tokens? providers_set = list(provider_tokens.keys()) if provider_tokens else [] @@ -310,6 +405,117 @@ class SlackUpdateExistingConversationView(SlackNewConversationView): self.conversation_id, event_to_dict(user_msg_action) ) + async def send_message_to_v1_conversation(self, jinja: Environment): + """Send a message to a v1 conversation using the agent server API.""" + # Import services within the method to avoid circular imports + from openhands.agent_server.models import SendMessageRequest + from openhands.app_server.config import ( + get_app_conversation_info_service, + get_httpx_client, + get_sandbox_service, + ) + from openhands.app_server.event_callback.util import ( + ensure_conversation_found, + get_agent_server_url_from_sandbox, + ) + from openhands.app_server.services.injector import InjectorState + from openhands.app_server.user.specifiy_user_context import ( + ADMIN, + USER_CONTEXT_ATTR, + ) + + # Create injector state for dependency injection + state = InjectorState() + setattr(state, USER_CONTEXT_ATTR, ADMIN) + + async with ( + get_app_conversation_info_service(state) as app_conversation_info_service, + get_sandbox_service(state) as sandbox_service, + get_httpx_client(state) as httpx_client, + ): + # 1. Conversation lookup + app_conversation_info = ensure_conversation_found( + await app_conversation_info_service.get_app_conversation_info( + UUID(self.conversation_id) + ), + UUID(self.conversation_id), + ) + + # 2. Sandbox lookup + validation + sandbox = await sandbox_service.get_sandbox( + app_conversation_info.sandbox_id + ) + + if sandbox and sandbox.status == SandboxStatus.PAUSED: + # Resume paused sandbox and wait for it to be running + logger.info('[Slack V1]: Attempting to resume paused sandbox') + await sandbox_service.resume_sandbox(app_conversation_info.sandbox_id) + + # Wait for sandbox to be running (handles both fresh start and resume) + running_sandbox = await sandbox_service.wait_for_sandbox_running( + app_conversation_info.sandbox_id, + timeout=120, + poll_interval=2, + httpx_client=httpx_client, + ) + + assert ( + running_sandbox.session_api_key is not None + ), f'No session API key for sandbox: {running_sandbox.id}' + + # 3. Get the agent server URL + agent_server_url = get_agent_server_url_from_sandbox(running_sandbox) + + # 4. Prepare the message content + user_msg, _ = self._get_instructions(jinja) + + # 5. Create the message request + send_message_request = SendMessageRequest( + role='user', content=[TextContent(text=user_msg)], run=True + ) + + # 6. Send the message to the agent server + url = f'{agent_server_url.rstrip("/")}/api/conversations/{UUID(self.conversation_id)}/events' + + headers = {'X-Session-API-Key': running_sandbox.session_api_key} + payload = send_message_request.model_dump() + + try: + response = await httpx_client.post( + url, + json=payload, + headers=headers, + timeout=30.0, + ) + response.raise_for_status() + + except Exception as e: + logger.error( + '[Slack V1] Failed to send message to conversation %s: %s', + self.conversation_id, + str(e), + exc_info=True, + ) + raise Exception(f'Failed to send message to v1 conversation: {str(e)}') + + async def create_or_update_conversation(self, jinja: Environment) -> str: + """Send new user message to converation""" + user_info: SlackUser = self.slack_to_openhands_user + + user_id = user_info.keycloak_user_id + + # Org management in the future will get rid of this + # For now, only user that created the conversation can send follow up messages to it + if user_id != self.slack_conversation.keycloak_user_id: + raise StartingConvoException( + f'{user_info.slack_display_name} is not authorized to send messages to this conversation.' + ) + + if self.slack_conversation.v1_enabled: + await self.send_message_to_v1_conversation(jinja) + else: + await self.send_message_to_v0_conversation(jinja) + return self.conversation_id def get_response_msg(self): @@ -361,7 +567,7 @@ class SlackFactory: 'channel_id': channel_id, }, ) - raise Exception('Did not slack team') + raise Exception('Did not find slack team') # Determine if this is a known slack user by openhands if not slack_user or not saas_user_auth or not channel_id: @@ -379,6 +585,7 @@ class SlackFactory: send_summary_instruction=False, conversation_id='', team_id=team_id, + v1_enabled=False, ) conversation: SlackConversation | None = call_async_from_sync( @@ -409,6 +616,7 @@ class SlackFactory: conversation_id=conversation.conversation_id, slack_conversation=conversation, team_id=team_id, + v1_enabled=False, ) elif SlackFactory.did_user_select_repo_from_form(message): @@ -426,6 +634,7 @@ class SlackFactory: send_summary_instruction=True, conversation_id='', team_id=team_id, + v1_enabled=False, ) else: @@ -443,4 +652,5 @@ class SlackFactory: send_summary_instruction=True, conversation_id='', team_id=team_id, + v1_enabled=False, ) diff --git a/enterprise/integrations/types.py b/enterprise/integrations/types.py index 0b8d79228c..c18acb25f6 100644 --- a/enterprise/integrations/types.py +++ b/enterprise/integrations/types.py @@ -45,7 +45,3 @@ class ResolverViewInterface(SummaryExtractionTracker): async def create_new_conversation(self, jinja_env: Environment, token: str): "Create a new conversation" raise NotImplementedError() - - def get_callback_id(self) -> str: - "Unique callback id for subscribription made to EventStream for fetching agent summary" - raise NotImplementedError() diff --git a/enterprise/integrations/utils.py b/enterprise/integrations/utils.py index 577f819724..d4ee8102bb 100644 --- a/enterprise/integrations/utils.py +++ b/enterprise/integrations/utils.py @@ -6,7 +6,9 @@ import re from typing import TYPE_CHECKING from jinja2 import Environment, FileSystemLoader +from server.config import get_config from server.constants import WEB_HOST +from storage.database import session_maker from storage.repository_store import RepositoryStore from storage.stored_repository import StoredRepository from storage.user_repo_map import UserRepositoryMap @@ -25,6 +27,7 @@ from openhands.events.event_store_abc import EventStoreABC from openhands.events.observation.agent import AgentStateChangedObservation from openhands.integrations.service_types import Repository from openhands.storage.data_models.conversation_status import ConversationStatus +from openhands.utils.async_utils import call_sync_from_async if TYPE_CHECKING: from openhands.server.conversation_manager.conversation_manager import ( @@ -36,7 +39,7 @@ if TYPE_CHECKING: HOST = WEB_HOST # ---- DO NOT REMOVE ---- -HOST_URL = f'https://{HOST}' +HOST_URL = f'https://{HOST}' if 'localhost' not in HOST else f'http://{HOST}' GITHUB_WEBHOOK_URL = f'{HOST_URL}/integration/github/events' GITLAB_WEBHOOK_URL = f'{HOST_URL}/integration/gitlab/events' conversation_prefix = 'conversations/{}' @@ -78,6 +81,9 @@ ENABLE_V1_GITHUB_RESOLVER = ( os.getenv('ENABLE_V1_GITHUB_RESOLVER', 'false').lower() == 'true' ) +ENABLE_V1_SLACK_RESOLVER = ( + os.getenv('ENABLE_V1_SLACK_RESOLVER', 'false').lower() == 'true' +) OPENHANDS_RESOLVER_TEMPLATES_DIR = ( os.getenv('OPENHANDS_RESOLVER_TEMPLATES_DIR') @@ -110,6 +116,37 @@ def get_summary_instruction(): return summary_instruction +async def get_user_v1_enabled_setting(user_id: str | None) -> bool: + """Get the user's V1 conversation API setting. + + Args: + user_id: The keycloak user ID + + Returns: + True if V1 conversations are enabled for this user, False otherwise + """ + + # If no user ID is provided, we can't check user settings + if not user_id: + return False + + from storage.saas_settings_store import SaasSettingsStore + + config = get_config() + settings_store = SaasSettingsStore( + user_id=user_id, session_maker=session_maker, config=config + ) + + settings = await call_sync_from_async( + settings_store.get_user_settings_by_keycloak_id, user_id + ) + + if not settings or settings.v1_enabled is None: + return False + + return settings.v1_enabled + + def has_exact_mention(text: str, mention: str) -> bool: """Check if the text contains an exact mention (not part of a larger word). diff --git a/enterprise/migrations/versions/086_add_v1_column_to_slack_conversation.py b/enterprise/migrations/versions/086_add_v1_column_to_slack_conversation.py new file mode 100644 index 0000000000..2d44f534af --- /dev/null +++ b/enterprise/migrations/versions/086_add_v1_column_to_slack_conversation.py @@ -0,0 +1,30 @@ +"""add v1 column to slack conversation table + +Revision ID: 086 +Revises: 085 +Create Date: 2025-12-02 15:30:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '086' +down_revision: Union[str, None] = '085' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add v1 column + op.add_column( + 'slack_conversation', sa.Column('v1_enabled', sa.Boolean(), nullable=True) + ) + + +def downgrade() -> None: + # Drop v1 column + op.drop_column('slack_conversation', 'v1_enabled') diff --git a/enterprise/server/clustered_conversation_manager.py b/enterprise/server/clustered_conversation_manager.py index 1eae6d19da..3129849b1a 100644 --- a/enterprise/server/clustered_conversation_manager.py +++ b/enterprise/server/clustered_conversation_manager.py @@ -8,7 +8,6 @@ import socketio from server.logger import logger from server.utils.conversation_callback_utils import invoke_conversation_callbacks from storage.database import session_maker -from storage.saas_settings_store import SaasSettingsStore from storage.stored_conversation_metadata import StoredConversationMetadata from openhands.core.config import LLMConfig @@ -743,6 +742,8 @@ class ClusteredConversationManager(StandaloneConversationManager): return # Restart the agent loop + from storage.saas_settings_store import SaasSettingsStore + config = load_openhands_config() settings_store = await SaasSettingsStore.get_instance(config, user_id) settings = await settings_store.load() diff --git a/enterprise/server/routes/integration/slack.py b/enterprise/server/routes/integration/slack.py index 22afc64659..9e16995040 100644 --- a/enterprise/server/routes/integration/slack.py +++ b/enterprise/server/routes/integration/slack.py @@ -107,7 +107,7 @@ async def install_callback( # Redirect into keycloak scope = quote('openid email profile offline_access') - redirect_uri = quote(f'{HOST_URL}/slack/keycloak-callback') + redirect_uri = f'{HOST_URL}/slack/keycloak-callback' auth_url = ( f'{KEYCLOAK_SERVER_URL_EXT}/realms/{KEYCLOAK_REALM_NAME}/protocol/openid-connect/auth' f'?client_id={KEYCLOAK_CLIENT_ID}&response_type=code' @@ -158,7 +158,7 @@ async def keycloak_callback( team_id = payload['team_id'] # Retrieve the keycloak_user_id - redirect_uri = f'https://{request.url.netloc}{request.url.path}' + redirect_uri = f'{HOST_URL}{request.url.path}' ( keycloak_access_token, keycloak_refresh_token, diff --git a/enterprise/storage/slack_conversation.py b/enterprise/storage/slack_conversation.py index d2cea4e7a5..616ed5d9e6 100644 --- a/enterprise/storage/slack_conversation.py +++ b/enterprise/storage/slack_conversation.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Identity, Integer, String +from sqlalchemy import Boolean, Column, Identity, Integer, String from storage.base import Base @@ -9,3 +9,4 @@ class SlackConversation(Base): # type: ignore channel_id = Column(String, nullable=False) keycloak_user_id = Column(String, nullable=False) parent_id = Column(String, nullable=True, index=True) + v1_enabled = Column(Boolean, nullable=True) diff --git a/enterprise/storage/slack_conversation_store.py b/enterprise/storage/slack_conversation_store.py index 2d859ee62c..7ac156c082 100644 --- a/enterprise/storage/slack_conversation_store.py +++ b/enterprise/storage/slack_conversation_store.py @@ -14,8 +14,7 @@ class SlackConversationStore: async def get_slack_conversation( self, channel_id: str, parent_id: str ) -> SlackConversation | None: - """ - Get a slack conversation by channel_id and message_ts. + """Get a slack conversation by channel_id and message_ts. Both parameters are required to match for a conversation to be returned. """ with session_maker() as session: diff --git a/enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py b/enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py new file mode 100644 index 0000000000..6aa03c408d --- /dev/null +++ b/enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py @@ -0,0 +1,431 @@ +"""Tests for the SlackV1CallbackProcessor. + +Focuses on high-impact scenarios: +- Double callback processing (main requirement) +- Event filtering +- Error handling for critical failures +- Successful end-to-end flow +""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import httpx +import pytest +from integrations.slack.slack_v1_callback_processor import ( + SlackV1CallbackProcessor, +) + +from openhands.app_server.app_conversation.app_conversation_models import ( + AppConversationInfo, +) +from openhands.app_server.event_callback.event_callback_models import EventCallback +from openhands.app_server.event_callback.event_callback_result_models import ( + EventCallbackResultStatus, +) +from openhands.app_server.sandbox.sandbox_models import ( + ExposedUrl, + SandboxInfo, + SandboxStatus, +) +from openhands.events.action.message import MessageAction +from openhands.sdk.event import ConversationStateUpdateEvent + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def slack_callback_processor(): + return SlackV1CallbackProcessor( + slack_view_data={ + 'channel_id': 'C1234567890', + 'message_ts': '1234567890.123456', + 'team_id': 'T1234567890', + } + ) + + +@pytest.fixture +def finish_event(): + return ConversationStateUpdateEvent(key='execution_status', value='finished') + + +@pytest.fixture +def event_callback(): + return EventCallback( + id=uuid4(), + conversation_id=uuid4(), + processor=SlackV1CallbackProcessor(), + event_kind='ConversationStateUpdateEvent', + ) + + +@pytest.fixture +def mock_app_conversation_info(): + return AppConversationInfo( + id=uuid4(), + created_by_user_id='test-user-123', + sandbox_id=str(uuid4()), + title='Test Conversation', + ) + + +@pytest.fixture +def mock_sandbox_info(): + return SandboxInfo( + id=str(uuid4()), + created_by_user_id='test-user-123', + sandbox_spec_id='test-spec-123', + status=SandboxStatus.RUNNING, + session_api_key='test-session-key', + exposed_urls=[ + ExposedUrl( + url='http://localhost:8000', + name='AGENT_SERVER', + port=8000, + ) + ], + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestSlackV1CallbackProcessor: + """Test the SlackV1CallbackProcessor class with focus on high-impact scenarios.""" + + # ------------------------------------------------------------------------- + # Event filtering tests (parameterized) + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize( + 'event,expected_result', + [ + # Wrong event types should be ignored + (MessageAction(content='Hello world'), None), + # Wrong state values should be ignored + ( + ConversationStateUpdateEvent(key='execution_status', value='running'), + None, + ), + ( + ConversationStateUpdateEvent(key='execution_status', value='started'), + None, + ), + (ConversationStateUpdateEvent(key='other_key', value='finished'), None), + ], + ) + async def test_event_filtering( + self, slack_callback_processor, event_callback, event, expected_result + ): + """Test that processor correctly filters events.""" + result = await slack_callback_processor(uuid4(), event_callback, event) + assert result == expected_result + + # ------------------------------------------------------------------------- + # Double callback processing (main requirement) + # ------------------------------------------------------------------------- + + @patch('storage.slack_team_store.SlackTeamStore.get_instance') + @patch('integrations.slack.slack_v1_callback_processor.WebClient') + @patch.object(SlackV1CallbackProcessor, '_request_summary') + async def test_double_callback_processing( + self, + mock_request_summary, + mock_web_client, + mock_slack_team_store, + slack_callback_processor, + finish_event, + event_callback, + ): + """Test that processor handles double callback correctly and processes both times.""" + conversation_id = uuid4() + + # Mock SlackTeamStore + mock_store = MagicMock() + mock_store.get_team_bot_token.return_value = 'xoxb-test-token' + mock_slack_team_store.return_value = mock_store + + # Mock successful summary generation + mock_request_summary.return_value = 'Test summary from agent' + + # Mock Slack WebClient + mock_slack_client = MagicMock() + mock_slack_client.chat_postMessage.return_value = {'ok': True} + mock_web_client.return_value = mock_slack_client + + # First callback + result1 = await slack_callback_processor( + conversation_id, event_callback, finish_event + ) + + # Second callback (should not exit, should process again) + result2 = await slack_callback_processor( + conversation_id, event_callback, finish_event + ) + + # Verify both callbacks succeeded + assert result1 is not None + assert result1.status == EventCallbackResultStatus.SUCCESS + assert result1.detail == 'Test summary from agent' + + assert result2 is not None + assert result2.status == EventCallbackResultStatus.SUCCESS + assert result2.detail == 'Test summary from agent' + + # Verify both callbacks triggered summary requests and Slack posts + assert mock_request_summary.call_count == 2 + assert mock_slack_client.chat_postMessage.call_count == 2 + + # ------------------------------------------------------------------------- + # Successful end-to-end flow + # ------------------------------------------------------------------------- + + @patch('storage.slack_team_store.SlackTeamStore.get_instance') + @patch('openhands.app_server.config.get_httpx_client') + @patch('openhands.app_server.config.get_sandbox_service') + @patch('openhands.app_server.config.get_app_conversation_info_service') + @patch('integrations.slack.slack_v1_callback_processor.get_summary_instruction') + @patch('integrations.slack.slack_v1_callback_processor.WebClient') + async def test_successful_end_to_end_flow( + self, + mock_web_client, + mock_get_summary_instruction, + mock_get_app_conversation_info_service, + mock_get_sandbox_service, + mock_get_httpx_client, + mock_slack_team_store, + slack_callback_processor, + finish_event, + event_callback, + mock_app_conversation_info, + mock_sandbox_info, + ): + """Test successful end-to-end callback execution.""" + conversation_id = uuid4() + + # Mock SlackTeamStore + mock_store = MagicMock() + mock_store.get_team_bot_token.return_value = 'xoxb-test-token' + mock_slack_team_store.return_value = mock_store + + # Mock summary instruction + mock_get_summary_instruction.return_value = 'Please provide a summary' + + # Mock services + mock_app_conversation_info_service = AsyncMock() + mock_app_conversation_info_service.get_app_conversation_info.return_value = ( + mock_app_conversation_info + ) + mock_get_app_conversation_info_service.return_value.__aenter__.return_value = ( + mock_app_conversation_info_service + ) + + mock_sandbox_service = AsyncMock() + mock_sandbox_service.get_sandbox.return_value = mock_sandbox_info + mock_get_sandbox_service.return_value.__aenter__.return_value = ( + mock_sandbox_service + ) + + mock_httpx_client = AsyncMock() + mock_response = MagicMock() + mock_response.json.return_value = {'response': 'Test summary from agent'} + mock_response.raise_for_status = MagicMock() + mock_httpx_client.post.return_value = mock_response + mock_get_httpx_client.return_value.__aenter__.return_value = mock_httpx_client + + # Mock Slack WebClient + mock_slack_client = MagicMock() + mock_slack_client.chat_postMessage.return_value = {'ok': True} + mock_web_client.return_value = mock_slack_client + + # Execute + result = await slack_callback_processor( + conversation_id, event_callback, finish_event + ) + + # Verify result + assert result is not None + assert result.status == EventCallbackResultStatus.SUCCESS + assert result.conversation_id == conversation_id + assert result.detail == 'Test summary from agent' + + # Verify Slack posting + mock_slack_client.chat_postMessage.assert_called_once_with( + channel='C1234567890', + text='Test summary from agent', + thread_ts='1234567890.123456', + unfurl_links=False, + unfurl_media=False, + ) + + # ------------------------------------------------------------------------- + # Error handling tests (parameterized) + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize( + 'bot_token,expected_error', + [ + (None, 'Missing Slack bot access token'), + ('', 'Missing Slack bot access token'), + ], + ) + @patch('storage.slack_team_store.SlackTeamStore.get_instance') + @patch.object(SlackV1CallbackProcessor, '_request_summary') + async def test_missing_bot_token_scenarios( + self, + mock_request_summary, + mock_slack_team_store, + slack_callback_processor, + finish_event, + event_callback, + bot_token, + expected_error, + ): + """Test error handling when bot access token is missing or empty.""" + # Mock SlackTeamStore to return the test token + mock_store = MagicMock() + mock_store.get_team_bot_token.return_value = bot_token + mock_slack_team_store.return_value = mock_store + + # Mock successful summary generation + mock_request_summary.return_value = 'Test summary' + + result = await slack_callback_processor(uuid4(), event_callback, finish_event) + + assert result is not None + assert result.status == EventCallbackResultStatus.ERROR + assert expected_error in result.detail + + @pytest.mark.parametrize( + 'slack_response,expected_error', + [ + ( + {'ok': False, 'error': 'channel_not_found'}, + 'Slack API error: channel_not_found', + ), + ({'ok': False, 'error': 'invalid_auth'}, 'Slack API error: invalid_auth'), + ({'ok': False}, 'Slack API error: Unknown error'), + ], + ) + @patch('storage.slack_team_store.SlackTeamStore.get_instance') + @patch('integrations.slack.slack_v1_callback_processor.WebClient') + @patch.object(SlackV1CallbackProcessor, '_request_summary') + async def test_slack_api_error_scenarios( + self, + mock_request_summary, + mock_web_client, + mock_slack_team_store, + slack_callback_processor, + finish_event, + event_callback, + slack_response, + expected_error, + ): + """Test error handling for various Slack API errors.""" + # Mock SlackTeamStore + mock_store = MagicMock() + mock_store.get_team_bot_token.return_value = 'xoxb-test-token' + mock_slack_team_store.return_value = mock_store + + # Mock successful summary generation + mock_request_summary.return_value = 'Test summary' + + # Mock Slack WebClient with error response + mock_slack_client = MagicMock() + mock_slack_client.chat_postMessage.return_value = slack_response + mock_web_client.return_value = mock_slack_client + + result = await slack_callback_processor(uuid4(), event_callback, finish_event) + + assert result is not None + assert result.status == EventCallbackResultStatus.ERROR + assert expected_error in result.detail + + @pytest.mark.parametrize( + 'exception,expected_error_fragment', + [ + ( + httpx.TimeoutException('Request timeout'), + 'Request timeout after 30 seconds', + ), + ( + httpx.HTTPStatusError( + 'Server error', + request=MagicMock(), + response=MagicMock( + status_code=500, text='Internal Server Error', headers={} + ), + ), + 'Failed to send message to agent server', + ), + ( + httpx.RequestError('Connection error'), + 'Request error', + ), + ], + ) + @patch('storage.slack_team_store.SlackTeamStore.get_instance') + @patch('openhands.app_server.config.get_httpx_client') + @patch('openhands.app_server.config.get_sandbox_service') + @patch('openhands.app_server.config.get_app_conversation_info_service') + @patch('integrations.slack.slack_v1_callback_processor.get_summary_instruction') + async def test_agent_server_error_scenarios( + self, + mock_get_summary_instruction, + mock_get_app_conversation_info_service, + mock_get_sandbox_service, + mock_get_httpx_client, + mock_slack_team_store, + slack_callback_processor, + finish_event, + event_callback, + mock_app_conversation_info, + mock_sandbox_info, + exception, + expected_error_fragment, + ): + """Test error handling for various agent server errors.""" + conversation_id = uuid4() + + # Mock SlackTeamStore + mock_store = MagicMock() + mock_store.get_team_bot_token.return_value = 'xoxb-test-token' + mock_slack_team_store.return_value = mock_store + + # Mock summary instruction + mock_get_summary_instruction.return_value = 'Please provide a summary' + + # Mock services + mock_app_conversation_info_service = AsyncMock() + mock_app_conversation_info_service.get_app_conversation_info.return_value = ( + mock_app_conversation_info + ) + mock_get_app_conversation_info_service.return_value.__aenter__.return_value = ( + mock_app_conversation_info_service + ) + + mock_sandbox_service = AsyncMock() + mock_sandbox_service.get_sandbox.return_value = mock_sandbox_info + mock_get_sandbox_service.return_value.__aenter__.return_value = ( + mock_sandbox_service + ) + + # Mock HTTP client with the specified exception + mock_httpx_client = AsyncMock() + mock_httpx_client.post.side_effect = exception + mock_get_httpx_client.return_value.__aenter__.return_value = mock_httpx_client + + # Execute + result = await slack_callback_processor( + conversation_id, event_callback, finish_event + ) + + # Verify error result + assert result is not None + assert result.status == EventCallbackResultStatus.ERROR + assert expected_error_fragment in result.detail diff --git a/enterprise/tests/unit/integrations/slack/test_slack_view.py b/enterprise/tests/unit/integrations/slack/test_slack_view.py new file mode 100644 index 0000000000..99aa60009a --- /dev/null +++ b/enterprise/tests/unit/integrations/slack/test_slack_view.py @@ -0,0 +1,341 @@ +"""Tests for the Slack view classes and their v1 vs v0 conversation handling. + +Focuses on the 3 essential scenarios: +1. V1 vs V0 decision logic based on user setting +2. Message routing to correct method based on conversation v1 flag +3. Paused sandbox resumption for V1 conversations +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from integrations.slack.slack_view import ( + SlackNewConversationView, + SlackUpdateExistingConversationView, +) +from jinja2 import DictLoader, Environment +from storage.slack_conversation import SlackConversation +from storage.slack_user import SlackUser + +from openhands.app_server.sandbox.sandbox_models import SandboxStatus +from openhands.server.user_auth.user_auth import UserAuth + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_jinja_env(): + """Create a mock Jinja environment with test templates.""" + templates = { + 'user_message_conversation_instructions.j2': 'Previous messages: {{ messages|join(", ") }}\nUser: {{ username }}\nURL: {{ conversation_url }}' + } + return Environment(loader=DictLoader(templates)) + + +@pytest.fixture +def mock_slack_user(): + """Create a mock SlackUser.""" + user = SlackUser() + user.slack_user_id = 'U1234567890' + user.keycloak_user_id = 'test-user-123' + user.slack_display_name = 'Test User' + return user + + +@pytest.fixture +def mock_user_auth(): + """Create a mock UserAuth.""" + auth = MagicMock(spec=UserAuth) + auth.get_provider_tokens = AsyncMock(return_value={}) + auth.get_secrets = AsyncMock(return_value=MagicMock(custom_secrets={})) + return auth + + +@pytest.fixture +def slack_new_conversation_view(mock_slack_user, mock_user_auth): + """Create a SlackNewConversationView instance.""" + return SlackNewConversationView( + bot_access_token='xoxb-test-token', + user_msg='Hello OpenHands!', + slack_user_id='U1234567890', + slack_to_openhands_user=mock_slack_user, + saas_user_auth=mock_user_auth, + channel_id='C1234567890', + message_ts='1234567890.123456', + thread_ts=None, + selected_repo='owner/repo', + should_extract=True, + send_summary_instruction=True, + conversation_id='', + team_id='T1234567890', + v1_enabled=False, + ) + + +@pytest.fixture +def slack_update_conversation_view_v0(mock_slack_user, mock_user_auth): + """Create a SlackUpdateExistingConversationView instance for V0.""" + conversation_id = '87654321-4321-8765-4321-876543218765' + mock_conversation = SlackConversation( + conversation_id=conversation_id, + channel_id='C1234567890', + keycloak_user_id='test-user-123', + parent_id='1234567890.123456', + v1_enabled=False, + ) + return SlackUpdateExistingConversationView( + bot_access_token='xoxb-test-token', + user_msg='Follow up message', + slack_user_id='U1234567890', + slack_to_openhands_user=mock_slack_user, + saas_user_auth=mock_user_auth, + channel_id='C1234567890', + message_ts='1234567890.123457', + thread_ts='1234567890.123456', + selected_repo=None, + should_extract=True, + send_summary_instruction=True, + conversation_id=conversation_id, + slack_conversation=mock_conversation, + team_id='T1234567890', + v1_enabled=False, + ) + + +@pytest.fixture +def slack_update_conversation_view_v1(mock_slack_user, mock_user_auth): + """Create a SlackUpdateExistingConversationView instance for V1.""" + conversation_id = '12345678-1234-5678-1234-567812345678' + mock_conversation = SlackConversation( + conversation_id=conversation_id, + channel_id='C1234567890', + keycloak_user_id='test-user-123', + parent_id='1234567890.123456', + v1_enabled=True, + ) + return SlackUpdateExistingConversationView( + bot_access_token='xoxb-test-token', + user_msg='Follow up message', + slack_user_id='U1234567890', + slack_to_openhands_user=mock_slack_user, + saas_user_auth=mock_user_auth, + channel_id='C1234567890', + message_ts='1234567890.123457', + thread_ts='1234567890.123456', + selected_repo=None, + should_extract=True, + send_summary_instruction=True, + conversation_id=conversation_id, + slack_conversation=mock_conversation, + team_id='T1234567890', + v1_enabled=True, + ) + + +# --------------------------------------------------------------------------- +# Test 1: V1 vs V0 Decision Logic Based on User Setting +# --------------------------------------------------------------------------- + + +class TestV1V0DecisionLogic: + """Test the decision logic for choosing between V1 and V0 conversations based on user setting.""" + + @pytest.mark.parametrize( + 'v1_enabled,expected_v1_flag', + [ + (True, True), # V1 enabled, use V1 + (False, False), # V1 disabled, use V0 + ], + ) + @patch('integrations.slack.slack_view.is_v1_enabled_for_slack_resolver') + @patch.object(SlackNewConversationView, '_create_v1_conversation') + @patch.object(SlackNewConversationView, '_create_v0_conversation') + async def test_v1_v0_decision_logic( + self, + mock_create_v0, + mock_create_v1, + mock_is_v1_enabled, + slack_new_conversation_view, + mock_jinja_env, + v1_enabled, + expected_v1_flag, + ): + """Test the decision logic for V1 vs V0 conversation creation based on user setting.""" + # Setup mocks + mock_is_v1_enabled.return_value = v1_enabled + mock_create_v1.return_value = None + mock_create_v0.return_value = None + + # Execute + result = await slack_new_conversation_view.create_or_update_conversation( + mock_jinja_env + ) + + # Verify + assert result == slack_new_conversation_view.conversation_id + assert slack_new_conversation_view.v1_enabled == expected_v1_flag + + if v1_enabled: + mock_create_v1.assert_called_once() + mock_create_v0.assert_not_called() + else: + mock_create_v1.assert_not_called() + mock_create_v0.assert_called_once() + + +# --------------------------------------------------------------------------- +# Test 2: Message Routing Based on Conversation V1 Flag +# --------------------------------------------------------------------------- + + +class TestMessageRouting: + """Test that message sending routes to correct method based on conversation v1 flag.""" + + @patch.object( + SlackUpdateExistingConversationView, 'send_message_to_v1_conversation' + ) + @patch.object( + SlackUpdateExistingConversationView, 'send_message_to_v0_conversation' + ) + async def test_message_routing_to_v1( + self, + mock_send_v0, + mock_send_v1, + slack_update_conversation_view_v1, + mock_jinja_env, + ): + """Test that V1 conversations route to V1 message sending method.""" + # Setup + mock_send_v0.return_value = None + mock_send_v1.return_value = None + + # Execute + result = await slack_update_conversation_view_v1.create_or_update_conversation( + mock_jinja_env + ) + + # Verify + assert result == slack_update_conversation_view_v1.conversation_id + mock_send_v1.assert_called_once_with(mock_jinja_env) + mock_send_v0.assert_not_called() + + @patch.object( + SlackUpdateExistingConversationView, 'send_message_to_v1_conversation' + ) + @patch.object( + SlackUpdateExistingConversationView, 'send_message_to_v0_conversation' + ) + async def test_message_routing_to_v0( + self, + mock_send_v0, + mock_send_v1, + slack_update_conversation_view_v0, + mock_jinja_env, + ): + """Test that V0 conversations route to V0 message sending method.""" + # Setup + mock_send_v0.return_value = None + mock_send_v1.return_value = None + + # Execute + result = await slack_update_conversation_view_v0.create_or_update_conversation( + mock_jinja_env + ) + + # Verify + assert result == slack_update_conversation_view_v0.conversation_id + mock_send_v0.assert_called_once_with(mock_jinja_env) + mock_send_v1.assert_not_called() + + +# --------------------------------------------------------------------------- +# Test 3: Paused Sandbox Resumption for V1 Conversations +# --------------------------------------------------------------------------- + + +class TestPausedSandboxResumption: + """Test that paused sandboxes are resumed when sending messages to V1 conversations.""" + + @patch('openhands.app_server.config.get_sandbox_service') + @patch('openhands.app_server.config.get_app_conversation_info_service') + @patch('openhands.app_server.config.get_httpx_client') + @patch('openhands.app_server.event_callback.util.ensure_running_sandbox') + @patch('openhands.app_server.event_callback.util.get_agent_server_url_from_sandbox') + @patch.object(SlackUpdateExistingConversationView, '_get_instructions') + async def test_paused_sandbox_resumption( + self, + mock_get_instructions, + mock_get_agent_server_url, + mock_ensure_running_sandbox, + mock_get_httpx_client, + mock_get_app_info_service, + mock_get_sandbox_service, + slack_update_conversation_view_v1, + mock_jinja_env, + ): + """Test that paused sandboxes are resumed when sending messages to V1 conversations.""" + # Setup mocks + mock_get_instructions.return_value = ('User message', '') + + # Mock app conversation info service + mock_app_info_service = AsyncMock() + mock_app_info = MagicMock() + mock_app_info.sandbox_id = 'sandbox-123' + mock_app_info_service.get_app_conversation_info.return_value = mock_app_info + mock_get_app_info_service.return_value.__aenter__.return_value = ( + mock_app_info_service + ) + + # Mock sandbox service with paused sandbox that gets resumed + mock_sandbox_service = AsyncMock() + mock_paused_sandbox = MagicMock() + mock_paused_sandbox.status = SandboxStatus.PAUSED + mock_paused_sandbox.session_api_key = 'test-api-key' + mock_paused_sandbox.exposed_urls = [ + MagicMock(name='AGENT_SERVER', url='http://localhost:8000') + ] + + # After resume, sandbox becomes running + mock_running_sandbox = MagicMock() + mock_running_sandbox.status = SandboxStatus.RUNNING + mock_running_sandbox.session_api_key = 'test-api-key' + mock_running_sandbox.exposed_urls = [ + MagicMock(name='AGENT_SERVER', url='http://localhost:8000') + ] + + mock_sandbox_service.get_sandbox.side_effect = [ + mock_paused_sandbox, + mock_running_sandbox, + ] + mock_sandbox_service.resume_sandbox = AsyncMock() + mock_get_sandbox_service.return_value.__aenter__.return_value = ( + mock_sandbox_service + ) + + # Mock ensure_running_sandbox to first raise RuntimeError, then return running sandbox + mock_ensure_running_sandbox.side_effect = [ + RuntimeError('Sandbox not running: sandbox-123'), + mock_running_sandbox, + ] + + # Mock agent server URL + mock_get_agent_server_url.return_value = 'http://localhost:8000' + + # Mock HTTP client + mock_httpx_client = AsyncMock() + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_httpx_client.post.return_value = mock_response + mock_get_httpx_client.return_value.__aenter__.return_value = mock_httpx_client + + # Execute + await slack_update_conversation_view_v1.send_message_to_v1_conversation( + mock_jinja_env + ) + + # Verify sandbox was resumed + mock_sandbox_service.resume_sandbox.assert_called_once_with('sandbox-123') + mock_httpx_client.post.assert_called_once() + mock_response.raise_for_status.assert_called_once() diff --git a/enterprise/tests/unit/test_get_user_v1_enabled_setting.py b/enterprise/tests/unit/test_get_user_v1_enabled_setting.py index bbed7b8ba0..5fd2fe14a1 100644 --- a/enterprise/tests/unit/test_get_user_v1_enabled_setting.py +++ b/enterprise/tests/unit/test_get_user_v1_enabled_setting.py @@ -1,10 +1,9 @@ """Unit tests for get_user_v1_enabled_setting function.""" -import os -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest -from integrations.github.github_view import get_user_v1_enabled_setting +from integrations.utils import get_user_v1_enabled_setting @pytest.fixture @@ -16,10 +15,9 @@ def mock_user_settings(): @pytest.fixture -def mock_settings_store(mock_user_settings): +def mock_settings_store(): """Create a mock settings store.""" store = MagicMock() - store.get_user_settings_by_keycloak_id = AsyncMock(return_value=mock_user_settings) return store @@ -40,15 +38,16 @@ def mock_dependencies( mock_settings_store, mock_config, mock_session_maker, mock_user_settings ): """Fixture that patches all the common dependencies.""" + # Patch at the source module since SaasSettingsStore is imported inside the function with patch( - 'integrations.github.github_view.SaasSettingsStore', + 'storage.saas_settings_store.SaasSettingsStore', return_value=mock_settings_store, ) as mock_store_class, patch( - 'integrations.github.github_view.get_config', return_value=mock_config + 'integrations.utils.get_config', return_value=mock_config ) as mock_get_config, patch( - 'integrations.github.github_view.session_maker', mock_session_maker + 'integrations.utils.session_maker', mock_session_maker ), patch( - 'integrations.github.github_view.call_sync_from_async', + 'integrations.utils.call_sync_from_async', return_value=mock_user_settings, ) as mock_call_sync: yield { @@ -66,67 +65,64 @@ class TestGetUserV1EnabledSetting: @pytest.mark.asyncio @pytest.mark.parametrize( - 'env_var_enabled,user_setting_enabled,expected_result', + 'user_setting_enabled,expected_result', [ - (False, True, False), # Env var disabled, user enabled -> False - (True, False, False), # Env var enabled, user disabled -> False - (True, True, True), # Both enabled -> True - (False, False, False), # Both disabled -> False + (True, True), # User enabled -> True + (False, False), # User disabled -> False ], ) - async def test_v1_enabled_combinations( - self, mock_dependencies, env_var_enabled, user_setting_enabled, expected_result + async def test_v1_enabled_user_setting( + self, mock_dependencies, user_setting_enabled, expected_result ): - """Test all combinations of environment variable and user setting values.""" + """Test that the function returns the user's v1_enabled setting.""" mock_dependencies['user_settings'].v1_enabled = user_setting_enabled - with patch( - 'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_enabled - ): - result = await get_user_v1_enabled_setting('test_user_id') - assert result is expected_result + result = await get_user_v1_enabled_setting('test_user_id') + assert result is expected_result @pytest.mark.asyncio - @pytest.mark.parametrize( - 'env_var_value,env_var_bool,expected_result', - [ - ('false', False, False), # Environment variable 'false' -> False - ('true', True, True), # Environment variable 'true' -> True - ], - ) - async def test_environment_variable_integration( - self, mock_dependencies, env_var_value, env_var_bool, expected_result - ): - """Test that the function properly reads the ENABLE_V1_GITHUB_RESOLVER environment variable.""" - mock_dependencies['user_settings'].v1_enabled = True + async def test_returns_false_when_no_user_id(self): + """Test that the function returns False when no user_id is provided.""" + result = await get_user_v1_enabled_setting(None) + assert result is False - with patch.dict( - os.environ, {'ENABLE_V1_GITHUB_RESOLVER': env_var_value} - ), patch('integrations.utils.os.getenv', return_value=env_var_value), patch( - 'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_bool - ): - result = await get_user_v1_enabled_setting('test_user_id') - assert result is expected_result + result = await get_user_v1_enabled_setting('') + assert result is False + + @pytest.mark.asyncio + async def test_returns_false_when_settings_is_none(self, mock_dependencies): + """Test that the function returns False when settings is None.""" + mock_dependencies['call_sync'].return_value = None + + result = await get_user_v1_enabled_setting('test_user_id') + assert result is False + + @pytest.mark.asyncio + async def test_returns_false_when_v1_enabled_is_none(self, mock_dependencies): + """Test that the function returns False when v1_enabled is None.""" + mock_dependencies['user_settings'].v1_enabled = None + + result = await get_user_v1_enabled_setting('test_user_id') + assert result is False @pytest.mark.asyncio async def test_function_calls_correct_methods(self, mock_dependencies): """Test that the function calls the correct methods with correct parameters.""" mock_dependencies['user_settings'].v1_enabled = True - with patch('integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', True): - result = await get_user_v1_enabled_setting('test_user_123') + result = await get_user_v1_enabled_setting('test_user_123') - # Verify the result - assert result is True + # Verify the result + assert result is True - # Verify correct methods were called with correct parameters - mock_dependencies['get_config'].assert_called_once() - mock_dependencies['store_class'].assert_called_once_with( - user_id='test_user_123', - session_maker=mock_dependencies['session_maker'], - config=mock_dependencies['get_config'].return_value, - ) - mock_dependencies['call_sync'].assert_called_once_with( - mock_dependencies['settings_store'].get_user_settings_by_keycloak_id, - 'test_user_123', - ) + # Verify correct methods were called with correct parameters + mock_dependencies['get_config'].assert_called_once() + mock_dependencies['store_class'].assert_called_once_with( + user_id='test_user_123', + session_maker=mock_dependencies['session_maker'], + config=mock_dependencies['get_config'].return_value, + ) + mock_dependencies['call_sync'].assert_called_once_with( + mock_dependencies['settings_store'].get_user_settings_by_keycloak_id, + 'test_user_123', + ) diff --git a/openhands/app_server/app_conversation/app_conversation_models.py b/openhands/app_server/app_conversation/app_conversation_models.py index 34d9c46031..de0c8db769 100644 --- a/openhands/app_server/app_conversation/app_conversation_models.py +++ b/openhands/app_server/app_conversation/app_conversation_models.py @@ -5,7 +5,7 @@ from uuid import UUID, uuid4 from pydantic import BaseModel, Field -from openhands.agent_server.models import SendMessageRequest +from openhands.agent_server.models import OpenHandsModel, SendMessageRequest from openhands.agent_server.utils import OpenHandsUUID, utc_now from openhands.app_server.event_callback.event_callback_models import ( EventCallbackProcessor, @@ -14,7 +14,6 @@ from openhands.app_server.sandbox.sandbox_models import SandboxStatus from openhands.integrations.service_types import ProviderType from openhands.sdk.conversation.state import ConversationExecutionStatus from openhands.sdk.llm import MetricsSnapshot -from openhands.sdk.utils.models import OpenHandsModel from openhands.storage.data_models.conversation_metadata import ConversationTrigger diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index 0074477bed..25b8052f27 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -7,7 +7,6 @@ import zipfile from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta -from time import time from typing import Any, AsyncGenerator, Sequence from uuid import UUID, uuid4 @@ -477,7 +476,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase): self, task: AppConversationStartTask ) -> AsyncGenerator[AppConversationStartTask, None]: """Wait for sandbox to start and return info.""" - # Get the sandbox + # Get or create the sandbox if not task.request.sandbox_id: sandbox = await self.sandbox_service.start_sandbox() task.sandbox_id = sandbox.id @@ -489,45 +488,34 @@ class LiveStatusAppConversationService(AppConversationServiceBase): raise SandboxError(f'Sandbox not found: {task.request.sandbox_id}') sandbox = sandbox_info - # Update the listener + # Update the listener with sandbox info task.status = AppConversationStartTaskStatus.WAITING_FOR_SANDBOX task.sandbox_id = sandbox.id yield task + # Resume if paused if sandbox.status == SandboxStatus.PAUSED: await self.sandbox_service.resume_sandbox(sandbox.id) + + # Check for immediate error states if sandbox.status in (None, SandboxStatus.ERROR): raise SandboxError(f'Sandbox status: {sandbox.status}') - if sandbox.status == SandboxStatus.RUNNING: - # There are still bugs in the remote runtime - they report running while still just - # starting resulting in a race condition. Manually check that it is actually - # running. - if await self._check_agent_server_alive(sandbox): - return - if sandbox.status != SandboxStatus.STARTING: + + # For non-STARTING/RUNNING states (except PAUSED which we just resumed), fail fast + if sandbox.status not in ( + SandboxStatus.STARTING, + SandboxStatus.RUNNING, + SandboxStatus.PAUSED, + ): raise SandboxError(f'Sandbox not startable: {sandbox.id}') - start = time() - while time() - start <= self.sandbox_startup_timeout: - await asyncio.sleep(self.sandbox_startup_poll_frequency) - sandbox_info = await self.sandbox_service.get_sandbox(sandbox.id) - if sandbox_info is None: - raise SandboxError(f'Sandbox not found: {sandbox.id}') - if sandbox.status not in (SandboxStatus.STARTING, SandboxStatus.RUNNING): - raise SandboxError(f'Sandbox not startable: {sandbox.id}') - if sandbox_info.status == SandboxStatus.RUNNING: - # There are still bugs in the remote runtime - they report running while still just - # starting resulting in a race condition. Manually check that it is actually - # running. - if await self._check_agent_server_alive(sandbox_info): - return - raise SandboxError(f'Sandbox failed to start: {sandbox.id}') - - async def _check_agent_server_alive(self, sandbox_info: SandboxInfo) -> bool: - agent_server_url = self._get_agent_server_url(sandbox_info) - url = f'{agent_server_url.rstrip("/")}/alive' - response = await self.httpx_client.get(url) - return response.is_success + # Use shared wait_for_sandbox_running utility to poll for ready state + await self.sandbox_service.wait_for_sandbox_running( + sandbox.id, + timeout=self.sandbox_startup_timeout, + poll_interval=self.sandbox_startup_poll_frequency, + httpx_client=self.httpx_client, + ) def _get_agent_server_url(self, sandbox: SandboxInfo) -> str: """Get agent server url for running sandbox.""" diff --git a/openhands/app_server/sandbox/sandbox_service.py b/openhands/app_server/sandbox/sandbox_service.py index 45274975d7..7319e4f5c7 100644 --- a/openhands/app_server/sandbox/sandbox_service.py +++ b/openhands/app_server/sandbox/sandbox_service.py @@ -1,12 +1,20 @@ import asyncio +import time from abc import ABC, abstractmethod +import httpx + +from openhands.app_server.errors import SandboxError from openhands.app_server.sandbox.sandbox_models import ( + AGENT_SERVER, SandboxInfo, SandboxPage, SandboxStatus, ) from openhands.app_server.services.injector import Injector +from openhands.app_server.utils.docker_utils import ( + replace_localhost_hostname_for_docker, +) from openhands.sdk.utils.models import DiscriminatedUnionMixin from openhands.sdk.utils.paging import page_iterator @@ -56,6 +64,96 @@ class SandboxService(ABC): Return False if the sandbox did not exist. """ + async def wait_for_sandbox_running( + self, + sandbox_id: str, + timeout: int = 120, + poll_interval: int = 2, + httpx_client: httpx.AsyncClient | None = None, + ) -> SandboxInfo: + """Wait for a sandbox to reach RUNNING status with an alive agent server. + + This method polls the sandbox status until it reaches RUNNING state and + optionally verifies the agent server is responding to health checks. + + Args: + sandbox_id: The sandbox ID to wait for + timeout: Maximum time to wait in seconds (default: 120) + poll_interval: Time between status checks in seconds (default: 2) + httpx_client: Optional httpx client for agent server health checks. + If provided, will verify the agent server /alive endpoint responds + before returning. + + Returns: + SandboxInfo with RUNNING status and verified agent server + + Raises: + SandboxError: If sandbox not found, enters ERROR state, or times out + """ + start = time.time() + while time.time() - start <= timeout: + sandbox = await self.get_sandbox(sandbox_id) + if sandbox is None: + raise SandboxError(f'Sandbox not found: {sandbox_id}') + + if sandbox.status == SandboxStatus.ERROR: + raise SandboxError(f'Sandbox entered error state: {sandbox_id}') + + if sandbox.status == SandboxStatus.RUNNING: + # Optionally verify agent server is alive to avoid race conditions + # where sandbox reports RUNNING but agent server isn't ready yet + if httpx_client and sandbox.exposed_urls: + if await self._check_agent_server_alive(sandbox, httpx_client): + return sandbox + # Agent server not ready yet, continue polling + else: + return sandbox + + await asyncio.sleep(poll_interval) + + raise SandboxError(f'Sandbox failed to start within {timeout}s: {sandbox_id}') + + async def _check_agent_server_alive( + self, sandbox: SandboxInfo, httpx_client: httpx.AsyncClient + ) -> bool: + """Check if the agent server is responding to health checks. + + Args: + sandbox: The sandbox info containing exposed URLs + httpx_client: HTTP client to make the health check request + + Returns: + True if agent server is alive, False otherwise + """ + try: + agent_server_url = self._get_agent_server_url(sandbox) + url = f'{agent_server_url.rstrip("/")}/alive' + response = await httpx_client.get(url, timeout=5.0) + return response.is_success + except Exception: + return False + + def _get_agent_server_url(self, sandbox: SandboxInfo) -> str: + """Get agent server URL from sandbox exposed URLs. + + Args: + sandbox: The sandbox info containing exposed URLs + + Returns: + The agent server URL + + Raises: + SandboxError: If no agent server URL is found + """ + if not sandbox.exposed_urls: + raise SandboxError(f'No exposed URLs for sandbox: {sandbox.id}') + + for exposed_url in sandbox.exposed_urls: + if exposed_url.name == AGENT_SERVER: + return replace_localhost_hostname_for_docker(exposed_url.url) + + raise SandboxError(f'No agent server URL found for sandbox: {sandbox.id}') + @abstractmethod async def pause_sandbox(self, sandbox_id: str) -> bool: """Begin the process of pausing a sandbox.