mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Fix LocalRuntime to properly handle existing subprocesses (#8821)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
ab1cdb5b5f
commit
72c24b461c
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@ -104,7 +105,7 @@ class OpenHandsConfig(BaseModel):
|
||||
max_concurrent_conversations: int = Field(
|
||||
default=3
|
||||
) # Maximum number of concurrent agent loops allowed per user
|
||||
mcp_host: str = Field(default='localhost:3000')
|
||||
mcp_host: str = Field(default=f'localhost:{os.getenv("port", 3000)}')
|
||||
mcp: MCPConfig = Field(default_factory=MCPConfig)
|
||||
|
||||
defaults_dict: ClassVar[dict] = {}
|
||||
|
||||
@ -110,6 +110,9 @@ class CLIRuntime(Runtime):
|
||||
)
|
||||
logger.info(f'Created temporary workspace at {self._workspace_path}')
|
||||
|
||||
# Runtime tests rely on this being set correctly.
|
||||
self.config.workspace_mount_path_in_sandbox = self._workspace_path
|
||||
|
||||
# Initialize runtime state
|
||||
self._runtime_initialized = False
|
||||
self.file_editor = OHEditor(workspace_root=self._workspace_path)
|
||||
|
||||
@ -320,7 +320,10 @@ class DockerRuntime(ActionExecutionClient):
|
||||
environment = {
|
||||
'port': str(self._container_port),
|
||||
'PYTHONUNBUFFERED': '1',
|
||||
# Passing in the ports means nested runtimes do not come up with their own ports!
|
||||
'VSCODE_PORT': str(self._vscode_port),
|
||||
'APP_PORT_1': self._app_ports[0],
|
||||
'APP_PORT_2': self._app_ports[1],
|
||||
'PIP_BREAK_SYSTEM_PACKAGES': '1',
|
||||
}
|
||||
if self.config.debug or DEBUG:
|
||||
|
||||
@ -6,7 +6,9 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import tenacity
|
||||
@ -39,6 +41,24 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionExecutionServerInfo:
|
||||
"""Information about a running server process."""
|
||||
|
||||
process: subprocess.Popen
|
||||
execution_server_port: int
|
||||
vscode_port: int
|
||||
app_ports: list[int]
|
||||
log_thread: threading.Thread
|
||||
log_thread_exit_event: threading.Event
|
||||
temp_workspace: str | None
|
||||
workspace_mount_path: str
|
||||
|
||||
|
||||
# Global dictionary to track running server processes by session ID
|
||||
_RUNNING_SERVERS: dict[str, ActionExecutionServerInfo] = {}
|
||||
|
||||
|
||||
def get_user_info() -> tuple[int, str | None]:
|
||||
"""Get user ID and username in a cross-platform way."""
|
||||
username = os.getenv('USER')
|
||||
@ -135,50 +155,25 @@ class LocalRuntime(ActionExecutionClient):
|
||||
self.config = config
|
||||
self._user_id, self._username = get_user_info()
|
||||
|
||||
if self.config.workspace_base is not None:
|
||||
logger.warning(
|
||||
f'Workspace base path is set to {self.config.workspace_base}. '
|
||||
'It will be used as the path for the agent to run in. '
|
||||
'Be careful, the agent can EDIT files in this directory!'
|
||||
)
|
||||
self.config.workspace_mount_path_in_sandbox = self.config.workspace_base
|
||||
self._temp_workspace = None
|
||||
else:
|
||||
# A temporary directory is created for the agent to run in
|
||||
# This is used for the local runtime only
|
||||
self._temp_workspace = tempfile.mkdtemp(
|
||||
prefix=f'openhands_workspace_{sid}',
|
||||
)
|
||||
self.config.workspace_mount_path_in_sandbox = self._temp_workspace
|
||||
|
||||
logger.warning(
|
||||
'Initializing LocalRuntime. WARNING: NO SANDBOX IS USED. '
|
||||
'This is an experimental feature, please report issues to https://github.com/All-Hands-AI/OpenHands/issues. '
|
||||
'`run_as_openhands` will be ignored since the current user will be used to launch the server. '
|
||||
'We highly recommend using a sandbox (eg. DockerRuntime) unless you '
|
||||
'are running in a controlled environment.\n'
|
||||
f'Temp workspace: {self._temp_workspace}. '
|
||||
f'User ID: {self._user_id}. '
|
||||
f'Username: {self._username}.'
|
||||
)
|
||||
|
||||
if self.config.workspace_base is not None:
|
||||
logger.warning(
|
||||
f'Workspace base path is set to {self.config.workspace_base}. It will be used as the path for the agent to run in.'
|
||||
)
|
||||
self.config.workspace_mount_path_in_sandbox = self.config.workspace_base
|
||||
else:
|
||||
logger.warning(
|
||||
'Workspace base path is NOT set. Agent will run in a temporary directory.'
|
||||
)
|
||||
self._temp_workspace = tempfile.mkdtemp()
|
||||
self.config.workspace_mount_path_in_sandbox = self._temp_workspace
|
||||
|
||||
self._host_port = -1
|
||||
# Initialize these values to be set in connect()
|
||||
self._temp_workspace: str | None = None
|
||||
self._execution_server_port = -1
|
||||
self._vscode_port = -1
|
||||
self._app_ports: list[int] = []
|
||||
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._host_port}'
|
||||
self.api_url = (
|
||||
f'{self.config.sandbox.local_runtime_url}:{self._execution_server_port}'
|
||||
)
|
||||
self.status_callback = status_callback
|
||||
self.server_process: subprocess.Popen[str] | None = None
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
@ -210,101 +205,180 @@ class LocalRuntime(ActionExecutionClient):
|
||||
return self.api_url
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Start the action_execution_server on the local machine."""
|
||||
"""Start the action_execution_server on the local machine or connect to an existing one."""
|
||||
self.send_status_message('STATUS$STARTING_RUNTIME')
|
||||
|
||||
self._host_port = self._find_available_port(EXECUTION_SERVER_PORT_RANGE)
|
||||
self._vscode_port = self._find_available_port(VSCODE_PORT_RANGE)
|
||||
self._app_ports = [
|
||||
self._find_available_port(APP_PORT_RANGE_1),
|
||||
self._find_available_port(APP_PORT_RANGE_2),
|
||||
]
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._host_port}'
|
||||
# Check if there's already a server running for this session ID
|
||||
if self.sid in _RUNNING_SERVERS:
|
||||
self.log('info', f'Connecting to existing server for session {self.sid}')
|
||||
server_info = _RUNNING_SERVERS[self.sid]
|
||||
self.server_process = server_info.process
|
||||
self._execution_server_port = server_info.execution_server_port
|
||||
self._log_thread = server_info.log_thread
|
||||
self._log_thread_exit_event = server_info.log_thread_exit_event
|
||||
self._vscode_port = server_info.vscode_port
|
||||
self._app_ports = server_info.app_ports
|
||||
self._temp_workspace = server_info.temp_workspace
|
||||
self.config.workspace_mount_path_in_sandbox = (
|
||||
server_info.workspace_mount_path
|
||||
)
|
||||
self.api_url = (
|
||||
f'{self.config.sandbox.local_runtime_url}:{self._execution_server_port}'
|
||||
)
|
||||
elif self.attach_to_existing:
|
||||
# If we're supposed to attach to an existing server but none exists, raise an error
|
||||
self.log('error', f'No existing server found for session {self.sid}')
|
||||
raise AgentRuntimeDisconnectedError(
|
||||
f'No existing server found for session {self.sid}'
|
||||
)
|
||||
else:
|
||||
# Set up workspace directory
|
||||
if self.config.workspace_base is not None:
|
||||
logger.warning(
|
||||
f'Workspace base path is set to {self.config.workspace_base}. '
|
||||
'It will be used as the path for the agent to run in. '
|
||||
'Be careful, the agent can EDIT files in this directory!'
|
||||
)
|
||||
self.config.workspace_mount_path_in_sandbox = self.config.workspace_base
|
||||
self._temp_workspace = None
|
||||
else:
|
||||
# A temporary directory is created for the agent to run in
|
||||
logger.warning(
|
||||
'Workspace base path is NOT set. Agent will run in a temporary directory.'
|
||||
)
|
||||
self._temp_workspace = tempfile.mkdtemp(
|
||||
prefix=f'openhands_workspace_{self.sid}',
|
||||
)
|
||||
self.config.workspace_mount_path_in_sandbox = self._temp_workspace
|
||||
|
||||
# Start the server process
|
||||
cmd = get_action_execution_server_startup_command(
|
||||
server_port=self._host_port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
python_prefix=['poetry', 'run'],
|
||||
override_user_id=self._user_id,
|
||||
override_username=self._username,
|
||||
)
|
||||
logger.info(
|
||||
f'Using workspace directory: {self.config.workspace_mount_path_in_sandbox}'
|
||||
)
|
||||
|
||||
self.log('debug', f'Starting server with command: {cmd}')
|
||||
env = os.environ.copy()
|
||||
# Get the code repo path
|
||||
code_repo_path = os.path.dirname(os.path.dirname(openhands.__file__))
|
||||
env['PYTHONPATH'] = os.pathsep.join([code_repo_path, env.get('PYTHONPATH', '')])
|
||||
env['OPENHANDS_REPO_PATH'] = code_repo_path
|
||||
env['LOCAL_RUNTIME_MODE'] = '1'
|
||||
env['VSCODE_PORT'] = str(self._vscode_port)
|
||||
# Start a new server
|
||||
self._execution_server_port = self._find_available_port(
|
||||
EXECUTION_SERVER_PORT_RANGE
|
||||
)
|
||||
self._vscode_port = int(
|
||||
os.getenv('VSCODE_PORT')
|
||||
or str(self._find_available_port(VSCODE_PORT_RANGE))
|
||||
)
|
||||
self._app_ports = [
|
||||
int(
|
||||
os.getenv('APP_PORT_1')
|
||||
or str(self._find_available_port(APP_PORT_RANGE_1))
|
||||
),
|
||||
int(
|
||||
os.getenv('APP_PORT_2')
|
||||
or str(self._find_available_port(APP_PORT_RANGE_2))
|
||||
),
|
||||
]
|
||||
self.api_url = (
|
||||
f'{self.config.sandbox.local_runtime_url}:{self._execution_server_port}'
|
||||
)
|
||||
|
||||
# Derive environment paths using sys.executable
|
||||
interpreter_path = sys.executable
|
||||
python_bin_path = os.path.dirname(interpreter_path)
|
||||
env_root_path = os.path.dirname(python_bin_path)
|
||||
# Start the server process
|
||||
cmd = get_action_execution_server_startup_command(
|
||||
server_port=self._execution_server_port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
python_prefix=['poetry', 'run'],
|
||||
override_user_id=self._user_id,
|
||||
override_username=self._username,
|
||||
)
|
||||
|
||||
# Prepend the interpreter's bin directory to PATH for subprocesses
|
||||
env['PATH'] = f'{python_bin_path}{os.pathsep}{env.get("PATH", "")}'
|
||||
logger.debug(f'Updated PATH for subprocesses: {env["PATH"]}')
|
||||
self.log('debug', f'Starting server with command: {cmd}')
|
||||
env = os.environ.copy()
|
||||
# Get the code repo path
|
||||
code_repo_path = os.path.dirname(os.path.dirname(openhands.__file__))
|
||||
env['PYTHONPATH'] = os.pathsep.join(
|
||||
[code_repo_path, env.get('PYTHONPATH', '')]
|
||||
)
|
||||
env['OPENHANDS_REPO_PATH'] = code_repo_path
|
||||
env['LOCAL_RUNTIME_MODE'] = '1'
|
||||
env['VSCODE_PORT'] = str(self._vscode_port)
|
||||
|
||||
# Check dependencies using the derived env_root_path
|
||||
check_dependencies(code_repo_path, env_root_path)
|
||||
self.server_process = subprocess.Popen( # noqa: S603
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
cwd=code_repo_path, # Explicitly set the working directory
|
||||
)
|
||||
# Derive environment paths using sys.executable
|
||||
interpreter_path = sys.executable
|
||||
python_bin_path = os.path.dirname(interpreter_path)
|
||||
env_root_path = os.path.dirname(python_bin_path)
|
||||
|
||||
# Start a thread to read and log server output
|
||||
def log_output() -> None:
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
self.log('error', 'Server process or stdout not available for logging.')
|
||||
return
|
||||
# Prepend the interpreter's bin directory to PATH for subprocesses
|
||||
env['PATH'] = f'{python_bin_path}{os.pathsep}{env.get("PATH", "")}'
|
||||
logger.debug(f'Updated PATH for subprocesses: {env["PATH"]}')
|
||||
|
||||
try:
|
||||
# Read lines while the process is running and stdout is available
|
||||
while self.server_process.poll() is None:
|
||||
if self._log_thread_exit_event.is_set(): # Check exit event
|
||||
self.log('info', 'Log thread received exit signal.')
|
||||
break # Exit loop if signaled
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
# Process might have exited between poll() and readline()
|
||||
break
|
||||
self.log('info', f'Server: {line.strip()}')
|
||||
# Check dependencies using the derived env_root_path
|
||||
check_dependencies(code_repo_path, env_root_path)
|
||||
self.server_process = subprocess.Popen( # noqa: S603
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
cwd=code_repo_path, # Explicitly set the working directory
|
||||
)
|
||||
|
||||
# Capture any remaining output after the process exits OR if signaled
|
||||
if (
|
||||
not self._log_thread_exit_event.is_set()
|
||||
): # Check again before reading remaining
|
||||
self.log('info', 'Server process exited, reading remaining output.')
|
||||
for line in self.server_process.stdout:
|
||||
if (
|
||||
self._log_thread_exit_event.is_set()
|
||||
): # Check inside loop too
|
||||
self.log(
|
||||
'info',
|
||||
'Log thread received exit signal while reading remaining output.',
|
||||
)
|
||||
# Start a thread to read and log server output
|
||||
def log_output() -> None:
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
self.log(
|
||||
'error', 'Server process or stdout not available for logging.'
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Read lines while the process is running and stdout is available
|
||||
while self.server_process.poll() is None:
|
||||
if self._log_thread_exit_event.is_set(): # Check exit event
|
||||
self.log('info', 'Log thread received exit signal.')
|
||||
break # Exit loop if signaled
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
# Process might have exited between poll() and readline()
|
||||
break
|
||||
self.log('info', f'Server (remaining): {line.strip()}')
|
||||
self.log('info', f'Server: {line.strip()}')
|
||||
|
||||
except Exception as e:
|
||||
# Log the error, but don't prevent the thread from potentially exiting
|
||||
self.log('error', f'Error reading server output: {e}')
|
||||
finally:
|
||||
self.log(
|
||||
'info', 'Log output thread finished.'
|
||||
) # Add log for thread exit
|
||||
# Capture any remaining output after the process exits OR if signaled
|
||||
if (
|
||||
not self._log_thread_exit_event.is_set()
|
||||
): # Check again before reading remaining
|
||||
self.log(
|
||||
'info', 'Server process exited, reading remaining output.'
|
||||
)
|
||||
for line in self.server_process.stdout:
|
||||
if (
|
||||
self._log_thread_exit_event.is_set()
|
||||
): # Check inside loop too
|
||||
self.log(
|
||||
'info',
|
||||
'Log thread received exit signal while reading remaining output.',
|
||||
)
|
||||
break
|
||||
self.log('info', f'Server (remaining): {line.strip()}')
|
||||
|
||||
self._log_thread = threading.Thread(target=log_output, daemon=True)
|
||||
self._log_thread.start()
|
||||
except Exception as e:
|
||||
# Log the error, but don't prevent the thread from potentially exiting
|
||||
self.log('error', f'Error reading server output: {e}')
|
||||
finally:
|
||||
self.log(
|
||||
'info', 'Log output thread finished.'
|
||||
) # Add log for thread exit
|
||||
|
||||
self._log_thread = threading.Thread(target=log_output, daemon=True)
|
||||
self._log_thread.start()
|
||||
|
||||
# Store the server process in the global dictionary
|
||||
_RUNNING_SERVERS[self.sid] = ActionExecutionServerInfo(
|
||||
process=self.server_process,
|
||||
execution_server_port=self._execution_server_port,
|
||||
vscode_port=self._vscode_port,
|
||||
app_ports=self._app_ports,
|
||||
log_thread=self._log_thread,
|
||||
log_thread_exit_event=self._log_thread_exit_event,
|
||||
temp_workspace=self._temp_workspace,
|
||||
workspace_mount_path=self.config.workspace_mount_path_in_sandbox,
|
||||
)
|
||||
|
||||
self.log('info', f'Waiting for server to become ready at {self.api_url}...')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
@ -356,7 +430,19 @@ class LocalRuntime(ActionExecutionClient):
|
||||
if not self.runtime_initialized:
|
||||
raise AgentRuntimeDisconnectedError('Runtime not initialized')
|
||||
|
||||
if self.server_process is None or self.server_process.poll() is not None:
|
||||
# Check if our server process is still valid
|
||||
if self.server_process is None:
|
||||
# Check if there's a server in the global dictionary
|
||||
if self.sid in _RUNNING_SERVERS:
|
||||
self.server_process = _RUNNING_SERVERS[self.sid].process
|
||||
else:
|
||||
raise AgentRuntimeDisconnectedError('Server process not found')
|
||||
|
||||
# Check if the server process is still running
|
||||
if self.server_process.poll() is not None:
|
||||
# If the process died, remove it from the global dictionary
|
||||
if self.sid in _RUNNING_SERVERS:
|
||||
del _RUNNING_SERVERS[self.sid]
|
||||
raise AgentRuntimeDisconnectedError('Server process died')
|
||||
|
||||
with self.action_semaphore:
|
||||
@ -372,8 +458,25 @@ class LocalRuntime(ActionExecutionClient):
|
||||
raise AgentRuntimeDisconnectedError('Server connection lost')
|
||||
|
||||
def close(self) -> None:
|
||||
"""Stop the server process."""
|
||||
self._log_thread_exit_event.set() # Signal the log thread to exit
|
||||
"""Stop the server process if not in attach_to_existing mode."""
|
||||
# If we're in attach_to_existing mode, don't close the server
|
||||
if self.attach_to_existing:
|
||||
self.log(
|
||||
'info',
|
||||
f'Not closing server for session {self.sid} (attach_to_existing=True)',
|
||||
)
|
||||
# Just clean up our reference to the process, but leave it running
|
||||
self.server_process = None
|
||||
# Don't clean up temp workspace when attach_to_existing=True
|
||||
super().close()
|
||||
return
|
||||
|
||||
# Signal the log thread to exit
|
||||
self._log_thread_exit_event.set()
|
||||
|
||||
# Remove from global dictionary
|
||||
if self.sid in _RUNNING_SERVERS:
|
||||
del _RUNNING_SERVERS[self.sid]
|
||||
|
||||
if self.server_process:
|
||||
self.server_process.terminate()
|
||||
@ -384,21 +487,45 @@ class LocalRuntime(ActionExecutionClient):
|
||||
self.server_process = None
|
||||
self._log_thread.join(timeout=5) # Add timeout to join
|
||||
|
||||
if self._temp_workspace:
|
||||
# Clean up temp workspace if it exists and we created it
|
||||
if self._temp_workspace and not self.attach_to_existing:
|
||||
shutil.rmtree(self._temp_workspace)
|
||||
self._temp_workspace = None
|
||||
|
||||
super().close()
|
||||
|
||||
@classmethod
|
||||
async def delete(cls, conversation_id: str) -> None:
|
||||
"""Delete the runtime for a conversation."""
|
||||
if conversation_id in _RUNNING_SERVERS:
|
||||
logger.info(f'Deleting LocalRuntime for conversation {conversation_id}')
|
||||
server_info = _RUNNING_SERVERS[conversation_id]
|
||||
|
||||
# Signal the log thread to exit
|
||||
server_info.log_thread_exit_event.set()
|
||||
|
||||
# Terminate the server process
|
||||
if server_info.process:
|
||||
server_info.process.terminate()
|
||||
try:
|
||||
server_info.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
server_info.process.kill()
|
||||
|
||||
# Wait for the log thread to finish
|
||||
server_info.log_thread.join(timeout=5)
|
||||
|
||||
# Remove from global dictionary
|
||||
del _RUNNING_SERVERS[conversation_id]
|
||||
logger.info(f'LocalRuntime for conversation {conversation_id} deleted')
|
||||
|
||||
@property
|
||||
def runtime_url(self) -> str:
|
||||
|
||||
runtime_url = os.getenv('RUNTIME_URL')
|
||||
if runtime_url:
|
||||
return runtime_url
|
||||
|
||||
|
||||
|
||||
#TODO: This could be removed if we had a straightforward variable containing the RUNTIME_URL in the K8 env.
|
||||
# TODO: This could be removed if we had a straightforward variable containing the RUNTIME_URL in the K8 env.
|
||||
runtime_url_pattern = os.getenv('RUNTIME_URL_PATTERN')
|
||||
hostname = os.getenv('HOSTNAME')
|
||||
if runtime_url_pattern and hostname:
|
||||
@ -414,8 +541,14 @@ class LocalRuntime(ActionExecutionClient):
|
||||
token = super().get_vscode_token()
|
||||
if not token:
|
||||
return None
|
||||
vscode_url = f'{self.runtime_url}:{self._vscode_port}/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}'
|
||||
return vscode_url
|
||||
runtime_url = self.runtime_url
|
||||
if 'localhost' in runtime_url:
|
||||
vscode_url = f'{self.runtime_url}:{self._vscode_port}'
|
||||
else:
|
||||
# Similar to remote runtime...
|
||||
parsed_url = urlparse(runtime_url)
|
||||
vscode_url = f'{parsed_url.scheme}://vscode-{parsed_url.netloc}'
|
||||
return f'{vscode_url}/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}'
|
||||
|
||||
@property
|
||||
def web_hosts(self) -> dict[str, int]:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import random
|
||||
import socket
|
||||
import time
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
def check_port_available(port: int) -> bool:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
@ -194,6 +194,8 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
settings_json = settings.model_dump(context={'expose_secrets': True})
|
||||
settings_json.pop('custom_secrets', None)
|
||||
settings_json.pop('git_provider_tokens', None)
|
||||
if settings_json.get('git_provider'):
|
||||
settings_json['git_provider'] = settings_json['git_provider'].value
|
||||
secrets_store = settings_json.pop('secrets_store', None) or {}
|
||||
response = await client.post(
|
||||
f'{api_url}/api/settings', json=settings_json
|
||||
@ -421,8 +423,12 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
)
|
||||
env_vars['SERVE_FRONTEND'] = '0'
|
||||
env_vars['RUNTIME'] = 'local'
|
||||
env_vars['USER'] = 'CURRENT_USER'
|
||||
# TODO: In the long term we may come up with a more secure strategy for user management within the nested runtime.
|
||||
env_vars['USER'] = 'root'
|
||||
env_vars['SESSION_API_KEY'] = self._get_session_api_key_for_conversation(sid)
|
||||
# We need to be able to specify the nested conversation id within the nested runtime
|
||||
env_vars['ALLOW_SET_CONVERSATION_ID'] = '1'
|
||||
env_vars['WORKSPACE_BASE'] = f'/workspace'
|
||||
|
||||
# Set up mounted volume for conversation directory within workspace
|
||||
# TODO: Check if we are using the standard event store and file store
|
||||
@ -447,7 +453,6 @@ class DockerNestedConversationManager(ConversationManager):
|
||||
plugins=agent.sandbox_plugins,
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
env_vars=env_vars,
|
||||
main_module='openhands.server',
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@ -60,7 +61,9 @@ class InitSessionRequest(BaseModel):
|
||||
replay_json: str | None = None
|
||||
suggested_task: SuggestedTask | None = None
|
||||
conversation_instructions: str | None = None
|
||||
conversation_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
|
||||
# Only nested runtimes require the ability to specify a conversation id, and it could be a security risk
|
||||
if os.getenv('ALLOW_SET_CONVERSATION_ID', '0') == '1':
|
||||
conversation_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
@ -122,7 +125,7 @@ async def new_conversation(
|
||||
# Check against git_provider, otherwise check all provider apis
|
||||
await provider_handler.verify_repo_provider(repository, git_provider)
|
||||
|
||||
conversation_id = data.conversation_id
|
||||
conversation_id = getattr(data, 'conversation_id', None) or uuid.uuid4().hex
|
||||
await create_new_conversation(
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@ -84,28 +85,28 @@ async def create_new_conversation(
|
||||
# For nested runtimes, we allow a single conversation id, passed in on container creation
|
||||
if conversation_id is None:
|
||||
conversation_id = uuid.uuid4().hex
|
||||
while await conversation_store.exists(conversation_id):
|
||||
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
|
||||
conversation_id = uuid.uuid4().hex
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
git_provider=git_provider,
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
await conversation_store.save_metadata(
|
||||
ConversationMetadata(
|
||||
trigger=conversation_trigger,
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
git_provider=git_provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Starting agent loop for conversation {conversation_id}',
|
||||
|
||||
@ -289,7 +289,7 @@ def _load_runtime(
|
||||
|
||||
call_async_from_sync(runtime.connect)
|
||||
time.sleep(2)
|
||||
return runtime, config
|
||||
return runtime, runtime.config
|
||||
|
||||
|
||||
# Export necessary function
|
||||
|
||||
@ -257,7 +257,6 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
image_urls=['https://example.com/image.jpg'],
|
||||
conversation_id='test_conversation_id',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
@ -266,7 +265,9 @@ async def test_new_conversation_success(provider_handler_mock):
|
||||
# Verify the response
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
assert response.conversation_id == 'test_conversation_id'
|
||||
# Don't check the exact conversation_id as it's now generated dynamically
|
||||
assert response.conversation_id is not None
|
||||
assert isinstance(response.conversation_id, str)
|
||||
|
||||
# Verify that create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
@ -314,7 +315,6 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
suggested_task=test_task,
|
||||
conversation_id='test_conversation_id',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
@ -323,7 +323,9 @@ async def test_new_conversation_with_suggested_task(provider_handler_mock):
|
||||
# Verify the response
|
||||
assert isinstance(response, InitSessionResponse)
|
||||
assert response.status == 'ok'
|
||||
assert response.conversation_id == 'test_conversation_id'
|
||||
# Don't check the exact conversation_id as it's now generated dynamically
|
||||
assert response.conversation_id is not None
|
||||
assert isinstance(response.conversation_id, str)
|
||||
|
||||
# Verify that create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
@ -582,18 +584,20 @@ async def test_new_conversation_with_provider_authentication_error(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_with_unsupported_params(test_client):
|
||||
test_request_data = {
|
||||
'repository': 'test/repo', # This is valid
|
||||
'selected_branch': 'main', # This is valid
|
||||
'initial_user_msg': 'Hello, agent!', # Valid parameter
|
||||
'unsupported_param': 'unsupported param', # Invalid, unsupported parameter
|
||||
}
|
||||
async def test_new_conversation_with_unsupported_params():
|
||||
"""Test that unsupported parameters are rejected."""
|
||||
# Create a test request with an unsupported parameter
|
||||
with _patch_store():
|
||||
# Create a direct instance of InitSessionRequest to test validation
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
# This should raise a validation error because of the extra parameter
|
||||
InitSessionRequest(
|
||||
repository='test/repo',
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
unsupported_param='unsupported param', # This should cause validation to fail
|
||||
)
|
||||
|
||||
# Send the POST request to the appropriate endpoint
|
||||
response = test_client.post('/api/conversations', json=test_request_data)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
assert 'Extra inputs are not permitted' in response.text
|
||||
assert 'unsupported param' in response.text
|
||||
# Verify that the error message mentions the unsupported parameter
|
||||
assert 'Extra inputs are not permitted' in str(excinfo.value)
|
||||
assert 'unsupported_param' in str(excinfo.value)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user