Pydantic-based configuration and setting objects (#6321)

Co-authored-by: Calvin Smith <calvin@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Calvin Smith 2025-01-17 12:33:22 -07:00 committed by GitHub
parent 899c1f8360
commit a12087243a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 337 additions and 432 deletions

View File

@ -80,7 +80,7 @@ def load_dependencies(runtime: Runtime) -> List[str]:
def init_task_env(runtime: Runtime, hostname: str, env_llm_config: LLMConfig):
command = (
f'SERVER_HOSTNAME={hostname} '
f'LITELLM_API_KEY={env_llm_config.api_key} '
f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} '
f'LITELLM_BASE_URL={env_llm_config.base_url} '
f'LITELLM_MODEL={env_llm_config.model} '
'bash /utils/init.sh'
@ -165,7 +165,7 @@ def run_evaluator(
runtime: Runtime, env_llm_config: LLMConfig, trajectory_path: str, result_path: str
):
command = (
f'LITELLM_API_KEY={env_llm_config.api_key} '
f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} '
f'LITELLM_BASE_URL={env_llm_config.base_url} '
f'LITELLM_MODEL={env_llm_config.model} '
f"DECRYPTION_KEY='theagentcompany is all you need' " # Hardcoded Key

View File

@ -52,30 +52,6 @@ class EvalMetadata(BaseModel):
details: dict[str, Any] | None = None
condenser_config: CondenserConfig | None = None
def model_dump(self, *args, **kwargs):
dumped_dict = super().model_dump(*args, **kwargs)
# avoid leaking sensitive information
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
if hasattr(self.condenser_config, 'llm_config'):
dumped_dict['condenser_config']['llm_config'] = (
self.condenser_config.llm_config.to_safe_dict()
)
return dumped_dict
def model_dump_json(self, *args, **kwargs):
dumped = super().model_dump_json(*args, **kwargs)
dumped_dict = json.loads(dumped)
# avoid leaking sensitive information
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
if hasattr(self.condenser_config, 'llm_config'):
dumped_dict['condenser_config']['llm_config'] = (
self.condenser_config.llm_config.to_safe_dict()
)
logger.debug(f'Dumped metadata: {dumped_dict}')
return json.dumps(dumped_dict)
class EvalOutput(BaseModel):
# NOTE: User-specified
@ -98,23 +74,6 @@ class EvalOutput(BaseModel):
# Optionally save the input test instance
instance: dict[str, Any] | None = None
def model_dump(self, *args, **kwargs):
dumped_dict = super().model_dump(*args, **kwargs)
# Remove None values
dumped_dict = {k: v for k, v in dumped_dict.items() if v is not None}
# Apply custom serialization for metadata (to avoid leaking sensitive information)
if self.metadata is not None:
dumped_dict['metadata'] = self.metadata.model_dump()
return dumped_dict
def model_dump_json(self, *args, **kwargs):
dumped = super().model_dump_json(*args, **kwargs)
dumped_dict = json.loads(dumped)
# Apply custom serialization for metadata (to avoid leaking sensitive information)
if 'metadata' in dumped_dict:
dumped_dict['metadata'] = json.loads(self.metadata.model_dump_json())
return json.dumps(dumped_dict)
class EvalException(Exception):
pass
@ -314,7 +273,7 @@ def update_progress(
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.write(result.model_dump_json() + '\n')
output_fp.flush()

View File

@ -37,21 +37,17 @@ export SANDBOX_TIMEOUT='300'
## Type Handling
The `load_from_env` function attempts to cast environment variable values to the types specified in the dataclasses. It handles:
The `load_from_env` function attempts to cast environment variable values to the types specified in the models. It handles:
- Basic types (str, int, bool)
- Optional types (e.g., `str | None`)
- Nested dataclasses
- Nested models
If type casting fails, an error is logged, and the default value is retained.
## Default Values
If an environment variable is not set, the default value specified in the dataclass is used.
## Nested Configurations
The `AppConfig` class contains nested configurations like `LLMConfig` and `AgentConfig`. The `load_from_env` function handles these by recursively processing nested dataclasses with updated prefixes.
If an environment variable is not set, the default value specified in the model is used.
## Security Considerations

View File

@ -1,11 +1,9 @@
from dataclasses import dataclass, field, fields
from pydantic import BaseModel, Field
from openhands.core.config.condenser_config import CondenserConfig, NoOpCondenserConfig
from openhands.core.config.config_utils import get_field_info
@dataclass
class AgentConfig:
class AgentConfig(BaseModel):
"""Configuration for the agent.
Attributes:
@ -22,20 +20,13 @@ class AgentConfig:
condenser: Configuration for the memory condenser. Default is NoOpCondenserConfig.
"""
codeact_enable_browsing: bool = True
codeact_enable_llm_editor: bool = False
codeact_enable_jupyter: bool = True
micro_agent_name: str | None = None
memory_enabled: bool = False
memory_max_threads: int = 3
llm_config: str | None = None
enable_prompt_extensions: bool = True
disabled_microagents: list[str] | None = None
condenser: CondenserConfig = field(default_factory=NoOpCondenserConfig) # type: ignore
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
result[f.name] = get_field_info(f)
return result
codeact_enable_browsing: bool = Field(default=True)
codeact_enable_llm_editor: bool = Field(default=False)
codeact_enable_jupyter: bool = Field(default=True)
micro_agent_name: str | None = Field(default=None)
memory_enabled: bool = Field(default=False)
memory_max_threads: int = Field(default=3)
llm_config: str | None = Field(default=None)
enable_prompt_extensions: bool = Field(default=False)
disabled_microagents: list[str] | None = Field(default=None)
condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig)

View File

@ -1,20 +1,20 @@
from dataclasses import dataclass, field, fields, is_dataclass
from typing import ClassVar
from pydantic import BaseModel, Field, SecretStr
from openhands.core import logger
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.config_utils import (
OH_DEFAULT_AGENT,
OH_MAX_ITERATIONS,
get_field_info,
model_defaults_to_dict,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
@dataclass
class AppConfig:
class AppConfig(BaseModel):
"""Configuration for the app.
Attributes:
@ -46,37 +46,39 @@ class AppConfig:
input is read line by line. When enabled, input continues until /exit command.
"""
llms: dict[str, LLMConfig] = field(default_factory=dict)
agents: dict = field(default_factory=dict)
default_agent: str = OH_DEFAULT_AGENT
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
runtime: str = 'docker'
file_store: str = 'local'
file_store_path: str = '/tmp/openhands_file_store'
save_trajectory_path: str | None = None
workspace_base: str | None = None
workspace_mount_path: str | None = None
workspace_mount_path_in_sandbox: str = '/workspace'
workspace_mount_rewrite: str | None = None
cache_dir: str = '/tmp/cache'
run_as_openhands: bool = True
max_iterations: int = OH_MAX_ITERATIONS
max_budget_per_task: float | None = None
e2b_api_key: str = ''
modal_api_token_id: str = ''
modal_api_token_secret: str = ''
disable_color: bool = False
jwt_secret: str = ''
debug: bool = False
file_uploads_max_file_size_mb: int = 0
file_uploads_restrict_file_types: bool = False
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
runloop_api_key: str | None = None
cli_multiline_input: bool = False
llms: dict[str, LLMConfig] = Field(default_factory=dict)
agents: dict = Field(default_factory=dict)
default_agent: str = Field(default=OH_DEFAULT_AGENT)
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
security: SecurityConfig = Field(default_factory=SecurityConfig)
runtime: str = Field(default='docker')
file_store: str = Field(default='local')
file_store_path: str = Field(default='/tmp/openhands_file_store')
save_trajectory_path: str | None = Field(default=None)
workspace_base: str | None = Field(default=None)
workspace_mount_path: str | None = Field(default=None)
workspace_mount_path_in_sandbox: str = Field(default='/workspace')
workspace_mount_rewrite: str | None = Field(default=None)
cache_dir: str = Field(default='/tmp/cache')
run_as_openhands: bool = Field(default=True)
max_iterations: int = Field(default=OH_MAX_ITERATIONS)
max_budget_per_task: float | None = Field(default=None)
e2b_api_key: SecretStr | None = Field(default=None)
modal_api_token_id: SecretStr | None = Field(default=None)
modal_api_token_secret: SecretStr | None = Field(default=None)
disable_color: bool = Field(default=False)
jwt_secret: SecretStr | None = Field(default=None)
debug: bool = Field(default=False)
file_uploads_max_file_size_mb: int = Field(default=0)
file_uploads_restrict_file_types: bool = Field(default=False)
file_uploads_allowed_extensions: list[str] = Field(default_factory=lambda: ['.*'])
runloop_api_key: SecretStr | None = Field(default=None)
cli_multiline_input: bool = Field(default=False)
defaults_dict: ClassVar[dict] = {}
model_config = {'extra': 'forbid'}
def get_llm_config(self, name='llm') -> LLMConfig:
"""'llm' is the name for default config (for backward compatibility prior to 0.8)."""
if name in self.llms:
@ -115,42 +117,7 @@ class AppConfig:
def get_agent_configs(self) -> dict[str, AgentConfig]:
return self.agents
def __post_init__(self):
def model_post_init(self, __context):
"""Post-initialization hook, called when the instance is created with only default values."""
AppConfig.defaults_dict = self.defaults_to_dict()
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
field_value = getattr(self, f.name)
# dataclasses compute their defaults themselves
if is_dataclass(type(field_value)):
result[f.name] = field_value.defaults_to_dict()
else:
result[f.name] = get_field_info(f)
return result
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
if attr_name in [
'e2b_api_key',
'github_token',
'jwt_secret',
'modal_api_token_id',
'modal_api_token_secret',
'runloop_api_key',
]:
attr_value = '******' if attr_value else None
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"AppConfig({', '.join(attr_str)}"
def __repr__(self):
return self.__str__()
super().model_post_init(__context)
AppConfig.defaults_dict = model_defaults_to_dict(self)

View File

@ -1,19 +1,22 @@
from types import UnionType
from typing import get_args, get_origin
from typing import Any, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
OH_DEFAULT_AGENT = 'CodeActAgent'
OH_MAX_ITERATIONS = 500
def get_field_info(f):
def get_field_info(field: FieldInfo) -> dict[str, Any]:
"""Extract information about a dataclass field: type, optional, and default.
Args:
f: The field to extract information from.
field: The field to extract information from.
Returns: A dict with the field's type, whether it's optional, and its default value.
"""
field_type = f.type
field_type = field.annotation
optional = False
# for types like str | None, find the non-None type and set optional to True
@ -33,7 +36,21 @@ def get_field_info(f):
)
# default is always present
default = f.default
default = field.default
# return a schema with the useful info for frontend
return {'type': type_name.lower(), 'optional': optional, 'default': default}
def model_defaults_to_dict(model: BaseModel) -> dict[str, Any]:
"""Serialize field information in a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for name, field in model.model_fields.items():
field_value = getattr(model, name)
if isinstance(field_value, BaseModel):
result[name] = model_defaults_to_dict(field_value)
else:
result[name] = get_field_info(field)
return result

View File

@ -1,14 +1,14 @@
import os
from dataclasses import dataclass, fields
from __future__ import annotations
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from openhands.core.config.config_utils import get_field_info
from openhands.core.logger import LOG_DIR
LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key']
@dataclass
class LLMConfig:
class LLMConfig(BaseModel):
"""Configuration for the LLM model.
Attributes:
@ -47,99 +47,56 @@ class LLMConfig:
native_tool_calling: Whether to use native tool calling if supported by the model. Can be True, False, or not set.
"""
model: str = 'claude-3-5-sonnet-20241022'
api_key: str | None = None
base_url: str | None = None
api_version: str | None = None
embedding_model: str = 'local'
embedding_base_url: str | None = None
embedding_deployment_name: str | None = None
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None
aws_region_name: str | None = None
openrouter_site_url: str = 'https://docs.all-hands.dev/'
openrouter_app_name: str = 'OpenHands'
num_retries: int = 8
retry_multiplier: float = 2
retry_min_wait: int = 15
retry_max_wait: int = 120
timeout: int | None = None
max_message_chars: int = 30_000 # maximum number of characters in an observation's content when sent to the llm
temperature: float = 0.0
top_p: float = 1.0
custom_llm_provider: str | None = None
max_input_tokens: int | None = None
max_output_tokens: int | None = None
input_cost_per_token: float | None = None
output_cost_per_token: float | None = None
ollama_base_url: str | None = None
model: str = Field(default='claude-3-5-sonnet-20241022')
api_key: SecretStr | None = Field(default=None)
base_url: str | None = Field(default=None)
api_version: str | None = Field(default=None)
embedding_model: str = Field(default='local')
embedding_base_url: str | None = Field(default=None)
embedding_deployment_name: str | None = Field(default=None)
aws_access_key_id: SecretStr | None = Field(default=None)
aws_secret_access_key: SecretStr | None = Field(default=None)
aws_region_name: str | None = Field(default=None)
openrouter_site_url: str = Field(default='https://docs.all-hands.dev/')
openrouter_app_name: str = Field(default='OpenHands')
num_retries: int = Field(default=8)
retry_multiplier: float = Field(default=2)
retry_min_wait: int = Field(default=15)
retry_max_wait: int = Field(default=120)
timeout: int | None = Field(default=None)
max_message_chars: int = Field(
default=30_000
) # maximum number of characters in an observation's content when sent to the llm
temperature: float = Field(default=0.0)
top_p: float = Field(default=1.0)
custom_llm_provider: str | None = Field(default=None)
max_input_tokens: int | None = Field(default=None)
max_output_tokens: int | None = Field(default=None)
input_cost_per_token: float | None = Field(default=None)
output_cost_per_token: float | None = Field(default=None)
ollama_base_url: str | None = Field(default=None)
# This setting can be sent in each call to litellm
drop_params: bool = True
drop_params: bool = Field(default=True)
# Note: this setting is actually global, unlike drop_params
modify_params: bool = True
disable_vision: bool | None = None
reasoning_effort: str | None = None
caching_prompt: bool = True
log_completions: bool = False
log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
custom_tokenizer: str | None = None
native_tool_calling: bool | None = None
modify_params: bool = Field(default=True)
disable_vision: bool | None = Field(default=None)
caching_prompt: bool = Field(default=True)
log_completions: bool = Field(default=False)
log_completions_folder: str = Field(default=os.path.join(LOG_DIR, 'completions'))
custom_tokenizer: str | None = Field(default=None)
native_tool_calling: bool | None = Field(default=None)
model_config = {'extra': 'forbid'}
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
result[f.name] = get_field_info(f)
return result
def model_post_init(self, __context: Any):
"""Post-initialization hook to assign OpenRouter-related variables to environment variables.
def __post_init__(self):
"""
Post-initialization hook to assign OpenRouter-related variables to environment variables.
This ensures that these values are accessible to litellm at runtime.
"""
super().model_post_init(__context)
# Assign OpenRouter-specific variables to environment variables
if self.openrouter_site_url:
os.environ['OR_SITE_URL'] = self.openrouter_site_url
if self.openrouter_app_name:
os.environ['OR_APP_NAME'] = self.openrouter_app_name
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
if attr_name in LLM_SENSITIVE_FIELDS:
attr_value = '******' if attr_value else None
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"LLMConfig({', '.join(attr_str)})"
def __repr__(self):
return self.__str__()
def to_safe_dict(self):
"""Return a dict with the sensitive fields replaced with ******."""
ret = self.__dict__.copy()
for k, v in ret.items():
if k in LLM_SENSITIVE_FIELDS:
ret[k] = '******' if v else None
elif isinstance(v, LLMConfig):
ret[k] = v.to_safe_dict()
return ret
@classmethod
def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig':
"""Create an LLMConfig object from a dictionary.
This function is used to create an LLMConfig object from a dictionary.
"""
# Keep None values to preserve defaults, filter out other dicts
args = {
k: v
for k, v in llm_config_dict.items()
if not isinstance(v, dict) or v is None
}
return cls(**args)

View File

@ -1,11 +1,9 @@
import os
from dataclasses import dataclass, field, fields
from openhands.core.config.config_utils import get_field_info
from pydantic import BaseModel, Field
@dataclass
class SandboxConfig:
class SandboxConfig(BaseModel):
"""Configuration for the sandbox.
Attributes:
@ -39,48 +37,32 @@ class SandboxConfig:
This should be a JSON string that will be parsed into a dictionary.
"""
remote_runtime_api_url: str = 'http://localhost:8000'
local_runtime_url: str = 'http://localhost'
keep_runtime_alive: bool = False
rm_all_containers: bool = False
api_key: str | None = None
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
runtime_container_image: str | None = None
user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
timeout: int = 120
remote_runtime_init_timeout: int = 180
enable_auto_lint: bool = (
False # once enabled, OpenHands would lint files after editing
remote_runtime_api_url: str = Field(default='http://localhost:8000')
local_runtime_url: str = Field(default='http://localhost')
keep_runtime_alive: bool = Field(default=False)
rm_all_containers: bool = Field(default=False)
api_key: str | None = Field(default=None)
base_container_image: str = Field(
default='nikolaik/python-nodejs:python3.12-nodejs22'
)
use_host_network: bool = False
runtime_extra_build_args: list[str] | None = None
initialize_plugins: bool = True
force_rebuild_runtime: bool = False
runtime_extra_deps: str | None = None
runtime_startup_env_vars: dict[str, str] = field(default_factory=dict)
browsergym_eval_env: str | None = None
platform: str | None = None
close_delay: int = 15
remote_runtime_resource_factor: int = 1
enable_gpu: bool = False
docker_runtime_kwargs: str | None = None
runtime_container_image: str | None = Field(default=None)
user_id: int = Field(default=os.getuid() if hasattr(os, 'getuid') else 1000)
timeout: int = Field(default=120)
remote_runtime_init_timeout: int = Field(default=180)
enable_auto_lint: bool = Field(
default=False # once enabled, OpenHands would lint files after editing
)
use_host_network: bool = Field(default=False)
runtime_extra_build_args: list[str] | None = Field(default=None)
initialize_plugins: bool = Field(default=True)
force_rebuild_runtime: bool = Field(default=False)
runtime_extra_deps: str | None = Field(default=None)
runtime_startup_env_vars: dict[str, str] = Field(default_factory=dict)
browsergym_eval_env: str | None = Field(default=None)
platform: str | None = Field(default=None)
close_delay: int = Field(default=900)
remote_runtime_resource_factor: int = Field(default=1)
enable_gpu: bool = Field(default=False)
docker_runtime_kwargs: str | None = Field(default=None)
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
dict = {}
for f in fields(self):
dict[f.name] = get_field_info(f)
return dict
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"SandboxConfig({', '.join(attr_str)})"
def __repr__(self):
return self.__str__()
model_config = {'extra': 'forbid'}

View File

@ -1,10 +1,7 @@
from dataclasses import dataclass, fields
from openhands.core.config.config_utils import get_field_info
from pydantic import BaseModel, Field
@dataclass
class SecurityConfig:
class SecurityConfig(BaseModel):
"""Configuration for security related functionalities.
Attributes:
@ -12,29 +9,5 @@ class SecurityConfig:
security_analyzer: The security analyzer to use.
"""
confirmation_mode: bool = False
security_analyzer: str | None = None
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
dict = {}
for f in fields(self):
dict[f.name] = get_field_info(f)
return dict
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"SecurityConfig({', '.join(attr_str)})"
@classmethod
def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig':
return cls(**security_config_dict)
def __repr__(self):
return self.__str__()
confirmation_mode: bool = Field(default=False)
security_analyzer: str | None = Field(default=None)

View File

@ -3,13 +3,13 @@ import os
import pathlib
import platform
import sys
from dataclasses import is_dataclass
from types import UnionType
from typing import Any, MutableMapping, get_args, get_origin
from uuid import uuid4
import toml
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
from openhands.core import logger
from openhands.core.config.agent_config import AgentConfig
@ -43,17 +43,19 @@ def load_from_env(cfg: AppConfig, env_or_toml_dict: dict | MutableMapping[str, s
return next((t for t in types if t is not type(None)), None)
# helper function to set attributes based on env vars
def set_attr_from_env(sub_config: Any, prefix=''):
"""Set attributes of a config dataclass based on environment variables."""
for field_name, field_type in sub_config.__annotations__.items():
def set_attr_from_env(sub_config: BaseModel, prefix=''):
"""Set attributes of a config model based on environment variables."""
for field_name, field_info in sub_config.model_fields.items():
field_value = getattr(sub_config, field_name)
field_type = field_info.annotation
# compute the expected env var name from the prefix and field name
# e.g. LLM_BASE_URL
env_var_name = (prefix + field_name).upper()
if is_dataclass(field_type):
# nested dataclass
nested_sub_config = getattr(sub_config, field_name)
set_attr_from_env(nested_sub_config, prefix=field_name + '_')
if isinstance(field_value, BaseModel):
set_attr_from_env(field_value, prefix=field_name + '_')
elif env_var_name in env_or_toml_dict:
# convert the env var to the correct type and set it
value = env_or_toml_dict[env_var_name]
@ -125,22 +127,40 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
if isinstance(value, dict):
try:
if key is not None and key.lower() == 'agent':
# Every entry here is either a field for the default `agent` config group, or itself a group
# The best way to tell the difference is to try to parse it as an AgentConfig object
agent_group_ids: set[str] = set()
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
try:
agent_config = AgentConfig(**nested_value)
except ValidationError:
continue
agent_group_ids.add(nested_key)
cfg.set_agent_config(agent_config, nested_key)
logger.openhands_logger.debug(
'Attempt to load default agent config from config toml'
)
non_dict_fields = {
k: v for k, v in value.items() if not isinstance(v, dict)
value_without_groups = {
k: v for k, v in value.items() if k not in agent_group_ids
}
agent_config = AgentConfig(**non_dict_fields)
agent_config = AgentConfig(**value_without_groups)
cfg.set_agent_config(agent_config, 'agent')
elif key is not None and key.lower() == 'llm':
# Every entry here is either a field for the default `llm` config group, or itself a group
# The best way to tell the difference is to try to parse it as an LLMConfig object
llm_group_ids: set[str] = set()
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
logger.openhands_logger.debug(
f'Attempt to load group {nested_key} from config toml as agent config'
)
agent_config = AgentConfig(**nested_value)
cfg.set_agent_config(agent_config, nested_key)
elif key is not None and key.lower() == 'llm':
try:
llm_config = LLMConfig(**nested_value)
except ValidationError:
continue
llm_group_ids.add(nested_key)
cfg.set_llm_config(llm_config, nested_key)
logger.openhands_logger.debug(
'Attempt to load default LLM config from config toml'
)
@ -150,7 +170,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
for k, v in value.items():
if not isinstance(v, dict):
generic_llm_fields[k] = v
generic_llm_config = LLMConfig.from_dict(generic_llm_fields)
generic_llm_config = LLMConfig(**generic_llm_fields)
cfg.set_llm_config(generic_llm_config, 'llm')
# Process custom named LLM configs
@ -170,22 +190,23 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
for k, v in nested_value.items():
if not isinstance(v, dict):
custom_fields[k] = v
merged_llm_dict = generic_llm_config.__dict__.copy()
merged_llm_dict = generic_llm_fields.copy()
merged_llm_dict.update(custom_fields)
custom_llm_config = LLMConfig.from_dict(merged_llm_dict)
custom_llm_config = LLMConfig(**merged_llm_dict)
cfg.set_llm_config(custom_llm_config, nested_key)
elif key is not None and key.lower() == 'security':
logger.openhands_logger.debug(
'Attempt to load security config from config toml'
)
security_config = SecurityConfig.from_dict(value)
security_config = SecurityConfig(**value)
cfg.security = security_config
elif not key.startswith('sandbox') and key.lower() != 'core':
logger.openhands_logger.warning(
f'Unknown key in {toml_file}: "{key}"'
)
except (TypeError, KeyError) as e:
except (TypeError, KeyError, ValidationError) as e:
logger.openhands_logger.warning(
f'Cannot parse [{key}] config from toml, values have not been applied.\nError: {e}',
)
@ -221,7 +242,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
logger.openhands_logger.warning(
f'Unknown config key "{key}" in [core] section'
)
except (TypeError, KeyError) as e:
except (TypeError, KeyError, ValidationError) as e:
logger.openhands_logger.warning(
f'Cannot parse [sandbox] config from toml, values have not been applied.\nError: {e}',
)
@ -324,7 +345,7 @@ def get_llm_config_arg(
# update the llm config with the specified section
if 'llm' in toml_config and llm_config_arg in toml_config['llm']:
return LLMConfig.from_dict(toml_config['llm'][llm_config_arg])
return LLMConfig(**toml_config['llm'][llm_config_arg])
logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}')
return None

View File

@ -23,7 +23,9 @@ class AsyncLLM(LLM):
self._async_completion = partial(
self._call_acompletion,
model=self.config.model,
api_key=self.config.api_key,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,

View File

@ -141,7 +141,9 @@ class LLM(RetryMixin, DebugMixin):
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
@ -331,7 +333,9 @@ class LLM(RetryMixin, DebugMixin):
# GET {base_url}/v1/model/info with litellm_model_id as path param
response = requests.get(
f'{self.config.base_url}/v1/model/info',
headers={'Authorization': f'Bearer {self.config.api_key}'},
headers={
'Authorization': f'Bearer {self.config.api_key.get_secret_value() if self.config.api_key else None}'
},
)
resp_json = response.json()
if 'data' not in resp_json:

View File

@ -17,7 +17,9 @@ class StreamingLLM(AsyncLLM):
self._async_streaming_completion = partial(
self._call_acompletion,
model=self.config.model,
api_key=self.config.api_key,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,

View File

@ -59,7 +59,8 @@ class ModalRuntime(ActionExecutionClient):
self.sandbox = None
self.modal_client = modal.Client.from_credentials(
config.modal_api_token_id, config.modal_api_token_secret
config.modal_api_token_id.get_secret_value(),
config.modal_api_token_secret.get_secret_value(),
)
self.app = modal.App.lookup(
'openhands', create_if_missing=True, client=self.modal_client

View File

@ -40,7 +40,7 @@ class RunloopRuntime(ActionExecutionClient):
self.devbox: DevboxView | None = None
self.config = config
self.runloop_api_client = Runloop(
bearer_token=config.runloop_api_key,
bearer_token=config.runloop_api_key.get_secret_value(),
)
self.container_name = CONTAINER_NAME_PREFIX + sid
super().__init__(

View File

@ -51,8 +51,8 @@ async def get_litellm_models() -> list[str]:
):
bedrock_model_list = bedrock.list_foundation_models(
llm_config.aws_region_name,
llm_config.aws_access_key_id,
llm_config.aws_secret_access_key,
llm_config.aws_access_key_id.get_secret_value(),
llm_config.aws_secret_access_key.get_secret_value(),
)
model_list = litellm_model_list_without_bedrock + bedrock_model_list
for llm_config in config.llms.values():

View File

@ -30,9 +30,6 @@ async def load_settings(request: Request) -> Settings | None:
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Settings not found'},
)
# For security reasons we don't ever send the api key to the client
settings.llm_api_key = 'SET' if settings.llm_api_key else None
return settings
except Exception as e:
logger.warning(f'Invalid token: {e}')

View File

@ -1,13 +1,12 @@
from dataclasses import dataclass
from pydantic import Field
from openhands.server.settings import Settings
@dataclass
class ConversationInitData(Settings):
"""
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
github_token: str | None = None
selected_repository: str | None = None
github_token: str | None = Field(default=None)
selected_repository: str | None = Field(default=None)

View File

@ -91,6 +91,9 @@ class Session:
)
max_iterations = settings.max_iterations or self.config.max_iterations
# This is a shallow copy of the default LLM config, so changes here will
# persist if we retrieve the default LLM config again when constructing
# the agent
default_llm_config = self.config.get_llm_config()
default_llm_config.model = settings.llm_model or ''
default_llm_config.api_key = settings.llm_api_key

View File

@ -1,8 +1,8 @@
from dataclasses import dataclass
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
from pydantic.json import pydantic_encoder
@dataclass
class Settings:
class Settings(BaseModel):
"""
Persisted settings for OpenHands sessions
"""
@ -13,6 +13,20 @@ class Settings:
security_analyzer: str | None = None
confirmation_mode: bool | None = None
llm_model: str | None = None
llm_api_key: str | None = None
llm_api_key: SecretStr | None = None
llm_base_url: str | None = None
remote_runtime_resource_factor: int | None = None
@field_serializer('llm_api_key')
def llm_api_key_serializer(self, llm_api_key: SecretStr, info: SerializationInfo):
"""Custom serializer for the LLM API key.
To serialize the API key instead of `"********"`, set `expose_secrets` to True in the serialization context. For example::
settings.model_dump_json(context={'expose_secrets': True})
"""
context = info.context
if context and context.get('expose_secrets', False):
return llm_api_key.get_secret_value()
return pydantic_encoder(llm_api_key)

View File

@ -26,7 +26,7 @@ class FileSettingsStore(SettingsStore):
return None
async def store(self, settings: Settings):
json_str = json.dumps(settings.__dict__)
json_str = settings.model_dump_json(context={'expose_secrets': True})
await call_sync_from_async(self.file_store.write, self.path, json_str)
@classmethod

View File

@ -90,7 +90,9 @@ class EmbeddingsLoader:
return OpenAIEmbedding(
model='text-embedding-ada-002',
api_key=llm_config.api_key,
api_key=llm_config.api_key.get_secret_value()
if llm_config.api_key
else None,
)
elif strategy == 'azureopenai':
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding

View File

@ -109,9 +109,6 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
print(f'Cancel requested: {is_set}')
return is_set
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
async def mock_acompletion(*args, **kwargs):
print('Starting mock_acompletion')
for i in range(20): # Increased iterations for longer running task
@ -153,13 +150,6 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
cancel_requested = False
async def mock_on_cancel_requested():
nonlocal cancel_requested
return cancel_requested
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
test_messages = [
'This is ',
'a test ',

View File

@ -60,7 +60,6 @@ def mock_state() -> State:
def test_cmd_output_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = CmdOutputObservation(
command='echo hello',
content='Command output',
@ -82,7 +81,6 @@ def test_cmd_output_observation_message(agent: CodeActAgent):
def test_ipython_run_cell_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = IPythonRunCellObservation(
code='plt.plot()',
content='IPython output\n![image]()',
@ -105,7 +103,6 @@ def test_ipython_run_cell_observation_message(agent: CodeActAgent):
def test_agent_delegate_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = AgentDelegateObservation(
content='Content', outputs={'content': 'Delegated agent output'}
)
@ -122,7 +119,6 @@ def test_agent_delegate_observation_message(agent: CodeActAgent):
def test_error_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = ErrorObservation('Error message')
results = agent.get_observation_message(obs, tool_call_id_to_message={})
@ -145,7 +141,6 @@ def test_unknown_observation_message(agent: CodeActAgent):
def test_file_edit_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = FileEditObservation(
path='/test/file.txt',
prev_exist=True,
@ -167,7 +162,6 @@ def test_file_edit_observation_message(agent: CodeActAgent):
def test_file_read_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = FileReadObservation(
path='/test/file.txt',
content='File content',
@ -186,7 +180,6 @@ def test_file_read_observation_message(agent: CodeActAgent):
def test_browser_output_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = BrowserOutputObservation(
url='http://example.com',
trigger_by_action='browse',
@ -207,7 +200,6 @@ def test_browser_output_observation_message(agent: CodeActAgent):
def test_user_reject_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = UserRejectObservation('Action rejected')
results = agent.get_observation_message(obs, tool_call_id_to_message={})
@ -223,7 +215,6 @@ def test_user_reject_observation_message(agent: CodeActAgent):
def test_function_calling_observation_message(agent: CodeActAgent):
agent.config.function_calling = True
mock_response = {
'id': 'mock_id',
'total_calls_in_response': 1,

View File

@ -226,7 +226,7 @@ def test_llm_condenser_from_config():
assert isinstance(condenser, LLMSummarizingCondenser)
assert condenser.llm.config.model == 'gpt-4o'
assert condenser.llm.config.api_key == 'test_key'
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
def test_llm_condenser(mock_llm, mock_state):
@ -381,7 +381,7 @@ def test_llm_attention_condenser_from_config():
assert isinstance(condenser, LLMAttentionCondenser)
assert condenser.llm.config.model == 'gpt-4o'
assert condenser.llm.config.api_key == 'test_key'
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
assert condenser.max_size == 50
assert condenser.keep_first == 10

View File

@ -63,7 +63,7 @@ def test_compat_env_to_config(monkeypatch, setup_env):
assert config.workspace_base == '/repos/openhands/workspace'
assert isinstance(config.get_llm_config(), LLMConfig)
assert config.get_llm_config().api_key == 'sk-proj-rgMV0...'
assert config.get_llm_config().api_key.get_secret_value() == 'sk-proj-rgMV0...'
assert config.get_llm_config().model == 'gpt-4o'
assert isinstance(config.get_agent_config(), AgentConfig)
assert isinstance(config.get_agent_config().memory_max_threads, int)
@ -83,7 +83,7 @@ def test_load_from_old_style_env(monkeypatch, default_config):
load_from_env(default_config, os.environ)
assert default_config.get_llm_config().api_key == 'test-api-key'
assert default_config.get_llm_config().api_key.get_secret_value() == 'test-api-key'
assert default_config.get_agent_config().memory_enabled is True
assert default_config.default_agent == 'BrowsingAgent'
assert default_config.workspace_base == '/opt/files/workspace'
@ -126,7 +126,7 @@ default_agent = "TestAgent"
# default llm & agent configs
assert default_config.default_agent == 'TestAgent'
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_llm_config().api_key == 'toml-api-key'
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
assert default_config.get_agent_config().memory_enabled is True
# undefined agent config inherits default ones
@ -291,7 +291,7 @@ sandbox_user_id = 1001
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_llm_config('llm').model == 'test-model'
assert default_config.get_llm_config_from_agent().model == 'test-model'
assert default_config.get_llm_config().api_key == 'env-api-key'
assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key'
# after we set workspace_base to 'UNDEFINED' in the environment,
# workspace_base should be set to that
@ -336,7 +336,7 @@ user_id = 1001
assert default_config.workspace_mount_path is None
# before load_from_env, values are set to the values from the toml file
assert default_config.get_llm_config().api_key == 'toml-api-key'
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
assert default_config.sandbox.timeout == 500
assert default_config.sandbox.user_id == 1001
@ -345,7 +345,7 @@ user_id = 1001
# values from env override values from toml
assert os.environ.get('LLM_MODEL') is None
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_llm_config().api_key == 'env-api-key'
assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key'
assert default_config.sandbox.timeout == 1000
assert default_config.sandbox.user_id == 1002
@ -412,7 +412,7 @@ def test_security_config_from_dict():
# Test with all fields
config_dict = {'confirmation_mode': True, 'security_analyzer': 'some_analyzer'}
security_config = SecurityConfig.from_dict(config_dict)
security_config = SecurityConfig(**config_dict)
# Verify all fields are correctly set
assert security_config.confirmation_mode is True
@ -560,10 +560,7 @@ invalid_field_in_sandbox = "test"
assert 'Cannot parse [llm] config from toml' in log_content
assert 'values have not been applied' in log_content
# Error: LLMConfig.__init__() got an unexpected keyword argume
assert (
'Error: LLMConfig.__init__() got an unexpected keyword argume'
in log_content
)
assert 'Error: 1 validation error for LLMConfig' in log_content
assert 'invalid_field' in log_content
# invalid [sandbox] config
@ -635,12 +632,14 @@ def test_api_keys_repr_str():
aws_access_key_id='my_access_key',
aws_secret_access_key='my_secret_key',
)
assert "api_key='******'" in repr(llm_config)
assert "aws_access_key_id='******'" in repr(llm_config)
assert "aws_secret_access_key='******'" in repr(llm_config)
assert "api_key='******'" in str(llm_config)
assert "aws_access_key_id='******'" in str(llm_config)
assert "aws_secret_access_key='******'" in str(llm_config)
# Check that no secret keys are emitted in representations of the config object
assert 'my_api_key' not in repr(llm_config)
assert 'my_api_key' not in str(llm_config)
assert 'my_access_key' not in repr(llm_config)
assert 'my_access_key' not in str(llm_config)
assert 'my_secret_key' not in repr(llm_config)
assert 'my_secret_key' not in str(llm_config)
# Check that no other attrs in LLMConfig have 'key' or 'token' in their name
# This will fail when new attrs are added, and attract attention
@ -652,7 +651,7 @@ def test_api_keys_repr_str():
'output_cost_per_token',
'custom_tokenizer',
]
for attr_name in dir(LLMConfig):
for attr_name in LLMConfig.model_fields.keys():
if (
not attr_name.startswith('__')
and attr_name not in known_key_token_attrs_llm
@ -667,7 +666,7 @@ def test_api_keys_repr_str():
# Test AgentConfig
# No attrs in AgentConfig have 'key' or 'token' in their name
agent_config = AgentConfig(memory_enabled=True, memory_max_threads=4)
for attr_name in dir(AgentConfig):
for attr_name in AgentConfig.model_fields.keys():
if not attr_name.startswith('__'):
assert (
'key' not in attr_name.lower()
@ -686,16 +685,16 @@ def test_api_keys_repr_str():
modal_api_token_secret='my_modal_api_token_secret',
runloop_api_key='my_runloop_api_key',
)
assert "e2b_api_key='******'" in repr(app_config)
assert "e2b_api_key='******'" in str(app_config)
assert "jwt_secret='******'" in repr(app_config)
assert "jwt_secret='******'" in str(app_config)
assert "modal_api_token_id='******'" in repr(app_config)
assert "modal_api_token_id='******'" in str(app_config)
assert "modal_api_token_secret='******'" in repr(app_config)
assert "modal_api_token_secret='******'" in str(app_config)
assert "runloop_api_key='******'" in repr(app_config)
assert "runloop_api_key='******'" in str(app_config)
assert 'my_e2b_api_key' not in repr(app_config)
assert 'my_e2b_api_key' not in str(app_config)
assert 'my_jwt_secret' not in repr(app_config)
assert 'my_jwt_secret' not in str(app_config)
assert 'my_modal_api_token_id' not in repr(app_config)
assert 'my_modal_api_token_id' not in str(app_config)
assert 'my_modal_api_token_secret' not in repr(app_config)
assert 'my_modal_api_token_secret' not in str(app_config)
assert 'my_runloop_api_key' not in repr(app_config)
assert 'my_runloop_api_key' not in str(app_config)
# Check that no other attrs in AppConfig have 'key' or 'token' in their name
# This will fail when new attrs are added, and attract attention
@ -705,7 +704,7 @@ def test_api_keys_repr_str():
'modal_api_token_secret',
'runloop_api_key',
]
for attr_name in dir(AppConfig):
for attr_name in AppConfig.model_fields.keys():
if (
not attr_name.startswith('__')
and attr_name not in known_key_token_attrs_app

View File

@ -1,4 +1,3 @@
import json
from unittest.mock import MagicMock, patch
import pytest
@ -43,7 +42,7 @@ async def test_store_and_load_data(file_settings_store):
await file_settings_store.store(init_data)
# Verify store called with correct JSON
expected_json = json.dumps(init_data.__dict__)
expected_json = init_data.model_dump_json(context={'expose_secrets': True})
file_settings_store.file_store.write.assert_called_once_with(
'settings.json', expected_json
)
@ -60,7 +59,12 @@ async def test_store_and_load_data(file_settings_store):
assert loaded_data.security_analyzer == init_data.security_analyzer
assert loaded_data.confirmation_mode == init_data.confirmation_mode
assert loaded_data.llm_model == init_data.llm_model
assert loaded_data.llm_api_key == init_data.llm_api_key
assert loaded_data.llm_api_key
assert init_data.llm_api_key
assert (
loaded_data.llm_api_key.get_secret_value()
== init_data.llm_api_key.get_secret_value()
)
assert loaded_data.llm_base_url == init_data.llm_base_url

View File

@ -40,7 +40,7 @@ def default_config():
def test_llm_init_with_default_config(default_config):
llm = LLM(default_config)
assert llm.config.model == 'gpt-4o'
assert llm.config.api_key == 'test_key'
assert llm.config.api_key.get_secret_value() == 'test_key'
assert isinstance(llm.metrics, Metrics)
assert llm.metrics.model_name == 'gpt-4o'
@ -77,7 +77,7 @@ def test_llm_init_with_custom_config():
)
llm = LLM(custom_config)
assert llm.config.model == 'custom-model'
assert llm.config.api_key == 'custom_key'
assert llm.config.api_key.get_secret_value() == 'custom_key'
assert llm.config.max_input_tokens == 5000
assert llm.config.max_output_tokens == 1500
assert llm.config.temperature == 0.8

View File

@ -59,28 +59,28 @@ def test_load_from_toml_llm_with_fallback(
# Verify generic LLM configuration
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key == 'base-api-key'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom1 LLM falls back 'num_retries' from base
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key == 'custom-api-key-1'
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key == 'custom-api-key-2'
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # overridden value
# Verify custom3 LLM inherits all attributes except 'model' and 'api_key'
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key == 'custom-api-key-3'
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
@ -113,14 +113,14 @@ num_retries = 10
# Verify generic LLM configuration remains unchanged
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key == 'base-api-key'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.embedding_model == 'base-embedding'
assert generic_llm.num_retries == 3
# Verify custom_full LLM overrides all attributes
custom_full = default_config.get_llm_config('custom_full')
assert custom_full.model == 'full-custom-model'
assert custom_full.api_key == 'full-custom-api-key'
assert custom_full.api_key.get_secret_value() == 'full-custom-api-key'
assert custom_full.embedding_model == 'full-custom-embedding'
assert custom_full.num_retries == 10 # overridden value
@ -136,14 +136,14 @@ def test_load_from_toml_llm_custom_partial_override(
# Verify custom1 LLM overrides 'model' and 'api_key' but inherits 'num_retries'
custom1 = default_config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.api_key == 'custom-api-key-1'
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
assert custom1.embedding_model == 'base-embedding'
assert custom1.num_retries == 3 # from [llm]
# Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries'
custom2 = default_config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.api_key == 'custom-api-key-2'
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
assert custom2.embedding_model == 'base-embedding'
assert custom2.num_retries == 5 # Overridden value
@ -159,7 +159,7 @@ def test_load_from_toml_llm_custom_no_override(
# Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic
custom3 = default_config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.api_key == 'custom-api-key-3'
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
assert custom3.embedding_model == 'base-embedding'
assert custom3.num_retries == 3 # from [llm]
@ -186,7 +186,7 @@ api_key = "custom-only-api-key"
# Verify custom_only LLM uses its own attributes and defaults for others
custom_only = default_config.get_llm_config('custom_only')
assert custom_only.model == 'custom-only-model'
assert custom_only.api_key == 'custom-only-api-key'
assert custom_only.api_key.get_secret_value() == 'custom-only-api-key'
assert custom_only.embedding_model == 'local' # default value
assert custom_only.num_retries == 8 # default value
@ -217,12 +217,12 @@ unknown_attr = "should_not_exist"
# Verify generic LLM is loaded correctly
generic_llm = default_config.get_llm_config('llm')
assert generic_llm.model == 'base-model'
assert generic_llm.api_key == 'base-api-key'
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
assert generic_llm.num_retries == 3
# Verify invalid_custom LLM does not override generic attributes
custom_invalid = default_config.get_llm_config('invalid_custom')
assert custom_invalid.model == 'base-model'
assert custom_invalid.api_key == 'base-api-key'
assert custom_invalid.api_key.get_secret_value() == 'base-api-key'
assert custom_invalid.num_retries == 3 # default value
assert custom_invalid.embedding_model == 'local' # default value

View File

@ -86,7 +86,7 @@ def test_draft_editor_as_named_llm(config_toml_with_draft_editor):
draft_llm = config.get_llm_config('draft_editor')
assert draft_llm is not None
assert draft_llm.model == 'draft-model'
assert draft_llm.api_key == 'draft-api-key'
assert draft_llm.api_key.get_secret_value() == 'draft-api-key'
def test_draft_editor_fallback(config_toml_with_draft_editor):

View File

@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from pydantic import SecretStr
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.server.app import app
@ -50,7 +51,7 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
'security_analyzer': 'default',
'confirmation_mode': True,
'llm_model': 'test-model',
'llm_api_key': None,
'llm_api_key': 'test-key',
'llm_base_url': 'https://test.com',
'remote_runtime_resource_factor': 2,
}
@ -83,3 +84,36 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
mock_settings_store.store.assert_called()
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.remote_runtime_resource_factor == 2
assert isinstance(stored_settings.llm_api_key, SecretStr)
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
@pytest.mark.asyncio
async def test_settings_llm_api_key(test_client, mock_settings_store):
# Mock the settings store to return None initially (no existing settings)
mock_settings_store.load.return_value = None
# Test data with remote_runtime_resource_factor
settings_data = {'llm_api_key': 'test-key'}
# The test_client fixture already handles authentication
# Make the POST request to store settings
response = test_client.post('/api/settings', json=settings_data)
assert response.status_code == 200
# Verify the settings were stored with the correct secret API key
stored_settings = mock_settings_store.store.call_args[0][0]
assert isinstance(stored_settings.llm_api_key, SecretStr)
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
# Mock settings store to return our settings for the GET request
mock_settings_store.load.return_value = Settings(**settings_data)
# Make a GET request to retrieve settings
response = test_client.get('/api/settings')
assert response.status_code == 200
# We should never expose the API key in the response
assert 'test-key' not in response.json()