mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add extensive typing to controller directory (#7731)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Ray Myers <ray.myers@gmail.com> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
This commit is contained in:
parent
fa559ace86
commit
dc91cb263b
@ -45,7 +45,7 @@ describe("Empty state", () => {
|
||||
it("should render suggestions if empty", () => {
|
||||
const { store } = renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
chat: {
|
||||
chat: {
|
||||
messages: [],
|
||||
systemMessage: {
|
||||
content: "",
|
||||
@ -76,7 +76,7 @@ describe("Empty state", () => {
|
||||
it("should render the default suggestions", () => {
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
chat: {
|
||||
chat: {
|
||||
messages: [],
|
||||
systemMessage: {
|
||||
content: "",
|
||||
@ -114,7 +114,7 @@ describe("Empty state", () => {
|
||||
const user = userEvent.setup();
|
||||
const { store } = renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
chat: {
|
||||
chat: {
|
||||
messages: [],
|
||||
systemMessage: {
|
||||
content: "",
|
||||
@ -151,7 +151,7 @@ describe("Empty state", () => {
|
||||
const user = userEvent.setup();
|
||||
const { rerender } = renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
chat: {
|
||||
chat: {
|
||||
messages: [],
|
||||
systemMessage: {
|
||||
content: "",
|
||||
|
||||
@ -108,9 +108,7 @@ class CodeActAgent(Agent):
|
||||
|
||||
tools = []
|
||||
if self.config.enable_cmd:
|
||||
tools.append(
|
||||
create_cmd_run_tool(use_short_description=use_short_tool_desc)
|
||||
)
|
||||
tools.append(create_cmd_run_tool(use_short_description=use_short_tool_desc))
|
||||
if self.config.enable_think:
|
||||
tools.append(ThinkTool)
|
||||
if self.config.enable_finish:
|
||||
|
||||
@ -32,9 +32,7 @@ def create_cmd_run_tool(
|
||||
use_short_description: bool = False,
|
||||
) -> ChatCompletionToolParam:
|
||||
description = (
|
||||
_SHORT_BASH_DESCRIPTION
|
||||
if use_short_description
|
||||
else _DETAILED_BASH_DESCRIPTION
|
||||
_SHORT_BASH_DESCRIPTION if use_short_description else _DETAILED_BASH_DESCRIPTION
|
||||
)
|
||||
return ChatCompletionToolParam(
|
||||
type='function',
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.action import Action
|
||||
|
||||
@ -9,7 +10,7 @@ class ActionParseError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.error
|
||||
|
||||
|
||||
@ -20,16 +21,16 @@ class ResponseParser(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
) -> None:
|
||||
# Need pay attention to the item order in self.action_parsers
|
||||
self.action_parsers = []
|
||||
self.action_parsers: list[ActionParser] = []
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, response: str) -> Action:
|
||||
def parse(self, response: Any) -> Action:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- response (str): The response from the LLM.
|
||||
- response: The response from the LLM, which can be a string or a dictionary.
|
||||
|
||||
Returns:
|
||||
- action (Action): The action parsed from the response.
|
||||
@ -37,11 +38,11 @@ class ResponseParser(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response) -> str:
|
||||
def parse_response(self, response: Any) -> str:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- response (str): The response from the LLM.
|
||||
- response: The response from the LLM, which can be a string or a dictionary.
|
||||
|
||||
Returns:
|
||||
- action_str (str): The action str parsed from the response.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
@ -106,11 +108,11 @@ class Agent(ABC):
|
||||
self.llm.reset()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, agent_cls: Type['Agent']):
|
||||
def register(cls, name: str, agent_cls: Type['Agent']) -> None:
|
||||
"""Registers an agent class in the registry.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
@ -190,7 +192,7 @@ class AgentController:
|
||||
self.event_stream.add_event(system_message, EventSource.AGENT)
|
||||
logger.debug(f'System message added to event stream: {system_message}')
|
||||
|
||||
async def close(self, set_stop_state=True) -> None:
|
||||
async def close(self, set_stop_state: bool = True) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
|
||||
Note that it's fairly important that this closes properly, otherwise the state is incomplete.
|
||||
@ -242,18 +244,18 @@ class AgentController:
|
||||
extra_merged = {'session_id': self.id, **extra}
|
||||
getattr(logger, level)(message, extra=extra_merged, stacklevel=2)
|
||||
|
||||
def update_state_before_step(self):
|
||||
def update_state_before_step(self) -> None:
|
||||
self.state.iteration += 1
|
||||
self.state.local_iteration += 1
|
||||
|
||||
async def update_state_after_step(self):
|
||||
async def update_state_after_step(self) -> None:
|
||||
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent._reset()
|
||||
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
|
||||
|
||||
async def _react_to_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
):
|
||||
) -> None:
|
||||
"""React to an exception by setting the agent state to error and sending a status message."""
|
||||
# Store the error reason before setting the agent state
|
||||
self.state.last_error = f'{type(e).__name__}: {str(e)}'
|
||||
@ -293,7 +295,10 @@ class AgentController:
|
||||
# Set the agent state to ERROR after storing the reason
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
|
||||
async def _step_with_exception_handling(self):
|
||||
def step(self) -> None:
|
||||
asyncio.create_task(self._step_with_exception_handling())
|
||||
|
||||
async def _step_with_exception_handling(self) -> None:
|
||||
try:
|
||||
await self._step()
|
||||
except Exception as e:
|
||||
@ -1277,7 +1282,7 @@ class AgentController:
|
||||
extra={'msg_type': 'METRICS'},
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
pending_action_info = '<none>'
|
||||
if (
|
||||
hasattr(self, '_pending_action_info')
|
||||
@ -1300,7 +1305,7 @@ class AgentController:
|
||||
f'_pending_action={pending_action_info})'
|
||||
)
|
||||
|
||||
def _is_awaiting_observation(self):
|
||||
def _is_awaiting_observation(self) -> bool:
|
||||
events = self.event_stream.get_events(reverse=True)
|
||||
for event in events:
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.message import MessageAction
|
||||
@ -79,7 +81,7 @@ class ReplayManager:
|
||||
return event
|
||||
|
||||
@staticmethod
|
||||
def get_replay_events(trajectory) -> list[Event]:
|
||||
def get_replay_events(trajectory: list[dict]) -> list[Event]:
|
||||
if not isinstance(trajectory, list):
|
||||
raise ValueError(
|
||||
f'Expected a list in {trajectory}, got {type(trajectory).__name__}'
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import pickle
|
||||
@ -104,7 +106,9 @@ class State:
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
last_error: str = ''
|
||||
|
||||
def save_to_session(self, sid: str, file_store: FileStore, user_id: str | None):
|
||||
def save_to_session(
|
||||
self, sid: str, file_store: FileStore, user_id: str | None
|
||||
) -> None:
|
||||
pickled = pickle.dumps(self)
|
||||
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
|
||||
encoded = base64.b64encode(pickled).decode('utf-8')
|
||||
@ -165,7 +169,7 @@ class State:
|
||||
state.agent_state = AgentState.LOADING
|
||||
return state
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict:
|
||||
# don't pickle history, it will be restored from the event stream
|
||||
state = self.__dict__.copy()
|
||||
state['history'] = []
|
||||
@ -177,7 +181,7 @@ class State:
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.__dict__.update(state)
|
||||
|
||||
# make sure we always have the attribute history
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.exceptions import (
|
||||
LLMMalformedActionError,
|
||||
TaskInvalidStateError,
|
||||
@ -21,7 +23,7 @@ STATES = [
|
||||
class Task:
|
||||
id: str
|
||||
goal: str
|
||||
parent: 'Task | None'
|
||||
parent: 'Task' | None
|
||||
subtasks: list['Task']
|
||||
|
||||
def __init__(
|
||||
@ -29,8 +31,8 @@ class Task:
|
||||
parent: 'Task',
|
||||
goal: str,
|
||||
state: str = OPEN_STATE,
|
||||
subtasks=None, # noqa: B006
|
||||
):
|
||||
subtasks: list[dict | 'Task'] | None = None, # noqa: B006
|
||||
) -> None:
|
||||
"""Initializes a new instance of the Task class.
|
||||
|
||||
Args:
|
||||
@ -53,15 +55,15 @@ class Task:
|
||||
if isinstance(subtask, Task):
|
||||
self.subtasks.append(subtask)
|
||||
else:
|
||||
goal = subtask.get('goal')
|
||||
state = subtask.get('state')
|
||||
goal = str(subtask.get('goal', ''))
|
||||
state = str(subtask.get('state', OPEN_STATE))
|
||||
subtasks = subtask.get('subtasks')
|
||||
logger.debug(f'Reading: {goal}, {state}, {subtasks}')
|
||||
self.subtasks.append(Task(self, goal, state, subtasks))
|
||||
|
||||
self.state = OPEN_STATE
|
||||
|
||||
def to_string(self, indent=''):
|
||||
def to_string(self, indent: str = '') -> str:
|
||||
"""Returns a string representation of the task and its subtasks.
|
||||
|
||||
Args:
|
||||
@ -86,7 +88,7 @@ class Task:
|
||||
result += subtask.to_string(indent + ' ')
|
||||
return result
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict:
|
||||
"""Returns a dictionary representation of the task.
|
||||
|
||||
Returns:
|
||||
@ -99,10 +101,11 @@ class Task:
|
||||
'subtasks': [t.to_dict() for t in self.subtasks],
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
def set_state(self, state: str) -> None:
|
||||
"""Sets the state of the task and its subtasks.
|
||||
|
||||
Args: state: The new state of the task.
|
||||
Args:
|
||||
state: The new state of the task.
|
||||
|
||||
Raises:
|
||||
TaskInvalidStateError: If the provided state is invalid.
|
||||
@ -123,7 +126,7 @@ class Task:
|
||||
if self.parent is not None:
|
||||
self.parent.set_state(state)
|
||||
|
||||
def get_current_task(self) -> 'Task | None':
|
||||
def get_current_task(self) -> 'Task' | None:
|
||||
"""Retrieves the current task in progress.
|
||||
|
||||
Returns:
|
||||
@ -155,11 +158,11 @@ class RootTask(Task):
|
||||
goal: str = ''
|
||||
parent: None = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.subtasks = []
|
||||
self.state = OPEN_STATE
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Returns a string representation of the root_task.
|
||||
|
||||
Returns:
|
||||
@ -194,7 +197,12 @@ class RootTask(Task):
|
||||
task = task.subtasks[part]
|
||||
return task
|
||||
|
||||
def add_subtask(self, parent_id: str, goal: str, subtasks: list | None = None):
|
||||
def add_subtask(
|
||||
self,
|
||||
parent_id: str,
|
||||
goal: str,
|
||||
subtasks: list[dict | Task] | None = None,
|
||||
) -> None:
|
||||
"""Adds a subtask to a parent task.
|
||||
|
||||
Args:
|
||||
@ -207,7 +215,7 @@ class RootTask(Task):
|
||||
child = Task(parent=parent, goal=goal, subtasks=subtasks)
|
||||
parent.subtasks.append(child)
|
||||
|
||||
def set_subtask_state(self, id: str, state: str):
|
||||
def set_subtask_state(self, id: str, state: str) -> None:
|
||||
"""Sets the state of a subtask.
|
||||
|
||||
Args:
|
||||
|
||||
@ -25,7 +25,7 @@ class StuckDetector:
|
||||
def __init__(self, state: State):
|
||||
self.state = state
|
||||
|
||||
def is_stuck(self, headless_mode: bool = True):
|
||||
def is_stuck(self, headless_mode: bool = True) -> bool:
|
||||
"""Checks if the agent is stuck in a loop.
|
||||
|
||||
Args:
|
||||
@ -109,7 +109,9 @@ class StuckDetector:
|
||||
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
|
||||
def _is_stuck_repeating_action_observation(
|
||||
self, last_actions: list[Event], last_observations: list[Event]
|
||||
) -> bool:
|
||||
# scenario 1: same action, same observation
|
||||
# it takes 4 actions and 4 observations to detect a loop
|
||||
# assert len(last_actions) == 4 and len(last_observations) == 4
|
||||
@ -130,7 +132,9 @@ class StuckDetector:
|
||||
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_error(self, last_actions, last_observations):
|
||||
def _is_stuck_repeating_action_error(
|
||||
self, last_actions: list[Event], last_observations: list[Event]
|
||||
) -> bool:
|
||||
# scenario 2: same action, errors
|
||||
# it takes 3 actions and 3 observations to detect a loop
|
||||
# check if the last three actions are the same and result in errors
|
||||
@ -155,7 +159,12 @@ class StuckDetector:
|
||||
'SyntaxError: unterminated string literal (detected at line'
|
||||
):
|
||||
if self._check_for_consistent_line_error(
|
||||
last_observations[:3], error_message
|
||||
[
|
||||
obs
|
||||
for obs in last_observations[:3]
|
||||
if isinstance(obs, IPythonRunCellObservation)
|
||||
],
|
||||
error_message,
|
||||
):
|
||||
logger.warning(warning)
|
||||
return True
|
||||
@ -163,13 +172,20 @@ class StuckDetector:
|
||||
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
||||
'SyntaxError: incomplete input',
|
||||
) and self._check_for_consistent_invalid_syntax(
|
||||
last_observations[:3], error_message
|
||||
[
|
||||
obs
|
||||
for obs in last_observations[:3]
|
||||
if isinstance(obs, IPythonRunCellObservation)
|
||||
],
|
||||
error_message,
|
||||
):
|
||||
logger.warning(warning)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_for_consistent_invalid_syntax(self, observations, error_message):
|
||||
def _check_for_consistent_invalid_syntax(
|
||||
self, observations: list[IPythonRunCellObservation], error_message: str
|
||||
) -> bool:
|
||||
first_lines = []
|
||||
valid_observations = []
|
||||
|
||||
@ -210,7 +226,9 @@ class StuckDetector:
|
||||
== 1
|
||||
)
|
||||
|
||||
def _check_for_consistent_line_error(self, observations, error_message):
|
||||
def _check_for_consistent_line_error(
|
||||
self, observations: list[IPythonRunCellObservation], error_message: str
|
||||
) -> bool:
|
||||
error_lines = []
|
||||
|
||||
for obs in observations:
|
||||
@ -237,7 +255,7 @@ class StuckDetector:
|
||||
# and the 3rd-to-last line is identical across all occurrences
|
||||
return len(error_lines) == 3 and len(set(error_lines)) == 1
|
||||
|
||||
def _is_stuck_monologue(self, filtered_history):
|
||||
def _is_stuck_monologue(self, filtered_history: list[Event]) -> bool:
|
||||
# scenario 3: monologue
|
||||
# check for repeated MessageActions with source=AGENT
|
||||
# see if the agent is engaged in a good old monologue, telling itself the same thing over and over
|
||||
@ -271,7 +289,9 @@ class StuckDetector:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_action_observation_pattern(self, filtered_history):
|
||||
def _is_stuck_action_observation_pattern(
|
||||
self, filtered_history: list[Event]
|
||||
) -> bool:
|
||||
# scenario 4: action, observation pattern on the last six steps
|
||||
# check if the agent repeats the same (Action, Observation)
|
||||
# every other step in the last six steps
|
||||
@ -313,7 +333,7 @@ class StuckDetector:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_context_window_error(self, filtered_history):
|
||||
def _is_stuck_context_window_error(self, filtered_history: list[Event]) -> bool:
|
||||
"""Detects if we're stuck in a loop of context window errors.
|
||||
|
||||
This happens when we repeatedly get context window errors and try to trim,
|
||||
@ -361,7 +381,7 @@ class StuckDetector:
|
||||
|
||||
return False
|
||||
|
||||
def _eq_no_pid(self, obj1, obj2):
|
||||
def _eq_no_pid(self, obj1: Event, obj2: Event) -> bool:
|
||||
if isinstance(obj1, IPythonRunCellAction) and isinstance(
|
||||
obj2, IPythonRunCellAction
|
||||
):
|
||||
|
||||
@ -46,7 +46,7 @@ class GitHubService(BaseGitService, GitService):
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return ProviderType.GITHUB.value
|
||||
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
"""Retrieve the GH Token from settings store to construct the headers."""
|
||||
if not self.token:
|
||||
|
||||
@ -22,7 +22,6 @@ class GitLabService(BaseGitService, GitService):
|
||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -46,7 +45,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return ProviderType.GITLAB.value
|
||||
|
||||
|
||||
async def _get_gitlab_headers(self) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve the GitLab Token to construct the headers
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user