mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Rename OpenDevin to OpenHands (#3472)
* Replace OpenDevin with OpenHands * Update CONTRIBUTING.md * Update README.md * Update README.md * update poetry lock; move opendevin folder to openhands * fix env var * revert image references in docs * revert permissions * revert permissions --------- Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
52
openhands/README.md
Normal file
52
openhands/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# OpenHands Architecture
|
||||
|
||||
This directory contains the core components of OpenHands.
|
||||
|
||||
This diagram provides an overview of the roles of each component and how they communicate and collaborate.
|
||||

|
||||
|
||||
## Classes
|
||||
The key classes in OpenHands are:
|
||||
|
||||
* LLM: brokers all interactions with large language models. Works with any underlying completion model, thanks to LiteLLM.
|
||||
* Agent: responsible for looking at the current State, and producing an Action that moves one step closer toward the end-goal.
|
||||
* AgentController: initializes the Agent, manages State, and drive the main loop that pushes the Agent forward, step by step
|
||||
* State: represents the current state of the Agent's task. Includes things like the current step, a history of recent events, the Agent's long-term plan, etc
|
||||
* EventStream: a central hub for Events, where any component can publish Events, or listen for Events published by other components
|
||||
* Event: an Action or Observeration
|
||||
* Action: represents a request to e.g. edit a file, run a command, or send a message
|
||||
* Observation: represents information collected from the environment, e.g. file contents or command output
|
||||
* Runtime: responsible for performing Actions, and sending back Observations
|
||||
* Sandbox: the part of the runtime responsible for running commands, e.g. inside of Docker
|
||||
* Server: brokers OpenHands sessions over HTTP, e.g. to drive the frontend
|
||||
* Session: holds a single EventStream, a single AgentController, and a single Runtime. Generally represents a single task (but potentially including several user prompts)
|
||||
* SessionManager: keeps a list of active sessions, and ensures requests are routed to the correct Session
|
||||
|
||||
## Control Flow
|
||||
Here's the basic loop (in pseudocode) that drives agents.
|
||||
```python
|
||||
while True:
|
||||
prompt = agent.generate_prompt(state)
|
||||
response = llm.completion(prompt)
|
||||
action = agent.parse_response(response)
|
||||
observation = runtime.run(action)
|
||||
state = state.update(action, observation)
|
||||
```
|
||||
|
||||
In reality, most of this is achieved through message passing, via the EventStream.
|
||||
The EventStream serves as the backbone for all communication in OpenHands.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Agent--Actions-->AgentController
|
||||
AgentController--State-->Agent
|
||||
AgentController--Actions-->EventStream
|
||||
EventStream--Observations-->AgentController
|
||||
Runtime--Observations-->EventStream
|
||||
EventStream--Actions-->Runtime
|
||||
Frontend--Actions-->EventStream
|
||||
```
|
||||
|
||||
## Runtime
|
||||
|
||||
Please refer to the [documentation](https://docs.all-hands.dev/modules/usage/runtime) to learn more about `Runtime`.
|
||||
0
openhands/__init__.py
Normal file
0
openhands/__init__.py
Normal file
5
openhands/controller/__init__.py
Normal file
5
openhands/controller/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .agent_controller import AgentController
|
||||
|
||||
__all__ = [
|
||||
'AgentController',
|
||||
]
|
||||
67
openhands/controller/action_parser.py
Normal file
67
openhands/controller/action_parser.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from openhands.events.action import Action
|
||||
|
||||
|
||||
class ResponseParser(ABC):
|
||||
"""This abstract base class is a general interface for an response parser dedicated to
|
||||
parsing the action from the response from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
# Need pay attention to the item order in self.action_parsers
|
||||
self.action_parsers = []
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, response: str) -> Action:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- response (str): The response from the LLM.
|
||||
|
||||
Returns:
|
||||
- action (Action): The action parsed from the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response) -> str:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- response (str): The response from the LLM.
|
||||
|
||||
Returns:
|
||||
- action_str (str): The action str parsed from the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_action(self, action_str: str) -> Action:
|
||||
"""Parses the action from the response from the LLM.
|
||||
|
||||
Parameters:
|
||||
- action_str (str): The response from the LLM.
|
||||
|
||||
Returns:
|
||||
- action (Action): The action parsed from the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ActionParser(ABC):
|
||||
"""This abstract base class is a general interface for an action parser dedicated to
|
||||
parsing the action from the action str from the LLM.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def check_condition(self, action_str: str) -> bool:
|
||||
"""Check if the action string can be parsed by this parser."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, action_str: str) -> Action:
|
||||
"""Parses the action from the action string from the LLM response."""
|
||||
pass
|
||||
109
openhands/controller/agent.py
Normal file
109
openhands/controller/agent.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action
|
||||
from openhands.core.exceptions import (
|
||||
AgentAlreadyRegisteredError,
|
||||
AgentNotRegisteredError,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
|
||||
|
||||
class Agent(ABC):
|
||||
DEPRECATED = False
|
||||
"""
|
||||
This abstract base class is an general interface for an agent dedicated to
|
||||
executing a specific instruction and allowing human interaction with the
|
||||
agent during execution.
|
||||
It tracks the execution status and maintains a history of interactions.
|
||||
"""
|
||||
|
||||
_registry: dict[str, Type['Agent']] = {}
|
||||
sandbox_plugins: list[PluginRequirement] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLM,
|
||||
config: 'AgentConfig',
|
||||
):
|
||||
self.llm = llm
|
||||
self.config = config
|
||||
self._complete = False
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Indicates whether the current instruction execution is complete.
|
||||
|
||||
Returns:
|
||||
- complete (bool): True if execution is complete; False otherwise.
|
||||
"""
|
||||
return self._complete
|
||||
|
||||
@abstractmethod
|
||||
def step(self, state: 'State') -> 'Action':
|
||||
"""Starts the execution of the assigned instruction. This method should
|
||||
be implemented by subclasses to define the specific execution logic.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the agent's execution status and clears the history. This method can be used
|
||||
to prepare the agent for restarting the instruction or cleaning up before destruction.
|
||||
|
||||
"""
|
||||
# TODO clear history
|
||||
self._complete = False
|
||||
|
||||
if self.llm:
|
||||
self.llm.reset()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, agent_cls: Type['Agent']):
|
||||
"""Registers an agent class in the registry.
|
||||
|
||||
Parameters:
|
||||
- name (str): The name to register the class under.
|
||||
- agent_cls (Type['Agent']): The class to register.
|
||||
|
||||
Raises:
|
||||
- AgentAlreadyRegisteredError: If name already registered
|
||||
"""
|
||||
if name in cls._registry:
|
||||
raise AgentAlreadyRegisteredError(name)
|
||||
cls._registry[name] = agent_cls
|
||||
|
||||
@classmethod
|
||||
def get_cls(cls, name: str) -> Type['Agent']:
|
||||
"""Retrieves an agent class from the registry.
|
||||
|
||||
Parameters:
|
||||
- name (str): The name of the class to retrieve
|
||||
|
||||
Returns:
|
||||
- agent_cls (Type['Agent']): The class registered under the specified name.
|
||||
|
||||
Raises:
|
||||
- AgentNotRegisteredError: If name not registered
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
raise AgentNotRegisteredError(name)
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def list_agents(cls) -> list[str]:
|
||||
"""Retrieves the list of all agent names from the registry.
|
||||
|
||||
Raises:
|
||||
- AgentNotRegisteredError: If no agent is registered
|
||||
"""
|
||||
if not bool(cls._registry):
|
||||
raise AgentNotRegisteredError()
|
||||
return list(cls._registry.keys())
|
||||
545
openhands/controller/agent_controller.py
Normal file
545
openhands/controller/agent_controller.py
Normal file
@@ -0,0 +1,545 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Type
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
from openhands.controller.stuck import StuckDetector
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.exceptions import (
|
||||
LLMMalformedActionError,
|
||||
LLMNoActionError,
|
||||
LLMResponseError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
AddTaskAction,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
ModifyTaskAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
# note: RESUME is only available on web GUI
|
||||
TRAFFIC_CONTROL_REMINDER = (
|
||||
"Please click on resume button if you'd like to continue, or start a new task."
|
||||
)
|
||||
|
||||
|
||||
class AgentController:
|
||||
id: str
|
||||
agent: Agent
|
||||
max_iterations: int
|
||||
event_stream: EventStream
|
||||
state: State
|
||||
confirmation_mode: bool
|
||||
agent_to_llm_config: dict[str, LLMConfig]
|
||||
agent_configs: dict[str, AgentConfig]
|
||||
agent_task: asyncio.Task | None = None
|
||||
parent: 'AgentController | None' = None
|
||||
delegate: 'AgentController | None' = None
|
||||
_pending_action: Action | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: Agent,
|
||||
event_stream: EventStream,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
sid: str = 'default',
|
||||
confirmation_mode: bool = False,
|
||||
initial_state: State | None = None,
|
||||
is_delegate: bool = False,
|
||||
headless_mode: bool = True,
|
||||
):
|
||||
"""Initializes a new instance of the AgentController class.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to control.
|
||||
event_stream: The event stream to publish events to.
|
||||
max_iterations: The maximum number of iterations the agent can run.
|
||||
max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
|
||||
agent_to_llm_config: A dictionary mapping agent names to LLM configurations in the case that
|
||||
we delegate to a different agent.
|
||||
agent_configs: A dictionary mapping agent names to agent configurations in the case that
|
||||
we delegate to a different agent.
|
||||
sid: The session ID of the agent.
|
||||
initial_state: The initial state of the controller.
|
||||
is_delegate: Whether this controller is a delegate.
|
||||
headless_mode: Whether the agent is run in headless mode.
|
||||
"""
|
||||
self._step_lock = asyncio.Lock()
|
||||
self.id = sid
|
||||
self.agent = agent
|
||||
self.headless_mode = headless_mode
|
||||
|
||||
# subscribe to the event stream
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
|
||||
)
|
||||
|
||||
# state from the previous session, state from a parent agent, or a fresh state
|
||||
self.set_initial_state(
|
||||
state=initial_state,
|
||||
max_iterations=max_iterations,
|
||||
confirmation_mode=confirmation_mode,
|
||||
)
|
||||
self.max_budget_per_task = max_budget_per_task
|
||||
self.agent_to_llm_config = agent_to_llm_config if agent_to_llm_config else {}
|
||||
self.agent_configs = agent_configs if agent_configs else {}
|
||||
|
||||
# stuck helper
|
||||
self._stuck_detector = StuckDetector(self.state)
|
||||
|
||||
if not is_delegate:
|
||||
self.agent_task = asyncio.create_task(self._start_step_loop())
|
||||
|
||||
async def close(self):
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
|
||||
if self.agent_task is not None:
|
||||
self.agent_task.cancel()
|
||||
await self.set_agent_state_to(AgentState.STOPPED)
|
||||
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
|
||||
|
||||
def update_state_before_step(self):
|
||||
self.state.iteration += 1
|
||||
self.state.local_iteration += 1
|
||||
|
||||
async def update_state_after_step(self):
|
||||
# update metrics especially for cost
|
||||
self.state.local_metrics = self.agent.llm.metrics
|
||||
|
||||
async def report_error(self, message: str, exception: Exception | None = None):
|
||||
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.
|
||||
|
||||
This method should be called for a particular type of errors, which have:
|
||||
- a user-friendly message, which will be shown in the chat box. This should not be a raw exception message.
|
||||
- an ErrorObservation that can be sent to the LLM by the agent, with the exception message, so it can self-correct next time.
|
||||
"""
|
||||
self.state.last_error = message
|
||||
if exception:
|
||||
self.state.last_error += f': {exception}'
|
||||
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
|
||||
|
||||
async def _start_step_loop(self):
|
||||
"""The main loop for the agent's step-by-step execution."""
|
||||
|
||||
logger.info(f'[Agent Controller {self.id}] Starting step loop...')
|
||||
while True:
|
||||
try:
|
||||
await self._step()
|
||||
except asyncio.CancelledError:
|
||||
logger.info('AgentController task was cancelled')
|
||||
break
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f'Error while running the agent: {e}')
|
||||
logger.error(traceback.format_exc())
|
||||
await self.report_error(
|
||||
'There was an unexpected error while running the agent', exception=e
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
"""Callback from the event stream. Notifies the controller of incoming events.
|
||||
|
||||
Args:
|
||||
event (Event): The incoming event to process.
|
||||
"""
|
||||
if isinstance(event, ChangeAgentStateAction):
|
||||
await self.set_agent_state_to(event.agent_state) # type: ignore
|
||||
elif isinstance(event, MessageAction):
|
||||
if event.source == EventSource.USER:
|
||||
logger.info(
|
||||
event,
|
||||
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
|
||||
)
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif event.source == EventSource.AGENT and event.wait_for_response:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
elif isinstance(event, AgentDelegateAction):
|
||||
await self.start_delegate(event)
|
||||
elif isinstance(event, AddTaskAction):
|
||||
self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
|
||||
elif isinstance(event, ModifyTaskAction):
|
||||
self.state.root_task.set_subtask_state(event.task_id, event.state)
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
self.state.outputs = event.outputs
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.FINISHED)
|
||||
elif isinstance(event, AgentRejectAction):
|
||||
self.state.outputs = event.outputs
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.REJECTED)
|
||||
elif isinstance(event, Observation):
|
||||
if (
|
||||
self._pending_action
|
||||
and hasattr(self._pending_action, 'is_confirmed')
|
||||
and self._pending_action.is_confirmed
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
return
|
||||
if self._pending_action and self._pending_action.id == event.cause:
|
||||
self._pending_action = None
|
||||
if self.state.agent_state == AgentState.USER_CONFIRMED:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
if self.state.agent_state == AgentState.USER_REJECTED:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
elif isinstance(event, CmdOutputObservation):
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
elif isinstance(event, AgentDelegateObservation):
|
||||
self.state.history.on_event(event)
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
elif isinstance(event, ErrorObservation):
|
||||
logger.info(event, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
def reset_task(self):
|
||||
"""Resets the agent's task."""
|
||||
|
||||
self.almost_stuck = 0
|
||||
self.agent.reset()
|
||||
|
||||
async def set_agent_state_to(self, new_state: AgentState):
|
||||
"""Updates the agent's state and handles side effects. Can emit events to the event stream.
|
||||
|
||||
Args:
|
||||
new_state (AgentState): The new state to set for the agent.
|
||||
"""
|
||||
logger.debug(
|
||||
f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
|
||||
)
|
||||
|
||||
if new_state == self.state.agent_state:
|
||||
return
|
||||
|
||||
if (
|
||||
self.state.agent_state == AgentState.PAUSED
|
||||
and new_state == AgentState.RUNNING
|
||||
and self.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
):
|
||||
# user intends to interrupt traffic control and let the task resume temporarily
|
||||
self.state.traffic_control_state = TrafficControlState.PAUSED
|
||||
|
||||
self.state.agent_state = new_state
|
||||
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
|
||||
self.reset_task()
|
||||
|
||||
if self._pending_action is not None and (
|
||||
new_state == AgentState.USER_CONFIRMED
|
||||
or new_state == AgentState.USER_REJECTED
|
||||
):
|
||||
if hasattr(self._pending_action, 'thought'):
|
||||
self._pending_action.thought = '' # type: ignore[union-attr]
|
||||
if new_state == AgentState.USER_CONFIRMED:
|
||||
self._pending_action.is_confirmed = ActionConfirmationStatus.CONFIRMED # type: ignore[attr-defined]
|
||||
else:
|
||||
self._pending_action.is_confirmed = ActionConfirmationStatus.REJECTED # type: ignore[attr-defined]
|
||||
self.event_stream.add_event(self._pending_action, EventSource.AGENT)
|
||||
|
||||
self.event_stream.add_event(
|
||||
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
|
||||
)
|
||||
|
||||
if new_state == AgentState.INIT and self.state.resume_state:
|
||||
await self.set_agent_state_to(self.state.resume_state)
|
||||
self.state.resume_state = None
|
||||
|
||||
def get_agent_state(self):
|
||||
"""Returns the current state of the agent.
|
||||
|
||||
Returns:
|
||||
AgentState: The current state of the agent.
|
||||
"""
|
||||
return self.state.agent_state
|
||||
|
||||
async def start_delegate(self, action: AgentDelegateAction):
|
||||
"""Start a delegate agent to handle a subtask.
|
||||
|
||||
OpenHands is a multi-agentic system. A `task` is a conversation between
|
||||
OpenHands (the whole system) and the user, which might involve one or more inputs
|
||||
from the user. It starts with an initial input (typically a task statement) from
|
||||
the user, and ends with either an `AgentFinishAction` initiated by the agent, a
|
||||
stop initiated by the user, or an error.
|
||||
|
||||
A `subtask` is a conversation between an agent and the user, or another agent. If a `task`
|
||||
is conducted by a single agent, then it's also a `subtask`. Otherwise, a `task` consists of
|
||||
multiple `subtasks`, each executed by one agent.
|
||||
|
||||
Args:
|
||||
action (AgentDelegateAction): The action containing information about the delegate agent to start.
|
||||
"""
|
||||
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
|
||||
agent_config = self.agent_configs.get(action.agent, self.agent.config)
|
||||
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
|
||||
llm = LLM(config=llm_config)
|
||||
delegate_agent = agent_cls(llm=llm, config=agent_config)
|
||||
state = State(
|
||||
inputs=action.inputs or {},
|
||||
local_iteration=0,
|
||||
iteration=self.state.iteration,
|
||||
max_iterations=self.state.max_iterations,
|
||||
delegate_level=self.state.delegate_level + 1,
|
||||
# global metrics should be shared between parent and child
|
||||
metrics=self.state.metrics,
|
||||
)
|
||||
logger.info(
|
||||
f'[Agent Controller {self.id}]: start delegate, creating agent {delegate_agent.name} using LLM {llm}'
|
||||
)
|
||||
self.delegate = AgentController(
|
||||
sid=self.id + '-delegate',
|
||||
agent=delegate_agent,
|
||||
event_stream=self.event_stream,
|
||||
max_iterations=self.state.max_iterations,
|
||||
max_budget_per_task=self.max_budget_per_task,
|
||||
agent_to_llm_config=self.agent_to_llm_config,
|
||||
agent_configs=self.agent_configs,
|
||||
initial_state=state,
|
||||
is_delegate=True,
|
||||
headless_mode=self.headless_mode,
|
||||
)
|
||||
await self.delegate.set_agent_state_to(AgentState.RUNNING)
|
||||
|
||||
async def _step(self) -> None:
|
||||
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await asyncio.sleep(1)
|
||||
return
|
||||
|
||||
if self._pending_action:
|
||||
logger.debug(
|
||||
f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
return
|
||||
|
||||
if self.delegate is not None:
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
|
||||
assert self.delegate != self
|
||||
await self.delegate._step()
|
||||
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
|
||||
assert self.delegate is not None
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
logger.debug(
|
||||
f'[Agent Controller {self.id}] Delegate state: {delegate_state}'
|
||||
)
|
||||
if delegate_state == AgentState.ERROR:
|
||||
# close the delegate upon error
|
||||
await self.delegate.close()
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
await self.report_error('Delegator agent encounters an error')
|
||||
return
|
||||
delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED)
|
||||
if delegate_done:
|
||||
logger.info(
|
||||
f'[Agent Controller {self.id}] Delegate agent has finished execution'
|
||||
)
|
||||
# retrieve delegate result
|
||||
outputs = self.delegate.state.outputs if self.delegate.state else {}
|
||||
|
||||
# update iteration that shall be shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close delegate controller: we must close the delegate controller before adding new events
|
||||
await self.delegate.close()
|
||||
|
||||
# update delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in outputs.items()
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
obs: Observation = AgentDelegateObservation(
|
||||
outputs=outputs, content=content
|
||||
)
|
||||
|
||||
# clean up delegate status
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration}',
|
||||
extra={'msg_type': 'STEP'},
|
||||
)
|
||||
|
||||
if self.state.iteration >= self.state.max_iterations:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
if self.headless_mode:
|
||||
# set to ERROR state if running in headless mode
|
||||
# since user cannot resume on the web interface
|
||||
await self.report_error(
|
||||
'Agent reached maximum number of iterations in headless mode, task stopped.'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
else:
|
||||
await self.report_error(
|
||||
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
elif self.max_budget_per_task is not None:
|
||||
current_cost = self.state.metrics.accumulated_cost
|
||||
if current_cost > self.max_budget_per_task:
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
logger.info(
|
||||
'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
if self.headless_mode:
|
||||
# set to ERROR state if running in headless mode
|
||||
# there is no way to resume
|
||||
await self.report_error(
|
||||
f'Task budget exceeded. Current cost: {current_cost:.2f}, max budget: {self.max_budget_per_task:.2f}, task stopped.'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
else:
|
||||
await self.report_error(
|
||||
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
return
|
||||
|
||||
self.update_state_before_step()
|
||||
action: Action = NullAction()
|
||||
try:
|
||||
action = self.agent.step(self.state)
|
||||
if action is None:
|
||||
raise LLMNoActionError('No action was returned')
|
||||
except (LLMMalformedActionError, LLMNoActionError, LLMResponseError) as e:
|
||||
# report to the user
|
||||
# and send the underlying exception to the LLM for self-correction
|
||||
await self.report_error(str(e))
|
||||
return
|
||||
|
||||
if action.runnable:
|
||||
if self.state.confirmation_mode and (
|
||||
type(action) is CmdRunAction or type(action) is IPythonRunCellAction
|
||||
):
|
||||
action.is_confirmed = ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
self._pending_action = action
|
||||
|
||||
if not isinstance(action, NullAction):
|
||||
if (
|
||||
hasattr(action, 'is_confirmed')
|
||||
and action.is_confirmed
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION)
|
||||
self.event_stream.add_event(action, EventSource.AGENT)
|
||||
|
||||
await self.update_state_after_step()
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
|
||||
if self._is_stuck():
|
||||
await self.report_error('Agent got stuck in a loop')
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
|
||||
def get_state(self):
|
||||
"""Returns the current running state object.
|
||||
|
||||
Returns:
|
||||
State: The current state object.
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def set_initial_state(
|
||||
self,
|
||||
state: State | None,
|
||||
max_iterations: int,
|
||||
confirmation_mode: bool = False,
|
||||
):
|
||||
"""Sets the initial state for the agent, either from the previous session, or from a parent agent, or by creating a new one.
|
||||
|
||||
Args:
|
||||
state: The state to initialize with, or None to create a new state.
|
||||
max_iterations: The maximum number of iterations allowed for the task.
|
||||
confirmation_mode: Whether to enable confirmation mode.
|
||||
"""
|
||||
# state from the previous session, state from a parent agent, or a new state
|
||||
# note that this is called twice when restoring a previous session, first with state=None
|
||||
if state is None:
|
||||
self.state = State(
|
||||
inputs={},
|
||||
max_iterations=max_iterations,
|
||||
confirmation_mode=confirmation_mode,
|
||||
)
|
||||
else:
|
||||
self.state = state
|
||||
|
||||
# when restored from a previous session, the State object will have history, start_id, and end_id
|
||||
# connect it to the event stream
|
||||
self.state.history.set_event_stream(self.event_stream)
|
||||
|
||||
# if start_id was not set in State, we're starting fresh, at the top of the stream
|
||||
start_id = self.state.start_id
|
||||
if start_id == -1:
|
||||
start_id = self.event_stream.get_latest_event_id() + 1
|
||||
else:
|
||||
logger.debug(f'AgentController {self.id} restoring from event {start_id}')
|
||||
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
self.state.history.start_id = start_id
|
||||
|
||||
# if there was an end_id saved in State, set it in history
|
||||
# currently not used, later useful for delegates
|
||||
if self.state.end_id > -1:
|
||||
self.state.history.end_id = self.state.end_id
|
||||
|
||||
def _is_stuck(self):
|
||||
"""Checks if the agent or its delegate is stuck in a loop.
|
||||
|
||||
Returns:
|
||||
bool: True if the agent is stuck, False otherwise.
|
||||
"""
|
||||
# check if delegate stuck
|
||||
if self.delegate and self.delegate._is_stuck():
|
||||
return True
|
||||
|
||||
return self._stuck_detector.is_stuck()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'AgentController(id={self.id}, agent={self.agent!r}, '
|
||||
f'event_stream={self.event_stream!r}, '
|
||||
f'state={self.state!r}, agent_task={self.agent_task!r}, '
|
||||
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
|
||||
)
|
||||
172
openhands/controller/state/state.py
Normal file
172
openhands/controller/state/state.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import base64
|
||||
import pickle
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from openhands.controller.state.task import RootTask
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.metrics import Metrics
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import AgentFinishAction
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class TrafficControlState(str, Enum):
|
||||
# default state, no rate limiting
|
||||
NORMAL = 'normal'
|
||||
|
||||
# task paused due to traffic control
|
||||
THROTTLING = 'throttling'
|
||||
|
||||
# traffic control is temporarily paused
|
||||
PAUSED = 'paused'
|
||||
|
||||
|
||||
RESUMABLE_STATES = [
|
||||
AgentState.RUNNING,
|
||||
AgentState.PAUSED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""
|
||||
Represents the running state of an agent in the OpenHands system, saving data of its operation and memory.
|
||||
|
||||
- Multi-agent/delegate state:
|
||||
- store the task (conversation between the agent and the user)
|
||||
- the subtask (conversation between an agent and the user or another agent)
|
||||
- global and local iterations
|
||||
- delegate levels for multi-agent interactions
|
||||
- almost stuck state
|
||||
|
||||
- Running state of an agent:
|
||||
- current agent state (e.g., LOADING, RUNNING, PAUSED)
|
||||
- traffic control state for rate limiting
|
||||
- confirmation mode
|
||||
- the last error encountered
|
||||
|
||||
- Data for saving and restoring the agent:
|
||||
- save to and restore from a session
|
||||
- serialize with pickle and base64
|
||||
|
||||
- Save / restore data about message history
|
||||
- start and end IDs for events in agent's history
|
||||
- summaries and delegate summaries
|
||||
|
||||
- Metrics:
|
||||
- global metrics for the current task
|
||||
- local metrics for the current subtask
|
||||
|
||||
- Extra data:
|
||||
- additional task-specific data
|
||||
"""
|
||||
|
||||
root_task: RootTask = field(default_factory=RootTask)
|
||||
# global iteration for the current task
|
||||
iteration: int = 0
|
||||
# local iteration for the current subtask
|
||||
local_iteration: int = 0
|
||||
# max number of iterations for the current task
|
||||
max_iterations: int = 100
|
||||
confirmation_mode: bool = False
|
||||
history: ShortTermHistory = field(default_factory=ShortTermHistory)
|
||||
inputs: dict = field(default_factory=dict)
|
||||
outputs: dict = field(default_factory=dict)
|
||||
last_error: str | None = None
|
||||
agent_state: AgentState = AgentState.LOADING
|
||||
resume_state: AgentState | None = None
|
||||
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
|
||||
# global metrics for the current task
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
# local metrics for the current subtask
|
||||
local_metrics: Metrics = field(default_factory=Metrics)
|
||||
# root agent has level 0, and every delegate increases the level by one
|
||||
delegate_level: int = 0
|
||||
# start_id and end_id track the range of events in history
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
almost_stuck: int = 0
|
||||
# NOTE: This will never be used by the controller, but it can be used by different
|
||||
# evaluation tasks to store extra data needed to track the progress/state of the task.
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def save_to_session(self, sid: str, file_store: FileStore):
|
||||
pickled = pickle.dumps(self)
|
||||
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
|
||||
encoded = base64.b64encode(pickled).decode('utf-8')
|
||||
try:
|
||||
file_store.write(f'sessions/{sid}/agent_state.pkl', encoded)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to save state to session: {e}')
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def restore_from_session(sid: str, file_store: FileStore) -> 'State':
|
||||
try:
|
||||
encoded = file_store.read(f'sessions/{sid}/agent_state.pkl')
|
||||
pickled = base64.b64decode(encoded)
|
||||
state = pickle.loads(pickled)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to restore state from session: {e}')
|
||||
raise e
|
||||
|
||||
# update state
|
||||
if state.agent_state in RESUMABLE_STATES:
|
||||
state.resume_state = state.agent_state
|
||||
else:
|
||||
state.resume_state = None
|
||||
|
||||
# don't carry last_error anymore after restore
|
||||
state.last_error = None
|
||||
|
||||
# first state after restore
|
||||
state.agent_state = AgentState.LOADING
|
||||
return state
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
|
||||
# save the relevant data from recent history
|
||||
# so that we can restore it when the state is restored
|
||||
if 'history' in state:
|
||||
state['start_id'] = state['history'].start_id
|
||||
state['end_id'] = state['history'].end_id
|
||||
|
||||
# don't save history object itself
|
||||
state.pop('history', None)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
# recreate the history object
|
||||
if not hasattr(self, 'history'):
|
||||
self.history = ShortTermHistory()
|
||||
|
||||
# restore the relevant data in history from the state
|
||||
self.history.start_id = self.start_id
|
||||
self.history.end_id = self.end_id
|
||||
|
||||
# remove the restored data from the state if any
|
||||
|
||||
def get_current_user_intent(self):
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
last_user_message = None
|
||||
last_user_message_image_urls: list[str] | None = []
|
||||
for event in self.history.get_events(reverse=True):
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
last_user_message = event.content
|
||||
last_user_message_image_urls = event.images_urls
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
if last_user_message is not None:
|
||||
return last_user_message
|
||||
|
||||
return last_user_message, last_user_message_image_urls
|
||||
226
openhands/controller/state/task.py
Normal file
226
openhands/controller/state/task.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from openhands.core.exceptions import (
|
||||
LLMMalformedActionError,
|
||||
TaskInvalidStateError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
OPEN_STATE = 'open'
|
||||
COMPLETED_STATE = 'completed'
|
||||
ABANDONED_STATE = 'abandoned'
|
||||
IN_PROGRESS_STATE = 'in_progress'
|
||||
VERIFIED_STATE = 'verified'
|
||||
STATES = [
|
||||
OPEN_STATE,
|
||||
COMPLETED_STATE,
|
||||
ABANDONED_STATE,
|
||||
IN_PROGRESS_STATE,
|
||||
VERIFIED_STATE,
|
||||
]
|
||||
|
||||
|
||||
class Task:
|
||||
id: str
|
||||
goal: str
|
||||
parent: 'Task | None'
|
||||
subtasks: list['Task']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent: 'Task',
|
||||
goal: str,
|
||||
state: str = OPEN_STATE,
|
||||
subtasks=None, # noqa: B006
|
||||
):
|
||||
"""Initializes a new instance of the Task class.
|
||||
|
||||
Args:
|
||||
parent: The parent task, or None if it is the root task.
|
||||
goal: The goal of the task.
|
||||
state: The initial state of the task.
|
||||
subtasks: A list of subtasks associated with this task.
|
||||
"""
|
||||
if subtasks is None:
|
||||
subtasks = []
|
||||
if parent.id:
|
||||
self.id = parent.id + '.' + str(len(parent.subtasks))
|
||||
else:
|
||||
self.id = str(len(parent.subtasks))
|
||||
self.parent = parent
|
||||
self.goal = goal
|
||||
logger.debug(f'Creating task {self.id} with parent={parent.id}, goal={goal}')
|
||||
self.subtasks = []
|
||||
for subtask in subtasks or []:
|
||||
if isinstance(subtask, Task):
|
||||
self.subtasks.append(subtask)
|
||||
else:
|
||||
goal = subtask.get('goal')
|
||||
state = subtask.get('state')
|
||||
subtasks = subtask.get('subtasks')
|
||||
logger.debug(f'Reading: {goal}, {state}, {subtasks}')
|
||||
self.subtasks.append(Task(self, goal, state, subtasks))
|
||||
|
||||
self.state = OPEN_STATE
|
||||
|
||||
def to_string(self, indent=''):
|
||||
"""Returns a string representation of the task and its subtasks.
|
||||
|
||||
Args:
|
||||
indent: The indentation string for formatting the output.
|
||||
|
||||
Returns:
|
||||
A string representation of the task and its subtasks.
|
||||
"""
|
||||
emoji = ''
|
||||
if self.state == VERIFIED_STATE:
|
||||
emoji = '✅'
|
||||
elif self.state == COMPLETED_STATE:
|
||||
emoji = '🟢'
|
||||
elif self.state == ABANDONED_STATE:
|
||||
emoji = '❌'
|
||||
elif self.state == IN_PROGRESS_STATE:
|
||||
emoji = '💪'
|
||||
elif self.state == OPEN_STATE:
|
||||
emoji = '🔵'
|
||||
result = indent + emoji + ' ' + self.id + ' ' + self.goal + '\n'
|
||||
for subtask in self.subtasks:
|
||||
result += subtask.to_string(indent + ' ')
|
||||
return result
|
||||
|
||||
def to_dict(self):
|
||||
"""Returns a dictionary representation of the task.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the task's attributes.
|
||||
"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'goal': self.goal,
|
||||
'state': self.state,
|
||||
'subtasks': [t.to_dict() for t in self.subtasks],
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
"""Sets the state of the task and its subtasks.
|
||||
|
||||
Args: state: The new state of the task.
|
||||
|
||||
Raises:
|
||||
TaskInvalidStateError: If the provided state is invalid.
|
||||
"""
|
||||
if state not in STATES:
|
||||
logger.error('Invalid state: %s', state)
|
||||
raise TaskInvalidStateError(state)
|
||||
self.state = state
|
||||
if (
|
||||
state == COMPLETED_STATE
|
||||
or state == ABANDONED_STATE
|
||||
or state == VERIFIED_STATE
|
||||
):
|
||||
for subtask in self.subtasks:
|
||||
if subtask.state != ABANDONED_STATE:
|
||||
subtask.set_state(state)
|
||||
elif state == IN_PROGRESS_STATE:
|
||||
if self.parent is not None:
|
||||
self.parent.set_state(state)
|
||||
|
||||
def get_current_task(self) -> 'Task | None':
|
||||
"""Retrieves the current task in progress.
|
||||
|
||||
Returns:
|
||||
The current task in progress, or None if no task is in progress.
|
||||
"""
|
||||
for subtask in self.subtasks:
|
||||
if subtask.state == IN_PROGRESS_STATE:
|
||||
return subtask.get_current_task()
|
||||
if self.state == IN_PROGRESS_STATE:
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class RootTask(Task):
|
||||
"""Serves as the root node in a tree of tasks.
|
||||
Because we want the top-level of the root_task to be a list of tasks (1, 2, 3, etc.),
|
||||
the "root node" of the data structure is kind of invisible--it just
|
||||
holds references to the top-level tasks.
|
||||
|
||||
Attributes:
|
||||
id: Kept blank for root_task
|
||||
goal: Kept blank for root_task
|
||||
parent: None for root_task
|
||||
subtasks: The top-level list of tasks associated with the root_task.
|
||||
state: The state of the root_task.
|
||||
"""
|
||||
|
||||
id: str = ''
|
||||
goal: str = ''
|
||||
parent: None = None
|
||||
|
||||
def __init__(self):
|
||||
self.subtasks = []
|
||||
self.state = OPEN_STATE
|
||||
|
||||
def __str__(self):
|
||||
"""Returns a string representation of the root_task.
|
||||
|
||||
Returns:
|
||||
A string representation of the root_task.
|
||||
"""
|
||||
return self.to_string()
|
||||
|
||||
def get_task_by_id(self, id: str) -> Task:
|
||||
"""Retrieves a task by its ID.
|
||||
|
||||
Args:
|
||||
id: The ID of the task.
|
||||
|
||||
Returns:
|
||||
The task with the specified ID.
|
||||
|
||||
Raises:
|
||||
AgentMalformedActionError: If the provided task ID is invalid or does not exist.
|
||||
"""
|
||||
if id == '':
|
||||
return self
|
||||
if len(self.subtasks) == 0:
|
||||
raise LLMMalformedActionError('Task does not exist:' + id)
|
||||
try:
|
||||
parts = [int(p) for p in id.split('.')]
|
||||
except ValueError:
|
||||
raise LLMMalformedActionError('Invalid task id:' + id)
|
||||
task: Task = self
|
||||
for part in parts:
|
||||
if part >= len(task.subtasks):
|
||||
raise LLMMalformedActionError('Task does not exist:' + id)
|
||||
task = task.subtasks[part]
|
||||
return task
|
||||
|
||||
def add_subtask(self, parent_id: str, goal: str, subtasks: list | None = None):
|
||||
"""Adds a subtask to a parent task.
|
||||
|
||||
Args:
|
||||
parent_id: The ID of the parent task.
|
||||
goal: The goal of the subtask.
|
||||
subtasks: A list of subtasks associated with the new subtask.
|
||||
"""
|
||||
subtasks = subtasks or []
|
||||
parent = self.get_task_by_id(parent_id)
|
||||
child = Task(parent=parent, goal=goal, subtasks=subtasks)
|
||||
parent.subtasks.append(child)
|
||||
|
||||
def set_subtask_state(self, id: str, state: str):
|
||||
"""Sets the state of a subtask.
|
||||
|
||||
Args:
|
||||
id: The ID of the subtask.
|
||||
state: The new state of the subtask.
|
||||
"""
|
||||
task = self.get_task_by_id(id)
|
||||
logger.debug('Setting task {task.id} from state {task.state} to {state}')
|
||||
task.set_state(state)
|
||||
unfinished_tasks = [
|
||||
t
|
||||
for t in self.subtasks
|
||||
if t.state not in [COMPLETED_STATE, VERIFIED_STATE, ABANDONED_STATE]
|
||||
]
|
||||
if len(unfinished_tasks) == 0:
|
||||
self.set_state(COMPLETED_STATE)
|
||||
237
openhands/controller/stuck.py
Normal file
237
openhands/controller/stuck.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from typing import cast
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
class StuckDetector:
|
||||
def __init__(self, state: State):
|
||||
self.state = state
|
||||
|
||||
def is_stuck(self):
|
||||
# filter out MessageAction with source='user' from history
|
||||
filtered_history = [
|
||||
event
|
||||
for event in self.state.history.get_events()
|
||||
if not (
|
||||
(isinstance(event, MessageAction) and event.source == EventSource.USER)
|
||||
or
|
||||
# there might be some NullAction or NullObservation in the history at least for now
|
||||
isinstance(event, NullAction)
|
||||
or isinstance(event, NullObservation)
|
||||
)
|
||||
]
|
||||
|
||||
# it takes 3 actions minimum to detect a loop, otherwise nothing to do here
|
||||
if len(filtered_history) < 3:
|
||||
return False
|
||||
|
||||
# the first few scenarios detect 3 or 4 repeated steps
|
||||
# prepare the last 4 actions and observations, to check them out
|
||||
last_actions: list[Event] = []
|
||||
last_observations: list[Event] = []
|
||||
|
||||
# retrieve the last four actions and observations starting from the end of history, wherever they are
|
||||
for event in reversed(filtered_history):
|
||||
if isinstance(event, Action) and len(last_actions) < 4:
|
||||
last_actions.append(event)
|
||||
elif isinstance(event, Observation) and len(last_observations) < 4:
|
||||
last_observations.append(event)
|
||||
|
||||
if len(last_actions) == 4 and len(last_observations) == 4:
|
||||
break
|
||||
|
||||
# scenario 1: same action, same observation
|
||||
if self._is_stuck_repeating_action_observation(last_actions, last_observations):
|
||||
return True
|
||||
|
||||
# scenario 2: same action, errors
|
||||
if self._is_stuck_repeating_action_error(last_actions, last_observations):
|
||||
return True
|
||||
|
||||
# scenario 3: monologue
|
||||
if self._is_stuck_monologue(filtered_history):
|
||||
return True
|
||||
|
||||
# scenario 4: action, observation pattern on the last six steps
|
||||
if len(filtered_history) < 6:
|
||||
return False
|
||||
if self._is_stuck_action_observation_pattern(filtered_history):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
|
||||
# scenario 1: same action, same observation
|
||||
# it takes 4 actions and 4 observations to detect a loop
|
||||
# assert len(last_actions) == 4 and len(last_observations) == 4
|
||||
|
||||
# reset almost_stuck reminder
|
||||
self.state.almost_stuck = 0
|
||||
|
||||
# almost stuck? if two actions, obs are the same, we're almost stuck
|
||||
if len(last_actions) >= 2 and len(last_observations) >= 2:
|
||||
actions_equal = all(
|
||||
self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
|
||||
)
|
||||
observations_equal = all(
|
||||
self._eq_no_pid(last_observations[0], observation)
|
||||
for observation in last_observations[:2]
|
||||
)
|
||||
|
||||
# the last two actions and obs are the same?
|
||||
if actions_equal and observations_equal:
|
||||
self.state.almost_stuck = 2
|
||||
|
||||
# the last three actions and observations are the same?
|
||||
if len(last_actions) >= 3 and len(last_observations) >= 3:
|
||||
if (
|
||||
actions_equal
|
||||
and observations_equal
|
||||
and self._eq_no_pid(last_actions[0], last_actions[2])
|
||||
and self._eq_no_pid(last_observations[0], last_observations[2])
|
||||
):
|
||||
self.state.almost_stuck = 1
|
||||
|
||||
if len(last_actions) == 4 and len(last_observations) == 4:
|
||||
if (
|
||||
actions_equal
|
||||
and observations_equal
|
||||
and self._eq_no_pid(last_actions[0], last_actions[3])
|
||||
and self._eq_no_pid(last_observations[0], last_observations[3])
|
||||
):
|
||||
logger.warning('Action, Observation loop detected')
|
||||
self.state.almost_stuck = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_error(self, last_actions, last_observations):
|
||||
# scenario 2: same action, errors
|
||||
# it takes 4 actions and 4 observations to detect a loop
|
||||
# check if the last four actions are the same and result in errors
|
||||
|
||||
# are the last four actions the same?
|
||||
if len(last_actions) == 4 and all(
|
||||
self._eq_no_pid(last_actions[0], action) for action in last_actions
|
||||
):
|
||||
# and the last four observations all errors?
|
||||
if all(isinstance(obs, ErrorObservation) for obs in last_observations):
|
||||
logger.warning('Action, ErrorObservation loop detected')
|
||||
return True
|
||||
# or, are the last four observations all IPythonRunCellObservation with SyntaxError?
|
||||
elif all(
|
||||
isinstance(obs, IPythonRunCellObservation) for obs in last_observations
|
||||
) and all(
|
||||
cast(IPythonRunCellObservation, obs)
|
||||
.content[-100:]
|
||||
.find('SyntaxError: unterminated string literal (detected at line')
|
||||
!= -1
|
||||
and len(
|
||||
cast(IPythonRunCellObservation, obs).content.split(
|
||||
'SyntaxError: unterminated string literal (detected at line'
|
||||
)[-1]
|
||||
)
|
||||
< 10
|
||||
for obs in last_observations
|
||||
):
|
||||
logger.warning('Action, IPythonRunCellObservation loop detected')
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_monologue(self, filtered_history):
|
||||
# scenario 3: monologue
|
||||
# check for repeated MessageActions with source=AGENT
|
||||
# see if the agent is engaged in a good old monologue, telling itself the same thing over and over
|
||||
agent_message_actions = [
|
||||
(i, event)
|
||||
for i, event in enumerate(filtered_history)
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.AGENT
|
||||
]
|
||||
|
||||
# last three message actions will do for this check
|
||||
if len(agent_message_actions) >= 3:
|
||||
last_agent_message_actions = agent_message_actions[-3:]
|
||||
|
||||
if all(
|
||||
(last_agent_message_actions[0][1] == action[1])
|
||||
for action in last_agent_message_actions
|
||||
):
|
||||
# check if there are any observations between the repeated MessageActions
|
||||
# then it's not yet a loop, maybe it can recover
|
||||
start_index = last_agent_message_actions[0][0]
|
||||
end_index = last_agent_message_actions[-1][0]
|
||||
|
||||
has_observation_between = False
|
||||
for event in filtered_history[start_index + 1 : end_index]:
|
||||
if isinstance(event, Observation):
|
||||
has_observation_between = True
|
||||
break
|
||||
|
||||
if not has_observation_between:
|
||||
logger.warning('Repeated MessageAction with source=AGENT detected')
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_action_observation_pattern(self, filtered_history):
|
||||
# scenario 4: action, observation pattern on the last six steps
|
||||
# check if the agent repeats the same (Action, Observation)
|
||||
# every other step in the last six steps
|
||||
last_six_actions: list[Event] = []
|
||||
last_six_observations: list[Event] = []
|
||||
|
||||
# the end of history is most interesting
|
||||
for event in reversed(filtered_history):
|
||||
if isinstance(event, Action) and len(last_six_actions) < 6:
|
||||
last_six_actions.append(event)
|
||||
elif isinstance(event, Observation) and len(last_six_observations) < 6:
|
||||
last_six_observations.append(event)
|
||||
|
||||
if len(last_six_actions) == 6 and len(last_six_observations) == 6:
|
||||
break
|
||||
|
||||
# this pattern is every other step, like:
|
||||
# (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
|
||||
if len(last_six_actions) == 6 and len(last_six_observations) == 6:
|
||||
actions_equal = (
|
||||
# action_0 == action_2 == action_4
|
||||
self._eq_no_pid(last_six_actions[0], last_six_actions[2])
|
||||
and self._eq_no_pid(last_six_actions[0], last_six_actions[4])
|
||||
# action_1 == action_3 == action_5
|
||||
and self._eq_no_pid(last_six_actions[1], last_six_actions[3])
|
||||
and self._eq_no_pid(last_six_actions[1], last_six_actions[5])
|
||||
)
|
||||
observations_equal = (
|
||||
# obs_0 == obs_2 == obs_4
|
||||
self._eq_no_pid(last_six_observations[0], last_six_observations[2])
|
||||
and self._eq_no_pid(last_six_observations[0], last_six_observations[4])
|
||||
# obs_1 == obs_3 == obs_5
|
||||
and self._eq_no_pid(last_six_observations[1], last_six_observations[3])
|
||||
and self._eq_no_pid(last_six_observations[1], last_six_observations[5])
|
||||
)
|
||||
|
||||
if actions_equal and observations_equal:
|
||||
logger.warning('Action, Observation pattern detected')
|
||||
return True
|
||||
return False
|
||||
|
||||
def _eq_no_pid(self, obj1, obj2):
|
||||
if isinstance(obj1, CmdOutputObservation) and isinstance(
|
||||
obj2, CmdOutputObservation
|
||||
):
|
||||
# for loop detection, ignore command_id, which is the pid
|
||||
return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
|
||||
else:
|
||||
# this is the default comparison
|
||||
return obj1 == obj2
|
||||
762
openhands/core/config.py
Normal file
762
openhands/core/config.py
Normal file
@@ -0,0 +1,762 @@
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from types import UnionType
|
||||
from typing import Any, ClassVar, MutableMapping, get_args, get_origin
|
||||
|
||||
import toml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from openhands.core import logger
|
||||
from openhands.core.utils import Singleton
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key']
|
||||
_DEFAULT_AGENT = 'CodeActAgent'
|
||||
_MAX_ITERATIONS = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""Configuration for the LLM model.
|
||||
|
||||
Attributes:
|
||||
model: The model to use.
|
||||
api_key: The API key to use.
|
||||
base_url: The base URL for the API. This is necessary for local LLMs. It is also used for Azure embeddings.
|
||||
api_version: The version of the API.
|
||||
embedding_model: The embedding model to use.
|
||||
embedding_base_url: The base URL for the embedding API.
|
||||
embedding_deployment_name: The name of the deployment for the embedding API. This is used for Azure OpenAI.
|
||||
aws_access_key_id: The AWS access key ID.
|
||||
aws_secret_access_key: The AWS secret access key.
|
||||
aws_region_name: The AWS region name.
|
||||
num_retries: The number of retries to attempt.
|
||||
retry_multiplier: The multiplier for the exponential backoff.
|
||||
retry_min_wait: The minimum time to wait between retries, in seconds. This is exponential backoff minimum. For models with very low limits, this can be set to 15-20.
|
||||
retry_max_wait: The maximum time to wait between retries, in seconds. This is exponential backoff maximum.
|
||||
timeout: The timeout for the API.
|
||||
max_message_chars: The approximate max number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated.
|
||||
temperature: The temperature for the API.
|
||||
top_p: The top p for the API.
|
||||
custom_llm_provider: The custom LLM provider to use. This is undocumented in openhands, and normally not used. It is documented on the litellm side.
|
||||
max_input_tokens: The maximum number of input tokens. Note that this is currently unused, and the value at runtime is actually the total tokens in OpenAI (e.g. 128,000 tokens for GPT-4).
|
||||
max_output_tokens: The maximum number of output tokens. This is sent to the LLM.
|
||||
input_cost_per_token: The cost per input token. This will available in logs for the user to check.
|
||||
output_cost_per_token: The cost per output token. This will available in logs for the user to check.
|
||||
ollama_base_url: The base URL for the OLLAMA API.
|
||||
drop_params: Drop any unmapped (unsupported) params without causing an exception.
|
||||
"""
|
||||
|
||||
model: str = 'gpt-4o'
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
api_version: str | None = None
|
||||
embedding_model: str = 'local'
|
||||
embedding_base_url: str | None = None
|
||||
embedding_deployment_name: str | None = None
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
aws_region_name: str | None = None
|
||||
num_retries: int = 10
|
||||
retry_multiplier: float = 2
|
||||
retry_min_wait: int = 3
|
||||
retry_max_wait: int = 300
|
||||
timeout: int | None = None
|
||||
max_message_chars: int = 10_000 # maximum number of characters in an observation's content when sent to the llm
|
||||
temperature: float = 0
|
||||
top_p: float = 0.5
|
||||
custom_llm_provider: str | None = None
|
||||
max_input_tokens: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
input_cost_per_token: float | None = None
|
||||
output_cost_per_token: float | None = None
|
||||
ollama_base_url: str | None = None
|
||||
drop_params: bool | None = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
result = {}
|
||||
for f in fields(self):
|
||||
result[f.name] = get_field_info(f)
|
||||
return result
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
if attr_name in LLM_SENSITIVE_FIELDS:
|
||||
attr_value = '******' if attr_value else None
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"LLMConfig({', '.join(attr_str)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def to_safe_dict(self):
|
||||
"""Return a dict with the sensitive fields replaced with ******."""
|
||||
ret = self.__dict__.copy()
|
||||
for k, v in ret.items():
|
||||
if k in LLM_SENSITIVE_FIELDS:
|
||||
ret[k] = '******' if v else None
|
||||
return ret
|
||||
|
||||
def set_missing_attributes(self):
|
||||
"""Set any missing attributes to their default values."""
|
||||
for field_name, field_obj in self.__dataclass_fields__.items():
|
||||
if not hasattr(self, field_name):
|
||||
setattr(self, field_name, field_obj.default)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Configuration for the agent.
|
||||
|
||||
Attributes:
|
||||
memory_enabled: Whether long-term memory (embeddings) is enabled.
|
||||
memory_max_threads: The maximum number of threads indexing at the same time for embeddings.
|
||||
llm_config: The name of the llm config to use. If specified, this will override global llm config.
|
||||
"""
|
||||
|
||||
memory_enabled: bool = False
|
||||
memory_max_threads: int = 2
|
||||
llm_config: str | None = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
result = {}
|
||||
for f in fields(self):
|
||||
result[f.name] = get_field_info(f)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig(metaclass=Singleton):
|
||||
"""Configuration for security related functionalities.
|
||||
|
||||
Attributes:
|
||||
confirmation_mode: Whether to enable confirmation mode.
|
||||
security_analyzer: The security analyzer to use.
|
||||
"""
|
||||
|
||||
confirmation_mode: bool = False
|
||||
security_analyzer: str | None = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
dict = {}
|
||||
for f in fields(self):
|
||||
dict[f.name] = get_field_info(f)
|
||||
return dict
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"SecurityConfig({', '.join(attr_str)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxConfig(metaclass=Singleton):
|
||||
"""Configuration for the sandbox.
|
||||
|
||||
Attributes:
|
||||
api_hostname: The hostname for the EventStream Runtime API.
|
||||
container_image: The container image to use for the sandbox.
|
||||
user_id: The user ID for the sandbox.
|
||||
timeout: The timeout for the sandbox.
|
||||
enable_auto_lint: Whether to enable auto-lint.
|
||||
use_host_network: Whether to use the host network.
|
||||
initialize_plugins: Whether to initialize plugins.
|
||||
od_runtime_extra_deps: The extra dependencies to install in the runtime image (typically used for evaluation).
|
||||
This will be rendered into the end of the Dockerfile that builds the runtime image.
|
||||
It can contain any valid shell commands (e.g., pip install numpy).
|
||||
The path to the interpreter is available as $OD_INTERPRETER_PATH,
|
||||
which can be used to install dependencies for the OD-specific Python interpreter.
|
||||
od_runtime_startup_env_vars: The environment variables to set at the launch of the runtime.
|
||||
This is a dictionary of key-value pairs.
|
||||
This is useful for setting environment variables that are needed by the runtime.
|
||||
For example, for specifying the base url of website for browsergym evaluation.
|
||||
browsergym_eval_env: The BrowserGym environment to use for evaluation.
|
||||
Default is None for general purpose browsing. Check evaluation/miniwob and evaluation/webarena for examples.
|
||||
"""
|
||||
|
||||
api_hostname: str = 'localhost'
|
||||
container_image: str = 'nikolaik/python-nodejs:python3.11-nodejs22' # default to nikolaik/python-nodejs:python3.11-nodejs22 for eventstream runtime
|
||||
user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
|
||||
timeout: int = 120
|
||||
enable_auto_lint: bool = (
|
||||
False # once enabled, OpenHands would lint files after editing
|
||||
)
|
||||
use_host_network: bool = False
|
||||
initialize_plugins: bool = True
|
||||
od_runtime_extra_deps: str | None = None
|
||||
od_runtime_startup_env_vars: dict[str, str] = field(default_factory=dict)
|
||||
browsergym_eval_env: str | None = None
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
dict = {}
|
||||
for f in fields(self):
|
||||
dict[f.name] = get_field_info(f)
|
||||
return dict
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"SandboxConfig({', '.join(attr_str)})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class UndefinedString(str, Enum):
|
||||
UNDEFINED = 'UNDEFINED'
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig(metaclass=Singleton):
|
||||
"""Configuration for the app.
|
||||
|
||||
Attributes:
|
||||
llms: A dictionary of name -> LLM configuration. Default config is under 'llm' key.
|
||||
agents: A dictionary of name -> Agent configuration. Default config is under 'agent' key.
|
||||
default_agent: The name of the default agent to use.
|
||||
sandbox: The sandbox configuration.
|
||||
runtime: The runtime environment.
|
||||
file_store: The file store to use.
|
||||
file_store_path: The path to the file store.
|
||||
workspace_base: The base path for the workspace. Defaults to ./workspace as an absolute path.
|
||||
workspace_mount_path: The path to mount the workspace. This is set to the workspace base by default.
|
||||
workspace_mount_path_in_sandbox: The path to mount the workspace in the sandbox. Defaults to /workspace.
|
||||
workspace_mount_rewrite: The path to rewrite the workspace mount path to.
|
||||
cache_dir: The path to the cache directory. Defaults to /tmp/cache.
|
||||
run_as_openhands: Whether to run as openhands.
|
||||
max_iterations: The maximum number of iterations.
|
||||
max_budget_per_task: The maximum budget allowed per task, beyond which the agent will stop.
|
||||
e2b_api_key: The E2B API key.
|
||||
disable_color: Whether to disable color. For terminals that don't support color.
|
||||
debug: Whether to enable debugging.
|
||||
enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
|
||||
file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
|
||||
file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
|
||||
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
|
||||
"""
|
||||
|
||||
llms: dict[str, LLMConfig] = field(default_factory=dict)
|
||||
agents: dict = field(default_factory=dict)
|
||||
default_agent: str = _DEFAULT_AGENT
|
||||
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
runtime: str = 'eventstream'
|
||||
file_store: str = 'memory'
|
||||
file_store_path: str = '/tmp/file_store'
|
||||
# TODO: clean up workspace path after the removal of ServerRuntime
|
||||
workspace_base: str = os.path.join(os.getcwd(), 'workspace')
|
||||
workspace_mount_path: str | None = (
|
||||
UndefinedString.UNDEFINED # this path should always be set when config is fully loaded
|
||||
) # when set to None, do not mount the workspace
|
||||
workspace_mount_path_in_sandbox: str = '/workspace'
|
||||
workspace_mount_rewrite: str | None = None
|
||||
cache_dir: str = '/tmp/cache'
|
||||
run_as_openhands: bool = True
|
||||
max_iterations: int = _MAX_ITERATIONS
|
||||
max_budget_per_task: float | None = None
|
||||
e2b_api_key: str = ''
|
||||
disable_color: bool = False
|
||||
jwt_secret: str = uuid.uuid4().hex
|
||||
debug: bool = False
|
||||
enable_cli_session: bool = False
|
||||
file_uploads_max_file_size_mb: int = 0
|
||||
file_uploads_restrict_file_types: bool = False
|
||||
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
|
||||
|
||||
defaults_dict: ClassVar[dict] = {}
|
||||
|
||||
def get_llm_config(self, name='llm') -> LLMConfig:
|
||||
"""Llm is the name for default config (for backward compatibility prior to 0.8)"""
|
||||
if name in self.llms:
|
||||
return self.llms[name]
|
||||
if name is not None and name != 'llm':
|
||||
logger.openhands_logger.warning(
|
||||
f'llm config group {name} not found, using default config'
|
||||
)
|
||||
if 'llm' not in self.llms:
|
||||
self.llms['llm'] = LLMConfig()
|
||||
return self.llms['llm']
|
||||
|
||||
def set_llm_config(self, value: LLMConfig, name='llm'):
|
||||
self.llms[name] = value
|
||||
|
||||
def get_agent_config(self, name='agent') -> AgentConfig:
|
||||
"""Agent is the name for default config (for backward compability prior to 0.8)"""
|
||||
if name in self.agents:
|
||||
return self.agents[name]
|
||||
if 'agent' not in self.agents:
|
||||
self.agents['agent'] = AgentConfig()
|
||||
return self.agents['agent']
|
||||
|
||||
def set_agent_config(self, value: AgentConfig, name='agent'):
|
||||
self.agents[name] = value
|
||||
|
||||
def get_agent_to_llm_config_map(self) -> dict[str, LLMConfig]:
|
||||
"""Get a map of agent names to llm configs."""
|
||||
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
|
||||
|
||||
def get_llm_config_from_agent(self, name='agent') -> LLMConfig:
|
||||
agent_config: AgentConfig = self.get_agent_config(name)
|
||||
llm_config_name = agent_config.llm_config
|
||||
return self.get_llm_config(llm_config_name)
|
||||
|
||||
def get_agent_configs(self) -> dict[str, AgentConfig]:
|
||||
return self.agents
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization hook, called when the instance is created with only default values."""
|
||||
AppConfig.defaults_dict = self.defaults_to_dict()
|
||||
|
||||
def defaults_to_dict(self) -> dict:
|
||||
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
|
||||
result = {}
|
||||
for f in fields(self):
|
||||
field_value = getattr(self, f.name)
|
||||
|
||||
# dataclasses compute their defaults themselves
|
||||
if is_dataclass(type(field_value)):
|
||||
result[f.name] = field_value.defaults_to_dict()
|
||||
else:
|
||||
result[f.name] = get_field_info(f)
|
||||
return result
|
||||
|
||||
def __str__(self):
|
||||
attr_str = []
|
||||
for f in fields(self):
|
||||
attr_name = f.name
|
||||
attr_value = getattr(self, f.name)
|
||||
|
||||
if attr_name in [
|
||||
'e2b_api_key',
|
||||
'github_token',
|
||||
'jwt_secret',
|
||||
]:
|
||||
attr_value = '******' if attr_value else None
|
||||
|
||||
attr_str.append(f'{attr_name}={repr(attr_value)}')
|
||||
|
||||
return f"AppConfig({', '.join(attr_str)}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def get_field_info(f):
|
||||
"""Extract information about a dataclass field: type, optional, and default.
|
||||
|
||||
Args:
|
||||
f: The field to extract information from.
|
||||
|
||||
Returns: A dict with the field's type, whether it's optional, and its default value.
|
||||
"""
|
||||
field_type = f.type
|
||||
optional = False
|
||||
|
||||
# for types like str | None, find the non-None type and set optional to True
|
||||
# this is useful for the frontend to know if a field is optional
|
||||
# and to show the correct type in the UI
|
||||
# Note: this only works for UnionTypes with None as one of the types
|
||||
if get_origin(field_type) is UnionType:
|
||||
types = get_args(field_type)
|
||||
non_none_arg = next((t for t in types if t is not type(None)), None)
|
||||
if non_none_arg is not None:
|
||||
field_type = non_none_arg
|
||||
optional = True
|
||||
|
||||
# type name in a pretty format
|
||||
type_name = (
|
||||
field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
|
||||
)
|
||||
|
||||
# default is always present
|
||||
default = f.default
|
||||
|
||||
# return a schema with the useful info for frontend
|
||||
return {'type': type_name.lower(), 'optional': optional, 'default': default}
|
||||
|
||||
|
||||
def load_from_env(cfg: AppConfig, env_or_toml_dict: dict | MutableMapping[str, str]):
|
||||
"""Reads the env-style vars and sets config attributes based on env vars or a config.toml dict.
|
||||
Compatibility with vars like LLM_BASE_URL, AGENT_MEMORY_ENABLED, SANDBOX_TIMEOUT and others.
|
||||
|
||||
Args:
|
||||
cfg: The AppConfig object to set attributes on.
|
||||
env_or_toml_dict: The environment variables or a config.toml dict.
|
||||
"""
|
||||
|
||||
def get_optional_type(union_type: UnionType) -> Any:
|
||||
"""Returns the non-None type from a Union."""
|
||||
types = get_args(union_type)
|
||||
return next((t for t in types if t is not type(None)), None)
|
||||
|
||||
# helper function to set attributes based on env vars
|
||||
def set_attr_from_env(sub_config: Any, prefix=''):
|
||||
"""Set attributes of a config dataclass based on environment variables."""
|
||||
for field_name, field_type in sub_config.__annotations__.items():
|
||||
# compute the expected env var name from the prefix and field name
|
||||
# e.g. LLM_BASE_URL
|
||||
env_var_name = (prefix + field_name).upper()
|
||||
|
||||
if is_dataclass(field_type):
|
||||
# nested dataclass
|
||||
nested_sub_config = getattr(sub_config, field_name)
|
||||
set_attr_from_env(nested_sub_config, prefix=field_name + '_')
|
||||
elif env_var_name in env_or_toml_dict:
|
||||
# convert the env var to the correct type and set it
|
||||
value = env_or_toml_dict[env_var_name]
|
||||
|
||||
# skip empty config values (fall back to default)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
# if it's an optional type, get the non-None type
|
||||
if get_origin(field_type) is UnionType:
|
||||
field_type = get_optional_type(field_type)
|
||||
|
||||
# Attempt to cast the env var to type hinted in the dataclass
|
||||
if field_type is bool:
|
||||
cast_value = str(value).lower() in ['true', '1']
|
||||
else:
|
||||
cast_value = field_type(value)
|
||||
setattr(sub_config, field_name, cast_value)
|
||||
except (ValueError, TypeError):
|
||||
logger.openhands_logger.error(
|
||||
f'Error setting env var {env_var_name}={value}: check that the value is of the right type'
|
||||
)
|
||||
|
||||
# Start processing from the root of the config object
|
||||
set_attr_from_env(cfg)
|
||||
|
||||
# load default LLM config from env
|
||||
default_llm_config = cfg.get_llm_config()
|
||||
set_attr_from_env(default_llm_config, 'LLM_')
|
||||
# load default agent config from env
|
||||
default_agent_config = cfg.get_agent_config()
|
||||
set_attr_from_env(default_agent_config, 'AGENT_')
|
||||
|
||||
|
||||
def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
|
||||
"""Load the config from the toml file. Supports both styles of config vars.
|
||||
|
||||
Args:
|
||||
cfg: The AppConfig object to update attributes of.
|
||||
toml_file: The path to the toml file. Defaults to 'config.toml'.
|
||||
"""
|
||||
# try to read the config.toml file into the config object
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.warning(
|
||||
f'Cannot parse config from toml, toml values have not been applied.\nError: {e}',
|
||||
exc_info=False,
|
||||
)
|
||||
return
|
||||
|
||||
# if there was an exception or core is not in the toml, try to use the old-style toml
|
||||
if 'core' not in toml_config:
|
||||
# re-use the env loader to set the config from env-style vars
|
||||
load_from_env(cfg, toml_config)
|
||||
return
|
||||
|
||||
core_config = toml_config['core']
|
||||
|
||||
# load llm configs and agent configs
|
||||
for key, value in toml_config.items():
|
||||
if isinstance(value, dict):
|
||||
try:
|
||||
if key is not None and key.lower() == 'agent':
|
||||
logger.openhands_logger.info(
|
||||
'Attempt to load default agent config from config toml'
|
||||
)
|
||||
non_dict_fields = {
|
||||
k: v for k, v in value.items() if not isinstance(v, dict)
|
||||
}
|
||||
agent_config = AgentConfig(**non_dict_fields)
|
||||
cfg.set_agent_config(agent_config, 'agent')
|
||||
for nested_key, nested_value in value.items():
|
||||
if isinstance(nested_value, dict):
|
||||
logger.openhands_logger.info(
|
||||
f'Attempt to load group {nested_key} from config toml as agent config'
|
||||
)
|
||||
agent_config = AgentConfig(**nested_value)
|
||||
cfg.set_agent_config(agent_config, nested_key)
|
||||
elif key is not None and key.lower() == 'llm':
|
||||
logger.openhands_logger.info(
|
||||
'Attempt to load default LLM config from config toml'
|
||||
)
|
||||
non_dict_fields = {
|
||||
k: v for k, v in value.items() if not isinstance(v, dict)
|
||||
}
|
||||
llm_config = LLMConfig(**non_dict_fields)
|
||||
cfg.set_llm_config(llm_config, 'llm')
|
||||
for nested_key, nested_value in value.items():
|
||||
if isinstance(nested_value, dict):
|
||||
logger.openhands_logger.info(
|
||||
f'Attempt to load group {nested_key} from config toml as llm config'
|
||||
)
|
||||
llm_config = LLMConfig(**nested_value)
|
||||
cfg.set_llm_config(llm_config, nested_key)
|
||||
elif not key.startswith('sandbox') and key.lower() != 'core':
|
||||
logger.openhands_logger.warning(
|
||||
f'Unknown key in {toml_file}: "{key}"'
|
||||
)
|
||||
except (TypeError, KeyError) as e:
|
||||
logger.openhands_logger.warning(
|
||||
f'Cannot parse config from toml, toml values have not been applied.\n Error: {e}',
|
||||
exc_info=False,
|
||||
)
|
||||
else:
|
||||
logger.openhands_logger.warning(f'Unknown key in {toml_file}: "{key}')
|
||||
|
||||
try:
|
||||
# set sandbox config from the toml file
|
||||
sandbox_config = cfg.sandbox
|
||||
|
||||
# migrate old sandbox configs from [core] section to sandbox config
|
||||
keys_to_migrate = [key for key in core_config if key.startswith('sandbox_')]
|
||||
for key in keys_to_migrate:
|
||||
new_key = key.replace('sandbox_', '')
|
||||
if new_key in sandbox_config.__annotations__:
|
||||
# read the key in sandbox and remove it from core
|
||||
setattr(sandbox_config, new_key, core_config.pop(key))
|
||||
else:
|
||||
logger.openhands_logger.warning(f'Unknown sandbox config: {key}')
|
||||
|
||||
# the new style values override the old style values
|
||||
if 'sandbox' in toml_config:
|
||||
sandbox_config = SandboxConfig(**toml_config['sandbox'])
|
||||
|
||||
# update the config object with the new values
|
||||
AppConfig(sandbox=sandbox_config, **core_config)
|
||||
except (TypeError, KeyError) as e:
|
||||
logger.openhands_logger.warning(
|
||||
f'Cannot parse config from toml, toml values have not been applied.\nError: {e}',
|
||||
exc_info=False,
|
||||
)
|
||||
|
||||
|
||||
def finalize_config(cfg: AppConfig):
|
||||
"""More tweaks to the config after it's been loaded."""
|
||||
# Set workspace_mount_path if not set by the user
|
||||
if cfg.workspace_mount_path is UndefinedString.UNDEFINED:
|
||||
cfg.workspace_mount_path = os.path.abspath(cfg.workspace_base)
|
||||
cfg.workspace_base = os.path.abspath(cfg.workspace_base)
|
||||
|
||||
if cfg.workspace_mount_rewrite: # and not config.workspace_mount_path:
|
||||
# TODO why do we need to check if workspace_mount_path is None?
|
||||
base = cfg.workspace_base or os.getcwd()
|
||||
parts = cfg.workspace_mount_rewrite.split(':')
|
||||
cfg.workspace_mount_path = base.replace(parts[0], parts[1])
|
||||
|
||||
for llm in cfg.llms.values():
|
||||
if llm.embedding_base_url is None:
|
||||
llm.embedding_base_url = llm.base_url
|
||||
|
||||
if cfg.sandbox.use_host_network and platform.system() == 'Darwin':
|
||||
logger.openhands_logger.warning(
|
||||
'Please upgrade to Docker Desktop 4.29.0 or later to use host network mode on macOS. '
|
||||
'See https://github.com/docker/roadmap/issues/238#issuecomment-2044688144 for more information.'
|
||||
)
|
||||
|
||||
# make sure cache dir exists
|
||||
if cfg.cache_dir:
|
||||
pathlib.Path(cfg.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Utility function for command line --group argument
|
||||
def get_llm_config_arg(
|
||||
llm_config_arg: str, toml_file: str = 'config.toml'
|
||||
) -> LLMConfig | None:
|
||||
"""Get a group of llm settings from the config file.
|
||||
|
||||
A group in config.toml can look like this:
|
||||
|
||||
```
|
||||
[llm.gpt-3.5-for-eval]
|
||||
model = 'gpt-3.5-turbo'
|
||||
api_key = '...'
|
||||
temperature = 0.5
|
||||
num_retries = 10
|
||||
...
|
||||
```
|
||||
|
||||
The user-defined group name, like "gpt-3.5-for-eval", is the argument to this function. The function will load the LLMConfig object
|
||||
with the settings of this group, from the config file, and set it as the LLMConfig object for the app.
|
||||
|
||||
Note that the group must be under "llm" group, or in other words, the group name must start with "llm.".
|
||||
|
||||
Args:
|
||||
llm_config_arg: The group of llm settings to get from the config.toml file.
|
||||
|
||||
Returns:
|
||||
LLMConfig: The LLMConfig object with the settings from the config file.
|
||||
"""
|
||||
# keep only the name, just in case
|
||||
llm_config_arg = llm_config_arg.strip('[]')
|
||||
|
||||
# truncate the prefix, just in case
|
||||
if llm_config_arg.startswith('llm.'):
|
||||
llm_config_arg = llm_config_arg[4:]
|
||||
|
||||
logger.openhands_logger.info(f'Loading llm config from {llm_config_arg}')
|
||||
|
||||
# load the toml file
|
||||
try:
|
||||
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
|
||||
toml_config = toml.load(toml_contents)
|
||||
except FileNotFoundError as e:
|
||||
logger.openhands_logger.error(f'Config file not found: {e}')
|
||||
return None
|
||||
except toml.TomlDecodeError as e:
|
||||
logger.openhands_logger.error(
|
||||
f'Cannot parse llm group from {llm_config_arg}. Exception: {e}'
|
||||
)
|
||||
return None
|
||||
|
||||
# update the llm config with the specified section
|
||||
if 'llm' in toml_config and llm_config_arg in toml_config['llm']:
|
||||
return LLMConfig(**toml_config['llm'][llm_config_arg])
|
||||
logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}')
|
||||
return None
|
||||
|
||||
|
||||
# Command line arguments
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
"""Get the parser for the command line arguments."""
|
||||
parser = argparse.ArgumentParser(description='Run an agent with a specific task')
|
||||
parser.add_argument(
|
||||
'-d',
|
||||
'--directory',
|
||||
type=str,
|
||||
help='The working directory for the agent',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-t',
|
||||
'--task',
|
||||
type=str,
|
||||
default='',
|
||||
help='The task for the agent to perform',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-f',
|
||||
'--file',
|
||||
type=str,
|
||||
help='Path to a file containing the task. Overrides -t if both are provided.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--agent-cls',
|
||||
default=_DEFAULT_AGENT,
|
||||
type=str,
|
||||
help='Name of the default agent to use',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--max-iterations',
|
||||
default=_MAX_ITERATIONS,
|
||||
type=int,
|
||||
help='The maximum number of iterations to run the agent',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-b',
|
||||
'--max-budget-per-task',
|
||||
type=float,
|
||||
help='The maximum budget allowed per task, beyond which the agent will stop.',
|
||||
)
|
||||
# --eval configs are for evaluations only
|
||||
parser.add_argument(
|
||||
'--eval-output-dir',
|
||||
default='evaluation/evaluation_outputs/outputs',
|
||||
type=str,
|
||||
help='The directory to save evaluation output',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--eval-n-limit',
|
||||
default=None,
|
||||
type=int,
|
||||
help='The number of instances to evaluate',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--eval-num-workers',
|
||||
default=4,
|
||||
type=int,
|
||||
help='The number of workers to use for evaluation',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--eval-note',
|
||||
default=None,
|
||||
type=str,
|
||||
help='The note to add to the evaluation directory',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-l',
|
||||
'--llm-config',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--name',
|
||||
default='default',
|
||||
type=str,
|
||||
help='Name for the session',
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""Parse the command line arguments."""
|
||||
parser = get_parser()
|
||||
parsed_args, _ = parser.parse_known_args()
|
||||
return parsed_args
|
||||
|
||||
|
||||
def load_app_config(set_logging_levels: bool = True) -> AppConfig:
|
||||
"""Load the configuration from the config.toml file and environment variables.
|
||||
|
||||
Args:
|
||||
set_logger_levels: Whether to set the global variables for logging levels.
|
||||
"""
|
||||
config = AppConfig()
|
||||
load_from_toml(config)
|
||||
load_from_env(config, os.environ)
|
||||
finalize_config(config)
|
||||
if set_logging_levels:
|
||||
logger.DEBUG = config.debug
|
||||
logger.DISABLE_COLOR_PRINTING = config.disable_color
|
||||
return config
|
||||
1
openhands/core/const/guide_url.py
Normal file
1
openhands/core/const/guide_url.py
Normal file
@@ -0,0 +1 @@
|
||||
TROUBLESHOOTING_URL = 'https://docs.all-hands.dev/modules/usage/troubleshooting'
|
||||
2
openhands/core/download.py
Normal file
2
openhands/core/download.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Run this file to trigger a model download
|
||||
import agenthub # noqa F401 (we import this to get the agents registered)
|
||||
74
openhands/core/exceptions.py
Normal file
74
openhands/core/exceptions.py
Normal file
@@ -0,0 +1,74 @@
|
||||
class AgentNoInstructionError(Exception):
|
||||
def __init__(self, message='Instruction must be provided'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentEventTypeError(Exception):
|
||||
def __init__(self, message='Event must be a dictionary'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentAlreadyRegisteredError(Exception):
|
||||
def __init__(self, name=None):
|
||||
if name is not None:
|
||||
message = f"Agent class already registered under '{name}'"
|
||||
else:
|
||||
message = 'Agent class already registered'
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentNotRegisteredError(Exception):
|
||||
def __init__(self, name=None):
|
||||
if name is not None:
|
||||
message = f"No agent class registered under '{name}'"
|
||||
else:
|
||||
message = 'No agent class registered'
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TaskInvalidStateError(Exception):
|
||||
def __init__(self, state=None):
|
||||
if state is not None:
|
||||
message = f'Invalid state {state}'
|
||||
else:
|
||||
message = 'Invalid state'
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BrowserInitException(Exception):
|
||||
def __init__(self, message='Failed to initialize browser environment'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BrowserUnavailableException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message='Browser environment is not available, please check if has been initialized',
|
||||
):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# This exception gets sent back to the LLM
|
||||
# It might be malformed JSON
|
||||
class LLMMalformedActionError(Exception):
|
||||
def __init__(self, message='Malformed response'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# This exception gets sent back to the LLM
|
||||
# For some reason, the agent did not return an action
|
||||
class LLMNoActionError(Exception):
|
||||
def __init__(self, message='Agent must return an action'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# This exception gets sent back to the LLM
|
||||
# The LLM output did not include an action, or the action was not the expected type
|
||||
class LLMResponseError(Exception):
|
||||
def __init__(self, message='Failed to retrieve action from LLM response'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserCancelledError(Exception):
|
||||
def __init__(self, message='User cancelled the request'):
|
||||
super().__init__(message)
|
||||
255
openhands/core/logger.py
Normal file
255
openhands/core/logger.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Literal, Mapping
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
DISABLE_COLOR_PRINTING = False
|
||||
DEBUG = os.getenv('DEBUG', 'False').lower() in ['true', '1', 'yes']
|
||||
LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'False').lower() in ['true', '1', 'yes']
|
||||
|
||||
ColorType = Literal[
|
||||
'red',
|
||||
'green',
|
||||
'yellow',
|
||||
'blue',
|
||||
'magenta',
|
||||
'cyan',
|
||||
'light_grey',
|
||||
'dark_grey',
|
||||
'light_red',
|
||||
'light_green',
|
||||
'light_yellow',
|
||||
'light_blue',
|
||||
'light_magenta',
|
||||
'light_cyan',
|
||||
'white',
|
||||
]
|
||||
|
||||
LOG_COLORS: Mapping[str, ColorType] = {
|
||||
'ACTION': 'green',
|
||||
'USER_ACTION': 'light_red',
|
||||
'OBSERVATION': 'yellow',
|
||||
'USER_OBSERVATION': 'light_green',
|
||||
'DETAIL': 'cyan',
|
||||
'ERROR': 'red',
|
||||
'PLAN': 'light_magenta',
|
||||
}
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
msg_type = record.__dict__.get('msg_type')
|
||||
event_source = record.__dict__.get('event_source')
|
||||
if event_source:
|
||||
new_msg_type = f'{event_source.upper()}_{msg_type}'
|
||||
if new_msg_type in LOG_COLORS:
|
||||
msg_type = new_msg_type
|
||||
if msg_type in LOG_COLORS and not DISABLE_COLOR_PRINTING:
|
||||
msg_type_color = colored(msg_type, LOG_COLORS[msg_type])
|
||||
msg = colored(record.msg, LOG_COLORS[msg_type])
|
||||
time_str = colored(
|
||||
self.formatTime(record, self.datefmt), LOG_COLORS[msg_type]
|
||||
)
|
||||
name_str = colored(record.name, LOG_COLORS[msg_type])
|
||||
level_str = colored(record.levelname, LOG_COLORS[msg_type])
|
||||
if msg_type in ['ERROR'] or DEBUG:
|
||||
return f'{time_str} - {name_str}:{level_str}: {record.filename}:{record.lineno}\n{msg_type_color}\n{msg}'
|
||||
return f'{time_str} - {msg_type_color}\n{msg}'
|
||||
elif msg_type == 'STEP':
|
||||
msg = '\n\n==============\n' + record.msg + '\n'
|
||||
return f'{msg}'
|
||||
return super().format(record)
|
||||
|
||||
|
||||
console_formatter = ColoredFormatter(
|
||||
'\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
)
|
||||
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s:%(levelname)s: %(filename)s:%(lineno)s - %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
)
|
||||
llm_formatter = logging.Formatter('%(message)s')
|
||||
|
||||
|
||||
class SensitiveDataFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# start with attributes
|
||||
sensitive_patterns = [
|
||||
'api_key',
|
||||
'aws_access_key_id',
|
||||
'aws_secret_access_key',
|
||||
'e2b_api_key',
|
||||
'github_token',
|
||||
'jwt_secret',
|
||||
]
|
||||
|
||||
# add env var names
|
||||
env_vars = [attr.upper() for attr in sensitive_patterns]
|
||||
sensitive_patterns.extend(env_vars)
|
||||
|
||||
# and some special cases
|
||||
sensitive_patterns.append('JWT_SECRET')
|
||||
sensitive_patterns.append('LLM_API_KEY')
|
||||
sensitive_patterns.append('GITHUB_TOKEN')
|
||||
sensitive_patterns.append('SANDBOX_ENV_GITHUB_TOKEN')
|
||||
|
||||
# this also formats the message with % args
|
||||
msg = record.getMessage()
|
||||
record.args = ()
|
||||
|
||||
for attr in sensitive_patterns:
|
||||
pattern = rf"{attr}='?([\w-]+)'?"
|
||||
msg = re.sub(pattern, f"{attr}='******'", msg)
|
||||
|
||||
# passed with msg
|
||||
record.msg = msg
|
||||
return True
|
||||
|
||||
|
||||
def get_console_handler():
|
||||
"""Returns a console handler for logging."""
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
if DEBUG:
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
return console_handler
|
||||
|
||||
|
||||
def get_file_handler(log_dir):
|
||||
"""Returns a file handler for logging."""
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d')
|
||||
file_name = f'openhands_{timestamp}.log'
|
||||
file_handler = logging.FileHandler(os.path.join(log_dir, file_name))
|
||||
if DEBUG:
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
return file_handler
|
||||
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
|
||||
def log_uncaught_exceptions(ex_cls, ex, tb):
|
||||
"""Logs uncaught exceptions along with the traceback.
|
||||
|
||||
Args:
|
||||
ex_cls (type): The type of the exception.
|
||||
ex (Exception): The exception instance.
|
||||
tb (traceback): The traceback object.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
logging.error(''.join(traceback.format_tb(tb)))
|
||||
logging.error('{0}: {1}'.format(ex_cls, ex))
|
||||
|
||||
|
||||
sys.excepthook = log_uncaught_exceptions
|
||||
|
||||
openhands_logger = logging.getLogger('openhands')
|
||||
openhands_logger.setLevel(logging.INFO)
|
||||
LOG_DIR = os.path.join(
|
||||
# parent dir of openhands/core (i.e., root of the repo)
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||
'logs',
|
||||
)
|
||||
|
||||
if DEBUG:
|
||||
openhands_logger.setLevel(logging.DEBUG)
|
||||
|
||||
if LOG_TO_FILE:
|
||||
# default log to project root
|
||||
openhands_logger.info('Logging to file is enabled. Logging to %s', LOG_DIR)
|
||||
openhands_logger.addHandler(get_file_handler(LOG_DIR))
|
||||
|
||||
openhands_logger.addHandler(get_console_handler())
|
||||
openhands_logger.addFilter(SensitiveDataFilter(openhands_logger.name))
|
||||
openhands_logger.propagate = False
|
||||
openhands_logger.debug('Logging initialized')
|
||||
|
||||
|
||||
# Exclude LiteLLM from logging output
|
||||
logging.getLogger('LiteLLM').disabled = True
|
||||
logging.getLogger('LiteLLM Router').disabled = True
|
||||
logging.getLogger('LiteLLM Proxy').disabled = True
|
||||
|
||||
|
||||
class LlmFileHandler(logging.FileHandler):
|
||||
"""# LLM prompt and response logging"""
|
||||
|
||||
def __init__(self, filename, mode='a', encoding='utf-8', delay=False):
|
||||
"""Initializes an instance of LlmFileHandler.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the log file.
|
||||
mode (str, optional): The file mode. Defaults to 'a'.
|
||||
encoding (str, optional): The file encoding. Defaults to None.
|
||||
delay (bool, optional): Whether to delay file opening. Defaults to False.
|
||||
"""
|
||||
self.filename = filename
|
||||
self.message_counter = 1
|
||||
if DEBUG:
|
||||
self.session = datetime.now().strftime('%y-%m-%d_%H-%M')
|
||||
else:
|
||||
self.session = 'default'
|
||||
self.log_directory = os.path.join(LOG_DIR, 'llm', self.session)
|
||||
os.makedirs(self.log_directory, exist_ok=True)
|
||||
if not DEBUG:
|
||||
# Clear the log directory if not in debug mode
|
||||
for file in os.listdir(self.log_directory):
|
||||
file_path = os.path.join(self.log_directory, file)
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except Exception as e:
|
||||
openhands_logger.error(
|
||||
'Failed to delete %s. Reason: %s', file_path, e
|
||||
)
|
||||
filename = f'{self.filename}_{self.message_counter:03}.log'
|
||||
self.baseFilename = os.path.join(self.log_directory, filename)
|
||||
super().__init__(self.baseFilename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
"""Emits a log record.
|
||||
|
||||
Args:
|
||||
record (logging.LogRecord): The log record to emit.
|
||||
"""
|
||||
filename = f'{self.filename}_{self.message_counter:03}.log'
|
||||
self.baseFilename = os.path.join(self.log_directory, filename)
|
||||
self.stream = self._open()
|
||||
super().emit(record)
|
||||
self.stream.close()
|
||||
openhands_logger.debug('Logging to %s', self.baseFilename)
|
||||
self.message_counter += 1
|
||||
|
||||
|
||||
def _get_llm_file_handler(name, debug_level=logging.DEBUG):
|
||||
# The 'delay' parameter, when set to True, postpones the opening of the log file
|
||||
# until the first log message is emitted.
|
||||
llm_file_handler = LlmFileHandler(name, delay=True)
|
||||
llm_file_handler.setFormatter(llm_formatter)
|
||||
llm_file_handler.setLevel(debug_level)
|
||||
return llm_file_handler
|
||||
|
||||
|
||||
def _setup_llm_logger(name, debug_level=logging.DEBUG):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(debug_level)
|
||||
if LOG_TO_FILE:
|
||||
logger.addHandler(_get_llm_file_handler(name, debug_level))
|
||||
return logger
|
||||
|
||||
|
||||
llm_prompt_logger = _setup_llm_logger('prompt', logging.DEBUG)
|
||||
llm_response_logger = _setup_llm_logger('response', logging.DEBUG)
|
||||
259
openhands/core/main.py
Normal file
259
openhands/core/main.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Callable, Protocol, Type
|
||||
|
||||
import agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AppConfig,
|
||||
get_llm_config_arg,
|
||||
load_app_config,
|
||||
parse_arguments,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
class FakeUserResponseFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
state: State,
|
||||
encapsulate_solution: bool = ...,
|
||||
try_parse: Callable[[Action], str] = ...,
|
||||
) -> str: ...
|
||||
|
||||
|
||||
def read_task_from_file(file_path: str) -> str:
|
||||
"""Read task from the specified file."""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def read_task_from_stdin() -> str:
|
||||
"""Read task from stdin."""
|
||||
return sys.stdin.read()
|
||||
|
||||
|
||||
async def create_runtime(
|
||||
config: AppConfig,
|
||||
sid: str | None = None,
|
||||
runtime_tools_config: dict | None = None,
|
||||
) -> Runtime:
|
||||
"""Create a runtime for the agent to run on.
|
||||
|
||||
config: The app config.
|
||||
sid: The session id.
|
||||
runtime_tools_config: (will be deprecated) The runtime tools config.
|
||||
"""
|
||||
# if sid is provided on the command line, use it as the name of the event stream
|
||||
# otherwise generate it on the basis of the configured jwt_secret
|
||||
# we can do this better, this is just so that the sid is retrieved when we want to restore the session
|
||||
session_id = sid or generate_sid(config)
|
||||
|
||||
# set up the event stream
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
event_stream = EventStream(session_id, file_store)
|
||||
|
||||
# agent class
|
||||
agent_cls = agenthub.Agent.get_cls(config.default_agent)
|
||||
|
||||
# runtime and tools
|
||||
runtime_cls = get_runtime_cls(config.runtime)
|
||||
logger.info(f'Initializing runtime: {runtime_cls}')
|
||||
runtime: Runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid=session_id,
|
||||
plugins=agent_cls.sandbox_plugins,
|
||||
)
|
||||
await runtime.ainit()
|
||||
|
||||
return runtime
|
||||
|
||||
|
||||
async def run_controller(
|
||||
config: AppConfig,
|
||||
task_str: str,
|
||||
sid: str | None = None,
|
||||
runtime: Runtime | None = None,
|
||||
agent: Agent | None = None,
|
||||
exit_on_message: bool = False,
|
||||
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
||||
headless_mode: bool = True,
|
||||
) -> State | None:
|
||||
"""Main coroutine to run the agent controller with task input flexibility.
|
||||
It's only used when you launch openhands backend directly via cmdline.
|
||||
|
||||
Args:
|
||||
config: The app config.
|
||||
task_str: The task to run. It can be a string.
|
||||
runtime: (optional) A runtime for the agent to run on.
|
||||
agent: (optional) A agent to run.
|
||||
exit_on_message: quit if agent asks for a message from user (optional)
|
||||
fake_user_response_fn: An optional function that receives the current state
|
||||
(could be None) and returns a fake user response.
|
||||
headless_mode: Whether the agent is run in headless mode.
|
||||
"""
|
||||
# Create the agent
|
||||
if agent is None:
|
||||
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
|
||||
agent_config = config.get_agent_config(config.default_agent)
|
||||
llm_config = config.get_llm_config_from_agent(config.default_agent)
|
||||
agent = agent_cls(
|
||||
llm=LLM(config=llm_config),
|
||||
config=agent_config,
|
||||
)
|
||||
|
||||
# make sure the session id is set
|
||||
sid = sid or generate_sid(config)
|
||||
|
||||
if runtime is None:
|
||||
runtime = await create_runtime(config, sid=sid)
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
# restore cli session if enabled
|
||||
initial_state = None
|
||||
if config.enable_cli_session:
|
||||
try:
|
||||
logger.info(f'Restoring agent state from cli session {event_stream.sid}')
|
||||
initial_state = State.restore_from_session(
|
||||
event_stream.sid, event_stream.file_store
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f'Error restoring state: {e}')
|
||||
|
||||
# init controller with this initial state
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
max_iterations=config.max_iterations,
|
||||
max_budget_per_task=config.max_budget_per_task,
|
||||
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
||||
event_stream=event_stream,
|
||||
initial_state=initial_state,
|
||||
headless_mode=headless_mode,
|
||||
)
|
||||
|
||||
assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
|
||||
# Logging
|
||||
logger.info(
|
||||
f'Agent Controller Initialized: Running agent {agent.name}, model '
|
||||
f'{agent.llm.config.model}, with task: "{task_str}"'
|
||||
)
|
||||
|
||||
# start event is a MessageAction with the task, either resumed or new
|
||||
if config.enable_cli_session and initial_state is not None:
|
||||
# we're resuming the previous session
|
||||
event_stream.add_event(
|
||||
MessageAction(
|
||||
content=(
|
||||
"Let's get back on track. If you experienced errors before, do "
|
||||
'NOT resume your task. Ask me about it.'
|
||||
),
|
||||
),
|
||||
EventSource.USER,
|
||||
)
|
||||
elif initial_state is None:
|
||||
# init with the provided task
|
||||
event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
|
||||
|
||||
async def on_event(event: Event):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.AWAITING_USER_INPUT:
|
||||
if exit_on_message:
|
||||
message = '/exit'
|
||||
elif fake_user_response_fn is None:
|
||||
message = input('Request user input >> ')
|
||||
else:
|
||||
message = fake_user_response_fn(controller.get_state())
|
||||
action = MessageAction(content=message)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
while controller.state.agent_state not in [
|
||||
AgentState.FINISHED,
|
||||
AgentState.REJECTED,
|
||||
AgentState.ERROR,
|
||||
AgentState.PAUSED,
|
||||
AgentState.STOPPED,
|
||||
]:
|
||||
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
|
||||
|
||||
# save session when we're about to close
|
||||
if config.enable_cli_session:
|
||||
end_state = controller.get_state()
|
||||
end_state.save_to_session(event_stream.sid, event_stream.file_store)
|
||||
|
||||
# close when done
|
||||
await controller.close()
|
||||
state = controller.get_state()
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
|
||||
"""Generate a session id based on the session name and the jwt secret."""
|
||||
session_name = session_name or str(uuid.uuid4())
|
||||
jwt_secret = config.jwt_secret
|
||||
|
||||
hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest()
|
||||
return f'{session_name}_{hash_str[:16]}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
# Determine the task
|
||||
if args.file:
|
||||
task_str = read_task_from_file(args.file)
|
||||
elif args.task:
|
||||
task_str = args.task
|
||||
elif not sys.stdin.isatty():
|
||||
task_str = read_task_from_stdin()
|
||||
else:
|
||||
raise ValueError('No task provided. Please specify a task through -t, -f.')
|
||||
|
||||
# Load the app config
|
||||
# this will load config from config.toml in the current directory
|
||||
# as well as from the environment variables
|
||||
config = load_app_config()
|
||||
|
||||
# Override default LLM configs ([llm] section in config.toml)
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
|
||||
config.set_llm_config(llm_config)
|
||||
|
||||
# Set default agent
|
||||
config.default_agent = args.agent_cls
|
||||
|
||||
# Set session name
|
||||
session_name = args.name
|
||||
sid = generate_sid(config, session_name)
|
||||
|
||||
# if max budget per task is not sent on the command line, use the config value
|
||||
if args.max_budget_per_task is not None:
|
||||
config.max_budget_per_task = args.max_budget_per_task
|
||||
if args.max_iterations is not None:
|
||||
config.max_iterations = args.max_iterations
|
||||
|
||||
asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=task_str,
|
||||
sid=sid,
|
||||
)
|
||||
)
|
||||
59
openhands/core/message.py
Normal file
59
openhands/core/message.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
TEXT = 'text'
|
||||
IMAGE_URL = 'image_url'
|
||||
|
||||
|
||||
class Content(BaseModel):
|
||||
type: ContentType
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
|
||||
class TextContent(Content):
|
||||
type: ContentType = ContentType.TEXT
|
||||
text: str
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
return {'type': self.type.value, 'text': self.text}
|
||||
|
||||
|
||||
class ImageContent(Content):
|
||||
type: ContentType = ContentType.IMAGE_URL
|
||||
image_urls: list[str]
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self):
|
||||
images: list[dict[str, str | dict[str, str]]] = []
|
||||
for url in self.image_urls:
|
||||
images.append({'type': self.type.value, 'image_url': {'url': url}})
|
||||
return images
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Literal['user', 'system', 'assistant']
|
||||
content: list[TextContent | ImageContent] = Field(default=list)
|
||||
|
||||
@property
|
||||
def contains_image(self) -> bool:
|
||||
return any(isinstance(content, ImageContent) for content in self.content)
|
||||
|
||||
@model_serializer
|
||||
def serialize_model(self) -> dict:
|
||||
content: list[dict[str, str | dict[str, str]]] = []
|
||||
|
||||
for item in self.content:
|
||||
if isinstance(item, TextContent):
|
||||
content.append(item.model_dump())
|
||||
elif isinstance(item, ImageContent):
|
||||
content.extend(item.model_dump())
|
||||
|
||||
return {'role': self.role, 'content': content}
|
||||
48
openhands/core/metrics.py
Normal file
48
openhands/core/metrics.py
Normal file
@@ -0,0 +1,48 @@
|
||||
class Metrics:
|
||||
"""Metrics class can record various metrics during running and evaluation.
|
||||
Currently, we define the following metrics:
|
||||
accumulated_cost: the total cost (USD $) of the current LLM.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._accumulated_cost: float = 0.0
|
||||
self._costs: list[float] = []
|
||||
|
||||
@property
|
||||
def accumulated_cost(self) -> float:
|
||||
return self._accumulated_cost
|
||||
|
||||
@accumulated_cost.setter
|
||||
def accumulated_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError('Total cost cannot be negative.')
|
||||
self._accumulated_cost = value
|
||||
|
||||
@property
|
||||
def costs(self) -> list:
|
||||
return self._costs
|
||||
|
||||
def add_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError('Added cost cannot be negative.')
|
||||
self._accumulated_cost += value
|
||||
self._costs.append(value)
|
||||
|
||||
def merge(self, other: 'Metrics') -> None:
|
||||
self._accumulated_cost += other.accumulated_cost
|
||||
self._costs += other._costs
|
||||
|
||||
def get(self):
|
||||
"""Return the metrics in a dictionary."""
|
||||
return {'accumulated_cost': self._accumulated_cost, 'costs': self._costs}
|
||||
|
||||
def log(self):
|
||||
"""Log the metrics."""
|
||||
metrics = self.get()
|
||||
logs = ''
|
||||
for key, value in metrics.items():
|
||||
logs += f'{key}: {value}\n'
|
||||
return logs
|
||||
|
||||
def __repr__(self):
|
||||
return f'Metrics({self.get()}'
|
||||
11
openhands/core/schema/__init__.py
Normal file
11
openhands/core/schema/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .action import ActionType
|
||||
from .agent import AgentState
|
||||
from .config import ConfigType
|
||||
from .observation import ObservationType
|
||||
|
||||
__all__ = [
|
||||
'ActionType',
|
||||
'ObservationType',
|
||||
'ConfigType',
|
||||
'AgentState',
|
||||
]
|
||||
86
openhands/core/schema/action.py
Normal file
86
openhands/core/schema/action.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = ['ActionType']
|
||||
|
||||
|
||||
class ActionTypeSchema(BaseModel):
|
||||
INIT: str = Field(default='initialize')
|
||||
"""Initializes the agent. Only sent by client.
|
||||
"""
|
||||
|
||||
MESSAGE: str = Field(default='message')
|
||||
"""Represents a message.
|
||||
"""
|
||||
|
||||
START: str = Field(default='start')
|
||||
"""Starts a new development task OR send chat from the user. Only sent by the client.
|
||||
"""
|
||||
|
||||
READ: str = Field(default='read')
|
||||
"""Reads the content of a file.
|
||||
"""
|
||||
|
||||
WRITE: str = Field(default='write')
|
||||
"""Writes the content to a file.
|
||||
"""
|
||||
|
||||
RUN: str = Field(default='run')
|
||||
"""Runs a command.
|
||||
"""
|
||||
|
||||
RUN_IPYTHON: str = Field(default='run_ipython')
|
||||
"""Runs a IPython cell.
|
||||
"""
|
||||
|
||||
BROWSE: str = Field(default='browse')
|
||||
"""Opens a web page.
|
||||
"""
|
||||
|
||||
BROWSE_INTERACTIVE: str = Field(default='browse_interactive')
|
||||
"""Interact with the browser instance.
|
||||
"""
|
||||
|
||||
DELEGATE: str = Field(default='delegate')
|
||||
"""Delegates a task to another agent.
|
||||
"""
|
||||
|
||||
FINISH: str = Field(default='finish')
|
||||
"""If you're absolutely certain that you've completed your task and have tested your work,
|
||||
use the finish action to stop working.
|
||||
"""
|
||||
|
||||
REJECT: str = Field(default='reject')
|
||||
"""If you're absolutely certain that you cannot complete the task with given requirements,
|
||||
use the reject action to stop working.
|
||||
"""
|
||||
|
||||
NULL: str = Field(default='null')
|
||||
|
||||
SUMMARIZE: str = Field(default='summarize')
|
||||
|
||||
ADD_TASK: str = Field(default='add_task')
|
||||
|
||||
MODIFY_TASK: str = Field(default='modify_task')
|
||||
|
||||
PAUSE: str = Field(default='pause')
|
||||
"""Pauses the task.
|
||||
"""
|
||||
|
||||
RESUME: str = Field(default='resume')
|
||||
"""Resumes the task.
|
||||
"""
|
||||
|
||||
STOP: str = Field(default='stop')
|
||||
"""Stops the task. Must send a start action to restart a new task.
|
||||
"""
|
||||
|
||||
CHANGE_AGENT_STATE: str = Field(default='change_agent_state')
|
||||
|
||||
PUSH: str = Field(default='push')
|
||||
"""Push a branch to github."""
|
||||
|
||||
SEND_PR: str = Field(default='send_pr')
|
||||
"""Send a PR to github."""
|
||||
|
||||
|
||||
ActionType = ActionTypeSchema()
|
||||
51
openhands/core/schema/agent.py
Normal file
51
openhands/core/schema/agent.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AgentState(str, Enum):
|
||||
LOADING = 'loading'
|
||||
"""The agent is loading.
|
||||
"""
|
||||
|
||||
INIT = 'init'
|
||||
"""The agent is initialized.
|
||||
"""
|
||||
|
||||
RUNNING = 'running'
|
||||
"""The agent is running.
|
||||
"""
|
||||
|
||||
AWAITING_USER_INPUT = 'awaiting_user_input'
|
||||
"""The agent is awaiting user input.
|
||||
"""
|
||||
|
||||
PAUSED = 'paused'
|
||||
"""The agent is paused.
|
||||
"""
|
||||
|
||||
STOPPED = 'stopped'
|
||||
"""The agent is stopped.
|
||||
"""
|
||||
|
||||
FINISHED = 'finished'
|
||||
"""The agent is finished with the current task.
|
||||
"""
|
||||
|
||||
REJECTED = 'rejected'
|
||||
"""The agent rejects the task.
|
||||
"""
|
||||
|
||||
ERROR = 'error'
|
||||
"""An error occurred during the task.
|
||||
"""
|
||||
|
||||
AWAITING_USER_CONFIRMATION = 'awaiting_user_confirmation'
|
||||
"""The agent is awaiting user confirmation.
|
||||
"""
|
||||
|
||||
USER_CONFIRMED = 'user_confirmed'
|
||||
"""The user confirmed the agent's action.
|
||||
"""
|
||||
|
||||
USER_REJECTED = 'user_rejected'
|
||||
"""The user rejected the agent's action.
|
||||
"""
|
||||
47
openhands/core/schema/config.py
Normal file
47
openhands/core/schema/config.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ConfigType(str, Enum):
|
||||
# For frontend
|
||||
LLM_CUSTOM_LLM_PROVIDER = 'LLM_CUSTOM_LLM_PROVIDER'
|
||||
LLM_DROP_PARAMS = 'LLM_DROP_PARAMS'
|
||||
LLM_MAX_INPUT_TOKENS = 'LLM_MAX_INPUT_TOKENS'
|
||||
LLM_MAX_OUTPUT_TOKENS = 'LLM_MAX_OUTPUT_TOKENS'
|
||||
LLM_TOP_P = 'LLM_TOP_P'
|
||||
LLM_TEMPERATURE = 'LLM_TEMPERATURE'
|
||||
LLM_TIMEOUT = 'LLM_TIMEOUT'
|
||||
LLM_API_KEY = 'LLM_API_KEY'
|
||||
LLM_BASE_URL = 'LLM_BASE_URL'
|
||||
AWS_ACCESS_KEY_ID = 'AWS_ACCESS_KEY_ID'
|
||||
AWS_SECRET_ACCESS_KEY = 'AWS_SECRET_ACCESS_KEY'
|
||||
AWS_REGION_NAME = 'AWS_REGION_NAME'
|
||||
WORKSPACE_BASE = 'WORKSPACE_BASE'
|
||||
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
|
||||
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
|
||||
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
|
||||
CACHE_DIR = 'CACHE_DIR'
|
||||
LLM_MODEL = 'LLM_MODEL'
|
||||
CONFIRMATION_MODE = 'CONFIRMATION_MODE'
|
||||
SANDBOX_CONTAINER_IMAGE = 'SANDBOX_CONTAINER_IMAGE'
|
||||
RUN_AS_OPENHANDS = 'RUN_AS_OPENHANDS'
|
||||
LLM_EMBEDDING_MODEL = 'LLM_EMBEDDING_MODEL'
|
||||
LLM_EMBEDDING_BASE_URL = 'LLM_EMBEDDING_BASE_URL'
|
||||
LLM_EMBEDDING_DEPLOYMENT_NAME = 'LLM_EMBEDDING_DEPLOYMENT_NAME'
|
||||
LLM_API_VERSION = 'LLM_API_VERSION'
|
||||
LLM_NUM_RETRIES = 'LLM_NUM_RETRIES'
|
||||
LLM_RETRY_MIN_WAIT = 'LLM_RETRY_MIN_WAIT'
|
||||
LLM_RETRY_MAX_WAIT = 'LLM_RETRY_MAX_WAIT'
|
||||
AGENT_MEMORY_MAX_THREADS = 'AGENT_MEMORY_MAX_THREADS'
|
||||
AGENT_MEMORY_ENABLED = 'AGENT_MEMORY_ENABLED'
|
||||
MAX_ITERATIONS = 'MAX_ITERATIONS'
|
||||
AGENT = 'AGENT'
|
||||
E2B_API_KEY = 'E2B_API_KEY'
|
||||
SECURITY_ANALYZER = 'SECURITY_ANALYZER'
|
||||
SANDBOX_USER_ID = 'SANDBOX_USER_ID'
|
||||
SANDBOX_TIMEOUT = 'SANDBOX_TIMEOUT'
|
||||
USE_HOST_NETWORK = 'USE_HOST_NETWORK'
|
||||
DISABLE_COLOR = 'DISABLE_COLOR'
|
||||
DEBUG = 'DEBUG'
|
||||
FILE_UPLOADS_MAX_FILE_SIZE_MB = 'FILE_UPLOADS_MAX_FILE_SIZE_MB'
|
||||
FILE_UPLOADS_RESTRICT_FILE_TYPES = 'FILE_UPLOADS_RESTRICT_FILE_TYPES'
|
||||
FILE_UPLOADS_ALLOWED_EXTENSIONS = 'FILE_UPLOADS_ALLOWED_EXTENSIONS'
|
||||
46
openhands/core/schema/observation.py
Normal file
46
openhands/core/schema/observation.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = ['ObservationType']
|
||||
|
||||
|
||||
class ObservationTypeSchema(BaseModel):
|
||||
READ: str = Field(default='read')
|
||||
"""The content of a file
|
||||
"""
|
||||
|
||||
WRITE: str = Field(default='write')
|
||||
|
||||
BROWSE: str = Field(default='browse')
|
||||
"""The HTML content of a URL
|
||||
"""
|
||||
|
||||
RUN: str = Field(default='run')
|
||||
"""The output of a command
|
||||
"""
|
||||
|
||||
RUN_IPYTHON: str = Field(default='run_ipython')
|
||||
"""Runs a IPython cell.
|
||||
"""
|
||||
|
||||
CHAT: str = Field(default='chat')
|
||||
"""A message from the user
|
||||
"""
|
||||
|
||||
DELEGATE: str = Field(default='delegate')
|
||||
"""The result of a task delegated to another agent
|
||||
"""
|
||||
|
||||
MESSAGE: str = Field(default='message')
|
||||
|
||||
ERROR: str = Field(default='error')
|
||||
|
||||
SUCCESS: str = Field(default='success')
|
||||
|
||||
NULL: str = Field(default='null')
|
||||
|
||||
AGENT_STATE_CHANGED: str = Field(default='agent_state_changed')
|
||||
|
||||
USER_REJECTED: str = Field(default='user_rejected')
|
||||
|
||||
|
||||
ObservationType = ObservationTypeSchema()
|
||||
3
openhands/core/utils/__init__.py
Normal file
3
openhands/core/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .singleton import Singleton
|
||||
|
||||
__all__ = ['Singleton']
|
||||
49
openhands/core/utils/json.py
Normal file
49
openhands/core/utils/json.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from openhands.core.exceptions import LLMResponseError
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization import event_to_dict
|
||||
|
||||
|
||||
def my_default_encoder(obj):
|
||||
"""Custom JSON encoder that handles datetime and event objects"""
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, Event):
|
||||
return event_to_dict(obj)
|
||||
return json.JSONEncoder().default(obj)
|
||||
|
||||
|
||||
def dumps(obj, **kwargs):
|
||||
"""Serialize an object to str format"""
|
||||
return json.dumps(obj, default=my_default_encoder, **kwargs)
|
||||
|
||||
|
||||
def loads(json_str, **kwargs):
|
||||
"""Create a JSON object from str"""
|
||||
try:
|
||||
return json.loads(json_str, **kwargs)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
depth = 0
|
||||
start = -1
|
||||
for i, char in enumerate(json_str):
|
||||
if char == '{':
|
||||
if depth == 0:
|
||||
start = i
|
||||
depth += 1
|
||||
elif char == '}':
|
||||
depth -= 1
|
||||
if depth == 0 and start != -1:
|
||||
response = json_str[start : i + 1]
|
||||
try:
|
||||
json_str = repair_json(response)
|
||||
return json.loads(json_str, **kwargs)
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
raise LLMResponseError(
|
||||
'Invalid JSON in response. Please make sure the response is a valid JSON object.'
|
||||
) from e
|
||||
raise LLMResponseError('No valid JSON object found in response.')
|
||||
37
openhands/core/utils/singleton.py
Normal file
37
openhands/core/utils/singleton.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import dataclasses
|
||||
|
||||
from openhands.core import logger
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances: dict = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
else:
|
||||
# allow updates, just update existing instance
|
||||
# perhaps not the most orthodox way to do it, though it simplifies client code
|
||||
# useful for pre-defined groups of settings
|
||||
instance = cls._instances[cls]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
else:
|
||||
logger.openhands_logger.warning(
|
||||
f'Unknown key for {cls.__name__}: "{key}"'
|
||||
)
|
||||
return cls._instances[cls]
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
# used by pytest to reset the state of the singleton instances
|
||||
for instance_type, instance in cls._instances.items():
|
||||
print('resetting... ', instance_type)
|
||||
for field_info in dataclasses.fields(instance_type):
|
||||
if dataclasses.is_dataclass(field_info.type):
|
||||
setattr(instance, field_info.name, field_info.type())
|
||||
elif field_info.default_factory is not dataclasses.MISSING:
|
||||
setattr(instance, field_info.name, field_info.default_factory())
|
||||
else:
|
||||
setattr(instance, field_info.name, field_info.default)
|
||||
9
openhands/events/__init__.py
Normal file
9
openhands/events/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .event import Event, EventSource
|
||||
from .stream import EventStream, EventStreamSubscriber
|
||||
|
||||
__all__ = [
|
||||
'Event',
|
||||
'EventSource',
|
||||
'EventStream',
|
||||
'EventStreamSubscriber',
|
||||
]
|
||||
34
openhands/events/action/__init__.py
Normal file
34
openhands/events/action/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from .action import Action, ActionConfirmationStatus
|
||||
from .agent import (
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentSummarizeAction,
|
||||
ChangeAgentStateAction,
|
||||
)
|
||||
from .browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from .commands import CmdRunAction, IPythonRunCellAction
|
||||
from .empty import NullAction
|
||||
from .files import FileReadAction, FileWriteAction
|
||||
from .message import MessageAction
|
||||
from .tasks import AddTaskAction, ModifyTaskAction
|
||||
|
||||
__all__ = [
|
||||
'Action',
|
||||
'NullAction',
|
||||
'CmdRunAction',
|
||||
'BrowseURLAction',
|
||||
'BrowseInteractiveAction',
|
||||
'FileReadAction',
|
||||
'FileWriteAction',
|
||||
'AgentFinishAction',
|
||||
'AgentRejectAction',
|
||||
'AgentDelegateAction',
|
||||
'AgentSummarizeAction',
|
||||
'AddTaskAction',
|
||||
'ModifyTaskAction',
|
||||
'ChangeAgentStateAction',
|
||||
'IPythonRunCellAction',
|
||||
'MessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
]
|
||||
23
openhands/events/action/action.py
Normal file
23
openhands/events/action/action.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.events.event import Event
|
||||
|
||||
|
||||
class ActionConfirmationStatus(str, Enum):
|
||||
CONFIRMED = 'confirmed'
|
||||
REJECTED = 'rejected'
|
||||
AWAITING_CONFIRMATION = 'awaiting_confirmation'
|
||||
|
||||
|
||||
class ActionSecurityRisk(int, Enum):
|
||||
UNKNOWN = -1
|
||||
LOW = 0
|
||||
MEDIUM = 1
|
||||
HIGH = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action(Event):
|
||||
runnable: ClassVar[bool] = False
|
||||
81
openhands/events/action/agent.py
Normal file
81
openhands/events/action/agent.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
from .action import Action
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeAgentStateAction(Action):
|
||||
"""Fake action, just to notify the client that a task state has changed."""
|
||||
|
||||
agent_state: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.CHANGE_AGENT_STATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Agent state changed to {self.agent_state}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSummarizeAction(Action):
|
||||
summary: str
|
||||
action: str = ActionType.SUMMARIZE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.summary
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**AgentSummarizeAction**\n'
|
||||
ret += f'SUMMARY: {self.summary}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentFinishAction(Action):
|
||||
"""An action where the agent finishes the task.
|
||||
|
||||
Attributes:
|
||||
outputs (dict): The outputs of the agent, for instance "content".
|
||||
thought (str): The agent's explanation of its actions.
|
||||
action (str): The action type, namely ActionType.FINISH.
|
||||
"""
|
||||
|
||||
outputs: dict[str, Any] = field(default_factory=dict)
|
||||
thought: str = ''
|
||||
action: str = ActionType.FINISH
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
if self.thought != '':
|
||||
return self.thought
|
||||
return "All done! What's next on the agenda?"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRejectAction(Action):
|
||||
outputs: dict = field(default_factory=dict)
|
||||
thought: str = ''
|
||||
action: str = ActionType.REJECT
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
msg: str = 'Task is rejected by the agent.'
|
||||
if 'reason' in self.outputs:
|
||||
msg += ' Reason: ' + self.outputs['reason']
|
||||
return msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDelegateAction(Action):
|
||||
agent: str
|
||||
inputs: dict
|
||||
thought: str = ''
|
||||
action: str = ActionType.DELEGATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"I'm asking {self.agent} for help with this task."
|
||||
47
openhands/events/action/browse.py
Normal file
47
openhands/events/action/browse.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
from .action import Action, ActionSecurityRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowseURLAction(Action):
|
||||
url: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.BROWSE
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Browsing URL: {self.url}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**BrowseURLAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'URL: {self.url}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowseInteractiveAction(Action):
|
||||
browser_actions: str
|
||||
thought: str = ''
|
||||
browsergym_send_msg_to_user: str = ''
|
||||
action: str = ActionType.BROWSE_INTERACTIVE
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Executing browser actions: {self.browser_actions}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**BrowseInteractiveAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'BROWSER_ACTIONS: {self.browser_actions}'
|
||||
return ret
|
||||
57
openhands/events/action/commands.py
Normal file
57
openhands/events/action/commands.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
from .action import Action, ActionConfirmationStatus, ActionSecurityRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class CmdRunAction(Action):
|
||||
command: str
|
||||
thought: str = ''
|
||||
keep_prompt: bool = True
|
||||
# if True, the command prompt will be kept in the command output observation
|
||||
# Example of command output:
|
||||
# root@sandbox:~# ls
|
||||
# file1.txt
|
||||
# file2.txt
|
||||
# root@sandbox:~# <-- this is the command prompt
|
||||
|
||||
action: str = ActionType.RUN
|
||||
runnable: ClassVar[bool] = True
|
||||
is_confirmed: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Running command: {self.command}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = f'**CmdRunAction (source={self.source})**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'COMMAND:\n{self.command}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPythonRunCellAction(Action):
|
||||
code: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.RUN_IPYTHON
|
||||
runnable: ClassVar[bool] = True
|
||||
is_confirmed: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
kernel_init_code: str = '' # code to run in the kernel (if the kernel is restarted)
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**IPythonRunCellAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'CODE:\n{self.code}'
|
||||
return ret
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Running Python code interactively: {self.code}'
|
||||
16
openhands/events/action/empty.py
Normal file
16
openhands/events/action/empty.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
from .action import Action
|
||||
|
||||
|
||||
@dataclass
|
||||
class NullAction(Action):
|
||||
"""An action that does nothing."""
|
||||
|
||||
action: str = ActionType.NULL
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'No action'
|
||||
42
openhands/events/action/files.py
Normal file
42
openhands/events/action/files.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
from .action import Action, ActionSecurityRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileReadAction(Action):
|
||||
"""Reads a file from a given path.
|
||||
Can be set to read specific lines using start and end
|
||||
Default lines 0:-1 (whole file)
|
||||
"""
|
||||
|
||||
path: str
|
||||
start: int = 0
|
||||
end: int = -1
|
||||
thought: str = ''
|
||||
action: str = ActionType.READ
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Reading file: {self.path}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileWriteAction(Action):
|
||||
path: str
|
||||
content: str
|
||||
start: int = 0
|
||||
end: int = -1
|
||||
thought: str = ''
|
||||
action: str = ActionType.WRITE
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Writing file: {self.path}'
|
||||
26
openhands/events/action/message.py
Normal file
26
openhands/events/action/message.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
from .action import Action, ActionSecurityRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageAction(Action):
|
||||
content: str
|
||||
images_urls: list | None = None
|
||||
wait_for_response: bool = False
|
||||
action: str = ActionType.MESSAGE
|
||||
security_risk: ActionSecurityRisk | None = None
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = f'**MessageAction** (source={self.source})\n'
|
||||
ret += f'CONTENT: {self.content}'
|
||||
if self.images_urls:
|
||||
for url in self.images_urls:
|
||||
ret += f'\nIMAGE_URL: {url}'
|
||||
return ret
|
||||
30
openhands/events/action/tasks.py
Normal file
30
openhands/events/action/tasks.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
|
||||
from .action import Action
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddTaskAction(Action):
|
||||
parent: str
|
||||
goal: str
|
||||
subtasks: list = field(default_factory=list)
|
||||
thought: str = ''
|
||||
action: str = ActionType.ADD_TASK
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Added task: {self.goal}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModifyTaskAction(Action):
|
||||
task_id: str
|
||||
state: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.MODIFY_TASK
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Set task {self.task_id} to {self.state}'
|
||||
51
openhands/events/event.py
Normal file
51
openhands/events/event.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EventSource(str, Enum):
|
||||
AGENT = 'agent'
|
||||
USER = 'user'
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
@property
|
||||
def message(self) -> str | None:
|
||||
if hasattr(self, '_message'):
|
||||
return self._message # type: ignore[attr-defined]
|
||||
return ''
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if hasattr(self, '_id'):
|
||||
return self._id # type: ignore[attr-defined]
|
||||
return -1
|
||||
|
||||
@property
|
||||
def timestamp(self) -> datetime.datetime | None:
|
||||
if hasattr(self, '_timestamp'):
|
||||
return self._timestamp # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@property
|
||||
def source(self) -> EventSource | None:
|
||||
if hasattr(self, '_source'):
|
||||
return self._source # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@property
|
||||
def cause(self) -> int | None:
|
||||
if hasattr(self, '_cause'):
|
||||
return self._cause # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@property
|
||||
def timeout(self) -> int | None:
|
||||
if hasattr(self, '_timeout'):
|
||||
return self._timeout # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@timeout.setter
|
||||
def timeout(self, value: int | None) -> None:
|
||||
self._timeout = value
|
||||
25
openhands/events/observation/__init__.py
Normal file
25
openhands/events/observation/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .agent import AgentStateChangedObservation
|
||||
from .browse import BrowserOutputObservation
|
||||
from .commands import CmdOutputObservation, IPythonRunCellObservation
|
||||
from .delegate import AgentDelegateObservation
|
||||
from .empty import NullObservation
|
||||
from .error import ErrorObservation
|
||||
from .files import FileReadObservation, FileWriteObservation
|
||||
from .observation import Observation
|
||||
from .reject import UserRejectObservation
|
||||
from .success import SuccessObservation
|
||||
|
||||
__all__ = [
|
||||
'Observation',
|
||||
'NullObservation',
|
||||
'CmdOutputObservation',
|
||||
'IPythonRunCellObservation',
|
||||
'BrowserOutputObservation',
|
||||
'FileReadObservation',
|
||||
'FileWriteObservation',
|
||||
'ErrorObservation',
|
||||
'AgentStateChangedObservation',
|
||||
'AgentDelegateObservation',
|
||||
'SuccessObservation',
|
||||
'UserRejectObservation',
|
||||
]
|
||||
17
openhands/events/observation/agent.py
Normal file
17
openhands/events/observation/agent.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStateChangedObservation(Observation):
|
||||
"""This data class represents the result from delegating to another agent"""
|
||||
|
||||
agent_state: str
|
||||
observation: str = ObservationType.AGENT_STATE_CHANGED
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return ''
|
||||
44
openhands/events/observation/browse.py
Normal file
44
openhands/events/observation/browse.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserOutputObservation(Observation):
|
||||
"""This data class represents the output of a browser."""
|
||||
|
||||
url: str
|
||||
screenshot: str = field(repr=False) # don't show in repr
|
||||
error: bool = False
|
||||
observation: str = ObservationType.BROWSE
|
||||
# do not include in the memory
|
||||
open_pages_urls: list = field(default_factory=list)
|
||||
active_page_index: int = -1
|
||||
dom_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
axtree_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
extra_element_properties: dict = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
last_browser_action: str = ''
|
||||
last_browser_action_error: str = ''
|
||||
focused_element_bid: str = ''
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'Visited ' + self.url
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
'**BrowserOutputObservation**\n'
|
||||
f'URL: {self.url}\n'
|
||||
f'Error: {self.error}\n'
|
||||
f'Open pages: {self.open_pages_urls}\n'
|
||||
f'Active page index: {self.active_page_index}\n'
|
||||
f'Last browser action: {self.last_browser_action}\n'
|
||||
f'Last browser action error: {self.last_browser_action_error}\n'
|
||||
f'Focused element bid: {self.focused_element_bid}\n'
|
||||
f'axTree: {self.axtree_object}\n'
|
||||
f'CONTENT: {self.content}\n'
|
||||
)
|
||||
45
openhands/events/observation/commands.py
Normal file
45
openhands/events/observation/commands.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class CmdOutputObservation(Observation):
|
||||
"""This data class represents the output of a command."""
|
||||
|
||||
command_id: int
|
||||
command: str
|
||||
exit_code: int = 0
|
||||
observation: str = ObservationType.RUN
|
||||
|
||||
@property
|
||||
def error(self) -> bool:
|
||||
return self.exit_code != 0
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPythonRunCellObservation(Observation):
|
||||
"""This data class represents the output of a IPythonRunCellAction."""
|
||||
|
||||
code: str
|
||||
observation: str = ObservationType.RUN_IPYTHON
|
||||
|
||||
@property
|
||||
def error(self) -> bool:
|
||||
return False # IPython cells do not return exit codes
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'Code executed in IPython cell.'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'**IPythonRunCellObservation**\n{self.content}'
|
||||
23
openhands/events/observation/delegate.py
Normal file
23
openhands/events/observation/delegate.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDelegateObservation(Observation):
|
||||
"""This data class represents the result from delegating to another agent.
|
||||
|
||||
Attributes:
|
||||
content (str): The content of the observation.
|
||||
outputs (dict): The outputs of the delegated agent.
|
||||
observation (str): The type of observation.
|
||||
"""
|
||||
|
||||
outputs: dict
|
||||
observation: str = ObservationType.DELEGATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return ''
|
||||
18
openhands/events/observation/empty.py
Normal file
18
openhands/events/observation/empty.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class NullObservation(Observation):
|
||||
"""This data class represents a null observation.
|
||||
This is used when the produced action is NOT executable.
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.NULL
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'No observation'
|
||||
16
openhands/events/observation/error.py
Normal file
16
openhands/events/observation/error.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorObservation(Observation):
|
||||
"""This data class represents an error encountered by the agent."""
|
||||
|
||||
observation: str = ObservationType.ERROR
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
29
openhands/events/observation/files.py
Normal file
29
openhands/events/observation/files.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileReadObservation(Observation):
|
||||
"""This data class represents the content of a file."""
|
||||
|
||||
path: str
|
||||
observation: str = ObservationType.READ
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I read the file {self.path}.'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileWriteObservation(Observation):
|
||||
"""This data class represents a file write operation"""
|
||||
|
||||
path: str
|
||||
observation: str = ObservationType.WRITE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I wrote to the file {self.path}.'
|
||||
8
openhands/events/observation/observation.py
Normal file
8
openhands/events/observation/observation.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.events.event import Event
|
||||
|
||||
|
||||
@dataclass
|
||||
class Observation(Event):
|
||||
content: str
|
||||
16
openhands/events/observation/reject.py
Normal file
16
openhands/events/observation/reject.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRejectObservation(Observation):
|
||||
"""This data class represents the result of a successful action."""
|
||||
|
||||
observation: str = ObservationType.USER_REJECTED
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
16
openhands/events/observation/success.py
Normal file
16
openhands/events/observation/success.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
|
||||
from .observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuccessObservation(Observation):
|
||||
"""This data class represents the result of a successful action."""
|
||||
|
||||
observation: str = ObservationType.SUCCESS
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
19
openhands/events/serialization/__init__.py
Normal file
19
openhands/events/serialization/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .action import (
|
||||
action_from_dict,
|
||||
)
|
||||
from .event import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
event_to_memory,
|
||||
)
|
||||
from .observation import (
|
||||
observation_from_dict,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'action_from_dict',
|
||||
'event_from_dict',
|
||||
'event_to_dict',
|
||||
'event_to_memory',
|
||||
'observation_from_dict',
|
||||
]
|
||||
61
openhands/events/serialization/action.py
Normal file
61
openhands/events/serialization/action.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from openhands.core.exceptions import LLMMalformedActionError
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.agent import (
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
ChangeAgentStateAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import (
|
||||
CmdRunAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.files import FileReadAction, FileWriteAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.tasks import AddTaskAction, ModifyTaskAction
|
||||
|
||||
actions = (
|
||||
NullAction,
|
||||
CmdRunAction,
|
||||
IPythonRunCellAction,
|
||||
BrowseURLAction,
|
||||
BrowseInteractiveAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentDelegateAction,
|
||||
AddTaskAction,
|
||||
ModifyTaskAction,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
)
|
||||
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def action_from_dict(action: dict) -> Action:
|
||||
if not isinstance(action, dict):
|
||||
raise LLMMalformedActionError('action must be a dictionary')
|
||||
action = action.copy()
|
||||
if 'action' not in action:
|
||||
raise LLMMalformedActionError(f"'action' key is not found in {action=}")
|
||||
if not isinstance(action['action'], str):
|
||||
raise LLMMalformedActionError(
|
||||
f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}"
|
||||
)
|
||||
action_class = ACTION_TYPE_TO_CLASS.get(action['action'])
|
||||
if action_class is None:
|
||||
raise LLMMalformedActionError(
|
||||
f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}"
|
||||
)
|
||||
args = action.get('args', {})
|
||||
try:
|
||||
decoded_action = action_class(**args)
|
||||
if 'timeout' in action:
|
||||
decoded_action.timeout = action['timeout']
|
||||
except TypeError:
|
||||
raise LLMMalformedActionError(f'action={action} has the wrong arguments')
|
||||
return decoded_action
|
||||
101
openhands/events/serialization/event.py
Normal file
101
openhands/events/serialization/event.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
|
||||
from openhands.events import Event, EventSource
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
from .action import action_from_dict
|
||||
from .observation import observation_from_dict
|
||||
from .utils import remove_fields
|
||||
|
||||
# TODO: move `content` into `extras`
|
||||
TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation']
|
||||
UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause']
|
||||
|
||||
DELETE_FROM_MEMORY_EXTRAS = {
|
||||
'screenshot',
|
||||
'dom_object',
|
||||
'axtree_object',
|
||||
'open_pages_urls',
|
||||
'active_page_index',
|
||||
'last_browser_action',
|
||||
'last_browser_action_error',
|
||||
'focused_element_bid',
|
||||
'extra_element_properties',
|
||||
}
|
||||
|
||||
|
||||
def event_from_dict(data) -> 'Event':
|
||||
evt: Event
|
||||
if 'action' in data:
|
||||
evt = action_from_dict(data)
|
||||
elif 'observation' in data:
|
||||
evt = observation_from_dict(data)
|
||||
else:
|
||||
raise ValueError('Unknown event type: ' + data)
|
||||
for key in UNDERSCORE_KEYS:
|
||||
if key in data:
|
||||
value = data[key]
|
||||
if key == 'timestamp':
|
||||
value = datetime.fromisoformat(value)
|
||||
if key == 'source':
|
||||
value = EventSource(value)
|
||||
setattr(evt, '_' + key, value)
|
||||
return evt
|
||||
|
||||
|
||||
def event_to_dict(event: 'Event') -> dict:
|
||||
props = asdict(event)
|
||||
d = {}
|
||||
for key in TOP_KEYS:
|
||||
if hasattr(event, key) and getattr(event, key) is not None:
|
||||
d[key] = getattr(event, key)
|
||||
elif hasattr(event, f'_{key}') and getattr(event, f'_{key}') is not None:
|
||||
d[key] = getattr(event, f'_{key}')
|
||||
if key == 'id' and d.get('id') == -1:
|
||||
d.pop('id', None)
|
||||
if key == 'timestamp' and 'timestamp' in d:
|
||||
d['timestamp'] = d['timestamp'].isoformat()
|
||||
if key == 'source' and 'source' in d:
|
||||
d['source'] = d['source'].value
|
||||
props.pop(key, None)
|
||||
if 'security_risk' in props and props['security_risk'] is None:
|
||||
props.pop('security_risk')
|
||||
if 'action' in d:
|
||||
d['args'] = props
|
||||
if event.timeout is not None:
|
||||
d['timeout'] = event.timeout
|
||||
elif 'observation' in d:
|
||||
d['content'] = props.pop('content', '')
|
||||
d['extras'] = props
|
||||
else:
|
||||
raise ValueError('Event must be either action or observation')
|
||||
return d
|
||||
|
||||
|
||||
def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
|
||||
d = event_to_dict(event)
|
||||
d.pop('id', None)
|
||||
d.pop('cause', None)
|
||||
d.pop('timestamp', None)
|
||||
d.pop('message', None)
|
||||
d.pop('images_urls', None)
|
||||
if 'extras' in d:
|
||||
remove_fields(d['extras'], DELETE_FROM_MEMORY_EXTRAS)
|
||||
if isinstance(event, Observation) and 'content' in d:
|
||||
d['content'] = truncate_content(d['content'], max_message_chars)
|
||||
return d
|
||||
|
||||
|
||||
def truncate_content(content: str, max_chars: int) -> str:
|
||||
"""Truncate the middle of the observation content if it is too long."""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
|
||||
# truncate the middle and include a message to the LLM about it
|
||||
half = max_chars // 2
|
||||
return (
|
||||
content[:half]
|
||||
+ '\n[... Observation truncated due to length ...]\n'
|
||||
+ content[-half:]
|
||||
)
|
||||
48
openhands/events/serialization/observation.py
Normal file
48
openhands/events/serialization/observation.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.files import FileReadObservation, FileWriteObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
|
||||
observations = (
|
||||
NullObservation,
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
BrowserOutputObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
AgentDelegateObservation,
|
||||
SuccessObservation,
|
||||
ErrorObservation,
|
||||
AgentStateChangedObservation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
|
||||
OBSERVATION_TYPE_TO_CLASS = {
|
||||
observation_class.observation: observation_class # type: ignore[attr-defined]
|
||||
for observation_class in observations
|
||||
}
|
||||
|
||||
|
||||
def observation_from_dict(observation: dict) -> Observation:
|
||||
observation = observation.copy()
|
||||
if 'observation' not in observation:
|
||||
raise KeyError(f"'observation' key is not found in {observation=}")
|
||||
observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation'])
|
||||
if observation_class is None:
|
||||
raise KeyError(
|
||||
f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}"
|
||||
)
|
||||
observation.pop('observation')
|
||||
observation.pop('message', None)
|
||||
content = observation.pop('content', '')
|
||||
extras = observation.pop('extras', {})
|
||||
return observation_class(content=content, **extras)
|
||||
20
openhands/events/serialization/utils.py
Normal file
20
openhands/events/serialization/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
def remove_fields(obj, fields: set[str]):
|
||||
"""Remove fields from an object.
|
||||
|
||||
Parameters:
|
||||
- obj: The dictionary, or list of dictionaries to remove fields from
|
||||
- fields (set[str]): A set of field names to remove from the object
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
for field in fields:
|
||||
if field in obj:
|
||||
del obj[field]
|
||||
for _, value in obj.items():
|
||||
remove_fields(value, fields)
|
||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for item in obj:
|
||||
remove_fields(item, fields)
|
||||
elif hasattr(obj, '__dataclass_fields__'):
|
||||
raise ValueError(
|
||||
'Object must not contain dataclass, consider converting to dict first'
|
||||
)
|
||||
155
openhands/events/stream.py
Normal file
155
openhands/events/stream.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.utils import json
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.storage import FileStore
|
||||
|
||||
from .event import Event, EventSource
|
||||
|
||||
|
||||
class EventStreamSubscriber(str, Enum):
|
||||
AGENT_CONTROLLER = 'agent_controller'
|
||||
SECURITY_ANALYZER = 'security_analyzer'
|
||||
SERVER = 'server'
|
||||
RUNTIME = 'runtime'
|
||||
MAIN = 'main'
|
||||
TEST = 'test'
|
||||
|
||||
|
||||
class EventStream:
|
||||
sid: str
|
||||
file_store: FileStore
|
||||
# For each subscriber ID, there is a stack of callback functions - useful
|
||||
# when there are agent delegates
|
||||
_subscribers: dict[str, list[Callable]]
|
||||
_cur_id: int
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore):
|
||||
self.sid = sid
|
||||
self.file_store = file_store
|
||||
self._subscribers = {}
|
||||
self._cur_id = 0
|
||||
self._lock = threading.Lock()
|
||||
self._reinitialize_from_file_store()
|
||||
|
||||
def _reinitialize_from_file_store(self) -> None:
|
||||
try:
|
||||
events = self.file_store.list(f'sessions/{self.sid}/events')
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'No events found for session {self.sid}')
|
||||
self._cur_id = 0
|
||||
return
|
||||
|
||||
# if we have events, we need to find the highest id to prepare for new events
|
||||
for event_str in events:
|
||||
id = self._get_id_from_filename(event_str)
|
||||
if id >= self._cur_id:
|
||||
self._cur_id = id + 1
|
||||
|
||||
def _get_filename_for_id(self, id: int) -> str:
|
||||
return f'sessions/{self.sid}/events/{id}.json'
|
||||
|
||||
@staticmethod
|
||||
def _get_id_from_filename(filename: str) -> int:
|
||||
try:
|
||||
return int(filename.split('/')[-1].split('.')[0])
|
||||
except ValueError:
|
||||
logger.warning(f'get id from filename ({filename}) failed.')
|
||||
return -1
|
||||
|
||||
def get_events(
|
||||
self,
|
||||
start_id=0,
|
||||
end_id=None,
|
||||
reverse=False,
|
||||
filter_out_type: tuple[type[Event], ...] | None = None,
|
||||
) -> Iterable[Event]:
|
||||
if reverse:
|
||||
if end_id is None:
|
||||
end_id = self._cur_id - 1
|
||||
event_id = end_id
|
||||
while event_id >= start_id:
|
||||
try:
|
||||
event = self.get_event(event_id)
|
||||
if filter_out_type is None or not isinstance(
|
||||
event, filter_out_type
|
||||
):
|
||||
yield event
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'No event found for ID {event_id}')
|
||||
event_id -= 1
|
||||
else:
|
||||
event_id = start_id
|
||||
while True:
|
||||
if end_id is not None and event_id > end_id:
|
||||
break
|
||||
try:
|
||||
event = self.get_event(event_id)
|
||||
if filter_out_type is None or not isinstance(
|
||||
event, filter_out_type
|
||||
):
|
||||
yield event
|
||||
except FileNotFoundError:
|
||||
break
|
||||
event_id += 1
|
||||
|
||||
def get_event(self, id: int) -> Event:
|
||||
filename = self._get_filename_for_id(id)
|
||||
content = self.file_store.read(filename)
|
||||
data = json.loads(content)
|
||||
return event_from_dict(data)
|
||||
|
||||
def get_latest_event(self) -> Event:
|
||||
return self.get_event(self._cur_id - 1)
|
||||
|
||||
def get_latest_event_id(self) -> int:
|
||||
return self._cur_id - 1
|
||||
|
||||
def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
|
||||
if id in self._subscribers:
|
||||
if append:
|
||||
self._subscribers[id].append(callback)
|
||||
else:
|
||||
raise ValueError('Subscriber already exists: ' + id)
|
||||
else:
|
||||
self._subscribers[id] = [callback]
|
||||
|
||||
def unsubscribe(self, id: EventStreamSubscriber):
|
||||
if id not in self._subscribers:
|
||||
logger.warning('Subscriber not found during unsubscribe: ' + id)
|
||||
else:
|
||||
self._subscribers[id].pop()
|
||||
if len(self._subscribers[id]) == 0:
|
||||
del self._subscribers[id]
|
||||
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
with self._lock:
|
||||
event._id = self._cur_id # type: ignore [attr-defined]
|
||||
self._cur_id += 1
|
||||
logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
|
||||
event._timestamp = datetime.now() # type: ignore [attr-defined]
|
||||
event._source = source # type: ignore [attr-defined]
|
||||
data = event_to_dict(event)
|
||||
if event.id is not None:
|
||||
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
|
||||
for key in sorted(self._subscribers.keys()):
|
||||
stack = self._subscribers[key]
|
||||
callback = stack[-1]
|
||||
asyncio.create_task(callback(event))
|
||||
|
||||
def filtered_events_by_source(self, source: EventSource):
|
||||
for event in self.get_events():
|
||||
if event.source == source:
|
||||
yield event
|
||||
|
||||
def clear(self):
|
||||
self.file_store.delete(f'sessions/{self.sid}')
|
||||
self._cur_id = 0
|
||||
# self._subscribers = {}
|
||||
self._reinitialize_from_file_store()
|
||||
32
openhands/llm/bedrock.py
Normal file
32
openhands/llm/bedrock.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import boto3
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def list_foundation_models(
|
||||
aws_region_name: str, aws_access_key_id: str, aws_secret_access_key: str
|
||||
) -> list[str]:
|
||||
try:
|
||||
# The AWS bedrock model id is not queried, if no AWS parameters are configured.
|
||||
client = boto3.client(
|
||||
service_name='bedrock',
|
||||
region_name=aws_region_name,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
foundation_models_list = client.list_foundation_models(
|
||||
byOutputModality='TEXT', byInferenceType='ON_DEMAND'
|
||||
)
|
||||
model_summaries = foundation_models_list['modelSummaries']
|
||||
return ['bedrock/' + model['modelId'] for model in model_summaries]
|
||||
except Exception as err:
|
||||
logger.warning(
|
||||
'%s. Please config AWS_REGION_NAME AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY'
|
||||
' if you want use bedrock model.',
|
||||
err,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def remove_error_modelId(model_list):
|
||||
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
|
||||
511
openhands/llm/llm.py
Normal file
511
openhands/llm/llm.py
Normal file
@@ -0,0 +1,511 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
import litellm
|
||||
from litellm import completion as litellm_completion
|
||||
from litellm import completion_cost as litellm_completion_cost
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
ContentPolicyViolationError,
|
||||
InternalServerError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from litellm.types.utils import CostPerToken
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from openhands.core.exceptions import UserCancelledError
|
||||
from openhands.core.logger import llm_prompt_logger, llm_response_logger
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.metrics import Metrics
|
||||
|
||||
__all__ = ['LLM']
|
||||
|
||||
message_separator = '\n\n----------\n\n'
|
||||
|
||||
|
||||
class LLM:
|
||||
"""The LLM class represents a Language Model instance.
|
||||
|
||||
Attributes:
|
||||
config: an LLMConfig object specifying the configuration of the LLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig,
|
||||
metrics: Metrics | None = None,
|
||||
):
|
||||
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
|
||||
|
||||
Passing simple parameters always overrides config.
|
||||
|
||||
Args:
|
||||
config: The LLM configuration
|
||||
"""
|
||||
self.config = copy.deepcopy(config)
|
||||
self.metrics = metrics if metrics is not None else Metrics()
|
||||
self.cost_metric_supported = True
|
||||
|
||||
# Set up config attributes with default values to prevent AttributeError
|
||||
LLMConfig.set_missing_attributes(self.config)
|
||||
|
||||
# litellm actually uses base Exception here for unknown model
|
||||
self.model_info = None
|
||||
try:
|
||||
if self.config.model.startswith('openrouter'):
|
||||
self.model_info = litellm.get_model_info(self.config.model)
|
||||
else:
|
||||
self.model_info = litellm.get_model_info(
|
||||
self.config.model.split(':')[0]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not get model info for {config.model}:\n{e}')
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_input_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_input_tokens'], int)
|
||||
):
|
||||
self.config.max_input_tokens = self.model_info['max_input_tokens']
|
||||
else:
|
||||
# Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
|
||||
self.config.max_input_tokens = 4096
|
||||
|
||||
if self.config.max_output_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_output_tokens' in self.model_info
|
||||
and isinstance(self.model_info['max_output_tokens'], int)
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_output_tokens']
|
||||
else:
|
||||
# Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
|
||||
self.config.max_output_tokens = 1024
|
||||
|
||||
if self.config.drop_params:
|
||||
litellm.drop_params = self.config.drop_params
|
||||
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
max_tokens=self.config.max_output_tokens,
|
||||
timeout=self.config.timeout,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
completion_unwrapped = self._completion
|
||||
|
||||
def attempt_on_error(retry_state):
|
||||
logger.error(
|
||||
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
|
||||
exc_info=False,
|
||||
)
|
||||
return None
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=self.config.retry_multiplier,
|
||||
min=self.config.retry_min_wait,
|
||||
max=self.config.retry_max_wait,
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
ContentPolicyViolationError,
|
||||
)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1]
|
||||
|
||||
# log the prompt
|
||||
debug_message = ''
|
||||
for message in messages:
|
||||
content = message['content']
|
||||
|
||||
if isinstance(content, list):
|
||||
for element in content:
|
||||
if isinstance(element, dict):
|
||||
if 'text' in element:
|
||||
content_str = element['text'].strip()
|
||||
elif (
|
||||
'image_url' in element and 'url' in element['image_url']
|
||||
):
|
||||
content_str = element['image_url']['url']
|
||||
else:
|
||||
content_str = str(element)
|
||||
else:
|
||||
content_str = str(element)
|
||||
|
||||
debug_message += message_separator + content_str
|
||||
else:
|
||||
content_str = str(content)
|
||||
debug_message += message_separator + content_str
|
||||
|
||||
llm_prompt_logger.debug(debug_message)
|
||||
|
||||
# skip if messages is empty (thus debug_message is empty)
|
||||
if debug_message:
|
||||
resp = completion_unwrapped(*args, **kwargs)
|
||||
else:
|
||||
resp = {'choices': [{'message': {'content': ''}}]}
|
||||
|
||||
# log the response
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
llm_response_logger.debug(message_back)
|
||||
|
||||
# post-process to log costs
|
||||
self._post_completion(resp)
|
||||
|
||||
return resp
|
||||
|
||||
self._completion = wrapper # type: ignore
|
||||
|
||||
# Async version
|
||||
self._async_completion = partial(
|
||||
self._call_acompletion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
max_tokens=self.config.max_output_tokens,
|
||||
timeout=self.config.timeout,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
async_completion_unwrapped = self._async_completion
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=self.config.retry_multiplier,
|
||||
min=self.config.retry_min_wait,
|
||||
max=self.config.retry_max_wait,
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
ContentPolicyViolationError,
|
||||
)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
)
|
||||
async def async_completion_wrapper(*args, **kwargs):
|
||||
"""Async wrapper for the litellm acompletion function."""
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1]
|
||||
|
||||
# log the prompt
|
||||
debug_message = ''
|
||||
for message in messages:
|
||||
content = message['content']
|
||||
|
||||
if isinstance(content, list):
|
||||
for element in content:
|
||||
if isinstance(element, dict):
|
||||
if 'text' in element:
|
||||
content_str = element['text']
|
||||
elif (
|
||||
'image_url' in element and 'url' in element['image_url']
|
||||
):
|
||||
content_str = element['image_url']['url']
|
||||
else:
|
||||
content_str = str(element)
|
||||
else:
|
||||
content_str = str(element)
|
||||
|
||||
debug_message += message_separator + content_str
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
debug_message += message_separator + content_str
|
||||
|
||||
llm_prompt_logger.debug(debug_message)
|
||||
|
||||
async def check_stopped():
|
||||
while True:
|
||||
if (
|
||||
hasattr(self.config, 'on_cancel_requested_fn')
|
||||
and self.config.on_cancel_requested_fn is not None
|
||||
and await self.config.on_cancel_requested_fn()
|
||||
):
|
||||
raise UserCancelledError('LLM request cancelled by user')
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
stop_check_task = asyncio.create_task(check_stopped())
|
||||
|
||||
try:
|
||||
# Directly call and await litellm_acompletion
|
||||
resp = await async_completion_unwrapped(*args, **kwargs)
|
||||
|
||||
# skip if messages is empty (thus debug_message is empty)
|
||||
if debug_message:
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
llm_response_logger.debug(message_back)
|
||||
else:
|
||||
resp = {'choices': [{'message': {'content': ''}}]}
|
||||
self._post_completion(resp)
|
||||
|
||||
# We do not support streaming in this method, thus return resp
|
||||
return resp
|
||||
|
||||
except UserCancelledError:
|
||||
logger.info('LLM request cancelled by user.')
|
||||
raise
|
||||
except OpenAIError as e:
|
||||
logger.error(f'OpenAIError occurred:\n{e}')
|
||||
raise
|
||||
except (
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
) as e:
|
||||
logger.error(f'Completion Error occurred:\n{e}')
|
||||
raise
|
||||
|
||||
finally:
|
||||
await asyncio.sleep(0.1)
|
||||
stop_check_task.cancel()
|
||||
try:
|
||||
await stop_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=self.config.retry_multiplier,
|
||||
min=self.config.retry_min_wait,
|
||||
max=self.config.retry_max_wait,
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
ContentPolicyViolationError,
|
||||
)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
)
|
||||
async def async_acompletion_stream_wrapper(*args, **kwargs):
|
||||
"""Async wrapper for the litellm acompletion with streaming function."""
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1]
|
||||
|
||||
# log the prompt
|
||||
debug_message = ''
|
||||
for message in messages:
|
||||
debug_message += message_separator + message['content']
|
||||
llm_prompt_logger.debug(debug_message)
|
||||
|
||||
try:
|
||||
# Directly call and await litellm_acompletion
|
||||
resp = await async_completion_unwrapped(*args, **kwargs)
|
||||
|
||||
# For streaming we iterate over the chunks
|
||||
async for chunk in resp:
|
||||
# Check for cancellation before yielding the chunk
|
||||
if (
|
||||
hasattr(self.config, 'on_cancel_requested_fn')
|
||||
and self.config.on_cancel_requested_fn is not None
|
||||
and await self.config.on_cancel_requested_fn()
|
||||
):
|
||||
raise UserCancelledError(
|
||||
'LLM request cancelled due to CANCELLED state'
|
||||
)
|
||||
# with streaming, it is "delta", not "message"!
|
||||
message_back = chunk['choices'][0]['delta']['content']
|
||||
llm_response_logger.debug(message_back)
|
||||
self._post_completion(chunk)
|
||||
|
||||
yield chunk
|
||||
|
||||
except UserCancelledError:
|
||||
logger.info('LLM request cancelled by user.')
|
||||
raise
|
||||
except OpenAIError as e:
|
||||
logger.error(f'OpenAIError occurred:\n{e}')
|
||||
raise
|
||||
except (
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
) as e:
|
||||
logger.error(f'Completion Error occurred:\n{e}')
|
||||
raise
|
||||
|
||||
finally:
|
||||
if kwargs.get('stream', False):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self._async_completion = async_completion_wrapper # type: ignore
|
||||
self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
|
||||
|
||||
async def _call_acompletion(self, *args, **kwargs):
|
||||
return await litellm.acompletion(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def completion(self):
|
||||
"""Decorator for the litellm completion function.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/completion
|
||||
"""
|
||||
return self._completion
|
||||
|
||||
@property
|
||||
def async_completion(self):
|
||||
"""Decorator for the async litellm acompletion function.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
||||
"""
|
||||
return self._async_completion
|
||||
|
||||
@property
|
||||
def async_streaming_completion(self):
|
||||
"""Decorator for the async litellm acompletion function with streaming.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
||||
"""
|
||||
return self._async_streaming_completion
|
||||
|
||||
def supports_vision(self):
|
||||
return litellm.supports_vision(self.config.model)
|
||||
|
||||
def _post_completion(self, response: str) -> None:
|
||||
"""Post-process the completion response."""
|
||||
try:
|
||||
cur_cost = self.completion_cost(response)
|
||||
except Exception:
|
||||
cur_cost = 0
|
||||
if self.cost_metric_supported:
|
||||
logger.info(
|
||||
'Cost: %.2f USD | Accumulated Cost: %.2f USD',
|
||||
cur_cost,
|
||||
self.metrics.accumulated_cost,
|
||||
)
|
||||
|
||||
def get_token_count(self, messages):
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages (list): A list of messages.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
return litellm.token_counter(model=self.config.model, messages=messages)
|
||||
|
||||
def is_local(self):
|
||||
"""Determines if the system is using a locally running LLM.
|
||||
|
||||
Returns:
|
||||
boolean: True if executing a local model.
|
||||
"""
|
||||
if self.config.base_url is not None:
|
||||
for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
|
||||
if substring in self.config.base_url:
|
||||
return True
|
||||
elif self.config.model is not None:
|
||||
if self.config.model.startswith('ollama'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def completion_cost(self, response):
|
||||
"""Calculate the cost of a completion response based on the model. Local models are treated as free.
|
||||
Add the current cost into total cost in metrics.
|
||||
|
||||
Args:
|
||||
response: A response from a model invocation.
|
||||
|
||||
Returns:
|
||||
number: The cost of the response.
|
||||
"""
|
||||
if not self.cost_metric_supported:
|
||||
return 0.0
|
||||
|
||||
extra_kwargs = {}
|
||||
if (
|
||||
self.config.input_cost_per_token is not None
|
||||
and self.config.output_cost_per_token is not None
|
||||
):
|
||||
cost_per_token = CostPerToken(
|
||||
input_cost_per_token=self.config.input_cost_per_token,
|
||||
output_cost_per_token=self.config.output_cost_per_token,
|
||||
)
|
||||
logger.info(f'Using custom cost per token: {cost_per_token}')
|
||||
extra_kwargs['custom_cost_per_token'] = cost_per_token
|
||||
|
||||
if not self.is_local():
|
||||
try:
|
||||
cost = litellm_completion_cost(
|
||||
completion_response=response, **extra_kwargs
|
||||
)
|
||||
self.metrics.add_cost(cost)
|
||||
return cost
|
||||
except Exception:
|
||||
self.cost_metric_supported = False
|
||||
logger.warning('Cost calculation not supported for this model.')
|
||||
return 0.0
|
||||
|
||||
def __str__(self):
|
||||
if self.config.api_version:
|
||||
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
|
||||
elif self.config.base_url:
|
||||
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
|
||||
return f'LLM(model={self.config.model})'
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def reset(self):
|
||||
self.metrics = Metrics()
|
||||
23
openhands/memory/README.md
Normal file
23
openhands/memory/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Memory Component
|
||||
|
||||
- Short Term History
|
||||
- Memory Condenser
|
||||
- Long Term Memory
|
||||
|
||||
## Short Term History
|
||||
- Short term history filters the event stream and computes the messages that are injected into the context
|
||||
- It filters out certain events of no interest for the Agent, such as AgentChangeStateObservation or NullAction/NullObservation
|
||||
- When the context window or the token limit set by the user is exceeded, history starts condensing: chunks of messages into summaries.
|
||||
- Each summary is then injected into the context, in the place of the respective chunk it summarizes
|
||||
|
||||
## Memory Condenser
|
||||
- Memory condenser is responsible for summarizing the chunks of events
|
||||
- It summarizes the earlier events first
|
||||
- It starts with the earliest agent actions and observations between two user messages
|
||||
- Then it does the same for later chunks of events between user messages
|
||||
- If there are no more agent events, it summarizes the user messages, this time one by one, if they're large enough and not immediately after an AgentFinishAction event (we assume those are tasks, potentially important)
|
||||
- Summaries are retrieved from the LLM as AgentSummarizeAction, and are saved in State.
|
||||
|
||||
## Long Term Memory
|
||||
- Long term memory component stores embeddings for events and prompts in a vector store
|
||||
- The agent can query it when it needs detailed information about a past event or to learn new actions
|
||||
5
openhands/memory/__init__.py
Normal file
5
openhands/memory/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .condenser import MemoryCondenser
|
||||
from .history import ShortTermHistory
|
||||
from .memory import LongTermMemory
|
||||
|
||||
__all__ = ['LongTermMemory', 'ShortTermHistory', 'MemoryCondenser']
|
||||
24
openhands/memory/condenser.py
Normal file
24
openhands/memory/condenser.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class MemoryCondenser:
|
||||
def condense(self, summarize_prompt: str, llm: LLM):
|
||||
"""Attempts to condense the memory by using the llm
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): llm to be used for summarization
|
||||
|
||||
Raises:
|
||||
- Exception: the same exception as it got from the llm or processing the response
|
||||
"""
|
||||
try:
|
||||
messages = [{'content': summarize_prompt, 'role': 'user'}]
|
||||
resp = llm.completion(messages=messages)
|
||||
summary_response = resp['choices'][0]['message']['content']
|
||||
return summary_response
|
||||
except Exception as e:
|
||||
logger.error('Error condensing thoughts: %s', str(e), exc_info=False)
|
||||
|
||||
# TODO If the llm fails with ContextWindowExceededError, we can try to condense the memory chunk by chunk
|
||||
raise
|
||||
261
openhands/memory/history.py
Normal file
261
openhands/memory/history.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from typing import ClassVar, Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.agent import (
|
||||
AgentDelegateAction,
|
||||
ChangeAgentStateAction,
|
||||
)
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.commands import CmdOutputObservation
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.events.stream import EventStream
|
||||
|
||||
|
||||
class ShortTermHistory(list[Event]):
|
||||
"""A list of events that represents the short-term memory of the agent.
|
||||
|
||||
This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
|
||||
"""
|
||||
|
||||
start_id: int
|
||||
end_id: int
|
||||
_event_stream: EventStream
|
||||
delegates: dict[tuple[int, int], tuple[str, str]]
|
||||
filter_out: ClassVar[tuple[type[Event], ...]] = (
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.start_id = -1
|
||||
self.end_id = -1
|
||||
self.delegates = {}
|
||||
|
||||
def set_event_stream(self, event_stream: EventStream):
|
||||
self._event_stream = event_stream
|
||||
|
||||
def get_events_as_list(self) -> list[Event]:
|
||||
"""Return the history as a list of Event objects."""
|
||||
return list(self.get_events())
|
||||
|
||||
def get_events(self, reverse: bool = False) -> Iterable[Event]:
|
||||
"""Return the events as a stream of Event objects."""
|
||||
# TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
|
||||
# or even if it is, because currently we don't add it to the summary
|
||||
|
||||
# iterate from start_id to end_id, or reverse
|
||||
start_id = self.start_id if self.start_id != -1 else 0
|
||||
end_id = (
|
||||
self.end_id
|
||||
if self.end_id != -1
|
||||
else self._event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
for event in self._event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=reverse,
|
||||
filter_out_type=self.filter_out,
|
||||
):
|
||||
# TODO add summaries
|
||||
# and filter out events that were included in a summary
|
||||
|
||||
# filter out the events from a delegate of the current agent
|
||||
if not any(
|
||||
# except for the delegate action and observation themselves, currently
|
||||
# AgentDelegateAction has id = delegate_start
|
||||
# AgentDelegateObservation has id = delegate_end
|
||||
delegate_start < event.id < delegate_end
|
||||
for delegate_start, delegate_end in self.delegates.keys()
|
||||
):
|
||||
yield event
|
||||
|
||||
def get_last_action(self, end_id: int = -1) -> Action | None:
|
||||
"""Return the last action from the event stream, filtered to exclude unwanted events."""
|
||||
# from end_id in reverse, find the first action
|
||||
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
|
||||
|
||||
last_action = next(
|
||||
(
|
||||
event
|
||||
for event in self._event_stream.get_events(
|
||||
end_id=end_id, reverse=True, filter_out_type=self.filter_out
|
||||
)
|
||||
if isinstance(event, Action)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_action
|
||||
|
||||
def get_last_observation(self, end_id: int = -1) -> Observation | None:
|
||||
"""Return the last observation from the event stream, filtered to exclude unwanted events."""
|
||||
# from end_id in reverse, find the first observation
|
||||
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
|
||||
|
||||
last_observation = next(
|
||||
(
|
||||
event
|
||||
for event in self._event_stream.get_events(
|
||||
end_id=end_id, reverse=True, filter_out_type=self.filter_out
|
||||
)
|
||||
if isinstance(event, Observation)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_observation
|
||||
|
||||
def get_last_user_message(self) -> str:
|
||||
"""Return the content of the last user message from the event stream."""
|
||||
last_user_message = next(
|
||||
(
|
||||
event.content
|
||||
for event in self._event_stream.get_events(reverse=True)
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_user_message if last_user_message is not None else ''
|
||||
|
||||
def get_last_agent_message(self) -> str:
|
||||
"""Return the content of the last agent message from the event stream."""
|
||||
last_agent_message = next(
|
||||
(
|
||||
event.content
|
||||
for event in self._event_stream.get_events(reverse=True)
|
||||
if isinstance(event, MessageAction)
|
||||
and event.source == EventSource.AGENT
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return last_agent_message if last_agent_message is not None else ''
|
||||
|
||||
def get_last_events(self, n: int) -> list[Event]:
|
||||
"""Return the last n events from the event stream."""
|
||||
# dummy agent is using this
|
||||
# it should work, but it's not great to store temporary lists now just for a test
|
||||
end_id = self._event_stream.get_latest_event_id()
|
||||
start_id = max(0, end_id - n + 1)
|
||||
|
||||
return list(
|
||||
event
|
||||
for event in self._event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
filter_out_type=self.filter_out,
|
||||
)
|
||||
)
|
||||
|
||||
def has_delegation(self) -> bool:
|
||||
for event in self._event_stream.get_events():
|
||||
if isinstance(event, AgentDelegateObservation):
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_event(self, event: Event):
|
||||
if not isinstance(event, AgentDelegateObservation):
|
||||
return
|
||||
|
||||
logger.debug('AgentDelegateObservation received')
|
||||
|
||||
# figure out what this delegate's actions were
|
||||
# from the last AgentDelegateAction to this AgentDelegateObservation
|
||||
# and save their ids as start and end ids
|
||||
# in order to use later to exclude them from parent stream
|
||||
# or summarize them
|
||||
delegate_end = event.id
|
||||
delegate_start = -1
|
||||
delegate_agent: str = ''
|
||||
delegate_task: str = ''
|
||||
for prev_event in self._event_stream.get_events(
|
||||
end_id=event.id - 1, reverse=True
|
||||
):
|
||||
if isinstance(prev_event, AgentDelegateAction):
|
||||
delegate_start = prev_event.id
|
||||
delegate_agent = prev_event.agent
|
||||
delegate_task = prev_event.inputs.get('task', '')
|
||||
break
|
||||
|
||||
if delegate_start == -1:
|
||||
logger.error(
|
||||
f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
|
||||
)
|
||||
return
|
||||
|
||||
self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
|
||||
logger.debug(
|
||||
f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
|
||||
)
|
||||
|
||||
# TODO remove me when unnecessary
|
||||
# history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
|
||||
# we rebuild the pairs here
|
||||
# for compatibility with the existing output format in evaluations
|
||||
def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
|
||||
history_pairs = []
|
||||
|
||||
for action, observation in self.get_pairs():
|
||||
history_pairs.append((event_to_dict(action), event_to_dict(observation)))
|
||||
|
||||
return history_pairs
|
||||
|
||||
def get_pairs(self) -> list[tuple[Action, Observation]]:
|
||||
"""Return the history as a list of tuples (action, observation)."""
|
||||
tuples: list[tuple[Action, Observation]] = []
|
||||
action_map: dict[int, Action] = {}
|
||||
observation_map: dict[int, Observation] = {}
|
||||
|
||||
# runnable actions are set as cause of observations
|
||||
# (MessageAction, NullObservation) for source=USER
|
||||
# (MessageAction, NullObservation) for source=AGENT
|
||||
# (other_action?, NullObservation)
|
||||
# (NullAction, CmdOutputObservation) background CmdOutputObservations
|
||||
|
||||
for event in self.get_events_as_list():
|
||||
if event.id is None or event.id == -1:
|
||||
logger.debug(f'Event {event} has no ID')
|
||||
|
||||
if isinstance(event, Action):
|
||||
action_map[event.id] = event
|
||||
|
||||
if isinstance(event, Observation):
|
||||
if event.cause is None or event.cause == -1:
|
||||
logger.debug(f'Observation {event} has no cause')
|
||||
|
||||
if event.cause is None:
|
||||
# runnable actions are set as cause of observations
|
||||
# NullObservations have no cause
|
||||
continue
|
||||
|
||||
observation_map[event.cause] = event
|
||||
|
||||
for action_id, action in action_map.items():
|
||||
observation = observation_map.get(action_id)
|
||||
if observation:
|
||||
# observation with a cause
|
||||
tuples.append((action, observation))
|
||||
else:
|
||||
tuples.append((action, NullObservation('')))
|
||||
|
||||
for cause_id, observation in observation_map.items():
|
||||
if cause_id not in action_map:
|
||||
if isinstance(observation, NullObservation):
|
||||
continue
|
||||
if not isinstance(observation, CmdOutputObservation):
|
||||
logger.debug(f'Observation {observation} has no cause')
|
||||
tuples.append((NullAction(), observation))
|
||||
|
||||
return tuples.copy()
|
||||
184
openhands/memory/memory.py
Normal file
184
openhands/memory/memory.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import threading
|
||||
|
||||
from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.utils import json
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
import llama_index.embeddings.openai.base as llama_openai
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
LLAMA_INDEX_AVAILABLE = True
|
||||
except ImportError:
|
||||
LLAMA_INDEX_AVAILABLE = False
|
||||
|
||||
if LLAMA_INDEX_AVAILABLE:
|
||||
# TODO: this could be made configurable
|
||||
num_retries: int = 10
|
||||
retry_min_wait: int = 3
|
||||
retry_max_wait: int = 300
|
||||
|
||||
# llama-index includes a retry decorator around openai.get_embeddings() function
|
||||
# it is initialized with hard-coded values and errors
|
||||
# this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits
|
||||
# this block attempts to banish it and replace it with our decorator, to allow users to set their own limits
|
||||
|
||||
if hasattr(llama_openai.get_embeddings, '__wrapped__'):
|
||||
original_get_embeddings = llama_openai.get_embeddings.__wrapped__
|
||||
else:
|
||||
logger.warning('Cannot set custom retry limits.')
|
||||
num_retries = 1
|
||||
original_get_embeddings = llama_openai.get_embeddings
|
||||
|
||||
def attempt_on_error(retry_state):
|
||||
logger.error(
|
||||
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
|
||||
exc_info=False,
|
||||
)
|
||||
return None
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(num_retries),
|
||||
wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, InternalServerError)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
)
|
||||
def wrapper_get_embeddings(*args, **kwargs):
|
||||
return original_get_embeddings(*args, **kwargs)
|
||||
|
||||
llama_openai.get_embeddings = wrapper_get_embeddings
|
||||
|
||||
class EmbeddingsLoader:
|
||||
"""Loader for embedding model initialization."""
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model(strategy: str, llm_config: LLMConfig):
|
||||
supported_ollama_embed_models = [
|
||||
'llama2',
|
||||
'mxbai-embed-large',
|
||||
'nomic-embed-text',
|
||||
'all-minilm',
|
||||
'stable-code',
|
||||
]
|
||||
if strategy in supported_ollama_embed_models:
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
return OllamaEmbedding(
|
||||
model_name=strategy,
|
||||
base_url=llm_config.embedding_base_url,
|
||||
ollama_additional_kwargs={'mirostat': 0},
|
||||
)
|
||||
elif strategy == 'openai':
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
api_key=llm_config.api_key,
|
||||
)
|
||||
elif strategy == 'azureopenai':
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
return AzureOpenAIEmbedding(
|
||||
model='text-embedding-ada-002',
|
||||
deployment_name=llm_config.embedding_deployment_name,
|
||||
api_key=llm_config.api_key,
|
||||
azure_endpoint=llm_config.base_url,
|
||||
api_version=llm_config.api_version,
|
||||
)
|
||||
elif (strategy is not None) and (strategy.lower() == 'none'):
|
||||
# TODO: this works but is not elegant enough. The incentive is when
|
||||
# an agent using embeddings is not used, there is no reason we need to
|
||||
# initialize an embedding model
|
||||
return None
|
||||
else:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5')
|
||||
|
||||
|
||||
class LongTermMemory:
|
||||
"""Handles storing information for the agent to access later, using chromadb."""
|
||||
|
||||
def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1):
|
||||
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
|
||||
if not LLAMA_INDEX_AVAILABLE:
|
||||
raise ImportError(
|
||||
'llama_index and its dependencies are not installed. '
|
||||
'To use LongTermMemory, please run: poetry install --with llama-index'
|
||||
)
|
||||
|
||||
db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False))
|
||||
self.collection = db.get_or_create_collection(name='memories')
|
||||
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
||||
embedding_strategy = llm_config.embedding_model
|
||||
embed_model = EmbeddingsLoader.get_embedding_model(
|
||||
embedding_strategy, llm_config
|
||||
)
|
||||
self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model)
|
||||
self.sema = threading.Semaphore(value=memory_max_threads)
|
||||
self.thought_idx = 0
|
||||
self._add_threads: list[threading.Thread] = []
|
||||
|
||||
def add_event(self, event: dict):
|
||||
"""Adds a new event to the long term memory with a unique id.
|
||||
|
||||
Parameters:
|
||||
- event (dict): The new event to be added to memory
|
||||
"""
|
||||
id = ''
|
||||
t = ''
|
||||
if 'action' in event:
|
||||
t = 'action'
|
||||
id = event['action']
|
||||
elif 'observation' in event:
|
||||
t = 'observation'
|
||||
id = event['observation']
|
||||
doc = Document(
|
||||
text=json.dumps(event),
|
||||
doc_id=str(self.thought_idx),
|
||||
extra_info={
|
||||
'type': t,
|
||||
'id': id,
|
||||
'idx': self.thought_idx,
|
||||
},
|
||||
)
|
||||
self.thought_idx += 1
|
||||
logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
|
||||
thread = threading.Thread(target=self._add_doc, args=(doc,))
|
||||
self._add_threads.append(thread)
|
||||
thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert
|
||||
|
||||
def _add_doc(self, doc):
|
||||
with self.sema:
|
||||
self.index.insert(doc)
|
||||
|
||||
def search(self, query: str, k: int = 10):
|
||||
"""Searches through the current memory using VectorIndexRetriever
|
||||
|
||||
Parameters:
|
||||
- query (str): A query to match search results to
|
||||
- k (int): Number of top results to return
|
||||
|
||||
Returns:
|
||||
- list[str]: list of top k results found in current memory
|
||||
"""
|
||||
retriever = VectorIndexRetriever(
|
||||
index=self.index,
|
||||
similarity_top_k=k,
|
||||
)
|
||||
results = retriever.retrieve(query)
|
||||
return [r.get_text() for r in results]
|
||||
21
openhands/runtime/__init__.py
Normal file
21
openhands/runtime/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .e2b.sandbox import E2BBox
|
||||
|
||||
|
||||
def get_runtime_cls(name: str):
|
||||
# Local imports to avoid circular imports
|
||||
if name == 'eventstream':
|
||||
from .client.runtime import EventStreamRuntime
|
||||
|
||||
return EventStreamRuntime
|
||||
elif name == 'e2b':
|
||||
from .e2b.runtime import E2BRuntime
|
||||
|
||||
return E2BRuntime
|
||||
else:
|
||||
raise ValueError(f'Runtime {name} not supported')
|
||||
|
||||
|
||||
__all__ = [
|
||||
'E2BBox',
|
||||
'get_runtime_cls',
|
||||
]
|
||||
3
openhands/runtime/browser/__init__.py
Normal file
3
openhands/runtime/browser/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .utils import browse
|
||||
|
||||
__all__ = ['browse']
|
||||
231
openhands/runtime/browser/browser_env.py
Normal file
231
openhands/runtime/browser/browser_env.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import multiprocessing
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import browsergym.core # noqa F401 (we register the openended task as a gym environment)
|
||||
import gymnasium as gym
|
||||
import html2text
|
||||
import numpy as np
|
||||
import tenacity
|
||||
from browsergym.utils.obs import flatten_dom_to_str
|
||||
from PIL import Image
|
||||
|
||||
from openhands.core.exceptions import BrowserInitException
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
|
||||
BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
|
||||
|
||||
|
||||
class BrowserEnv:
|
||||
def __init__(self, browsergym_eval_env: str | None = None):
|
||||
self.html_text_converter = self.get_html_text_converter()
|
||||
self.eval_mode = False
|
||||
self.eval_dir = ''
|
||||
|
||||
# EVAL only: browsergym_eval_env must be provided for evaluation
|
||||
self.browsergym_eval_env = browsergym_eval_env
|
||||
self.eval_mode = bool(browsergym_eval_env)
|
||||
|
||||
# Initialize browser environment process
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
self.browser_side, self.agent_side = multiprocessing.Pipe()
|
||||
|
||||
self.init_browser()
|
||||
atexit.register(self.close)
|
||||
|
||||
def get_html_text_converter(self):
|
||||
html_text_converter = html2text.HTML2Text()
|
||||
# ignore links and images
|
||||
html_text_converter.ignore_links = False
|
||||
html_text_converter.ignore_images = True
|
||||
# use alt text for images
|
||||
html_text_converter.images_to_alt = True
|
||||
# disable auto text wrapping
|
||||
html_text_converter.body_width = 0
|
||||
return html_text_converter
|
||||
|
||||
@tenacity.retry(
|
||||
wait=tenacity.wait_fixed(1),
|
||||
stop=tenacity.stop_after_attempt(5),
|
||||
retry=tenacity.retry_if_exception_type(BrowserInitException),
|
||||
)
|
||||
def init_browser(self):
|
||||
logger.info('Starting browser env...')
|
||||
try:
|
||||
self.process = multiprocessing.Process(target=self.browser_process)
|
||||
self.process.start()
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to start browser process: {e}')
|
||||
raise
|
||||
|
||||
if not self.check_alive():
|
||||
self.close()
|
||||
raise BrowserInitException('Failed to start browser environment.')
|
||||
|
||||
def browser_process(self):
|
||||
if self.eval_mode:
|
||||
assert self.browsergym_eval_env is not None
|
||||
logger.info('Initializing browser env for web browsing evaluation.')
|
||||
if 'webarena' in self.browsergym_eval_env:
|
||||
import browsergym.webarena # noqa F401 register webarena tasks as gym environments
|
||||
elif 'miniwob' in self.browsergym_eval_env:
|
||||
import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
|
||||
)
|
||||
env = gym.make(self.browsergym_eval_env)
|
||||
else:
|
||||
env = gym.make(
|
||||
'browsergym/openended',
|
||||
task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
|
||||
wait_for_user_message=False,
|
||||
headless=True,
|
||||
disable_env_checker=True,
|
||||
)
|
||||
|
||||
obs, info = env.reset()
|
||||
|
||||
# EVAL ONLY: save the goal into file for evaluation
|
||||
self.eval_goal = None
|
||||
self.eval_rewards: list[float] = []
|
||||
if self.eval_mode:
|
||||
logger.info(f"Browsing goal: {obs['goal']}")
|
||||
self.eval_goal = obs['goal']
|
||||
|
||||
logger.info('Browser env started.')
|
||||
while True:
|
||||
try:
|
||||
if self.browser_side.poll(timeout=0.01):
|
||||
unique_request_id, action_data = self.browser_side.recv()
|
||||
|
||||
# shutdown the browser environment
|
||||
if unique_request_id == 'SHUTDOWN':
|
||||
logger.info('SHUTDOWN recv, shutting down browser env...')
|
||||
env.close()
|
||||
return
|
||||
elif unique_request_id == 'IS_ALIVE':
|
||||
self.browser_side.send(('ALIVE', None))
|
||||
continue
|
||||
|
||||
# EVAL ONLY: Get evaluation info
|
||||
if action_data['action'] == BROWSER_EVAL_GET_GOAL_ACTION:
|
||||
self.browser_side.send(
|
||||
(unique_request_id, {'text_content': self.eval_goal})
|
||||
)
|
||||
continue
|
||||
elif action_data['action'] == BROWSER_EVAL_GET_REWARDS_ACTION:
|
||||
self.browser_side.send(
|
||||
(
|
||||
unique_request_id,
|
||||
{'text_content': json.dumps(self.eval_rewards)},
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
action = action_data['action']
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
# EVAL ONLY: Save the rewards into file for evaluation
|
||||
if self.eval_mode:
|
||||
self.eval_rewards.append(reward)
|
||||
|
||||
# add text content of the page
|
||||
html_str = flatten_dom_to_str(obs['dom_object'])
|
||||
obs['text_content'] = self.html_text_converter.handle(html_str)
|
||||
# make observation serializable
|
||||
obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
|
||||
obs['active_page_index'] = obs['active_page_index'].item()
|
||||
obs['elapsed_time'] = obs['elapsed_time'].item()
|
||||
self.browser_side.send((unique_request_id, obs))
|
||||
except KeyboardInterrupt:
|
||||
logger.info('Browser env process interrupted by user.')
|
||||
try:
|
||||
env.close()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
def step(self, action_str: str, timeout: float = 30) -> dict:
|
||||
"""Execute an action in the browser environment and return the observation."""
|
||||
unique_request_id = str(uuid.uuid4())
|
||||
self.agent_side.send((unique_request_id, {'action': action_str}))
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError('Browser environment took too long to respond.')
|
||||
if self.agent_side.poll(timeout=0.01):
|
||||
response_id, obs = self.agent_side.recv()
|
||||
if response_id == unique_request_id:
|
||||
return obs
|
||||
|
||||
def check_alive(self, timeout: float = 60):
|
||||
self.agent_side.send(('IS_ALIVE', None))
|
||||
if self.agent_side.poll(timeout=timeout):
|
||||
response_id, _ = self.agent_side.recv()
|
||||
if response_id == 'ALIVE':
|
||||
return True
|
||||
logger.info(f'Browser env is not alive. Response ID: {response_id}')
|
||||
|
||||
def close(self):
|
||||
if not self.process.is_alive():
|
||||
return
|
||||
try:
|
||||
self.agent_side.send(('SHUTDOWN', None))
|
||||
self.process.join(5) # Wait for the process to terminate
|
||||
if self.process.is_alive():
|
||||
logger.error(
|
||||
'Browser process did not terminate, forcefully terminating...'
|
||||
)
|
||||
self.process.terminate()
|
||||
self.process.join(5) # Wait for the process to terminate
|
||||
if self.process.is_alive():
|
||||
self.process.kill()
|
||||
self.process.join(5) # Wait for the process to terminate
|
||||
self.agent_side.close()
|
||||
self.browser_side.close()
|
||||
except Exception:
|
||||
logger.error('Encountered an error when closing browser env', exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def image_to_png_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
):
|
||||
"""Convert a numpy array to a base64 encoded png image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='PNG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/png;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def image_to_jpg_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
):
|
||||
"""Convert a numpy array to a base64 encoded jpeg image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='JPEG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/jpeg;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
60
openhands/runtime/browser/utils.py
Normal file
60
openhands/runtime/browser/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
|
||||
from openhands.core.exceptions import BrowserUnavailableException
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
|
||||
|
||||
async def browse(
|
||||
action: BrowseURLAction | BrowseInteractiveAction, browser: BrowserEnv | None
|
||||
) -> BrowserOutputObservation:
|
||||
if browser is None:
|
||||
raise BrowserUnavailableException()
|
||||
|
||||
if isinstance(action, BrowseURLAction):
|
||||
# legacy BrowseURLAction
|
||||
asked_url = action.url
|
||||
if not asked_url.startswith('http'):
|
||||
asked_url = os.path.abspath(os.curdir) + action.url
|
||||
action_str = f'goto("{asked_url}")'
|
||||
|
||||
elif isinstance(action, BrowseInteractiveAction):
|
||||
# new BrowseInteractiveAction, supports full featured BrowserGym actions
|
||||
# action in BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/functions.py
|
||||
action_str = action.browser_actions
|
||||
else:
|
||||
raise ValueError(f'Invalid action type: {action.action}')
|
||||
|
||||
try:
|
||||
# obs provided by BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/env.py#L396
|
||||
obs = browser.step(action_str)
|
||||
return BrowserOutputObservation(
|
||||
content=obs['text_content'], # text content of the page
|
||||
url=obs.get('url', ''), # URL of the page
|
||||
screenshot=obs.get('screenshot', None), # base64-encoded screenshot, png
|
||||
open_pages_urls=obs.get('open_pages_urls', []), # list of open pages
|
||||
active_page_index=obs.get(
|
||||
'active_page_index', -1
|
||||
), # index of the active page
|
||||
dom_object=obs.get('dom_object', {}), # DOM object
|
||||
axtree_object=obs.get('axtree_object', {}), # accessibility tree object
|
||||
extra_element_properties=obs.get('extra_element_properties', {}),
|
||||
focused_element_bid=obs.get(
|
||||
'focused_element_bid', None
|
||||
), # focused element bid
|
||||
last_browser_action=obs.get(
|
||||
'last_action', ''
|
||||
), # last browser env action performed
|
||||
last_browser_action_error=obs.get('last_action_error', ''),
|
||||
error=True if obs.get('last_action_error', '') else False, # error flag
|
||||
)
|
||||
except Exception as e:
|
||||
return BrowserOutputObservation(
|
||||
content=str(e),
|
||||
screenshot='',
|
||||
error=True,
|
||||
last_browser_action_error=str(e),
|
||||
url=asked_url if action.action == ActionType.BROWSE else '',
|
||||
)
|
||||
4
openhands/runtime/builder/__init__.py
Normal file
4
openhands/runtime/builder/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import RuntimeBuilder
|
||||
from .docker import DockerRuntimeBuilder
|
||||
|
||||
__all__ = ['RuntimeBuilder', 'DockerRuntimeBuilder']
|
||||
37
openhands/runtime/builder/base.py
Normal file
37
openhands/runtime/builder/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import abc
|
||||
|
||||
|
||||
class RuntimeBuilder(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def build(
|
||||
self,
|
||||
path: str,
|
||||
tags: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build the runtime image.
|
||||
|
||||
Args:
|
||||
path (str): The path to the runtime image's build directory.
|
||||
tags (list[str]): The tags to apply to the runtime image (e.g., ["repo:my-repo", "sha:my-sha"]).
|
||||
|
||||
Returns:
|
||||
str: The name of the runtime image (e.g., "repo:sha").
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the build failed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def image_exists(self, image_name: str) -> bool:
|
||||
"""
|
||||
Check if the runtime image exists.
|
||||
|
||||
Args:
|
||||
image_name (str): The name of the runtime image (e.g., "repo:sha").
|
||||
|
||||
Returns:
|
||||
bool: Whether the runtime image exists.
|
||||
"""
|
||||
pass
|
||||
83
openhands/runtime/builder/docker.py
Normal file
83
openhands/runtime/builder/docker.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import docker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
from .base import RuntimeBuilder
|
||||
|
||||
|
||||
class DockerRuntimeBuilder(RuntimeBuilder):
|
||||
def __init__(self, docker_client: docker.DockerClient):
|
||||
self.docker_client = docker_client
|
||||
|
||||
def build(self, path: str, tags: list[str]) -> str:
|
||||
target_image_hash_name = tags[0]
|
||||
target_image_repo, target_image_hash_tag = target_image_hash_name.split(':')
|
||||
target_image_tag = tags[1].split(':')[1] if len(tags) > 1 else None
|
||||
|
||||
try:
|
||||
build_logs = self.docker_client.api.build(
|
||||
path=path,
|
||||
tag=target_image_hash_name,
|
||||
rm=True,
|
||||
decode=True,
|
||||
)
|
||||
except docker.errors.BuildError as e:
|
||||
logger.error(f'Sandbox image build failed: {e}')
|
||||
raise RuntimeError(f'Sandbox image build failed: {e}')
|
||||
|
||||
for log in build_logs:
|
||||
if 'stream' in log:
|
||||
print(log['stream'].strip())
|
||||
elif 'error' in log:
|
||||
logger.error(log['error'].strip())
|
||||
else:
|
||||
logger.info(str(log))
|
||||
|
||||
logger.info(f'Image [{target_image_hash_name}] build finished.')
|
||||
|
||||
assert (
|
||||
target_image_tag
|
||||
), f'Expected target image tag [{target_image_tag}] is None'
|
||||
image = self.docker_client.images.get(target_image_hash_name)
|
||||
image.tag(target_image_repo, target_image_tag)
|
||||
logger.info(
|
||||
f'Re-tagged image [{target_image_hash_name}] with more generic tag [{target_image_tag}]'
|
||||
)
|
||||
|
||||
# Check if the image is built successfully
|
||||
image = self.docker_client.images.get(target_image_hash_name)
|
||||
if image is None:
|
||||
raise RuntimeError(
|
||||
f'Build failed: Image {target_image_hash_name} not found'
|
||||
)
|
||||
|
||||
tags_str = (
|
||||
f'{target_image_hash_tag}, {target_image_tag}'
|
||||
if target_image_tag
|
||||
else target_image_hash_tag
|
||||
)
|
||||
logger.info(
|
||||
f'Image {target_image_repo} with tags [{tags_str}] built successfully'
|
||||
)
|
||||
return target_image_hash_name
|
||||
|
||||
def image_exists(self, image_name: str) -> bool:
|
||||
"""Check if the image exists in the registry (try to pull it first) AND in the local store.
|
||||
|
||||
Args:
|
||||
image_name (str): The Docker image to check (<image repo>:<image tag>)
|
||||
Returns:
|
||||
bool: Whether the Docker image exists in the registry and in the local store
|
||||
"""
|
||||
# Try to pull the Docker image from the registry
|
||||
try:
|
||||
self.docker_client.images.pull(image_name)
|
||||
except Exception:
|
||||
logger.info(f'Cannot pull image {image_name} directly')
|
||||
|
||||
images = self.docker_client.images.list()
|
||||
if images:
|
||||
for image in images:
|
||||
if image_name in image.tags:
|
||||
return True
|
||||
return False
|
||||
685
openhands/runtime/client/client.py
Normal file
685
openhands/runtime/client/client.py
Normal file
@@ -0,0 +1,685 @@
|
||||
"""
|
||||
This is the main file for the runtime client.
|
||||
It is responsible for executing actions received from OpenHands backend and producing observations.
|
||||
|
||||
NOTE: this will be executed inside the docker sandbox.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import pexpect
|
||||
from fastapi import FastAPI, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pathspec import PathSpec
|
||||
from pathspec.patterns import GitWildMatchPattern
|
||||
from pydantic import BaseModel
|
||||
from uvicorn import run
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
IPythonRunCellObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.runtime.browser import browse
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
from openhands.runtime.plugins import (
|
||||
ALL_PLUGINS,
|
||||
JupyterPlugin,
|
||||
Plugin,
|
||||
)
|
||||
from openhands.runtime.utils import split_bash_commands
|
||||
from openhands.runtime.utils.files import insert_lines, read_lines
|
||||
|
||||
|
||||
class ActionRequest(BaseModel):
|
||||
action: dict
|
||||
|
||||
|
||||
ROOT_GID = 0
|
||||
INIT_COMMANDS = [
|
||||
'git config --global user.name "openhands"',
|
||||
'git config --global user.email "openhands@all-hands.dev"',
|
||||
"alias git='git --no-pager'",
|
||||
]
|
||||
|
||||
|
||||
class RuntimeClient:
|
||||
"""RuntimeClient is running inside docker sandbox.
|
||||
It is responsible for executing actions received from OpenHands backend and producing observations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugins_to_load: list[Plugin],
|
||||
work_dir: str,
|
||||
username: str,
|
||||
user_id: int,
|
||||
browsergym_eval_env: str | None,
|
||||
) -> None:
|
||||
self.plugins_to_load = plugins_to_load
|
||||
self.username = username
|
||||
self.user_id = user_id
|
||||
self.pwd = work_dir # current PWD
|
||||
self._initial_pwd = work_dir
|
||||
self._init_user(self.username, self.user_id)
|
||||
self._init_bash_shell(self.pwd, self.username)
|
||||
self.lock = asyncio.Lock()
|
||||
self.plugins: dict[str, Plugin] = {}
|
||||
self.browser = BrowserEnv(browsergym_eval_env)
|
||||
self._initial_pwd = work_dir
|
||||
|
||||
@property
|
||||
def initial_pwd(self):
|
||||
return self._initial_pwd
|
||||
|
||||
async def ainit(self):
|
||||
for plugin in self.plugins_to_load:
|
||||
await plugin.initialize(self.username)
|
||||
self.plugins[plugin.name] = plugin
|
||||
logger.info(f'Initializing plugin: {plugin.name}')
|
||||
|
||||
if isinstance(plugin, JupyterPlugin):
|
||||
await self.run_ipython(
|
||||
IPythonRunCellAction(code=f'import os; os.chdir("{self.pwd}")')
|
||||
)
|
||||
|
||||
# This is a temporary workaround
|
||||
# TODO: refactor AgentSkills to be part of JupyterPlugin
|
||||
# AFTER ServerRuntime is deprecated
|
||||
if 'agent_skills' in self.plugins and 'jupyter' in self.plugins:
|
||||
obs = await self.run_ipython(
|
||||
IPythonRunCellAction(
|
||||
code='from openhands.runtime.plugins.agent_skills.agentskills import *\n'
|
||||
)
|
||||
)
|
||||
logger.info(f'AgentSkills initialized: {obs}')
|
||||
|
||||
await self._init_bash_commands()
|
||||
|
||||
def _init_user(self, username: str, user_id: int) -> None:
|
||||
"""Create user if not exists."""
|
||||
# Skip root since it is already created
|
||||
if username == 'root':
|
||||
return
|
||||
|
||||
# Check if the username already exists
|
||||
try:
|
||||
subprocess.run(
|
||||
f'id -u {username}', shell=True, check=True, capture_output=True
|
||||
)
|
||||
logger.debug(f'User {username} already exists. Skipping creation.')
|
||||
return
|
||||
except subprocess.CalledProcessError:
|
||||
pass # User does not exist, continue with creation
|
||||
|
||||
# Add sudoer
|
||||
sudoer_line = r"echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers"
|
||||
output = subprocess.run(sudoer_line, shell=True, capture_output=True)
|
||||
if output.returncode != 0:
|
||||
raise RuntimeError(f'Failed to add sudoer: {output.stderr.decode()}')
|
||||
logger.debug(f'Added sudoer successfully. Output: [{output.stdout.decode()}]')
|
||||
|
||||
# Attempt to add the user, retrying with incremented user_id if necessary
|
||||
while True:
|
||||
command = (
|
||||
f'useradd -rm -d /home/{username} -s /bin/bash '
|
||||
f'-g root -G sudo -u {user_id} {username}'
|
||||
)
|
||||
|
||||
if not os.path.exists(self.initial_pwd):
|
||||
command += f' && mkdir -p {self.initial_pwd}'
|
||||
command += f' && chown -R {username}:root {self.initial_pwd}'
|
||||
command += f' && chmod g+s {self.initial_pwd}'
|
||||
|
||||
output = subprocess.run(command, shell=True, capture_output=True)
|
||||
if output.returncode == 0:
|
||||
logger.debug(
|
||||
f'Added user {username} successfully with UID {user_id}. Output: [{output.stdout.decode()}]'
|
||||
)
|
||||
break
|
||||
elif f'UID {user_id} is not unique' in output.stderr.decode():
|
||||
logger.warning(
|
||||
f'UID {user_id} is not unique. Incrementing UID and retrying...'
|
||||
)
|
||||
user_id += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Failed to create user {username}: {output.stderr.decode()}'
|
||||
)
|
||||
|
||||
def _init_bash_shell(self, work_dir: str, username: str) -> None:
|
||||
self.shell = pexpect.spawn(
|
||||
f'su - {username}',
|
||||
encoding='utf-8',
|
||||
echo=False,
|
||||
)
|
||||
self.__bash_PS1 = r'[PEXPECT_BEGIN] \u@\h:\w [PEXPECT_END]'
|
||||
|
||||
# This should NOT match "PS1=\u@\h:\w [PEXPECT]$" when `env` is executed
|
||||
self.__bash_expect_regex = (
|
||||
r'\[PEXPECT_BEGIN\] ([a-z0-9_-]*)@([a-zA-Z0-9.-]*):(.+) \[PEXPECT_END\]'
|
||||
)
|
||||
|
||||
self.shell.sendline(f'export PS1="{self.__bash_PS1}"; export PS2=""')
|
||||
self.shell.expect(self.__bash_expect_regex)
|
||||
|
||||
self.shell.sendline(f'cd {work_dir}')
|
||||
self.shell.expect(self.__bash_expect_regex)
|
||||
logger.debug(
|
||||
f'Bash initialized. Working directory: {work_dir}. Output: {self.shell.before}'
|
||||
)
|
||||
|
||||
async def _init_bash_commands(self):
|
||||
logger.info(f'Initializing by running {len(INIT_COMMANDS)} bash commands...')
|
||||
for command in INIT_COMMANDS:
|
||||
action = CmdRunAction(command=command)
|
||||
action.timeout = 300
|
||||
logger.debug(f'Executing init command: {command}')
|
||||
obs: CmdOutputObservation = await self.run(action)
|
||||
logger.debug(
|
||||
f'Init command outputs (exit code: {obs.exit_code}): {obs.content}'
|
||||
)
|
||||
assert obs.exit_code == 0
|
||||
|
||||
logger.info('Bash init commands completed')
|
||||
|
||||
def _get_bash_prompt_and_update_pwd(self):
|
||||
ps1 = self.shell.after
|
||||
|
||||
# begin at the last occurence of '[PEXPECT_BEGIN]'.
|
||||
# In multi-line bash commands, the prompt will be repeated
|
||||
# and the matched regex captures all of them
|
||||
# - we only want the last one (newest prompt)
|
||||
_begin_pos = ps1.rfind('[PEXPECT_BEGIN]')
|
||||
if _begin_pos != -1:
|
||||
ps1 = ps1[_begin_pos:]
|
||||
|
||||
# parse the ps1 to get username, hostname, and working directory
|
||||
matched = re.match(self.__bash_expect_regex, ps1)
|
||||
assert (
|
||||
matched is not None
|
||||
), f'Failed to parse bash prompt: {ps1}. This should not happen.'
|
||||
username, hostname, working_dir = matched.groups()
|
||||
self.pwd = os.path.expanduser(working_dir)
|
||||
|
||||
# re-assemble the prompt
|
||||
prompt = f'{username}@{hostname}:{working_dir} '
|
||||
if username == 'root':
|
||||
prompt += '#'
|
||||
else:
|
||||
prompt += '$'
|
||||
return prompt + ' '
|
||||
|
||||
def _execute_bash(
|
||||
self,
|
||||
command: str,
|
||||
timeout: int | None,
|
||||
keep_prompt: bool = True,
|
||||
) -> tuple[str, int]:
|
||||
logger.debug(f'Executing command: {command}')
|
||||
try:
|
||||
self.shell.sendline(command)
|
||||
self.shell.expect(self.__bash_expect_regex, timeout=timeout)
|
||||
|
||||
output = self.shell.before
|
||||
|
||||
# Get exit code
|
||||
self.shell.sendline('echo $?')
|
||||
logger.debug(f'Executing command for exit code: {command}')
|
||||
self.shell.expect(self.__bash_expect_regex, timeout=timeout)
|
||||
_exit_code_output = self.shell.before
|
||||
logger.debug(f'Exit code Output: {_exit_code_output}')
|
||||
exit_code = int(_exit_code_output.strip().split()[0])
|
||||
|
||||
except pexpect.TIMEOUT as e:
|
||||
self.shell.sendintr() # send SIGINT to the shell
|
||||
self.shell.expect(self.__bash_expect_regex, timeout=timeout)
|
||||
output = self.shell.before
|
||||
output += (
|
||||
'\r\n\r\n'
|
||||
+ f'[Command timed out after {timeout} seconds. SIGINT was sent to interrupt it.]'
|
||||
)
|
||||
exit_code = 130 # SIGINT
|
||||
logger.error(f'Failed to execute command: {command}. Error: {e}')
|
||||
|
||||
finally:
|
||||
bash_prompt = self._get_bash_prompt_and_update_pwd()
|
||||
if keep_prompt:
|
||||
output += '\r\n' + bash_prompt
|
||||
logger.debug(f'Command output: {output}')
|
||||
|
||||
return output, exit_code
|
||||
|
||||
async def run_action(self, action) -> Observation:
|
||||
action_type = action.action
|
||||
observation = await getattr(self, action_type)(action)
|
||||
return observation
|
||||
|
||||
async def run(self, action: CmdRunAction) -> CmdOutputObservation:
|
||||
try:
|
||||
assert (
|
||||
action.timeout is not None
|
||||
), f'Timeout argument is required for CmdRunAction: {action}'
|
||||
commands = split_bash_commands(action.command)
|
||||
all_output = ''
|
||||
for command in commands:
|
||||
output, exit_code = self._execute_bash(
|
||||
command,
|
||||
timeout=action.timeout,
|
||||
keep_prompt=action.keep_prompt,
|
||||
)
|
||||
if all_output:
|
||||
# previous output already exists with prompt "user@hostname:working_dir #""
|
||||
# we need to add the command to the previous output,
|
||||
# so model knows the following is the output of another action)
|
||||
all_output = all_output.rstrip() + ' ' + command + '\r\n'
|
||||
|
||||
all_output += str(output) + '\r\n'
|
||||
if exit_code != 0:
|
||||
break
|
||||
return CmdOutputObservation(
|
||||
command_id=-1,
|
||||
content=all_output.rstrip('\r\n'),
|
||||
command=action.command,
|
||||
exit_code=exit_code,
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
raise RuntimeError('Command output could not be decoded as utf-8')
|
||||
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
if 'jupyter' in self.plugins:
|
||||
_jupyter_plugin: JupyterPlugin = self.plugins['jupyter'] # type: ignore
|
||||
# This is used to make AgentSkills in Jupyter aware of the
|
||||
# current working directory in Bash
|
||||
if self.pwd != getattr(self, '_jupyter_pwd', None):
|
||||
logger.debug(
|
||||
f"{self.pwd} != {getattr(self, '_jupyter_pwd', None)} -> reset Jupyter PWD"
|
||||
)
|
||||
reset_jupyter_pwd_code = f'import os; os.chdir("{self.pwd}")'
|
||||
_aux_action = IPythonRunCellAction(code=reset_jupyter_pwd_code)
|
||||
_reset_obs = await _jupyter_plugin.run(_aux_action)
|
||||
logger.debug(
|
||||
f'Changed working directory in IPython to: {self.pwd}. Output: {_reset_obs}'
|
||||
)
|
||||
self._jupyter_pwd = self.pwd
|
||||
|
||||
obs: IPythonRunCellObservation = await _jupyter_plugin.run(action)
|
||||
obs.content = obs.content.rstrip()
|
||||
obs.content += f'\n[Jupyter current working directory: {self.pwd}]'
|
||||
return obs
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'JupyterRequirement not found. Unable to run IPython action.'
|
||||
)
|
||||
|
||||
def _get_working_directory(self):
|
||||
# NOTE: this is part of initialization, so we hard code the timeout
|
||||
result, exit_code = self._execute_bash('pwd', timeout=60, keep_prompt=False)
|
||||
if exit_code != 0:
|
||||
raise RuntimeError('Failed to get working directory')
|
||||
return result.strip()
|
||||
|
||||
def _resolve_path(self, path: str, working_dir: str) -> str:
|
||||
filepath = Path(path)
|
||||
if not filepath.is_absolute():
|
||||
return str(Path(working_dir) / filepath)
|
||||
return str(filepath)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
# NOTE: the client code is running inside the sandbox,
|
||||
# so there's no need to check permission
|
||||
working_dir = self._get_working_directory()
|
||||
filepath = self._resolve_path(action.path, working_dir)
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as file:
|
||||
lines = read_lines(file.readlines(), action.start, action.end)
|
||||
except FileNotFoundError:
|
||||
return ErrorObservation(
|
||||
f'File not found: {filepath}. Your current working directory is {working_dir}.'
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
return ErrorObservation(f'File could not be decoded as utf-8: {filepath}.')
|
||||
except IsADirectoryError:
|
||||
return ErrorObservation(
|
||||
f'Path is a directory: {filepath}. You can only read files'
|
||||
)
|
||||
|
||||
code_view = ''.join(lines)
|
||||
return FileReadObservation(path=filepath, content=code_view)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
working_dir = self._get_working_directory()
|
||||
filepath = self._resolve_path(action.path, working_dir)
|
||||
|
||||
insert = action.content.split('\n')
|
||||
try:
|
||||
if not os.path.exists(os.path.dirname(filepath)):
|
||||
os.makedirs(os.path.dirname(filepath))
|
||||
|
||||
file_exists = os.path.exists(filepath)
|
||||
if file_exists:
|
||||
file_stat = os.stat(filepath)
|
||||
else:
|
||||
file_stat = None
|
||||
|
||||
mode = 'w' if not file_exists else 'r+'
|
||||
try:
|
||||
with open(filepath, mode, encoding='utf-8') as file:
|
||||
if mode != 'w':
|
||||
all_lines = file.readlines()
|
||||
new_file = insert_lines(
|
||||
insert, all_lines, action.start, action.end
|
||||
)
|
||||
else:
|
||||
new_file = [i + '\n' for i in insert]
|
||||
|
||||
file.seek(0)
|
||||
file.writelines(new_file)
|
||||
file.truncate()
|
||||
|
||||
# Handle file permissions
|
||||
if file_exists:
|
||||
assert file_stat is not None
|
||||
# restore the original file permissions if the file already exists
|
||||
os.chmod(filepath, file_stat.st_mode)
|
||||
os.chown(filepath, file_stat.st_uid, file_stat.st_gid)
|
||||
else:
|
||||
# set the new file permissions if the file is new
|
||||
os.chmod(filepath, 0o644)
|
||||
os.chown(filepath, self.user_id, self.user_id)
|
||||
|
||||
except FileNotFoundError:
|
||||
return ErrorObservation(f'File not found: {filepath}')
|
||||
except IsADirectoryError:
|
||||
return ErrorObservation(
|
||||
f'Path is a directory: {filepath}. You can only write to files'
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
return ErrorObservation(
|
||||
f'File could not be decoded as utf-8: {filepath}'
|
||||
)
|
||||
except PermissionError:
|
||||
return ErrorObservation(f'Malformed paths not permitted: {filepath}')
|
||||
return FileWriteObservation(content='', path=filepath)
|
||||
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
return await browse(action, self.browser)
|
||||
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return await browse(action, self.browser)
|
||||
|
||||
def close(self):
|
||||
self.shell.close()
|
||||
self.browser.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('port', type=int, help='Port to listen on')
|
||||
parser.add_argument('--working-dir', type=str, help='Working directory')
|
||||
parser.add_argument('--plugins', type=str, help='Plugins to initialize', nargs='+')
|
||||
parser.add_argument(
|
||||
'--username', type=str, help='User to run as', default='openhands'
|
||||
)
|
||||
parser.add_argument('--user-id', type=int, help='User ID to run as', default=1000)
|
||||
parser.add_argument(
|
||||
'--browsergym-eval-env',
|
||||
type=str,
|
||||
help='BrowserGym environment used for browser evaluation',
|
||||
default=None,
|
||||
)
|
||||
# example: python client.py 8000 --working-dir /workspace --plugins JupyterRequirement
|
||||
args = parser.parse_args()
|
||||
|
||||
plugins_to_load: list[Plugin] = []
|
||||
if args.plugins:
|
||||
for plugin in args.plugins:
|
||||
if plugin not in ALL_PLUGINS:
|
||||
raise ValueError(f'Plugin {plugin} not found')
|
||||
plugins_to_load.append(ALL_PLUGINS[plugin]()) # type: ignore
|
||||
|
||||
client: RuntimeClient | None = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global client
|
||||
client = RuntimeClient(
|
||||
plugins_to_load,
|
||||
work_dir=args.working_dir,
|
||||
username=args.username,
|
||||
user_id=args.user_id,
|
||||
browsergym_eval_env=args.browsergym_eval_env,
|
||||
)
|
||||
await client.ainit()
|
||||
yield
|
||||
# Clean up & release the resources
|
||||
client.close()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.middleware('http')
|
||||
async def one_request_at_a_time(request: Request, call_next):
|
||||
assert client is not None
|
||||
async with client.lock:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@app.post('/execute_action')
|
||||
async def execute_action(action_request: ActionRequest):
|
||||
assert client is not None
|
||||
try:
|
||||
action = event_from_dict(action_request.action)
|
||||
if not isinstance(action, Action):
|
||||
raise HTTPException(status_code=400, detail='Invalid action type')
|
||||
observation = await client.run_action(action)
|
||||
return event_to_dict(observation)
|
||||
except Exception as e:
|
||||
logger.error(f'Error processing command: {str(e)}')
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post('/upload_file')
|
||||
async def upload_file(
|
||||
file: UploadFile, destination: str = '/', recursive: bool = False
|
||||
):
|
||||
assert client is not None
|
||||
|
||||
try:
|
||||
# Ensure the destination directory exists
|
||||
if not os.path.isabs(destination):
|
||||
raise HTTPException(
|
||||
status_code=400, detail='Destination must be an absolute path'
|
||||
)
|
||||
|
||||
full_dest_path = destination
|
||||
if not os.path.exists(full_dest_path):
|
||||
os.makedirs(full_dest_path, exist_ok=True)
|
||||
|
||||
if recursive:
|
||||
# For recursive uploads, we expect a zip file
|
||||
if not file.filename.endswith('.zip'):
|
||||
raise HTTPException(
|
||||
status_code=400, detail='Recursive uploads must be zip files'
|
||||
)
|
||||
|
||||
zip_path = os.path.join(full_dest_path, file.filename)
|
||||
with open(zip_path, 'wb') as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
# Extract the zip file
|
||||
shutil.unpack_archive(zip_path, full_dest_path)
|
||||
os.remove(zip_path) # Remove the zip file after extraction
|
||||
|
||||
logger.info(
|
||||
f'Uploaded file {file.filename} and extracted to {destination}'
|
||||
)
|
||||
else:
|
||||
# For single file uploads
|
||||
file_path = os.path.join(full_dest_path, file.filename)
|
||||
with open(file_path, 'wb') as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
logger.info(f'Uploaded file {file.filename} to {destination}')
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
'filename': file.filename,
|
||||
'destination': destination,
|
||||
'recursive': recursive,
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get('/alive')
|
||||
async def alive():
|
||||
return {'status': 'ok'}
|
||||
|
||||
# ================================
|
||||
# File-specific operations for UI
|
||||
# ================================
|
||||
|
||||
@app.post('/list_files')
|
||||
async def list_files(request: Request):
|
||||
"""List files in the specified path.
|
||||
|
||||
This function retrieves a list of files from the agent's runtime file store,
|
||||
excluding certain system and hidden files/directories.
|
||||
|
||||
To list files:
|
||||
```sh
|
||||
curl http://localhost:3000/api/list-files
|
||||
```
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
path (str, optional): The path to list files from. Defaults to '/'.
|
||||
|
||||
Returns:
|
||||
list: A list of file names in the specified path.
|
||||
|
||||
Raises:
|
||||
HTTPException: If there's an error listing the files.
|
||||
"""
|
||||
assert client is not None
|
||||
|
||||
# get request as dict
|
||||
request_dict = await request.json()
|
||||
path = request_dict.get('path', None)
|
||||
|
||||
# Get the full path of the requested directory
|
||||
if path is None:
|
||||
full_path = client.initial_pwd
|
||||
elif os.path.isabs(path):
|
||||
full_path = path
|
||||
else:
|
||||
full_path = os.path.join(client.initial_pwd, path)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return JSONResponse(
|
||||
content={'error': f'Directory {full_path} does not exist'},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if the directory exists
|
||||
if not os.path.exists(full_path) or not os.path.isdir(full_path):
|
||||
return []
|
||||
|
||||
# Check if .gitignore exists
|
||||
gitignore_path = os.path.join(full_path, '.gitignore')
|
||||
if os.path.exists(gitignore_path):
|
||||
# Use PathSpec to parse .gitignore
|
||||
with open(gitignore_path, 'r') as f:
|
||||
spec = PathSpec.from_lines(GitWildMatchPattern, f.readlines())
|
||||
else:
|
||||
# Fallback to default exclude list if .gitignore doesn't exist
|
||||
default_exclude = [
|
||||
'.git',
|
||||
'.DS_Store',
|
||||
'.svn',
|
||||
'.hg',
|
||||
'.idea',
|
||||
'.vscode',
|
||||
'.settings',
|
||||
'.pytest_cache',
|
||||
'__pycache__',
|
||||
'node_modules',
|
||||
'vendor',
|
||||
'build',
|
||||
'dist',
|
||||
'bin',
|
||||
'logs',
|
||||
'log',
|
||||
'tmp',
|
||||
'temp',
|
||||
'coverage',
|
||||
'venv',
|
||||
'env',
|
||||
]
|
||||
spec = PathSpec.from_lines(GitWildMatchPattern, default_exclude)
|
||||
|
||||
entries = os.listdir(full_path)
|
||||
|
||||
# Filter entries using PathSpec
|
||||
filtered_entries = [
|
||||
os.path.join(full_path, entry)
|
||||
for entry in entries
|
||||
if not spec.match_file(os.path.relpath(entry, str(full_path)))
|
||||
]
|
||||
|
||||
# Separate directories and files
|
||||
directories = []
|
||||
files = []
|
||||
for entry in filtered_entries:
|
||||
# Remove leading slash and any parent directory components
|
||||
entry_relative = entry.lstrip('/').split('/')[-1]
|
||||
|
||||
# Construct the full path by joining the base path with the relative entry path
|
||||
full_entry_path = os.path.join(full_path, entry_relative)
|
||||
if os.path.exists(full_entry_path):
|
||||
is_dir = os.path.isdir(full_entry_path)
|
||||
if is_dir:
|
||||
# add trailing slash to directories
|
||||
# required by FE to differentiate directories and files
|
||||
entry = entry.rstrip('/') + '/'
|
||||
directories.append(entry)
|
||||
else:
|
||||
files.append(entry)
|
||||
|
||||
# Sort directories and files separately
|
||||
directories.sort(key=lambda s: s.lower())
|
||||
files.sort(key=lambda s: s.lower())
|
||||
|
||||
# Combine sorted directories and files
|
||||
sorted_entries = directories + files
|
||||
return sorted_entries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Error listing files: {e}', exc_info=True)
|
||||
return []
|
||||
|
||||
logger.info(f'Starting action execution API on port {args.port}')
|
||||
print(f'Starting action execution API on port {args.port}')
|
||||
run(app, host='0.0.0.0', port=args.port)
|
||||
386
openhands/runtime/client/runtime.py
Normal file
386
openhands/runtime/client/runtime.py
Normal file
@@ -0,0 +1,386 @@
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from zipfile import ZipFile
|
||||
|
||||
import aiohttp
|
||||
import docker
|
||||
import tenacity
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.action import (
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
|
||||
|
||||
class EventStreamRuntime(Runtime):
|
||||
"""This runtime will subscribe the event stream.
|
||||
When receive an event, it will send the event to od-runtime-client which run inside the docker environment.
|
||||
"""
|
||||
|
||||
container_name_prefix = 'openhands-sandbox-'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig,
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
container_image: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
config, event_stream, sid, plugins
|
||||
) # will initialize the event stream
|
||||
self._port = find_available_tcp_port()
|
||||
self.api_url = f'http://{self.config.sandbox.api_hostname}:{self._port}'
|
||||
self.session: aiohttp.ClientSession | None = None
|
||||
|
||||
self.instance_id = (
|
||||
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
|
||||
)
|
||||
# TODO: We can switch to aiodocker when `get_od_sandbox_image` is updated to use aiodocker
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.container_image = (
|
||||
self.config.sandbox.container_image
|
||||
if container_image is None
|
||||
else container_image
|
||||
)
|
||||
self.container_name = self.container_name_prefix + self.instance_id
|
||||
|
||||
self.container = None
|
||||
self.action_semaphore = asyncio.Semaphore(1) # Ensure one action at a time
|
||||
|
||||
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
|
||||
logger.debug(f'EventStreamRuntime `{sid}` config:\n{self.config}')
|
||||
|
||||
async def ainit(self, env_vars: dict[str, str] | None = None):
|
||||
if self.config.sandbox.od_runtime_extra_deps:
|
||||
logger.info(
|
||||
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.od_runtime_extra_deps}'
|
||||
)
|
||||
|
||||
self.container_image = build_runtime_image(
|
||||
self.container_image,
|
||||
self.runtime_builder,
|
||||
extra_deps=self.config.sandbox.od_runtime_extra_deps,
|
||||
)
|
||||
self.container = await self._init_container(
|
||||
self.sandbox_workspace_dir,
|
||||
mount_dir=self.config.workspace_mount_path,
|
||||
plugins=self.plugins,
|
||||
)
|
||||
# MUST call super().ainit() to initialize both default env vars
|
||||
# AND the ones in env vars!
|
||||
await super().ainit(env_vars)
|
||||
|
||||
logger.info(
|
||||
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'
|
||||
)
|
||||
logger.info(f'Container initialized with env vars: {env_vars}')
|
||||
|
||||
@staticmethod
|
||||
def _init_docker_client() -> docker.DockerClient:
|
||||
try:
|
||||
return docker.from_env()
|
||||
except Exception as ex:
|
||||
logger.error(
|
||||
'Launch docker client failed. Please make sure you have installed docker and started the docker daemon.'
|
||||
)
|
||||
raise ex
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(5),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
)
|
||||
async def _init_container(
|
||||
self,
|
||||
sandbox_workspace_dir: str,
|
||||
mount_dir: str | None = None,
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f'Starting container with image: {self.container_image} and name: {self.container_name}'
|
||||
)
|
||||
plugin_arg = ''
|
||||
if plugins is not None and len(plugins) > 0:
|
||||
plugin_arg = (
|
||||
f'--plugins {" ".join([plugin.name for plugin in plugins])} '
|
||||
)
|
||||
|
||||
network_mode: str | None = None
|
||||
port_mapping: dict[str, int] | None = None
|
||||
if self.config.sandbox.use_host_network:
|
||||
network_mode = 'host'
|
||||
logger.warn(
|
||||
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop'
|
||||
)
|
||||
else:
|
||||
port_mapping = {f'{self._port}/tcp': self._port}
|
||||
|
||||
if mount_dir is not None:
|
||||
volumes = {mount_dir: {'bind': sandbox_workspace_dir, 'mode': 'rw'}}
|
||||
logger.info(f'Mount dir: {sandbox_workspace_dir}')
|
||||
else:
|
||||
logger.warn(
|
||||
'Mount dir is not set, will not mount the workspace directory to the container.'
|
||||
)
|
||||
volumes = None
|
||||
|
||||
if self.config.sandbox.browsergym_eval_env is not None:
|
||||
browsergym_arg = (
|
||||
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
|
||||
)
|
||||
else:
|
||||
browsergym_arg = ''
|
||||
container = self.docker_client.containers.run(
|
||||
self.container_image,
|
||||
command=(
|
||||
f'/openhands/miniforge3/bin/mamba run --no-capture-output -n base '
|
||||
'PYTHONUNBUFFERED=1 poetry run '
|
||||
f'python -u -m openhands.runtime.client.client {self._port} '
|
||||
f'--working-dir {sandbox_workspace_dir} '
|
||||
f'{plugin_arg}'
|
||||
f'--username {"openhands" if self.config.run_as_openhands else "root"} '
|
||||
f'--user-id {self.config.sandbox.user_id} '
|
||||
f'{browsergym_arg}'
|
||||
),
|
||||
network_mode=network_mode,
|
||||
ports=port_mapping,
|
||||
working_dir='/openhands/code/',
|
||||
name=self.container_name,
|
||||
detach=True,
|
||||
environment={'DEBUG': 'true'} if self.config.debug else None,
|
||||
volumes=volumes,
|
||||
)
|
||||
logger.info(f'Container started. Server url: {self.api_url}')
|
||||
return container
|
||||
except Exception as e:
|
||||
logger.error('Failed to start container')
|
||||
logger.exception(e)
|
||||
await self.close(close_client=False)
|
||||
raise e
|
||||
|
||||
async def _ensure_session(self):
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self.session
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(10),
|
||||
wait=tenacity.wait_exponential(multiplier=2, min=10, max=60),
|
||||
)
|
||||
async def _wait_until_alive(self):
|
||||
logger.debug('Getting container logs...')
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
# get logs
|
||||
_logs = container.logs(tail=10).decode('utf-8').split('\n')
|
||||
# add indent
|
||||
_logs = '\n'.join([f' |{log}' for log in _logs])
|
||||
logger.info(
|
||||
'\n'
|
||||
+ '-' * 30
|
||||
+ 'Container logs (last 10 lines):'
|
||||
+ '-' * 30
|
||||
+ f'\n{_logs}'
|
||||
+ '\n'
|
||||
+ '-' * 90
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f'{self.api_url}/alive') as response:
|
||||
if response.status == 200:
|
||||
return
|
||||
else:
|
||||
msg = f'Action execution API is not alive. Response: {response}'
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@property
|
||||
def sandbox_workspace_dir(self):
|
||||
return self.config.workspace_mount_path_in_sandbox
|
||||
|
||||
async def close(self, close_client: bool = True):
|
||||
if self.session is not None and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
containers = self.docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
try:
|
||||
if container.name.startswith(self.container_name_prefix):
|
||||
logs = container.logs(tail=1000).decode('utf-8')
|
||||
logger.debug(
|
||||
f'==== Container logs ====\n{logs}\n==== End of container logs ===='
|
||||
)
|
||||
container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
if close_client:
|
||||
self.docker_client.close()
|
||||
|
||||
async def run_action(self, action: Action) -> Observation:
|
||||
# set timeout to default if not set
|
||||
if action.timeout is None:
|
||||
action.timeout = self.config.sandbox.timeout
|
||||
|
||||
async with self.action_semaphore:
|
||||
if not action.runnable:
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
return ErrorObservation(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return ErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.'
|
||||
)
|
||||
|
||||
logger.info('Awaiting session')
|
||||
session = await self._ensure_session()
|
||||
await self._wait_until_alive()
|
||||
|
||||
assert action.timeout is not None
|
||||
|
||||
try:
|
||||
logger.info('Executing command')
|
||||
async with session.post(
|
||||
f'{self.api_url}/execute_action',
|
||||
json={'action': event_to_dict(action)},
|
||||
timeout=action.timeout,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
output = await response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
return obs
|
||||
else:
|
||||
error_message = await response.text()
|
||||
logger.error(f'Error from server: {error_message}')
|
||||
obs = ErrorObservation(
|
||||
f'Command execution failed: {error_message}'
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error('No response received within the timeout period.')
|
||||
obs = ErrorObservation('Command execution timed out')
|
||||
except Exception as e:
|
||||
logger.error(f'Error during command execution: {e}')
|
||||
obs = ErrorObservation(f'Command execution failed: {str(e)}')
|
||||
return obs
|
||||
|
||||
async def run(self, action: CmdRunAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
return await self.run_action(action)
|
||||
|
||||
# ====================================================================
|
||||
# Implement these methods (for file operations) in the subclass
|
||||
# ====================================================================
|
||||
|
||||
async def copy_to(
|
||||
self, host_src: str, sandbox_dest: str, recursive: bool = False
|
||||
) -> None:
|
||||
if not os.path.exists(host_src):
|
||||
raise FileNotFoundError(f'Source file {host_src} does not exist')
|
||||
|
||||
session = await self._ensure_session()
|
||||
await self._wait_until_alive()
|
||||
try:
|
||||
if recursive:
|
||||
# For recursive copy, create a zip file
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix='.zip', delete=False
|
||||
) as temp_zip:
|
||||
temp_zip_path = temp_zip.name
|
||||
|
||||
with ZipFile(temp_zip_path, 'w') as zipf:
|
||||
for root, _, files in os.walk(host_src):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(
|
||||
file_path, os.path.dirname(host_src)
|
||||
)
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
upload_data = {'file': open(temp_zip_path, 'rb')}
|
||||
else:
|
||||
# For single file copy
|
||||
upload_data = {'file': open(host_src, 'rb')}
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
async with session.post(
|
||||
f'{self.api_url}/upload_file', data=upload_data, params=params
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return
|
||||
else:
|
||||
error_message = await response.text()
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Copy operation failed: {str(e)}')
|
||||
finally:
|
||||
if recursive:
|
||||
os.unlink(temp_zip_path)
|
||||
logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
|
||||
|
||||
async def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files in the sandbox.
|
||||
|
||||
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
await self._wait_until_alive()
|
||||
try:
|
||||
data = {}
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
async with session.post(
|
||||
f'{self.api_url}/list_files', json=data
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
response_json = await response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = await response.text()
|
||||
raise Exception(f'List files operation failed: {error_message}')
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError('List files operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'List files operation failed: {str(e)}')
|
||||
35
openhands/runtime/e2b/README.md
Normal file
35
openhands/runtime/e2b/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# How to use E2B
|
||||
|
||||
[E2B](https://e2b.dev) is an [open-source](https://github.com/e2b-dev/e2b) secure cloud environment (sandbox) made for running AI-generated code and agents. E2B offers [Python](https://pypi.org/project/e2b/) and [JS/TS](https://www.npmjs.com/package/e2b) SDK to spawn and control these sandboxes.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. [Get your API key](https://e2b.dev/docs/getting-started/api-key)
|
||||
|
||||
1. Set your E2B API key to the `E2B_API_KEY` env var when starting the Docker container
|
||||
|
||||
1. **Optional** - Install the CLI with NPM.
|
||||
```sh
|
||||
npm install -g @e2b/cli@latest
|
||||
```
|
||||
Full CLI API is [here](https://e2b.dev/docs/cli/installation).
|
||||
|
||||
## OpenHands sandbox
|
||||
You can use the E2B CLI to create a custom sandbox with a Dockerfile. Read the full guide [here](https://e2b.dev/docs/guide/custom-sandbox). The premade OpenHands sandbox for E2B is set up in the [`containers` directory](/containers/e2b-sandbox). and it's called `openhands`.
|
||||
|
||||
## Debugging
|
||||
You can connect to a running E2B sandbox with E2B CLI in your terminal.
|
||||
|
||||
- List all running sandboxes (based on your API key)
|
||||
```sh
|
||||
e2b sandbox list
|
||||
```
|
||||
|
||||
- Connect to a running sandbox
|
||||
```sh
|
||||
e2b sandbox connect <sandbox-id>
|
||||
```
|
||||
|
||||
## Links
|
||||
- [E2B Docs](https://e2b.dev/docs)
|
||||
- [E2B GitHub](https://github.com/e2b-dev/e2b)
|
||||
18
openhands/runtime/e2b/filestore.py
Normal file
18
openhands/runtime/e2b/filestore.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class E2BFileStore(FileStore):
|
||||
def __init__(self, filesystem):
|
||||
self.filesystem = filesystem
|
||||
|
||||
def write(self, path: str, contents: str) -> None:
|
||||
self.filesystem.write(path, contents)
|
||||
|
||||
def read(self, path: str) -> str:
|
||||
return self.filesystem.read(path)
|
||||
|
||||
def list(self, path: str) -> list[str]:
|
||||
return self.filesystem.list(path)
|
||||
|
||||
def delete(self, path: str) -> None:
|
||||
self.filesystem.delete(path)
|
||||
57
openhands/runtime/e2b/runtime.py
Normal file
57
openhands/runtime/e2b/runtime.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events.action import (
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.runtime import Runtime
|
||||
|
||||
from ..utils.files import insert_lines, read_lines
|
||||
from .filestore import E2BFileStore
|
||||
from .sandbox import E2BSandbox
|
||||
|
||||
|
||||
class E2BRuntime(Runtime):
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig,
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
sandbox: E2BSandbox | None = None,
|
||||
):
|
||||
super().__init__(config, event_stream, sid, plugins)
|
||||
if sandbox is None:
|
||||
self.sandbox = E2BSandbox()
|
||||
if not isinstance(self.sandbox, E2BSandbox):
|
||||
raise ValueError('E2BRuntime requires an E2BSandbox')
|
||||
self.file_store = E2BFileStore(self.sandbox.filesystem)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
content = self.file_store.read(action.path)
|
||||
lines = read_lines(content.split('\n'), action.start, action.end)
|
||||
code_view = ''.join(lines)
|
||||
return FileReadObservation(code_view, path=action.path)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
if action.start == 0 and action.end == -1:
|
||||
self.file_store.write(action.path, action.content)
|
||||
return FileWriteObservation(content='', path=action.path)
|
||||
files = self.file_store.list(action.path)
|
||||
if action.path in files:
|
||||
all_lines = self.file_store.read(action.path).split('\n')
|
||||
new_file = insert_lines(
|
||||
action.content.split('\n'), all_lines, action.start, action.end
|
||||
)
|
||||
self.file_store.write(action.path, ''.join(new_file))
|
||||
return FileWriteObservation('', path=action.path)
|
||||
else:
|
||||
# FIXME: we should create a new file here
|
||||
return ErrorObservation(f'File not found: {action.path}')
|
||||
116
openhands/runtime/e2b/sandbox.py
Normal file
116
openhands/runtime/e2b/sandbox.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import copy
|
||||
import os
|
||||
import tarfile
|
||||
from glob import glob
|
||||
|
||||
from e2b import Sandbox as E2BSandbox
|
||||
from e2b.sandbox.exception import (
|
||||
TimeoutException,
|
||||
)
|
||||
|
||||
from openhands.core.config import SandboxConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class E2BBox:
|
||||
closed = False
|
||||
_cwd: str = '/home/user'
|
||||
_env: dict[str, str] = {}
|
||||
is_initial_session: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SandboxConfig,
|
||||
e2b_api_key: str,
|
||||
template: str = 'openhands',
|
||||
):
|
||||
self.config = copy.deepcopy(config)
|
||||
self.initialize_plugins: bool = config.initialize_plugins
|
||||
self.sandbox = E2BSandbox(
|
||||
api_key=e2b_api_key,
|
||||
template=template,
|
||||
# It's possible to stream stdout and stderr from sandbox and from each process
|
||||
on_stderr=lambda x: logger.info(f'E2B sandbox stderr: {x}'),
|
||||
on_stdout=lambda x: logger.info(f'E2B sandbox stdout: {x}'),
|
||||
cwd=self._cwd, # Default workdir inside sandbox
|
||||
)
|
||||
logger.info(f'Started E2B sandbox with ID "{self.sandbox.id}"')
|
||||
|
||||
@property
|
||||
def filesystem(self):
|
||||
return self.sandbox.filesystem
|
||||
|
||||
def _archive(self, host_src: str, recursive: bool = False):
|
||||
if recursive:
|
||||
assert os.path.isdir(
|
||||
host_src
|
||||
), 'Source must be a directory when recursive is True'
|
||||
files = glob(host_src + '/**/*', recursive=True)
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
for file in files:
|
||||
tar.add(
|
||||
file, arcname=os.path.relpath(file, os.path.dirname(host_src))
|
||||
)
|
||||
else:
|
||||
assert os.path.isfile(
|
||||
host_src
|
||||
), 'Source must be a file when recursive is False'
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
tar.add(host_src, arcname=srcname)
|
||||
return tar_filename
|
||||
|
||||
def execute(self, cmd: str, timeout: int | None = None) -> tuple[int, str]:
|
||||
timeout = timeout if timeout is not None else self.config.timeout
|
||||
process = self.sandbox.process.start(cmd, env_vars=self._env)
|
||||
try:
|
||||
process_output = process.wait(timeout=timeout)
|
||||
except TimeoutException:
|
||||
logger.info('Command timed out, killing process...')
|
||||
process.kill()
|
||||
return -1, f'Command: "{cmd}" timed out'
|
||||
|
||||
logs = [m.line for m in process_output.messages]
|
||||
logs_str = '\n'.join(logs)
|
||||
if process.exit_code is None:
|
||||
return -1, logs_str
|
||||
|
||||
assert process_output.exit_code is not None
|
||||
return process_output.exit_code, logs_str
|
||||
|
||||
def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
|
||||
"""Copies a local file or directory to the sandbox."""
|
||||
tar_filename = self._archive(host_src, recursive)
|
||||
|
||||
# Prepend the sandbox destination with our sandbox cwd
|
||||
sandbox_dest = os.path.join(self._cwd, sandbox_dest.removeprefix('/'))
|
||||
|
||||
with open(tar_filename, 'rb') as tar_file:
|
||||
# Upload the archive to /home/user (default destination that always exists)
|
||||
uploaded_path = self.sandbox.upload_file(tar_file)
|
||||
|
||||
# Check if sandbox_dest exists. If not, create it.
|
||||
process = self.sandbox.process.start_and_wait(f'test -d {sandbox_dest}')
|
||||
if process.exit_code != 0:
|
||||
self.sandbox.filesystem.make_dir(sandbox_dest)
|
||||
|
||||
# Extract the archive into the destination and delete the archive
|
||||
process = self.sandbox.process.start_and_wait(
|
||||
f'sudo tar -xf {uploaded_path} -C {sandbox_dest} && sudo rm {uploaded_path}'
|
||||
)
|
||||
if process.exit_code != 0:
|
||||
raise Exception(
|
||||
f'Failed to extract {uploaded_path} to {sandbox_dest}: {process.stderr}'
|
||||
)
|
||||
|
||||
# Delete the local archive
|
||||
os.remove(tar_filename)
|
||||
|
||||
def close(self):
|
||||
self.sandbox.close()
|
||||
|
||||
def get_working_directory(self):
|
||||
return self.sandbox.cwd
|
||||
18
openhands/runtime/plugins/__init__.py
Normal file
18
openhands/runtime/plugins/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Requirements
|
||||
from .agent_skills import AgentSkillsPlugin, AgentSkillsRequirement
|
||||
from .jupyter import JupyterPlugin, JupyterRequirement
|
||||
from .requirement import Plugin, PluginRequirement
|
||||
|
||||
__all__ = [
|
||||
'Plugin',
|
||||
'PluginRequirement',
|
||||
'AgentSkillsRequirement',
|
||||
'AgentSkillsPlugin',
|
||||
'JupyterRequirement',
|
||||
'JupyterPlugin',
|
||||
]
|
||||
|
||||
ALL_PLUGINS = {
|
||||
'jupyter': JupyterPlugin,
|
||||
'agent_skills': AgentSkillsPlugin,
|
||||
}
|
||||
57
openhands/runtime/plugins/agent_skills/README.md
Normal file
57
openhands/runtime/plugins/agent_skills/README.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# OpenHands Skill Sets
|
||||
|
||||
This folder implements a skill/tool set `agentskills` for OpenHands.
|
||||
|
||||
It is intended to be used by the agent **inside sandbox**.
|
||||
The skill set will be exposed as a `pip` package that can be installed as a plugin inside the sandbox.
|
||||
|
||||
The skill set can contain a bunch of wrapped tools for agent ([many examples here](https://github.com/All-Hands-AI/OpenHands/pull/1914)), for example:
|
||||
- Audio/Video to text (these are a temporary solution, and we should switch to multimodal models when they are sufficiently cheap
|
||||
- PDF to text
|
||||
- etc.
|
||||
|
||||
# Inclusion Criteria
|
||||
|
||||
We are walking a fine line here.
|
||||
We DON't want to *wrap* every possible python packages and re-teach agent their usage (e.g., LLM already knows `pandas` pretty well, so we don't really need create a skill that reads `csv` - it can just use `pandas`).
|
||||
|
||||
We ONLY want to add a new skill, when:
|
||||
- Such skill is not easily achievable for LLM to write code directly (e.g., edit code and replace certain line)
|
||||
- It involves calling an external model (e.g., you need to call a speech to text model, editor model for speculative editing)
|
||||
|
||||
# Intended functionality
|
||||
|
||||
- Tool/skill usage (through `IPythonRunAction`)
|
||||
|
||||
```python
|
||||
# In[1]
|
||||
from agentskills import open_file, edit_file
|
||||
open_file("/workspace/a.txt")
|
||||
# Out[1]
|
||||
[SWE-agent open output]
|
||||
|
||||
# In[2]
|
||||
edit_file(
|
||||
"/workspace/a.txt",
|
||||
start=1, end=3,
|
||||
content=(
|
||||
("REPLACE TEXT")
|
||||
))
|
||||
# Out[1]
|
||||
[SWE-agent edit output]
|
||||
```
|
||||
|
||||
- Tool/skill retrieval (through `IPythonRunAction`)
|
||||
|
||||
```python
|
||||
# In[1]
|
||||
from agentskills import help_me
|
||||
|
||||
help_me("I want to solve a task that involves reading a bunch of PDFs and reason about them")
|
||||
|
||||
# Out[1]
|
||||
"Here are the top skills that may be helpful to you:
|
||||
- `pdf_to_text`: [documentation about the tools]
|
||||
...
|
||||
"
|
||||
```
|
||||
15
openhands/runtime/plugins/agent_skills/__init__.py
Normal file
15
openhands/runtime/plugins/agent_skills/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
|
||||
|
||||
from . import agentskills
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSkillsRequirement(PluginRequirement):
|
||||
name: str = 'agent_skills'
|
||||
documentation: str = agentskills.DOCUMENTATION
|
||||
|
||||
|
||||
class AgentSkillsPlugin(Plugin):
|
||||
name: str = 'agent_skills'
|
||||
25
openhands/runtime/plugins/agent_skills/agentskills.py
Normal file
25
openhands/runtime/plugins/agent_skills/agentskills.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from inspect import signature
|
||||
|
||||
from . import file_ops, file_reader
|
||||
from .utils.dependency import import_functions
|
||||
|
||||
import_functions(
|
||||
module=file_ops, function_names=file_ops.__all__, target_globals=globals()
|
||||
)
|
||||
import_functions(
|
||||
module=file_reader, function_names=file_reader.__all__, target_globals=globals()
|
||||
)
|
||||
__all__ = file_ops.__all__ + file_reader.__all__
|
||||
|
||||
DOCUMENTATION = ''
|
||||
for func_name in __all__:
|
||||
func = globals()[func_name]
|
||||
|
||||
cur_doc = func.__doc__
|
||||
# remove indentation from docstring and extra empty lines
|
||||
cur_doc = '\n'.join(filter(None, map(lambda x: x.strip(), cur_doc.split('\n'))))
|
||||
# now add a consistent 4 indentation
|
||||
cur_doc = '\n'.join(map(lambda x: ' ' * 4 + x, cur_doc.split('\n')))
|
||||
|
||||
fn_signature = f'{func.__name__}' + str(signature(func))
|
||||
DOCUMENTATION += f'{fn_signature}:\n{cur_doc}\n\n'
|
||||
@@ -0,0 +1,7 @@
|
||||
from ..utils.dependency import import_functions
|
||||
from . import file_ops
|
||||
|
||||
import_functions(
|
||||
module=file_ops, function_names=file_ops.__all__, target_globals=globals()
|
||||
)
|
||||
__all__ = file_ops.__all__
|
||||
857
openhands/runtime/plugins/agent_skills/file_ops/file_ops.py
Normal file
857
openhands/runtime/plugins/agent_skills/file_ops/file_ops.py
Normal file
@@ -0,0 +1,857 @@
|
||||
"""file_ops.py
|
||||
|
||||
This module provides various file manipulation skills for the OpenHands agent.
|
||||
|
||||
Functions:
|
||||
- open_file(path: str, line_number: int | None = 1, context_lines: int = 100): Opens a file and optionally moves to a specific line.
|
||||
- goto_line(line_number: int): Moves the window to show the specified line number.
|
||||
- scroll_down(): Moves the window down by the number of lines specified in WINDOW.
|
||||
- scroll_up(): Moves the window up by the number of lines specified in WINDOW.
|
||||
- create_file(filename: str): Creates and opens a new file with the given name.
|
||||
- search_dir(search_term: str, dir_path: str = './'): Searches for a term in all files in the specified directory.
|
||||
- search_file(search_term: str, file_path: str | None = None): Searches for a term in the specified file or the currently open file.
|
||||
- find_file(file_name: str, dir_path: str = './'): Finds all files with the given name in the specified directory.
|
||||
- edit_file_by_replace(file_name: str, to_replace: str, new_content: str): Replaces specific content in a file with new content.
|
||||
- insert_content_at_line(file_name: str, line_number: int, content: str): Inserts given content at the specified line number in a file.
|
||||
- append_file(file_name: str, content: str): Appends the given content to the end of the specified file.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
if __package__ is None or __package__ == '':
|
||||
from aider import Linter
|
||||
else:
|
||||
from ..utils.aider import Linter
|
||||
|
||||
CURRENT_FILE: str | None = None
|
||||
CURRENT_LINE = 1
|
||||
WINDOW = 100
|
||||
|
||||
# This is also used in unit tests!
|
||||
MSG_FILE_UPDATED = '[File updated (edited at line {line_number}). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]'
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
|
||||
def _is_valid_filename(file_name) -> bool:
|
||||
if not file_name or not isinstance(file_name, str) or not file_name.strip():
|
||||
return False
|
||||
invalid_chars = '<>:"/\\|?*'
|
||||
if os.name == 'nt': # Windows
|
||||
invalid_chars = '<>:"/\\|?*'
|
||||
elif os.name == 'posix': # Unix-like systems
|
||||
invalid_chars = '\0'
|
||||
|
||||
for char in invalid_chars:
|
||||
if char in file_name:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_valid_path(path) -> bool:
|
||||
if not path or not isinstance(path, str):
|
||||
return False
|
||||
try:
|
||||
return os.path.exists(os.path.normpath(path))
|
||||
except PermissionError:
|
||||
return False
|
||||
|
||||
|
||||
def _create_paths(file_name) -> bool:
|
||||
try:
|
||||
dirname = os.path.dirname(file_name)
|
||||
if dirname:
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
return True
|
||||
except PermissionError:
|
||||
return False
|
||||
|
||||
|
||||
def _check_current_file(file_path: str | None = None) -> bool:
|
||||
global CURRENT_FILE
|
||||
if not file_path:
|
||||
file_path = CURRENT_FILE
|
||||
if not file_path or not os.path.isfile(file_path):
|
||||
raise ValueError('No file open. Use the open_file function first.')
|
||||
return True
|
||||
|
||||
|
||||
def _clamp(value, min_value, max_value):
|
||||
return max(min_value, min(value, max_value))
|
||||
|
||||
|
||||
def _lint_file(file_path: str) -> tuple[str | None, int | None]:
|
||||
"""Lint the file at the given path and return a tuple with a boolean indicating if there are errors,
|
||||
and the line number of the first error, if any.
|
||||
|
||||
Returns:
|
||||
tuple[str | None, int | None]: (lint_error, first_error_line_number)
|
||||
"""
|
||||
linter = Linter(root=os.getcwd())
|
||||
lint_error = linter.lint(file_path)
|
||||
if not lint_error:
|
||||
# Linting successful. No issues found.
|
||||
return None, None
|
||||
return 'ERRORS:\n' + lint_error.text, lint_error.lines[0]
|
||||
|
||||
|
||||
def _print_window(file_path, targeted_line, window, return_str=False):
|
||||
global CURRENT_LINE
|
||||
_check_current_file(file_path)
|
||||
with open(file_path) as file:
|
||||
content = file.read()
|
||||
|
||||
# Ensure the content ends with a newline character
|
||||
if not content.endswith('\n'):
|
||||
content += '\n'
|
||||
|
||||
lines = content.splitlines(True) # Keep all line ending characters
|
||||
total_lines = len(lines)
|
||||
|
||||
# cover edge cases
|
||||
CURRENT_LINE = _clamp(targeted_line, 1, total_lines)
|
||||
half_window = max(1, window // 2)
|
||||
|
||||
# Ensure at least one line above and below the targeted line
|
||||
start = max(1, CURRENT_LINE - half_window)
|
||||
end = min(total_lines, CURRENT_LINE + half_window)
|
||||
|
||||
# Adjust start and end to ensure at least one line above and below
|
||||
if start == 1:
|
||||
end = min(total_lines, start + window - 1)
|
||||
if end == total_lines:
|
||||
start = max(1, end - window + 1)
|
||||
|
||||
output = ''
|
||||
|
||||
# only display this when there's at least one line above
|
||||
if start > 1:
|
||||
output += f'({start - 1} more lines above)\n'
|
||||
else:
|
||||
output += '(this is the beginning of the file)\n'
|
||||
for i in range(start, end + 1):
|
||||
_new_line = f'{i}|{lines[i-1]}'
|
||||
if not _new_line.endswith('\n'):
|
||||
_new_line += '\n'
|
||||
output += _new_line
|
||||
if end < total_lines:
|
||||
output += f'({total_lines - end} more lines below)\n'
|
||||
else:
|
||||
output += '(this is the end of the file)\n'
|
||||
output = output.rstrip()
|
||||
|
||||
if return_str:
|
||||
return output
|
||||
else:
|
||||
print(output)
|
||||
|
||||
|
||||
def _cur_file_header(current_file, total_lines) -> str:
|
||||
if not current_file:
|
||||
return ''
|
||||
return f'[File: {os.path.abspath(current_file)} ({total_lines} lines total)]\n'
|
||||
|
||||
|
||||
def open_file(
|
||||
path: str, line_number: int | None = 1, context_lines: int | None = WINDOW
|
||||
) -> None:
|
||||
"""Opens the file at the given path in the editor. If line_number is provided, the window will be moved to include that line.
|
||||
It only shows the first 100 lines by default! Max `context_lines` supported is 2000, use `scroll up/down`
|
||||
to view the file if you want to see more.
|
||||
|
||||
Args:
|
||||
path: str: The path to the file to open, preferred absolute path.
|
||||
line_number: int | None = 1: The line number to move to. Defaults to 1.
|
||||
context_lines: int | None = 100: Only shows this number of lines in the context window (usually from line 1), with line_number as the center (if possible). Defaults to 100.
|
||||
"""
|
||||
global CURRENT_FILE, CURRENT_LINE, WINDOW
|
||||
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f'File {path} not found')
|
||||
|
||||
CURRENT_FILE = os.path.abspath(path)
|
||||
with open(CURRENT_FILE) as file:
|
||||
total_lines = max(1, sum(1 for _ in file))
|
||||
|
||||
if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines:
|
||||
raise ValueError(f'Line number must be between 1 and {total_lines}')
|
||||
CURRENT_LINE = line_number
|
||||
|
||||
# Override WINDOW with context_lines
|
||||
if context_lines is None or context_lines < 1:
|
||||
context_lines = WINDOW
|
||||
|
||||
output = _cur_file_header(CURRENT_FILE, total_lines)
|
||||
output += _print_window(
|
||||
CURRENT_FILE, CURRENT_LINE, _clamp(context_lines, 1, 2000), return_str=True
|
||||
)
|
||||
print(output)
|
||||
|
||||
|
||||
def goto_line(line_number: int) -> None:
|
||||
"""Moves the window to show the specified line number.
|
||||
|
||||
Args:
|
||||
line_number: int: The line number to move to.
|
||||
"""
|
||||
global CURRENT_FILE, CURRENT_LINE, WINDOW
|
||||
_check_current_file()
|
||||
|
||||
with open(str(CURRENT_FILE)) as file:
|
||||
total_lines = max(1, sum(1 for _ in file))
|
||||
if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines:
|
||||
raise ValueError(f'Line number must be between 1 and {total_lines}')
|
||||
|
||||
CURRENT_LINE = _clamp(line_number, 1, total_lines)
|
||||
|
||||
output = _cur_file_header(CURRENT_FILE, total_lines)
|
||||
output += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True)
|
||||
print(output)
|
||||
|
||||
|
||||
def scroll_down() -> None:
|
||||
"""Moves the window down by 100 lines.
|
||||
|
||||
Args:
|
||||
None
|
||||
"""
|
||||
global CURRENT_FILE, CURRENT_LINE, WINDOW
|
||||
_check_current_file()
|
||||
|
||||
with open(str(CURRENT_FILE)) as file:
|
||||
total_lines = max(1, sum(1 for _ in file))
|
||||
CURRENT_LINE = _clamp(CURRENT_LINE + WINDOW, 1, total_lines)
|
||||
output = _cur_file_header(CURRENT_FILE, total_lines)
|
||||
output += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True)
|
||||
print(output)
|
||||
|
||||
|
||||
def scroll_up() -> None:
|
||||
"""Moves the window up by 100 lines.
|
||||
|
||||
Args:
|
||||
None
|
||||
"""
|
||||
global CURRENT_FILE, CURRENT_LINE, WINDOW
|
||||
_check_current_file()
|
||||
|
||||
with open(str(CURRENT_FILE)) as file:
|
||||
total_lines = max(1, sum(1 for _ in file))
|
||||
CURRENT_LINE = _clamp(CURRENT_LINE - WINDOW, 1, total_lines)
|
||||
output = _cur_file_header(CURRENT_FILE, total_lines)
|
||||
output += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True)
|
||||
print(output)
|
||||
|
||||
|
||||
def create_file(filename: str) -> None:
|
||||
"""Creates and opens a new file with the given name.
|
||||
|
||||
Args:
|
||||
filename: str: The name of the file to create.
|
||||
"""
|
||||
if os.path.exists(filename):
|
||||
raise FileExistsError(f"File '{filename}' already exists.")
|
||||
|
||||
with open(filename, 'w') as file:
|
||||
file.write('\n')
|
||||
|
||||
open_file(filename)
|
||||
print(f'[File {filename} created.]')
|
||||
|
||||
|
||||
LINTER_ERROR_MSG = '[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n'
|
||||
|
||||
|
||||
class LineNumberError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _append_impl(lines, content):
|
||||
"""Internal method to handle appending to a file.
|
||||
|
||||
Args:
|
||||
lines: list[str]: The lines in the original file.
|
||||
content: str: The content to append to the file.
|
||||
|
||||
Returns:
|
||||
content: str: The new content of the file.
|
||||
n_added_lines: int: The number of lines added to the file.
|
||||
"""
|
||||
content_lines = content.splitlines(keepends=True)
|
||||
n_added_lines = len(content_lines)
|
||||
if lines and not (len(lines) == 1 and lines[0].strip() == ''):
|
||||
# file is not empty
|
||||
if not lines[-1].endswith('\n'):
|
||||
lines[-1] += '\n'
|
||||
new_lines = lines + content_lines
|
||||
content = ''.join(new_lines)
|
||||
else:
|
||||
# file is empty
|
||||
content = ''.join(content_lines)
|
||||
|
||||
return content, n_added_lines
|
||||
|
||||
|
||||
def _insert_impl(lines, start, content):
|
||||
"""Internal method to handle inserting to a file.
|
||||
|
||||
Args:
|
||||
lines: list[str]: The lines in the original file.
|
||||
start: int: The start line number for inserting.
|
||||
content: str: The content to insert to the file.
|
||||
|
||||
Returns:
|
||||
content: str: The new content of the file.
|
||||
n_added_lines: int: The number of lines added to the file.
|
||||
|
||||
Raises:
|
||||
LineNumberError: If the start line number is invalid.
|
||||
"""
|
||||
inserted_lines = [content + '\n' if not content.endswith('\n') else content]
|
||||
if len(lines) == 0:
|
||||
new_lines = inserted_lines
|
||||
elif start is not None:
|
||||
if len(lines) == 1 and lines[0].strip() == '':
|
||||
# if the file with only 1 line and that line is empty
|
||||
lines = []
|
||||
|
||||
if len(lines) == 0:
|
||||
new_lines = inserted_lines
|
||||
else:
|
||||
new_lines = lines[: start - 1] + inserted_lines + lines[start - 1 :]
|
||||
else:
|
||||
raise LineNumberError(
|
||||
f'Invalid line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive).'
|
||||
)
|
||||
|
||||
content = ''.join(new_lines)
|
||||
n_added_lines = len(inserted_lines)
|
||||
return content, n_added_lines
|
||||
|
||||
|
||||
def _edit_impl(lines, start, end, content):
|
||||
"""Internal method to handle editing a file.
|
||||
|
||||
REQUIRES (should be checked by caller):
|
||||
start <= end
|
||||
start and end are between 1 and len(lines) (inclusive)
|
||||
content ends with a newline
|
||||
|
||||
Args:
|
||||
lines: list[str]: The lines in the original file.
|
||||
start: int: The start line number for editing.
|
||||
end: int: The end line number for editing.
|
||||
content: str: The content to replace the lines with.
|
||||
|
||||
Returns:
|
||||
content: str: The new content of the file.
|
||||
n_added_lines: int: The number of lines added to the file.
|
||||
"""
|
||||
# Handle cases where start or end are None
|
||||
if start is None:
|
||||
start = 1 # Default to the beginning
|
||||
if end is None:
|
||||
end = len(lines) # Default to the end
|
||||
# Check arguments
|
||||
if not (1 <= start <= len(lines)):
|
||||
raise LineNumberError(
|
||||
f'Invalid start line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive).'
|
||||
)
|
||||
if not (1 <= end <= len(lines)):
|
||||
raise LineNumberError(
|
||||
f'Invalid end line number: {end}. Line numbers must be between 1 and {len(lines)} (inclusive).'
|
||||
)
|
||||
if start > end:
|
||||
raise LineNumberError(
|
||||
f'Invalid line range: {start}-{end}. Start must be less than or equal to end.'
|
||||
)
|
||||
|
||||
if not content.endswith('\n'):
|
||||
content += '\n'
|
||||
content_lines = content.splitlines(True)
|
||||
n_added_lines = len(content_lines)
|
||||
new_lines = lines[: start - 1] + content_lines + lines[end:]
|
||||
content = ''.join(new_lines)
|
||||
return content, n_added_lines
|
||||
|
||||
|
||||
def _edit_file_impl(
|
||||
file_name: str,
|
||||
start: int | None = None,
|
||||
end: int | None = None,
|
||||
content: str = '',
|
||||
is_insert: bool = False,
|
||||
is_append: bool = False,
|
||||
) -> str:
|
||||
"""Internal method to handle common logic for edit_/append_file methods.
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit or append to.
|
||||
start: int | None = None: The start line number for editing. Ignored if is_append is True.
|
||||
end: int | None = None: The end line number for editing. Ignored if is_append is True.
|
||||
content: str: The content to replace the lines with or to append.
|
||||
is_insert: bool = False: Whether to insert content at the given line number instead of editing.
|
||||
is_append: bool = False: Whether to append content to the file instead of editing.
|
||||
"""
|
||||
ret_str = ''
|
||||
global CURRENT_FILE, CURRENT_LINE, WINDOW
|
||||
|
||||
ERROR_MSG = f'[Error editing file {file_name}. Please confirm the file is correct.]'
|
||||
ERROR_MSG_SUFFIX = (
|
||||
'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Open the correct file and try again or 2) Specify the correct line number arguments.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.'
|
||||
)
|
||||
|
||||
if not _is_valid_filename(file_name):
|
||||
raise FileNotFoundError('Invalid file name.')
|
||||
|
||||
if not _is_valid_path(file_name):
|
||||
raise FileNotFoundError('Invalid path or file name.')
|
||||
|
||||
if not _create_paths(file_name):
|
||||
raise PermissionError('Could not access or create directories.')
|
||||
|
||||
if not os.path.isfile(file_name):
|
||||
raise FileNotFoundError(f'File {file_name} not found.')
|
||||
|
||||
if is_insert and is_append:
|
||||
raise ValueError('Cannot insert and append at the same time.')
|
||||
|
||||
# Use a temporary file to write changes
|
||||
content = str(content or '')
|
||||
temp_file_path = ''
|
||||
src_abs_path = os.path.abspath(file_name)
|
||||
first_error_line = None
|
||||
|
||||
try:
|
||||
n_added_lines = None
|
||||
|
||||
# lint the original file
|
||||
enable_auto_lint = os.getenv('ENABLE_AUTO_LINT', 'false').lower() == 'true'
|
||||
if enable_auto_lint:
|
||||
original_lint_error, _ = _lint_file(file_name)
|
||||
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
# Read the original file and check if empty and for a trailing newline
|
||||
with open(file_name) as original_file:
|
||||
lines = original_file.readlines()
|
||||
|
||||
if is_append:
|
||||
content, n_added_lines = _append_impl(lines, content)
|
||||
elif is_insert:
|
||||
try:
|
||||
content, n_added_lines = _insert_impl(lines, start, content)
|
||||
except LineNumberError as e:
|
||||
ret_str += (f'{ERROR_MSG}\n' f'{e}\n' f'{ERROR_MSG_SUFFIX}') + '\n'
|
||||
return ret_str
|
||||
else:
|
||||
try:
|
||||
content, n_added_lines = _edit_impl(lines, start, end, content)
|
||||
except LineNumberError as e:
|
||||
ret_str += (f'{ERROR_MSG}\n' f'{e}\n' f'{ERROR_MSG_SUFFIX}') + '\n'
|
||||
return ret_str
|
||||
|
||||
if not content.endswith('\n'):
|
||||
content += '\n'
|
||||
|
||||
# Write the new content to the temporary file
|
||||
temp_file.write(content)
|
||||
|
||||
# Replace the original file with the temporary file atomically
|
||||
shutil.move(temp_file_path, src_abs_path)
|
||||
|
||||
# Handle linting
|
||||
# NOTE: we need to get env var inside this function
|
||||
# because the env var will be set AFTER the agentskills is imported
|
||||
if enable_auto_lint:
|
||||
# BACKUP the original file
|
||||
original_file_backup_path = os.path.join(
|
||||
os.path.dirname(file_name),
|
||||
f'.backup.{os.path.basename(file_name)}',
|
||||
)
|
||||
with open(original_file_backup_path, 'w') as f:
|
||||
f.writelines(lines)
|
||||
|
||||
lint_error, first_error_line = _lint_file(file_name)
|
||||
|
||||
# Select the errors caused by the modification
|
||||
def extract_last_part(line):
|
||||
parts = line.split(':')
|
||||
if len(parts) > 1:
|
||||
return parts[-1].strip()
|
||||
return line.strip()
|
||||
|
||||
def subtract_strings(str1, str2) -> str:
|
||||
lines1 = str1.splitlines()
|
||||
lines2 = str2.splitlines()
|
||||
|
||||
last_parts1 = [extract_last_part(line) for line in lines1]
|
||||
|
||||
remaining_lines = [
|
||||
line
|
||||
for line in lines2
|
||||
if extract_last_part(line) not in last_parts1
|
||||
]
|
||||
|
||||
result = '\n'.join(remaining_lines)
|
||||
return result
|
||||
|
||||
if original_lint_error and lint_error:
|
||||
lint_error = subtract_strings(original_lint_error, lint_error)
|
||||
if lint_error == '':
|
||||
lint_error = None
|
||||
first_error_line = None
|
||||
|
||||
if lint_error is not None:
|
||||
if first_error_line is not None:
|
||||
show_line = int(first_error_line)
|
||||
elif is_append:
|
||||
# original end-of-file
|
||||
show_line = len(lines)
|
||||
# insert OR edit WILL provide meaningful line numbers
|
||||
elif start is not None and end is not None:
|
||||
show_line = int((start + end) / 2)
|
||||
else:
|
||||
raise ValueError('Invalid state. This should never happen.')
|
||||
|
||||
ret_str += LINTER_ERROR_MSG
|
||||
ret_str += lint_error + '\n'
|
||||
|
||||
editor_lines = n_added_lines + 20
|
||||
|
||||
ret_str += '[This is how your edit would have looked if applied]\n'
|
||||
ret_str += '-------------------------------------------------\n'
|
||||
ret_str += (
|
||||
_print_window(file_name, show_line, editor_lines, return_str=True)
|
||||
+ '\n'
|
||||
)
|
||||
ret_str += '-------------------------------------------------\n\n'
|
||||
|
||||
ret_str += '[This is the original code before your edit]\n'
|
||||
ret_str += '-------------------------------------------------\n'
|
||||
ret_str += (
|
||||
_print_window(
|
||||
original_file_backup_path,
|
||||
show_line,
|
||||
editor_lines,
|
||||
return_str=True,
|
||||
)
|
||||
+ '\n'
|
||||
)
|
||||
ret_str += '-------------------------------------------------\n'
|
||||
|
||||
ret_str += (
|
||||
'Your changes have NOT been applied. Please fix your edit command and try again.\n'
|
||||
'You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n'
|
||||
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.'
|
||||
)
|
||||
|
||||
# recover the original file
|
||||
with open(original_file_backup_path) as fin, open(
|
||||
file_name, 'w'
|
||||
) as fout:
|
||||
fout.write(fin.read())
|
||||
os.remove(original_file_backup_path)
|
||||
return ret_str
|
||||
|
||||
except FileNotFoundError as e:
|
||||
ret_str += f'File not found: {e}\n'
|
||||
except IOError as e:
|
||||
ret_str += f'An error occurred while handling the file: {e}\n'
|
||||
except ValueError as e:
|
||||
ret_str += f'Invalid input: {e}\n'
|
||||
except Exception as e:
|
||||
# Clean up the temporary file if an error occurs
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
print(f'An unexpected error occurred: {e}')
|
||||
raise e
|
||||
|
||||
# Update the file information and print the updated content
|
||||
with open(file_name, 'r', encoding='utf-8') as file:
|
||||
n_total_lines = max(1, len(file.readlines()))
|
||||
if first_error_line is not None and int(first_error_line) > 0:
|
||||
CURRENT_LINE = first_error_line
|
||||
else:
|
||||
if is_append:
|
||||
CURRENT_LINE = max(1, len(lines)) # end of original file
|
||||
else:
|
||||
CURRENT_LINE = start or n_total_lines or 1
|
||||
ret_str += f'[File: {os.path.abspath(file_name)} ({n_total_lines} lines total after edit)]\n'
|
||||
CURRENT_FILE = file_name
|
||||
ret_str += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True) + '\n'
|
||||
ret_str += MSG_FILE_UPDATED.format(line_number=CURRENT_LINE)
|
||||
return ret_str
|
||||
|
||||
|
||||
def edit_file_by_replace(file_name: str, to_replace: str, new_content: str) -> None:
|
||||
"""Edit a file. This will search for `to_replace` in the given file and replace it with `new_content`.
|
||||
|
||||
Every *to_replace* must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc.
|
||||
|
||||
Include enough lines to make code in `to_replace` unique. `to_replace` should NOT be empty.
|
||||
|
||||
For example, given a file "/workspace/example.txt" with the following content:
|
||||
```
|
||||
line 1
|
||||
line 2
|
||||
line 2
|
||||
line 3
|
||||
```
|
||||
|
||||
EDITING: If you want to replace the second occurrence of "line 2", you can make `to_replace` unique:
|
||||
|
||||
edit_file_by_replace(
|
||||
'/workspace/example.txt',
|
||||
to_replace='line 2\nline 3',
|
||||
new_content='new line\nline 3',
|
||||
)
|
||||
|
||||
This will replace only the second "line 2" with "new line". The first "line 2" will remain unchanged.
|
||||
|
||||
The resulting file will be:
|
||||
```
|
||||
line 1
|
||||
line 2
|
||||
new line
|
||||
line 3
|
||||
```
|
||||
|
||||
REMOVAL: If you want to remove "line 2" and "line 3", you can set `new_content` to an empty string:
|
||||
|
||||
edit_file_by_replace(
|
||||
'/workspace/example.txt',
|
||||
to_replace='line 2\nline 3',
|
||||
new_content='',
|
||||
)
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit.
|
||||
to_replace: str: The content to search for and replace.
|
||||
new_content: str: The new content to replace the old content with.
|
||||
"""
|
||||
# FIXME: support replacing *all* occurrences
|
||||
if to_replace.strip() == '':
|
||||
raise ValueError('`to_replace` must not be empty.')
|
||||
|
||||
if to_replace == new_content:
|
||||
raise ValueError('`to_replace` and `new_content` must be different.')
|
||||
|
||||
# search for `to_replace` in the file
|
||||
# if found, replace it with `new_content`
|
||||
# if not found, perform a fuzzy search to find the closest match and replace it with `new_content`
|
||||
with open(file_name, 'r') as file:
|
||||
file_content = file.read()
|
||||
|
||||
if file_content.count(to_replace) > 1:
|
||||
raise ValueError(
|
||||
'`to_replace` appears more than once, please include enough lines to make code in `to_replace` unique.'
|
||||
)
|
||||
|
||||
start = file_content.find(to_replace)
|
||||
if start != -1:
|
||||
# Convert start from index to line number
|
||||
start_line_number = file_content[:start].count('\n') + 1
|
||||
end_line_number = start_line_number + len(to_replace.splitlines()) - 1
|
||||
else:
|
||||
|
||||
def _fuzzy_transform(s: str) -> str:
|
||||
# remove all space except newline
|
||||
return re.sub(r'[^\S\n]+', '', s)
|
||||
|
||||
# perform a fuzzy search (remove all spaces except newlines)
|
||||
to_replace_fuzzy = _fuzzy_transform(to_replace)
|
||||
file_content_fuzzy = _fuzzy_transform(file_content)
|
||||
# find the closest match
|
||||
start = file_content_fuzzy.find(to_replace_fuzzy)
|
||||
if start == -1:
|
||||
print(
|
||||
f'[No exact match found in {file_name} for\n```\n{to_replace}\n```\n]'
|
||||
)
|
||||
return
|
||||
# Convert start from index to line number for fuzzy match
|
||||
start_line_number = file_content_fuzzy[:start].count('\n') + 1
|
||||
end_line_number = start_line_number + len(to_replace.splitlines()) - 1
|
||||
|
||||
ret_str = _edit_file_impl(
|
||||
file_name,
|
||||
start=start_line_number,
|
||||
end=end_line_number,
|
||||
content=new_content,
|
||||
is_insert=False,
|
||||
)
|
||||
# lint_error = bool(LINTER_ERROR_MSG in ret_str)
|
||||
# TODO: automatically tries to fix linter error (maybe involve some static analysis tools on the location near the edit to figure out indentation)
|
||||
print(ret_str)
|
||||
|
||||
|
||||
def insert_content_at_line(file_name: str, line_number: int, content: str) -> None:
|
||||
"""Insert content at the given line number in a file.
|
||||
This will NOT modify the content of the lines before OR after the given line number.
|
||||
|
||||
For example, if the file has the following content:
|
||||
```
|
||||
line 1
|
||||
line 2
|
||||
line 3
|
||||
```
|
||||
and you call `insert_content_at_line('file.txt', 2, 'new line')`, the file will be updated to:
|
||||
```
|
||||
line 1
|
||||
new line
|
||||
line 2
|
||||
line 3
|
||||
```
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit.
|
||||
line_number: int: The line number (starting from 1) to insert the content after.
|
||||
content: str: The content to insert.
|
||||
"""
|
||||
ret_str = _edit_file_impl(
|
||||
file_name,
|
||||
start=line_number,
|
||||
end=line_number,
|
||||
content=content,
|
||||
is_insert=True,
|
||||
is_append=False,
|
||||
)
|
||||
print(ret_str)
|
||||
|
||||
|
||||
def append_file(file_name: str, content: str) -> None:
|
||||
"""Append content to the given file.
|
||||
It appends text `content` to the end of the specified file.
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to edit.
|
||||
line_number: int: The line number (starting from 1) to insert the content after.
|
||||
content: str: The content to insert.
|
||||
"""
|
||||
ret_str = _edit_file_impl(
|
||||
file_name,
|
||||
start=None,
|
||||
end=None,
|
||||
content=content,
|
||||
is_insert=False,
|
||||
is_append=True,
|
||||
)
|
||||
print(ret_str)
|
||||
|
||||
|
||||
def search_dir(search_term: str, dir_path: str = './') -> None:
|
||||
"""Searches for search_term in all files in dir. If dir is not provided, searches in the current directory.
|
||||
|
||||
Args:
|
||||
search_term: str: The term to search for.
|
||||
dir_path: str: The path to the directory to search.
|
||||
"""
|
||||
if not os.path.isdir(dir_path):
|
||||
raise FileNotFoundError(f'Directory {dir_path} not found')
|
||||
matches = []
|
||||
for root, _, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
if file.startswith('.'):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, 'r', errors='ignore') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if search_term in line:
|
||||
matches.append((file_path, line_num, line.strip()))
|
||||
|
||||
if not matches:
|
||||
print(f'No matches found for "{search_term}" in {dir_path}')
|
||||
return
|
||||
|
||||
num_matches = len(matches)
|
||||
num_files = len(set(match[0] for match in matches))
|
||||
|
||||
if num_files > 100:
|
||||
print(
|
||||
f'More than {num_files} files matched for "{search_term}" in {dir_path}. Please narrow your search.'
|
||||
)
|
||||
return
|
||||
|
||||
print(f'[Found {num_matches} matches for "{search_term}" in {dir_path}]')
|
||||
for file_path, line_num, line in matches:
|
||||
print(f'{file_path} (Line {line_num}): {line}')
|
||||
print(f'[End of matches for "{search_term}" in {dir_path}]')
|
||||
|
||||
|
||||
def search_file(search_term: str, file_path: str | None = None) -> None:
|
||||
"""Searches for search_term in file. If file is not provided, searches in the current open file.
|
||||
|
||||
Args:
|
||||
search_term: str: The term to search for.
|
||||
file_path: str | None: The path to the file to search.
|
||||
"""
|
||||
global CURRENT_FILE
|
||||
if file_path is None:
|
||||
file_path = CURRENT_FILE
|
||||
if file_path is None:
|
||||
raise FileNotFoundError(
|
||||
'No file specified or open. Use the open_file function first.'
|
||||
)
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f'File {file_path} not found')
|
||||
|
||||
matches = []
|
||||
with open(file_path) as file:
|
||||
for i, line in enumerate(file, 1):
|
||||
if search_term in line:
|
||||
matches.append((i, line.strip()))
|
||||
|
||||
if matches:
|
||||
print(f'[Found {len(matches)} matches for "{search_term}" in {file_path}]')
|
||||
for match in matches:
|
||||
print(f'Line {match[0]}: {match[1]}')
|
||||
print(f'[End of matches for "{search_term}" in {file_path}]')
|
||||
else:
|
||||
print(f'[No matches found for "{search_term}" in {file_path}]')
|
||||
|
||||
|
||||
def find_file(file_name: str, dir_path: str = './') -> None:
|
||||
"""Finds all files with the given name in the specified directory.
|
||||
|
||||
Args:
|
||||
file_name: str: The name of the file to find.
|
||||
dir_path: str: The path to the directory to search.
|
||||
"""
|
||||
if not os.path.isdir(dir_path):
|
||||
raise FileNotFoundError(f'Directory {dir_path} not found')
|
||||
|
||||
matches = []
|
||||
for root, _, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
if file_name in file:
|
||||
matches.append(os.path.join(root, file))
|
||||
|
||||
if matches:
|
||||
print(f'[Found {len(matches)} matches for "{file_name}" in {dir_path}]')
|
||||
for match in matches:
|
||||
print(f'{match}')
|
||||
print(f'[End of matches for "{file_name}" in {dir_path}]')
|
||||
else:
|
||||
print(f'[No matches found for "{file_name}" in {dir_path}]')
|
||||
|
||||
|
||||
__all__ = [
|
||||
'open_file',
|
||||
'goto_line',
|
||||
'scroll_down',
|
||||
'scroll_up',
|
||||
'create_file',
|
||||
'edit_file_by_replace',
|
||||
'insert_content_at_line',
|
||||
'append_file',
|
||||
'search_dir',
|
||||
'search_file',
|
||||
'find_file',
|
||||
]
|
||||
@@ -0,0 +1,7 @@
|
||||
from ..utils.dependency import import_functions
|
||||
from . import file_readers
|
||||
|
||||
import_functions(
|
||||
module=file_readers, function_names=file_readers.__all__, target_globals=globals()
|
||||
)
|
||||
__all__ = file_readers.__all__
|
||||
@@ -0,0 +1,244 @@
|
||||
"""File reader skills for the OpenHands agent.
|
||||
|
||||
This module provides various functions to parse and extract content from different file types,
|
||||
including PDF, DOCX, LaTeX, audio, image, video, and PowerPoint files. It utilizes different
|
||||
libraries and APIs to process these files and output their content or descriptions.
|
||||
|
||||
Functions:
|
||||
parse_pdf(file_path: str) -> None: Parse and print content of a PDF file.
|
||||
parse_docx(file_path: str) -> None: Parse and print content of a DOCX file.
|
||||
parse_latex(file_path: str) -> None: Parse and print content of a LaTeX file.
|
||||
parse_audio(file_path: str, model: str = 'whisper-1') -> None: Transcribe and print content of an audio file.
|
||||
parse_image(file_path: str, task: str = 'Describe this image as detail as possible.') -> None: Analyze and print description of an image file.
|
||||
parse_video(file_path: str, task: str = 'Describe this image as detail as possible.', frame_interval: int = 30) -> None: Analyze and print description of video frames.
|
||||
parse_pptx(file_path: str) -> None: Parse and print content of a PowerPoint file.
|
||||
|
||||
Note:
|
||||
Some functions (parse_audio, parse_video, parse_image) require OpenAI API credentials
|
||||
and are only available if the necessary environment variables are set.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
import docx
|
||||
import PyPDF2
|
||||
from pptx import Presentation
|
||||
from pylatexenc.latex2text import LatexNodes2Text
|
||||
|
||||
from ..utils.config import (
|
||||
_get_max_token,
|
||||
_get_openai_api_key,
|
||||
_get_openai_base_url,
|
||||
_get_openai_client,
|
||||
_get_openai_model,
|
||||
)
|
||||
|
||||
|
||||
def parse_pdf(file_path: str) -> None:
|
||||
"""Parses the content of a PDF file and prints it.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the file to open.
|
||||
"""
|
||||
print(f'[Reading PDF file from {file_path}]')
|
||||
content = PyPDF2.PdfReader(file_path)
|
||||
text = ''
|
||||
for page_idx in range(len(content.pages)):
|
||||
text += (
|
||||
f'@@ Page {page_idx + 1} @@\n'
|
||||
+ content.pages[page_idx].extract_text()
|
||||
+ '\n\n'
|
||||
)
|
||||
print(text.strip())
|
||||
|
||||
|
||||
def parse_docx(file_path: str) -> None:
|
||||
"""Parses the content of a DOCX file and prints it.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the file to open.
|
||||
"""
|
||||
print(f'[Reading DOCX file from {file_path}]')
|
||||
content = docx.Document(file_path)
|
||||
text = ''
|
||||
for i, para in enumerate(content.paragraphs):
|
||||
text += f'@@ Page {i + 1} @@\n' + para.text + '\n\n'
|
||||
print(text)
|
||||
|
||||
|
||||
def parse_latex(file_path: str) -> None:
|
||||
"""Parses the content of a LaTex file and prints it.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the file to open.
|
||||
"""
|
||||
print(f'[Reading LaTex file from {file_path}]')
|
||||
with open(file_path) as f:
|
||||
data = f.read()
|
||||
text = LatexNodes2Text().latex_to_text(data)
|
||||
print(text.strip())
|
||||
|
||||
|
||||
def _base64_img(file_path: str) -> str:
|
||||
with open(file_path, 'rb') as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return encoded_image
|
||||
|
||||
|
||||
def _base64_video(file_path: str, frame_interval: int = 10) -> list[str]:
|
||||
import cv2
|
||||
|
||||
video = cv2.VideoCapture(file_path)
|
||||
base64_frames = []
|
||||
frame_count = 0
|
||||
while video.isOpened():
|
||||
success, frame = video.read()
|
||||
if not success:
|
||||
break
|
||||
if frame_count % frame_interval == 0:
|
||||
_, buffer = cv2.imencode('.jpg', frame)
|
||||
base64_frames.append(base64.b64encode(buffer).decode('utf-8'))
|
||||
frame_count += 1
|
||||
video.release()
|
||||
return base64_frames
|
||||
|
||||
|
||||
def _prepare_image_messages(task: str, base64_image: str):
|
||||
return [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{'type': 'text', 'text': task},
|
||||
{
|
||||
'type': 'image_url',
|
||||
'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def parse_audio(file_path: str, model: str = 'whisper-1') -> None:
|
||||
"""Parses the content of an audio file and prints it.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the audio file to transcribe.
|
||||
model: str: The audio model to use for transcription. Defaults to 'whisper-1'.
|
||||
"""
|
||||
print(f'[Transcribing audio file from {file_path}]')
|
||||
try:
|
||||
# TODO: record the COST of the API call
|
||||
with open(file_path, 'rb') as audio_file:
|
||||
transcript = _get_openai_client().audio.translations.create(
|
||||
model=model, file=audio_file
|
||||
)
|
||||
print(transcript.text)
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error transcribing audio file: {e}')
|
||||
|
||||
|
||||
def parse_image(
|
||||
file_path: str, task: str = 'Describe this image as detail as possible.'
|
||||
) -> None:
|
||||
"""Parses the content of an image file and prints the description.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the file to open.
|
||||
task: str: The task description for the API call. Defaults to 'Describe this image as detail as possible.'.
|
||||
"""
|
||||
print(f'[Reading image file from {file_path}]')
|
||||
# TODO: record the COST of the API call
|
||||
try:
|
||||
base64_image = _base64_img(file_path)
|
||||
response = _get_openai_client().chat.completions.create(
|
||||
model=_get_openai_model(),
|
||||
messages=_prepare_image_messages(task, base64_image),
|
||||
max_tokens=_get_max_token(),
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
print(content)
|
||||
|
||||
except Exception as error:
|
||||
print(f'Error with the request: {error}')
|
||||
|
||||
|
||||
def parse_video(
|
||||
file_path: str,
|
||||
task: str = 'Describe this image as detail as possible.',
|
||||
frame_interval: int = 30,
|
||||
) -> None:
|
||||
"""Parses the content of an image file and prints the description.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the video file to open.
|
||||
task: str: The task description for the API call. Defaults to 'Describe this image as detail as possible.'.
|
||||
frame_interval: int: The interval between frames to analyze. Defaults to 30.
|
||||
|
||||
"""
|
||||
print(
|
||||
f'[Processing video file from {file_path} with frame interval {frame_interval}]'
|
||||
)
|
||||
|
||||
task = task or 'This is one frame from a video, please summarize this frame.'
|
||||
base64_frames = _base64_video(file_path)
|
||||
selected_frames = base64_frames[::frame_interval]
|
||||
|
||||
if len(selected_frames) > 30:
|
||||
new_interval = len(base64_frames) // 30
|
||||
selected_frames = base64_frames[::new_interval]
|
||||
|
||||
print(f'Totally {len(selected_frames)} would be analyze...\n')
|
||||
|
||||
idx = 0
|
||||
for base64_frame in selected_frames:
|
||||
idx += 1
|
||||
print(f'Process the {file_path}, current No. {idx * frame_interval} frame...')
|
||||
# TODO: record the COST of the API call
|
||||
try:
|
||||
response = _get_openai_client().chat.completions.create(
|
||||
model=_get_openai_model(),
|
||||
messages=_prepare_image_messages(task, base64_frame),
|
||||
max_tokens=_get_max_token(),
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
current_frame_content = f"Frame {idx}'s content: {content}\n"
|
||||
print(current_frame_content)
|
||||
|
||||
except Exception as error:
|
||||
print(f'Error with the request: {error}')
|
||||
|
||||
|
||||
def parse_pptx(file_path: str) -> None:
|
||||
"""Parses the content of a pptx file and prints it.
|
||||
|
||||
Args:
|
||||
file_path: str: The path to the file to open.
|
||||
"""
|
||||
print(f'[Reading PowerPoint file from {file_path}]')
|
||||
try:
|
||||
pres = Presentation(str(file_path))
|
||||
text = []
|
||||
for slide_idx, slide in enumerate(pres.slides):
|
||||
text.append(f'@@ Slide {slide_idx + 1} @@')
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, 'text'):
|
||||
text.append(shape.text)
|
||||
print('\n'.join(text))
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error reading PowerPoint file: {e}')
|
||||
|
||||
|
||||
__all__ = [
|
||||
'parse_pdf',
|
||||
'parse_docx',
|
||||
'parse_latex',
|
||||
'parse_pptx',
|
||||
]
|
||||
|
||||
# This is called from OpenHands's side
|
||||
# If SANDBOX_ENV_OPENAI_API_KEY is set, we will be able to use these tools in the sandbox environment
|
||||
if _get_openai_api_key() and _get_openai_base_url():
|
||||
__all__ += ['parse_audio', 'parse_video', 'parse_image']
|
||||
202
openhands/runtime/plugins/agent_skills/utils/aider/LICENSE.txt
Normal file
202
openhands/runtime/plugins/agent_skills/utils/aider/LICENSE.txt
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,8 @@
|
||||
# Aider is AI pair programming in your terminal
|
||||
|
||||
Aider lets you pair program with LLMs,
|
||||
to edit code in your local git repository.
|
||||
|
||||
Please see the [original repository](https://github.com/paul-gauthier/aider) for more information.
|
||||
|
||||
OpenHands has adapted and integrated its linter module ([original code](https://github.com/paul-gauthier/aider/blob/main/aider/linter.py)).
|
||||
@@ -0,0 +1,6 @@
|
||||
if __package__ is None or __package__ == '':
|
||||
from linter import Linter, LintResult
|
||||
else:
|
||||
from .linter import Linter, LintResult
|
||||
|
||||
__all__ = ['Linter', 'LintResult']
|
||||
223
openhands/runtime/plugins/agent_skills/utils/aider/linter.py
Normal file
223
openhands/runtime/plugins/agent_skills/utils/aider/linter.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from grep_ast import TreeContext, filename_to_lang
|
||||
from tree_sitter_languages import get_parser # noqa: E402
|
||||
|
||||
# tree_sitter is throwing a FutureWarning
|
||||
warnings.simplefilter('ignore', category=FutureWarning)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LintResult:
|
||||
text: str
|
||||
lines: list
|
||||
|
||||
|
||||
class Linter:
|
||||
def __init__(self, encoding='utf-8', root=None):
|
||||
self.encoding = encoding
|
||||
self.root = root
|
||||
|
||||
self.languages = dict(
|
||||
python=self.py_lint,
|
||||
)
|
||||
self.all_lint_cmd = None
|
||||
|
||||
def set_linter(self, lang, cmd):
|
||||
if lang:
|
||||
self.languages[lang] = cmd
|
||||
return
|
||||
|
||||
self.all_lint_cmd = cmd
|
||||
|
||||
def get_rel_fname(self, fname):
|
||||
if self.root:
|
||||
return os.path.relpath(fname, self.root)
|
||||
else:
|
||||
return fname
|
||||
|
||||
def run_cmd(self, cmd, rel_fname, code):
|
||||
cmd += ' ' + rel_fname
|
||||
cmd = cmd.split()
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
||||
)
|
||||
stdout, _ = process.communicate()
|
||||
errors = stdout.decode().strip()
|
||||
self.returncode = process.returncode
|
||||
if self.returncode == 0:
|
||||
return # zero exit status
|
||||
|
||||
cmd = ' '.join(cmd)
|
||||
res = ''
|
||||
res += errors
|
||||
line_num = extract_error_line_from(res)
|
||||
return LintResult(text=res, lines=[line_num])
|
||||
|
||||
def get_abs_fname(self, fname):
|
||||
if os.path.isabs(fname):
|
||||
return fname
|
||||
elif os.path.isfile(fname):
|
||||
rel_fname = self.get_rel_fname(fname)
|
||||
return os.path.abspath(rel_fname)
|
||||
else: # if a temp file
|
||||
return self.get_rel_fname(fname)
|
||||
|
||||
def lint(self, fname, cmd=None) -> LintResult | None:
|
||||
code = Path(fname).read_text(self.encoding)
|
||||
absolute_fname = self.get_abs_fname(fname)
|
||||
if cmd:
|
||||
cmd = cmd.strip()
|
||||
if not cmd:
|
||||
lang = filename_to_lang(fname)
|
||||
if not lang:
|
||||
return None
|
||||
if self.all_lint_cmd:
|
||||
cmd = self.all_lint_cmd
|
||||
else:
|
||||
cmd = self.languages.get(lang)
|
||||
if callable(cmd):
|
||||
linkres = cmd(fname, absolute_fname, code)
|
||||
elif cmd:
|
||||
linkres = self.run_cmd(cmd, absolute_fname, code)
|
||||
else:
|
||||
linkres = basic_lint(absolute_fname, code)
|
||||
return linkres
|
||||
|
||||
def flake_lint(self, rel_fname, code):
|
||||
fatal = 'F821,F822,F831,E112,E113,E999,E902'
|
||||
flake8 = f'flake8 --select={fatal} --isolated'
|
||||
|
||||
try:
|
||||
flake_res = self.run_cmd(flake8, rel_fname, code)
|
||||
except FileNotFoundError:
|
||||
flake_res = None
|
||||
return flake_res
|
||||
|
||||
def py_lint(self, fname, rel_fname, code):
|
||||
error = self.flake_lint(rel_fname, code)
|
||||
if not error:
|
||||
error = lint_python_compile(fname, code)
|
||||
if not error:
|
||||
error = basic_lint(rel_fname, code)
|
||||
return error
|
||||
|
||||
|
||||
def lint_python_compile(fname, code):
|
||||
try:
|
||||
compile(code, fname, 'exec') # USE TRACEBACK BELOW HERE
|
||||
return
|
||||
except IndentationError as err:
|
||||
end_lineno = getattr(err, 'end_lineno', err.lineno)
|
||||
if isinstance(end_lineno, int):
|
||||
line_numbers = list(range(end_lineno - 1, end_lineno))
|
||||
else:
|
||||
line_numbers = []
|
||||
|
||||
tb_lines = traceback.format_exception(type(err), err, err.__traceback__)
|
||||
last_file_i = 0
|
||||
|
||||
target = '# USE TRACEBACK'
|
||||
target += ' BELOW HERE'
|
||||
for i in range(len(tb_lines)):
|
||||
if target in tb_lines[i]:
|
||||
last_file_i = i
|
||||
break
|
||||
tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :]
|
||||
|
||||
res = ''.join(tb_lines)
|
||||
return LintResult(text=res, lines=line_numbers)
|
||||
|
||||
|
||||
def basic_lint(fname, code):
|
||||
"""
|
||||
Use tree-sitter to look for syntax errors, display them with tree context.
|
||||
"""
|
||||
|
||||
lang = filename_to_lang(fname)
|
||||
if not lang:
|
||||
return
|
||||
|
||||
parser = get_parser(lang)
|
||||
tree = parser.parse(bytes(code, 'utf-8'))
|
||||
|
||||
errors = traverse_tree(tree.root_node)
|
||||
if not errors:
|
||||
return
|
||||
return LintResult(text=f'{fname}:{errors[0]}', lines=errors)
|
||||
|
||||
|
||||
def extract_error_line_from(lint_error):
|
||||
# moved from openhands.agentskills#_lint_file
|
||||
for line in lint_error.splitlines(True):
|
||||
if line.strip():
|
||||
# The format of the error message is: <filename>:<line>:<column>: <error code> <error message>
|
||||
parts = line.split(':')
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
first_error_line = int(parts[1])
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
return first_error_line
|
||||
|
||||
|
||||
def tree_context(fname, code, line_nums):
|
||||
context = TreeContext(
|
||||
fname,
|
||||
code,
|
||||
color=False,
|
||||
line_number=True,
|
||||
child_context=False,
|
||||
last_line=False,
|
||||
margin=0,
|
||||
mark_lois=True,
|
||||
loi_pad=3,
|
||||
# header_max=30,
|
||||
show_top_of_file_parent_scope=False,
|
||||
)
|
||||
line_nums = set(line_nums)
|
||||
context.add_lines_of_interest(line_nums)
|
||||
context.add_context()
|
||||
output = context.format()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Traverse the tree to find errors
|
||||
def traverse_tree(node):
|
||||
errors = []
|
||||
if node.type == 'ERROR' or node.is_missing:
|
||||
line_no = node.start_point[0] + 1
|
||||
errors.append(line_no)
|
||||
|
||||
for child in node.children:
|
||||
errors += traverse_tree(child)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to parse files provided as command line arguments.
|
||||
"""
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python linter.py <file1> <file2> ...')
|
||||
sys.exit(1)
|
||||
|
||||
linter = Linter(root=os.getcwd())
|
||||
for file_path in sys.argv[1:]:
|
||||
errors = linter.lint(file_path)
|
||||
if errors:
|
||||
print(errors)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
30
openhands/runtime/plugins/agent_skills/utils/config.py
Normal file
30
openhands/runtime/plugins/agent_skills/utils/config.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
# ==================================================================================================
|
||||
# OPENAI
|
||||
# TODO: Move this to EventStream Actions when EventStreamRuntime is fully implemented
|
||||
# NOTE: we need to get env vars inside functions because they will be set in IPython
|
||||
# AFTER the agentskills is imported (the case for EventStreamRuntime)
|
||||
# ==================================================================================================
|
||||
def _get_openai_api_key():
|
||||
return os.getenv('OPENAI_API_KEY', os.getenv('SANDBOX_ENV_OPENAI_API_KEY', ''))
|
||||
|
||||
|
||||
def _get_openai_base_url():
|
||||
return os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1')
|
||||
|
||||
|
||||
def _get_openai_model():
|
||||
return os.getenv('OPENAI_MODEL', 'gpt-4o-2024-05-13')
|
||||
|
||||
|
||||
def _get_max_token():
|
||||
return os.getenv('MAX_TOKEN', 500)
|
||||
|
||||
|
||||
def _get_openai_client():
|
||||
client = OpenAI(api_key=_get_openai_api_key(), base_url=_get_openai_base_url())
|
||||
return client
|
||||
11
openhands/runtime/plugins/agent_skills/utils/dependency.py
Normal file
11
openhands/runtime/plugins/agent_skills/utils/dependency.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def import_functions(
|
||||
module: ModuleType, function_names: list[str], target_globals: dict
|
||||
) -> None:
|
||||
for name in function_names:
|
||||
if hasattr(module, name):
|
||||
target_globals[name] = getattr(module, name)
|
||||
else:
|
||||
raise ValueError(f'Function {name} not found in {module.__name__}')
|
||||
76
openhands/runtime/plugins/jupyter/__init__.py
Normal file
76
openhands/runtime/plugins/jupyter/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import Action, IPythonRunCellAction
|
||||
from openhands.events.observation import IPythonRunCellObservation
|
||||
from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
|
||||
from .execute_server import JupyterKernel
|
||||
|
||||
|
||||
@dataclass
|
||||
class JupyterRequirement(PluginRequirement):
|
||||
name: str = 'jupyter'
|
||||
|
||||
|
||||
class JupyterPlugin(Plugin):
|
||||
name: str = 'jupyter'
|
||||
|
||||
async def initialize(self, username: str, kernel_id: str = 'openhands-default'):
|
||||
self.kernel_gateway_port = find_available_tcp_port()
|
||||
self.kernel_id = kernel_id
|
||||
self.gateway_process = subprocess.Popen(
|
||||
(
|
||||
f"su - {username} -s /bin/bash << 'EOF'\n"
|
||||
'cd /openhands/code\n'
|
||||
'export POETRY_VIRTUALENVS_PATH=/openhands/poetry;\n'
|
||||
'export PYTHONPATH=/openhands/code:$PYTHONPATH;\n'
|
||||
'/openhands/miniforge3/bin/mamba run -n base '
|
||||
'poetry run jupyter kernelgateway '
|
||||
'--KernelGatewayApp.ip=0.0.0.0 '
|
||||
f'--KernelGatewayApp.port={self.kernel_gateway_port}\n'
|
||||
'EOF'
|
||||
),
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
)
|
||||
# read stdout until the kernel gateway is ready
|
||||
output = ''
|
||||
while True and self.gateway_process.stdout is not None:
|
||||
line = self.gateway_process.stdout.readline().decode('utf-8')
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
time.sleep(1)
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
|
||||
logger.info(
|
||||
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
|
||||
)
|
||||
|
||||
async def _run(self, action: Action) -> IPythonRunCellObservation:
|
||||
"""Internal method to run a code cell in the jupyter kernel."""
|
||||
if not isinstance(action, IPythonRunCellAction):
|
||||
raise ValueError(
|
||||
f'Jupyter plugin only supports IPythonRunCellAction, but got {action}'
|
||||
)
|
||||
|
||||
if not hasattr(self, 'kernel'):
|
||||
self.kernel = JupyterKernel(
|
||||
f'localhost:{self.kernel_gateway_port}', self.kernel_id
|
||||
)
|
||||
|
||||
if not self.kernel.initialized:
|
||||
await self.kernel.initialize()
|
||||
output = await self.kernel.execute(action.code, timeout=action.timeout)
|
||||
return IPythonRunCellObservation(
|
||||
content=output,
|
||||
code=action.code,
|
||||
)
|
||||
|
||||
async def run(self, action: Action) -> IPythonRunCellObservation:
|
||||
obs = await self._run(action)
|
||||
return obs
|
||||
285
openhands/runtime/plugins/jupyter/execute_server.py
Executable file
285
openhands/runtime/plugins/jupyter/execute_server.py
Executable file
@@ -0,0 +1,285 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from uuid import uuid4
|
||||
|
||||
import tornado
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
from tornado.escape import json_decode, json_encode, url_escape
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
||||
from tornado.ioloop import PeriodicCallback
|
||||
from tornado.websocket import websocket_connect
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def strip_ansi(o: str) -> str:
|
||||
"""Removes ANSI escape sequences from `o`, as defined by ECMA-048 in
|
||||
http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf
|
||||
|
||||
# https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py
|
||||
|
||||
>>> strip_ansi("\\033[33mLorem ipsum\\033[0m")
|
||||
'Lorem ipsum'
|
||||
|
||||
>>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.")
|
||||
'Lorem Ipsum sit\\namet.'
|
||||
|
||||
>>> strip_ansi("")
|
||||
''
|
||||
|
||||
>>> strip_ansi("\\x1b[0m")
|
||||
''
|
||||
|
||||
>>> strip_ansi("Lorem")
|
||||
'Lorem'
|
||||
|
||||
>>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m')
|
||||
'Lorem ipsum'
|
||||
|
||||
>>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m')
|
||||
'Lorem dolor sit ipsum'
|
||||
"""
|
||||
# pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/')
|
||||
pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m')
|
||||
stripped = pattern.sub('', o)
|
||||
return stripped
|
||||
|
||||
|
||||
class JupyterKernel:
|
||||
def __init__(self, url_suffix, convid, lang='python'):
|
||||
self.base_url = f'http://{url_suffix}'
|
||||
self.base_ws_url = f'ws://{url_suffix}'
|
||||
self.lang = lang
|
||||
self.kernel_id = None
|
||||
self.ws = None
|
||||
self.convid = convid
|
||||
logging.info(
|
||||
f'Jupyter kernel created for conversation {convid} at {url_suffix}'
|
||||
)
|
||||
|
||||
self.heartbeat_interval = 10000 # 10 seconds
|
||||
self.heartbeat_callback = None
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.execute(r'%colors nocolor')
|
||||
# pre-defined tools
|
||||
self.tools_to_run: list[str] = [
|
||||
# TODO: You can add code for your pre-defined tools here
|
||||
]
|
||||
for tool in self.tools_to_run:
|
||||
res = await self.execute(tool)
|
||||
logging.info(f'Tool [{tool}] initialized:\n{res}')
|
||||
self.initialized = True
|
||||
|
||||
async def _send_heartbeat(self):
|
||||
if not self.ws:
|
||||
return
|
||||
try:
|
||||
self.ws.ping()
|
||||
# logging.info('Heartbeat sent...')
|
||||
except tornado.iostream.StreamClosedError:
|
||||
# logging.info('Heartbeat failed, reconnecting...')
|
||||
try:
|
||||
await self._connect()
|
||||
except ConnectionRefusedError:
|
||||
logging.info(
|
||||
'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?'
|
||||
)
|
||||
|
||||
async def _connect(self):
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
client = AsyncHTTPClient()
|
||||
if not self.kernel_id:
|
||||
n_tries = 5
|
||||
while n_tries > 0:
|
||||
try:
|
||||
response = await client.fetch(
|
||||
'{}/api/kernels'.format(self.base_url),
|
||||
method='POST',
|
||||
body=json_encode({'name': self.lang}),
|
||||
)
|
||||
kernel = json_decode(response.body)
|
||||
self.kernel_id = kernel['id']
|
||||
break
|
||||
except Exception:
|
||||
# kernels are not ready yet
|
||||
n_tries -= 1
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if n_tries == 0:
|
||||
raise ConnectionRefusedError('Failed to connect to kernel')
|
||||
|
||||
ws_req = HTTPRequest(
|
||||
url='{}/api/kernels/{}/channels'.format(
|
||||
self.base_ws_url, url_escape(self.kernel_id)
|
||||
)
|
||||
)
|
||||
self.ws = await websocket_connect(ws_req)
|
||||
logging.info('Connected to kernel websocket')
|
||||
|
||||
# Setup heartbeat
|
||||
if self.heartbeat_callback:
|
||||
self.heartbeat_callback.stop()
|
||||
self.heartbeat_callback = PeriodicCallback(
|
||||
self._send_heartbeat, self.heartbeat_interval
|
||||
)
|
||||
self.heartbeat_callback.start()
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(ConnectionRefusedError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(2),
|
||||
)
|
||||
async def execute(self, code, timeout=120):
|
||||
if not self.ws:
|
||||
await self._connect()
|
||||
|
||||
msg_id = uuid4().hex
|
||||
assert self.ws is not None
|
||||
res = await self.ws.write_message(
|
||||
json_encode(
|
||||
{
|
||||
'header': {
|
||||
'username': '',
|
||||
'version': '5.0',
|
||||
'session': '',
|
||||
'msg_id': msg_id,
|
||||
'msg_type': 'execute_request',
|
||||
},
|
||||
'parent_header': {},
|
||||
'channel': 'shell',
|
||||
'content': {
|
||||
'code': code,
|
||||
'silent': False,
|
||||
'store_history': False,
|
||||
'user_expressions': {},
|
||||
'allow_stdin': False,
|
||||
},
|
||||
'metadata': {},
|
||||
'buffers': {},
|
||||
}
|
||||
)
|
||||
)
|
||||
logging.info(f'Executed code in jupyter kernel:\n{res}')
|
||||
|
||||
outputs = []
|
||||
|
||||
async def wait_for_messages():
|
||||
execution_done = False
|
||||
while not execution_done:
|
||||
assert self.ws is not None
|
||||
msg = await self.ws.read_message()
|
||||
msg = json_decode(msg)
|
||||
msg_type = msg['msg_type']
|
||||
parent_msg_id = msg['parent_header'].get('msg_id', None)
|
||||
|
||||
if parent_msg_id != msg_id:
|
||||
continue
|
||||
|
||||
if os.environ.get('DEBUG'):
|
||||
logging.info(
|
||||
f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}"
|
||||
)
|
||||
|
||||
if msg_type == 'error':
|
||||
traceback = '\n'.join(msg['content']['traceback'])
|
||||
outputs.append(traceback)
|
||||
execution_done = True
|
||||
elif msg_type == 'stream':
|
||||
outputs.append(msg['content']['text'])
|
||||
elif msg_type in ['execute_result', 'display_data']:
|
||||
outputs.append(msg['content']['data']['text/plain'])
|
||||
if 'image/png' in msg['content']['data']:
|
||||
# use markdone to display image (in case of large image)
|
||||
outputs.append(
|
||||
f"\n\n"
|
||||
)
|
||||
|
||||
elif msg_type == 'execute_reply':
|
||||
execution_done = True
|
||||
return execution_done
|
||||
|
||||
async def interrupt_kernel():
|
||||
client = AsyncHTTPClient()
|
||||
interrupt_response = await client.fetch(
|
||||
f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt',
|
||||
method='POST',
|
||||
body=json_encode({'kernel_id': self.kernel_id}),
|
||||
)
|
||||
logging.info(f'Kernel interrupted: {interrupt_response}')
|
||||
|
||||
try:
|
||||
execution_done = await asyncio.wait_for(wait_for_messages(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
await interrupt_kernel()
|
||||
return f'[Execution timed out ({timeout} seconds).]'
|
||||
|
||||
if not outputs and execution_done:
|
||||
ret = '[Code executed successfully with no output]'
|
||||
else:
|
||||
ret = ''.join(outputs)
|
||||
|
||||
# Remove ANSI
|
||||
ret = strip_ansi(ret)
|
||||
|
||||
if os.environ.get('DEBUG'):
|
||||
logging.info(f'OUTPUT:\n{ret}')
|
||||
return ret
|
||||
|
||||
async def shutdown_async(self):
|
||||
if self.kernel_id:
|
||||
client = AsyncHTTPClient()
|
||||
await client.fetch(
|
||||
'{}/api/kernels/{}'.format(self.base_url, self.kernel_id),
|
||||
method='DELETE',
|
||||
)
|
||||
self.kernel_id = None
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
|
||||
class ExecuteHandler(tornado.web.RequestHandler):
|
||||
def initialize(self, jupyter_kernel):
|
||||
self.jupyter_kernel = jupyter_kernel
|
||||
|
||||
async def post(self):
|
||||
data = json_decode(self.request.body)
|
||||
code = data.get('code')
|
||||
|
||||
if not code:
|
||||
self.set_status(400)
|
||||
self.write('Missing code')
|
||||
return
|
||||
|
||||
output = await self.jupyter_kernel.execute(code)
|
||||
|
||||
self.write(output)
|
||||
|
||||
|
||||
def make_app():
|
||||
jupyter_kernel = JupyterKernel(
|
||||
f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT')}",
|
||||
os.environ.get('JUPYTER_GATEWAY_KERNEL_ID'),
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize())
|
||||
|
||||
return tornado.web.Application(
|
||||
[
|
||||
(r'/execute', ExecuteHandler, {'jupyter_kernel': jupyter_kernel}),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = make_app()
|
||||
app.listen(os.environ.get('JUPYTER_EXEC_SERVER_PORT'))
|
||||
tornado.ioloop.IOLoop.current().start()
|
||||
31
openhands/runtime/plugins/requirement.py
Normal file
31
openhands/runtime/plugins/requirement.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.observation import Observation
|
||||
|
||||
|
||||
class Plugin:
|
||||
"""Base class for a plugin.
|
||||
|
||||
This will be initialized by the runtime client, which will run inside docker.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, username: str):
|
||||
"""Initialize the plugin."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, action: Action) -> Observation:
|
||||
"""Run the plugin for a given action."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginRequirement:
|
||||
"""Requirement for a plugin."""
|
||||
|
||||
name: str
|
||||
211
openhands/runtime/runtime.py
Normal file
211
openhands/runtime/runtime.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
from openhands.core.config import AppConfig, SandboxConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
|
||||
|
||||
|
||||
def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
|
||||
ret = {}
|
||||
for key in os.environ:
|
||||
if key.startswith('SANDBOX_ENV_'):
|
||||
sandbox_key = key.removeprefix('SANDBOX_ENV_')
|
||||
ret[sandbox_key] = os.environ[key]
|
||||
if sandbox_config.enable_auto_lint:
|
||||
ret['ENABLE_AUTO_LINT'] = 'true'
|
||||
return ret
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""The runtime is how the agent interacts with the external environment.
|
||||
This includes a bash sandbox, a browser, and filesystem interactions.
|
||||
|
||||
sid is the session id, which is used to identify the current user session.
|
||||
"""
|
||||
|
||||
sid: str
|
||||
config: AppConfig
|
||||
DEFAULT_ENV_VARS: dict[str, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig,
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
|
||||
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
|
||||
|
||||
self.config = copy.deepcopy(config)
|
||||
self.DEFAULT_ENV_VARS = _default_env_vars(config.sandbox)
|
||||
atexit.register(self.close_sync)
|
||||
logger.debug(f'Runtime `{sid}` config:\n{self.config}')
|
||||
|
||||
async def ainit(self, env_vars: dict[str, str] | None = None) -> None:
|
||||
"""
|
||||
Initialize the runtime (asynchronously).
|
||||
|
||||
This method should be called after the runtime's constructor.
|
||||
"""
|
||||
if self.DEFAULT_ENV_VARS:
|
||||
logger.debug(f'Adding default env vars: {self.DEFAULT_ENV_VARS}')
|
||||
await self.add_env_vars(self.DEFAULT_ENV_VARS)
|
||||
if env_vars is not None:
|
||||
logger.debug(f'Adding provided env vars: {env_vars}')
|
||||
await self.add_env_vars(env_vars)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def close_sync(self) -> None:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No running event loop, use asyncio.run()
|
||||
asyncio.run(self.close())
|
||||
else:
|
||||
# There is a running event loop, create a task
|
||||
if loop.is_running():
|
||||
loop.create_task(self.close())
|
||||
else:
|
||||
loop.run_until_complete(self.close())
|
||||
|
||||
# ====================================================================
|
||||
|
||||
async def add_env_vars(self, env_vars: dict[str, str]) -> None:
|
||||
# Add env vars to the IPython shell (if Jupyter is used)
|
||||
if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
|
||||
code = 'import os\n'
|
||||
for key, value in env_vars.items():
|
||||
# Note: json.dumps gives us nice escaping for free
|
||||
code += f'os.environ["{key}"] = {json.dumps(value)}\n'
|
||||
code += '\n'
|
||||
obs = await self.run_ipython(IPythonRunCellAction(code))
|
||||
logger.info(f'Added env vars to IPython: code={code}, obs={obs}')
|
||||
|
||||
# Add env vars to the Bash shell
|
||||
cmd = ''
|
||||
for key, value in env_vars.items():
|
||||
# Note: json.dumps gives us nice escaping for free
|
||||
cmd += f'export {key}={json.dumps(value)}; '
|
||||
if not cmd:
|
||||
return
|
||||
cmd = cmd.strip()
|
||||
logger.debug(f'Adding env var: {cmd}')
|
||||
obs = await self.run(CmdRunAction(cmd))
|
||||
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event) -> None:
|
||||
if isinstance(event, Action):
|
||||
# set timeout to default if not set
|
||||
if event.timeout is None:
|
||||
event.timeout = self.config.sandbox.timeout
|
||||
assert event.timeout is not None
|
||||
observation = await self.run_action(event)
|
||||
observation._cause = event.id # type: ignore[attr-defined]
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
async def run_action(self, action: Action) -> Observation:
|
||||
"""Run an action and return the resulting observation.
|
||||
If the action is not runnable in any runtime, a NullObservation is returned.
|
||||
If the action is not supported by the current runtime, an ErrorObservation is returned.
|
||||
"""
|
||||
if not action.runnable:
|
||||
return NullObservation('')
|
||||
if (
|
||||
hasattr(action, 'is_confirmed')
|
||||
and action.is_confirmed == ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
return ErrorObservation(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return ErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.'
|
||||
)
|
||||
if (
|
||||
hasattr(action, 'is_confirmed')
|
||||
and action.is_confirmed == ActionConfirmationStatus.REJECTED
|
||||
):
|
||||
return UserRejectObservation(
|
||||
'Action has been rejected by the user! Waiting for further user input.'
|
||||
)
|
||||
observation = await getattr(self, action_type)(action)
|
||||
return observation
|
||||
|
||||
# ====================================================================
|
||||
# Action execution
|
||||
# ====================================================================
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, action: CmdRunAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
pass
|
||||
|
||||
# ====================================================================
|
||||
# File operations
|
||||
# ====================================================================
|
||||
|
||||
@abstractmethod
|
||||
async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
|
||||
raise NotImplementedError('This method is not implemented in the base class.')
|
||||
|
||||
@abstractmethod
|
||||
async def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files in the sandbox.
|
||||
|
||||
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
|
||||
"""
|
||||
raise NotImplementedError('This method is not implemented in the base class.')
|
||||
4
openhands/runtime/utils/__init__.py
Normal file
4
openhands/runtime/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .bash import split_bash_commands
|
||||
from .system import find_available_tcp_port
|
||||
|
||||
__all__ = ['find_available_tcp_port', 'split_bash_commands']
|
||||
52
openhands/runtime/utils/bash.py
Normal file
52
openhands/runtime/utils/bash.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import bashlex
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def split_bash_commands(commands):
|
||||
try:
|
||||
parsed = bashlex.parse(commands)
|
||||
except bashlex.errors.ParsingError as e:
|
||||
logger.debug(
|
||||
f'Failed to parse bash commands\n'
|
||||
f'[input]: {commands}\n'
|
||||
f'[warning]: {e}\n'
|
||||
f'The original command will be returned as is.'
|
||||
)
|
||||
# If parsing fails, return the original commands
|
||||
return [commands]
|
||||
|
||||
result: list[str] = []
|
||||
last_end = 0
|
||||
|
||||
for node in parsed:
|
||||
start, end = node.pos
|
||||
|
||||
# Include any text between the last command and this one
|
||||
if start > last_end:
|
||||
between = commands[last_end:start]
|
||||
logger.debug(f'BASH PARSING between: {between}')
|
||||
if result:
|
||||
result[-1] += between.rstrip()
|
||||
elif between.strip():
|
||||
# THIS SHOULD NOT HAPPEN
|
||||
result.append(between.rstrip())
|
||||
|
||||
# Extract the command, preserving original formatting
|
||||
command = commands[start:end].rstrip()
|
||||
logger.debug(f'BASH PARSING command: {command}')
|
||||
result.append(command)
|
||||
|
||||
last_end = end
|
||||
|
||||
# Add any remaining text after the last command to the last command
|
||||
remaining = commands[last_end:].rstrip()
|
||||
logger.debug(f'BASH PARSING remaining: {remaining}')
|
||||
if last_end < len(commands) and result:
|
||||
result[-1] += remaining
|
||||
logger.debug(f'BASH PARSING result[-1] += remaining: {result[-1]}')
|
||||
elif last_end < len(commands):
|
||||
if remaining:
|
||||
result.append(remaining)
|
||||
logger.debug(f'BASH PARSING result.append(remaining): {result[-1]}')
|
||||
return result
|
||||
145
openhands/runtime/utils/files.py
Normal file
145
openhands/runtime/utils/files.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
Observation,
|
||||
)
|
||||
|
||||
|
||||
def resolve_path(
|
||||
file_path: str,
|
||||
working_directory: str,
|
||||
workspace_base: str,
|
||||
workspace_mount_path_in_sandbox: str,
|
||||
):
|
||||
"""Resolve a file path to a path on the host filesystem.
|
||||
|
||||
Args:
|
||||
file_path: The path to resolve.
|
||||
working_directory: The working directory of the agent.
|
||||
workspace_mount_path_in_sandbox: The path to the workspace inside the sandbox.
|
||||
workspace_base: The base path of the workspace on the host filesystem.
|
||||
|
||||
Returns:
|
||||
The resolved path on the host filesystem.
|
||||
"""
|
||||
path_in_sandbox = Path(file_path)
|
||||
|
||||
# Apply working directory
|
||||
if not path_in_sandbox.is_absolute():
|
||||
path_in_sandbox = Path(working_directory) / path_in_sandbox
|
||||
|
||||
# Sanitize the path with respect to the root of the full sandbox
|
||||
# (deny any .. path traversal to parent directories of the sandbox)
|
||||
abs_path_in_sandbox = path_in_sandbox.resolve()
|
||||
|
||||
# If the path is outside the workspace, deny it
|
||||
if not abs_path_in_sandbox.is_relative_to(workspace_mount_path_in_sandbox):
|
||||
raise PermissionError(f'File access not permitted: {file_path}')
|
||||
|
||||
# Get path relative to the root of the workspace inside the sandbox
|
||||
path_in_workspace = abs_path_in_sandbox.relative_to(
|
||||
Path(workspace_mount_path_in_sandbox)
|
||||
)
|
||||
|
||||
# Get path relative to host
|
||||
path_in_host_workspace = Path(workspace_base) / path_in_workspace
|
||||
|
||||
return path_in_host_workspace
|
||||
|
||||
|
||||
def read_lines(all_lines: list[str], start=0, end=-1):
|
||||
start = max(start, 0)
|
||||
start = min(start, len(all_lines))
|
||||
end = -1 if end == -1 else max(end, 0)
|
||||
end = min(end, len(all_lines))
|
||||
if end == -1:
|
||||
if start == 0:
|
||||
return all_lines
|
||||
else:
|
||||
return all_lines[start:]
|
||||
else:
|
||||
num_lines = len(all_lines)
|
||||
begin = max(0, min(start, num_lines - 2))
|
||||
end = -1 if end > num_lines else max(begin + 1, end)
|
||||
return all_lines[begin:end]
|
||||
|
||||
|
||||
async def read_file(
|
||||
path, workdir, workspace_base, workspace_mount_path_in_sandbox, start=0, end=-1
|
||||
) -> Observation:
|
||||
try:
|
||||
whole_path = resolve_path(
|
||||
path, workdir, workspace_base, workspace_mount_path_in_sandbox
|
||||
)
|
||||
except PermissionError:
|
||||
return ErrorObservation(
|
||||
f"You're not allowed to access this path: {path}. You can only access paths inside the workspace."
|
||||
)
|
||||
|
||||
try:
|
||||
with open(whole_path, 'r', encoding='utf-8') as file:
|
||||
lines = read_lines(file.readlines(), start, end)
|
||||
except FileNotFoundError:
|
||||
return ErrorObservation(f'File not found: {path}')
|
||||
except UnicodeDecodeError:
|
||||
return ErrorObservation(f'File could not be decoded as utf-8: {path}')
|
||||
except IsADirectoryError:
|
||||
return ErrorObservation(f'Path is a directory: {path}. You can only read files')
|
||||
code_view = ''.join(lines)
|
||||
return FileReadObservation(path=path, content=code_view)
|
||||
|
||||
|
||||
def insert_lines(
|
||||
to_insert: list[str], original: list[str], start: int = 0, end: int = -1
|
||||
):
|
||||
"""Insert the new content to the original content based on start and end"""
|
||||
new_lines = [''] if start == 0 else original[:start]
|
||||
new_lines += [i + '\n' for i in to_insert]
|
||||
new_lines += [''] if end == -1 else original[end:]
|
||||
return new_lines
|
||||
|
||||
|
||||
async def write_file(
|
||||
path,
|
||||
workdir,
|
||||
workspace_base,
|
||||
workspace_mount_path_in_sandbox,
|
||||
content,
|
||||
start=0,
|
||||
end=-1,
|
||||
) -> Observation:
|
||||
insert = content.split('\n')
|
||||
|
||||
try:
|
||||
whole_path = resolve_path(
|
||||
path, workdir, workspace_base, workspace_mount_path_in_sandbox
|
||||
)
|
||||
if not os.path.exists(os.path.dirname(whole_path)):
|
||||
os.makedirs(os.path.dirname(whole_path))
|
||||
mode = 'w' if not os.path.exists(whole_path) else 'r+'
|
||||
try:
|
||||
with open(whole_path, mode, encoding='utf-8') as file:
|
||||
if mode != 'w':
|
||||
all_lines = file.readlines()
|
||||
new_file = insert_lines(insert, all_lines, start, end)
|
||||
else:
|
||||
new_file = [i + '\n' for i in insert]
|
||||
|
||||
file.seek(0)
|
||||
file.writelines(new_file)
|
||||
file.truncate()
|
||||
except FileNotFoundError:
|
||||
return ErrorObservation(f'File not found: {path}')
|
||||
except IsADirectoryError:
|
||||
return ErrorObservation(
|
||||
f'Path is a directory: {path}. You can only write to files'
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
return ErrorObservation(f'File could not be decoded as utf-8: {path}')
|
||||
except PermissionError:
|
||||
return ErrorObservation(f'Malformed paths not permitted: {path}')
|
||||
return FileWriteObservation(content='', path=path)
|
||||
442
openhands/runtime/utils/runtime_build.py
Normal file
442
openhands/runtime/utils/runtime_build.py
Normal file
@@ -0,0 +1,442 @@
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import docker
|
||||
import toml
|
||||
from dirhash import dirhash
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
import openhands
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder, RuntimeBuilder
|
||||
|
||||
RUNTIME_IMAGE_REPO = os.getenv(
|
||||
'OD_RUNTIME_RUNTIME_IMAGE_REPO', 'ghcr.io/openhands/od_runtime'
|
||||
)
|
||||
|
||||
|
||||
def _get_package_version():
|
||||
"""Read the version from pyproject.toml.
|
||||
|
||||
Returns:
|
||||
- The version specified in pyproject.toml under [tool.poetry]
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(openhands.__file__)))
|
||||
pyproject_path = os.path.join(project_root, 'pyproject.toml')
|
||||
with open(pyproject_path, 'r') as f:
|
||||
pyproject_data = toml.load(f)
|
||||
return pyproject_data['tool']['poetry']['version']
|
||||
|
||||
|
||||
def _create_project_source_dist():
|
||||
"""Create a source distribution of the project.
|
||||
|
||||
Returns:
|
||||
- str: The path to the project tarball
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(openhands.__file__)))
|
||||
logger.info(f'Using project root: {project_root}')
|
||||
|
||||
# run "python -m build -s" on project_root to create project tarball
|
||||
result = subprocess.run(f'python -m build -s {project_root}', shell=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f'Build failed: {result}')
|
||||
raise Exception(f'Build failed: {result}')
|
||||
|
||||
# Fetch the correct version from pyproject.toml
|
||||
package_version = _get_package_version()
|
||||
tarball_path = os.path.join(
|
||||
project_root, 'dist', f'openhands-{package_version}.tar.gz'
|
||||
)
|
||||
if not os.path.exists(tarball_path):
|
||||
logger.error(f'Source distribution not found at {tarball_path}')
|
||||
raise Exception(f'Source distribution not found at {tarball_path}')
|
||||
logger.info(f'Source distribution created at {tarball_path}')
|
||||
|
||||
return tarball_path
|
||||
|
||||
|
||||
def _put_source_code_to_dir(temp_dir: str):
|
||||
"""Builds the project source tarball. Copies it to temp_dir and unpacks it.
|
||||
The OpenHands source code ends up in the temp_dir/code directory
|
||||
|
||||
Parameters:
|
||||
- temp_dir (str): The directory to put the source code in
|
||||
"""
|
||||
# Build the project source tarball
|
||||
tarball_path = _create_project_source_dist()
|
||||
filename = os.path.basename(tarball_path)
|
||||
filename = filename.removesuffix('.tar.gz')
|
||||
|
||||
# Move the project tarball to temp_dir
|
||||
_res = shutil.copy(tarball_path, os.path.join(temp_dir, 'project.tar.gz'))
|
||||
if _res:
|
||||
os.remove(tarball_path)
|
||||
logger.info(
|
||||
f'Source distribution moved to {os.path.join(temp_dir, "project.tar.gz")}'
|
||||
)
|
||||
|
||||
# Unzip the tarball
|
||||
shutil.unpack_archive(os.path.join(temp_dir, 'project.tar.gz'), temp_dir)
|
||||
# Remove the tarball
|
||||
os.remove(os.path.join(temp_dir, 'project.tar.gz'))
|
||||
# Rename the directory containing the code to 'code'
|
||||
os.rename(os.path.join(temp_dir, filename), os.path.join(temp_dir, 'code'))
|
||||
logger.info(f'Unpacked source code directory: {os.path.join(temp_dir, "code")}')
|
||||
|
||||
|
||||
def _generate_dockerfile(
|
||||
base_image: str,
|
||||
skip_init: bool = False,
|
||||
extra_deps: str | None = None,
|
||||
) -> str:
|
||||
"""Generate the Dockerfile content for the runtime image based on the base image.
|
||||
|
||||
Parameters:
|
||||
- base_image (str): The base image provided for the runtime image
|
||||
- skip_init (boolean):
|
||||
- extra_deps (str):
|
||||
|
||||
Returns:
|
||||
- str: The resulting Dockerfile content
|
||||
"""
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(
|
||||
searchpath=os.path.join(os.path.dirname(__file__), 'runtime_templates')
|
||||
)
|
||||
)
|
||||
template = env.get_template('Dockerfile.j2')
|
||||
|
||||
dockerfile_content = template.render(
|
||||
base_image=base_image,
|
||||
skip_init=skip_init,
|
||||
extra_deps=extra_deps if extra_deps is not None else '',
|
||||
)
|
||||
return dockerfile_content
|
||||
|
||||
|
||||
def prep_docker_build_folder(
|
||||
dir_path: str,
|
||||
base_image: str,
|
||||
skip_init: bool = False,
|
||||
extra_deps: str | None = None,
|
||||
) -> str:
|
||||
"""Prepares a docker build folder by copying the source code and generating the Dockerfile
|
||||
|
||||
Parameters:
|
||||
- dir_path (str): The build folder to place the source code and Dockerfile
|
||||
- base_image (str): The base Docker image to use for the Dockerfile
|
||||
- skip_init (str):
|
||||
- extra_deps (str):
|
||||
|
||||
Returns:
|
||||
- str: The MD5 hash of the build folder directory (dir_path)
|
||||
"""
|
||||
# Copy the source code to directory. It will end up in dir_path/code
|
||||
_put_source_code_to_dir(dir_path)
|
||||
|
||||
# Create a Dockerfile and write it to dir_path
|
||||
dockerfile_content = _generate_dockerfile(
|
||||
base_image,
|
||||
skip_init=skip_init,
|
||||
extra_deps=extra_deps,
|
||||
)
|
||||
logger.info(
|
||||
(
|
||||
f'===== Dockerfile content start =====\n'
|
||||
f'{dockerfile_content}\n'
|
||||
f'===== Dockerfile content end ====='
|
||||
)
|
||||
)
|
||||
with open(os.path.join(dir_path, 'Dockerfile'), 'w') as file:
|
||||
file.write(dockerfile_content)
|
||||
|
||||
# Get the MD5 hash of the dir_path directory
|
||||
hash = dirhash(dir_path, 'md5')
|
||||
logger.info(
|
||||
f'Input base image: {base_image}\n'
|
||||
f'Skip init: {skip_init}\n'
|
||||
f'Extra deps: {extra_deps}\n'
|
||||
f'Hash for docker build directory [{dir_path}] (contents: {os.listdir(dir_path)}): {hash}\n'
|
||||
)
|
||||
return hash
|
||||
|
||||
|
||||
def get_runtime_image_repo_and_tag(base_image: str) -> tuple[str, str]:
|
||||
"""Retrieves the Docker repo and tag associated with the Docker image.
|
||||
|
||||
Parameters:
|
||||
- base_image (str): The name of the base Docker image
|
||||
|
||||
Returns:
|
||||
- tuple[str, str]: The Docker repo and tag of the Docker image
|
||||
"""
|
||||
|
||||
if RUNTIME_IMAGE_REPO in base_image:
|
||||
logger.info(
|
||||
f'The provided image [{base_image}] is a already a valid od_runtime image.\n'
|
||||
f'Will try to reuse it as is.'
|
||||
)
|
||||
|
||||
if ':' not in base_image:
|
||||
base_image = base_image + ':latest'
|
||||
repo, tag = base_image.split(':')
|
||||
return repo, tag
|
||||
else:
|
||||
if ':' not in base_image:
|
||||
base_image = base_image + ':latest'
|
||||
[repo, tag] = base_image.split(':')
|
||||
repo = repo.replace('/', '___')
|
||||
od_version = _get_package_version()
|
||||
return RUNTIME_IMAGE_REPO, f'od_v{od_version}_image_{repo}_tag_{tag}'
|
||||
|
||||
|
||||
def build_runtime_image(
|
||||
base_image: str,
|
||||
runtime_builder: RuntimeBuilder,
|
||||
extra_deps: str | None = None,
|
||||
docker_build_folder: str | None = None,
|
||||
dry_run: bool = False,
|
||||
force_rebuild: bool = False,
|
||||
) -> str:
|
||||
"""Prepares the final docker build folder.
|
||||
If dry_run is False, it will also build the OpenHands runtime Docker image using the docker build folder.
|
||||
|
||||
Parameters:
|
||||
- base_image (str): The name of the base Docker image to use
|
||||
- runtime_builder (RuntimeBuilder): The runtime builder to use
|
||||
- extra_deps (str):
|
||||
- docker_build_folder (str): The directory to use for the build. If not provided a temporary directory will be used
|
||||
- dry_run (bool): if True, it will only ready the build folder. It will not actually build the Docker image
|
||||
- force_rebuild (bool): if True, it will create the Dockerfile which uses the base_image
|
||||
|
||||
Returns:
|
||||
- str: <image_repo>:<MD5 hash>. Where MD5 hash is the hash of the docker build folder
|
||||
|
||||
See https://docs.all-hands.dev/modules/usage/runtime for more details.
|
||||
"""
|
||||
# Calculate the hash for the docker build folder (source code and Dockerfile)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
from_scratch_hash = prep_docker_build_folder(
|
||||
temp_dir,
|
||||
base_image=base_image,
|
||||
skip_init=False,
|
||||
extra_deps=extra_deps,
|
||||
)
|
||||
|
||||
runtime_image_repo, runtime_image_tag = get_runtime_image_repo_and_tag(base_image)
|
||||
|
||||
# The image name in the format <image repo>:<hash>
|
||||
hash_runtime_image_name = f'{runtime_image_repo}:{from_scratch_hash}'
|
||||
|
||||
# non-hash generic image name, it could contain *similar* dependencies
|
||||
# but *might* not exactly match the state of the source code.
|
||||
# It resembles the "latest" tag in the docker image naming convention for
|
||||
# a particular {repo}:{tag} pair (e.g., ubuntu:latest -> od_runtime:ubuntu_tag_latest)
|
||||
# we will build from IT to save time if the `from_scratch_hash` is not found
|
||||
generic_runtime_image_name = f'{runtime_image_repo}:{runtime_image_tag}'
|
||||
|
||||
# Scenario 1: If we already have an image with the exact same hash, then it means the image is already built
|
||||
# with the exact same source code and Dockerfile, so we will reuse it. Building it is not required.
|
||||
if runtime_builder.image_exists(hash_runtime_image_name):
|
||||
logger.info(
|
||||
f'Image [{hash_runtime_image_name}] already exists so we will reuse it.'
|
||||
)
|
||||
return hash_runtime_image_name
|
||||
|
||||
# Scenario 2: If a Docker image with the exact hash is not found, we will FIRST try to re-build it
|
||||
# by leveraging the `generic_runtime_image_name` to save some time
|
||||
# from re-building the dependencies (e.g., poetry install, apt install)
|
||||
elif runtime_builder.image_exists(generic_runtime_image_name) and not force_rebuild:
|
||||
logger.info(
|
||||
f'Cannot find docker Image [{hash_runtime_image_name}]\n'
|
||||
f'Will try to re-build it from latest [{generic_runtime_image_name}] image to potentially save '
|
||||
f'time for dependencies installation.\n'
|
||||
)
|
||||
|
||||
cur_docker_build_folder = docker_build_folder or tempfile.mkdtemp()
|
||||
_skip_init_hash = prep_docker_build_folder(
|
||||
cur_docker_build_folder,
|
||||
# we want to use the existing generic image as base
|
||||
# so that we can leverage existing dependencies already installed in the image
|
||||
base_image=generic_runtime_image_name,
|
||||
skip_init=True, # skip init since we are re-using the existing image
|
||||
extra_deps=extra_deps,
|
||||
)
|
||||
|
||||
assert (
|
||||
_skip_init_hash != from_scratch_hash
|
||||
), f'The skip_init hash [{_skip_init_hash}] should not match the existing hash [{from_scratch_hash}]'
|
||||
|
||||
if not dry_run:
|
||||
_build_sandbox_image(
|
||||
docker_folder=cur_docker_build_folder,
|
||||
runtime_builder=runtime_builder,
|
||||
target_image_repo=runtime_image_repo,
|
||||
# NOTE: WE ALWAYS use the "from_scratch_hash" tag for the target image
|
||||
# otherwise, even if the source code is exactly the same, the image *might* be re-built
|
||||
# because the same source code will generate different hash when skip_init=True/False
|
||||
# since the Dockerfile is slightly different
|
||||
target_image_hash_tag=from_scratch_hash,
|
||||
target_image_tag=runtime_image_tag,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f'Dry run: Skipping image build for [{generic_runtime_image_name}]'
|
||||
)
|
||||
|
||||
if docker_build_folder is None:
|
||||
shutil.rmtree(cur_docker_build_folder)
|
||||
|
||||
# Scenario 3: If the Docker image with the required hash is not found AND we cannot re-use the latest
|
||||
# relevant image, we will build it completely from scratch
|
||||
else:
|
||||
if force_rebuild:
|
||||
logger.info(
|
||||
f'Force re-build: Will try to re-build image [{generic_runtime_image_name}] from scratch.\n'
|
||||
)
|
||||
|
||||
cur_docker_build_folder = docker_build_folder or tempfile.mkdtemp()
|
||||
_new_from_scratch_hash = prep_docker_build_folder(
|
||||
cur_docker_build_folder,
|
||||
base_image,
|
||||
skip_init=False,
|
||||
extra_deps=extra_deps,
|
||||
)
|
||||
assert (
|
||||
_new_from_scratch_hash == from_scratch_hash
|
||||
), f'The new from scratch hash [{_new_from_scratch_hash}] does not match the existing hash [{from_scratch_hash}]'
|
||||
|
||||
if not dry_run:
|
||||
_build_sandbox_image(
|
||||
docker_folder=cur_docker_build_folder,
|
||||
runtime_builder=runtime_builder,
|
||||
target_image_repo=runtime_image_repo,
|
||||
# NOTE: WE ALWAYS use the "from_scratch_hash" tag for the target image
|
||||
target_image_hash_tag=from_scratch_hash,
|
||||
target_image_tag=runtime_image_tag,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f'Dry run: Skipping image build for [{generic_runtime_image_name}]'
|
||||
)
|
||||
|
||||
if docker_build_folder is None:
|
||||
shutil.rmtree(cur_docker_build_folder)
|
||||
|
||||
return f'{runtime_image_repo}:{from_scratch_hash}'
|
||||
|
||||
|
||||
def _build_sandbox_image(
|
||||
docker_folder: str,
|
||||
runtime_builder: RuntimeBuilder,
|
||||
target_image_repo: str,
|
||||
target_image_hash_tag: str,
|
||||
target_image_tag: str,
|
||||
) -> str:
|
||||
"""Build and tag the sandbox image.
|
||||
The image will be tagged as both:
|
||||
- target_image_hash_tag
|
||||
- target_image_tag
|
||||
|
||||
Parameters:
|
||||
- docker_folder (str): the path to the docker build folder
|
||||
- runtime_builder (RuntimeBuilder): the runtime builder instance
|
||||
- target_image_repo (str): the repository name for the target image
|
||||
- target_image_hash_tag (str): the *hash* tag for the target image that is calculated based
|
||||
on the contents of the docker build folder (source code and Dockerfile)
|
||||
e.g. 1234567890abcdef
|
||||
-target_image_tag (str): the tag for the target image that's generic and based on the base image name
|
||||
e.g. od_v0.8.3_image_ubuntu_tag_22.04
|
||||
"""
|
||||
target_image_hash_name = f'{target_image_repo}:{target_image_hash_tag}'
|
||||
target_image_generic_name = f'{target_image_repo}:{target_image_tag}'
|
||||
|
||||
try:
|
||||
success = runtime_builder.build(
|
||||
path=docker_folder, tags=[target_image_hash_name, target_image_generic_name]
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(f'Build failed for image {target_image_hash_name}')
|
||||
except Exception as e:
|
||||
logger.error(f'Sandbox image build failed: {e}')
|
||||
raise
|
||||
|
||||
return target_image_hash_name
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--base_image', type=str, default='nikolaik/python-nodejs:python3.11-nodejs22'
|
||||
)
|
||||
parser.add_argument('--build_folder', type=str, default=None)
|
||||
parser.add_argument('--force_rebuild', action='store_true', default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.build_folder is not None:
|
||||
# If a build_folder is provided, we do not actually build the Docker image. We copy the necessary source code
|
||||
# and create a Dockerfile dynamically and place it in the build_folder only. This allows the Docker image to
|
||||
# then be created using the Dockerfile (most likely using the containers/build.sh script)
|
||||
build_folder = args.build_folder
|
||||
assert os.path.exists(
|
||||
build_folder
|
||||
), f'Build folder {build_folder} does not exist'
|
||||
logger.info(
|
||||
f'Copying the source code and generating the Dockerfile in the build folder: {build_folder}'
|
||||
)
|
||||
|
||||
runtime_image_repo, runtime_image_tag = get_runtime_image_repo_and_tag(
|
||||
args.base_image
|
||||
)
|
||||
logger.info(
|
||||
f'Runtime image repo: {runtime_image_repo} and runtime image tag: {runtime_image_tag}'
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# dry_run is true so we only prepare a temp_dir containing the required source code and the Dockerfile. We
|
||||
# then obtain the MD5 hash of the folder and return <image_repo>:<temp_dir_md5_hash>
|
||||
runtime_image_hash_name = build_runtime_image(
|
||||
args.base_image,
|
||||
runtime_builder=DockerRuntimeBuilder(docker.from_env()),
|
||||
docker_build_folder=temp_dir,
|
||||
dry_run=True,
|
||||
force_rebuild=args.force_rebuild,
|
||||
)
|
||||
|
||||
_runtime_image_repo, runtime_image_hash_tag = runtime_image_hash_name.split(
|
||||
':'
|
||||
)
|
||||
|
||||
# Move contents of temp_dir to build_folder
|
||||
shutil.copytree(temp_dir, build_folder, dirs_exist_ok=True)
|
||||
logger.info(
|
||||
f'Build folder [{build_folder}] is ready: {os.listdir(build_folder)}'
|
||||
)
|
||||
|
||||
# We now update the config.sh in the build_folder to contain the required values. This is used in the
|
||||
# containers/build.sh script which is called to actually build the Docker image
|
||||
with open(os.path.join(build_folder, 'config.sh'), 'a') as file:
|
||||
file.write(
|
||||
(
|
||||
f'\n'
|
||||
f'DOCKER_IMAGE_TAG={runtime_image_tag}\n'
|
||||
f'DOCKER_IMAGE_HASH_TAG={runtime_image_hash_tag}\n'
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f'`config.sh` is updated with the image repo[{runtime_image_repo}] and tags [{runtime_image_tag}, {runtime_image_hash_tag}]'
|
||||
)
|
||||
logger.info(
|
||||
f'Dockerfile, source code and config.sh are ready in {build_folder}'
|
||||
)
|
||||
else:
|
||||
# If a build_folder is not provided, after copying the required source code and dynamically creating the
|
||||
# Dockerfile, we actually build the Docker image
|
||||
logger.info('Building image in a temporary folder')
|
||||
docker_builder = DockerRuntimeBuilder(docker.from_env())
|
||||
image_name = build_runtime_image(args.base_image, docker_builder)
|
||||
print(f'\nBUILT Image: {image_name}\n')
|
||||
68
openhands/runtime/utils/runtime_templates/Dockerfile.j2
Normal file
68
openhands/runtime/utils/runtime_templates/Dockerfile.j2
Normal file
@@ -0,0 +1,68 @@
|
||||
{% if skip_init %}
|
||||
FROM {{ base_image }}
|
||||
{% else %}
|
||||
# ================================================================
|
||||
# START: Build Runtime Image from Scratch
|
||||
# ================================================================
|
||||
FROM {{ base_image }}
|
||||
|
||||
{% if 'ubuntu' in base_image and (base_image.endswith(':latest') or base_image.endswith(':24.04')) %}
|
||||
{% set LIBGL_MESA = 'libgl1' %}
|
||||
{% else %}
|
||||
{% set LIBGL_MESA = 'libgl1-mesa-glx' %}
|
||||
{% endif %}
|
||||
|
||||
# Install necessary packages and clean up in one layer
|
||||
RUN apt-get update && \
|
||||
apt-get install -y wget sudo apt-utils {{ LIBGL_MESA }} libasound2-plugins git && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /openhands && \
|
||||
mkdir -p /openhands/logs && \
|
||||
mkdir -p /openhands/poetry
|
||||
|
||||
ENV POETRY_VIRTUALENVS_PATH=/openhands/poetry
|
||||
|
||||
RUN if [ ! -d /openhands/miniforge3 ]; then \
|
||||
wget --progress=bar:force -O Miniforge3.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" && \
|
||||
bash Miniforge3.sh -b -p /openhands/miniforge3 && \
|
||||
rm Miniforge3.sh && \
|
||||
chmod -R g+w /openhands/miniforge3 && \
|
||||
bash -c ". /openhands/miniforge3/etc/profile.d/conda.sh && conda config --set changeps1 False && conda config --append channels conda-forge"; \
|
||||
fi
|
||||
|
||||
# Install Python and Poetry
|
||||
RUN /openhands/miniforge3/bin/mamba install conda-forge::poetry python=3.11 -y
|
||||
# ================================================================
|
||||
# END: Build Runtime Image from Scratch
|
||||
# ================================================================
|
||||
{% endif %}
|
||||
|
||||
# ================================================================
|
||||
# START: Copy Project and Install/Update Dependencies
|
||||
# ================================================================
|
||||
RUN if [ -d /openhands/code ]; then rm -rf /openhands/code; fi
|
||||
COPY ./code /openhands/code
|
||||
|
||||
# Install/Update Dependencies
|
||||
# 1. Install pyproject.toml via poetry
|
||||
# 2. Install playwright and chromium
|
||||
# 3. Clear poetry, apt, mamba caches
|
||||
RUN cd /openhands/code && \
|
||||
/openhands/miniforge3/bin/mamba run -n base poetry env use python3.11 && \
|
||||
/openhands/miniforge3/bin/mamba run -n base poetry install --only main,runtime --no-interaction --no-root && \
|
||||
apt-get update && \
|
||||
/openhands/miniforge3/bin/mamba run -n base poetry run pip install playwright && \
|
||||
/openhands/miniforge3/bin/mamba run -n base poetry run playwright install --with-deps chromium && \
|
||||
export OD_INTERPRETER_PATH=$(/openhands/miniforge3/bin/mamba run -n base poetry run python -c "import sys; print(sys.executable)") && \
|
||||
{{ extra_deps }} {% if extra_deps %} && {% endif %} \
|
||||
/openhands/miniforge3/bin/mamba run -n base poetry cache clear --all . && \
|
||||
{% if not skip_init %}chmod -R g+rws /openhands/poetry && {% endif %} \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
|
||||
/openhands/miniforge3/bin/mamba clean --all
|
||||
|
||||
# ================================================================
|
||||
# END: Copy Project and Install/Update Dependencies
|
||||
# ================================================================
|
||||
0
openhands/runtime/utils/singleton.py
Normal file
0
openhands/runtime/utils/singleton.py
Normal file
14
openhands/runtime/utils/system.py
Normal file
14
openhands/runtime/utils/system.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import socket
|
||||
|
||||
|
||||
def find_available_tcp_port() -> int:
|
||||
"""Find an available TCP port, return -1 if none available."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.bind(('localhost', 0))
|
||||
port = sock.getsockname()[1]
|
||||
return port
|
||||
except Exception:
|
||||
return -1
|
||||
finally:
|
||||
sock.close()
|
||||
73
openhands/security/README.md
Normal file
73
openhands/security/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Security
|
||||
|
||||
Given the impressive capabilities of OpenHands and similar coding agents, ensuring robust security measures is essential to prevent unintended actions or security breaches. The SecurityAnalyzer framework provides a structured approach to monitor and analyze agent actions for potential security risks.
|
||||
|
||||
To enable this feature:
|
||||
* From the web interface
|
||||
* Open Configuration (by clicking the gear icon in the bottom right)
|
||||
* Select a Security Analyzer from the dropdown
|
||||
* Save settings
|
||||
* (to disable) repeat the same steps, but click the X in the Security Analyzer dropdown
|
||||
* From config.toml
|
||||
```toml
|
||||
[security]
|
||||
# Enable confirmation mode
|
||||
confirmation_mode = true
|
||||
# The security analyzer to use
|
||||
security_analyzer = "your-security-analyzer"
|
||||
```
|
||||
(to disable) remove the lines from config.toml
|
||||
|
||||
## SecurityAnalyzer Base Class
|
||||
|
||||
The `SecurityAnalyzer` class (analyzer.py) is an abstract base class designed to listen to an event stream and analyze actions for security risks and eventually act before the action is executed. Below is a detailed explanation of its components and methods:
|
||||
|
||||
### Initialization
|
||||
|
||||
- **event_stream**: An instance of `EventStream` that the analyzer will listen to for events.
|
||||
|
||||
### Event Handling
|
||||
|
||||
- **on_event(event: Event)**: Handles incoming events. If the event is an `Action`, it evaluates its security risk and acts upon it.
|
||||
|
||||
### Abstract Methods
|
||||
|
||||
- **handle_api_request(request: Request)**: Abstract method to handle API requests.
|
||||
- **log_event(event: Event)**: Logs events.
|
||||
- **act(event: Event)**: Defines actions to take based on the analyzed event.
|
||||
- **security_risk(event: Action)**: Evaluates the security risk of an action and returns the risk level.
|
||||
- **close()**: Cleanups resources used by the security analyzer.
|
||||
|
||||
In conclusion, a concrete security analyzer should evaluate the risk of each event and act accordingly (e.g. auto-confirm, send Slack message, etc).
|
||||
|
||||
For customization and decoupling from the OpenHands core logic, the security analyzer can define its own API endpoints that can then be accessed from the frontend. These API endpoints need to be secured (do not allow more capabilities than the core logic
|
||||
provides).
|
||||
|
||||
## How to implement your own Security Analyzer
|
||||
|
||||
1. Create a submodule in [security](/openhands/security/) with your analyzer's desired name
|
||||
* Have your main class inherit from [SecurityAnalyzer](/openhands/security/analyzer.py)
|
||||
* Optional: define API endpoints for `/api/security/{path:path}` to manage settings,
|
||||
2. Add your analyzer class to the [options](/openhands/security/options.py) to have it be visible from the frontend combobox
|
||||
3. Optional: implement your modal frontend (for when you click on the lock) in [security](/frontend/src/components/modals/security/) and add your component to [Security.tsx](/frontend/src/components/modals/security/Security.tsx)
|
||||
|
||||
## Implemented Security Analyzers
|
||||
|
||||
### Invariant
|
||||
|
||||
It uses the [Invariant Analyzer](https://github.com/invariantlabs-ai/invariant) to analyze traces and detect potential issues with OpenHands's workflow. It uses confirmation mode to ask for user confirmation on potentially risky actions.
|
||||
|
||||
This allows the agent to run autonomously without fear that it will inadvertently compromise security or perform unintended actions that could be harmful.
|
||||
|
||||
Features:
|
||||
|
||||
* Detects:
|
||||
* potential secret leaks by the agent
|
||||
* security issues in Python code
|
||||
* malicious bash commands
|
||||
* Logs:
|
||||
* actions and their associated risk
|
||||
* OpenHands traces in JSON format
|
||||
* Run-time settings:
|
||||
* the [invariant policy](https://github.com/invariantlabs-ai/invariant?tab=readme-ov-file#policy-language)
|
||||
* acceptable risk threshold
|
||||
7
openhands/security/__init__.py
Normal file
7
openhands/security/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .analyzer import SecurityAnalyzer
|
||||
from .invariant.analyzer import InvariantAnalyzer
|
||||
|
||||
__all__ = [
|
||||
'SecurityAnalyzer',
|
||||
'InvariantAnalyzer',
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user