mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Pydantic-based configuration and setting objects (#6321)
Co-authored-by: Calvin Smith <calvin@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
parent
899c1f8360
commit
a12087243a
@ -80,7 +80,7 @@ def load_dependencies(runtime: Runtime) -> List[str]:
|
||||
def init_task_env(runtime: Runtime, hostname: str, env_llm_config: LLMConfig):
|
||||
command = (
|
||||
f'SERVER_HOSTNAME={hostname} '
|
||||
f'LITELLM_API_KEY={env_llm_config.api_key} '
|
||||
f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} '
|
||||
f'LITELLM_BASE_URL={env_llm_config.base_url} '
|
||||
f'LITELLM_MODEL={env_llm_config.model} '
|
||||
'bash /utils/init.sh'
|
||||
@ -165,7 +165,7 @@ def run_evaluator(
|
||||
runtime: Runtime, env_llm_config: LLMConfig, trajectory_path: str, result_path: str
|
||||
):
|
||||
command = (
|
||||
f'LITELLM_API_KEY={env_llm_config.api_key} '
|
||||
f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} '
|
||||
f'LITELLM_BASE_URL={env_llm_config.base_url} '
|
||||
f'LITELLM_MODEL={env_llm_config.model} '
|
||||
f"DECRYPTION_KEY='theagentcompany is all you need' " # Hardcoded Key
|
||||
|
||||
@ -52,30 +52,6 @@ class EvalMetadata(BaseModel):
|
||||
details: dict[str, Any] | None = None
|
||||
condenser_config: CondenserConfig | None = None
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
dumped_dict = super().model_dump(*args, **kwargs)
|
||||
# avoid leaking sensitive information
|
||||
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
|
||||
if hasattr(self.condenser_config, 'llm_config'):
|
||||
dumped_dict['condenser_config']['llm_config'] = (
|
||||
self.condenser_config.llm_config.to_safe_dict()
|
||||
)
|
||||
|
||||
return dumped_dict
|
||||
|
||||
def model_dump_json(self, *args, **kwargs):
|
||||
dumped = super().model_dump_json(*args, **kwargs)
|
||||
dumped_dict = json.loads(dumped)
|
||||
# avoid leaking sensitive information
|
||||
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
|
||||
if hasattr(self.condenser_config, 'llm_config'):
|
||||
dumped_dict['condenser_config']['llm_config'] = (
|
||||
self.condenser_config.llm_config.to_safe_dict()
|
||||
)
|
||||
|
||||
logger.debug(f'Dumped metadata: {dumped_dict}')
|
||||
return json.dumps(dumped_dict)
|
||||
|
||||
|
||||
class EvalOutput(BaseModel):
|
||||
# NOTE: User-specified
|
||||
@ -98,23 +74,6 @@ class EvalOutput(BaseModel):
|
||||
# Optionally save the input test instance
|
||||
instance: dict[str, Any] | None = None
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
dumped_dict = super().model_dump(*args, **kwargs)
|
||||
# Remove None values
|
||||
dumped_dict = {k: v for k, v in dumped_dict.items() if v is not None}
|
||||
# Apply custom serialization for metadata (to avoid leaking sensitive information)
|
||||
if self.metadata is not None:
|
||||
dumped_dict['metadata'] = self.metadata.model_dump()
|
||||
return dumped_dict
|
||||
|
||||
def model_dump_json(self, *args, **kwargs):
|
||||
dumped = super().model_dump_json(*args, **kwargs)
|
||||
dumped_dict = json.loads(dumped)
|
||||
# Apply custom serialization for metadata (to avoid leaking sensitive information)
|
||||
if 'metadata' in dumped_dict:
|
||||
dumped_dict['metadata'] = json.loads(self.metadata.model_dump_json())
|
||||
return json.dumps(dumped_dict)
|
||||
|
||||
|
||||
class EvalException(Exception):
|
||||
pass
|
||||
@ -314,7 +273,7 @@ def update_progress(
|
||||
logger.info(
|
||||
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
||||
)
|
||||
output_fp.write(json.dumps(result.model_dump()) + '\n')
|
||||
output_fp.write(result.model_dump_json() + '\n')
|
||||
output_fp.flush()
|
||||
|
||||
|
||||
|
||||
@ -37,21 +37,17 @@ export SANDBOX_TIMEOUT='300'
|
||||
|
||||
## Type Handling
|
||||
|
||||
The `load_from_env` function attempts to cast environment variable values to the types specified in the dataclasses. It handles:
|
||||
The `load_from_env` function attempts to cast environment variable values to the types specified in the models. It handles:
|
||||
|
||||
- Basic types (str, int, bool)
|
||||
- Optional types (e.g., `str | None`)
|
||||
- Nested dataclasses
|
||||
- Nested models
|
||||
|
||||
If type casting fails, an error is logged, and the default value is retained.
|
||||
|
||||
## Default Values
|
||||
|
||||
If an environment variable is not set, the default value specified in the dataclass is used.
|
||||
|
||||
## Nested Configurations
|
||||
|
||||
The `AppConfig` class contains nested configurations like `LLMConfig` and `AgentConfig`. The `load_from_env` function handles these by recursively processing nested dataclasses with updated prefixes.
|
||||
If an environment variable is not set, the default value specified in the model is used.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from dataclasses import dataclass, field, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.config.condenser_config import CondenserConfig, NoOpCondenserConfig
|
||||
from openhands.core.config.config_utils import get_field_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
class AgentConfig(BaseModel):
|
||||
"""Configuration for the agent.
|
||||
|
||||
Attributes:
|
||||
@ -22,20 +20,13 @@ class AgentConfig:
|
||||
condenser: Configuration for the memory condenser. Default is NoOpCondenserConfig.
|
||||
"""
|
||||
|
||||
codeact_enable_browsing: bool = True
|
||||
codeact_enable_llm_editor: bool = False
|
||||
codeact_enable_jupyter: bool = True
|
||||
micro_agent_name: str | None = None
|
||||
memory_enabled: bool = False
|
||||
memory_max_threads: int = 3
|
||||
llm_config: str | None = None
|
||||
enable_prompt_extensions: bool = True
|
||||
disabled_microagents: list[str] | None = None
|
||||
condenser: CondenserConfig = field(default_factory=NoOpCondenserConfig) # type: ignore
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
result = {}
|
||||
for f in fields(self):
|
||||
result[f.name] = get_field_info(f)
|
||||
return result
|
||||
codeact_enable_browsing: bool = Field(default=True)
|
||||
codeact_enable_llm_editor: bool = Field(default=False)
|
||||
codeact_enable_jupyter: bool = Field(default=True)
|
||||
micro_agent_name: str | None = Field(default=None)
|
||||
memory_enabled: bool = Field(default=False)
|
||||
memory_max_threads: int = Field(default=3)
|
||||
llm_config: str | None = Field(default=None)
|
||||
enable_prompt_extensions: bool = Field(default=False)
|
||||
disabled_microagents: list[str] | None = Field(default=None)
|
||||
condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig)
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from openhands.core import logger
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.config.config_utils import (
|
||||
OH_DEFAULT_AGENT,
|
||||
OH_MAX_ITERATIONS,
|
||||
get_field_info,
|
||||
model_defaults_to_dict,
|
||||
)
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
class AppConfig(BaseModel):
|
||||
"""Configuration for the app.
|
||||
|
||||
Attributes:
|
||||
@ -46,37 +46,39 @@ class AppConfig:
|
||||
input is read line by line. When enabled, input continues until /exit command.
|
||||
"""
|
||||
|
||||
llms: dict[str, LLMConfig] = field(default_factory=dict)
|
||||
agents: dict = field(default_factory=dict)
|
||||
default_agent: str = OH_DEFAULT_AGENT
|
||||
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
runtime: str = 'docker'
|
||||
file_store: str = 'local'
|
||||
file_store_path: str = '/tmp/openhands_file_store'
|
||||
save_trajectory_path: str | None = None
|
||||
workspace_base: str | None = None
|
||||
workspace_mount_path: str | None = None
|
||||
workspace_mount_path_in_sandbox: str = '/workspace'
|
||||
workspace_mount_rewrite: str | None = None
|
||||
cache_dir: str = '/tmp/cache'
|
||||
run_as_openhands: bool = True
|
||||
max_iterations: int = OH_MAX_ITERATIONS
|
||||
max_budget_per_task: float | None = None
|
||||
e2b_api_key: str = ''
|
||||
modal_api_token_id: str = ''
|
||||
modal_api_token_secret: str = ''
|
||||
disable_color: bool = False
|
||||
jwt_secret: str = ''
|
||||
debug: bool = False
|
||||
file_uploads_max_file_size_mb: int = 0
|
||||
file_uploads_restrict_file_types: bool = False
|
||||
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
|
||||
runloop_api_key: str | None = None
|
||||
cli_multiline_input: bool = False
|
||||
llms: dict[str, LLMConfig] = Field(default_factory=dict)
|
||||
agents: dict = Field(default_factory=dict)
|
||||
default_agent: str = Field(default=OH_DEFAULT_AGENT)
|
||||
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
|
||||
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
||||
runtime: str = Field(default='docker')
|
||||
file_store: str = Field(default='local')
|
||||
file_store_path: str = Field(default='/tmp/openhands_file_store')
|
||||
save_trajectory_path: str | None = Field(default=None)
|
||||
workspace_base: str | None = Field(default=None)
|
||||
workspace_mount_path: str | None = Field(default=None)
|
||||
workspace_mount_path_in_sandbox: str = Field(default='/workspace')
|
||||
workspace_mount_rewrite: str | None = Field(default=None)
|
||||
cache_dir: str = Field(default='/tmp/cache')
|
||||
run_as_openhands: bool = Field(default=True)
|
||||
max_iterations: int = Field(default=OH_MAX_ITERATIONS)
|
||||
max_budget_per_task: float | None = Field(default=None)
|
||||
e2b_api_key: SecretStr | None = Field(default=None)
|
||||
modal_api_token_id: SecretStr | None = Field(default=None)
|
||||
modal_api_token_secret: SecretStr | None = Field(default=None)
|
||||
disable_color: bool = Field(default=False)
|
||||
jwt_secret: SecretStr | None = Field(default=None)
|
||||
debug: bool = Field(default=False)
|
||||
file_uploads_max_file_size_mb: int = Field(default=0)
|
||||
file_uploads_restrict_file_types: bool = Field(default=False)
|
||||
file_uploads_allowed_extensions: list[str] = Field(default_factory=lambda: ['.*'])
|
||||
runloop_api_key: SecretStr | None = Field(default=None)
|
||||
cli_multiline_input: bool = Field(default=False)
|
||||
|
||||
defaults_dict: ClassVar[dict] = {}
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
def get_llm_config(self, name='llm') -> LLMConfig:
|
||||
"""'llm' is the name for default config (for backward compatibility prior to 0.8)."""
|
||||
if name in self.llms:
|
||||
@ -115,42 +117,7 @@ class AppConfig:
|
||||
def get_agent_configs(self) -> dict[str, AgentConfig]:
|
||||
return self.agents
|
||||
|
||||
def __post_init__(self):
|
||||
def model_post_init(self, __context):
|
||||
"""Post-initialization hook, called when the instance is created with only default values."""
|
||||
AppConfig.defaults_dict = self.defaults_to_dict()
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
result = {}
|
||||
for f in fields(self):
|
||||
field_value = getattr(self, f.name)
|
||||
|
||||
# dataclasses compute their defaults themselves
|
||||
if is_dataclass(type(field_value)):
|
||||
result[f.name] = field_value.defaults_to_dict()
|
||||
else:
|
||||
result[f.name] = get_field_info(f)
|
||||
return result
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
if attr_name in [
|
||||
'e2b_api_key',
|
||||
'github_token',
|
||||
'jwt_secret',
|
||||
'modal_api_token_id',
|
||||
'modal_api_token_secret',
|
||||
'runloop_api_key',
|
||||
]:
|
||||
attr_value = '******' if attr_value else None
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"AppConfig({', '.join(attr_str)}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
super().model_post_init(__context)
|
||||
AppConfig.defaults_dict = model_defaults_to_dict(self)
|
||||
|
||||
@ -1,19 +1,22 @@
|
||||
from types import UnionType
|
||||
from typing import get_args, get_origin
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
OH_DEFAULT_AGENT = 'CodeActAgent'
|
||||
OH_MAX_ITERATIONS = 500
|
||||
|
||||
|
||||
def get_field_info(f):
|
||||
def get_field_info(field: FieldInfo) -> dict[str, Any]:
|
||||
"""Extract information about a dataclass field: type, optional, and default.
|
||||
|
||||
Args:
|
||||
f: The field to extract information from.
|
||||
field: The field to extract information from.
|
||||
|
||||
Returns: A dict with the field's type, whether it's optional, and its default value.
|
||||
"""
|
||||
field_type = f.type
|
||||
field_type = field.annotation
|
||||
optional = False
|
||||
|
||||
# for types like str | None, find the non-None type and set optional to True
|
||||
@ -33,7 +36,21 @@ def get_field_info(f):
|
||||
)
|
||||
|
||||
# default is always present
|
||||
default = f.default
|
||||
default = field.default
|
||||
|
||||
# return a schema with the useful info for frontend
|
||||
return {'type': type_name.lower(), 'optional': optional, 'default': default}
|
||||
|
||||
|
||||
def model_defaults_to_dict(model: BaseModel) -> dict[str, Any]:
|
||||
"""Serialize field information in a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
result = {}
|
||||
for name, field in model.model_fields.items():
|
||||
field_value = getattr(model, name)
|
||||
|
||||
if isinstance(field_value, BaseModel):
|
||||
result[name] = model_defaults_to_dict(field_value)
|
||||
else:
|
||||
result[name] = get_field_info(field)
|
||||
|
||||
return result
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import os
|
||||
from dataclasses import dataclass, fields
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from openhands.core.config.config_utils import get_field_info
|
||||
from openhands.core.logger import LOG_DIR
|
||||
|
||||
LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key']
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
class LLMConfig(BaseModel):
|
||||
"""Configuration for the LLM model.
|
||||
|
||||
Attributes:
|
||||
@ -47,99 +47,56 @@ class LLMConfig:
|
||||
native_tool_calling: Whether to use native tool calling if supported by the model. Can be True, False, or not set.
|
||||
"""
|
||||
|
||||
model: str = 'claude-3-5-sonnet-20241022'
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
api_version: str | None = None
|
||||
embedding_model: str = 'local'
|
||||
embedding_base_url: str | None = None
|
||||
embedding_deployment_name: str | None = None
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
aws_region_name: str | None = None
|
||||
openrouter_site_url: str = 'https://docs.all-hands.dev/'
|
||||
openrouter_app_name: str = 'OpenHands'
|
||||
num_retries: int = 8
|
||||
retry_multiplier: float = 2
|
||||
retry_min_wait: int = 15
|
||||
retry_max_wait: int = 120
|
||||
timeout: int | None = None
|
||||
max_message_chars: int = 30_000 # maximum number of characters in an observation's content when sent to the llm
|
||||
temperature: float = 0.0
|
||||
top_p: float = 1.0
|
||||
custom_llm_provider: str | None = None
|
||||
max_input_tokens: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
input_cost_per_token: float | None = None
|
||||
output_cost_per_token: float | None = None
|
||||
ollama_base_url: str | None = None
|
||||
model: str = Field(default='claude-3-5-sonnet-20241022')
|
||||
api_key: SecretStr | None = Field(default=None)
|
||||
base_url: str | None = Field(default=None)
|
||||
api_version: str | None = Field(default=None)
|
||||
embedding_model: str = Field(default='local')
|
||||
embedding_base_url: str | None = Field(default=None)
|
||||
embedding_deployment_name: str | None = Field(default=None)
|
||||
aws_access_key_id: SecretStr | None = Field(default=None)
|
||||
aws_secret_access_key: SecretStr | None = Field(default=None)
|
||||
aws_region_name: str | None = Field(default=None)
|
||||
openrouter_site_url: str = Field(default='https://docs.all-hands.dev/')
|
||||
openrouter_app_name: str = Field(default='OpenHands')
|
||||
num_retries: int = Field(default=8)
|
||||
retry_multiplier: float = Field(default=2)
|
||||
retry_min_wait: int = Field(default=15)
|
||||
retry_max_wait: int = Field(default=120)
|
||||
timeout: int | None = Field(default=None)
|
||||
max_message_chars: int = Field(
|
||||
default=30_000
|
||||
) # maximum number of characters in an observation's content when sent to the llm
|
||||
temperature: float = Field(default=0.0)
|
||||
top_p: float = Field(default=1.0)
|
||||
custom_llm_provider: str | None = Field(default=None)
|
||||
max_input_tokens: int | None = Field(default=None)
|
||||
max_output_tokens: int | None = Field(default=None)
|
||||
input_cost_per_token: float | None = Field(default=None)
|
||||
output_cost_per_token: float | None = Field(default=None)
|
||||
ollama_base_url: str | None = Field(default=None)
|
||||
# This setting can be sent in each call to litellm
|
||||
drop_params: bool = True
|
||||
drop_params: bool = Field(default=True)
|
||||
# Note: this setting is actually global, unlike drop_params
|
||||
modify_params: bool = True
|
||||
disable_vision: bool | None = None
|
||||
reasoning_effort: str | None = None
|
||||
caching_prompt: bool = True
|
||||
log_completions: bool = False
|
||||
log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
|
||||
custom_tokenizer: str | None = None
|
||||
native_tool_calling: bool | None = None
|
||||
modify_params: bool = Field(default=True)
|
||||
disable_vision: bool | None = Field(default=None)
|
||||
caching_prompt: bool = Field(default=True)
|
||||
log_completions: bool = Field(default=False)
|
||||
log_completions_folder: str = Field(default=os.path.join(LOG_DIR, 'completions'))
|
||||
custom_tokenizer: str | None = Field(default=None)
|
||||
native_tool_calling: bool | None = Field(default=None)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
result = {}
|
||||
for f in fields(self):
|
||||
result[f.name] = get_field_info(f)
|
||||
return result
|
||||
def model_post_init(self, __context: Any):
|
||||
"""Post-initialization hook to assign OpenRouter-related variables to environment variables.
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization hook to assign OpenRouter-related variables to environment variables.
|
||||
This ensures that these values are accessible to litellm at runtime.
|
||||
"""
|
||||
super().model_post_init(__context)
|
||||
|
||||
# Assign OpenRouter-specific variables to environment variables
|
||||
if self.openrouter_site_url:
|
||||
os.environ['OR_SITE_URL'] = self.openrouter_site_url
|
||||
if self.openrouter_app_name:
|
||||
os.environ['OR_APP_NAME'] = self.openrouter_app_name
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
if attr_name in LLM_SENSITIVE_FIELDS:
|
||||
attr_value = '******' if attr_value else None
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"LLMConfig({', '.join(attr_str)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def to_safe_dict(self):
|
||||
"""Return a dict with the sensitive fields replaced with ******."""
|
||||
ret = self.__dict__.copy()
|
||||
for k, v in ret.items():
|
||||
if k in LLM_SENSITIVE_FIELDS:
|
||||
ret[k] = '******' if v else None
|
||||
elif isinstance(v, LLMConfig):
|
||||
ret[k] = v.to_safe_dict()
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig':
|
||||
"""Create an LLMConfig object from a dictionary.
|
||||
|
||||
This function is used to create an LLMConfig object from a dictionary.
|
||||
"""
|
||||
# Keep None values to preserve defaults, filter out other dicts
|
||||
args = {
|
||||
k: v
|
||||
for k, v in llm_config_dict.items()
|
||||
if not isinstance(v, dict) or v is None
|
||||
}
|
||||
return cls(**args)
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field, fields
|
||||
|
||||
from openhands.core.config.config_utils import get_field_info
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxConfig:
|
||||
class SandboxConfig(BaseModel):
|
||||
"""Configuration for the sandbox.
|
||||
|
||||
Attributes:
|
||||
@ -39,48 +37,32 @@ class SandboxConfig:
|
||||
This should be a JSON string that will be parsed into a dictionary.
|
||||
"""
|
||||
|
||||
remote_runtime_api_url: str = 'http://localhost:8000'
|
||||
local_runtime_url: str = 'http://localhost'
|
||||
keep_runtime_alive: bool = False
|
||||
rm_all_containers: bool = False
|
||||
api_key: str | None = None
|
||||
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
|
||||
runtime_container_image: str | None = None
|
||||
user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
|
||||
timeout: int = 120
|
||||
remote_runtime_init_timeout: int = 180
|
||||
enable_auto_lint: bool = (
|
||||
False # once enabled, OpenHands would lint files after editing
|
||||
remote_runtime_api_url: str = Field(default='http://localhost:8000')
|
||||
local_runtime_url: str = Field(default='http://localhost')
|
||||
keep_runtime_alive: bool = Field(default=False)
|
||||
rm_all_containers: bool = Field(default=False)
|
||||
api_key: str | None = Field(default=None)
|
||||
base_container_image: str = Field(
|
||||
default='nikolaik/python-nodejs:python3.12-nodejs22'
|
||||
)
|
||||
use_host_network: bool = False
|
||||
runtime_extra_build_args: list[str] | None = None
|
||||
initialize_plugins: bool = True
|
||||
force_rebuild_runtime: bool = False
|
||||
runtime_extra_deps: str | None = None
|
||||
runtime_startup_env_vars: dict[str, str] = field(default_factory=dict)
|
||||
browsergym_eval_env: str | None = None
|
||||
platform: str | None = None
|
||||
close_delay: int = 15
|
||||
remote_runtime_resource_factor: int = 1
|
||||
enable_gpu: bool = False
|
||||
docker_runtime_kwargs: str | None = None
|
||||
runtime_container_image: str | None = Field(default=None)
|
||||
user_id: int = Field(default=os.getuid() if hasattr(os, 'getuid') else 1000)
|
||||
timeout: int = Field(default=120)
|
||||
remote_runtime_init_timeout: int = Field(default=180)
|
||||
enable_auto_lint: bool = Field(
|
||||
default=False # once enabled, OpenHands would lint files after editing
|
||||
)
|
||||
use_host_network: bool = Field(default=False)
|
||||
runtime_extra_build_args: list[str] | None = Field(default=None)
|
||||
initialize_plugins: bool = Field(default=True)
|
||||
force_rebuild_runtime: bool = Field(default=False)
|
||||
runtime_extra_deps: str | None = Field(default=None)
|
||||
runtime_startup_env_vars: dict[str, str] = Field(default_factory=dict)
|
||||
browsergym_eval_env: str | None = Field(default=None)
|
||||
platform: str | None = Field(default=None)
|
||||
close_delay: int = Field(default=900)
|
||||
remote_runtime_resource_factor: int = Field(default=1)
|
||||
enable_gpu: bool = Field(default=False)
|
||||
docker_runtime_kwargs: str | None = Field(default=None)
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
dict = {}
|
||||
for f in fields(self):
|
||||
dict[f.name] = get_field_info(f)
|
||||
return dict
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"SandboxConfig({', '.join(attr_str)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from openhands.core.config.config_utils import get_field_info
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
class SecurityConfig(BaseModel):
|
||||
"""Configuration for security related functionalities.
|
||||
|
||||
Attributes:
|
||||
@ -12,29 +9,5 @@ class SecurityConfig:
|
||||
security_analyzer: The security analyzer to use.
|
||||
"""
|
||||
|
||||
confirmation_mode: bool = False
|
||||
security_analyzer: str | None = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
dict = {}
|
||||
for f in fields(self):
|
||||
dict[f.name] = get_field_info(f)
|
||||
return dict
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"SecurityConfig({', '.join(attr_str)})"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig':
|
||||
return cls(**security_config_dict)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
confirmation_mode: bool = Field(default=False)
|
||||
security_analyzer: str | None = Field(default=None)
|
||||
|
||||
@ -3,13 +3,13 @@ import os
|
||||
import pathlib
|
||||
import platform
|
||||
import sys
|
||||
from dataclasses import is_dataclass
|
||||
from types import UnionType
|
||||
from typing import Any, MutableMapping, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
import toml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from openhands.core import logger
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
@ -43,17 +43,19 @@ def load_from_env(cfg: AppConfig, env_or_toml_dict: dict | MutableMapping[str, s
|
||||
return next((t for t in types if t is not type(None)), None)
|
||||
|
||||
# helper function to set attributes based on env vars
|
||||
def set_attr_from_env(sub_config: Any, prefix=''):
|
||||
"""Set attributes of a config dataclass based on environment variables."""
|
||||
for field_name, field_type in sub_config.__annotations__.items():
|
||||
def set_attr_from_env(sub_config: BaseModel, prefix=''):
|
||||
"""Set attributes of a config model based on environment variables."""
|
||||
for field_name, field_info in sub_config.model_fields.items():
|
||||
field_value = getattr(sub_config, field_name)
|
||||
field_type = field_info.annotation
|
||||
|
||||
# compute the expected env var name from the prefix and field name
|
||||
# e.g. LLM_BASE_URL
|
||||
env_var_name = (prefix + field_name).upper()
|
||||
|
||||
if is_dataclass(field_type):
|
||||
# nested dataclass
|
||||
nested_sub_config = getattr(sub_config, field_name)
|
||||
set_attr_from_env(nested_sub_config, prefix=field_name + '_')
|
||||
if isinstance(field_value, BaseModel):
|
||||
set_attr_from_env(field_value, prefix=field_name + '_')
|
||||
|
||||
elif env_var_name in env_or_toml_dict:
|
||||
# convert the env var to the correct type and set it
|
||||
value = env_or_toml_dict[env_var_name]
|
||||
@ -125,22 +127,40 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
if isinstance(value, dict):
|
||||
try:
|
||||
if key is not None and key.lower() == 'agent':
|
||||
# Every entry here is either a field for the default `agent` config group, or itself a group
|
||||
# The best way to tell the difference is to try to parse it as an AgentConfig object
|
||||
agent_group_ids: set[str] = set()
|
||||
for nested_key, nested_value in value.items():
|
||||
if isinstance(nested_value, dict):
|
||||
try:
|
||||
agent_config = AgentConfig(**nested_value)
|
||||
except ValidationError:
|
||||
continue
|
||||
agent_group_ids.add(nested_key)
|
||||
cfg.set_agent_config(agent_config, nested_key)
|
||||
|
||||
logger.openhands_logger.debug(
|
||||
'Attempt to load default agent config from config toml'
|
||||
)
|
||||
non_dict_fields = {
|
||||
k: v for k, v in value.items() if not isinstance(v, dict)
|
||||
value_without_groups = {
|
||||
k: v for k, v in value.items() if k not in agent_group_ids
|
||||
}
|
||||
agent_config = AgentConfig(**non_dict_fields)
|
||||
agent_config = AgentConfig(**value_without_groups)
|
||||
cfg.set_agent_config(agent_config, 'agent')
|
||||
|
||||
elif key is not None and key.lower() == 'llm':
|
||||
# Every entry here is either a field for the default `llm` config group, or itself a group
|
||||
# The best way to tell the difference is to try to parse it as an LLMConfig object
|
||||
llm_group_ids: set[str] = set()
|
||||
for nested_key, nested_value in value.items():
|
||||
if isinstance(nested_value, dict):
|
||||
logger.openhands_logger.debug(
|
||||
f'Attempt to load group {nested_key} from config toml as agent config'
|
||||
)
|
||||
agent_config = AgentConfig(**nested_value)
|
||||
cfg.set_agent_config(agent_config, nested_key)
|
||||
elif key is not None and key.lower() == 'llm':
|
||||
try:
|
||||
llm_config = LLMConfig(**nested_value)
|
||||
except ValidationError:
|
||||
continue
|
||||
llm_group_ids.add(nested_key)
|
||||
cfg.set_llm_config(llm_config, nested_key)
|
||||
|
||||
logger.openhands_logger.debug(
|
||||
'Attempt to load default LLM config from config toml'
|
||||
)
|
||||
@ -150,7 +170,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
for k, v in value.items():
|
||||
if not isinstance(v, dict):
|
||||
generic_llm_fields[k] = v
|
||||
generic_llm_config = LLMConfig.from_dict(generic_llm_fields)
|
||||
generic_llm_config = LLMConfig(**generic_llm_fields)
|
||||
cfg.set_llm_config(generic_llm_config, 'llm')
|
||||
|
||||
# Process custom named LLM configs
|
||||
@ -170,22 +190,23 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
for k, v in nested_value.items():
|
||||
if not isinstance(v, dict):
|
||||
custom_fields[k] = v
|
||||
merged_llm_dict = generic_llm_config.__dict__.copy()
|
||||
merged_llm_dict = generic_llm_fields.copy()
|
||||
merged_llm_dict.update(custom_fields)
|
||||
|
||||
custom_llm_config = LLMConfig.from_dict(merged_llm_dict)
|
||||
|
||||
custom_llm_config = LLMConfig(**merged_llm_dict)
|
||||
cfg.set_llm_config(custom_llm_config, nested_key)
|
||||
|
||||
elif key is not None and key.lower() == 'security':
|
||||
logger.openhands_logger.debug(
|
||||
'Attempt to load security config from config toml'
|
||||
)
|
||||
security_config = SecurityConfig.from_dict(value)
|
||||
security_config = SecurityConfig(**value)
|
||||
cfg.security = security_config
|
||||
elif not key.startswith('sandbox') and key.lower() != 'core':
|
||||
logger.openhands_logger.warning(
|
||||
f'Unknown key in {toml_file}: "{key}"'
|
||||
)
|
||||
except (TypeError, KeyError) as e:
|
||||
except (TypeError, KeyError, ValidationError) as e:
|
||||
logger.openhands_logger.warning(
|
||||
f'Cannot parse [{key}] config from toml, values have not been applied.\nError: {e}',
|
||||
)
|
||||
@ -221,7 +242,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
logger.openhands_logger.warning(
|
||||
f'Unknown config key "{key}" in [core] section'
|
||||
)
|
||||
except (TypeError, KeyError) as e:
|
||||
except (TypeError, KeyError, ValidationError) as e:
|
||||
logger.openhands_logger.warning(
|
||||
f'Cannot parse [sandbox] config from toml, values have not been applied.\nError: {e}',
|
||||
)
|
||||
@ -324,7 +345,7 @@ def get_llm_config_arg(
|
||||
|
||||
# update the llm config with the specified section
|
||||
if 'llm' in toml_config and llm_config_arg in toml_config['llm']:
|
||||
return LLMConfig.from_dict(toml_config['llm'][llm_config_arg])
|
||||
return LLMConfig(**toml_config['llm'][llm_config_arg])
|
||||
logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}')
|
||||
return None
|
||||
|
||||
|
||||
@ -23,7 +23,9 @@ class AsyncLLM(LLM):
|
||||
self._async_completion = partial(
|
||||
self._call_acompletion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
api_key=self.config.api_key.get_secret_value()
|
||||
if self.config.api_key
|
||||
else None,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
|
||||
@ -141,7 +141,9 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
api_key=self.config.api_key.get_secret_value()
|
||||
if self.config.api_key
|
||||
else None,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
@ -331,7 +333,9 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# GET {base_url}/v1/model/info with litellm_model_id as path param
|
||||
response = requests.get(
|
||||
f'{self.config.base_url}/v1/model/info',
|
||||
headers={'Authorization': f'Bearer {self.config.api_key}'},
|
||||
headers={
|
||||
'Authorization': f'Bearer {self.config.api_key.get_secret_value() if self.config.api_key else None}'
|
||||
},
|
||||
)
|
||||
resp_json = response.json()
|
||||
if 'data' not in resp_json:
|
||||
|
||||
@ -17,7 +17,9 @@ class StreamingLLM(AsyncLLM):
|
||||
self._async_streaming_completion = partial(
|
||||
self._call_acompletion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
api_key=self.config.api_key.get_secret_value()
|
||||
if self.config.api_key
|
||||
else None,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
|
||||
@ -59,7 +59,8 @@ class ModalRuntime(ActionExecutionClient):
|
||||
self.sandbox = None
|
||||
|
||||
self.modal_client = modal.Client.from_credentials(
|
||||
config.modal_api_token_id, config.modal_api_token_secret
|
||||
config.modal_api_token_id.get_secret_value(),
|
||||
config.modal_api_token_secret.get_secret_value(),
|
||||
)
|
||||
self.app = modal.App.lookup(
|
||||
'openhands', create_if_missing=True, client=self.modal_client
|
||||
|
||||
@ -40,7 +40,7 @@ class RunloopRuntime(ActionExecutionClient):
|
||||
self.devbox: DevboxView | None = None
|
||||
self.config = config
|
||||
self.runloop_api_client = Runloop(
|
||||
bearer_token=config.runloop_api_key,
|
||||
bearer_token=config.runloop_api_key.get_secret_value(),
|
||||
)
|
||||
self.container_name = CONTAINER_NAME_PREFIX + sid
|
||||
super().__init__(
|
||||
|
||||
@ -51,8 +51,8 @@ async def get_litellm_models() -> list[str]:
|
||||
):
|
||||
bedrock_model_list = bedrock.list_foundation_models(
|
||||
llm_config.aws_region_name,
|
||||
llm_config.aws_access_key_id,
|
||||
llm_config.aws_secret_access_key,
|
||||
llm_config.aws_access_key_id.get_secret_value(),
|
||||
llm_config.aws_secret_access_key.get_secret_value(),
|
||||
)
|
||||
model_list = litellm_model_list_without_bedrock + bedrock_model_list
|
||||
for llm_config in config.llms.values():
|
||||
|
||||
@ -30,9 +30,6 @@ async def load_settings(request: Request) -> Settings | None:
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Settings not found'},
|
||||
)
|
||||
|
||||
# For security reasons we don't ever send the api key to the client
|
||||
settings.llm_api_key = 'SET' if settings.llm_api_key else None
|
||||
return settings
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid token: {e}')
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationInitData(Settings):
|
||||
"""
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
|
||||
github_token: str | None = None
|
||||
selected_repository: str | None = None
|
||||
github_token: str | None = Field(default=None)
|
||||
selected_repository: str | None = Field(default=None)
|
||||
|
||||
@ -91,6 +91,9 @@ class Session:
|
||||
)
|
||||
max_iterations = settings.max_iterations or self.config.max_iterations
|
||||
|
||||
# This is a shallow copy of the default LLM config, so changes here will
|
||||
# persist if we retrieve the default LLM config again when constructing
|
||||
# the agent
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = settings.llm_model or ''
|
||||
default_llm_config.api_key = settings.llm_api_key
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
class Settings(BaseModel):
|
||||
"""
|
||||
Persisted settings for OpenHands sessions
|
||||
"""
|
||||
@ -13,6 +13,20 @@ class Settings:
|
||||
security_analyzer: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
llm_model: str | None = None
|
||||
llm_api_key: str | None = None
|
||||
llm_api_key: SecretStr | None = None
|
||||
llm_base_url: str | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
|
||||
@field_serializer('llm_api_key')
|
||||
def llm_api_key_serializer(self, llm_api_key: SecretStr, info: SerializationInfo):
|
||||
"""Custom serializer for the LLM API key.
|
||||
|
||||
To serialize the API key instead of `"********"`, set `expose_secrets` to True in the serialization context. For example::
|
||||
|
||||
settings.model_dump_json(context={'expose_secrets': True})
|
||||
"""
|
||||
context = info.context
|
||||
if context and context.get('expose_secrets', False):
|
||||
return llm_api_key.get_secret_value()
|
||||
|
||||
return pydantic_encoder(llm_api_key)
|
||||
|
||||
@ -26,7 +26,7 @@ class FileSettingsStore(SettingsStore):
|
||||
return None
|
||||
|
||||
async def store(self, settings: Settings):
|
||||
json_str = json.dumps(settings.__dict__)
|
||||
json_str = settings.model_dump_json(context={'expose_secrets': True})
|
||||
await call_sync_from_async(self.file_store.write, self.path, json_str)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -90,7 +90,9 @@ class EmbeddingsLoader:
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
api_key=llm_config.api_key,
|
||||
api_key=llm_config.api_key.get_secret_value()
|
||||
if llm_config.api_key
|
||||
else None,
|
||||
)
|
||||
elif strategy == 'azureopenai':
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
@ -109,9 +109,6 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
|
||||
print(f'Cancel requested: {is_set}')
|
||||
return is_set
|
||||
|
||||
config = load_app_config()
|
||||
config.on_cancel_requested_fn = mock_on_cancel_requested
|
||||
|
||||
async def mock_acompletion(*args, **kwargs):
|
||||
print('Starting mock_acompletion')
|
||||
for i in range(20): # Increased iterations for longer running task
|
||||
@ -153,13 +150,6 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
|
||||
async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
|
||||
cancel_requested = False
|
||||
|
||||
async def mock_on_cancel_requested():
|
||||
nonlocal cancel_requested
|
||||
return cancel_requested
|
||||
|
||||
config = load_app_config()
|
||||
config.on_cancel_requested_fn = mock_on_cancel_requested
|
||||
|
||||
test_messages = [
|
||||
'This is ',
|
||||
'a test ',
|
||||
|
||||
@ -60,7 +60,6 @@ def mock_state() -> State:
|
||||
|
||||
|
||||
def test_cmd_output_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = CmdOutputObservation(
|
||||
command='echo hello',
|
||||
content='Command output',
|
||||
@ -82,7 +81,6 @@ def test_cmd_output_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_ipython_run_cell_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = IPythonRunCellObservation(
|
||||
code='plt.plot()',
|
||||
content='IPython output\n',
|
||||
@ -105,7 +103,6 @@ def test_ipython_run_cell_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_agent_delegate_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = AgentDelegateObservation(
|
||||
content='Content', outputs={'content': 'Delegated agent output'}
|
||||
)
|
||||
@ -122,7 +119,6 @@ def test_agent_delegate_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_error_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = ErrorObservation('Error message')
|
||||
|
||||
results = agent.get_observation_message(obs, tool_call_id_to_message={})
|
||||
@ -145,7 +141,6 @@ def test_unknown_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_file_edit_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
@ -167,7 +162,6 @@ def test_file_edit_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_file_read_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = FileReadObservation(
|
||||
path='/test/file.txt',
|
||||
content='File content',
|
||||
@ -186,7 +180,6 @@ def test_file_read_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_browser_output_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = BrowserOutputObservation(
|
||||
url='http://example.com',
|
||||
trigger_by_action='browse',
|
||||
@ -207,7 +200,6 @@ def test_browser_output_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_user_reject_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = False
|
||||
obs = UserRejectObservation('Action rejected')
|
||||
|
||||
results = agent.get_observation_message(obs, tool_call_id_to_message={})
|
||||
@ -223,7 +215,6 @@ def test_user_reject_observation_message(agent: CodeActAgent):
|
||||
|
||||
|
||||
def test_function_calling_observation_message(agent: CodeActAgent):
|
||||
agent.config.function_calling = True
|
||||
mock_response = {
|
||||
'id': 'mock_id',
|
||||
'total_calls_in_response': 1,
|
||||
|
||||
@ -226,7 +226,7 @@ def test_llm_condenser_from_config():
|
||||
|
||||
assert isinstance(condenser, LLMSummarizingCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
assert condenser.llm.config.api_key == 'test_key'
|
||||
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
|
||||
|
||||
|
||||
def test_llm_condenser(mock_llm, mock_state):
|
||||
@ -381,7 +381,7 @@ def test_llm_attention_condenser_from_config():
|
||||
|
||||
assert isinstance(condenser, LLMAttentionCondenser)
|
||||
assert condenser.llm.config.model == 'gpt-4o'
|
||||
assert condenser.llm.config.api_key == 'test_key'
|
||||
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert condenser.max_size == 50
|
||||
assert condenser.keep_first == 10
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ def test_compat_env_to_config(monkeypatch, setup_env):
|
||||
|
||||
assert config.workspace_base == '/repos/openhands/workspace'
|
||||
assert isinstance(config.get_llm_config(), LLMConfig)
|
||||
assert config.get_llm_config().api_key == 'sk-proj-rgMV0...'
|
||||
assert config.get_llm_config().api_key.get_secret_value() == 'sk-proj-rgMV0...'
|
||||
assert config.get_llm_config().model == 'gpt-4o'
|
||||
assert isinstance(config.get_agent_config(), AgentConfig)
|
||||
assert isinstance(config.get_agent_config().memory_max_threads, int)
|
||||
@ -83,7 +83,7 @@ def test_load_from_old_style_env(monkeypatch, default_config):
|
||||
|
||||
load_from_env(default_config, os.environ)
|
||||
|
||||
assert default_config.get_llm_config().api_key == 'test-api-key'
|
||||
assert default_config.get_llm_config().api_key.get_secret_value() == 'test-api-key'
|
||||
assert default_config.get_agent_config().memory_enabled is True
|
||||
assert default_config.default_agent == 'BrowsingAgent'
|
||||
assert default_config.workspace_base == '/opt/files/workspace'
|
||||
@ -126,7 +126,7 @@ default_agent = "TestAgent"
|
||||
# default llm & agent configs
|
||||
assert default_config.default_agent == 'TestAgent'
|
||||
assert default_config.get_llm_config().model == 'test-model'
|
||||
assert default_config.get_llm_config().api_key == 'toml-api-key'
|
||||
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
|
||||
assert default_config.get_agent_config().memory_enabled is True
|
||||
|
||||
# undefined agent config inherits default ones
|
||||
@ -291,7 +291,7 @@ sandbox_user_id = 1001
|
||||
assert default_config.get_llm_config().model == 'test-model'
|
||||
assert default_config.get_llm_config('llm').model == 'test-model'
|
||||
assert default_config.get_llm_config_from_agent().model == 'test-model'
|
||||
assert default_config.get_llm_config().api_key == 'env-api-key'
|
||||
assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key'
|
||||
|
||||
# after we set workspace_base to 'UNDEFINED' in the environment,
|
||||
# workspace_base should be set to that
|
||||
@ -336,7 +336,7 @@ user_id = 1001
|
||||
assert default_config.workspace_mount_path is None
|
||||
|
||||
# before load_from_env, values are set to the values from the toml file
|
||||
assert default_config.get_llm_config().api_key == 'toml-api-key'
|
||||
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
|
||||
assert default_config.sandbox.timeout == 500
|
||||
assert default_config.sandbox.user_id == 1001
|
||||
|
||||
@ -345,7 +345,7 @@ user_id = 1001
|
||||
# values from env override values from toml
|
||||
assert os.environ.get('LLM_MODEL') is None
|
||||
assert default_config.get_llm_config().model == 'test-model'
|
||||
assert default_config.get_llm_config().api_key == 'env-api-key'
|
||||
assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key'
|
||||
|
||||
assert default_config.sandbox.timeout == 1000
|
||||
assert default_config.sandbox.user_id == 1002
|
||||
@ -412,7 +412,7 @@ def test_security_config_from_dict():
|
||||
# Test with all fields
|
||||
config_dict = {'confirmation_mode': True, 'security_analyzer': 'some_analyzer'}
|
||||
|
||||
security_config = SecurityConfig.from_dict(config_dict)
|
||||
security_config = SecurityConfig(**config_dict)
|
||||
|
||||
# Verify all fields are correctly set
|
||||
assert security_config.confirmation_mode is True
|
||||
@ -560,10 +560,7 @@ invalid_field_in_sandbox = "test"
|
||||
assert 'Cannot parse [llm] config from toml' in log_content
|
||||
assert 'values have not been applied' in log_content
|
||||
# Error: LLMConfig.__init__() got an unexpected keyword argume
|
||||
assert (
|
||||
'Error: LLMConfig.__init__() got an unexpected keyword argume'
|
||||
in log_content
|
||||
)
|
||||
assert 'Error: 1 validation error for LLMConfig' in log_content
|
||||
assert 'invalid_field' in log_content
|
||||
|
||||
# invalid [sandbox] config
|
||||
@ -635,12 +632,14 @@ def test_api_keys_repr_str():
|
||||
aws_access_key_id='my_access_key',
|
||||
aws_secret_access_key='my_secret_key',
|
||||
)
|
||||
assert "api_key='******'" in repr(llm_config)
|
||||
assert "aws_access_key_id='******'" in repr(llm_config)
|
||||
assert "aws_secret_access_key='******'" in repr(llm_config)
|
||||
assert "api_key='******'" in str(llm_config)
|
||||
assert "aws_access_key_id='******'" in str(llm_config)
|
||||
assert "aws_secret_access_key='******'" in str(llm_config)
|
||||
|
||||
# Check that no secret keys are emitted in representations of the config object
|
||||
assert 'my_api_key' not in repr(llm_config)
|
||||
assert 'my_api_key' not in str(llm_config)
|
||||
assert 'my_access_key' not in repr(llm_config)
|
||||
assert 'my_access_key' not in str(llm_config)
|
||||
assert 'my_secret_key' not in repr(llm_config)
|
||||
assert 'my_secret_key' not in str(llm_config)
|
||||
|
||||
# Check that no other attrs in LLMConfig have 'key' or 'token' in their name
|
||||
# This will fail when new attrs are added, and attract attention
|
||||
@ -652,7 +651,7 @@ def test_api_keys_repr_str():
|
||||
'output_cost_per_token',
|
||||
'custom_tokenizer',
|
||||
]
|
||||
for attr_name in dir(LLMConfig):
|
||||
for attr_name in LLMConfig.model_fields.keys():
|
||||
if (
|
||||
not attr_name.startswith('__')
|
||||
and attr_name not in known_key_token_attrs_llm
|
||||
@ -667,7 +666,7 @@ def test_api_keys_repr_str():
|
||||
# Test AgentConfig
|
||||
# No attrs in AgentConfig have 'key' or 'token' in their name
|
||||
agent_config = AgentConfig(memory_enabled=True, memory_max_threads=4)
|
||||
for attr_name in dir(AgentConfig):
|
||||
for attr_name in AgentConfig.model_fields.keys():
|
||||
if not attr_name.startswith('__'):
|
||||
assert (
|
||||
'key' not in attr_name.lower()
|
||||
@ -686,16 +685,16 @@ def test_api_keys_repr_str():
|
||||
modal_api_token_secret='my_modal_api_token_secret',
|
||||
runloop_api_key='my_runloop_api_key',
|
||||
)
|
||||
assert "e2b_api_key='******'" in repr(app_config)
|
||||
assert "e2b_api_key='******'" in str(app_config)
|
||||
assert "jwt_secret='******'" in repr(app_config)
|
||||
assert "jwt_secret='******'" in str(app_config)
|
||||
assert "modal_api_token_id='******'" in repr(app_config)
|
||||
assert "modal_api_token_id='******'" in str(app_config)
|
||||
assert "modal_api_token_secret='******'" in repr(app_config)
|
||||
assert "modal_api_token_secret='******'" in str(app_config)
|
||||
assert "runloop_api_key='******'" in repr(app_config)
|
||||
assert "runloop_api_key='******'" in str(app_config)
|
||||
assert 'my_e2b_api_key' not in repr(app_config)
|
||||
assert 'my_e2b_api_key' not in str(app_config)
|
||||
assert 'my_jwt_secret' not in repr(app_config)
|
||||
assert 'my_jwt_secret' not in str(app_config)
|
||||
assert 'my_modal_api_token_id' not in repr(app_config)
|
||||
assert 'my_modal_api_token_id' not in str(app_config)
|
||||
assert 'my_modal_api_token_secret' not in repr(app_config)
|
||||
assert 'my_modal_api_token_secret' not in str(app_config)
|
||||
assert 'my_runloop_api_key' not in repr(app_config)
|
||||
assert 'my_runloop_api_key' not in str(app_config)
|
||||
|
||||
# Check that no other attrs in AppConfig have 'key' or 'token' in their name
|
||||
# This will fail when new attrs are added, and attract attention
|
||||
@ -705,7 +704,7 @@ def test_api_keys_repr_str():
|
||||
'modal_api_token_secret',
|
||||
'runloop_api_key',
|
||||
]
|
||||
for attr_name in dir(AppConfig):
|
||||
for attr_name in AppConfig.model_fields.keys():
|
||||
if (
|
||||
not attr_name.startswith('__')
|
||||
and attr_name not in known_key_token_attrs_app
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -43,7 +42,7 @@ async def test_store_and_load_data(file_settings_store):
|
||||
await file_settings_store.store(init_data)
|
||||
|
||||
# Verify store called with correct JSON
|
||||
expected_json = json.dumps(init_data.__dict__)
|
||||
expected_json = init_data.model_dump_json(context={'expose_secrets': True})
|
||||
file_settings_store.file_store.write.assert_called_once_with(
|
||||
'settings.json', expected_json
|
||||
)
|
||||
@ -60,7 +59,12 @@ async def test_store_and_load_data(file_settings_store):
|
||||
assert loaded_data.security_analyzer == init_data.security_analyzer
|
||||
assert loaded_data.confirmation_mode == init_data.confirmation_mode
|
||||
assert loaded_data.llm_model == init_data.llm_model
|
||||
assert loaded_data.llm_api_key == init_data.llm_api_key
|
||||
assert loaded_data.llm_api_key
|
||||
assert init_data.llm_api_key
|
||||
assert (
|
||||
loaded_data.llm_api_key.get_secret_value()
|
||||
== init_data.llm_api_key.get_secret_value()
|
||||
)
|
||||
assert loaded_data.llm_base_url == init_data.llm_base_url
|
||||
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ def default_config():
|
||||
def test_llm_init_with_default_config(default_config):
|
||||
llm = LLM(default_config)
|
||||
assert llm.config.model == 'gpt-4o'
|
||||
assert llm.config.api_key == 'test_key'
|
||||
assert llm.config.api_key.get_secret_value() == 'test_key'
|
||||
assert isinstance(llm.metrics, Metrics)
|
||||
assert llm.metrics.model_name == 'gpt-4o'
|
||||
|
||||
@ -77,7 +77,7 @@ def test_llm_init_with_custom_config():
|
||||
)
|
||||
llm = LLM(custom_config)
|
||||
assert llm.config.model == 'custom-model'
|
||||
assert llm.config.api_key == 'custom_key'
|
||||
assert llm.config.api_key.get_secret_value() == 'custom_key'
|
||||
assert llm.config.max_input_tokens == 5000
|
||||
assert llm.config.max_output_tokens == 1500
|
||||
assert llm.config.temperature == 0.8
|
||||
|
||||
@ -59,28 +59,28 @@ def test_load_from_toml_llm_with_fallback(
|
||||
# Verify generic LLM configuration
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key == 'base-api-key'
|
||||
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
||||
assert generic_llm.embedding_model == 'base-embedding'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify custom1 LLM falls back 'num_retries' from base
|
||||
custom1 = default_config.get_llm_config('custom1')
|
||||
assert custom1.model == 'custom-model-1'
|
||||
assert custom1.api_key == 'custom-api-key-1'
|
||||
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
|
||||
assert custom1.embedding_model == 'base-embedding'
|
||||
assert custom1.num_retries == 3 # from [llm]
|
||||
|
||||
# Verify custom2 LLM overrides 'num_retries'
|
||||
custom2 = default_config.get_llm_config('custom2')
|
||||
assert custom2.model == 'custom-model-2'
|
||||
assert custom2.api_key == 'custom-api-key-2'
|
||||
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
|
||||
assert custom2.embedding_model == 'base-embedding'
|
||||
assert custom2.num_retries == 5 # overridden value
|
||||
|
||||
# Verify custom3 LLM inherits all attributes except 'model' and 'api_key'
|
||||
custom3 = default_config.get_llm_config('custom3')
|
||||
assert custom3.model == 'custom-model-3'
|
||||
assert custom3.api_key == 'custom-api-key-3'
|
||||
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
|
||||
assert custom3.embedding_model == 'base-embedding'
|
||||
assert custom3.num_retries == 3 # from [llm]
|
||||
|
||||
@ -113,14 +113,14 @@ num_retries = 10
|
||||
# Verify generic LLM configuration remains unchanged
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key == 'base-api-key'
|
||||
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
||||
assert generic_llm.embedding_model == 'base-embedding'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify custom_full LLM overrides all attributes
|
||||
custom_full = default_config.get_llm_config('custom_full')
|
||||
assert custom_full.model == 'full-custom-model'
|
||||
assert custom_full.api_key == 'full-custom-api-key'
|
||||
assert custom_full.api_key.get_secret_value() == 'full-custom-api-key'
|
||||
assert custom_full.embedding_model == 'full-custom-embedding'
|
||||
assert custom_full.num_retries == 10 # overridden value
|
||||
|
||||
@ -136,14 +136,14 @@ def test_load_from_toml_llm_custom_partial_override(
|
||||
# Verify custom1 LLM overrides 'model' and 'api_key' but inherits 'num_retries'
|
||||
custom1 = default_config.get_llm_config('custom1')
|
||||
assert custom1.model == 'custom-model-1'
|
||||
assert custom1.api_key == 'custom-api-key-1'
|
||||
assert custom1.api_key.get_secret_value() == 'custom-api-key-1'
|
||||
assert custom1.embedding_model == 'base-embedding'
|
||||
assert custom1.num_retries == 3 # from [llm]
|
||||
|
||||
# Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries'
|
||||
custom2 = default_config.get_llm_config('custom2')
|
||||
assert custom2.model == 'custom-model-2'
|
||||
assert custom2.api_key == 'custom-api-key-2'
|
||||
assert custom2.api_key.get_secret_value() == 'custom-api-key-2'
|
||||
assert custom2.embedding_model == 'base-embedding'
|
||||
assert custom2.num_retries == 5 # Overridden value
|
||||
|
||||
@ -159,7 +159,7 @@ def test_load_from_toml_llm_custom_no_override(
|
||||
# Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic
|
||||
custom3 = default_config.get_llm_config('custom3')
|
||||
assert custom3.model == 'custom-model-3'
|
||||
assert custom3.api_key == 'custom-api-key-3'
|
||||
assert custom3.api_key.get_secret_value() == 'custom-api-key-3'
|
||||
assert custom3.embedding_model == 'base-embedding'
|
||||
assert custom3.num_retries == 3 # from [llm]
|
||||
|
||||
@ -186,7 +186,7 @@ api_key = "custom-only-api-key"
|
||||
# Verify custom_only LLM uses its own attributes and defaults for others
|
||||
custom_only = default_config.get_llm_config('custom_only')
|
||||
assert custom_only.model == 'custom-only-model'
|
||||
assert custom_only.api_key == 'custom-only-api-key'
|
||||
assert custom_only.api_key.get_secret_value() == 'custom-only-api-key'
|
||||
assert custom_only.embedding_model == 'local' # default value
|
||||
assert custom_only.num_retries == 8 # default value
|
||||
|
||||
@ -217,12 +217,12 @@ unknown_attr = "should_not_exist"
|
||||
# Verify generic LLM is loaded correctly
|
||||
generic_llm = default_config.get_llm_config('llm')
|
||||
assert generic_llm.model == 'base-model'
|
||||
assert generic_llm.api_key == 'base-api-key'
|
||||
assert generic_llm.api_key.get_secret_value() == 'base-api-key'
|
||||
assert generic_llm.num_retries == 3
|
||||
|
||||
# Verify invalid_custom LLM does not override generic attributes
|
||||
custom_invalid = default_config.get_llm_config('invalid_custom')
|
||||
assert custom_invalid.model == 'base-model'
|
||||
assert custom_invalid.api_key == 'base-api-key'
|
||||
assert custom_invalid.api_key.get_secret_value() == 'base-api-key'
|
||||
assert custom_invalid.num_retries == 3 # default value
|
||||
assert custom_invalid.embedding_model == 'local' # default value
|
||||
|
||||
@ -86,7 +86,7 @@ def test_draft_editor_as_named_llm(config_toml_with_draft_editor):
|
||||
draft_llm = config.get_llm_config('draft_editor')
|
||||
assert draft_llm is not None
|
||||
assert draft_llm.model == 'draft-model'
|
||||
assert draft_llm.api_key == 'draft-api-key'
|
||||
assert draft_llm.api_key.get_secret_value() == 'draft-api-key'
|
||||
|
||||
|
||||
def test_draft_editor_fallback(config_toml_with_draft_editor):
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.server.app import app
|
||||
@ -50,7 +51,7 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
'security_analyzer': 'default',
|
||||
'confirmation_mode': True,
|
||||
'llm_model': 'test-model',
|
||||
'llm_api_key': None,
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'remote_runtime_resource_factor': 2,
|
||||
}
|
||||
@ -83,3 +84,36 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
mock_settings_store.store.assert_called()
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.remote_runtime_resource_factor == 2
|
||||
|
||||
assert isinstance(stored_settings.llm_api_key, SecretStr)
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_llm_api_key(test_client, mock_settings_store):
|
||||
# Mock the settings store to return None initially (no existing settings)
|
||||
mock_settings_store.load.return_value = None
|
||||
|
||||
# Test data with remote_runtime_resource_factor
|
||||
settings_data = {'llm_api_key': 'test-key'}
|
||||
|
||||
# The test_client fixture already handles authentication
|
||||
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the settings were stored with the correct secret API key
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert isinstance(stored_settings.llm_api_key, SecretStr)
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
mock_settings_store.load.return_value = Settings(**settings_data)
|
||||
|
||||
# Make a GET request to retrieve settings
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
|
||||
# We should never expose the API key in the response
|
||||
assert 'test-key' not in response.json()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user