Fix mypy errors in core/config directory (#7113)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig
2025-03-25 05:57:00 -07:00
committed by GitHub
parent 8b473397d1
commit 0efe4feb2a
5 changed files with 32 additions and 20 deletions

View File

@@ -30,7 +30,9 @@ class AgentConfig(BaseModel):
disabled_microagents: list[str] = Field(default_factory=list)
enable_history_truncation: bool = Field(default=True)
enable_som_visual_browsing: bool = Field(default=False)
condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig)
condenser: CondenserConfig = Field(
default_factory=lambda: NoOpCondenserConfig(type='noop')
)
model_config = {'extra': 'forbid'}

View File

@@ -134,4 +134,5 @@ class AppConfig(BaseModel):
def model_post_init(self, __context):
"""Post-initialization hook, called when the instance is created with only default values."""
super().model_post_init(__context)
AppConfig.defaults_dict = model_defaults_to_dict(self)
if not AppConfig.defaults_dict: # Only set defaults_dict if it's empty
AppConfig.defaults_dict = model_defaults_to_dict(self)

View File

@@ -200,7 +200,7 @@ def condenser_config_from_toml_section(
f'Invalid condenser configuration: {e}. Using NoOpCondenserConfig.'
)
# Default to NoOpCondenserConfig if config fails
config = NoOpCondenserConfig()
config = NoOpCondenserConfig(type='noop')
condenser_mapping['condenser'] = config
return condenser_mapping

View File

@@ -1,7 +1,9 @@
from typing import Any
from pydantic import RootModel
class ExtendedConfig(RootModel[dict]):
class ExtendedConfig(RootModel[dict[str, Any]]):
"""Configuration for extended functionalities.
This is implemented as a root model so that the entire input is stored
@@ -9,31 +11,30 @@ class ExtendedConfig(RootModel[dict]):
accessed via attribute or dictionary-style access.
"""
@property
def root(self) -> dict: # type annotation to help mypy
return super().root
def __str__(self) -> str:
# Use the root dict to build a string representation.
attr_str = [f'{k}={repr(v)}' for k, v in self.root.items()]
return f"ExtendedConfig({', '.join(attr_str)})"
root_dict: dict[str, Any] = self.model_dump()
attr_str = [f'{k}={repr(v)}' for k, v in root_dict.items()]
return f'ExtendedConfig({", ".join(attr_str)})'
def __repr__(self) -> str:
return self.__str__()
@classmethod
def from_dict(cls, data: dict) -> 'ExtendedConfig':
def from_dict(cls, data: dict[str, Any]) -> 'ExtendedConfig':
# Create an instance directly by wrapping the input dict.
return cls(data)
def __getitem__(self, key: str) -> object:
def __getitem__(self, key: str) -> Any:
# Provide dictionary-like access via the root dict.
return self.root[key]
root_dict: dict[str, Any] = self.model_dump()
return root_dict[key]
def __getattr__(self, key: str) -> object:
def __getattr__(self, key: str) -> Any:
# Fallback for attribute access using the root dict.
try:
return self.root[key]
root_dict: dict[str, Any] = self.model_dump()
return root_dict[key]
except KeyError as e:
raise AttributeError(
f"'ExtendedConfig' object has no attribute '{key}'"

View File

@@ -5,7 +5,7 @@ import platform
import sys
from ast import literal_eval
from types import UnionType
from typing import Any, MutableMapping, get_args, get_origin
from typing import MutableMapping, get_args, get_origin
from uuid import uuid4
import toml
@@ -46,10 +46,16 @@ def load_from_env(
env_or_toml_dict: The environment variables or a config.toml dict.
"""
def get_optional_type(union_type: UnionType) -> Any:
def get_optional_type(union_type: UnionType | type | None) -> type | None:
"""Returns the non-None type from a Union."""
types = get_args(union_type)
return next((t for t in types if t is not type(None)), None)
if union_type is None:
return None
if get_origin(union_type) is UnionType:
types = get_args(union_type)
return next((t for t in types if t is not type(None)), None)
if isinstance(union_type, type):
return union_type
return None
# helper function to set attributes based on env vars
def set_attr_from_env(sub_config: BaseModel, prefix='') -> None:
@@ -85,7 +91,8 @@ def load_from_env(
elif get_origin(field_type) is dict:
cast_value = literal_eval(value)
else:
cast_value = field_type(value)
if field_type is not None:
cast_value = field_type(value)
setattr(sub_config, field_name, cast_value)
except (ValueError, TypeError):
logger.openhands_logger.error(
@@ -225,6 +232,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml') -> None:
# Create default LLM summarizing condenser config
default_condenser = LLMSummarizingCondenserConfig(
llm_config=cfg.get_llm_config(), # Use default LLM config
type='llm',
)
# Set as default condenser