merge: update settings schema branch with main and SDK

- merge latest main into the GUI settings schema branch
- keep schema-driven LLM settings page and tests after conflict resolution
- update lockfiles to SDK branch head c333aedd

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
openhands
2026-03-19 02:35:30 +00:00
48 changed files with 5942 additions and 1407 deletions

View File

@@ -2,10 +2,11 @@
This module tests the environment variable override functionality that allows
users to inject custom environment variables into sandbox environments via
OH_AGENT_SERVER_ENV_* environment variables.
OH_AGENT_SERVER_ENV environment variable and auto-forwarding of LLM_* variables.
The functionality includes:
- Parsing OH_AGENT_SERVER_ENV_* environment variables
- Auto-forwarding of LLM_* environment variables to agent-server containers
- Explicit overrides via OH_AGENT_SERVER_ENV JSON
- Merging them into sandbox specifications
- Integration across different sandbox types (Docker, Process, Remote)
"""
@@ -25,6 +26,7 @@ from openhands.app_server.sandbox.remote_sandbox_spec_service import (
get_default_sandbox_specs as get_default_remote_sandbox_specs,
)
from openhands.app_server.sandbox.sandbox_spec_service import (
AUTO_FORWARD_PREFIXES,
get_agent_server_env,
)
@@ -185,6 +187,114 @@ class TestGetAgentServerEnv:
assert result == expected
class TestLLMAutoForwarding:
"""Test cases for automatic forwarding of LLM_* environment variables."""
def test_auto_forward_prefixes_contains_llm(self):
"""Test that LLM_ is in the auto-forward prefixes."""
assert 'LLM_' in AUTO_FORWARD_PREFIXES
def test_llm_timeout_auto_forwarded(self):
"""Test that LLM_TIMEOUT is automatically forwarded."""
env_vars = {
'LLM_TIMEOUT': '3600',
'OTHER_VAR': 'should_not_be_included',
}
with patch.dict(os.environ, env_vars, clear=True):
result = get_agent_server_env()
assert 'LLM_TIMEOUT' in result
assert result['LLM_TIMEOUT'] == '3600'
assert 'OTHER_VAR' not in result
def test_llm_num_retries_auto_forwarded(self):
"""Test that LLM_NUM_RETRIES is automatically forwarded."""
env_vars = {
'LLM_NUM_RETRIES': '10',
}
with patch.dict(os.environ, env_vars, clear=True):
result = get_agent_server_env()
assert 'LLM_NUM_RETRIES' in result
assert result['LLM_NUM_RETRIES'] == '10'
def test_multiple_llm_vars_auto_forwarded(self):
"""Test that multiple LLM_* variables are automatically forwarded."""
env_vars = {
'LLM_TIMEOUT': '3600',
'LLM_NUM_RETRIES': '10',
'LLM_MODEL': 'gpt-4',
'LLM_BASE_URL': 'https://api.example.com',
'LLM_API_KEY': 'secret-key',
'NON_LLM_VAR': 'should_not_be_included',
}
with patch.dict(os.environ, env_vars, clear=True):
result = get_agent_server_env()
assert result['LLM_TIMEOUT'] == '3600'
assert result['LLM_NUM_RETRIES'] == '10'
assert result['LLM_MODEL'] == 'gpt-4'
assert result['LLM_BASE_URL'] == 'https://api.example.com'
assert result['LLM_API_KEY'] == 'secret-key'
assert 'NON_LLM_VAR' not in result
def test_explicit_override_takes_precedence(self):
"""Test that OH_AGENT_SERVER_ENV overrides auto-forwarded variables."""
env_vars = {
'LLM_TIMEOUT': '3600', # Auto-forwarded value
'OH_AGENT_SERVER_ENV': '{"LLM_TIMEOUT": "7200"}', # Explicit override
}
with patch.dict(os.environ, env_vars, clear=True):
result = get_agent_server_env()
# Explicit override should win
assert result['LLM_TIMEOUT'] == '7200'
def test_combined_auto_forward_and_explicit(self):
"""Test combining auto-forwarded and explicit variables."""
env_vars = {
'LLM_TIMEOUT': '3600', # Auto-forwarded
'LLM_NUM_RETRIES': '10', # Auto-forwarded
'OH_AGENT_SERVER_ENV': '{"DEBUG": "true", "CUSTOM_VAR": "value"}', # Explicit
}
with patch.dict(os.environ, env_vars, clear=True):
result = get_agent_server_env()
# Auto-forwarded
assert result['LLM_TIMEOUT'] == '3600'
assert result['LLM_NUM_RETRIES'] == '10'
# Explicit
assert result['DEBUG'] == 'true'
assert result['CUSTOM_VAR'] == 'value'
def test_no_llm_vars_returns_empty_without_explicit(self):
"""Test that no LLM_* vars and no explicit env returns empty dict."""
env_vars = {
'SOME_OTHER_VAR': 'value',
'ANOTHER_VAR': 'another_value',
}
with patch.dict(os.environ, env_vars, clear=True):
result = get_agent_server_env()
assert result == {}
def test_llm_prefix_is_case_sensitive(self):
"""Test that LLM_ prefix matching is case-sensitive."""
env_vars = {
'LLM_TIMEOUT': '3600', # Should be included
'llm_timeout': 'lowercase', # Should NOT be included (wrong case)
'Llm_Timeout': 'mixed', # Should NOT be included (wrong case)
}
with patch.dict(os.environ, env_vars, clear=True):
result = get_agent_server_env()
assert 'LLM_TIMEOUT' in result
assert result['LLM_TIMEOUT'] == '3600'
# Lowercase variants should not be included
assert 'llm_timeout' not in result
assert 'Llm_Timeout' not in result
class TestDockerSandboxSpecEnvironmentOverride:
"""Test environment variable override integration in Docker sandbox specs."""
@@ -476,3 +586,311 @@ class TestEnvironmentOverrideIntegration:
# Should not have the old variables
assert 'VAR1' not in spec_2.initial_env
assert 'VAR2' not in spec_2.initial_env
class TestDockerSandboxServiceEnvIntegration:
"""Integration tests for environment variable propagation to Docker sandbox containers.
These tests verify that environment variables are correctly propagated through
the entire flow from the app-server environment to the agent-server container.
"""
def test_llm_env_vars_propagated_to_container_run(self):
"""Test that LLM_* env vars are included in docker container.run() environment argument."""
from unittest.mock import patch
# Set up environment with LLM_* variables
env_vars = {
'LLM_TIMEOUT': '3600',
'LLM_NUM_RETRIES': '10',
'LLM_MODEL': 'gpt-4',
'OTHER_VAR': 'should_not_be_forwarded',
}
with patch.dict(os.environ, env_vars, clear=True):
# Create a sandbox spec using the actual factory to get LLM_* vars
specs = get_default_docker_sandbox_specs()
sandbox_spec = specs[0]
# Verify the sandbox spec has the LLM_* variables
assert 'LLM_TIMEOUT' in sandbox_spec.initial_env
assert sandbox_spec.initial_env['LLM_TIMEOUT'] == '3600'
assert 'LLM_NUM_RETRIES' in sandbox_spec.initial_env
assert sandbox_spec.initial_env['LLM_NUM_RETRIES'] == '10'
assert 'LLM_MODEL' in sandbox_spec.initial_env
assert sandbox_spec.initial_env['LLM_MODEL'] == 'gpt-4'
# Non-LLM_* variables should not be included
assert 'OTHER_VAR' not in sandbox_spec.initial_env
def test_explicit_oh_agent_server_env_overrides_llm_vars(self):
"""Test that OH_AGENT_SERVER_ENV can override auto-forwarded LLM_* variables."""
env_vars = {
'LLM_TIMEOUT': '3600', # Auto-forwarded value
'OH_AGENT_SERVER_ENV': '{"LLM_TIMEOUT": "7200"}', # Override value
}
with patch.dict(os.environ, env_vars, clear=True):
specs = get_default_docker_sandbox_specs()
sandbox_spec = specs[0]
# OH_AGENT_SERVER_ENV should take precedence
assert sandbox_spec.initial_env['LLM_TIMEOUT'] == '7200'
def test_multiple_llm_vars_combined_with_explicit_overrides(self):
"""Test complex scenario with multiple LLM_* vars and explicit overrides."""
env_vars = {
'LLM_TIMEOUT': '3600',
'LLM_NUM_RETRIES': '10',
'LLM_MODEL': 'gpt-4',
'LLM_TEMPERATURE': '0.7',
'OH_AGENT_SERVER_ENV': '{"LLM_MODEL": "gpt-3.5-turbo", "CUSTOM_VAR": "custom_value"}',
}
with patch.dict(os.environ, env_vars, clear=True):
specs = get_default_docker_sandbox_specs()
sandbox_spec = specs[0]
# Auto-forwarded LLM_* vars that weren't overridden
assert sandbox_spec.initial_env['LLM_TIMEOUT'] == '3600'
assert sandbox_spec.initial_env['LLM_NUM_RETRIES'] == '10'
assert sandbox_spec.initial_env['LLM_TEMPERATURE'] == '0.7'
# LLM_MODEL should be overridden by OH_AGENT_SERVER_ENV
assert sandbox_spec.initial_env['LLM_MODEL'] == 'gpt-3.5-turbo'
# Custom variable from OH_AGENT_SERVER_ENV
assert sandbox_spec.initial_env['CUSTOM_VAR'] == 'custom_value'
def test_sandbox_spec_env_passed_to_docker_container_run(self):
"""Test that sandbox spec's initial_env is passed to docker container run."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from openhands.app_server.sandbox.docker_sandbox_service import (
DockerSandboxService,
ExposedPort,
)
# Create mock docker client
mock_docker_client = MagicMock()
mock_container = MagicMock()
mock_container.name = 'oh-test-abc123'
mock_container.image.tags = ['test-image:latest']
mock_container.attrs = {
'Created': '2024-01-01T00:00:00Z',
'Config': {
'Env': ['SESSION_API_KEY=test-key'],
'WorkingDir': '/workspace',
},
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': '32768'}]}},
'HostConfig': {'NetworkMode': 'bridge'},
}
mock_container.status = 'running'
mock_docker_client.containers.run.return_value = mock_container
mock_docker_client.containers.list.return_value = []
# Create mock sandbox spec service
mock_spec_service = MagicMock()
# Create sandbox spec with LLM_* environment variables
env_vars = {
'LLM_TIMEOUT': '3600',
'LLM_NUM_RETRIES': '10',
}
with patch.dict(os.environ, env_vars, clear=True):
specs = get_default_docker_sandbox_specs()
sandbox_spec = specs[0]
mock_spec_service.get_default_sandbox_spec = AsyncMock(
return_value=sandbox_spec
)
# Create service
service = DockerSandboxService(
sandbox_spec_service=mock_spec_service,
container_name_prefix='oh-test-',
host_port=3000,
container_url_pattern='http://localhost:{port}',
mounts=[],
exposed_ports=[
ExposedPort(
name='AGENT_SERVER',
description='Agent server',
container_port=8000,
)
],
health_check_path='/health',
httpx_client=MagicMock(spec=httpx.AsyncClient),
max_num_sandboxes=5,
docker_client=mock_docker_client,
)
# Start sandbox
import asyncio
asyncio.get_event_loop().run_until_complete(service.start_sandbox())
# Verify docker was called with environment variables including LLM_*
call_kwargs = mock_docker_client.containers.run.call_args[1]
container_env = call_kwargs['environment']
# LLM_* variables should be in the container environment
assert 'LLM_TIMEOUT' in container_env
assert container_env['LLM_TIMEOUT'] == '3600'
assert 'LLM_NUM_RETRIES' in container_env
assert container_env['LLM_NUM_RETRIES'] == '10'
# Default variables should also be present
assert 'OPENVSCODE_SERVER_ROOT' in container_env
assert 'LOG_JSON' in container_env
def test_host_network_mode_with_env_var(self):
"""Test that AGENT_SERVER_USE_HOST_NETWORK affects container network mode."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from openhands.app_server.sandbox.docker_sandbox_service import (
DockerSandboxService,
ExposedPort,
_get_use_host_network_default,
)
# Test with environment variable set
with patch.dict(
os.environ, {'AGENT_SERVER_USE_HOST_NETWORK': 'true'}, clear=True
):
assert _get_use_host_network_default() is True
# Create mock docker client
mock_docker_client = MagicMock()
mock_container = MagicMock()
mock_container.name = 'oh-test-abc123'
mock_container.image.tags = ['test-image:latest']
mock_container.attrs = {
'Created': '2024-01-01T00:00:00Z',
'Config': {
'Env': ['SESSION_API_KEY=test-key'],
'WorkingDir': '/workspace',
},
'NetworkSettings': {'Ports': {}},
'HostConfig': {'NetworkMode': 'host'},
}
mock_container.status = 'running'
mock_docker_client.containers.run.return_value = mock_container
mock_docker_client.containers.list.return_value = []
# Create mock sandbox spec service
mock_spec_service = MagicMock()
specs = get_default_docker_sandbox_specs()
mock_spec_service.get_default_sandbox_spec = AsyncMock(
return_value=specs[0]
)
# Create service with host network enabled
service = DockerSandboxService(
sandbox_spec_service=mock_spec_service,
container_name_prefix='oh-test-',
host_port=3000,
container_url_pattern='http://localhost:{port}',
mounts=[],
exposed_ports=[
ExposedPort(
name='AGENT_SERVER',
description='Agent server',
container_port=8000,
)
],
health_check_path='/health',
httpx_client=MagicMock(spec=httpx.AsyncClient),
max_num_sandboxes=5,
docker_client=mock_docker_client,
use_host_network=True,
)
# Start sandbox
import asyncio
asyncio.get_event_loop().run_until_complete(service.start_sandbox())
# Verify docker was called with host network mode
call_kwargs = mock_docker_client.containers.run.call_args[1]
assert call_kwargs['network_mode'] == 'host'
# Port mappings should be None in host network mode
assert call_kwargs['ports'] is None
def test_bridge_network_mode_without_env_var(self):
"""Test that default (bridge) network mode is used when env var is not set."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from openhands.app_server.sandbox.docker_sandbox_service import (
DockerSandboxService,
ExposedPort,
_get_use_host_network_default,
)
# Test without environment variable
with patch.dict(os.environ, {}, clear=True):
assert _get_use_host_network_default() is False
# Create mock docker client
mock_docker_client = MagicMock()
mock_container = MagicMock()
mock_container.name = 'oh-test-abc123'
mock_container.image.tags = ['test-image:latest']
mock_container.attrs = {
'Created': '2024-01-01T00:00:00Z',
'Config': {
'Env': ['SESSION_API_KEY=test-key'],
'WorkingDir': '/workspace',
},
'NetworkSettings': {'Ports': {'8000/tcp': [{'HostPort': '32768'}]}},
'HostConfig': {'NetworkMode': 'bridge'},
}
mock_container.status = 'running'
mock_docker_client.containers.run.return_value = mock_container
mock_docker_client.containers.list.return_value = []
# Create mock sandbox spec service
mock_spec_service = MagicMock()
specs = get_default_docker_sandbox_specs()
mock_spec_service.get_default_sandbox_spec = AsyncMock(
return_value=specs[0]
)
# Create service with bridge network (default)
service = DockerSandboxService(
sandbox_spec_service=mock_spec_service,
container_name_prefix='oh-test-',
host_port=3000,
container_url_pattern='http://localhost:{port}',
mounts=[],
exposed_ports=[
ExposedPort(
name='AGENT_SERVER',
description='Agent server',
container_port=8000,
)
],
health_check_path='/health',
httpx_client=MagicMock(spec=httpx.AsyncClient),
max_num_sandboxes=5,
docker_client=mock_docker_client,
use_host_network=False,
)
# Start sandbox
import asyncio
asyncio.get_event_loop().run_until_complete(service.start_sandbox())
# Verify docker was called with bridge network mode (network_mode=None)
call_kwargs = mock_docker_client.containers.run.call_args[1]
assert call_kwargs['network_mode'] is None
# Port mappings should be present in bridge mode
assert call_kwargs['ports'] is not None
assert 8000 in call_kwargs['ports']

