mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add more extensive typing to openhands/core directory (#7728)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
parent
d48e2a4cf1
commit
e698a393b2
@ -43,7 +43,7 @@ from openhands.io import read_task
|
||||
prompt_session = PromptSession()
|
||||
|
||||
|
||||
def display_message(message: str):
|
||||
def display_message(message: str) -> None:
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
@ -55,7 +55,7 @@ def display_message(message: str):
|
||||
)
|
||||
|
||||
|
||||
def display_command(command: str):
|
||||
def display_command(command: str) -> None:
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
@ -67,7 +67,7 @@ def display_command(command: str):
|
||||
)
|
||||
|
||||
|
||||
def display_confirmation(confirmation_state: ActionConfirmationStatus):
|
||||
def display_confirmation(confirmation_state: ActionConfirmationStatus) -> None:
|
||||
if confirmation_state == ActionConfirmationStatus.CONFIRMED:
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
@ -100,7 +100,7 @@ def display_confirmation(confirmation_state: ActionConfirmationStatus):
|
||||
)
|
||||
|
||||
|
||||
def display_command_output(output: str):
|
||||
def display_command_output(output: str) -> None:
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
|
||||
@ -110,7 +110,7 @@ def display_command_output(output: str):
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
def display_file_edit(event: FileEditAction | FileEditObservation):
|
||||
def display_file_edit(event: FileEditAction | FileEditObservation) -> None:
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
@ -121,7 +121,7 @@ def display_file_edit(event: FileEditAction | FileEditObservation):
|
||||
)
|
||||
|
||||
|
||||
def display_event(event: Event, config: AppConfig):
|
||||
def display_event(event: Event, config: AppConfig) -> None:
|
||||
if isinstance(event, Action):
|
||||
if hasattr(event, 'thought'):
|
||||
display_message(event.thought)
|
||||
@ -175,7 +175,7 @@ async def read_confirmation_input():
|
||||
return False
|
||||
|
||||
|
||||
async def main(loop: asyncio.AbstractEventLoop):
|
||||
async def main(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Runs the agent in CLI mode."""
|
||||
|
||||
args = parse_arguments()
|
||||
@ -207,7 +207,7 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
|
||||
async def prompt_for_next_task():
|
||||
async def prompt_for_next_task() -> None:
|
||||
next_message = await read_prompt_input(config.cli_multiline_input)
|
||||
if not next_message.strip():
|
||||
await prompt_for_next_task()
|
||||
@ -219,7 +219,7 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
action = MessageAction(content=next_message)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
async def on_event_async(event: Event):
|
||||
async def on_event_async(event: Event) -> None:
|
||||
display_event(event, config)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in [
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import ClassVar
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
@ -50,7 +50,7 @@ class AppConfig(BaseModel):
|
||||
"""
|
||||
|
||||
llms: dict[str, LLMConfig] = Field(default_factory=dict)
|
||||
agents: dict = Field(default_factory=dict)
|
||||
agents: dict[str, AgentConfig] = Field(default_factory=dict)
|
||||
default_agent: str = Field(default=OH_DEFAULT_AGENT)
|
||||
sandbox: SandboxConfig = Field(default_factory=SandboxConfig)
|
||||
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
||||
@ -93,7 +93,7 @@ class AppConfig(BaseModel):
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
def get_llm_config(self, name='llm') -> LLMConfig:
|
||||
def get_llm_config(self, name: str = 'llm') -> LLMConfig:
|
||||
"""'llm' is the name for default config (for backward compatibility prior to 0.8)."""
|
||||
if name in self.llms:
|
||||
return self.llms[name]
|
||||
@ -105,10 +105,10 @@ class AppConfig(BaseModel):
|
||||
self.llms['llm'] = LLMConfig()
|
||||
return self.llms['llm']
|
||||
|
||||
def set_llm_config(self, value: LLMConfig, name='llm') -> None:
|
||||
def set_llm_config(self, value: LLMConfig, name: str = 'llm') -> None:
|
||||
self.llms[name] = value
|
||||
|
||||
def get_agent_config(self, name='agent') -> AgentConfig:
|
||||
def get_agent_config(self, name: str = 'agent') -> AgentConfig:
|
||||
"""'agent' is the name for default config (for backward compatibility prior to 0.8)."""
|
||||
if name in self.agents:
|
||||
return self.agents[name]
|
||||
@ -116,22 +116,24 @@ class AppConfig(BaseModel):
|
||||
self.agents['agent'] = AgentConfig()
|
||||
return self.agents['agent']
|
||||
|
||||
def set_agent_config(self, value: AgentConfig, name='agent') -> None:
|
||||
def set_agent_config(self, value: AgentConfig, name: str = 'agent') -> None:
|
||||
self.agents[name] = value
|
||||
|
||||
def get_agent_to_llm_config_map(self) -> dict[str, LLMConfig]:
|
||||
"""Get a map of agent names to llm configs."""
|
||||
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
|
||||
|
||||
def get_llm_config_from_agent(self, name='agent') -> LLMConfig:
|
||||
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
|
||||
agent_config: AgentConfig = self.get_agent_config(name)
|
||||
llm_config_name = agent_config.llm_config
|
||||
llm_config_name = (
|
||||
agent_config.llm_config if agent_config.llm_config is not None else 'llm'
|
||||
)
|
||||
return self.get_llm_config(llm_config_name)
|
||||
|
||||
def get_agent_configs(self) -> dict[str, AgentConfig]:
|
||||
return self.agents
|
||||
|
||||
def model_post_init(self, __context):
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Post-initialization hook, called when the instance is created with only default values."""
|
||||
super().model_post_init(__context)
|
||||
if not AppConfig.defaults_dict: # Only set defaults_dict if it's empty
|
||||
|
||||
@ -151,7 +151,7 @@ class LLMConfig(BaseModel):
|
||||
|
||||
return llm_mapping
|
||||
|
||||
def model_post_init(self, __context: Any):
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Post-initialization hook to assign OpenRouter-related variables to environment variables.
|
||||
|
||||
This ensures that these values are accessible to litellm at runtime.
|
||||
|
||||
@ -58,7 +58,7 @@ def load_from_env(
|
||||
return None
|
||||
|
||||
# helper function to set attributes based on env vars
|
||||
def set_attr_from_env(sub_config: BaseModel, prefix='') -> None:
|
||||
def set_attr_from_env(sub_config: BaseModel, prefix: str = '') -> None:
|
||||
"""Set attributes of a config model based on environment variables."""
|
||||
for field_name, field_info in sub_config.model_fields.items():
|
||||
field_value = getattr(sub_config, field_name)
|
||||
@ -275,7 +275,7 @@ def get_or_create_jwt_secret(file_store: FileStore) -> str:
|
||||
return new_secret
|
||||
|
||||
|
||||
def finalize_config(cfg: AppConfig):
|
||||
def finalize_config(cfg: AppConfig) -> None:
|
||||
"""More tweaks to the config after it's been loaded."""
|
||||
if cfg.workspace_base is not None:
|
||||
cfg.workspace_base = os.path.abspath(cfg.workspace_base)
|
||||
|
||||
@ -6,7 +6,7 @@ import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import Any, Literal, Mapping, TextIO
|
||||
from typing import Any, Literal, Mapping, MutableMapping, TextIO
|
||||
|
||||
import litellm
|
||||
from pythonjsonlogger.json import JsonFormatter
|
||||
@ -304,7 +304,7 @@ def get_file_handler(
|
||||
return file_handler
|
||||
|
||||
|
||||
def json_formatter():
|
||||
def json_formatter() -> JsonFormatter:
|
||||
return JsonFormatter(
|
||||
'{message}{levelname}',
|
||||
style='{',
|
||||
@ -471,11 +471,15 @@ llm_response_logger = _setup_llm_logger('response', current_log_level)
|
||||
class OpenHandsLoggerAdapter(logging.LoggerAdapter):
|
||||
extra: dict
|
||||
|
||||
def __init__(self, logger=openhands_logger, extra=None):
|
||||
def __init__(
|
||||
self, logger: logging.Logger = openhands_logger, extra: dict | None = None
|
||||
) -> None:
|
||||
self.logger = logger
|
||||
self.extra = extra or {}
|
||||
|
||||
def process(self, msg, kwargs):
|
||||
def process(
|
||||
self, msg: str, kwargs: MutableMapping[str, Any]
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
"""
|
||||
If 'extra' is supplied in kwargs, merge it with the adapters 'extra' dict
|
||||
Starting in Python 3.13, LoggerAdapter's merge_extra option will do this.
|
||||
|
||||
@ -12,14 +12,14 @@ async def run_agent_until_done(
|
||||
runtime: Runtime,
|
||||
memory: Memory,
|
||||
end_states: list[AgentState],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
run_agent_until_done takes a controller and a runtime, and will run
|
||||
the agent until it reaches a terminal state.
|
||||
Note that runtime must be connected before being passed in here.
|
||||
"""
|
||||
|
||||
def status_callback(msg_type, msg_id, msg):
|
||||
def status_callback(msg_type: str, msg_id: str, msg: str) -> None:
|
||||
if msg_type == 'error':
|
||||
logger.error(msg)
|
||||
if controller:
|
||||
|
||||
@ -163,7 +163,7 @@ async def run_controller(
|
||||
# init with the provided actions
|
||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||
|
||||
def on_event(event: Event):
|
||||
def on_event(event: Event) -> None:
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.AWAITING_USER_INPUT:
|
||||
if exit_on_message:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from litellm import ChatCompletionMessageToolCall
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
@ -14,7 +14,7 @@ class Content(BaseModel):
|
||||
type: str
|
||||
cache_prompt: bool = False
|
||||
|
||||
@model_serializer
|
||||
@model_serializer(mode='plain')
|
||||
def serialize_model(
|
||||
self,
|
||||
) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]]:
|
||||
@ -25,7 +25,7 @@ class TextContent(Content):
|
||||
type: str = ContentType.TEXT.value
|
||||
text: str
|
||||
|
||||
@model_serializer
|
||||
@model_serializer(mode='plain')
|
||||
def serialize_model(self) -> dict[str, str | dict[str, str]]:
|
||||
data: dict[str, str | dict[str, str]] = {
|
||||
'type': self.type,
|
||||
@ -40,7 +40,7 @@ class ImageContent(Content):
|
||||
type: str = ContentType.IMAGE_URL.value
|
||||
image_urls: list[str]
|
||||
|
||||
@model_serializer
|
||||
@model_serializer(mode='plain')
|
||||
def serialize_model(self) -> list[dict[str, str | dict[str, str]]]:
|
||||
images: list[dict[str, str | dict[str, str]]] = []
|
||||
for url in self.image_urls:
|
||||
@ -71,8 +71,8 @@ class Message(BaseModel):
|
||||
def contains_image(self) -> bool:
|
||||
return any(isinstance(content, ImageContent) for content in self.content)
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self) -> dict:
|
||||
@model_serializer(mode='plain')
|
||||
def serialize_model(self) -> dict[str, Any]:
|
||||
# We need two kinds of serializations:
|
||||
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
|
||||
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
|
||||
@ -84,18 +84,18 @@ class Message(BaseModel):
|
||||
# some providers, like HF and Groq/llama, don't support a list here, but a single string
|
||||
return self._string_serializer()
|
||||
|
||||
def _string_serializer(self) -> dict:
|
||||
def _string_serializer(self) -> dict[str, Any]:
|
||||
# convert content to a single string
|
||||
content = '\n'.join(
|
||||
item.text for item in self.content if isinstance(item, TextContent)
|
||||
)
|
||||
message_dict: dict = {'content': content, 'role': self.role}
|
||||
message_dict: dict[str, Any] = {'content': content, 'role': self.role}
|
||||
|
||||
# add tool call keys if we have a tool call or response
|
||||
return self._add_tool_call_keys(message_dict)
|
||||
|
||||
def _list_serializer(self) -> dict:
|
||||
content: list[dict] = []
|
||||
def _list_serializer(self) -> dict[str, Any]:
|
||||
content: list[dict[str, Any]] = []
|
||||
role_tool_with_prompt_caching = False
|
||||
for item in self.content:
|
||||
d = item.model_dump()
|
||||
@ -120,7 +120,7 @@ class Message(BaseModel):
|
||||
# We know d is a list for ImageContent
|
||||
content.extend([d] if isinstance(d, dict) else d)
|
||||
|
||||
message_dict: dict = {'content': content, 'role': self.role}
|
||||
message_dict: dict[str, Any] = {'content': content, 'role': self.role}
|
||||
|
||||
if role_tool_with_prompt_caching:
|
||||
message_dict['cache_control'] = {'type': 'ephemeral'}
|
||||
@ -128,7 +128,7 @@ class Message(BaseModel):
|
||||
# add tool call keys if we have a tool call or response
|
||||
return self._add_tool_call_keys(message_dict)
|
||||
|
||||
def _add_tool_call_keys(self, message_dict: dict) -> dict:
|
||||
def _add_tool_call_keys(self, message_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add tool call keys if we have a tool call or response.
|
||||
|
||||
NOTE: this is necessary for both native and non-native tool calling
|
||||
|
||||
@ -167,8 +167,9 @@ class Session:
|
||||
"""
|
||||
Initialize LLM, extracted for testing.
|
||||
"""
|
||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||
return LLM(
|
||||
config=self.config.get_llm_config_from_agent(agent_cls),
|
||||
config=self.config.get_llm_config_from_agent(agent_name),
|
||||
retry_listener=self._notify_on_llm_retry,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user