Add more extensive typing to openhands/core directory (#7728)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-04-08 17:38:44 -04:00 committed by GitHub
parent d48e2a4cf1
commit e698a393b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 48 additions and 41 deletions

View File

@ -43,7 +43,7 @@ from openhands.io import read_task
prompt_session = PromptSession()
def display_message(message: str):
def display_message(message: str) -> None:
print_formatted_text(
FormattedText(
[
@ -55,7 +55,7 @@ def display_message(message: str):
)
def display_command(command: str):
def display_command(command: str) -> None:
print_formatted_text(
FormattedText(
[
@ -67,7 +67,7 @@ def display_command(command: str):
)
def display_confirmation(confirmation_state: ActionConfirmationStatus):
def display_confirmation(confirmation_state: ActionConfirmationStatus) -> None:
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
print_formatted_text(
FormattedText(
@ -100,7 +100,7 @@ def display_confirmation(confirmation_state: ActionConfirmationStatus):
)
def display_command_output(output: str):
def display_command_output(output: str) -> None:
lines = output.split('\n')
for line in lines:
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
@ -110,7 +110,7 @@ def display_command_output(output: str):
print_formatted_text('')
def display_file_edit(event: FileEditAction | FileEditObservation):
def display_file_edit(event: FileEditAction | FileEditObservation) -> None:
print_formatted_text(
FormattedText(
[
@ -121,7 +121,7 @@ def display_file_edit(event: FileEditAction | FileEditObservation):
)
def display_event(event: Event, config: AppConfig):
def display_event(event: Event, config: AppConfig) -> None:
if isinstance(event, Action):
if hasattr(event, 'thought'):
display_message(event.thought)
@ -175,7 +175,7 @@ async def read_confirmation_input():
return False
async def main(loop: asyncio.AbstractEventLoop):
async def main(loop: asyncio.AbstractEventLoop) -> None:
"""Runs the agent in CLI mode."""
args = parse_arguments()
@ -207,7 +207,7 @@ async def main(loop: asyncio.AbstractEventLoop):
event_stream = runtime.event_stream
async def prompt_for_next_task():
async def prompt_for_next_task() -> None:
next_message = await read_prompt_input(config.cli_multiline_input)
if not next_message.strip():
await prompt_for_next_task()
@ -219,7 +219,7 @@ async def main(loop: asyncio.AbstractEventLoop):
action = MessageAction(content=next_message)
event_stream.add_event(action, EventSource.USER)
async def on_event_async(event: Event):
async def on_event_async(event: Event) -> None:
display_event(event, config)
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in [

View File

@ -1,4 +1,4 @@
from typing import ClassVar
from typing import Any, ClassVar
from pydantic import BaseModel, Field, SecretStr
@ -50,7 +50,7 @@ class AppConfig(BaseModel):
"""
llms: dict[str, LLMConfig] = Field(default_factory=dict)
agents: dict = Field(default_factory=dict)
agents: dict[str, AgentConfig] = Field(default_factory=dict)
default_agent: str = Field(default=OH_DEFAULT_AGENT)
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
security: SecurityConfig = Field(default_factory=SecurityConfig)
@ -93,7 +93,7 @@ class AppConfig(BaseModel):
model_config = {'extra': 'forbid'}
def get_llm_config(self, name='llm') -> LLMConfig:
def get_llm_config(self, name: str = 'llm') -> LLMConfig:
"""'llm' is the name for default config (for backward compatibility prior to 0.8)."""
if name in self.llms:
return self.llms[name]
@ -105,10 +105,10 @@ class AppConfig(BaseModel):
self.llms['llm'] = LLMConfig()
return self.llms['llm']
def set_llm_config(self, value: LLMConfig, name='llm') -> None:
def set_llm_config(self, value: LLMConfig, name: str = 'llm') -> None:
self.llms[name] = value
def get_agent_config(self, name='agent') -> AgentConfig:
def get_agent_config(self, name: str = 'agent') -> AgentConfig:
"""'agent' is the name for default config (for backward compatibility prior to 0.8)."""
if name in self.agents:
return self.agents[name]
@ -116,22 +116,24 @@ class AppConfig(BaseModel):
self.agents['agent'] = AgentConfig()
return self.agents['agent']
def set_agent_config(self, value: AgentConfig, name='agent') -> None:
def set_agent_config(self, value: AgentConfig, name: str = 'agent') -> None:
self.agents[name] = value
def get_agent_to_llm_config_map(self) -> dict[str, LLMConfig]:
"""Get a map of agent names to llm configs."""
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
def get_llm_config_from_agent(self, name='agent') -> LLMConfig:
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
agent_config: AgentConfig = self.get_agent_config(name)
llm_config_name = agent_config.llm_config
llm_config_name = (
agent_config.llm_config if agent_config.llm_config is not None else 'llm'
)
return self.get_llm_config(llm_config_name)
def get_agent_configs(self) -> dict[str, AgentConfig]:
return self.agents
def model_post_init(self, __context):
def model_post_init(self, __context: Any) -> None:
"""Post-initialization hook, called when the instance is created with only default values."""
super().model_post_init(__context)
if not AppConfig.defaults_dict: # Only set defaults_dict if it's empty

View File

@ -151,7 +151,7 @@ class LLMConfig(BaseModel):
return llm_mapping
def model_post_init(self, __context: Any):
def model_post_init(self, __context: Any) -> None:
"""Post-initialization hook to assign OpenRouter-related variables to environment variables.
This ensures that these values are accessible to litellm at runtime.

View File

@ -58,7 +58,7 @@ def load_from_env(
return None
# helper function to set attributes based on env vars
def set_attr_from_env(sub_config: BaseModel, prefix='') -> None:
def set_attr_from_env(sub_config: BaseModel, prefix: str = '') -> None:
"""Set attributes of a config model based on environment variables."""
for field_name, field_info in sub_config.model_fields.items():
field_value = getattr(sub_config, field_name)
@ -275,7 +275,7 @@ def get_or_create_jwt_secret(file_store: FileStore) -> str:
return new_secret
def finalize_config(cfg: AppConfig):
def finalize_config(cfg: AppConfig) -> None:
"""More tweaks to the config after it's been loaded."""
if cfg.workspace_base is not None:
cfg.workspace_base = os.path.abspath(cfg.workspace_base)

View File

@ -6,7 +6,7 @@ import sys
import traceback
from datetime import datetime
from types import TracebackType
from typing import Any, Literal, Mapping, TextIO
from typing import Any, Literal, Mapping, MutableMapping, TextIO
import litellm
from pythonjsonlogger.json import JsonFormatter
@ -304,7 +304,7 @@ def get_file_handler(
return file_handler
def json_formatter():
def json_formatter() -> JsonFormatter:
return JsonFormatter(
'{message}{levelname}',
style='{',
@ -471,11 +471,15 @@ llm_response_logger = _setup_llm_logger('response', current_log_level)
class OpenHandsLoggerAdapter(logging.LoggerAdapter):
extra: dict
def __init__(self, logger=openhands_logger, extra=None):
def __init__(
self, logger: logging.Logger = openhands_logger, extra: dict | None = None
) -> None:
self.logger = logger
self.extra = extra or {}
def process(self, msg, kwargs):
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
) -> tuple[str, MutableMapping[str, Any]]:
"""
If 'extra' is supplied in kwargs, merge it with the adapters 'extra' dict
Starting in Python 3.13, LoggerAdapter's merge_extra option will do this.

View File

@ -12,14 +12,14 @@ async def run_agent_until_done(
runtime: Runtime,
memory: Memory,
end_states: list[AgentState],
):
) -> None:
"""
run_agent_until_done takes a controller and a runtime, and will run
the agent until it reaches a terminal state.
Note that runtime must be connected before being passed in here.
"""
def status_callback(msg_type, msg_id, msg):
def status_callback(msg_type: str, msg_id: str, msg: str) -> None:
if msg_type == 'error':
logger.error(msg)
if controller:

View File

@ -163,7 +163,7 @@ async def run_controller(
# init with the provided actions
event_stream.add_event(initial_user_action, EventSource.USER)
def on_event(event: Event):
def on_event(event: Event) -> None:
if isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.AWAITING_USER_INPUT:
if exit_on_message:

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Literal
from typing import Any, Literal
from litellm import ChatCompletionMessageToolCall
from pydantic import BaseModel, Field, model_serializer
@ -14,7 +14,7 @@ class Content(BaseModel):
type: str
cache_prompt: bool = False
@model_serializer
@model_serializer(mode='plain')
def serialize_model(
self,
) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]]:
@ -25,7 +25,7 @@ class TextContent(Content):
type: str = ContentType.TEXT.value
text: str
@model_serializer
@model_serializer(mode='plain')
def serialize_model(self) -> dict[str, str | dict[str, str]]:
data: dict[str, str | dict[str, str]] = {
'type': self.type,
@ -40,7 +40,7 @@ class ImageContent(Content):
type: str = ContentType.IMAGE_URL.value
image_urls: list[str]
@model_serializer
@model_serializer(mode='plain')
def serialize_model(self) -> list[dict[str, str | dict[str, str]]]:
images: list[dict[str, str | dict[str, str]]] = []
for url in self.image_urls:
@ -71,8 +71,8 @@ class Message(BaseModel):
def contains_image(self) -> bool:
return any(isinstance(content, ImageContent) for content in self.content)
@model_serializer
def serialize_model(self) -> dict:
@model_serializer(mode='plain')
def serialize_model(self) -> dict[str, Any]:
# We need two kinds of serializations:
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
@ -84,18 +84,18 @@ class Message(BaseModel):
# some providers, like HF and Groq/llama, don't support a list here, but a single string
return self._string_serializer()
def _string_serializer(self) -> dict:
def _string_serializer(self) -> dict[str, Any]:
# convert content to a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
message_dict: dict = {'content': content, 'role': self.role}
message_dict: dict[str, Any] = {'content': content, 'role': self.role}
# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)
def _list_serializer(self) -> dict:
content: list[dict] = []
def _list_serializer(self) -> dict[str, Any]:
content: list[dict[str, Any]] = []
role_tool_with_prompt_caching = False
for item in self.content:
d = item.model_dump()
@ -120,7 +120,7 @@ class Message(BaseModel):
# We know d is a list for ImageContent
content.extend([d] if isinstance(d, dict) else d)
message_dict: dict = {'content': content, 'role': self.role}
message_dict: dict[str, Any] = {'content': content, 'role': self.role}
if role_tool_with_prompt_caching:
message_dict['cache_control'] = {'type': 'ephemeral'}
@ -128,7 +128,7 @@ class Message(BaseModel):
# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)
def _add_tool_call_keys(self, message_dict: dict) -> dict:
def _add_tool_call_keys(self, message_dict: dict[str, Any]) -> dict[str, Any]:
"""Add tool call keys if we have a tool call or response.
NOTE: this is necessary for both native and non-native tool calling

View File

@ -167,8 +167,9 @@ class Session:
"""
Initialize LLM, extracted for testing.
"""
agent_name = agent_cls if agent_cls is not None else 'agent'
return LLM(
config=self.config.get_llm_config_from_agent(agent_cls),
config=self.config.get_llm_config_from_agent(agent_name),
retry_listener=self._notify_on_llm_retry,
)