View File

@@ -263,7 +263,7 @@ class TestGetConversationHooks:
assert response.status_code == status.HTTP_404_NOT_FOUND
async def test_get_hooks_returns_404_when_sandbox_not_running(self):
async def test_get_hooks_returns_404_when_sandbox_not_found(self):
conversation_id = uuid4()
sandbox_id = str(uuid4())
@@ -291,3 +291,44 @@ class TestGetConversationHooks:
)
assert response.status_code == status.HTTP_404_NOT_FOUND
async def test_get_hooks_returns_empty_list_when_sandbox_paused(self):
conversation_id = uuid4()
sandbox_id = str(uuid4())
mock_conversation = AppConversation(
id=conversation_id,
created_by_user_id='test-user',
sandbox_id=sandbox_id,
sandbox_status=SandboxStatus.PAUSED,
)
mock_sandbox = SandboxInfo(
id=sandbox_id,
created_by_user_id='test-user',
status=SandboxStatus.PAUSED,
sandbox_spec_id=str(uuid4()),
session_api_key='test-api-key',
)
mock_app_conversation_service = MagicMock()
mock_app_conversation_service.get_app_conversation = AsyncMock(
return_value=mock_conversation
)
mock_sandbox_service = MagicMock()
mock_sandbox_service.get_sandbox = AsyncMock(return_value=mock_sandbox)
response = await get_conversation_hooks(
conversation_id=conversation_id,
app_conversation_service=mock_app_conversation_service,
sandbox_service=mock_sandbox_service,
sandbox_spec_service=MagicMock(),
httpx_client=AsyncMock(spec=httpx.AsyncClient),
)
assert response.status_code == status.HTTP_200_OK
import json
data = json.loads(response.body.decode('utf-8'))
assert data == {'hooks': []}

