Revert "Config objects as Pydantic BaseModels (#6176)" (#6214)

This commit is contained in:
tofarr 2025-01-13 07:36:25 -07:00 committed by GitHub
parent 63133c0ba9
commit 23473070b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 406 additions and 257 deletions

View File

@ -80,7 +80,7 @@ def load_dependencies(runtime: Runtime) -> List[str]:
def init_task_env(runtime: Runtime, hostname: str, env_llm_config: LLMConfig):
command = (
f'SERVER_HOSTNAME={hostname} '
f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} '
f'LITELLM_API_KEY={env_llm_config.api_key} '
f'LITELLM_BASE_URL={env_llm_config.base_url} '
f'LITELLM_MODEL={env_llm_config.model} '
'bash /utils/init.sh'
@ -165,7 +165,7 @@ def run_evaluator(
runtime: Runtime, env_llm_config: LLMConfig, trajectory_path: str, result_path: str
):
command = (
f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} '
f'LITELLM_API_KEY={env_llm_config.api_key} '
f'LITELLM_BASE_URL={env_llm_config.base_url} '
f'LITELLM_MODEL={env_llm_config.model} '
f"DECRYPTION_KEY='theagentcompany is all you need' " # Hardcoded Key

View File

@ -52,6 +52,30 @@ class EvalMetadata(BaseModel):
details: dict[str, Any] | None = None
condenser_config: CondenserConfig | None = None
def model_dump(self, *args, **kwargs):
dumped_dict = super().model_dump(*args, **kwargs)
# avoid leaking sensitive information
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
if hasattr(self.condenser_config, 'llm_config'):
dumped_dict['condenser_config']['llm_config'] = (
self.condenser_config.llm_config.to_safe_dict()
)
return dumped_dict
def model_dump_json(self, *args, **kwargs):
dumped = super().model_dump_json(*args, **kwargs)
dumped_dict = json.loads(dumped)
# avoid leaking sensitive information
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
if hasattr(self.condenser_config, 'llm_config'):
dumped_dict['condenser_config']['llm_config'] = (
self.condenser_config.llm_config.to_safe_dict()
)
logger.debug(f'Dumped metadata: {dumped_dict}')
return json.dumps(dumped_dict)
class EvalOutput(BaseModel):
# NOTE: User-specified
@ -74,6 +98,23 @@ class EvalOutput(BaseModel):
# Optionally save the input test instance
instance: dict[str, Any] | None = None
def model_dump(self, *args, **kwargs):
dumped_dict = super().model_dump(*args, **kwargs)
# Remove None values
dumped_dict = {k: v for k, v in dumped_dict.items() if v is not None}
# Apply custom serialization for metadata (to avoid leaking sensitive information)
if self.metadata is not None:
dumped_dict['metadata'] = self.metadata.model_dump()
return dumped_dict
def model_dump_json(self, *args, **kwargs):
dumped = super().model_dump_json(*args, **kwargs)
dumped_dict = json.loads(dumped)
# Apply custom serialization for metadata (to avoid leaking sensitive information)
if 'metadata' in dumped_dict:
dumped_dict['metadata'] = json.loads(self.metadata.model_dump_json())
return json.dumps(dumped_dict)
class EvalException(Exception):
pass
@ -273,7 +314,7 @@ def update_progress(
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
output_fp.write(result.model_dump_json() + '\n')
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.flush()

View File

@ -37,17 +37,21 @@ export SANDBOX_TIMEOUT='300'
## Type Handling
The `load_from_env` function attempts to cast environment variable values to the types specified in the models. It handles:
The `load_from_env` function attempts to cast environment variable values to the types specified in the dataclasses. It handles:
- Basic types (str, int, bool)
- Optional types (e.g., `str | None`)
- Nested models
- Nested dataclasses
If type casting fails, an error is logged, and the default value is retained.
## Default Values
If an environment variable is not set, the default value specified in the model is used.
If an environment variable is not set, the default value specified in the dataclass is used.
## Nested Configurations
The `AppConfig` class contains nested configurations like `LLMConfig` and `AgentConfig`. The `load_from_env` function handles these by recursively processing nested dataclasses with updated prefixes.
## Security Considerations

View File

@ -1,9 +1,11 @@
from pydantic import BaseModel, Field
from dataclasses import dataclass, field, fields
from openhands.core.config.condenser_config import CondenserConfig, NoOpCondenserConfig
from openhands.core.config.config_utils import get_field_info
class AgentConfig(BaseModel):
@dataclass
class AgentConfig:
"""Configuration for the agent.
Attributes:
@ -20,13 +22,20 @@ class AgentConfig(BaseModel):
condenser: Configuration for the memory condenser. Default is NoOpCondenserConfig.
"""
codeact_enable_browsing: bool = Field(default=True)
codeact_enable_llm_editor: bool = Field(default=False)
codeact_enable_jupyter: bool = Field(default=True)
micro_agent_name: str | None = Field(default=None)
memory_enabled: bool = Field(default=False)
memory_max_threads: int = Field(default=3)
llm_config: str | None = Field(default=None)
use_microagents: bool = Field(default=True)
disabled_microagents: list[str] | None = Field(default=None)
condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig)
codeact_enable_browsing: bool = True
codeact_enable_llm_editor: bool = False
codeact_enable_jupyter: bool = True
micro_agent_name: str | None = None
memory_enabled: bool = False
memory_max_threads: int = 3
llm_config: str | None = None
use_microagents: bool = True
disabled_microagents: list[str] | None = None
condenser: CondenserConfig = field(default_factory=NoOpCondenserConfig) # type: ignore
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
result[f.name] = get_field_info(f)
return result

View File

@ -1,20 +1,20 @@
from dataclasses import dataclass, field, fields, is_dataclass
from typing import ClassVar
from pydantic import BaseModel, Field, SecretStr
from openhands.core import logger
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.config_utils import (
OH_DEFAULT_AGENT,
OH_MAX_ITERATIONS,
model_defaults_to_dict,
get_field_info,
)
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.core.config.security_config import SecurityConfig
class AppConfig(BaseModel):
@dataclass
class AppConfig:
"""Configuration for the app.
Attributes:
@ -46,39 +46,37 @@ class AppConfig(BaseModel):
input is read line by line. When enabled, input continues until /exit command.
"""
llms: dict[str, LLMConfig] = Field(default_factory=dict)
agents: dict = Field(default_factory=dict)
default_agent: str = Field(default=OH_DEFAULT_AGENT)
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
security: SecurityConfig = Field(default_factory=SecurityConfig)
runtime: str = Field(default='docker')
file_store: str = Field(default='local')
file_store_path: str = Field(default='/tmp/openhands_file_store')
trajectories_path: str | None = Field(default=None)
workspace_base: str | None = Field(default=None)
workspace_mount_path: str | None = Field(default=None)
workspace_mount_path_in_sandbox: str = Field(default='/workspace')
workspace_mount_rewrite: str | None = Field(default=None)
cache_dir: str = Field(default='/tmp/cache')
run_as_openhands: bool = Field(default=True)
max_iterations: int = Field(default=OH_MAX_ITERATIONS)
max_budget_per_task: float | None = Field(default=None)
e2b_api_key: SecretStr | None = Field(default=None)
modal_api_token_id: SecretStr | None = Field(default=None)
modal_api_token_secret: SecretStr | None = Field(default=None)
disable_color: bool = Field(default=False)
jwt_secret: SecretStr | None = Field(default=None)
debug: bool = Field(default=False)
file_uploads_max_file_size_mb: int = Field(default=0)
file_uploads_restrict_file_types: bool = Field(default=False)
file_uploads_allowed_extensions: list[str] = Field(default_factory=lambda: ['.*'])
runloop_api_key: SecretStr | None = Field(default=None)
cli_multiline_input: bool = Field(default=False)
llms: dict[str, LLMConfig] = field(default_factory=dict)
agents: dict = field(default_factory=dict)
default_agent: str = OH_DEFAULT_AGENT
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
runtime: str = 'docker'
file_store: str = 'local'
file_store_path: str = '/tmp/openhands_file_store'
trajectories_path: str | None = None
workspace_base: str | None = None
workspace_mount_path: str | None = None
workspace_mount_path_in_sandbox: str = '/workspace'
workspace_mount_rewrite: str | None = None
cache_dir: str = '/tmp/cache'
run_as_openhands: bool = True
max_iterations: int = OH_MAX_ITERATIONS
max_budget_per_task: float | None = None
e2b_api_key: str = ''
modal_api_token_id: str = ''
modal_api_token_secret: str = ''
disable_color: bool = False
jwt_secret: str = ''
debug: bool = False
file_uploads_max_file_size_mb: int = 0
file_uploads_restrict_file_types: bool = False
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
runloop_api_key: str | None = None
cli_multiline_input: bool = False
defaults_dict: ClassVar[dict] = {}
model_config = {'extra': 'forbid'}
def get_llm_config(self, name='llm') -> LLMConfig:
"""'llm' is the name for default config (for backward compatibility prior to 0.8)."""
if name in self.llms:
@ -117,7 +115,42 @@ class AppConfig(BaseModel):
def get_agent_configs(self) -> dict[str, AgentConfig]:
return self.agents
def model_post_init(self, __context):
def __post_init__(self):
"""Post-initialization hook, called when the instance is created with only default values."""
super().model_post_init(__context)
AppConfig.defaults_dict = model_defaults_to_dict(self)
AppConfig.defaults_dict = self.defaults_to_dict()
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
field_value = getattr(self, f.name)
# dataclasses compute their defaults themselves
if is_dataclass(type(field_value)):
result[f.name] = field_value.defaults_to_dict()
else:
result[f.name] = get_field_info(f)
return result
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
if attr_name in [
'e2b_api_key',
'github_token',
'jwt_secret',
'modal_api_token_id',
'modal_api_token_secret',
'runloop_api_key',
]:
attr_value = '******' if attr_value else None
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"AppConfig({', '.join(attr_str)}"
def __repr__(self):
return self.__str__()

View File

@ -1,22 +1,19 @@
from types import UnionType
from typing import Any, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing import get_args, get_origin
OH_DEFAULT_AGENT = 'CodeActAgent'
OH_MAX_ITERATIONS = 500
def get_field_info(field: FieldInfo) -> dict[str, Any]:
def get_field_info(f):
"""Extract information about a dataclass field: type, optional, and default.
Args:
field: The field to extract information from.
f: The field to extract information from.
Returns: A dict with the field's type, whether it's optional, and its default value.
"""
field_type = field.annotation
field_type = f.type
optional = False
# for types like str | None, find the non-None type and set optional to True
@ -36,21 +33,7 @@ def get_field_info(field: FieldInfo) -> dict[str, Any]:
)
# default is always present
default = field.default
default = f.default
# return a schema with the useful info for frontend
return {'type': type_name.lower(), 'optional': optional, 'default': default}
def model_defaults_to_dict(model: BaseModel) -> dict[str, Any]:
"""Serialize field information in a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for name, field in model.model_fields.items():
field_value = getattr(model, name)
if isinstance(field_value, BaseModel):
result[name] = model_defaults_to_dict(field_value)
else:
result[name] = get_field_info(field)
return result

View File

@ -1,14 +1,15 @@
from __future__ import annotations
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from dataclasses import dataclass, fields
from typing import Optional
from openhands.core.config.config_utils import get_field_info
from openhands.core.logger import LOG_DIR
LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key']
class LLMConfig(BaseModel):
@dataclass
class LLMConfig:
"""Configuration for the LLM model.
Attributes:
@ -47,57 +48,98 @@ class LLMConfig(BaseModel):
native_tool_calling: Whether to use native tool calling if supported by the model. Can be True, False, or not set.
"""
model: str = Field(default='claude-3-5-sonnet-20241022')
api_key: SecretStr | None = Field(default=None)
base_url: str | None = Field(default=None)
api_version: str | None = Field(default=None)
embedding_model: str = Field(default='local')
embedding_base_url: str | None = Field(default=None)
embedding_deployment_name: str | None = Field(default=None)
aws_access_key_id: SecretStr | None = Field(default=None)
aws_secret_access_key: SecretStr | None = Field(default=None)
aws_region_name: str | None = Field(default=None)
openrouter_site_url: str = Field(default='https://docs.all-hands.dev/')
openrouter_app_name: str = Field(default='OpenHands')
num_retries: int = Field(default=8)
retry_multiplier: float = Field(default=2)
retry_min_wait: int = Field(default=15)
retry_max_wait: int = Field(default=120)
timeout: int | None = Field(default=None)
max_message_chars: int = Field(
default=30_000
) # maximum number of characters in an observation's content when sent to the llm
temperature: float = Field(default=0.0)
top_p: float = Field(default=1.0)
custom_llm_provider: str | None = Field(default=None)
max_input_tokens: int | None = Field(default=None)
max_output_tokens: int | None = Field(default=None)
input_cost_per_token: float | None = Field(default=None)
output_cost_per_token: float | None = Field(default=None)
ollama_base_url: str | None = Field(default=None)
model: str = 'claude-3-5-sonnet-20241022'
api_key: str | None = None
base_url: str | None = None
api_version: str | None = None
embedding_model: str = 'local'
embedding_base_url: str | None = None
embedding_deployment_name: str | None = None
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None
aws_region_name: str | None = None
openrouter_site_url: str = 'https://docs.all-hands.dev/'
openrouter_app_name: str = 'OpenHands'
num_retries: int = 8
retry_multiplier: float = 2
retry_min_wait: int = 15
retry_max_wait: int = 120
timeout: int | None = None
max_message_chars: int = 30_000 # maximum number of characters in an observation's content when sent to the llm
temperature: float = 0.0
top_p: float = 1.0
custom_llm_provider: str | None = None
max_input_tokens: int | None = None
max_output_tokens: int | None = None
input_cost_per_token: float | None = None
output_cost_per_token: float | None = None
ollama_base_url: str | None = None
# This setting can be sent in each call to litellm
drop_params: bool = Field(default=True)
drop_params: bool = True
# Note: this setting is actually global, unlike drop_params
modify_params: bool = Field(default=True)
disable_vision: bool | None = Field(default=None)
caching_prompt: bool = Field(default=True)
log_completions: bool = Field(default=False)
log_completions_folder: str = Field(default=os.path.join(LOG_DIR, 'completions'))
draft_editor: LLMConfig | None = Field(default=None)
custom_tokenizer: str | None = Field(default=None)
native_tool_calling: bool | None = Field(default=None)
modify_params: bool = True
disable_vision: bool | None = None
caching_prompt: bool = True
log_completions: bool = False
log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
draft_editor: Optional['LLMConfig'] = None
custom_tokenizer: str | None = None
native_tool_calling: bool | None = None
model_config = {'extra': 'forbid'}
def model_post_init(self, __context: Any):
"""Post-initialization hook to assign OpenRouter-related variables to environment variables.
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
result[f.name] = get_field_info(f)
return result
def __post_init__(self):
"""
Post-initialization hook to assign OpenRouter-related variables to environment variables.
This ensures that these values are accessible to litellm at runtime.
"""
super().model_post_init(__context)
# Assign OpenRouter-specific variables to environment variables
if self.openrouter_site_url:
os.environ['OR_SITE_URL'] = self.openrouter_site_url
if self.openrouter_app_name:
os.environ['OR_APP_NAME'] = self.openrouter_app_name
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
if attr_name in LLM_SENSITIVE_FIELDS:
attr_value = '******' if attr_value else None
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"LLMConfig({', '.join(attr_str)})"
def __repr__(self):
return self.__str__()
def to_safe_dict(self):
"""Return a dict with the sensitive fields replaced with ******."""
ret = self.__dict__.copy()
for k, v in ret.items():
if k in LLM_SENSITIVE_FIELDS:
ret[k] = '******' if v else None
elif isinstance(v, LLMConfig):
ret[k] = v.to_safe_dict()
return ret
@classmethod
def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig':
"""Create an LLMConfig object from a dictionary.
This function is used to create an LLMConfig object from a dictionary,
with the exception of the 'draft_editor' key, which is a nested LLMConfig object.
"""
args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)}
if 'draft_editor' in llm_config_dict:
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
args['draft_editor'] = draft_editor_config
return cls(**args)

