diff --git a/frontend/__tests__/components/chat/chat-interface.test.tsx b/frontend/__tests__/components/chat/chat-interface.test.tsx index 69c67406b6..5b48db05eb 100644 --- a/frontend/__tests__/components/chat/chat-interface.test.tsx +++ b/frontend/__tests__/components/chat/chat-interface.test.tsx @@ -45,7 +45,7 @@ describe("Empty state", () => { it("should render suggestions if empty", () => { const { store } = renderWithProviders(, { preloadedState: { - chat: { + chat: { messages: [], systemMessage: { content: "", @@ -76,7 +76,7 @@ describe("Empty state", () => { it("should render the default suggestions", () => { renderWithProviders(, { preloadedState: { - chat: { + chat: { messages: [], systemMessage: { content: "", @@ -114,7 +114,7 @@ describe("Empty state", () => { const user = userEvent.setup(); const { store } = renderWithProviders(, { preloadedState: { - chat: { + chat: { messages: [], systemMessage: { content: "", @@ -151,7 +151,7 @@ describe("Empty state", () => { const user = userEvent.setup(); const { rerender } = renderWithProviders(, { preloadedState: { - chat: { + chat: { messages: [], systemMessage: { content: "", diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 61ab0ac1c6..e246e77467 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -108,9 +108,7 @@ class CodeActAgent(Agent): tools = [] if self.config.enable_cmd: - tools.append( - create_cmd_run_tool(use_short_description=use_short_tool_desc) - ) + tools.append(create_cmd_run_tool(use_short_description=use_short_tool_desc)) if self.config.enable_think: tools.append(ThinkTool) if self.config.enable_finish: diff --git a/openhands/agenthub/codeact_agent/tools/bash.py b/openhands/agenthub/codeact_agent/tools/bash.py index af557778d0..9adbeaa716 100644 --- a/openhands/agenthub/codeact_agent/tools/bash.py +++ b/openhands/agenthub/codeact_agent/tools/bash.py @@ -32,9 +32,7 @@ def create_cmd_run_tool( use_short_description: bool = False, ) -> ChatCompletionToolParam: description = ( - _SHORT_BASH_DESCRIPTION - if use_short_description - else _DETAILED_BASH_DESCRIPTION + _SHORT_BASH_DESCRIPTION if use_short_description else _DETAILED_BASH_DESCRIPTION ) return ChatCompletionToolParam( type='function', diff --git a/openhands/controller/action_parser.py b/openhands/controller/action_parser.py index fdd4f864b4..4250f46f38 100644 --- a/openhands/controller/action_parser.py +++ b/openhands/controller/action_parser.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any from openhands.events.action import Action @@ -9,7 +10,7 @@ class ActionParseError(Exception): def __init__(self, error: str): self.error = error - def __str__(self): + def __str__(self) -> str: return self.error @@ -20,16 +21,16 @@ class ResponseParser(ABC): def __init__( self, - ): + ) -> None: # Need pay attention to the item order in self.action_parsers - self.action_parsers = [] + self.action_parsers: list[ActionParser] = [] @abstractmethod - def parse(self, response: str) -> Action: + def parse(self, response: Any) -> Action: """Parses the action from the response from the LLM. Parameters: - - response (str): The response from the LLM. + - response: The response from the LLM, which can be a string or a dictionary. Returns: - action (Action): The action parsed from the response. @@ -37,11 +38,11 @@ class ResponseParser(ABC): pass @abstractmethod - def parse_response(self, response) -> str: + def parse_response(self, response: Any) -> str: """Parses the action from the response from the LLM. Parameters: - - response (str): The response from the LLM. + - response: The response from the LLM, which can be a string or a dictionary. Returns: - action_str (str): The action str parsed from the response. diff --git a/openhands/controller/agent.py b/openhands/controller/agent.py index 034cacf481..ca5908adfc 100644 --- a/openhands/controller/agent.py +++ b/openhands/controller/agent.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Type @@ -106,11 +108,11 @@ class Agent(ABC): self.llm.reset() @property - def name(self): + def name(self) -> str: return self.__class__.__name__ @classmethod - def register(cls, name: str, agent_cls: Type['Agent']): + def register(cls, name: str, agent_cls: Type['Agent']) -> None: """Registers an agent class in the registry. Parameters: diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index bc88f97170..795aad1909 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import copy import os @@ -190,7 +192,7 @@ class AgentController: self.event_stream.add_event(system_message, EventSource.AGENT) logger.debug(f'System message added to event stream: {system_message}') - async def close(self, set_stop_state=True) -> None: + async def close(self, set_stop_state: bool = True) -> None: """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream. Note that it's fairly important that this closes properly, otherwise the state is incomplete. @@ -242,18 +244,18 @@ class AgentController: extra_merged = {'session_id': self.id, **extra} getattr(logger, level)(message, extra=extra_merged, stacklevel=2) - def update_state_before_step(self): + def update_state_before_step(self) -> None: self.state.iteration += 1 self.state.local_iteration += 1 - async def update_state_after_step(self): + async def update_state_after_step(self) -> None: # update metrics especially for cost. Use deepcopy to avoid it being modified by agent._reset() self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics) async def _react_to_exception( self, e: Exception, - ): + ) -> None: """React to an exception by setting the agent state to error and sending a status message.""" # Store the error reason before setting the agent state self.state.last_error = f'{type(e).__name__}: {str(e)}' @@ -293,7 +295,10 @@ class AgentController: # Set the agent state to ERROR after storing the reason await self.set_agent_state_to(AgentState.ERROR) - async def _step_with_exception_handling(self): + def step(self) -> None: + asyncio.create_task(self._step_with_exception_handling()) + + async def _step_with_exception_handling(self) -> None: try: await self._step() except Exception as e: @@ -1277,7 +1282,7 @@ class AgentController: extra={'msg_type': 'METRICS'}, ) - def __repr__(self): + def __repr__(self) -> str: pending_action_info = '' if ( hasattr(self, '_pending_action_info') @@ -1300,7 +1305,7 @@ class AgentController: f'_pending_action={pending_action_info})' ) - def _is_awaiting_observation(self): + def _is_awaiting_observation(self) -> bool: events = self.event_stream.get_events(reverse=True) for event in events: if isinstance(event, AgentStateChangedObservation): diff --git a/openhands/controller/replay.py b/openhands/controller/replay.py index 81b17ee542..0eb8009435 100644 --- a/openhands/controller/replay.py +++ b/openhands/controller/replay.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from openhands.core.logger import openhands_logger as logger from openhands.events.action.action import Action from openhands.events.action.message import MessageAction @@ -79,7 +81,7 @@ class ReplayManager: return event @staticmethod - def get_replay_events(trajectory) -> list[Event]: + def get_replay_events(trajectory: list[dict]) -> list[Event]: if not isinstance(trajectory, list): raise ValueError( f'Expected a list in {trajectory}, got {type(trajectory).__name__}' diff --git a/openhands/controller/state/state.py b/openhands/controller/state/state.py index 71a4ce1a4a..847b513b5e 100644 --- a/openhands/controller/state/state.py +++ b/openhands/controller/state/state.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import os import pickle @@ -104,7 +106,9 @@ class State: extra_data: dict[str, Any] = field(default_factory=dict) last_error: str = '' - def save_to_session(self, sid: str, file_store: FileStore, user_id: str | None): + def save_to_session( + self, sid: str, file_store: FileStore, user_id: str | None + ) -> None: pickled = pickle.dumps(self) logger.debug(f'Saving state to session {sid}:{self.agent_state}') encoded = base64.b64encode(pickled).decode('utf-8') @@ -165,7 +169,7 @@ class State: state.agent_state = AgentState.LOADING return state - def __getstate__(self): + def __getstate__(self) -> dict: # don't pickle history, it will be restored from the event stream state = self.__dict__.copy() state['history'] = [] @@ -177,7 +181,7 @@ class State: return state - def __setstate__(self, state): + def __setstate__(self, state: dict) -> None: self.__dict__.update(state) # make sure we always have the attribute history diff --git a/openhands/controller/state/task.py b/openhands/controller/state/task.py index 456ae0f0a2..e866c15a4b 100644 --- a/openhands/controller/state/task.py +++ b/openhands/controller/state/task.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from openhands.core.exceptions import ( LLMMalformedActionError, TaskInvalidStateError, @@ -21,7 +23,7 @@ STATES = [ class Task: id: str goal: str - parent: 'Task | None' + parent: 'Task' | None subtasks: list['Task'] def __init__( @@ -29,8 +31,8 @@ class Task: parent: 'Task', goal: str, state: str = OPEN_STATE, - subtasks=None, # noqa: B006 - ): + subtasks: list[dict | 'Task'] | None = None, # noqa: B006 + ) -> None: """Initializes a new instance of the Task class. Args: @@ -53,15 +55,15 @@ class Task: if isinstance(subtask, Task): self.subtasks.append(subtask) else: - goal = subtask.get('goal') - state = subtask.get('state') + goal = str(subtask.get('goal', '')) + state = str(subtask.get('state', OPEN_STATE)) subtasks = subtask.get('subtasks') logger.debug(f'Reading: {goal}, {state}, {subtasks}') self.subtasks.append(Task(self, goal, state, subtasks)) self.state = OPEN_STATE - def to_string(self, indent=''): + def to_string(self, indent: str = '') -> str: """Returns a string representation of the task and its subtasks. Args: @@ -86,7 +88,7 @@ class Task: result += subtask.to_string(indent + ' ') return result - def to_dict(self): + def to_dict(self) -> dict: """Returns a dictionary representation of the task. Returns: @@ -99,10 +101,11 @@ class Task: 'subtasks': [t.to_dict() for t in self.subtasks], } - def set_state(self, state): + def set_state(self, state: str) -> None: """Sets the state of the task and its subtasks. - Args: state: The new state of the task. + Args: + state: The new state of the task. Raises: TaskInvalidStateError: If the provided state is invalid. @@ -123,7 +126,7 @@ class Task: if self.parent is not None: self.parent.set_state(state) - def get_current_task(self) -> 'Task | None': + def get_current_task(self) -> 'Task' | None: """Retrieves the current task in progress. Returns: @@ -155,11 +158,11 @@ class RootTask(Task): goal: str = '' parent: None = None - def __init__(self): + def __init__(self) -> None: self.subtasks = [] self.state = OPEN_STATE - def __str__(self): + def __str__(self) -> str: """Returns a string representation of the root_task. Returns: @@ -194,7 +197,12 @@ class RootTask(Task): task = task.subtasks[part] return task - def add_subtask(self, parent_id: str, goal: str, subtasks: list | None = None): + def add_subtask( + self, + parent_id: str, + goal: str, + subtasks: list[dict | Task] | None = None, + ) -> None: """Adds a subtask to a parent task. Args: @@ -207,7 +215,7 @@ class RootTask(Task): child = Task(parent=parent, goal=goal, subtasks=subtasks) parent.subtasks.append(child) - def set_subtask_state(self, id: str, state: str): + def set_subtask_state(self, id: str, state: str) -> None: """Sets the state of a subtask. Args: diff --git a/openhands/controller/stuck.py b/openhands/controller/stuck.py index 0fc85d0f97..e0772a4150 100644 --- a/openhands/controller/stuck.py +++ b/openhands/controller/stuck.py @@ -25,7 +25,7 @@ class StuckDetector: def __init__(self, state: State): self.state = state - def is_stuck(self, headless_mode: bool = True): + def is_stuck(self, headless_mode: bool = True) -> bool: """Checks if the agent is stuck in a loop. Args: @@ -109,7 +109,9 @@ class StuckDetector: return False - def _is_stuck_repeating_action_observation(self, last_actions, last_observations): + def _is_stuck_repeating_action_observation( + self, last_actions: list[Event], last_observations: list[Event] + ) -> bool: # scenario 1: same action, same observation # it takes 4 actions and 4 observations to detect a loop # assert len(last_actions) == 4 and len(last_observations) == 4 @@ -130,7 +132,9 @@ class StuckDetector: return False - def _is_stuck_repeating_action_error(self, last_actions, last_observations): + def _is_stuck_repeating_action_error( + self, last_actions: list[Event], last_observations: list[Event] + ) -> bool: # scenario 2: same action, errors # it takes 3 actions and 3 observations to detect a loop # check if the last three actions are the same and result in errors @@ -155,7 +159,12 @@ class StuckDetector: 'SyntaxError: unterminated string literal (detected at line' ): if self._check_for_consistent_line_error( - last_observations[:3], error_message + [ + obs + for obs in last_observations[:3] + if isinstance(obs, IPythonRunCellObservation) + ], + error_message, ): logger.warning(warning) return True @@ -163,13 +172,20 @@ class StuckDetector: 'SyntaxError: invalid syntax. Perhaps you forgot a comma?', 'SyntaxError: incomplete input', ) and self._check_for_consistent_invalid_syntax( - last_observations[:3], error_message + [ + obs + for obs in last_observations[:3] + if isinstance(obs, IPythonRunCellObservation) + ], + error_message, ): logger.warning(warning) return True return False - def _check_for_consistent_invalid_syntax(self, observations, error_message): + def _check_for_consistent_invalid_syntax( + self, observations: list[IPythonRunCellObservation], error_message: str + ) -> bool: first_lines = [] valid_observations = [] @@ -210,7 +226,9 @@ class StuckDetector: == 1 ) - def _check_for_consistent_line_error(self, observations, error_message): + def _check_for_consistent_line_error( + self, observations: list[IPythonRunCellObservation], error_message: str + ) -> bool: error_lines = [] for obs in observations: @@ -237,7 +255,7 @@ class StuckDetector: # and the 3rd-to-last line is identical across all occurrences return len(error_lines) == 3 and len(set(error_lines)) == 1 - def _is_stuck_monologue(self, filtered_history): + def _is_stuck_monologue(self, filtered_history: list[Event]) -> bool: # scenario 3: monologue # check for repeated MessageActions with source=AGENT # see if the agent is engaged in a good old monologue, telling itself the same thing over and over @@ -271,7 +289,9 @@ class StuckDetector: return True return False - def _is_stuck_action_observation_pattern(self, filtered_history): + def _is_stuck_action_observation_pattern( + self, filtered_history: list[Event] + ) -> bool: # scenario 4: action, observation pattern on the last six steps # check if the agent repeats the same (Action, Observation) # every other step in the last six steps @@ -313,7 +333,7 @@ class StuckDetector: return True return False - def _is_stuck_context_window_error(self, filtered_history): + def _is_stuck_context_window_error(self, filtered_history: list[Event]) -> bool: """Detects if we're stuck in a loop of context window errors. This happens when we repeatedly get context window errors and try to trim, @@ -361,7 +381,7 @@ class StuckDetector: return False - def _eq_no_pid(self, obj1, obj2): + def _eq_no_pid(self, obj1: Event, obj2: Event) -> bool: if isinstance(obj1, IPythonRunCellAction) and isinstance( obj2, IPythonRunCellAction ): diff --git a/openhands/integrations/github/github_service.py b/openhands/integrations/github/github_service.py index 7813c781d7..0550d9a797 100644 --- a/openhands/integrations/github/github_service.py +++ b/openhands/integrations/github/github_service.py @@ -46,7 +46,7 @@ class GitHubService(BaseGitService, GitService): @property def provider(self) -> str: return ProviderType.GITHUB.value - + async def _get_github_headers(self) -> dict: """Retrieve the GH Token from settings store to construct the headers.""" if not self.token: diff --git a/openhands/integrations/gitlab/gitlab_service.py b/openhands/integrations/gitlab/gitlab_service.py index c21c068f60..7efc58e639 100644 --- a/openhands/integrations/gitlab/gitlab_service.py +++ b/openhands/integrations/gitlab/gitlab_service.py @@ -22,7 +22,6 @@ class GitLabService(BaseGitService, GitService): GRAPHQL_URL = 'https://gitlab.com/api/graphql' token: SecretStr = SecretStr('') refresh = False - def __init__( self, @@ -46,7 +45,7 @@ class GitLabService(BaseGitService, GitService): @property def provider(self) -> str: return ProviderType.GITLAB.value - + async def _get_gitlab_headers(self) -> dict[str, Any]: """ Retrieve the GitLab Token to construct the headers