View File

@@ -203,7 +203,7 @@ class TestGetConversationSkills:
Arrange: Setup conversation but no sandbox
Act: Call get_conversation_skills endpoint
Assert: Response is 404 with sandbox error message
Assert: Response is 404
"""
# Arrange
conversation_id = uuid4()
@@ -237,19 +237,13 @@ class TestGetConversationSkills:
# Assert
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.body.decode('utf-8')
import json
data = json.loads(content)
assert 'error' in data
assert 'Sandbox not found' in data['error']
async def test_get_skills_returns_empty_list_when_sandbox_paused(self):
"""Test endpoint returns empty skills when sandbox is PAUSED (closed conversation).
async def test_get_skills_returns_404_when_sandbox_not_running(self):
"""Test endpoint returns 404 when sandbox is not in RUNNING state.
Arrange: Setup conversation with stopped sandbox
Arrange: Setup conversation with paused sandbox
Act: Call get_conversation_skills endpoint
Assert: Response is 404 with sandbox not running message
Assert: Response is 200 with empty skills list
"""
# Arrange
conversation_id = uuid4()
@@ -290,13 +284,12 @@ class TestGetConversationSkills:
)
# Assert
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.status_code == status.HTTP_200_OK
content = response.body.decode('utf-8')
import json
data = json.loads(content)
assert 'error' in data
assert 'not running' in data['error']
assert data == {'skills': []}
async def test_get_skills_handles_task_trigger_skills(self):
"""Test endpoint correctly handles skills with TaskTrigger.

View File

@@ -245,6 +245,61 @@ class TestDockerSandboxService:
assert len(result.items) == 0
assert result.next_page_id is None
async def test_search_sandboxes_skips_containers_with_no_image_tags(
self, service, mock_running_container
):
"""Test that containers with tagless images are skipped without crashing.
Regression test: when a container's image has been rebuilt with the same tag,
the old container's image loses its tags, causing container.image.tags to be
an empty list. Previously this caused an IndexError.
"""
# Setup a container with no image tags (e.g. image was retagged/rebuilt)
tagless_container = MagicMock()
tagless_container.name = 'oh-test-tagless'
tagless_container.status = 'paused'
tagless_container.image.tags = []
tagless_container.image.id = 'sha256:abc123def456'
tagless_container.attrs = {
'Created': '2024-01-15T10:30:00.000000000Z',
'Config': {'Env': []},
'NetworkSettings': {'Ports': {}},
}
service.docker_client.containers.list.return_value = [
mock_running_container,
tagless_container,
]
service.httpx_client.get.return_value.raise_for_status.return_value = None
# Execute - should not raise IndexError
result = await service.search_sandboxes()
# Verify - only the properly tagged container is returned
assert isinstance(result, SandboxPage)
assert len(result.items) == 1
assert result.items[0].id == 'oh-test-abc123'
async def test_get_sandbox_returns_none_for_tagless_image(self, service):
"""Test that get_sandbox returns None for containers with tagless images."""
tagless_container = MagicMock()
tagless_container.name = 'oh-test-tagless'
tagless_container.status = 'paused'
tagless_container.image.tags = []
tagless_container.image.id = 'sha256:abc123def456'
tagless_container.attrs = {
'Created': '2024-01-15T10:30:00.000000000Z',
'Config': {'Env': []},
'NetworkSettings': {'Ports': {}},
}
service.docker_client.containers.get.return_value = tagless_container
# Execute - should not raise IndexError
result = await service.get_sandbox('oh-test-tagless')
# Verify - returns None for tagless container
assert result is None
async def test_search_sandboxes_filters_by_prefix(self, service):
"""Test that search filters containers by name prefix."""
# Setup
@@ -1199,6 +1254,59 @@ class TestDockerSandboxServiceInjector:
injector = DockerSandboxServiceInjector(use_host_network=True)
assert injector.use_host_network is True
def test_use_host_network_from_agent_server_env_var(self):
"""Test that AGENT_SERVER_USE_HOST_NETWORK env var enables host network mode."""
import os
from unittest.mock import patch
from openhands.app_server.sandbox.docker_sandbox_service import (
DockerSandboxServiceInjector,
)
env_vars = {
'AGENT_SERVER_USE_HOST_NETWORK': 'true',
}
with patch.dict(os.environ, env_vars, clear=True):
injector = DockerSandboxServiceInjector()
assert injector.use_host_network is True
def test_use_host_network_env_var_accepts_various_true_values(self):
"""Test that use_host_network accepts various truthy values."""
import os
from unittest.mock import patch
from openhands.app_server.sandbox.docker_sandbox_service import (
DockerSandboxServiceInjector,
)
for true_value in ['true', 'TRUE', 'True', '1', 'yes', 'YES', 'Yes']:
env_vars = {'AGENT_SERVER_USE_HOST_NETWORK': true_value}
with patch.dict(os.environ, env_vars, clear=True):
injector = DockerSandboxServiceInjector()
assert injector.use_host_network is True, (
f'Failed for value: {true_value}'
)
def test_use_host_network_env_var_defaults_to_false(self):
"""Test that unset or empty env var defaults to False."""
import os
from unittest.mock import patch
from openhands.app_server.sandbox.docker_sandbox_service import (
DockerSandboxServiceInjector,
)
# Empty environment
with patch.dict(os.environ, {}, clear=True):
injector = DockerSandboxServiceInjector()
assert injector.use_host_network is False
# Empty string
with patch.dict(os.environ, {'AGENT_SERVER_USE_HOST_NETWORK': ''}, clear=True):
injector = DockerSandboxServiceInjector()
assert injector.use_host_network is False
class TestDockerSandboxServiceInjectorFromEnv:
"""Test cases for DockerSandboxServiceInjector environment variable configuration."""

View File

@@ -0,0 +1,294 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import httpx
import pytest
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversation,
)
from openhands.app_server.event_callback.event_callback_models import (
EventCallback,
EventCallbackStatus,
)
from openhands.app_server.event_callback.set_title_callback_processor import (
SetTitleCallbackProcessor,
)
from openhands.app_server.utils.docker_utils import (
replace_localhost_hostname_for_docker,
)
from openhands.sdk import Message, MessageEvent, TextContent
class _FakeHttpxClient:
def __init__(self, titles: list[str | None]):
self._titles = titles
self.calls: list[tuple[str, dict[str, str] | None]] = []
async def get(self, url: str, headers: dict[str, str] | None = None):
self.calls.append((url, headers))
idx = min(len(self.calls) - 1, len(self._titles) - 1)
request = httpx.Request('GET', url)
return httpx.Response(200, json={'title': self._titles[idx]}, request=request)
class _FailingHttpxClient:
def __init__(self, error: httpx.HTTPError):
self._error = error
self.calls: list[tuple[str, dict[str, str] | None]] = []
async def get(self, url: str, headers: dict[str, str] | None = None):
self.calls.append((url, headers))
raise self._error
@asynccontextmanager
async def _ctx(obj):
yield obj
@pytest.mark.asyncio
async def test_set_title_callback_processor_fetches_title_from_conversation():
conversation_id = uuid4()
session_api_key = 'test-session-key'
conversation_url = f'http://localhost:8000/api/conversations/{conversation_id.hex}'
app_conversation = AppConversation(
id=conversation_id,
created_by_user_id='user',
sandbox_id='sandbox',
title=f'Conversation {conversation_id.hex[:5]}',
conversation_url=conversation_url,
session_api_key=session_api_key,
)
app_conversation_service = AsyncMock()
app_conversation_service.get_app_conversation.return_value = app_conversation
app_conversation_info_service = AsyncMock()
event_callback_service = AsyncMock()
httpx_client = _FakeHttpxClient(titles=[None, None, None, 'Generated Title'])
def get_app_conversation_service(_state):
return _ctx(app_conversation_service)
def get_app_conversation_info_service(_state):
return _ctx(app_conversation_info_service)
def get_event_callback_service(_state):
return _ctx(event_callback_service)
def get_httpx_client(_state):
return _ctx(httpx_client)
callback = EventCallback(
conversation_id=conversation_id, processor=SetTitleCallbackProcessor()
)
event = MessageEvent(
source='user',
llm_message=Message(role='user', content=[TextContent(text='hi')]),
)
processor = SetTitleCallbackProcessor()
with (
patch(
'openhands.app_server.config.get_app_conversation_service',
get_app_conversation_service,
),
patch(
'openhands.app_server.config.get_app_conversation_info_service',
get_app_conversation_info_service,
),
patch(
'openhands.app_server.config.get_event_callback_service',
get_event_callback_service,
),
patch('openhands.app_server.config.get_httpx_client', get_httpx_client),
patch(
'openhands.app_server.event_callback.'
'set_title_callback_processor.asyncio.sleep',
new=AsyncMock(),
),
):
result = await processor(conversation_id, callback, event)
assert result is not None
assert len(httpx_client.calls) == 4
expected_url = replace_localhost_hostname_for_docker(conversation_url)
assert httpx_client.calls[0][0] == expected_url
assert httpx_client.calls[0][1] == {'X-Session-API-Key': session_api_key}
app_conversation_info_service.save_app_conversation_info.assert_called_once()
saved_info = app_conversation_info_service.save_app_conversation_info.call_args[0][
0
]
assert saved_info.title == 'Generated Title'
assert callback.status == EventCallbackStatus.DISABLED
event_callback_service.save_event_callback.assert_called_once()
@pytest.mark.asyncio
async def test_set_title_callback_processor_no_title_yet_returns_none():
conversation_id = uuid4()
session_api_key = 'test-session-key'
conversation_url = f'http://localhost:8000/api/conversations/{conversation_id.hex}'
app_conversation = AppConversation(
id=conversation_id,
created_by_user_id='user',
sandbox_id='sandbox',
title=f'Conversation {conversation_id.hex[:5]}',
conversation_url=conversation_url,
session_api_key=session_api_key,
)
app_conversation_service = AsyncMock()
app_conversation_service.get_app_conversation.return_value = app_conversation
app_conversation_info_service = AsyncMock()
event_callback_service = AsyncMock()
httpx_client = _FakeHttpxClient(titles=[None])
def get_app_conversation_service(_state):
return _ctx(app_conversation_service)
def get_app_conversation_info_service(_state):
return _ctx(app_conversation_info_service)
def get_event_callback_service(_state):
return _ctx(event_callback_service)
def get_httpx_client(_state):
return _ctx(httpx_client)
callback = EventCallback(
conversation_id=conversation_id, processor=SetTitleCallbackProcessor()
)
event = MessageEvent(
source='user',
llm_message=Message(role='user', content=[TextContent(text='hi')]),
)
processor = SetTitleCallbackProcessor()
with (
patch(
'openhands.app_server.config.get_app_conversation_service',
get_app_conversation_service,
),
patch(
'openhands.app_server.config.get_app_conversation_info_service',
get_app_conversation_info_service,
),
patch(
'openhands.app_server.config.get_event_callback_service',
get_event_callback_service,
),
patch('openhands.app_server.config.get_httpx_client', get_httpx_client),
patch(
'openhands.app_server.event_callback.'
'set_title_callback_processor.asyncio.sleep',
new=AsyncMock(),
),
):
result = await processor(conversation_id, callback, event)
assert result is None
app_conversation_info_service.save_app_conversation_info.assert_not_called()
event_callback_service.save_event_callback.assert_not_called()
assert callback.status == EventCallbackStatus.ACTIVE
@pytest.mark.asyncio
async def test_set_title_callback_processor_request_errors_return_none():
conversation_id = uuid4()
session_api_key = 'test-session-key'
conversation_url = f'http://localhost:8000/api/conversations/{conversation_id.hex}'
app_conversation = AppConversation(
id=conversation_id,
created_by_user_id='user',
sandbox_id='sandbox',
title=f'Conversation {conversation_id.hex[:5]}',
conversation_url=conversation_url,
session_api_key=session_api_key,
)
app_conversation_service = AsyncMock()
app_conversation_service.get_app_conversation.return_value = app_conversation
app_conversation_info_service = AsyncMock()
event_callback_service = AsyncMock()
httpx_client = _FailingHttpxClient(
httpx.RequestError(
'boom',
request=httpx.Request(
'GET', replace_localhost_hostname_for_docker(conversation_url)
),
)
)
def get_app_conversation_service(_state):
return _ctx(app_conversation_service)
def get_app_conversation_info_service(_state):
return _ctx(app_conversation_info_service)
def get_event_callback_service(_state):
return _ctx(event_callback_service)
def get_httpx_client(_state):
return _ctx(httpx_client)
callback = EventCallback(
conversation_id=conversation_id, processor=SetTitleCallbackProcessor()
)
event = MessageEvent(
source='user',
llm_message=Message(role='user', content=[TextContent(text='hi')]),
)
processor = SetTitleCallbackProcessor()
with (
patch(
'openhands.app_server.config.get_app_conversation_service',
get_app_conversation_service,
),
patch(
'openhands.app_server.config.get_app_conversation_info_service',
get_app_conversation_info_service,
),
patch(
'openhands.app_server.config.get_event_callback_service',
get_event_callback_service,
),
patch('openhands.app_server.config.get_httpx_client', get_httpx_client),
patch(
'openhands.app_server.event_callback.'
'set_title_callback_processor.asyncio.sleep',
new=AsyncMock(),
),
patch(
'openhands.app_server.event_callback.'
'set_title_callback_processor._logger.debug'
) as logger_debug,
):
result = await processor(conversation_id, callback, event)
assert result is None
assert len(httpx_client.calls) == 4
assert logger_debug.call_count == 4
app_conversation_info_service.save_app_conversation_info.assert_not_called()
event_callback_service.save_event_callback.assert_not_called()
assert callback.status == EventCallbackStatus.ACTIVE

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import shutil
import tempfile
import threading
from abc import ABC
from dataclasses import dataclass, field
from io import BytesIO, StringIO
@@ -122,6 +123,57 @@ class TestLocalFileStore(TestCase, _StorageTest):
f'Failed to remove temporary directory {self.temp_dir}: {e}'
)
def test_concurrent_writes_no_corruption(self):
"""Test that concurrent writes don't corrupt file content.
This test verifies the atomic write fix by having 9 threads write
progressively shorter strings to the same file simultaneously.
Without atomic writes, a shorter write following a longer write
could result in corrupted content (e.g., "123" followed by garbage
from the previous longer write).
The final content must be exactly one of the valid strings written,
with no trailing garbage from other writes.
"""
filename = 'concurrent_test.txt'
# Strings from longest to shortest: "123456789", "12345678", ..., "1"
valid_contents = ['123456789'[:i] for i in range(9, 0, -1)]
errors: list[Exception] = []
barrier = threading.Barrier(len(valid_contents))
def write_content(content: str):
try:
# Wait for all threads to be ready before writing
barrier.wait()
self.store.write(filename, content)
except Exception as e:
errors.append(e)
# Start all threads
threads = [
threading.Thread(target=write_content, args=(content,))
for content in valid_contents
]
for t in threads:
t.start()
for t in threads:
t.join()
# Check for errors during writes
self.assertEqual(
errors, [], f'Errors occurred during concurrent writes: {errors}'
)
# Read final content and verify it's one of the valid strings
final_content = self.store.read(filename)
self.assertIn(
final_content,
valid_contents,
f"File content '{final_content}' is not one of the valid strings. "
f'Length: {len(final_content)}. This indicates file corruption from '
f'concurrent writes (e.g., shorter write did not fully replace longer write).',
)
class TestInMemoryFileStore(TestCase, _StorageTest):
def setUp(self):