diff --git a/openhands/core/config/openhands_config.py b/openhands/core/config/openhands_config.py index b5d8e60c2a..f23d10a46a 100644 --- a/openhands/core/config/openhands_config.py +++ b/openhands/core/config/openhands_config.py @@ -1,3 +1,4 @@ +import os from typing import Any, ClassVar from pydantic import BaseModel, Field, SecretStr @@ -104,7 +105,7 @@ class OpenHandsConfig(BaseModel): max_concurrent_conversations: int = Field( default=3 ) # Maximum number of concurrent agent loops allowed per user - mcp_host: str = Field(default='localhost:3000') + mcp_host: str = Field(default=f'localhost:{os.getenv("port", 3000)}') mcp: MCPConfig = Field(default_factory=MCPConfig) defaults_dict: ClassVar[dict] = {} diff --git a/openhands/runtime/impl/cli/cli_runtime.py b/openhands/runtime/impl/cli/cli_runtime.py index c2d45028f6..8947aad3d2 100644 --- a/openhands/runtime/impl/cli/cli_runtime.py +++ b/openhands/runtime/impl/cli/cli_runtime.py @@ -110,6 +110,9 @@ class CLIRuntime(Runtime): ) logger.info(f'Created temporary workspace at {self._workspace_path}') + # Runtime tests rely on this being set correctly. + self.config.workspace_mount_path_in_sandbox = self._workspace_path + # Initialize runtime state self._runtime_initialized = False self.file_editor = OHEditor(workspace_root=self._workspace_path) diff --git a/openhands/runtime/impl/docker/docker_runtime.py b/openhands/runtime/impl/docker/docker_runtime.py index 233809d5a4..01e511a9ad 100644 --- a/openhands/runtime/impl/docker/docker_runtime.py +++ b/openhands/runtime/impl/docker/docker_runtime.py @@ -320,7 +320,10 @@ class DockerRuntime(ActionExecutionClient): environment = { 'port': str(self._container_port), 'PYTHONUNBUFFERED': '1', + # Passing in the ports means nested runtimes do not come up with their own ports! 'VSCODE_PORT': str(self._vscode_port), + 'APP_PORT_1': self._app_ports[0], + 'APP_PORT_2': self._app_ports[1], 'PIP_BREAK_SYSTEM_PACKAGES': '1', } if self.config.debug or DEBUG: diff --git a/openhands/runtime/impl/local/local_runtime.py b/openhands/runtime/impl/local/local_runtime.py index 1be1acb7b1..f1d6867794 100644 --- a/openhands/runtime/impl/local/local_runtime.py +++ b/openhands/runtime/impl/local/local_runtime.py @@ -6,7 +6,9 @@ import subprocess import sys import tempfile import threading +from dataclasses import dataclass from typing import Callable +from urllib.parse import urlparse import httpx import tenacity @@ -39,6 +41,24 @@ from openhands.utils.async_utils import call_sync_from_async from openhands.utils.tenacity_stop import stop_if_should_exit +@dataclass +class ActionExecutionServerInfo: + """Information about a running server process.""" + + process: subprocess.Popen + execution_server_port: int + vscode_port: int + app_ports: list[int] + log_thread: threading.Thread + log_thread_exit_event: threading.Event + temp_workspace: str | None + workspace_mount_path: str + + +# Global dictionary to track running server processes by session ID +_RUNNING_SERVERS: dict[str, ActionExecutionServerInfo] = {} + + def get_user_info() -> tuple[int, str | None]: """Get user ID and username in a cross-platform way.""" username = os.getenv('USER') @@ -135,50 +155,25 @@ class LocalRuntime(ActionExecutionClient): self.config = config self._user_id, self._username = get_user_info() - if self.config.workspace_base is not None: - logger.warning( - f'Workspace base path is set to {self.config.workspace_base}. ' - 'It will be used as the path for the agent to run in. ' - 'Be careful, the agent can EDIT files in this directory!' - ) - self.config.workspace_mount_path_in_sandbox = self.config.workspace_base - self._temp_workspace = None - else: - # A temporary directory is created for the agent to run in - # This is used for the local runtime only - self._temp_workspace = tempfile.mkdtemp( - prefix=f'openhands_workspace_{sid}', - ) - self.config.workspace_mount_path_in_sandbox = self._temp_workspace - logger.warning( 'Initializing LocalRuntime. WARNING: NO SANDBOX IS USED. ' 'This is an experimental feature, please report issues to https://github.com/All-Hands-AI/OpenHands/issues. ' '`run_as_openhands` will be ignored since the current user will be used to launch the server. ' 'We highly recommend using a sandbox (eg. DockerRuntime) unless you ' 'are running in a controlled environment.\n' - f'Temp workspace: {self._temp_workspace}. ' f'User ID: {self._user_id}. ' f'Username: {self._username}.' ) - if self.config.workspace_base is not None: - logger.warning( - f'Workspace base path is set to {self.config.workspace_base}. It will be used as the path for the agent to run in.' - ) - self.config.workspace_mount_path_in_sandbox = self.config.workspace_base - else: - logger.warning( - 'Workspace base path is NOT set. Agent will run in a temporary directory.' - ) - self._temp_workspace = tempfile.mkdtemp() - self.config.workspace_mount_path_in_sandbox = self._temp_workspace - - self._host_port = -1 + # Initialize these values to be set in connect() + self._temp_workspace: str | None = None + self._execution_server_port = -1 self._vscode_port = -1 self._app_ports: list[int] = [] - self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._host_port}' + self.api_url = ( + f'{self.config.sandbox.local_runtime_url}:{self._execution_server_port}' + ) self.status_callback = status_callback self.server_process: subprocess.Popen[str] | None = None self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time @@ -210,101 +205,180 @@ class LocalRuntime(ActionExecutionClient): return self.api_url async def connect(self) -> None: - """Start the action_execution_server on the local machine.""" + """Start the action_execution_server on the local machine or connect to an existing one.""" self.send_status_message('STATUS$STARTING_RUNTIME') - self._host_port = self._find_available_port(EXECUTION_SERVER_PORT_RANGE) - self._vscode_port = self._find_available_port(VSCODE_PORT_RANGE) - self._app_ports = [ - self._find_available_port(APP_PORT_RANGE_1), - self._find_available_port(APP_PORT_RANGE_2), - ] - self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._host_port}' + # Check if there's already a server running for this session ID + if self.sid in _RUNNING_SERVERS: + self.log('info', f'Connecting to existing server for session {self.sid}') + server_info = _RUNNING_SERVERS[self.sid] + self.server_process = server_info.process + self._execution_server_port = server_info.execution_server_port + self._log_thread = server_info.log_thread + self._log_thread_exit_event = server_info.log_thread_exit_event + self._vscode_port = server_info.vscode_port + self._app_ports = server_info.app_ports + self._temp_workspace = server_info.temp_workspace + self.config.workspace_mount_path_in_sandbox = ( + server_info.workspace_mount_path + ) + self.api_url = ( + f'{self.config.sandbox.local_runtime_url}:{self._execution_server_port}' + ) + elif self.attach_to_existing: + # If we're supposed to attach to an existing server but none exists, raise an error + self.log('error', f'No existing server found for session {self.sid}') + raise AgentRuntimeDisconnectedError( + f'No existing server found for session {self.sid}' + ) + else: + # Set up workspace directory + if self.config.workspace_base is not None: + logger.warning( + f'Workspace base path is set to {self.config.workspace_base}. ' + 'It will be used as the path for the agent to run in. ' + 'Be careful, the agent can EDIT files in this directory!' + ) + self.config.workspace_mount_path_in_sandbox = self.config.workspace_base + self._temp_workspace = None + else: + # A temporary directory is created for the agent to run in + logger.warning( + 'Workspace base path is NOT set. Agent will run in a temporary directory.' + ) + self._temp_workspace = tempfile.mkdtemp( + prefix=f'openhands_workspace_{self.sid}', + ) + self.config.workspace_mount_path_in_sandbox = self._temp_workspace - # Start the server process - cmd = get_action_execution_server_startup_command( - server_port=self._host_port, - plugins=self.plugins, - app_config=self.config, - python_prefix=['poetry', 'run'], - override_user_id=self._user_id, - override_username=self._username, - ) + logger.info( + f'Using workspace directory: {self.config.workspace_mount_path_in_sandbox}' + ) - self.log('debug', f'Starting server with command: {cmd}') - env = os.environ.copy() - # Get the code repo path - code_repo_path = os.path.dirname(os.path.dirname(openhands.__file__)) - env['PYTHONPATH'] = os.pathsep.join([code_repo_path, env.get('PYTHONPATH', '')]) - env['OPENHANDS_REPO_PATH'] = code_repo_path - env['LOCAL_RUNTIME_MODE'] = '1' - env['VSCODE_PORT'] = str(self._vscode_port) + # Start a new server + self._execution_server_port = self._find_available_port( + EXECUTION_SERVER_PORT_RANGE + ) + self._vscode_port = int( + os.getenv('VSCODE_PORT') + or str(self._find_available_port(VSCODE_PORT_RANGE)) + ) + self._app_ports = [ + int( + os.getenv('APP_PORT_1') + or str(self._find_available_port(APP_PORT_RANGE_1)) + ), + int( + os.getenv('APP_PORT_2') + or str(self._find_available_port(APP_PORT_RANGE_2)) + ), + ] + self.api_url = ( + f'{self.config.sandbox.local_runtime_url}:{self._execution_server_port}' + ) - # Derive environment paths using sys.executable - interpreter_path = sys.executable - python_bin_path = os.path.dirname(interpreter_path) - env_root_path = os.path.dirname(python_bin_path) + # Start the server process + cmd = get_action_execution_server_startup_command( + server_port=self._execution_server_port, + plugins=self.plugins, + app_config=self.config, + python_prefix=['poetry', 'run'], + override_user_id=self._user_id, + override_username=self._username, + ) - # Prepend the interpreter's bin directory to PATH for subprocesses - env['PATH'] = f'{python_bin_path}{os.pathsep}{env.get("PATH", "")}' - logger.debug(f'Updated PATH for subprocesses: {env["PATH"]}') + self.log('debug', f'Starting server with command: {cmd}') + env = os.environ.copy() + # Get the code repo path + code_repo_path = os.path.dirname(os.path.dirname(openhands.__file__)) + env['PYTHONPATH'] = os.pathsep.join( + [code_repo_path, env.get('PYTHONPATH', '')] + ) + env['OPENHANDS_REPO_PATH'] = code_repo_path + env['LOCAL_RUNTIME_MODE'] = '1' + env['VSCODE_PORT'] = str(self._vscode_port) - # Check dependencies using the derived env_root_path - check_dependencies(code_repo_path, env_root_path) - self.server_process = subprocess.Popen( # noqa: S603 - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True, - bufsize=1, - env=env, - cwd=code_repo_path, # Explicitly set the working directory - ) + # Derive environment paths using sys.executable + interpreter_path = sys.executable + python_bin_path = os.path.dirname(interpreter_path) + env_root_path = os.path.dirname(python_bin_path) - # Start a thread to read and log server output - def log_output() -> None: - if not self.server_process or not self.server_process.stdout: - self.log('error', 'Server process or stdout not available for logging.') - return + # Prepend the interpreter's bin directory to PATH for subprocesses + env['PATH'] = f'{python_bin_path}{os.pathsep}{env.get("PATH", "")}' + logger.debug(f'Updated PATH for subprocesses: {env["PATH"]}') - try: - # Read lines while the process is running and stdout is available - while self.server_process.poll() is None: - if self._log_thread_exit_event.is_set(): # Check exit event - self.log('info', 'Log thread received exit signal.') - break # Exit loop if signaled - line = self.server_process.stdout.readline() - if not line: - # Process might have exited between poll() and readline() - break - self.log('info', f'Server: {line.strip()}') + # Check dependencies using the derived env_root_path + check_dependencies(code_repo_path, env_root_path) + self.server_process = subprocess.Popen( # noqa: S603 + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + env=env, + cwd=code_repo_path, # Explicitly set the working directory + ) - # Capture any remaining output after the process exits OR if signaled - if ( - not self._log_thread_exit_event.is_set() - ): # Check again before reading remaining - self.log('info', 'Server process exited, reading remaining output.') - for line in self.server_process.stdout: - if ( - self._log_thread_exit_event.is_set() - ): # Check inside loop too - self.log( - 'info', - 'Log thread received exit signal while reading remaining output.', - ) + # Start a thread to read and log server output + def log_output() -> None: + if not self.server_process or not self.server_process.stdout: + self.log( + 'error', 'Server process or stdout not available for logging.' + ) + return + + try: + # Read lines while the process is running and stdout is available + while self.server_process.poll() is None: + if self._log_thread_exit_event.is_set(): # Check exit event + self.log('info', 'Log thread received exit signal.') + break # Exit loop if signaled + line = self.server_process.stdout.readline() + if not line: + # Process might have exited between poll() and readline() break - self.log('info', f'Server (remaining): {line.strip()}') + self.log('info', f'Server: {line.strip()}') - except Exception as e: - # Log the error, but don't prevent the thread from potentially exiting - self.log('error', f'Error reading server output: {e}') - finally: - self.log( - 'info', 'Log output thread finished.' - ) # Add log for thread exit + # Capture any remaining output after the process exits OR if signaled + if ( + not self._log_thread_exit_event.is_set() + ): # Check again before reading remaining + self.log( + 'info', 'Server process exited, reading remaining output.' + ) + for line in self.server_process.stdout: + if ( + self._log_thread_exit_event.is_set() + ): # Check inside loop too + self.log( + 'info', + 'Log thread received exit signal while reading remaining output.', + ) + break + self.log('info', f'Server (remaining): {line.strip()}') - self._log_thread = threading.Thread(target=log_output, daemon=True) - self._log_thread.start() + except Exception as e: + # Log the error, but don't prevent the thread from potentially exiting + self.log('error', f'Error reading server output: {e}') + finally: + self.log( + 'info', 'Log output thread finished.' + ) # Add log for thread exit + + self._log_thread = threading.Thread(target=log_output, daemon=True) + self._log_thread.start() + + # Store the server process in the global dictionary + _RUNNING_SERVERS[self.sid] = ActionExecutionServerInfo( + process=self.server_process, + execution_server_port=self._execution_server_port, + vscode_port=self._vscode_port, + app_ports=self._app_ports, + log_thread=self._log_thread, + log_thread_exit_event=self._log_thread_exit_event, + temp_workspace=self._temp_workspace, + workspace_mount_path=self.config.workspace_mount_path_in_sandbox, + ) self.log('info', f'Waiting for server to become ready at {self.api_url}...') self.send_status_message('STATUS$WAITING_FOR_CLIENT') @@ -356,7 +430,19 @@ class LocalRuntime(ActionExecutionClient): if not self.runtime_initialized: raise AgentRuntimeDisconnectedError('Runtime not initialized') - if self.server_process is None or self.server_process.poll() is not None: + # Check if our server process is still valid + if self.server_process is None: + # Check if there's a server in the global dictionary + if self.sid in _RUNNING_SERVERS: + self.server_process = _RUNNING_SERVERS[self.sid].process + else: + raise AgentRuntimeDisconnectedError('Server process not found') + + # Check if the server process is still running + if self.server_process.poll() is not None: + # If the process died, remove it from the global dictionary + if self.sid in _RUNNING_SERVERS: + del _RUNNING_SERVERS[self.sid] raise AgentRuntimeDisconnectedError('Server process died') with self.action_semaphore: @@ -372,8 +458,25 @@ class LocalRuntime(ActionExecutionClient): raise AgentRuntimeDisconnectedError('Server connection lost') def close(self) -> None: - """Stop the server process.""" - self._log_thread_exit_event.set() # Signal the log thread to exit + """Stop the server process if not in attach_to_existing mode.""" + # If we're in attach_to_existing mode, don't close the server + if self.attach_to_existing: + self.log( + 'info', + f'Not closing server for session {self.sid} (attach_to_existing=True)', + ) + # Just clean up our reference to the process, but leave it running + self.server_process = None + # Don't clean up temp workspace when attach_to_existing=True + super().close() + return + + # Signal the log thread to exit + self._log_thread_exit_event.set() + + # Remove from global dictionary + if self.sid in _RUNNING_SERVERS: + del _RUNNING_SERVERS[self.sid] if self.server_process: self.server_process.terminate() @@ -384,21 +487,45 @@ class LocalRuntime(ActionExecutionClient): self.server_process = None self._log_thread.join(timeout=5) # Add timeout to join - if self._temp_workspace: + # Clean up temp workspace if it exists and we created it + if self._temp_workspace and not self.attach_to_existing: shutil.rmtree(self._temp_workspace) + self._temp_workspace = None super().close() + @classmethod + async def delete(cls, conversation_id: str) -> None: + """Delete the runtime for a conversation.""" + if conversation_id in _RUNNING_SERVERS: + logger.info(f'Deleting LocalRuntime for conversation {conversation_id}') + server_info = _RUNNING_SERVERS[conversation_id] + + # Signal the log thread to exit + server_info.log_thread_exit_event.set() + + # Terminate the server process + if server_info.process: + server_info.process.terminate() + try: + server_info.process.wait(timeout=5) + except subprocess.TimeoutExpired: + server_info.process.kill() + + # Wait for the log thread to finish + server_info.log_thread.join(timeout=5) + + # Remove from global dictionary + del _RUNNING_SERVERS[conversation_id] + logger.info(f'LocalRuntime for conversation {conversation_id} deleted') + @property def runtime_url(self) -> str: - runtime_url = os.getenv('RUNTIME_URL') if runtime_url: return runtime_url - - - #TODO: This could be removed if we had a straightforward variable containing the RUNTIME_URL in the K8 env. + # TODO: This could be removed if we had a straightforward variable containing the RUNTIME_URL in the K8 env. runtime_url_pattern = os.getenv('RUNTIME_URL_PATTERN') hostname = os.getenv('HOSTNAME') if runtime_url_pattern and hostname: @@ -414,8 +541,14 @@ class LocalRuntime(ActionExecutionClient): token = super().get_vscode_token() if not token: return None - vscode_url = f'{self.runtime_url}:{self._vscode_port}/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}' - return vscode_url + runtime_url = self.runtime_url + if 'localhost' in runtime_url: + vscode_url = f'{self.runtime_url}:{self._vscode_port}' + else: + # Similar to remote runtime... + parsed_url = urlparse(runtime_url) + vscode_url = f'{parsed_url.scheme}://vscode-{parsed_url.netloc}' + return f'{vscode_url}/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}' @property def web_hosts(self) -> dict[str, int]: diff --git a/openhands/runtime/utils/system.py b/openhands/runtime/utils/system.py index 171a6575c2..61cc95b4d8 100644 --- a/openhands/runtime/utils/system.py +++ b/openhands/runtime/utils/system.py @@ -1,7 +1,7 @@ import random import socket import time - +from openhands.core.logger import openhands_logger as logger def check_port_available(port: int) -> bool: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/openhands/server/conversation_manager/docker_nested_conversation_manager.py b/openhands/server/conversation_manager/docker_nested_conversation_manager.py index 5156c6209a..2abd381667 100644 --- a/openhands/server/conversation_manager/docker_nested_conversation_manager.py +++ b/openhands/server/conversation_manager/docker_nested_conversation_manager.py @@ -194,6 +194,8 @@ class DockerNestedConversationManager(ConversationManager): settings_json = settings.model_dump(context={'expose_secrets': True}) settings_json.pop('custom_secrets', None) settings_json.pop('git_provider_tokens', None) + if settings_json.get('git_provider'): + settings_json['git_provider'] = settings_json['git_provider'].value secrets_store = settings_json.pop('secrets_store', None) or {} response = await client.post( f'{api_url}/api/settings', json=settings_json @@ -421,8 +423,12 @@ class DockerNestedConversationManager(ConversationManager): ) env_vars['SERVE_FRONTEND'] = '0' env_vars['RUNTIME'] = 'local' - env_vars['USER'] = 'CURRENT_USER' + # TODO: In the long term we may come up with a more secure strategy for user management within the nested runtime. + env_vars['USER'] = 'root' env_vars['SESSION_API_KEY'] = self._get_session_api_key_for_conversation(sid) + # We need to be able to specify the nested conversation id within the nested runtime + env_vars['ALLOW_SET_CONVERSATION_ID'] = '1' + env_vars['WORKSPACE_BASE'] = f'/workspace' # Set up mounted volume for conversation directory within workspace # TODO: Check if we are using the standard event store and file store @@ -447,7 +453,6 @@ class DockerNestedConversationManager(ConversationManager): plugins=agent.sandbox_plugins, headless_mode=False, attach_to_existing=False, - env_vars=env_vars, main_module='openhands.server', ) diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 1053884cd4..61d63544a4 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -1,4 +1,5 @@ import asyncio +import os import uuid from datetime import datetime, timezone @@ -60,7 +61,9 @@ class InitSessionRequest(BaseModel): replay_json: str | None = None suggested_task: SuggestedTask | None = None conversation_instructions: str | None = None - conversation_id: str = Field(default_factory=lambda: uuid.uuid4().hex) + # Only nested runtimes require the ability to specify a conversation id, and it could be a security risk + if os.getenv('ALLOW_SET_CONVERSATION_ID', '0') == '1': + conversation_id: str = Field(default_factory=lambda: uuid.uuid4().hex) model_config = {'extra': 'forbid'} @@ -122,7 +125,7 @@ async def new_conversation( # Check against git_provider, otherwise check all provider apis await provider_handler.verify_repo_provider(repository, git_provider) - conversation_id = data.conversation_id + conversation_id = getattr(data, 'conversation_id', None) or uuid.uuid4().hex await create_new_conversation( user_id=user_id, git_provider_tokens=provider_tokens, diff --git a/openhands/server/services/conversation_service.py b/openhands/server/services/conversation_service.py index 35601294fe..e0e9867d4c 100644 --- a/openhands/server/services/conversation_service.py +++ b/openhands/server/services/conversation_service.py @@ -1,3 +1,4 @@ +import os import uuid from typing import Any @@ -84,28 +85,28 @@ async def create_new_conversation( # For nested runtimes, we allow a single conversation id, passed in on container creation if conversation_id is None: conversation_id = uuid.uuid4().hex - while await conversation_store.exists(conversation_id): - logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...') - conversation_id = uuid.uuid4().hex - logger.info( - f'New conversation ID: {conversation_id}', - extra={'user_id': user_id, 'session_id': conversation_id}, - ) - conversation_title = get_default_conversation_title(conversation_id) + if not await conversation_store.exists(conversation_id): - logger.info(f'Saving metadata for conversation {conversation_id}') - await conversation_store.save_metadata( - ConversationMetadata( - trigger=conversation_trigger, - conversation_id=conversation_id, - title=conversation_title, - user_id=user_id, - selected_repository=selected_repository, - selected_branch=selected_branch, - git_provider=git_provider, + logger.info( + f'New conversation ID: {conversation_id}', + extra={'user_id': user_id, 'session_id': conversation_id}, + ) + + conversation_title = get_default_conversation_title(conversation_id) + + logger.info(f'Saving metadata for conversation {conversation_id}') + await conversation_store.save_metadata( + ConversationMetadata( + trigger=conversation_trigger, + conversation_id=conversation_id, + title=conversation_title, + user_id=user_id, + selected_repository=selected_repository, + selected_branch=selected_branch, + git_provider=git_provider, + ) ) - ) logger.info( f'Starting agent loop for conversation {conversation_id}', diff --git a/tests/runtime/conftest.py b/tests/runtime/conftest.py index fdd8bb0336..96483ccd01 100644 --- a/tests/runtime/conftest.py +++ b/tests/runtime/conftest.py @@ -289,7 +289,7 @@ def _load_runtime( call_async_from_sync(runtime.connect) time.sleep(2) - return runtime, config + return runtime, runtime.config # Export necessary function diff --git a/tests/unit/test_conversation.py b/tests/unit/test_conversation.py index dae39f3475..0cd7ecd837 100644 --- a/tests/unit/test_conversation.py +++ b/tests/unit/test_conversation.py @@ -257,7 +257,6 @@ async def test_new_conversation_success(provider_handler_mock): selected_branch='main', initial_user_msg='Hello, agent!', image_urls=['https://example.com/image.jpg'], - conversation_id='test_conversation_id', ) # Call new_conversation @@ -266,7 +265,9 @@ async def test_new_conversation_success(provider_handler_mock): # Verify the response assert isinstance(response, InitSessionResponse) assert response.status == 'ok' - assert response.conversation_id == 'test_conversation_id' + # Don't check the exact conversation_id as it's now generated dynamically + assert response.conversation_id is not None + assert isinstance(response.conversation_id, str) # Verify that create_new_conversation was called with the correct arguments mock_create_conversation.assert_called_once() @@ -314,7 +315,6 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock): repository='test/repo', selected_branch='main', suggested_task=test_task, - conversation_id='test_conversation_id', ) # Call new_conversation @@ -323,7 +323,9 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock): # Verify the response assert isinstance(response, InitSessionResponse) assert response.status == 'ok' - assert response.conversation_id == 'test_conversation_id' + # Don't check the exact conversation_id as it's now generated dynamically + assert response.conversation_id is not None + assert isinstance(response.conversation_id, str) # Verify that create_new_conversation was called with the correct arguments mock_create_conversation.assert_called_once() @@ -582,18 +584,20 @@ async def test_new_conversation_with_provider_authentication_error( @pytest.mark.asyncio -async def test_new_conversation_with_unsupported_params(test_client): - test_request_data = { - 'repository': 'test/repo', # This is valid - 'selected_branch': 'main', # This is valid - 'initial_user_msg': 'Hello, agent!', # Valid parameter - 'unsupported_param': 'unsupported param', # Invalid, unsupported parameter - } +async def test_new_conversation_with_unsupported_params(): + """Test that unsupported parameters are rejected.""" + # Create a test request with an unsupported parameter + with _patch_store(): + # Create a direct instance of InitSessionRequest to test validation + with pytest.raises(Exception) as excinfo: + # This should raise a validation error because of the extra parameter + InitSessionRequest( + repository='test/repo', + selected_branch='main', + initial_user_msg='Hello, agent!', + unsupported_param='unsupported param', # This should cause validation to fail + ) - # Send the POST request to the appropriate endpoint - response = test_client.post('/api/conversations', json=test_request_data) - - assert response.status_code == 422 # Validation error - - assert 'Extra inputs are not permitted' in response.text - assert 'unsupported param' in response.text + # Verify that the error message mentions the unsupported parameter + assert 'Extra inputs are not permitted' in str(excinfo.value) + assert 'unsupported_param' in str(excinfo.value)