View File

@ -1,9 +1,11 @@
import os
from dataclasses import dataclass, field, fields
from pydantic import BaseModel, Field
from openhands.core.config.config_utils import get_field_info
class SandboxConfig(BaseModel):
@dataclass
class SandboxConfig:
"""Configuration for the sandbox.
Attributes:
@ -37,32 +39,48 @@ class SandboxConfig(BaseModel):
This should be a JSON string that will be parsed into a dictionary.
"""
remote_runtime_api_url: str = Field(default='http://localhost:8000')
local_runtime_url: str = Field(default='http://localhost')
keep_runtime_alive: bool = Field(default=True)
rm_all_containers: bool = Field(default=False)
api_key: str | None = Field(default=None)
base_container_image: str = Field(
default='nikolaik/python-nodejs:python3.12-nodejs22'
remote_runtime_api_url: str = 'http://localhost:8000'
local_runtime_url: str = 'http://localhost'
keep_runtime_alive: bool = True
rm_all_containers: bool = False
api_key: str | None = None
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
runtime_container_image: str | None = None
user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
timeout: int = 120
remote_runtime_init_timeout: int = 180
enable_auto_lint: bool = (
False # once enabled, OpenHands would lint files after editing
)
runtime_container_image: str | None = Field(default=None)
user_id: int = Field(default=os.getuid() if hasattr(os, 'getuid') else 1000)
timeout: int = Field(default=120)
remote_runtime_init_timeout: int = Field(default=180)
enable_auto_lint: bool = Field(
default=False # once enabled, OpenHands would lint files after editing
)
use_host_network: bool = Field(default=False)
runtime_extra_build_args: list[str] | None = Field(default=None)
initialize_plugins: bool = Field(default=True)
force_rebuild_runtime: bool = Field(default=False)
runtime_extra_deps: str | None = Field(default=None)
runtime_startup_env_vars: dict[str, str] = Field(default_factory=dict)
browsergym_eval_env: str | None = Field(default=None)
platform: str | None = Field(default=None)
close_delay: int = Field(default=900)
remote_runtime_resource_factor: int = Field(default=1)
enable_gpu: bool = Field(default=False)
docker_runtime_kwargs: str | None = Field(default=None)
use_host_network: bool = False
runtime_extra_build_args: list[str] | None = None
initialize_plugins: bool = True
force_rebuild_runtime: bool = False
runtime_extra_deps: str | None = None
runtime_startup_env_vars: dict[str, str] = field(default_factory=dict)
browsergym_eval_env: str | None = None
platform: str | None = None
close_delay: int = 900
remote_runtime_resource_factor: int = 1
enable_gpu: bool = False
docker_runtime_kwargs: str | None = None
model_config = {'extra': 'forbid'}
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
dict = {}
for f in fields(self):
dict[f.name] = get_field_info(f)
return dict
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"SandboxConfig({', '.join(attr_str)})"
def __repr__(self):
return self.__str__()

View File

@ -1,7 +1,10 @@
from pydantic import BaseModel, Field
from dataclasses import dataclass, fields
from openhands.core.config.config_utils import get_field_info
class SecurityConfig(BaseModel):
@dataclass
class SecurityConfig:
"""Configuration for security related functionalities.
Attributes:
@ -9,5 +12,29 @@ class SecurityConfig(BaseModel):
security_analyzer: The security analyzer to use.
"""
confirmation_mode: bool = Field(default=False)
security_analyzer: str | None = Field(default=None)
confirmation_mode: bool = False
security_analyzer: str | None = None
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
dict = {}
for f in fields(self):
dict[f.name] = get_field_info(f)
return dict
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"SecurityConfig({', '.join(attr_str)})"
@classmethod
def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig':
return cls(**security_config_dict)
def __repr__(self):
return self.__str__()

View File

@ -3,13 +3,13 @@ import os
import pathlib
import platform
import sys
from dataclasses import is_dataclass
from types import UnionType
from typing import Any, MutableMapping, get_args, get_origin
from uuid import uuid4
import toml
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
from openhands.core import logger
from openhands.core.config.agent_config import AgentConfig
@ -43,19 +43,17 @@ def load_from_env(cfg: AppConfig, env_or_toml_dict: dict | MutableMapping[str, s
return next((t for t in types if t is not type(None)), None)
# helper function to set attributes based on env vars
def set_attr_from_env(sub_config: BaseModel, prefix=''):
"""Set attributes of a config model based on environment variables."""
for field_name, field_info in sub_config.model_fields.items():
field_value = getattr(sub_config, field_name)
field_type = field_info.annotation
def set_attr_from_env(sub_config: Any, prefix=''):
"""Set attributes of a config dataclass based on environment variables."""
for field_name, field_type in sub_config.__annotations__.items():
# compute the expected env var name from the prefix and field name
# e.g. LLM_BASE_URL
env_var_name = (prefix + field_name).upper()
if isinstance(field_value, BaseModel):
set_attr_from_env(field_value, prefix=field_name + '_')
if is_dataclass(field_type):
# nested dataclass
nested_sub_config = getattr(sub_config, field_name)
set_attr_from_env(nested_sub_config, prefix=field_name + '_')
elif env_var_name in env_or_toml_dict:
# convert the env var to the correct type and set it
value = env_or_toml_dict[env_var_name]
@ -128,60 +126,45 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
if isinstance(value, dict):
try:
if key is not None and key.lower() == 'agent':
# Every entry here is either a field for the default `agent` config group, or itself a group
# The best way to tell the difference is to try to parse it as an AgentConfig object
agent_group_ids: set[str] = set()
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
try:
agent_config = AgentConfig(**nested_value)
except ValidationError:
continue
agent_group_ids.add(nested_key)
cfg.set_agent_config(agent_config, nested_key)
logger.openhands_logger.debug(
'Attempt to load default agent config from config toml'
)
value_without_groups = {
k: v for k, v in value.items() if k not in agent_group_ids
non_dict_fields = {
k: v for k, v in value.items() if not isinstance(v, dict)
}
agent_config = AgentConfig(**value_without_groups)
agent_config = AgentConfig(**non_dict_fields)
cfg.set_agent_config(agent_config, 'agent')
elif key is not None and key.lower() == 'llm':
# Every entry here is either a field for the default `llm` config group, or itself a group
# The best way to tell the difference is to try to parse it as an LLMConfig object
llm_group_ids: set[str] = set()
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
try:
llm_config = LLMConfig(**nested_value)
except ValidationError:
continue
llm_group_ids.add(nested_key)
cfg.set_llm_config(llm_config, nested_key)
logger.openhands_logger.debug(
f'Attempt to load group {nested_key} from config toml as agent config'
)
agent_config = AgentConfig(**nested_value)
cfg.set_agent_config(agent_config, nested_key)
elif key is not None and key.lower() == 'llm':
logger.openhands_logger.debug(
'Attempt to load default LLM config from config toml'
)
value_without_groups = {
k: v for k, v in value.items() if k not in llm_group_ids
}
llm_config = LLMConfig(**value_without_groups)
llm_config = LLMConfig.from_dict(value)
cfg.set_llm_config(llm_config, 'llm')
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
logger.openhands_logger.debug(
f'Attempt to load group {nested_key} from config toml as llm config'
)
llm_config = LLMConfig.from_dict(nested_value)
cfg.set_llm_config(llm_config, nested_key)
elif key is not None and key.lower() == 'security':
logger.openhands_logger.debug(
'Attempt to load security config from config toml'
)
security_config = SecurityConfig(**value)
security_config = SecurityConfig.from_dict(value)
cfg.security = security_config
elif not key.startswith('sandbox') and key.lower() != 'core':
logger.openhands_logger.warning(
f'Unknown key in {toml_file}: "{key}"'
)
except (TypeError, KeyError, ValidationError) as e:
except (TypeError, KeyError) as e:
logger.openhands_logger.warning(
f'Cannot parse [{key}] config from toml, values have not been applied.\nError: {e}',
exc_info=False,
@ -218,7 +201,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
logger.openhands_logger.warning(
f'Unknown config key "{key}" in [core] section'
)
except (TypeError, KeyError, ValidationError) as e:
except (TypeError, KeyError) as e:
logger.openhands_logger.warning(
f'Cannot parse [sandbox] config from toml, values have not been applied.\nError: {e}',
exc_info=False,
@ -322,7 +305,7 @@ def get_llm_config_arg(
# update the llm config with the specified section
if 'llm' in toml_config and llm_config_arg in toml_config['llm']:
return LLMConfig(**toml_config['llm'][llm_config_arg])
return LLMConfig.from_dict(toml_config['llm'][llm_config_arg])
logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}')
return None

View File

@ -19,9 +19,7 @@ class AsyncLLM(LLM):
self._async_completion = partial(
self._call_acompletion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,

View File

@ -132,9 +132,7 @@ class LLM(RetryMixin, DebugMixin):
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
@ -320,9 +318,7 @@ class LLM(RetryMixin, DebugMixin):
# GET {base_url}/v1/model/info with litellm_model_id as path param
response = requests.get(
f'{self.config.base_url}/v1/model/info',
headers={
'Authorization': f'Bearer {self.config.api_key.get_secret_value() if self.config.api_key else None}'
},
headers={'Authorization': f'Bearer {self.config.api_key}'},
)
resp_json = response.json()
if 'data' not in resp_json:

View File

@ -16,9 +16,7 @@ class StreamingLLM(AsyncLLM):
self._async_streaming_completion = partial(
self._call_acompletion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,

View File

@ -59,8 +59,7 @@ class ModalRuntime(ActionExecutionClient):
self.sandbox = None
self.modal_client = modal.Client.from_credentials(
config.modal_api_token_id.get_secret_value(),
config.modal_api_token_secret.get_secret_value(),
config.modal_api_token_id, config.modal_api_token_secret
)
self.app = modal.App.lookup(
'openhands', create_if_missing=True, client=self.modal_client

View File

@ -40,7 +40,7 @@ class RunloopRuntime(ActionExecutionClient):
self.devbox: DevboxView | None = None
self.config = config
self.runloop_api_client = Runloop(
bearer_token=config.runloop_api_key.get_secret_value(),
bearer_token=config.runloop_api_key,
)
self.container_name = CONTAINER_NAME_PREFIX + sid
super().__init__(

View File

@ -51,8 +51,8 @@ async def get_litellm_models() -> list[str]:
):
bedrock_model_list = bedrock.list_foundation_models(
llm_config.aws_region_name,
llm_config.aws_access_key_id.get_secret_value(),
llm_config.aws_secret_access_key.get_secret_value(),
llm_config.aws_access_key_id,
llm_config.aws_secret_access_key,
)
model_list = litellm_model_list_without_bedrock + bedrock_model_list
for llm_config in config.llms.values():

View File

@ -90,9 +90,7 @@ class EmbeddingsLoader:
return OpenAIEmbedding(
model='text-embedding-ada-002',
api_key=llm_config.api_key.get_secret_value()
if llm_config.api_key
else None,
api_key=llm_config.api_key,
)
elif strategy == 'azureopenai':
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding

View File

@ -109,6 +109,9 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
print(f'Cancel requested: {is_set}')
return is_set
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
async def mock_acompletion(*args, **kwargs):
print('Starting mock_acompletion')
for i in range(20): # Increased iterations for longer running task
@ -150,6 +153,13 @@ async def test_async_completion_with_user_cancellation(cancel_delay):
async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
cancel_requested = False
async def mock_on_cancel_requested():
nonlocal cancel_requested
return cancel_requested
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
test_messages = [
'This is ',
'a test ',

View File

@ -60,6 +60,7 @@ def mock_state() -> State:
def test_cmd_output_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = CmdOutputObservation(
command='echo hello',
content='Command output',
@ -81,6 +82,7 @@ def test_cmd_output_observation_message(agent: CodeActAgent):
def test_ipython_run_cell_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = IPythonRunCellObservation(
code='plt.plot()',
content='IPython output\n![image]()',
@ -103,6 +105,7 @@ def test_ipython_run_cell_observation_message(agent: CodeActAgent):
def test_agent_delegate_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = AgentDelegateObservation(
content='Content', outputs={'content': 'Delegated agent output'}
)
@ -119,6 +122,7 @@ def test_agent_delegate_observation_message(agent: CodeActAgent):
def test_error_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = ErrorObservation('Error message')
results = agent.get_observation_message(obs, tool_call_id_to_message={})
@ -141,6 +145,7 @@ def test_unknown_observation_message(agent: CodeActAgent):
def test_file_edit_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = FileEditObservation(
path='/test/file.txt',
prev_exist=True,
@ -162,6 +167,7 @@ def test_file_edit_observation_message(agent: CodeActAgent):
def test_file_read_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = FileReadObservation(
path='/test/file.txt',
content='File content',
@ -180,6 +186,7 @@ def test_file_read_observation_message(agent: CodeActAgent):
def test_browser_output_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = BrowserOutputObservation(
url='http://example.com',
trigger_by_action='browse',
@ -200,6 +207,7 @@ def test_browser_output_observation_message(agent: CodeActAgent):
def test_user_reject_observation_message(agent: CodeActAgent):
agent.config.function_calling = False
obs = UserRejectObservation('Action rejected')
results = agent.get_observation_message(obs, tool_call_id_to_message={})
@ -215,6 +223,7 @@ def test_user_reject_observation_message(agent: CodeActAgent):
def test_function_calling_observation_message(agent: CodeActAgent):
agent.config.function_calling = True
mock_response = {
'id': 'mock_id',
'total_calls_in_response': 1,

View File

@ -226,7 +226,7 @@ def test_llm_condenser_from_config():
assert isinstance(condenser, LLMSummarizingCondenser)
assert condenser.llm.config.model == 'gpt-4o'
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
assert condenser.llm.config.api_key == 'test_key'
def test_llm_condenser(mock_llm, mock_state):
@ -381,7 +381,7 @@ def test_llm_attention_condenser_from_config():
assert isinstance(condenser, LLMAttentionCondenser)
assert condenser.llm.config.model == 'gpt-4o'
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
assert condenser.llm.config.api_key == 'test_key'
assert condenser.max_size == 50
assert condenser.keep_first == 10

View File

@ -63,7 +63,7 @@ def test_compat_env_to_config(monkeypatch, setup_env):
assert config.workspace_base == '/repos/openhands/workspace'
assert isinstance(config.get_llm_config(), LLMConfig)
assert config.get_llm_config().api_key.get_secret_value() == 'sk-proj-rgMV0...'
assert config.get_llm_config().api_key == 'sk-proj-rgMV0...'
assert config.get_llm_config().model == 'gpt-4o'
assert isinstance(config.get_agent_config(), AgentConfig)
assert isinstance(config.get_agent_config().memory_max_threads, int)
@ -83,7 +83,7 @@ def test_load_from_old_style_env(monkeypatch, default_config):
load_from_env(default_config, os.environ)
assert default_config.get_llm_config().api_key.get_secret_value() == 'test-api-key'
assert default_config.get_llm_config().api_key == 'test-api-key'
assert default_config.get_agent_config().memory_enabled is True
assert default_config.default_agent == 'BrowsingAgent'
assert default_config.workspace_base == '/opt/files/workspace'
@ -126,7 +126,7 @@ default_agent = "TestAgent"
# default llm & agent configs
assert default_config.default_agent == 'TestAgent'
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
assert default_config.get_llm_config().api_key == 'toml-api-key'
assert default_config.get_agent_config().memory_enabled is True
# undefined agent config inherits default ones
@ -291,7 +291,7 @@ sandbox_user_id = 1001
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_llm_config('llm').model == 'test-model'
assert default_config.get_llm_config_from_agent().model == 'test-model'
assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key'
assert default_config.get_llm_config().api_key == 'env-api-key'
# after we set workspace_base to 'UNDEFINED' in the environment,
# workspace_base should be set to that
@ -336,7 +336,7 @@ user_id = 1001
assert default_config.workspace_mount_path is None
# before load_from_env, values are set to the values from the toml file
assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key'
assert default_config.get_llm_config().api_key == 'toml-api-key'
assert default_config.sandbox.timeout == 500
assert default_config.sandbox.user_id == 1001
@ -345,7 +345,7 @@ user_id = 1001
# values from env override values from toml
assert os.environ.get('LLM_MODEL') is None
assert default_config.get_llm_config().model == 'test-model'
assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key'
assert default_config.get_llm_config().api_key == 'env-api-key'
assert default_config.sandbox.timeout == 1000
assert default_config.sandbox.user_id == 1002
@ -412,7 +412,7 @@ def test_security_config_from_dict():
# Test with all fields
config_dict = {'confirmation_mode': True, 'security_analyzer': 'some_analyzer'}
security_config = SecurityConfig(**config_dict)
security_config = SecurityConfig.from_dict(config_dict)
# Verify all fields are correctly set
assert security_config.confirmation_mode is True
@ -560,7 +560,10 @@ invalid_field_in_sandbox = "test"
assert 'Cannot parse [llm] config from toml' in log_content
assert 'values have not been applied' in log_content
# Error: LLMConfig.__init__() got an unexpected keyword argume
assert 'Error: 1 validation error for LLMConfig' in log_content
assert (
'Error: LLMConfig.__init__() got an unexpected keyword argume'
in log_content
)
assert 'invalid_field' in log_content
# invalid [sandbox] config
@ -632,14 +635,12 @@ def test_api_keys_repr_str():
aws_access_key_id='my_access_key',
aws_secret_access_key='my_secret_key',
)
# Check that no secret keys are emitted in representations of the config object
assert 'my_api_key' not in repr(llm_config)
assert 'my_api_key' not in str(llm_config)
assert 'my_access_key' not in repr(llm_config)
assert 'my_access_key' not in str(llm_config)
assert 'my_secret_key' not in repr(llm_config)
assert 'my_secret_key' not in str(llm_config)
assert "api_key='******'" in repr(llm_config)
assert "aws_access_key_id='******'" in repr(llm_config)
assert "aws_secret_access_key='******'" in repr(llm_config)
assert "api_key='******'" in str(llm_config)
assert "aws_access_key_id='******'" in str(llm_config)
assert "aws_secret_access_key='******'" in str(llm_config)
# Check that no other attrs in LLMConfig have 'key' or 'token' in their name
# This will fail when new attrs are added, and attract attention
@ -651,7 +652,7 @@ def test_api_keys_repr_str():
'output_cost_per_token',
'custom_tokenizer',
]
for attr_name in LLMConfig.model_fields.keys():
for attr_name in dir(LLMConfig):
if (
not attr_name.startswith('__')
and attr_name not in known_key_token_attrs_llm
@ -666,7 +667,7 @@ def test_api_keys_repr_str():
# Test AgentConfig
# No attrs in AgentConfig have 'key' or 'token' in their name
agent_config = AgentConfig(memory_enabled=True, memory_max_threads=4)
for attr_name in AgentConfig.model_fields.keys():
for attr_name in dir(AgentConfig):
if not attr_name.startswith('__'):
assert (
'key' not in attr_name.lower()
@ -685,16 +686,16 @@ def test_api_keys_repr_str():
modal_api_token_secret='my_modal_api_token_secret',
runloop_api_key='my_runloop_api_key',
)
assert 'my_e2b_api_key' not in repr(app_config)
assert 'my_e2b_api_key' not in str(app_config)
assert 'my_jwt_secret' not in repr(app_config)
assert 'my_jwt_secret' not in str(app_config)
assert 'my_modal_api_token_id' not in repr(app_config)
assert 'my_modal_api_token_id' not in str(app_config)
assert 'my_modal_api_token_secret' not in repr(app_config)
assert 'my_modal_api_token_secret' not in str(app_config)
assert 'my_runloop_api_key' not in repr(app_config)
assert 'my_runloop_api_key' not in str(app_config)
assert "e2b_api_key='******'" in repr(app_config)
assert "e2b_api_key='******'" in str(app_config)
assert "jwt_secret='******'" in repr(app_config)
assert "jwt_secret='******'" in str(app_config)
assert "modal_api_token_id='******'" in repr(app_config)
assert "modal_api_token_id='******'" in str(app_config)
assert "modal_api_token_secret='******'" in repr(app_config)
assert "modal_api_token_secret='******'" in str(app_config)
assert "runloop_api_key='******'" in repr(app_config)
assert "runloop_api_key='******'" in str(app_config)
# Check that no other attrs in AppConfig have 'key' or 'token' in their name
# This will fail when new attrs are added, and attract attention
@ -704,7 +705,7 @@ def test_api_keys_repr_str():
'modal_api_token_secret',
'runloop_api_key',
]
for attr_name in AppConfig.model_fields.keys():
for attr_name in dir(AppConfig):
if (
not attr_name.startswith('__')
and attr_name not in known_key_token_attrs_app

View File

@ -40,7 +40,7 @@ def default_config():
def test_llm_init_with_default_config(default_config):
llm = LLM(default_config)
assert llm.config.model == 'gpt-4o'
assert llm.config.api_key.get_secret_value() == 'test_key'
assert llm.config.api_key == 'test_key'
assert isinstance(llm.metrics, Metrics)
assert llm.metrics.model_name == 'gpt-4o'
@ -77,7 +77,7 @@ def test_llm_init_with_custom_config():
)
llm = LLM(custom_config)
assert llm.config.model == 'custom-model'
assert llm.config.api_key.get_secret_value() == 'custom_key'
assert llm.config.api_key == 'custom_key'
assert llm.config.max_input_tokens == 5000
assert llm.config.max_output_tokens == 1500
assert llm.config.temperature == 0.8