Add extensive typing to controller directory (#7731)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Ray Myers <ray.myers@gmail.com>
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
Graham Neubig 2025-04-23 11:33:17 -04:00 committed by GitHub
parent fa559ace86
commit dc91cb263b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 95 additions and 58 deletions

View File

@ -45,7 +45,7 @@ describe("Empty state", () => {
it("should render suggestions if empty", () => {
const { store } = renderWithProviders(<ChatInterface />, {
preloadedState: {
chat: {
chat: {
messages: [],
systemMessage: {
content: "",
@ -76,7 +76,7 @@ describe("Empty state", () => {
it("should render the default suggestions", () => {
renderWithProviders(<ChatInterface />, {
preloadedState: {
chat: {
chat: {
messages: [],
systemMessage: {
content: "",
@ -114,7 +114,7 @@ describe("Empty state", () => {
const user = userEvent.setup();
const { store } = renderWithProviders(<ChatInterface />, {
preloadedState: {
chat: {
chat: {
messages: [],
systemMessage: {
content: "",
@ -151,7 +151,7 @@ describe("Empty state", () => {
const user = userEvent.setup();
const { rerender } = renderWithProviders(<ChatInterface />, {
preloadedState: {
chat: {
chat: {
messages: [],
systemMessage: {
content: "",

View File

@ -108,9 +108,7 @@ class CodeActAgent(Agent):
tools = []
if self.config.enable_cmd:
tools.append(
create_cmd_run_tool(use_short_description=use_short_tool_desc)
)
tools.append(create_cmd_run_tool(use_short_description=use_short_tool_desc))
if self.config.enable_think:
tools.append(ThinkTool)
if self.config.enable_finish:

View File

@ -32,9 +32,7 @@ def create_cmd_run_tool(
use_short_description: bool = False,
) -> ChatCompletionToolParam:
description = (
_SHORT_BASH_DESCRIPTION
if use_short_description
else _DETAILED_BASH_DESCRIPTION
_SHORT_BASH_DESCRIPTION if use_short_description else _DETAILED_BASH_DESCRIPTION
)
return ChatCompletionToolParam(
type='function',

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from openhands.events.action import Action
@ -9,7 +10,7 @@ class ActionParseError(Exception):
def __init__(self, error: str):
self.error = error
def __str__(self):
def __str__(self) -> str:
return self.error
@ -20,16 +21,16 @@ class ResponseParser(ABC):
def __init__(
self,
):
) -> None:
# Need pay attention to the item order in self.action_parsers
self.action_parsers = []
self.action_parsers: list[ActionParser] = []
@abstractmethod
def parse(self, response: str) -> Action:
def parse(self, response: Any) -> Action:
"""Parses the action from the response from the LLM.
Parameters:
- response (str): The response from the LLM.
- response: The response from the LLM, which can be a string or a dictionary.
Returns:
- action (Action): The action parsed from the response.
@ -37,11 +38,11 @@ class ResponseParser(ABC):
pass
@abstractmethod
def parse_response(self, response) -> str:
def parse_response(self, response: Any) -> str:
"""Parses the action from the response from the LLM.
Parameters:
- response (str): The response from the LLM.
- response: The response from the LLM, which can be a string or a dictionary.
Returns:
- action_str (str): The action str parsed from the response.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Type
@ -106,11 +108,11 @@ class Agent(ABC):
self.llm.reset()
@property
def name(self):
def name(self) -> str:
return self.__class__.__name__
@classmethod
def register(cls, name: str, agent_cls: Type['Agent']):
def register(cls, name: str, agent_cls: Type['Agent']) -> None:
"""Registers an agent class in the registry.
Parameters:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import copy
import os
@ -190,7 +192,7 @@ class AgentController:
self.event_stream.add_event(system_message, EventSource.AGENT)
logger.debug(f'System message added to event stream: {system_message}')
async def close(self, set_stop_state=True) -> None:
async def close(self, set_stop_state: bool = True) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
Note that it's fairly important that this closes properly, otherwise the state is incomplete.
@ -242,18 +244,18 @@ class AgentController:
extra_merged = {'session_id': self.id, **extra}
getattr(logger, level)(message, extra=extra_merged, stacklevel=2)
def update_state_before_step(self):
def update_state_before_step(self) -> None:
self.state.iteration += 1
self.state.local_iteration += 1
async def update_state_after_step(self):
async def update_state_after_step(self) -> None:
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent._reset()
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
async def _react_to_exception(
self,
e: Exception,
):
) -> None:
"""React to an exception by setting the agent state to error and sending a status message."""
# Store the error reason before setting the agent state
self.state.last_error = f'{type(e).__name__}: {str(e)}'
@ -293,7 +295,10 @@ class AgentController:
# Set the agent state to ERROR after storing the reason
await self.set_agent_state_to(AgentState.ERROR)
async def _step_with_exception_handling(self):
def step(self) -> None:
asyncio.create_task(self._step_with_exception_handling())
async def _step_with_exception_handling(self) -> None:
try:
await self._step()
except Exception as e:
@ -1277,7 +1282,7 @@ class AgentController:
extra={'msg_type': 'METRICS'},
)
def __repr__(self):
def __repr__(self) -> str:
pending_action_info = '<none>'
if (
hasattr(self, '_pending_action_info')
@ -1300,7 +1305,7 @@ class AgentController:
f'_pending_action={pending_action_info})'
)
def _is_awaiting_observation(self):
def _is_awaiting_observation(self) -> bool:
events = self.event_stream.get_events(reverse=True)
for event in events:
if isinstance(event, AgentStateChangedObservation):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.message import MessageAction
@ -79,7 +81,7 @@ class ReplayManager:
return event
@staticmethod
def get_replay_events(trajectory) -> list[Event]:
def get_replay_events(trajectory: list[dict]) -> list[Event]:
if not isinstance(trajectory, list):
raise ValueError(
f'Expected a list in {trajectory}, got {type(trajectory).__name__}'

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import os
import pickle
@ -104,7 +106,9 @@ class State:
extra_data: dict[str, Any] = field(default_factory=dict)
last_error: str = ''
def save_to_session(self, sid: str, file_store: FileStore, user_id: str | None):
def save_to_session(
self, sid: str, file_store: FileStore, user_id: str | None
) -> None:
pickled = pickle.dumps(self)
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
encoded = base64.b64encode(pickled).decode('utf-8')
@ -165,7 +169,7 @@ class State:
state.agent_state = AgentState.LOADING
return state
def __getstate__(self):
def __getstate__(self) -> dict:
# don't pickle history, it will be restored from the event stream
state = self.__dict__.copy()
state['history'] = []
@ -177,7 +181,7 @@ class State:
return state
def __setstate__(self, state):
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
# make sure we always have the attribute history

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from openhands.core.exceptions import (
LLMMalformedActionError,
TaskInvalidStateError,
@ -21,7 +23,7 @@ STATES = [
class Task:
id: str
goal: str
parent: 'Task | None'
parent: 'Task' | None
subtasks: list['Task']
def __init__(
@ -29,8 +31,8 @@ class Task:
parent: 'Task',
goal: str,
state: str = OPEN_STATE,
subtasks=None, # noqa: B006
):
subtasks: list[dict | 'Task'] | None = None, # noqa: B006
) -> None:
"""Initializes a new instance of the Task class.
Args:
@ -53,15 +55,15 @@ class Task:
if isinstance(subtask, Task):
self.subtasks.append(subtask)
else:
goal = subtask.get('goal')
state = subtask.get('state')
goal = str(subtask.get('goal', ''))
state = str(subtask.get('state', OPEN_STATE))
subtasks = subtask.get('subtasks')
logger.debug(f'Reading: {goal}, {state}, {subtasks}')
self.subtasks.append(Task(self, goal, state, subtasks))
self.state = OPEN_STATE
def to_string(self, indent=''):
def to_string(self, indent: str = '') -> str:
"""Returns a string representation of the task and its subtasks.
Args:
@ -86,7 +88,7 @@ class Task:
result += subtask.to_string(indent + ' ')
return result
def to_dict(self):
def to_dict(self) -> dict:
"""Returns a dictionary representation of the task.
Returns:
@ -99,10 +101,11 @@ class Task:
'subtasks': [t.to_dict() for t in self.subtasks],
}
def set_state(self, state):
def set_state(self, state: str) -> None:
"""Sets the state of the task and its subtasks.
Args: state: The new state of the task.
Args:
state: The new state of the task.
Raises:
TaskInvalidStateError: If the provided state is invalid.
@ -123,7 +126,7 @@ class Task:
if self.parent is not None:
self.parent.set_state(state)
def get_current_task(self) -> 'Task | None':
def get_current_task(self) -> 'Task' | None:
"""Retrieves the current task in progress.
Returns:
@ -155,11 +158,11 @@ class RootTask(Task):
goal: str = ''
parent: None = None
def __init__(self):
def __init__(self) -> None:
self.subtasks = []
self.state = OPEN_STATE
def __str__(self):
def __str__(self) -> str:
"""Returns a string representation of the root_task.
Returns:
@ -194,7 +197,12 @@ class RootTask(Task):
task = task.subtasks[part]
return task
def add_subtask(self, parent_id: str, goal: str, subtasks: list | None = None):
def add_subtask(
self,
parent_id: str,
goal: str,
subtasks: list[dict | Task] | None = None,
) -> None:
"""Adds a subtask to a parent task.
Args:
@ -207,7 +215,7 @@ class RootTask(Task):
child = Task(parent=parent, goal=goal, subtasks=subtasks)
parent.subtasks.append(child)
def set_subtask_state(self, id: str, state: str):
def set_subtask_state(self, id: str, state: str) -> None:
"""Sets the state of a subtask.
Args:

View File

@ -25,7 +25,7 @@ class StuckDetector:
def __init__(self, state: State):
self.state = state
def is_stuck(self, headless_mode: bool = True):
def is_stuck(self, headless_mode: bool = True) -> bool:
"""Checks if the agent is stuck in a loop.
Args:
@ -109,7 +109,9 @@ class StuckDetector:
return False
def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
def _is_stuck_repeating_action_observation(
self, last_actions: list[Event], last_observations: list[Event]
) -> bool:
# scenario 1: same action, same observation
# it takes 4 actions and 4 observations to detect a loop
# assert len(last_actions) == 4 and len(last_observations) == 4
@ -130,7 +132,9 @@ class StuckDetector:
return False
def _is_stuck_repeating_action_error(self, last_actions, last_observations):
def _is_stuck_repeating_action_error(
self, last_actions: list[Event], last_observations: list[Event]
) -> bool:
# scenario 2: same action, errors
# it takes 3 actions and 3 observations to detect a loop
# check if the last three actions are the same and result in errors
@ -155,7 +159,12 @@ class StuckDetector:
'SyntaxError: unterminated string literal (detected at line'
):
if self._check_for_consistent_line_error(
last_observations[:3], error_message
[
obs
for obs in last_observations[:3]
if isinstance(obs, IPythonRunCellObservation)
],
error_message,
):
logger.warning(warning)
return True
@ -163,13 +172,20 @@ class StuckDetector:
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
'SyntaxError: incomplete input',
) and self._check_for_consistent_invalid_syntax(
last_observations[:3], error_message
[
obs
for obs in last_observations[:3]
if isinstance(obs, IPythonRunCellObservation)
],
error_message,
):
logger.warning(warning)
return True
return False
def _check_for_consistent_invalid_syntax(self, observations, error_message):
def _check_for_consistent_invalid_syntax(
self, observations: list[IPythonRunCellObservation], error_message: str
) -> bool:
first_lines = []
valid_observations = []
@ -210,7 +226,9 @@ class StuckDetector:
== 1
)
def _check_for_consistent_line_error(self, observations, error_message):
def _check_for_consistent_line_error(
self, observations: list[IPythonRunCellObservation], error_message: str
) -> bool:
error_lines = []
for obs in observations:
@ -237,7 +255,7 @@ class StuckDetector:
# and the 3rd-to-last line is identical across all occurrences
return len(error_lines) == 3 and len(set(error_lines)) == 1
def _is_stuck_monologue(self, filtered_history):
def _is_stuck_monologue(self, filtered_history: list[Event]) -> bool:
# scenario 3: monologue
# check for repeated MessageActions with source=AGENT
# see if the agent is engaged in a good old monologue, telling itself the same thing over and over
@ -271,7 +289,9 @@ class StuckDetector:
return True
return False
def _is_stuck_action_observation_pattern(self, filtered_history):
def _is_stuck_action_observation_pattern(
self, filtered_history: list[Event]
) -> bool:
# scenario 4: action, observation pattern on the last six steps
# check if the agent repeats the same (Action, Observation)
# every other step in the last six steps
@ -313,7 +333,7 @@ class StuckDetector:
return True
return False
def _is_stuck_context_window_error(self, filtered_history):
def _is_stuck_context_window_error(self, filtered_history: list[Event]) -> bool:
"""Detects if we're stuck in a loop of context window errors.
This happens when we repeatedly get context window errors and try to trim,
@ -361,7 +381,7 @@ class StuckDetector:
return False
def _eq_no_pid(self, obj1, obj2):
def _eq_no_pid(self, obj1: Event, obj2: Event) -> bool:
if isinstance(obj1, IPythonRunCellAction) and isinstance(
obj2, IPythonRunCellAction
):

View File

@ -46,7 +46,7 @@ class GitHubService(BaseGitService, GitService):
@property
def provider(self) -> str:
return ProviderType.GITHUB.value
async def _get_github_headers(self) -> dict:
"""Retrieve the GH Token from settings store to construct the headers."""
if not self.token:

View File

@ -22,7 +22,6 @@ class GitLabService(BaseGitService, GitService):
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,
@ -46,7 +45,7 @@ class GitLabService(BaseGitService, GitService):
@property
def provider(self) -> str:
return ProviderType.GITLAB.value
async def _get_gitlab_headers(self) -> dict[str, Any]:
"""
Retrieve the GitLab Token to construct the headers