From e698a393b25177b05c6bce529bc3cb2a84a25b91 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Tue, 8 Apr 2025 17:38:44 -0400 Subject: [PATCH] Add more extensive typing to openhands/core directory (#7728) Co-authored-by: openhands --- openhands/core/cli.py | 18 +++++++++--------- openhands/core/config/app_config.py | 20 +++++++++++--------- openhands/core/config/llm_config.py | 2 +- openhands/core/config/utils.py | 4 ++-- openhands/core/logger.py | 12 ++++++++---- openhands/core/loop.py | 4 ++-- openhands/core/main.py | 2 +- openhands/core/message.py | 24 ++++++++++++------------ openhands/server/session/session.py | 3 ++- 9 files changed, 48 insertions(+), 41 deletions(-) diff --git a/openhands/core/cli.py b/openhands/core/cli.py index 4d381984b3..bf53f2de7d 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -43,7 +43,7 @@ from openhands.io import read_task prompt_session = PromptSession() -def display_message(message: str): +def display_message(message: str) -> None: print_formatted_text( FormattedText( [ @@ -55,7 +55,7 @@ def display_message(message: str): ) -def display_command(command: str): +def display_command(command: str) -> None: print_formatted_text( FormattedText( [ @@ -67,7 +67,7 @@ def display_command(command: str): ) -def display_confirmation(confirmation_state: ActionConfirmationStatus): +def display_confirmation(confirmation_state: ActionConfirmationStatus) -> None: if confirmation_state == ActionConfirmationStatus.CONFIRMED: print_formatted_text( FormattedText( @@ -100,7 +100,7 @@ def display_confirmation(confirmation_state: ActionConfirmationStatus): ) -def display_command_output(output: str): +def display_command_output(output: str) -> None: lines = output.split('\n') for line in lines: if line.startswith('[Python Interpreter') or line.startswith('openhands@'): @@ -110,7 +110,7 @@ def display_command_output(output: str): print_formatted_text('') -def display_file_edit(event: FileEditAction | FileEditObservation): +def display_file_edit(event: FileEditAction | FileEditObservation) -> None: print_formatted_text( FormattedText( [ @@ -121,7 +121,7 @@ def display_file_edit(event: FileEditAction | FileEditObservation): ) -def display_event(event: Event, config: AppConfig): +def display_event(event: Event, config: AppConfig) -> None: if isinstance(event, Action): if hasattr(event, 'thought'): display_message(event.thought) @@ -175,7 +175,7 @@ async def read_confirmation_input(): return False -async def main(loop: asyncio.AbstractEventLoop): +async def main(loop: asyncio.AbstractEventLoop) -> None: """Runs the agent in CLI mode.""" args = parse_arguments() @@ -207,7 +207,7 @@ async def main(loop: asyncio.AbstractEventLoop): event_stream = runtime.event_stream - async def prompt_for_next_task(): + async def prompt_for_next_task() -> None: next_message = await read_prompt_input(config.cli_multiline_input) if not next_message.strip(): await prompt_for_next_task() @@ -219,7 +219,7 @@ async def main(loop: asyncio.AbstractEventLoop): action = MessageAction(content=next_message) event_stream.add_event(action, EventSource.USER) - async def on_event_async(event: Event): + async def on_event_async(event: Event) -> None: display_event(event, config) if isinstance(event, AgentStateChangedObservation): if event.agent_state in [ diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index fd1e9752f7..48f2b870ba 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -1,4 +1,4 @@ -from typing import ClassVar +from typing import Any, ClassVar from pydantic import BaseModel, Field, SecretStr @@ -50,7 +50,7 @@ class AppConfig(BaseModel): """ llms: dict[str, LLMConfig] = Field(default_factory=dict) - agents: dict = Field(default_factory=dict) + agents: dict[str, AgentConfig] = Field(default_factory=dict) default_agent: str = Field(default=OH_DEFAULT_AGENT) sandbox: SandboxConfig = Field(default_factory=SandboxConfig) security: SecurityConfig = Field(default_factory=SecurityConfig) @@ -93,7 +93,7 @@ class AppConfig(BaseModel): model_config = {'extra': 'forbid'} - def get_llm_config(self, name='llm') -> LLMConfig: + def get_llm_config(self, name: str = 'llm') -> LLMConfig: """'llm' is the name for default config (for backward compatibility prior to 0.8).""" if name in self.llms: return self.llms[name] @@ -105,10 +105,10 @@ class AppConfig(BaseModel): self.llms['llm'] = LLMConfig() return self.llms['llm'] - def set_llm_config(self, value: LLMConfig, name='llm') -> None: + def set_llm_config(self, value: LLMConfig, name: str = 'llm') -> None: self.llms[name] = value - def get_agent_config(self, name='agent') -> AgentConfig: + def get_agent_config(self, name: str = 'agent') -> AgentConfig: """'agent' is the name for default config (for backward compatibility prior to 0.8).""" if name in self.agents: return self.agents[name] @@ -116,22 +116,24 @@ class AppConfig(BaseModel): self.agents['agent'] = AgentConfig() return self.agents['agent'] - def set_agent_config(self, value: AgentConfig, name='agent') -> None: + def set_agent_config(self, value: AgentConfig, name: str = 'agent') -> None: self.agents[name] = value def get_agent_to_llm_config_map(self) -> dict[str, LLMConfig]: """Get a map of agent names to llm configs.""" return {name: self.get_llm_config_from_agent(name) for name in self.agents} - def get_llm_config_from_agent(self, name='agent') -> LLMConfig: + def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig: agent_config: AgentConfig = self.get_agent_config(name) - llm_config_name = agent_config.llm_config + llm_config_name = ( + agent_config.llm_config if agent_config.llm_config is not None else 'llm' + ) return self.get_llm_config(llm_config_name) def get_agent_configs(self) -> dict[str, AgentConfig]: return self.agents - def model_post_init(self, __context): + def model_post_init(self, __context: Any) -> None: """Post-initialization hook, called when the instance is created with only default values.""" super().model_post_init(__context) if not AppConfig.defaults_dict: # Only set defaults_dict if it's empty diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 00d440fe4d..0ff8a48060 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -151,7 +151,7 @@ class LLMConfig(BaseModel): return llm_mapping - def model_post_init(self, __context: Any): + def model_post_init(self, __context: Any) -> None: """Post-initialization hook to assign OpenRouter-related variables to environment variables. This ensures that these values are accessible to litellm at runtime. diff --git a/openhands/core/config/utils.py b/openhands/core/config/utils.py index 1f8452ec1e..0a2bbbe165 100644 --- a/openhands/core/config/utils.py +++ b/openhands/core/config/utils.py @@ -58,7 +58,7 @@ def load_from_env( return None # helper function to set attributes based on env vars - def set_attr_from_env(sub_config: BaseModel, prefix='') -> None: + def set_attr_from_env(sub_config: BaseModel, prefix: str = '') -> None: """Set attributes of a config model based on environment variables.""" for field_name, field_info in sub_config.model_fields.items(): field_value = getattr(sub_config, field_name) @@ -275,7 +275,7 @@ def get_or_create_jwt_secret(file_store: FileStore) -> str: return new_secret -def finalize_config(cfg: AppConfig): +def finalize_config(cfg: AppConfig) -> None: """More tweaks to the config after it's been loaded.""" if cfg.workspace_base is not None: cfg.workspace_base = os.path.abspath(cfg.workspace_base) diff --git a/openhands/core/logger.py b/openhands/core/logger.py index 0f1619d958..30f547f058 100644 --- a/openhands/core/logger.py +++ b/openhands/core/logger.py @@ -6,7 +6,7 @@ import sys import traceback from datetime import datetime from types import TracebackType -from typing import Any, Literal, Mapping, TextIO +from typing import Any, Literal, Mapping, MutableMapping, TextIO import litellm from pythonjsonlogger.json import JsonFormatter @@ -304,7 +304,7 @@ def get_file_handler( return file_handler -def json_formatter(): +def json_formatter() -> JsonFormatter: return JsonFormatter( '{message}{levelname}', style='{', @@ -471,11 +471,15 @@ llm_response_logger = _setup_llm_logger('response', current_log_level) class OpenHandsLoggerAdapter(logging.LoggerAdapter): extra: dict - def __init__(self, logger=openhands_logger, extra=None): + def __init__( + self, logger: logging.Logger = openhands_logger, extra: dict | None = None + ) -> None: self.logger = logger self.extra = extra or {} - def process(self, msg, kwargs): + def process( + self, msg: str, kwargs: MutableMapping[str, Any] + ) -> tuple[str, MutableMapping[str, Any]]: """ If 'extra' is supplied in kwargs, merge it with the adapters 'extra' dict Starting in Python 3.13, LoggerAdapter's merge_extra option will do this. diff --git a/openhands/core/loop.py b/openhands/core/loop.py index daf95d2f01..9163373c14 100644 --- a/openhands/core/loop.py +++ b/openhands/core/loop.py @@ -12,14 +12,14 @@ async def run_agent_until_done( runtime: Runtime, memory: Memory, end_states: list[AgentState], -): +) -> None: """ run_agent_until_done takes a controller and a runtime, and will run the agent until it reaches a terminal state. Note that runtime must be connected before being passed in here. """ - def status_callback(msg_type, msg_id, msg): + def status_callback(msg_type: str, msg_id: str, msg: str) -> None: if msg_type == 'error': logger.error(msg) if controller: diff --git a/openhands/core/main.py b/openhands/core/main.py index 5df5c66be9..2c5ebdbdef 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -163,7 +163,7 @@ async def run_controller( # init with the provided actions event_stream.add_event(initial_user_action, EventSource.USER) - def on_event(event: Event): + def on_event(event: Event) -> None: if isinstance(event, AgentStateChangedObservation): if event.agent_state == AgentState.AWAITING_USER_INPUT: if exit_on_message: diff --git a/openhands/core/message.py b/openhands/core/message.py index 73dd300e28..539df52b5e 100644 --- a/openhands/core/message.py +++ b/openhands/core/message.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Literal +from typing import Any, Literal from litellm import ChatCompletionMessageToolCall from pydantic import BaseModel, Field, model_serializer @@ -14,7 +14,7 @@ class Content(BaseModel): type: str cache_prompt: bool = False - @model_serializer + @model_serializer(mode='plain') def serialize_model( self, ) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]]: @@ -25,7 +25,7 @@ class TextContent(Content): type: str = ContentType.TEXT.value text: str - @model_serializer + @model_serializer(mode='plain') def serialize_model(self) -> dict[str, str | dict[str, str]]: data: dict[str, str | dict[str, str]] = { 'type': self.type, @@ -40,7 +40,7 @@ class ImageContent(Content): type: str = ContentType.IMAGE_URL.value image_urls: list[str] - @model_serializer + @model_serializer(mode='plain') def serialize_model(self) -> list[dict[str, str | dict[str, str]]]: images: list[dict[str, str | dict[str, str]]] = [] for url in self.image_urls: @@ -71,8 +71,8 @@ class Message(BaseModel): def contains_image(self) -> bool: return any(isinstance(content, ImageContent) for content in self.content) - @model_serializer - def serialize_model(self) -> dict: + @model_serializer(mode='plain') + def serialize_model(self) -> dict[str, Any]: # We need two kinds of serializations: # - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls) # - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls @@ -84,18 +84,18 @@ class Message(BaseModel): # some providers, like HF and Groq/llama, don't support a list here, but a single string return self._string_serializer() - def _string_serializer(self) -> dict: + def _string_serializer(self) -> dict[str, Any]: # convert content to a single string content = '\n'.join( item.text for item in self.content if isinstance(item, TextContent) ) - message_dict: dict = {'content': content, 'role': self.role} + message_dict: dict[str, Any] = {'content': content, 'role': self.role} # add tool call keys if we have a tool call or response return self._add_tool_call_keys(message_dict) - def _list_serializer(self) -> dict: - content: list[dict] = [] + def _list_serializer(self) -> dict[str, Any]: + content: list[dict[str, Any]] = [] role_tool_with_prompt_caching = False for item in self.content: d = item.model_dump() @@ -120,7 +120,7 @@ class Message(BaseModel): # We know d is a list for ImageContent content.extend([d] if isinstance(d, dict) else d) - message_dict: dict = {'content': content, 'role': self.role} + message_dict: dict[str, Any] = {'content': content, 'role': self.role} if role_tool_with_prompt_caching: message_dict['cache_control'] = {'type': 'ephemeral'} @@ -128,7 +128,7 @@ class Message(BaseModel): # add tool call keys if we have a tool call or response return self._add_tool_call_keys(message_dict) - def _add_tool_call_keys(self, message_dict: dict) -> dict: + def _add_tool_call_keys(self, message_dict: dict[str, Any]) -> dict[str, Any]: """Add tool call keys if we have a tool call or response. NOTE: this is necessary for both native and non-native tool calling diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index c2ce16eca9..0603c7d9d0 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -167,8 +167,9 @@ class Session: """ Initialize LLM, extracted for testing. """ + agent_name = agent_cls if agent_cls is not None else 'agent' return LLM( - config=self.config.get_llm_config_from_agent(agent_cls), + config=self.config.get_llm_config_from_agent(agent_name), retry_listener=self._notify_on_llm_retry, )