Rename OpenDevin to OpenHands (#3472)

* Replace OpenDevin with OpenHands

* Update CONTRIBUTING.md

* Update README.md

* Update README.md

* update poetry lock; move opendevin folder to openhands

* fix env var

* revert image references in docs

* revert permissions

* revert permissions

---------

Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
Robert Brennan
2024-08-19 12:44:54 -04:00
committed by GitHub
parent 83f36c1d66
commit 01ae22ef57
387 changed files with 1832 additions and 1824 deletions

52
openhands/README.md Normal file
View File

@@ -0,0 +1,52 @@
# OpenHands Architecture
This directory contains the core components of OpenHands.
This diagram provides an overview of the roles of each component and how they communicate and collaborate.
![OpenHands System Architecture Diagram (July 4, 2024)](../docs/static/img/system_architecture_overview.png)
## Classes
The key classes in OpenHands are:
* LLM: brokers all interactions with large language models. Works with any underlying completion model, thanks to LiteLLM.
* Agent: responsible for looking at the current State, and producing an Action that moves one step closer toward the end-goal.
* AgentController: initializes the Agent, manages State, and drive the main loop that pushes the Agent forward, step by step
* State: represents the current state of the Agent's task. Includes things like the current step, a history of recent events, the Agent's long-term plan, etc
* EventStream: a central hub for Events, where any component can publish Events, or listen for Events published by other components
* Event: an Action or Observeration
* Action: represents a request to e.g. edit a file, run a command, or send a message
* Observation: represents information collected from the environment, e.g. file contents or command output
* Runtime: responsible for performing Actions, and sending back Observations
* Sandbox: the part of the runtime responsible for running commands, e.g. inside of Docker
* Server: brokers OpenHands sessions over HTTP, e.g. to drive the frontend
* Session: holds a single EventStream, a single AgentController, and a single Runtime. Generally represents a single task (but potentially including several user prompts)
* SessionManager: keeps a list of active sessions, and ensures requests are routed to the correct Session
## Control Flow
Here's the basic loop (in pseudocode) that drives agents.
```python
while True:
prompt = agent.generate_prompt(state)
response = llm.completion(prompt)
action = agent.parse_response(response)
observation = runtime.run(action)
state = state.update(action, observation)
```
In reality, most of this is achieved through message passing, via the EventStream.
The EventStream serves as the backbone for all communication in OpenHands.
```mermaid
flowchart LR
Agent--Actions-->AgentController
AgentController--State-->Agent
AgentController--Actions-->EventStream
EventStream--Observations-->AgentController
Runtime--Observations-->EventStream
EventStream--Actions-->Runtime
Frontend--Actions-->EventStream
```
## Runtime
Please refer to the [documentation](https://docs.all-hands.dev/modules/usage/runtime) to learn more about `Runtime`.

0
openhands/__init__.py Normal file
View File

View File

@@ -0,0 +1,5 @@
from .agent_controller import AgentController
__all__ = [
'AgentController',
]

View File

@@ -0,0 +1,67 @@
from abc import ABC, abstractmethod
from openhands.events.action import Action
class ResponseParser(ABC):
"""This abstract base class is a general interface for an response parser dedicated to
parsing the action from the response from the LLM.
"""
def __init__(
self,
):
# Need pay attention to the item order in self.action_parsers
self.action_parsers = []
@abstractmethod
def parse(self, response: str) -> Action:
"""Parses the action from the response from the LLM.
Parameters:
- response (str): The response from the LLM.
Returns:
- action (Action): The action parsed from the response.
"""
pass
@abstractmethod
def parse_response(self, response) -> str:
"""Parses the action from the response from the LLM.
Parameters:
- response (str): The response from the LLM.
Returns:
- action_str (str): The action str parsed from the response.
"""
pass
@abstractmethod
def parse_action(self, action_str: str) -> Action:
"""Parses the action from the response from the LLM.
Parameters:
- action_str (str): The response from the LLM.
Returns:
- action (Action): The action parsed from the response.
"""
pass
class ActionParser(ABC):
"""This abstract base class is a general interface for an action parser dedicated to
parsing the action from the action str from the LLM.
"""
@abstractmethod
def check_condition(self, action_str: str) -> bool:
"""Check if the action string can be parsed by this parser."""
pass
@abstractmethod
def parse(self, action_str: str) -> Action:
"""Parses the action from the action string from the LLM response."""
pass

View File

@@ -0,0 +1,109 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Type
if TYPE_CHECKING:
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.events.action import Action
from openhands.core.exceptions import (
AgentAlreadyRegisteredError,
AgentNotRegisteredError,
)
from openhands.llm.llm import LLM
from openhands.runtime.plugins import PluginRequirement
class Agent(ABC):
DEPRECATED = False
"""
This abstract base class is an general interface for an agent dedicated to
executing a specific instruction and allowing human interaction with the
agent during execution.
It tracks the execution status and maintains a history of interactions.
"""
_registry: dict[str, Type['Agent']] = {}
sandbox_plugins: list[PluginRequirement] = []
def __init__(
self,
llm: LLM,
config: 'AgentConfig',
):
self.llm = llm
self.config = config
self._complete = False
@property
def complete(self) -> bool:
"""Indicates whether the current instruction execution is complete.
Returns:
- complete (bool): True if execution is complete; False otherwise.
"""
return self._complete
@abstractmethod
def step(self, state: 'State') -> 'Action':
"""Starts the execution of the assigned instruction. This method should
be implemented by subclasses to define the specific execution logic.
"""
pass
def reset(self) -> None:
"""Resets the agent's execution status and clears the history. This method can be used
to prepare the agent for restarting the instruction or cleaning up before destruction.
"""
# TODO clear history
self._complete = False
if self.llm:
self.llm.reset()
@property
def name(self):
return self.__class__.__name__
@classmethod
def register(cls, name: str, agent_cls: Type['Agent']):
"""Registers an agent class in the registry.
Parameters:
- name (str): The name to register the class under.
- agent_cls (Type['Agent']): The class to register.
Raises:
- AgentAlreadyRegisteredError: If name already registered
"""
if name in cls._registry:
raise AgentAlreadyRegisteredError(name)
cls._registry[name] = agent_cls
@classmethod
def get_cls(cls, name: str) -> Type['Agent']:
"""Retrieves an agent class from the registry.
Parameters:
- name (str): The name of the class to retrieve
Returns:
- agent_cls (Type['Agent']): The class registered under the specified name.
Raises:
- AgentNotRegisteredError: If name not registered
"""
if name not in cls._registry:
raise AgentNotRegisteredError(name)
return cls._registry[name]
@classmethod
def list_agents(cls) -> list[str]:
"""Retrieves the list of all agent names from the registry.
Raises:
- AgentNotRegisteredError: If no agent is registered
"""
if not bool(cls._registry):
raise AgentNotRegisteredError()
return list(cls._registry.keys())

View File

@@ -0,0 +1,545 @@
import asyncio
import traceback
from typing import Type
from openhands.controller.agent import Agent
from openhands.controller.state.state import State, TrafficControlState
from openhands.controller.stuck import StuckDetector
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.exceptions import (
LLMMalformedActionError,
LLMNoActionError,
LLMResponseError,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
Action,
ActionConfirmationStatus,
AddTaskAction,
AgentDelegateAction,
AgentFinishAction,
AgentRejectAction,
ChangeAgentStateAction,
CmdRunAction,
IPythonRunCellAction,
MessageAction,
ModifyTaskAction,
NullAction,
)
from openhands.events.event import Event
from openhands.events.observation import (
AgentDelegateObservation,
AgentStateChangedObservation,
CmdOutputObservation,
ErrorObservation,
Observation,
)
from openhands.llm.llm import LLM
# note: RESUME is only available on web GUI
TRAFFIC_CONTROL_REMINDER = (
"Please click on resume button if you'd like to continue, or start a new task."
)
class AgentController:
id: str
agent: Agent
max_iterations: int
event_stream: EventStream
state: State
confirmation_mode: bool
agent_to_llm_config: dict[str, LLMConfig]
agent_configs: dict[str, AgentConfig]
agent_task: asyncio.Task | None = None
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action: Action | None = None
def __init__(
self,
agent: Agent,
event_stream: EventStream,
max_iterations: int,
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
sid: str = 'default',
confirmation_mode: bool = False,
initial_state: State | None = None,
is_delegate: bool = False,
headless_mode: bool = True,
):
"""Initializes a new instance of the AgentController class.
Args:
agent: The agent instance to control.
event_stream: The event stream to publish events to.
max_iterations: The maximum number of iterations the agent can run.
max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
agent_to_llm_config: A dictionary mapping agent names to LLM configurations in the case that
we delegate to a different agent.
agent_configs: A dictionary mapping agent names to agent configurations in the case that
we delegate to a different agent.
sid: The session ID of the agent.
initial_state: The initial state of the controller.
is_delegate: Whether this controller is a delegate.
headless_mode: Whether the agent is run in headless mode.
"""
self._step_lock = asyncio.Lock()
self.id = sid
self.agent = agent
self.headless_mode = headless_mode
# subscribe to the event stream
self.event_stream = event_stream
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
)
# state from the previous session, state from a parent agent, or a fresh state
self.set_initial_state(
state=initial_state,
max_iterations=max_iterations,
confirmation_mode=confirmation_mode,
)
self.max_budget_per_task = max_budget_per_task
self.agent_to_llm_config = agent_to_llm_config if agent_to_llm_config else {}
self.agent_configs = agent_configs if agent_configs else {}
# stuck helper
self._stuck_detector = StuckDetector(self.state)
if not is_delegate:
self.agent_task = asyncio.create_task(self._start_step_loop())
async def close(self):
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
if self.agent_task is not None:
self.agent_task.cancel()
await self.set_agent_state_to(AgentState.STOPPED)
self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER)
def update_state_before_step(self):
self.state.iteration += 1
self.state.local_iteration += 1
async def update_state_after_step(self):
# update metrics especially for cost
self.state.local_metrics = self.agent.llm.metrics
async def report_error(self, message: str, exception: Exception | None = None):
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.
This method should be called for a particular type of errors, which have:
- a user-friendly message, which will be shown in the chat box. This should not be a raw exception message.
- an ErrorObservation that can be sent to the LLM by the agent, with the exception message, so it can self-correct next time.
"""
self.state.last_error = message
if exception:
self.state.last_error += f': {exception}'
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
async def _start_step_loop(self):
"""The main loop for the agent's step-by-step execution."""
logger.info(f'[Agent Controller {self.id}] Starting step loop...')
while True:
try:
await self._step()
except asyncio.CancelledError:
logger.info('AgentController task was cancelled')
break
except Exception as e:
traceback.print_exc()
logger.error(f'Error while running the agent: {e}')
logger.error(traceback.format_exc())
await self.report_error(
'There was an unexpected error while running the agent', exception=e
)
await self.set_agent_state_to(AgentState.ERROR)
break
await asyncio.sleep(0.1)
async def on_event(self, event: Event):
"""Callback from the event stream. Notifies the controller of incoming events.
Args:
event (Event): The incoming event to process.
"""
if isinstance(event, ChangeAgentStateAction):
await self.set_agent_state_to(event.agent_state) # type: ignore
elif isinstance(event, MessageAction):
if event.source == EventSource.USER:
logger.info(
event,
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif event.source == EventSource.AGENT and event.wait_for_response:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
elif isinstance(event, AgentDelegateAction):
await self.start_delegate(event)
elif isinstance(event, AddTaskAction):
self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks)
elif isinstance(event, ModifyTaskAction):
self.state.root_task.set_subtask_state(event.task_id, event.state)
elif isinstance(event, AgentFinishAction):
self.state.outputs = event.outputs
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.FINISHED)
elif isinstance(event, AgentRejectAction):
self.state.outputs = event.outputs
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.REJECTED)
elif isinstance(event, Observation):
if (
self._pending_action
and hasattr(self._pending_action, 'is_confirmed')
and self._pending_action.is_confirmed
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return
if self._pending_action and self._pending_action.id == event.cause:
self._pending_action = None
if self.state.agent_state == AgentState.USER_CONFIRMED:
await self.set_agent_state_to(AgentState.RUNNING)
if self.state.agent_state == AgentState.USER_REJECTED:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, CmdOutputObservation):
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, AgentDelegateObservation):
self.state.history.on_event(event)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, ErrorObservation):
logger.info(event, extra={'msg_type': 'OBSERVATION'})
def reset_task(self):
"""Resets the agent's task."""
self.almost_stuck = 0
self.agent.reset()
async def set_agent_state_to(self, new_state: AgentState):
"""Updates the agent's state and handles side effects. Can emit events to the event stream.
Args:
new_state (AgentState): The new state to set for the agent.
"""
logger.debug(
f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
)
if new_state == self.state.agent_state:
return
if (
self.state.agent_state == AgentState.PAUSED
and new_state == AgentState.RUNNING
and self.state.traffic_control_state == TrafficControlState.THROTTLING
):
# user intends to interrupt traffic control and let the task resume temporarily
self.state.traffic_control_state = TrafficControlState.PAUSED
self.state.agent_state = new_state
if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
self.reset_task()
if self._pending_action is not None and (
new_state == AgentState.USER_CONFIRMED
or new_state == AgentState.USER_REJECTED
):
if hasattr(self._pending_action, 'thought'):
self._pending_action.thought = '' # type: ignore[union-attr]
if new_state == AgentState.USER_CONFIRMED:
self._pending_action.is_confirmed = ActionConfirmationStatus.CONFIRMED # type: ignore[attr-defined]
else:
self._pending_action.is_confirmed = ActionConfirmationStatus.REJECTED # type: ignore[attr-defined]
self.event_stream.add_event(self._pending_action, EventSource.AGENT)
self.event_stream.add_event(
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
)
if new_state == AgentState.INIT and self.state.resume_state:
await self.set_agent_state_to(self.state.resume_state)
self.state.resume_state = None
def get_agent_state(self):
"""Returns the current state of the agent.
Returns:
AgentState: The current state of the agent.
"""
return self.state.agent_state
async def start_delegate(self, action: AgentDelegateAction):
"""Start a delegate agent to handle a subtask.
OpenHands is a multi-agentic system. A `task` is a conversation between
OpenHands (the whole system) and the user, which might involve one or more inputs
from the user. It starts with an initial input (typically a task statement) from
the user, and ends with either an `AgentFinishAction` initiated by the agent, a
stop initiated by the user, or an error.
A `subtask` is a conversation between an agent and the user, or another agent. If a `task`
is conducted by a single agent, then it's also a `subtask`. Otherwise, a `task` consists of
multiple `subtasks`, each executed by one agent.
Args:
action (AgentDelegateAction): The action containing information about the delegate agent to start.
"""
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
agent_config = self.agent_configs.get(action.agent, self.agent.config)
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
llm = LLM(config=llm_config)
delegate_agent = agent_cls(llm=llm, config=agent_config)
state = State(
inputs=action.inputs or {},
local_iteration=0,
iteration=self.state.iteration,
max_iterations=self.state.max_iterations,
delegate_level=self.state.delegate_level + 1,
# global metrics should be shared between parent and child
metrics=self.state.metrics,
)
logger.info(
f'[Agent Controller {self.id}]: start delegate, creating agent {delegate_agent.name} using LLM {llm}'
)
self.delegate = AgentController(
sid=self.id + '-delegate',
agent=delegate_agent,
event_stream=self.event_stream,
max_iterations=self.state.max_iterations,
max_budget_per_task=self.max_budget_per_task,
agent_to_llm_config=self.agent_to_llm_config,
agent_configs=self.agent_configs,
initial_state=state,
is_delegate=True,
headless_mode=self.headless_mode,
)
await self.delegate.set_agent_state_to(AgentState.RUNNING)
async def _step(self) -> None:
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
if self.get_agent_state() != AgentState.RUNNING:
await asyncio.sleep(1)
return
if self._pending_action:
logger.debug(
f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
)
await asyncio.sleep(1)
return
if self.delegate is not None:
logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...')
assert self.delegate != self
await self.delegate._step()
logger.debug(f'[Agent Controller {self.id}] Delegate step done')
assert self.delegate is not None
delegate_state = self.delegate.get_agent_state()
logger.debug(
f'[Agent Controller {self.id}] Delegate state: {delegate_state}'
)
if delegate_state == AgentState.ERROR:
# close the delegate upon error
await self.delegate.close()
self.delegate = None
self.delegateAction = None
await self.report_error('Delegator agent encounters an error')
return
delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED)
if delegate_done:
logger.info(
f'[Agent Controller {self.id}] Delegate agent has finished execution'
)
# retrieve delegate result
outputs = self.delegate.state.outputs if self.delegate.state else {}
# update iteration that shall be shared across agents
self.state.iteration = self.delegate.state.iteration
# close delegate controller: we must close the delegate controller before adding new events
await self.delegate.close()
# update delegate result observation
# TODO: replace this with AI-generated summary (#2395)
formatted_output = ', '.join(
f'{key}: {value}' for key, value in outputs.items()
)
content = (
f'{self.delegate.agent.name} finishes task with {formatted_output}'
)
obs: Observation = AgentDelegateObservation(
outputs=outputs, content=content
)
# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
return
logger.info(
f'{self.agent.name} LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration}',
extra={'msg_type': 'STEP'},
)
if self.state.iteration >= self.state.max_iterations:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
'Agent reached maximum number of iterations in headless mode, task stopped.'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
elif self.max_budget_per_task is not None:
current_cost = self.state.metrics.accumulated_cost
if current_cost > self.max_budget_per_task:
if self.state.traffic_control_state == TrafficControlState.PAUSED:
logger.info(
'Hitting traffic control, temporarily resume upon user request'
)
self.state.traffic_control_state = TrafficControlState.NORMAL
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# set to ERROR state if running in headless mode
# there is no way to resume
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, max budget: {self.max_budget_per_task:.2f}, task stopped.'
)
await self.set_agent_state_to(AgentState.ERROR)
else:
await self.report_error(
f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}'
)
await self.set_agent_state_to(AgentState.PAUSED)
return
self.update_state_before_step()
action: Action = NullAction()
try:
action = self.agent.step(self.state)
if action is None:
raise LLMNoActionError('No action was returned')
except (LLMMalformedActionError, LLMNoActionError, LLMResponseError) as e:
# report to the user
# and send the underlying exception to the LLM for self-correction
await self.report_error(str(e))
return
if action.runnable:
if self.state.confirmation_mode and (
type(action) is CmdRunAction or type(action) is IPythonRunCellAction
):
action.is_confirmed = ActionConfirmationStatus.AWAITING_CONFIRMATION
self._pending_action = action
if not isinstance(action, NullAction):
if (
hasattr(action, 'is_confirmed')
and action.is_confirmed
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION)
self.event_stream.add_event(action, EventSource.AGENT)
await self.update_state_after_step()
logger.info(action, extra={'msg_type': 'ACTION'})
if self._is_stuck():
await self.report_error('Agent got stuck in a loop')
await self.set_agent_state_to(AgentState.ERROR)
def get_state(self):
"""Returns the current running state object.
Returns:
State: The current state object.
"""
return self.state
def set_initial_state(
self,
state: State | None,
max_iterations: int,
confirmation_mode: bool = False,
):
"""Sets the initial state for the agent, either from the previous session, or from a parent agent, or by creating a new one.
Args:
state: The state to initialize with, or None to create a new state.
max_iterations: The maximum number of iterations allowed for the task.
confirmation_mode: Whether to enable confirmation mode.
"""
# state from the previous session, state from a parent agent, or a new state
# note that this is called twice when restoring a previous session, first with state=None
if state is None:
self.state = State(
inputs={},
max_iterations=max_iterations,
confirmation_mode=confirmation_mode,
)
else:
self.state = state
# when restored from a previous session, the State object will have history, start_id, and end_id
# connect it to the event stream
self.state.history.set_event_stream(self.event_stream)
# if start_id was not set in State, we're starting fresh, at the top of the stream
start_id = self.state.start_id
if start_id == -1:
start_id = self.event_stream.get_latest_event_id() + 1
else:
logger.debug(f'AgentController {self.id} restoring from event {start_id}')
# make sure history is in sync
self.state.start_id = start_id
self.state.history.start_id = start_id
# if there was an end_id saved in State, set it in history
# currently not used, later useful for delegates
if self.state.end_id > -1:
self.state.history.end_id = self.state.end_id
def _is_stuck(self):
"""Checks if the agent or its delegate is stuck in a loop.
Returns:
bool: True if the agent is stuck, False otherwise.
"""
# check if delegate stuck
if self.delegate and self.delegate._is_stuck():
return True
return self._stuck_detector.is_stuck()
def __repr__(self):
return (
f'AgentController(id={self.id}, agent={self.agent!r}, '
f'event_stream={self.event_stream!r}, '
f'state={self.state!r}, agent_task={self.agent_task!r}, '
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
)

View File

@@ -0,0 +1,172 @@
import base64
import pickle
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from openhands.controller.state.task import RootTask
from openhands.core.logger import openhands_logger as logger
from openhands.core.metrics import Metrics
from openhands.core.schema import AgentState
from openhands.events.action import (
MessageAction,
)
from openhands.events.action.agent import AgentFinishAction
from openhands.memory.history import ShortTermHistory
from openhands.storage.files import FileStore
class TrafficControlState(str, Enum):
# default state, no rate limiting
NORMAL = 'normal'
# task paused due to traffic control
THROTTLING = 'throttling'
# traffic control is temporarily paused
PAUSED = 'paused'
RESUMABLE_STATES = [
AgentState.RUNNING,
AgentState.PAUSED,
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
]
@dataclass
class State:
"""
Represents the running state of an agent in the OpenHands system, saving data of its operation and memory.
- Multi-agent/delegate state:
- store the task (conversation between the agent and the user)
- the subtask (conversation between an agent and the user or another agent)
- global and local iterations
- delegate levels for multi-agent interactions
- almost stuck state
- Running state of an agent:
- current agent state (e.g., LOADING, RUNNING, PAUSED)
- traffic control state for rate limiting
- confirmation mode
- the last error encountered
- Data for saving and restoring the agent:
- save to and restore from a session
- serialize with pickle and base64
- Save / restore data about message history
- start and end IDs for events in agent's history
- summaries and delegate summaries
- Metrics:
- global metrics for the current task
- local metrics for the current subtask
- Extra data:
- additional task-specific data
"""
root_task: RootTask = field(default_factory=RootTask)
# global iteration for the current task
iteration: int = 0
# local iteration for the current subtask
local_iteration: int = 0
# max number of iterations for the current task
max_iterations: int = 100
confirmation_mode: bool = False
history: ShortTermHistory = field(default_factory=ShortTermHistory)
inputs: dict = field(default_factory=dict)
outputs: dict = field(default_factory=dict)
last_error: str | None = None
agent_state: AgentState = AgentState.LOADING
resume_state: AgentState | None = None
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
# global metrics for the current task
metrics: Metrics = field(default_factory=Metrics)
# local metrics for the current subtask
local_metrics: Metrics = field(default_factory=Metrics)
# root agent has level 0, and every delegate increases the level by one
delegate_level: int = 0
# start_id and end_id track the range of events in history
start_id: int = -1
end_id: int = -1
almost_stuck: int = 0
# NOTE: This will never be used by the controller, but it can be used by different
# evaluation tasks to store extra data needed to track the progress/state of the task.
extra_data: dict[str, Any] = field(default_factory=dict)
def save_to_session(self, sid: str, file_store: FileStore):
pickled = pickle.dumps(self)
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
encoded = base64.b64encode(pickled).decode('utf-8')
try:
file_store.write(f'sessions/{sid}/agent_state.pkl', encoded)
except Exception as e:
logger.error(f'Failed to save state to session: {e}')
raise e
@staticmethod
def restore_from_session(sid: str, file_store: FileStore) -> 'State':
try:
encoded = file_store.read(f'sessions/{sid}/agent_state.pkl')
pickled = base64.b64decode(encoded)
state = pickle.loads(pickled)
except Exception as e:
logger.error(f'Failed to restore state from session: {e}')
raise e
# update state
if state.agent_state in RESUMABLE_STATES:
state.resume_state = state.agent_state
else:
state.resume_state = None
# don't carry last_error anymore after restore
state.last_error = None
# first state after restore
state.agent_state = AgentState.LOADING
return state
def __getstate__(self):
state = self.__dict__.copy()
# save the relevant data from recent history
# so that we can restore it when the state is restored
if 'history' in state:
state['start_id'] = state['history'].start_id
state['end_id'] = state['history'].end_id
# don't save history object itself
state.pop('history', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
# recreate the history object
if not hasattr(self, 'history'):
self.history = ShortTermHistory()
# restore the relevant data in history from the state
self.history.start_id = self.start_id
self.history.end_id = self.end_id
# remove the restored data from the state if any
def get_current_user_intent(self):
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
last_user_message = None
last_user_message_image_urls: list[str] | None = []
for event in self.history.get_events(reverse=True):
if isinstance(event, MessageAction) and event.source == 'user':
last_user_message = event.content
last_user_message_image_urls = event.images_urls
elif isinstance(event, AgentFinishAction):
if last_user_message is not None:
return last_user_message
return last_user_message, last_user_message_image_urls

View File

@@ -0,0 +1,226 @@
from openhands.core.exceptions import (
LLMMalformedActionError,
TaskInvalidStateError,
)
from openhands.core.logger import openhands_logger as logger
OPEN_STATE = 'open'
COMPLETED_STATE = 'completed'
ABANDONED_STATE = 'abandoned'
IN_PROGRESS_STATE = 'in_progress'
VERIFIED_STATE = 'verified'
STATES = [
OPEN_STATE,
COMPLETED_STATE,
ABANDONED_STATE,
IN_PROGRESS_STATE,
VERIFIED_STATE,
]
class Task:
id: str
goal: str
parent: 'Task | None'
subtasks: list['Task']
def __init__(
self,
parent: 'Task',
goal: str,
state: str = OPEN_STATE,
subtasks=None, # noqa: B006
):
"""Initializes a new instance of the Task class.
Args:
parent: The parent task, or None if it is the root task.
goal: The goal of the task.
state: The initial state of the task.
subtasks: A list of subtasks associated with this task.
"""
if subtasks is None:
subtasks = []
if parent.id:
self.id = parent.id + '.' + str(len(parent.subtasks))
else:
self.id = str(len(parent.subtasks))
self.parent = parent
self.goal = goal
logger.debug(f'Creating task {self.id} with parent={parent.id}, goal={goal}')
self.subtasks = []
for subtask in subtasks or []:
if isinstance(subtask, Task):
self.subtasks.append(subtask)
else:
goal = subtask.get('goal')
state = subtask.get('state')
subtasks = subtask.get('subtasks')
logger.debug(f'Reading: {goal}, {state}, {subtasks}')
self.subtasks.append(Task(self, goal, state, subtasks))
self.state = OPEN_STATE
def to_string(self, indent=''):
"""Returns a string representation of the task and its subtasks.
Args:
indent: The indentation string for formatting the output.
Returns:
A string representation of the task and its subtasks.
"""
emoji = ''
if self.state == VERIFIED_STATE:
emoji = ''
elif self.state == COMPLETED_STATE:
emoji = '🟢'
elif self.state == ABANDONED_STATE:
emoji = ''
elif self.state == IN_PROGRESS_STATE:
emoji = '💪'
elif self.state == OPEN_STATE:
emoji = '🔵'
result = indent + emoji + ' ' + self.id + ' ' + self.goal + '\n'
for subtask in self.subtasks:
result += subtask.to_string(indent + ' ')
return result
def to_dict(self):
"""Returns a dictionary representation of the task.
Returns:
A dictionary containing the task's attributes.
"""
return {
'id': self.id,
'goal': self.goal,
'state': self.state,
'subtasks': [t.to_dict() for t in self.subtasks],
}
def set_state(self, state):
"""Sets the state of the task and its subtasks.
Args: state: The new state of the task.
Raises:
TaskInvalidStateError: If the provided state is invalid.
"""
if state not in STATES:
logger.error('Invalid state: %s', state)
raise TaskInvalidStateError(state)
self.state = state
if (
state == COMPLETED_STATE
or state == ABANDONED_STATE
or state == VERIFIED_STATE
):
for subtask in self.subtasks:
if subtask.state != ABANDONED_STATE:
subtask.set_state(state)
elif state == IN_PROGRESS_STATE:
if self.parent is not None:
self.parent.set_state(state)
def get_current_task(self) -> 'Task | None':
"""Retrieves the current task in progress.
Returns:
The current task in progress, or None if no task is in progress.
"""
for subtask in self.subtasks:
if subtask.state == IN_PROGRESS_STATE:
return subtask.get_current_task()
if self.state == IN_PROGRESS_STATE:
return self
return None
class RootTask(Task):
"""Serves as the root node in a tree of tasks.
Because we want the top-level of the root_task to be a list of tasks (1, 2, 3, etc.),
the "root node" of the data structure is kind of invisible--it just
holds references to the top-level tasks.
Attributes:
id: Kept blank for root_task
goal: Kept blank for root_task
parent: None for root_task
subtasks: The top-level list of tasks associated with the root_task.
state: The state of the root_task.
"""
id: str = ''
goal: str = ''
parent: None = None
def __init__(self):
self.subtasks = []
self.state = OPEN_STATE
def __str__(self):
"""Returns a string representation of the root_task.
Returns:
A string representation of the root_task.
"""
return self.to_string()
def get_task_by_id(self, id: str) -> Task:
"""Retrieves a task by its ID.
Args:
id: The ID of the task.
Returns:
The task with the specified ID.
Raises:
AgentMalformedActionError: If the provided task ID is invalid or does not exist.
"""
if id == '':
return self
if len(self.subtasks) == 0:
raise LLMMalformedActionError('Task does not exist:' + id)
try:
parts = [int(p) for p in id.split('.')]
except ValueError:
raise LLMMalformedActionError('Invalid task id:' + id)
task: Task = self
for part in parts:
if part >= len(task.subtasks):
raise LLMMalformedActionError('Task does not exist:' + id)
task = task.subtasks[part]
return task
def add_subtask(self, parent_id: str, goal: str, subtasks: list | None = None):
"""Adds a subtask to a parent task.
Args:
parent_id: The ID of the parent task.
goal: The goal of the subtask.
subtasks: A list of subtasks associated with the new subtask.
"""
subtasks = subtasks or []
parent = self.get_task_by_id(parent_id)
child = Task(parent=parent, goal=goal, subtasks=subtasks)
parent.subtasks.append(child)
def set_subtask_state(self, id: str, state: str):
"""Sets the state of a subtask.
Args:
id: The ID of the subtask.
state: The new state of the subtask.
"""
task = self.get_task_by_id(id)
logger.debug('Setting task {task.id} from state {task.state} to {state}')
task.set_state(state)
unfinished_tasks = [
t
for t in self.subtasks
if t.state not in [COMPLETED_STATE, VERIFIED_STATE, ABANDONED_STATE]
]
if len(unfinished_tasks) == 0:
self.set_state(COMPLETED_STATE)

View File

@@ -0,0 +1,237 @@
from typing import cast
from openhands.controller.state.state import State
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.empty import NullAction
from openhands.events.action.message import MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.observation.commands import (
CmdOutputObservation,
IPythonRunCellObservation,
)
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.observation import Observation
class StuckDetector:
def __init__(self, state: State):
self.state = state
def is_stuck(self):
# filter out MessageAction with source='user' from history
filtered_history = [
event
for event in self.state.history.get_events()
if not (
(isinstance(event, MessageAction) and event.source == EventSource.USER)
or
# there might be some NullAction or NullObservation in the history at least for now
isinstance(event, NullAction)
or isinstance(event, NullObservation)
)
]
# it takes 3 actions minimum to detect a loop, otherwise nothing to do here
if len(filtered_history) < 3:
return False
# the first few scenarios detect 3 or 4 repeated steps
# prepare the last 4 actions and observations, to check them out
last_actions: list[Event] = []
last_observations: list[Event] = []
# retrieve the last four actions and observations starting from the end of history, wherever they are
for event in reversed(filtered_history):
if isinstance(event, Action) and len(last_actions) < 4:
last_actions.append(event)
elif isinstance(event, Observation) and len(last_observations) < 4:
last_observations.append(event)
if len(last_actions) == 4 and len(last_observations) == 4:
break
# scenario 1: same action, same observation
if self._is_stuck_repeating_action_observation(last_actions, last_observations):
return True
# scenario 2: same action, errors
if self._is_stuck_repeating_action_error(last_actions, last_observations):
return True
# scenario 3: monologue
if self._is_stuck_monologue(filtered_history):
return True
# scenario 4: action, observation pattern on the last six steps
if len(filtered_history) < 6:
return False
if self._is_stuck_action_observation_pattern(filtered_history):
return True
return False
def _is_stuck_repeating_action_observation(self, last_actions, last_observations):
# scenario 1: same action, same observation
# it takes 4 actions and 4 observations to detect a loop
# assert len(last_actions) == 4 and len(last_observations) == 4
# reset almost_stuck reminder
self.state.almost_stuck = 0
# almost stuck? if two actions, obs are the same, we're almost stuck
if len(last_actions) >= 2 and len(last_observations) >= 2:
actions_equal = all(
self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
)
observations_equal = all(
self._eq_no_pid(last_observations[0], observation)
for observation in last_observations[:2]
)
# the last two actions and obs are the same?
if actions_equal and observations_equal:
self.state.almost_stuck = 2
# the last three actions and observations are the same?
if len(last_actions) >= 3 and len(last_observations) >= 3:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[2])
and self._eq_no_pid(last_observations[0], last_observations[2])
):
self.state.almost_stuck = 1
if len(last_actions) == 4 and len(last_observations) == 4:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[3])
and self._eq_no_pid(last_observations[0], last_observations[3])
):
logger.warning('Action, Observation loop detected')
self.state.almost_stuck = 0
return True
return False
def _is_stuck_repeating_action_error(self, last_actions, last_observations):
# scenario 2: same action, errors
# it takes 4 actions and 4 observations to detect a loop
# check if the last four actions are the same and result in errors
# are the last four actions the same?
if len(last_actions) == 4 and all(
self._eq_no_pid(last_actions[0], action) for action in last_actions
):
# and the last four observations all errors?
if all(isinstance(obs, ErrorObservation) for obs in last_observations):
logger.warning('Action, ErrorObservation loop detected')
return True
# or, are the last four observations all IPythonRunCellObservation with SyntaxError?
elif all(
isinstance(obs, IPythonRunCellObservation) for obs in last_observations
) and all(
cast(IPythonRunCellObservation, obs)
.content[-100:]
.find('SyntaxError: unterminated string literal (detected at line')
!= -1
and len(
cast(IPythonRunCellObservation, obs).content.split(
'SyntaxError: unterminated string literal (detected at line'
)[-1]
)
< 10
for obs in last_observations
):
logger.warning('Action, IPythonRunCellObservation loop detected')
return True
return False
def _is_stuck_monologue(self, filtered_history):
# scenario 3: monologue
# check for repeated MessageActions with source=AGENT
# see if the agent is engaged in a good old monologue, telling itself the same thing over and over
agent_message_actions = [
(i, event)
for i, event in enumerate(filtered_history)
if isinstance(event, MessageAction) and event.source == EventSource.AGENT
]
# last three message actions will do for this check
if len(agent_message_actions) >= 3:
last_agent_message_actions = agent_message_actions[-3:]
if all(
(last_agent_message_actions[0][1] == action[1])
for action in last_agent_message_actions
):
# check if there are any observations between the repeated MessageActions
# then it's not yet a loop, maybe it can recover
start_index = last_agent_message_actions[0][0]
end_index = last_agent_message_actions[-1][0]
has_observation_between = False
for event in filtered_history[start_index + 1 : end_index]:
if isinstance(event, Observation):
has_observation_between = True
break
if not has_observation_between:
logger.warning('Repeated MessageAction with source=AGENT detected')
return True
return False
def _is_stuck_action_observation_pattern(self, filtered_history):
# scenario 4: action, observation pattern on the last six steps
# check if the agent repeats the same (Action, Observation)
# every other step in the last six steps
last_six_actions: list[Event] = []
last_six_observations: list[Event] = []
# the end of history is most interesting
for event in reversed(filtered_history):
if isinstance(event, Action) and len(last_six_actions) < 6:
last_six_actions.append(event)
elif isinstance(event, Observation) and len(last_six_observations) < 6:
last_six_observations.append(event)
if len(last_six_actions) == 6 and len(last_six_observations) == 6:
break
# this pattern is every other step, like:
# (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
if len(last_six_actions) == 6 and len(last_six_observations) == 6:
actions_equal = (
# action_0 == action_2 == action_4
self._eq_no_pid(last_six_actions[0], last_six_actions[2])
and self._eq_no_pid(last_six_actions[0], last_six_actions[4])
# action_1 == action_3 == action_5
and self._eq_no_pid(last_six_actions[1], last_six_actions[3])
and self._eq_no_pid(last_six_actions[1], last_six_actions[5])
)
observations_equal = (
# obs_0 == obs_2 == obs_4
self._eq_no_pid(last_six_observations[0], last_six_observations[2])
and self._eq_no_pid(last_six_observations[0], last_six_observations[4])
# obs_1 == obs_3 == obs_5
and self._eq_no_pid(last_six_observations[1], last_six_observations[3])
and self._eq_no_pid(last_six_observations[1], last_six_observations[5])
)
if actions_equal and observations_equal:
logger.warning('Action, Observation pattern detected')
return True
return False
def _eq_no_pid(self, obj1, obj2):
if isinstance(obj1, CmdOutputObservation) and isinstance(
obj2, CmdOutputObservation
):
# for loop detection, ignore command_id, which is the pid
return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
else:
# this is the default comparison
return obj1 == obj2

762
openhands/core/config.py Normal file
View File

@@ -0,0 +1,762 @@
import argparse
import os
import pathlib
import platform
import uuid
from dataclasses import dataclass, field, fields, is_dataclass
from enum import Enum
from types import UnionType
from typing import Any, ClassVar, MutableMapping, get_args, get_origin
import toml
from dotenv import load_dotenv
from openhands.core import logger
from openhands.core.utils import Singleton
load_dotenv()
LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key']
_DEFAULT_AGENT = 'CodeActAgent'
_MAX_ITERATIONS = 100
@dataclass
class LLMConfig:
"""Configuration for the LLM model.
Attributes:
model: The model to use.
api_key: The API key to use.
base_url: The base URL for the API. This is necessary for local LLMs. It is also used for Azure embeddings.
api_version: The version of the API.
embedding_model: The embedding model to use.
embedding_base_url: The base URL for the embedding API.
embedding_deployment_name: The name of the deployment for the embedding API. This is used for Azure OpenAI.
aws_access_key_id: The AWS access key ID.
aws_secret_access_key: The AWS secret access key.
aws_region_name: The AWS region name.
num_retries: The number of retries to attempt.
retry_multiplier: The multiplier for the exponential backoff.
retry_min_wait: The minimum time to wait between retries, in seconds. This is exponential backoff minimum. For models with very low limits, this can be set to 15-20.
retry_max_wait: The maximum time to wait between retries, in seconds. This is exponential backoff maximum.
timeout: The timeout for the API.
max_message_chars: The approximate max number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated.
temperature: The temperature for the API.
top_p: The top p for the API.
custom_llm_provider: The custom LLM provider to use. This is undocumented in openhands, and normally not used. It is documented on the litellm side.
max_input_tokens: The maximum number of input tokens. Note that this is currently unused, and the value at runtime is actually the total tokens in OpenAI (e.g. 128,000 tokens for GPT-4).
max_output_tokens: The maximum number of output tokens. This is sent to the LLM.
input_cost_per_token: The cost per input token. This will available in logs for the user to check.
output_cost_per_token: The cost per output token. This will available in logs for the user to check.
ollama_base_url: The base URL for the OLLAMA API.
drop_params: Drop any unmapped (unsupported) params without causing an exception.
"""
model: str = 'gpt-4o'
api_key: str | None = None
base_url: str | None = None
api_version: str | None = None
embedding_model: str = 'local'
embedding_base_url: str | None = None
embedding_deployment_name: str | None = None
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None
aws_region_name: str | None = None
num_retries: int = 10
retry_multiplier: float = 2
retry_min_wait: int = 3
retry_max_wait: int = 300
timeout: int | None = None
max_message_chars: int = 10_000 # maximum number of characters in an observation's content when sent to the llm
temperature: float = 0
top_p: float = 0.5
custom_llm_provider: str | None = None
max_input_tokens: int | None = None
max_output_tokens: int | None = None
input_cost_per_token: float | None = None
output_cost_per_token: float | None = None
ollama_base_url: str | None = None
drop_params: bool | None = None
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
result[f.name] = get_field_info(f)
return result
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
if attr_name in LLM_SENSITIVE_FIELDS:
attr_value = '******' if attr_value else None
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"LLMConfig({', '.join(attr_str)})"
def __repr__(self):
return self.__str__()
def to_safe_dict(self):
"""Return a dict with the sensitive fields replaced with ******."""
ret = self.__dict__.copy()
for k, v in ret.items():
if k in LLM_SENSITIVE_FIELDS:
ret[k] = '******' if v else None
return ret
def set_missing_attributes(self):
"""Set any missing attributes to their default values."""
for field_name, field_obj in self.__dataclass_fields__.items():
if not hasattr(self, field_name):
setattr(self, field_name, field_obj.default)
@dataclass
class AgentConfig:
"""Configuration for the agent.
Attributes:
memory_enabled: Whether long-term memory (embeddings) is enabled.
memory_max_threads: The maximum number of threads indexing at the same time for embeddings.
llm_config: The name of the llm config to use. If specified, this will override global llm config.
"""
memory_enabled: bool = False
memory_max_threads: int = 2
llm_config: str | None = None
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
result[f.name] = get_field_info(f)
return result
@dataclass
class SecurityConfig(metaclass=Singleton):
"""Configuration for security related functionalities.
Attributes:
confirmation_mode: Whether to enable confirmation mode.
security_analyzer: The security analyzer to use.
"""
confirmation_mode: bool = False
security_analyzer: str | None = None
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
dict = {}
for f in fields(self):
dict[f.name] = get_field_info(f)
return dict
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"SecurityConfig({', '.join(attr_str)})"
def __repr__(self):
return self.__str__()
@dataclass
class SandboxConfig(metaclass=Singleton):
"""Configuration for the sandbox.
Attributes:
api_hostname: The hostname for the EventStream Runtime API.
container_image: The container image to use for the sandbox.
user_id: The user ID for the sandbox.
timeout: The timeout for the sandbox.
enable_auto_lint: Whether to enable auto-lint.
use_host_network: Whether to use the host network.
initialize_plugins: Whether to initialize plugins.
od_runtime_extra_deps: The extra dependencies to install in the runtime image (typically used for evaluation).
This will be rendered into the end of the Dockerfile that builds the runtime image.
It can contain any valid shell commands (e.g., pip install numpy).
The path to the interpreter is available as $OD_INTERPRETER_PATH,
which can be used to install dependencies for the OD-specific Python interpreter.
od_runtime_startup_env_vars: The environment variables to set at the launch of the runtime.
This is a dictionary of key-value pairs.
This is useful for setting environment variables that are needed by the runtime.
For example, for specifying the base url of website for browsergym evaluation.
browsergym_eval_env: The BrowserGym environment to use for evaluation.
Default is None for general purpose browsing. Check evaluation/miniwob and evaluation/webarena for examples.
"""
api_hostname: str = 'localhost'
container_image: str = 'nikolaik/python-nodejs:python3.11-nodejs22' # default to nikolaik/python-nodejs:python3.11-nodejs22 for eventstream runtime
user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
timeout: int = 120
enable_auto_lint: bool = (
False # once enabled, OpenHands would lint files after editing
)
use_host_network: bool = False
initialize_plugins: bool = True
od_runtime_extra_deps: str | None = None
od_runtime_startup_env_vars: dict[str, str] = field(default_factory=dict)
browsergym_eval_env: str | None = None
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
dict = {}
for f in fields(self):
dict[f.name] = get_field_info(f)
return dict
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"SandboxConfig({', '.join(attr_str)})"
def __repr__(self):
return self.__str__()
class UndefinedString(str, Enum):
UNDEFINED = 'UNDEFINED'
@dataclass
class AppConfig(metaclass=Singleton):
"""Configuration for the app.
Attributes:
llms: A dictionary of name -> LLM configuration. Default config is under 'llm' key.
agents: A dictionary of name -> Agent configuration. Default config is under 'agent' key.
default_agent: The name of the default agent to use.
sandbox: The sandbox configuration.
runtime: The runtime environment.
file_store: The file store to use.
file_store_path: The path to the file store.
workspace_base: The base path for the workspace. Defaults to ./workspace as an absolute path.
workspace_mount_path: The path to mount the workspace. This is set to the workspace base by default.
workspace_mount_path_in_sandbox: The path to mount the workspace in the sandbox. Defaults to /workspace.
workspace_mount_rewrite: The path to rewrite the workspace mount path to.
cache_dir: The path to the cache directory. Defaults to /tmp/cache.
run_as_openhands: Whether to run as openhands.
max_iterations: The maximum number of iterations.
max_budget_per_task: The maximum budget allowed per task, beyond which the agent will stop.
e2b_api_key: The E2B API key.
disable_color: Whether to disable color. For terminals that don't support color.
debug: Whether to enable debugging.
enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
"""
llms: dict[str, LLMConfig] = field(default_factory=dict)
agents: dict = field(default_factory=dict)
default_agent: str = _DEFAULT_AGENT
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
runtime: str = 'eventstream'
file_store: str = 'memory'
file_store_path: str = '/tmp/file_store'
# TODO: clean up workspace path after the removal of ServerRuntime
workspace_base: str = os.path.join(os.getcwd(), 'workspace')
workspace_mount_path: str | None = (
UndefinedString.UNDEFINED # this path should always be set when config is fully loaded
) # when set to None, do not mount the workspace
workspace_mount_path_in_sandbox: str = '/workspace'
workspace_mount_rewrite: str | None = None
cache_dir: str = '/tmp/cache'
run_as_openhands: bool = True
max_iterations: int = _MAX_ITERATIONS
max_budget_per_task: float | None = None
e2b_api_key: str = ''
disable_color: bool = False
jwt_secret: str = uuid.uuid4().hex
debug: bool = False
enable_cli_session: bool = False
file_uploads_max_file_size_mb: int = 0
file_uploads_restrict_file_types: bool = False
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
defaults_dict: ClassVar[dict] = {}
def get_llm_config(self, name='llm') -> LLMConfig:
"""Llm is the name for default config (for backward compatibility prior to 0.8)"""
if name in self.llms:
return self.llms[name]
if name is not None and name != 'llm':
logger.openhands_logger.warning(
f'llm config group {name} not found, using default config'
)
if 'llm' not in self.llms:
self.llms['llm'] = LLMConfig()
return self.llms['llm']
def set_llm_config(self, value: LLMConfig, name='llm'):
self.llms[name] = value
def get_agent_config(self, name='agent') -> AgentConfig:
"""Agent is the name for default config (for backward compability prior to 0.8)"""
if name in self.agents:
return self.agents[name]
if 'agent' not in self.agents:
self.agents['agent'] = AgentConfig()
return self.agents['agent']
def set_agent_config(self, value: AgentConfig, name='agent'):
self.agents[name] = value
def get_agent_to_llm_config_map(self) -> dict[str, LLMConfig]:
"""Get a map of agent names to llm configs."""
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
def get_llm_config_from_agent(self, name='agent') -> LLMConfig:
agent_config: AgentConfig = self.get_agent_config(name)
llm_config_name = agent_config.llm_config
return self.get_llm_config(llm_config_name)
def get_agent_configs(self) -> dict[str, AgentConfig]:
return self.agents
def __post_init__(self):
"""Post-initialization hook, called when the instance is created with only default values."""
AppConfig.defaults_dict = self.defaults_to_dict()
def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
result = {}
for f in fields(self):
field_value = getattr(self, f.name)
# dataclasses compute their defaults themselves
if is_dataclass(type(field_value)):
result[f.name] = field_value.defaults_to_dict()
else:
result[f.name] = get_field_info(f)
return result
def __str__(self):
attr_str = []
for f in fields(self):
attr_name = f.name
attr_value = getattr(self, f.name)
if attr_name in [
'e2b_api_key',
'github_token',
'jwt_secret',
]:
attr_value = '******' if attr_value else None
attr_str.append(f'{attr_name}={repr(attr_value)}')
return f"AppConfig({', '.join(attr_str)}"
def __repr__(self):
return self.__str__()
def get_field_info(f):
"""Extract information about a dataclass field: type, optional, and default.
Args:
f: The field to extract information from.
Returns: A dict with the field's type, whether it's optional, and its default value.
"""
field_type = f.type
optional = False
# for types like str | None, find the non-None type and set optional to True
# this is useful for the frontend to know if a field is optional
# and to show the correct type in the UI
# Note: this only works for UnionTypes with None as one of the types
if get_origin(field_type) is UnionType:
types = get_args(field_type)
non_none_arg = next((t for t in types if t is not type(None)), None)
if non_none_arg is not None:
field_type = non_none_arg
optional = True
# type name in a pretty format
type_name = (
field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
)
# default is always present
default = f.default
# return a schema with the useful info for frontend
return {'type': type_name.lower(), 'optional': optional, 'default': default}
def load_from_env(cfg: AppConfig, env_or_toml_dict: dict | MutableMapping[str, str]):
"""Reads the env-style vars and sets config attributes based on env vars or a config.toml dict.
Compatibility with vars like LLM_BASE_URL, AGENT_MEMORY_ENABLED, SANDBOX_TIMEOUT and others.
Args:
cfg: The AppConfig object to set attributes on.
env_or_toml_dict: The environment variables or a config.toml dict.
"""
def get_optional_type(union_type: UnionType) -> Any:
"""Returns the non-None type from a Union."""
types = get_args(union_type)
return next((t for t in types if t is not type(None)), None)
# helper function to set attributes based on env vars
def set_attr_from_env(sub_config: Any, prefix=''):
"""Set attributes of a config dataclass based on environment variables."""
for field_name, field_type in sub_config.__annotations__.items():
# compute the expected env var name from the prefix and field name
# e.g. LLM_BASE_URL
env_var_name = (prefix + field_name).upper()
if is_dataclass(field_type):
# nested dataclass
nested_sub_config = getattr(sub_config, field_name)
set_attr_from_env(nested_sub_config, prefix=field_name + '_')
elif env_var_name in env_or_toml_dict:
# convert the env var to the correct type and set it
value = env_or_toml_dict[env_var_name]
# skip empty config values (fall back to default)
if not value:
continue
try:
# if it's an optional type, get the non-None type
if get_origin(field_type) is UnionType:
field_type = get_optional_type(field_type)
# Attempt to cast the env var to type hinted in the dataclass
if field_type is bool:
cast_value = str(value).lower() in ['true', '1']
else:
cast_value = field_type(value)
setattr(sub_config, field_name, cast_value)
except (ValueError, TypeError):
logger.openhands_logger.error(
f'Error setting env var {env_var_name}={value}: check that the value is of the right type'
)
# Start processing from the root of the config object
set_attr_from_env(cfg)
# load default LLM config from env
default_llm_config = cfg.get_llm_config()
set_attr_from_env(default_llm_config, 'LLM_')
# load default agent config from env
default_agent_config = cfg.get_agent_config()
set_attr_from_env(default_agent_config, 'AGENT_')
def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
"""Load the config from the toml file. Supports both styles of config vars.
Args:
cfg: The AppConfig object to update attributes of.
toml_file: The path to the toml file. Defaults to 'config.toml'.
"""
# try to read the config.toml file into the config object
try:
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
toml_config = toml.load(toml_contents)
except FileNotFoundError:
return
except toml.TomlDecodeError as e:
logger.openhands_logger.warning(
f'Cannot parse config from toml, toml values have not been applied.\nError: {e}',
exc_info=False,
)
return
# if there was an exception or core is not in the toml, try to use the old-style toml
if 'core' not in toml_config:
# re-use the env loader to set the config from env-style vars
load_from_env(cfg, toml_config)
return
core_config = toml_config['core']
# load llm configs and agent configs
for key, value in toml_config.items():
if isinstance(value, dict):
try:
if key is not None and key.lower() == 'agent':
logger.openhands_logger.info(
'Attempt to load default agent config from config toml'
)
non_dict_fields = {
k: v for k, v in value.items() if not isinstance(v, dict)
}
agent_config = AgentConfig(**non_dict_fields)
cfg.set_agent_config(agent_config, 'agent')
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
logger.openhands_logger.info(
f'Attempt to load group {nested_key} from config toml as agent config'
)
agent_config = AgentConfig(**nested_value)
cfg.set_agent_config(agent_config, nested_key)
elif key is not None and key.lower() == 'llm':
logger.openhands_logger.info(
'Attempt to load default LLM config from config toml'
)
non_dict_fields = {
k: v for k, v in value.items() if not isinstance(v, dict)
}
llm_config = LLMConfig(**non_dict_fields)
cfg.set_llm_config(llm_config, 'llm')
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
logger.openhands_logger.info(
f'Attempt to load group {nested_key} from config toml as llm config'
)
llm_config = LLMConfig(**nested_value)
cfg.set_llm_config(llm_config, nested_key)
elif not key.startswith('sandbox') and key.lower() != 'core':
logger.openhands_logger.warning(
f'Unknown key in {toml_file}: "{key}"'
)
except (TypeError, KeyError) as e:
logger.openhands_logger.warning(
f'Cannot parse config from toml, toml values have not been applied.\n Error: {e}',
exc_info=False,
)
else:
logger.openhands_logger.warning(f'Unknown key in {toml_file}: "{key}')
try:
# set sandbox config from the toml file
sandbox_config = cfg.sandbox
# migrate old sandbox configs from [core] section to sandbox config
keys_to_migrate = [key for key in core_config if key.startswith('sandbox_')]
for key in keys_to_migrate:
new_key = key.replace('sandbox_', '')
if new_key in sandbox_config.__annotations__:
# read the key in sandbox and remove it from core
setattr(sandbox_config, new_key, core_config.pop(key))
else:
logger.openhands_logger.warning(f'Unknown sandbox config: {key}')
# the new style values override the old style values
if 'sandbox' in toml_config:
sandbox_config = SandboxConfig(**toml_config['sandbox'])
# update the config object with the new values
AppConfig(sandbox=sandbox_config, **core_config)
except (TypeError, KeyError) as e:
logger.openhands_logger.warning(
f'Cannot parse config from toml, toml values have not been applied.\nError: {e}',
exc_info=False,
)
def finalize_config(cfg: AppConfig):
"""More tweaks to the config after it's been loaded."""
# Set workspace_mount_path if not set by the user
if cfg.workspace_mount_path is UndefinedString.UNDEFINED:
cfg.workspace_mount_path = os.path.abspath(cfg.workspace_base)
cfg.workspace_base = os.path.abspath(cfg.workspace_base)
if cfg.workspace_mount_rewrite: # and not config.workspace_mount_path:
# TODO why do we need to check if workspace_mount_path is None?
base = cfg.workspace_base or os.getcwd()
parts = cfg.workspace_mount_rewrite.split(':')
cfg.workspace_mount_path = base.replace(parts[0], parts[1])
for llm in cfg.llms.values():
if llm.embedding_base_url is None:
llm.embedding_base_url = llm.base_url
if cfg.sandbox.use_host_network and platform.system() == 'Darwin':
logger.openhands_logger.warning(
'Please upgrade to Docker Desktop 4.29.0 or later to use host network mode on macOS. '
'See https://github.com/docker/roadmap/issues/238#issuecomment-2044688144 for more information.'
)
# make sure cache dir exists
if cfg.cache_dir:
pathlib.Path(cfg.cache_dir).mkdir(parents=True, exist_ok=True)
# Utility function for command line --group argument
def get_llm_config_arg(
llm_config_arg: str, toml_file: str = 'config.toml'
) -> LLMConfig | None:
"""Get a group of llm settings from the config file.
A group in config.toml can look like this:
```
[llm.gpt-3.5-for-eval]
model = 'gpt-3.5-turbo'
api_key = '...'
temperature = 0.5
num_retries = 10
...
```
The user-defined group name, like "gpt-3.5-for-eval", is the argument to this function. The function will load the LLMConfig object
with the settings of this group, from the config file, and set it as the LLMConfig object for the app.
Note that the group must be under "llm" group, or in other words, the group name must start with "llm.".
Args:
llm_config_arg: The group of llm settings to get from the config.toml file.
Returns:
LLMConfig: The LLMConfig object with the settings from the config file.
"""
# keep only the name, just in case
llm_config_arg = llm_config_arg.strip('[]')
# truncate the prefix, just in case
if llm_config_arg.startswith('llm.'):
llm_config_arg = llm_config_arg[4:]
logger.openhands_logger.info(f'Loading llm config from {llm_config_arg}')
# load the toml file
try:
with open(toml_file, 'r', encoding='utf-8') as toml_contents:
toml_config = toml.load(toml_contents)
except FileNotFoundError as e:
logger.openhands_logger.error(f'Config file not found: {e}')
return None
except toml.TomlDecodeError as e:
logger.openhands_logger.error(
f'Cannot parse llm group from {llm_config_arg}. Exception: {e}'
)
return None
# update the llm config with the specified section
if 'llm' in toml_config and llm_config_arg in toml_config['llm']:
return LLMConfig(**toml_config['llm'][llm_config_arg])
logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}')
return None
# Command line arguments
def get_parser() -> argparse.ArgumentParser:
"""Get the parser for the command line arguments."""
parser = argparse.ArgumentParser(description='Run an agent with a specific task')
parser.add_argument(
'-d',
'--directory',
type=str,
help='The working directory for the agent',
)
parser.add_argument(
'-t',
'--task',
type=str,
default='',
help='The task for the agent to perform',
)
parser.add_argument(
'-f',
'--file',
type=str,
help='Path to a file containing the task. Overrides -t if both are provided.',
)
parser.add_argument(
'-c',
'--agent-cls',
default=_DEFAULT_AGENT,
type=str,
help='Name of the default agent to use',
)
parser.add_argument(
'-i',
'--max-iterations',
default=_MAX_ITERATIONS,
type=int,
help='The maximum number of iterations to run the agent',
)
parser.add_argument(
'-b',
'--max-budget-per-task',
type=float,
help='The maximum budget allowed per task, beyond which the agent will stop.',
)
# --eval configs are for evaluations only
parser.add_argument(
'--eval-output-dir',
default='evaluation/evaluation_outputs/outputs',
type=str,
help='The directory to save evaluation output',
)
parser.add_argument(
'--eval-n-limit',
default=None,
type=int,
help='The number of instances to evaluate',
)
parser.add_argument(
'--eval-num-workers',
default=4,
type=int,
help='The number of workers to use for evaluation',
)
parser.add_argument(
'--eval-note',
default=None,
type=str,
help='The note to add to the evaluation directory',
)
parser.add_argument(
'-l',
'--llm-config',
default=None,
type=str,
help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml',
)
parser.add_argument(
'-n',
'--name',
default='default',
type=str,
help='Name for the session',
)
return parser
def parse_arguments() -> argparse.Namespace:
"""Parse the command line arguments."""
parser = get_parser()
parsed_args, _ = parser.parse_known_args()
return parsed_args
def load_app_config(set_logging_levels: bool = True) -> AppConfig:
"""Load the configuration from the config.toml file and environment variables.
Args:
set_logger_levels: Whether to set the global variables for logging levels.
"""
config = AppConfig()
load_from_toml(config)
load_from_env(config, os.environ)
finalize_config(config)
if set_logging_levels:
logger.DEBUG = config.debug
logger.DISABLE_COLOR_PRINTING = config.disable_color
return config

View File

@@ -0,0 +1 @@
TROUBLESHOOTING_URL = 'https://docs.all-hands.dev/modules/usage/troubleshooting'

View File

@@ -0,0 +1,2 @@
# Run this file to trigger a model download
import agenthub # noqa F401 (we import this to get the agents registered)

View File

@@ -0,0 +1,74 @@
class AgentNoInstructionError(Exception):
def __init__(self, message='Instruction must be provided'):
super().__init__(message)
class AgentEventTypeError(Exception):
def __init__(self, message='Event must be a dictionary'):
super().__init__(message)
class AgentAlreadyRegisteredError(Exception):
def __init__(self, name=None):
if name is not None:
message = f"Agent class already registered under '{name}'"
else:
message = 'Agent class already registered'
super().__init__(message)
class AgentNotRegisteredError(Exception):
def __init__(self, name=None):
if name is not None:
message = f"No agent class registered under '{name}'"
else:
message = 'No agent class registered'
super().__init__(message)
class TaskInvalidStateError(Exception):
def __init__(self, state=None):
if state is not None:
message = f'Invalid state {state}'
else:
message = 'Invalid state'
super().__init__(message)
class BrowserInitException(Exception):
def __init__(self, message='Failed to initialize browser environment'):
super().__init__(message)
class BrowserUnavailableException(Exception):
def __init__(
self,
message='Browser environment is not available, please check if has been initialized',
):
super().__init__(message)
# This exception gets sent back to the LLM
# It might be malformed JSON
class LLMMalformedActionError(Exception):
def __init__(self, message='Malformed response'):
super().__init__(message)
# This exception gets sent back to the LLM
# For some reason, the agent did not return an action
class LLMNoActionError(Exception):
def __init__(self, message='Agent must return an action'):
super().__init__(message)
# This exception gets sent back to the LLM
# The LLM output did not include an action, or the action was not the expected type
class LLMResponseError(Exception):
def __init__(self, message='Failed to retrieve action from LLM response'):
super().__init__(message)
class UserCancelledError(Exception):
def __init__(self, message='User cancelled the request'):
super().__init__(message)

255
openhands/core/logger.py Normal file
View File

@@ -0,0 +1,255 @@
import logging
import os
import re
import sys
import traceback
from datetime import datetime
from typing import Literal, Mapping
from termcolor import colored
DISABLE_COLOR_PRINTING = False
DEBUG = os.getenv('DEBUG', 'False').lower() in ['true', '1', 'yes']
LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'False').lower() in ['true', '1', 'yes']
ColorType = Literal[
'red',
'green',
'yellow',
'blue',
'magenta',
'cyan',
'light_grey',
'dark_grey',
'light_red',
'light_green',
'light_yellow',
'light_blue',
'light_magenta',
'light_cyan',
'white',
]
LOG_COLORS: Mapping[str, ColorType] = {
'ACTION': 'green',
'USER_ACTION': 'light_red',
'OBSERVATION': 'yellow',
'USER_OBSERVATION': 'light_green',
'DETAIL': 'cyan',
'ERROR': 'red',
'PLAN': 'light_magenta',
}
class ColoredFormatter(logging.Formatter):
def format(self, record):
msg_type = record.__dict__.get('msg_type')
event_source = record.__dict__.get('event_source')
if event_source:
new_msg_type = f'{event_source.upper()}_{msg_type}'
if new_msg_type in LOG_COLORS:
msg_type = new_msg_type
if msg_type in LOG_COLORS and not DISABLE_COLOR_PRINTING:
msg_type_color = colored(msg_type, LOG_COLORS[msg_type])
msg = colored(record.msg, LOG_COLORS[msg_type])
time_str = colored(
self.formatTime(record, self.datefmt), LOG_COLORS[msg_type]
)
name_str = colored(record.name, LOG_COLORS[msg_type])
level_str = colored(record.levelname, LOG_COLORS[msg_type])
if msg_type in ['ERROR'] or DEBUG:
return f'{time_str} - {name_str}:{level_str}: {record.filename}:{record.lineno}\n{msg_type_color}\n{msg}'
return f'{time_str} - {msg_type_color}\n{msg}'
elif msg_type == 'STEP':
msg = '\n\n==============\n' + record.msg + '\n'
return f'{msg}'
return super().format(record)
console_formatter = ColoredFormatter(
'\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s',
datefmt='%H:%M:%S',
)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s:%(levelname)s: %(filename)s:%(lineno)s - %(message)s',
datefmt='%H:%M:%S',
)
llm_formatter = logging.Formatter('%(message)s')
class SensitiveDataFilter(logging.Filter):
def filter(self, record):
# start with attributes
sensitive_patterns = [
'api_key',
'aws_access_key_id',
'aws_secret_access_key',
'e2b_api_key',
'github_token',
'jwt_secret',
]
# add env var names
env_vars = [attr.upper() for attr in sensitive_patterns]
sensitive_patterns.extend(env_vars)
# and some special cases
sensitive_patterns.append('JWT_SECRET')
sensitive_patterns.append('LLM_API_KEY')
sensitive_patterns.append('GITHUB_TOKEN')
sensitive_patterns.append('SANDBOX_ENV_GITHUB_TOKEN')
# this also formats the message with % args
msg = record.getMessage()
record.args = ()
for attr in sensitive_patterns:
pattern = rf"{attr}='?([\w-]+)'?"
msg = re.sub(pattern, f"{attr}='******'", msg)
# passed with msg
record.msg = msg
return True
def get_console_handler():
"""Returns a console handler for logging."""
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
if DEBUG:
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(console_formatter)
return console_handler
def get_file_handler(log_dir):
"""Returns a file handler for logging."""
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y-%m-%d')
file_name = f'openhands_{timestamp}.log'
file_handler = logging.FileHandler(os.path.join(log_dir, file_name))
if DEBUG:
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(file_formatter)
return file_handler
# Set up logging
logging.basicConfig(level=logging.ERROR)
def log_uncaught_exceptions(ex_cls, ex, tb):
"""Logs uncaught exceptions along with the traceback.
Args:
ex_cls (type): The type of the exception.
ex (Exception): The exception instance.
tb (traceback): The traceback object.
Returns:
None
"""
logging.error(''.join(traceback.format_tb(tb)))
logging.error('{0}: {1}'.format(ex_cls, ex))
sys.excepthook = log_uncaught_exceptions
openhands_logger = logging.getLogger('openhands')
openhands_logger.setLevel(logging.INFO)
LOG_DIR = os.path.join(
# parent dir of openhands/core (i.e., root of the repo)
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
'logs',
)
if DEBUG:
openhands_logger.setLevel(logging.DEBUG)
if LOG_TO_FILE:
# default log to project root
openhands_logger.info('Logging to file is enabled. Logging to %s', LOG_DIR)
openhands_logger.addHandler(get_file_handler(LOG_DIR))
openhands_logger.addHandler(get_console_handler())
openhands_logger.addFilter(SensitiveDataFilter(openhands_logger.name))
openhands_logger.propagate = False
openhands_logger.debug('Logging initialized')
# Exclude LiteLLM from logging output
logging.getLogger('LiteLLM').disabled = True
logging.getLogger('LiteLLM Router').disabled = True
logging.getLogger('LiteLLM Proxy').disabled = True
class LlmFileHandler(logging.FileHandler):
"""# LLM prompt and response logging"""
def __init__(self, filename, mode='a', encoding='utf-8', delay=False):
"""Initializes an instance of LlmFileHandler.
Args:
filename (str): The name of the log file.
mode (str, optional): The file mode. Defaults to 'a'.
encoding (str, optional): The file encoding. Defaults to None.
delay (bool, optional): Whether to delay file opening. Defaults to False.
"""
self.filename = filename
self.message_counter = 1
if DEBUG:
self.session = datetime.now().strftime('%y-%m-%d_%H-%M')
else:
self.session = 'default'
self.log_directory = os.path.join(LOG_DIR, 'llm', self.session)
os.makedirs(self.log_directory, exist_ok=True)
if not DEBUG:
# Clear the log directory if not in debug mode
for file in os.listdir(self.log_directory):
file_path = os.path.join(self.log_directory, file)
try:
os.unlink(file_path)
except Exception as e:
openhands_logger.error(
'Failed to delete %s. Reason: %s', file_path, e
)
filename = f'{self.filename}_{self.message_counter:03}.log'
self.baseFilename = os.path.join(self.log_directory, filename)
super().__init__(self.baseFilename, mode, encoding, delay)
def emit(self, record):
"""Emits a log record.
Args:
record (logging.LogRecord): The log record to emit.
"""
filename = f'{self.filename}_{self.message_counter:03}.log'
self.baseFilename = os.path.join(self.log_directory, filename)
self.stream = self._open()
super().emit(record)
self.stream.close()
openhands_logger.debug('Logging to %s', self.baseFilename)
self.message_counter += 1
def _get_llm_file_handler(name, debug_level=logging.DEBUG):
# The 'delay' parameter, when set to True, postpones the opening of the log file
# until the first log message is emitted.
llm_file_handler = LlmFileHandler(name, delay=True)
llm_file_handler.setFormatter(llm_formatter)
llm_file_handler.setLevel(debug_level)
return llm_file_handler
def _setup_llm_logger(name, debug_level=logging.DEBUG):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(debug_level)
if LOG_TO_FILE:
logger.addHandler(_get_llm_file_handler(name, debug_level))
return logger
llm_prompt_logger = _setup_llm_logger('prompt', logging.DEBUG)
llm_response_logger = _setup_llm_logger('response', logging.DEBUG)

259
openhands/core/main.py Normal file
View File

@@ -0,0 +1,259 @@
import asyncio
import hashlib
import sys
import uuid
from typing import Callable, Protocol, Type
import agenthub # noqa F401 (we import this to get the agents registered)
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import (
AppConfig,
get_llm_config_arg,
load_app_config,
parse_arguments,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import MessageAction
from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.runtime import Runtime
from openhands.storage import get_file_store
class FakeUserResponseFunc(Protocol):
def __call__(
self,
state: State,
encapsulate_solution: bool = ...,
try_parse: Callable[[Action], str] = ...,
) -> str: ...
def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
def read_task_from_stdin() -> str:
"""Read task from stdin."""
return sys.stdin.read()
async def create_runtime(
config: AppConfig,
sid: str | None = None,
runtime_tools_config: dict | None = None,
) -> Runtime:
"""Create a runtime for the agent to run on.
config: The app config.
sid: The session id.
runtime_tools_config: (will be deprecated) The runtime tools config.
"""
# if sid is provided on the command line, use it as the name of the event stream
# otherwise generate it on the basis of the configured jwt_secret
# we can do this better, this is just so that the sid is retrieved when we want to restore the session
session_id = sid or generate_sid(config)
# set up the event stream
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(session_id, file_store)
# agent class
agent_cls = agenthub.Agent.get_cls(config.default_agent)
# runtime and tools
runtime_cls = get_runtime_cls(config.runtime)
logger.info(f'Initializing runtime: {runtime_cls}')
runtime: Runtime = runtime_cls(
config=config,
event_stream=event_stream,
sid=session_id,
plugins=agent_cls.sandbox_plugins,
)
await runtime.ainit()
return runtime
async def run_controller(
config: AppConfig,
task_str: str,
sid: str | None = None,
runtime: Runtime | None = None,
agent: Agent | None = None,
exit_on_message: bool = False,
fake_user_response_fn: FakeUserResponseFunc | None = None,
headless_mode: bool = True,
) -> State | None:
"""Main coroutine to run the agent controller with task input flexibility.
It's only used when you launch openhands backend directly via cmdline.
Args:
config: The app config.
task_str: The task to run. It can be a string.
runtime: (optional) A runtime for the agent to run on.
agent: (optional) A agent to run.
exit_on_message: quit if agent asks for a message from user (optional)
fake_user_response_fn: An optional function that receives the current state
(could be None) and returns a fake user response.
headless_mode: Whether the agent is run in headless mode.
"""
# Create the agent
if agent is None:
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
agent = agent_cls(
llm=LLM(config=llm_config),
config=agent_config,
)
# make sure the session id is set
sid = sid or generate_sid(config)
if runtime is None:
runtime = await create_runtime(config, sid=sid)
event_stream = runtime.event_stream
# restore cli session if enabled
initial_state = None
if config.enable_cli_session:
try:
logger.info(f'Restoring agent state from cli session {event_stream.sid}')
initial_state = State.restore_from_session(
event_stream.sid, event_stream.file_store
)
except Exception as e:
logger.info(f'Error restoring state: {e}')
# init controller with this initial state
controller = AgentController(
agent=agent,
max_iterations=config.max_iterations,
max_budget_per_task=config.max_budget_per_task,
agent_to_llm_config=config.get_agent_to_llm_config_map(),
event_stream=event_stream,
initial_state=initial_state,
headless_mode=headless_mode,
)
assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
# Logging
logger.info(
f'Agent Controller Initialized: Running agent {agent.name}, model '
f'{agent.llm.config.model}, with task: "{task_str}"'
)
# start event is a MessageAction with the task, either resumed or new
if config.enable_cli_session and initial_state is not None:
# we're resuming the previous session
event_stream.add_event(
MessageAction(
content=(
"Let's get back on track. If you experienced errors before, do "
'NOT resume your task. Ask me about it.'
),
),
EventSource.USER,
)
elif initial_state is None:
# init with the provided task
event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
async def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.AWAITING_USER_INPUT:
if exit_on_message:
message = '/exit'
elif fake_user_response_fn is None:
message = input('Request user input >> ')
else:
message = fake_user_response_fn(controller.get_state())
action = MessageAction(content=message)
event_stream.add_event(action, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
while controller.state.agent_state not in [
AgentState.FINISHED,
AgentState.REJECTED,
AgentState.ERROR,
AgentState.PAUSED,
AgentState.STOPPED,
]:
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
# save session when we're about to close
if config.enable_cli_session:
end_state = controller.get_state()
end_state.save_to_session(event_stream.sid, event_stream.file_store)
# close when done
await controller.close()
state = controller.get_state()
return state
def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
"""Generate a session id based on the session name and the jwt secret."""
session_name = session_name or str(uuid.uuid4())
jwt_secret = config.jwt_secret
hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest()
return f'{session_name}_{hash_str[:16]}'
if __name__ == '__main__':
args = parse_arguments()
# Determine the task
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()
else:
raise ValueError('No task provided. Please specify a task through -t, -f.')
# Load the app config
# this will load config from config.toml in the current directory
# as well as from the environment variables
config = load_app_config()
# Override default LLM configs ([llm] section in config.toml)
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
if llm_config is None:
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
config.set_llm_config(llm_config)
# Set default agent
config.default_agent = args.agent_cls
# Set session name
session_name = args.name
sid = generate_sid(config, session_name)
# if max budget per task is not sent on the command line, use the config value
if args.max_budget_per_task is not None:
config.max_budget_per_task = args.max_budget_per_task
if args.max_iterations is not None:
config.max_iterations = args.max_iterations
asyncio.run(
run_controller(
config=config,
task_str=task_str,
sid=sid,
)
)

59
openhands/core/message.py Normal file
View File

@@ -0,0 +1,59 @@
from enum import Enum
from pydantic import BaseModel, Field, model_serializer
from typing_extensions import Literal
class ContentType(Enum):
TEXT = 'text'
IMAGE_URL = 'image_url'
class Content(BaseModel):
type: ContentType
@model_serializer
def serialize_model(self):
raise NotImplementedError('Subclasses should implement this method.')
class TextContent(Content):
type: ContentType = ContentType.TEXT
text: str
@model_serializer
def serialize_model(self):
return {'type': self.type.value, 'text': self.text}
class ImageContent(Content):
type: ContentType = ContentType.IMAGE_URL
image_urls: list[str]
@model_serializer
def serialize_model(self):
images: list[dict[str, str | dict[str, str]]] = []
for url in self.image_urls:
images.append({'type': self.type.value, 'image_url': {'url': url}})
return images
class Message(BaseModel):
role: Literal['user', 'system', 'assistant']
content: list[TextContent | ImageContent] = Field(default=list)
@property
def contains_image(self) -> bool:
return any(isinstance(content, ImageContent) for content in self.content)
@model_serializer
def serialize_model(self) -> dict:
content: list[dict[str, str | dict[str, str]]] = []
for item in self.content:
if isinstance(item, TextContent):
content.append(item.model_dump())
elif isinstance(item, ImageContent):
content.extend(item.model_dump())
return {'role': self.role, 'content': content}

48
openhands/core/metrics.py Normal file
View File

@@ -0,0 +1,48 @@
class Metrics:
"""Metrics class can record various metrics during running and evaluation.
Currently, we define the following metrics:
accumulated_cost: the total cost (USD $) of the current LLM.
"""
def __init__(self) -> None:
self._accumulated_cost: float = 0.0
self._costs: list[float] = []
@property
def accumulated_cost(self) -> float:
return self._accumulated_cost
@accumulated_cost.setter
def accumulated_cost(self, value: float) -> None:
if value < 0:
raise ValueError('Total cost cannot be negative.')
self._accumulated_cost = value
@property
def costs(self) -> list:
return self._costs
def add_cost(self, value: float) -> None:
if value < 0:
raise ValueError('Added cost cannot be negative.')
self._accumulated_cost += value
self._costs.append(value)
def merge(self, other: 'Metrics') -> None:
self._accumulated_cost += other.accumulated_cost
self._costs += other._costs
def get(self):
"""Return the metrics in a dictionary."""
return {'accumulated_cost': self._accumulated_cost, 'costs': self._costs}
def log(self):
"""Log the metrics."""
metrics = self.get()
logs = ''
for key, value in metrics.items():
logs += f'{key}: {value}\n'
return logs
def __repr__(self):
return f'Metrics({self.get()}'

View File

@@ -0,0 +1,11 @@
from .action import ActionType
from .agent import AgentState
from .config import ConfigType
from .observation import ObservationType
__all__ = [
'ActionType',
'ObservationType',
'ConfigType',
'AgentState',
]

View File

@@ -0,0 +1,86 @@
from pydantic import BaseModel, Field
__all__ = ['ActionType']
class ActionTypeSchema(BaseModel):
INIT: str = Field(default='initialize')
"""Initializes the agent. Only sent by client.
"""
MESSAGE: str = Field(default='message')
"""Represents a message.
"""
START: str = Field(default='start')
"""Starts a new development task OR send chat from the user. Only sent by the client.
"""
READ: str = Field(default='read')
"""Reads the content of a file.
"""
WRITE: str = Field(default='write')
"""Writes the content to a file.
"""
RUN: str = Field(default='run')
"""Runs a command.
"""
RUN_IPYTHON: str = Field(default='run_ipython')
"""Runs a IPython cell.
"""
BROWSE: str = Field(default='browse')
"""Opens a web page.
"""
BROWSE_INTERACTIVE: str = Field(default='browse_interactive')
"""Interact with the browser instance.
"""
DELEGATE: str = Field(default='delegate')
"""Delegates a task to another agent.
"""
FINISH: str = Field(default='finish')
"""If you're absolutely certain that you've completed your task and have tested your work,
use the finish action to stop working.
"""
REJECT: str = Field(default='reject')
"""If you're absolutely certain that you cannot complete the task with given requirements,
use the reject action to stop working.
"""
NULL: str = Field(default='null')
SUMMARIZE: str = Field(default='summarize')
ADD_TASK: str = Field(default='add_task')
MODIFY_TASK: str = Field(default='modify_task')
PAUSE: str = Field(default='pause')
"""Pauses the task.
"""
RESUME: str = Field(default='resume')
"""Resumes the task.
"""
STOP: str = Field(default='stop')
"""Stops the task. Must send a start action to restart a new task.
"""
CHANGE_AGENT_STATE: str = Field(default='change_agent_state')
PUSH: str = Field(default='push')
"""Push a branch to github."""
SEND_PR: str = Field(default='send_pr')
"""Send a PR to github."""
ActionType = ActionTypeSchema()

View File

@@ -0,0 +1,51 @@
from enum import Enum
class AgentState(str, Enum):
LOADING = 'loading'
"""The agent is loading.
"""
INIT = 'init'
"""The agent is initialized.
"""
RUNNING = 'running'
"""The agent is running.
"""
AWAITING_USER_INPUT = 'awaiting_user_input'
"""The agent is awaiting user input.
"""
PAUSED = 'paused'
"""The agent is paused.
"""
STOPPED = 'stopped'
"""The agent is stopped.
"""
FINISHED = 'finished'
"""The agent is finished with the current task.
"""
REJECTED = 'rejected'
"""The agent rejects the task.
"""
ERROR = 'error'
"""An error occurred during the task.
"""
AWAITING_USER_CONFIRMATION = 'awaiting_user_confirmation'
"""The agent is awaiting user confirmation.
"""
USER_CONFIRMED = 'user_confirmed'
"""The user confirmed the agent's action.
"""
USER_REJECTED = 'user_rejected'
"""The user rejected the agent's action.
"""

View File

@@ -0,0 +1,47 @@
from enum import Enum
class ConfigType(str, Enum):
# For frontend
LLM_CUSTOM_LLM_PROVIDER = 'LLM_CUSTOM_LLM_PROVIDER'
LLM_DROP_PARAMS = 'LLM_DROP_PARAMS'
LLM_MAX_INPUT_TOKENS = 'LLM_MAX_INPUT_TOKENS'
LLM_MAX_OUTPUT_TOKENS = 'LLM_MAX_OUTPUT_TOKENS'
LLM_TOP_P = 'LLM_TOP_P'
LLM_TEMPERATURE = 'LLM_TEMPERATURE'
LLM_TIMEOUT = 'LLM_TIMEOUT'
LLM_API_KEY = 'LLM_API_KEY'
LLM_BASE_URL = 'LLM_BASE_URL'
AWS_ACCESS_KEY_ID = 'AWS_ACCESS_KEY_ID'
AWS_SECRET_ACCESS_KEY = 'AWS_SECRET_ACCESS_KEY'
AWS_REGION_NAME = 'AWS_REGION_NAME'
WORKSPACE_BASE = 'WORKSPACE_BASE'
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
CACHE_DIR = 'CACHE_DIR'
LLM_MODEL = 'LLM_MODEL'
CONFIRMATION_MODE = 'CONFIRMATION_MODE'
SANDBOX_CONTAINER_IMAGE = 'SANDBOX_CONTAINER_IMAGE'
RUN_AS_OPENHANDS = 'RUN_AS_OPENHANDS'
LLM_EMBEDDING_MODEL = 'LLM_EMBEDDING_MODEL'
LLM_EMBEDDING_BASE_URL = 'LLM_EMBEDDING_BASE_URL'
LLM_EMBEDDING_DEPLOYMENT_NAME = 'LLM_EMBEDDING_DEPLOYMENT_NAME'
LLM_API_VERSION = 'LLM_API_VERSION'
LLM_NUM_RETRIES = 'LLM_NUM_RETRIES'
LLM_RETRY_MIN_WAIT = 'LLM_RETRY_MIN_WAIT'
LLM_RETRY_MAX_WAIT = 'LLM_RETRY_MAX_WAIT'
AGENT_MEMORY_MAX_THREADS = 'AGENT_MEMORY_MAX_THREADS'
AGENT_MEMORY_ENABLED = 'AGENT_MEMORY_ENABLED'
MAX_ITERATIONS = 'MAX_ITERATIONS'
AGENT = 'AGENT'
E2B_API_KEY = 'E2B_API_KEY'
SECURITY_ANALYZER = 'SECURITY_ANALYZER'
SANDBOX_USER_ID = 'SANDBOX_USER_ID'
SANDBOX_TIMEOUT = 'SANDBOX_TIMEOUT'
USE_HOST_NETWORK = 'USE_HOST_NETWORK'
DISABLE_COLOR = 'DISABLE_COLOR'
DEBUG = 'DEBUG'
FILE_UPLOADS_MAX_FILE_SIZE_MB = 'FILE_UPLOADS_MAX_FILE_SIZE_MB'
FILE_UPLOADS_RESTRICT_FILE_TYPES = 'FILE_UPLOADS_RESTRICT_FILE_TYPES'
FILE_UPLOADS_ALLOWED_EXTENSIONS = 'FILE_UPLOADS_ALLOWED_EXTENSIONS'

View File

@@ -0,0 +1,46 @@
from pydantic import BaseModel, Field
__all__ = ['ObservationType']
class ObservationTypeSchema(BaseModel):
READ: str = Field(default='read')
"""The content of a file
"""
WRITE: str = Field(default='write')
BROWSE: str = Field(default='browse')
"""The HTML content of a URL
"""
RUN: str = Field(default='run')
"""The output of a command
"""
RUN_IPYTHON: str = Field(default='run_ipython')
"""Runs a IPython cell.
"""
CHAT: str = Field(default='chat')
"""A message from the user
"""
DELEGATE: str = Field(default='delegate')
"""The result of a task delegated to another agent
"""
MESSAGE: str = Field(default='message')
ERROR: str = Field(default='error')
SUCCESS: str = Field(default='success')
NULL: str = Field(default='null')
AGENT_STATE_CHANGED: str = Field(default='agent_state_changed')
USER_REJECTED: str = Field(default='user_rejected')
ObservationType = ObservationTypeSchema()

View File

@@ -0,0 +1,3 @@
from .singleton import Singleton
__all__ = ['Singleton']

View File

@@ -0,0 +1,49 @@
import json
from datetime import datetime
from json_repair import repair_json
from openhands.core.exceptions import LLMResponseError
from openhands.events.event import Event
from openhands.events.serialization import event_to_dict
def my_default_encoder(obj):
"""Custom JSON encoder that handles datetime and event objects"""
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, Event):
return event_to_dict(obj)
return json.JSONEncoder().default(obj)
def dumps(obj, **kwargs):
"""Serialize an object to str format"""
return json.dumps(obj, default=my_default_encoder, **kwargs)
def loads(json_str, **kwargs):
"""Create a JSON object from str"""
try:
return json.loads(json_str, **kwargs)
except json.JSONDecodeError:
pass
depth = 0
start = -1
for i, char in enumerate(json_str):
if char == '{':
if depth == 0:
start = i
depth += 1
elif char == '}':
depth -= 1
if depth == 0 and start != -1:
response = json_str[start : i + 1]
try:
json_str = repair_json(response)
return json.loads(json_str, **kwargs)
except (json.JSONDecodeError, ValueError, TypeError) as e:
raise LLMResponseError(
'Invalid JSON in response. Please make sure the response is a valid JSON object.'
) from e
raise LLMResponseError('No valid JSON object found in response.')

View File

@@ -0,0 +1,37 @@
import dataclasses
from openhands.core import logger
class Singleton(type):
_instances: dict = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
else:
# allow updates, just update existing instance
# perhaps not the most orthodox way to do it, though it simplifies client code
# useful for pre-defined groups of settings
instance = cls._instances[cls]
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
else:
logger.openhands_logger.warning(
f'Unknown key for {cls.__name__}: "{key}"'
)
return cls._instances[cls]
@classmethod
def reset(cls):
# used by pytest to reset the state of the singleton instances
for instance_type, instance in cls._instances.items():
print('resetting... ', instance_type)
for field_info in dataclasses.fields(instance_type):
if dataclasses.is_dataclass(field_info.type):
setattr(instance, field_info.name, field_info.type())
elif field_info.default_factory is not dataclasses.MISSING:
setattr(instance, field_info.name, field_info.default_factory())
else:
setattr(instance, field_info.name, field_info.default)

View File

@@ -0,0 +1,9 @@
from .event import Event, EventSource
from .stream import EventStream, EventStreamSubscriber
__all__ = [
'Event',
'EventSource',
'EventStream',
'EventStreamSubscriber',
]

View File

@@ -0,0 +1,34 @@
from .action import Action, ActionConfirmationStatus
from .agent import (
AgentDelegateAction,
AgentFinishAction,
AgentRejectAction,
AgentSummarizeAction,
ChangeAgentStateAction,
)
from .browse import BrowseInteractiveAction, BrowseURLAction
from .commands import CmdRunAction, IPythonRunCellAction
from .empty import NullAction
from .files import FileReadAction, FileWriteAction
from .message import MessageAction
from .tasks import AddTaskAction, ModifyTaskAction
__all__ = [
'Action',
'NullAction',
'CmdRunAction',
'BrowseURLAction',
'BrowseInteractiveAction',
'FileReadAction',
'FileWriteAction',
'AgentFinishAction',
'AgentRejectAction',
'AgentDelegateAction',
'AgentSummarizeAction',
'AddTaskAction',
'ModifyTaskAction',
'ChangeAgentStateAction',
'IPythonRunCellAction',
'MessageAction',
'ActionConfirmationStatus',
]

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
from enum import Enum
from typing import ClassVar
from openhands.events.event import Event
class ActionConfirmationStatus(str, Enum):
CONFIRMED = 'confirmed'
REJECTED = 'rejected'
AWAITING_CONFIRMATION = 'awaiting_confirmation'
class ActionSecurityRisk(int, Enum):
UNKNOWN = -1
LOW = 0
MEDIUM = 1
HIGH = 2
@dataclass
class Action(Event):
runnable: ClassVar[bool] = False

View File

@@ -0,0 +1,81 @@
from dataclasses import dataclass, field
from typing import Any
from openhands.core.schema import ActionType
from .action import Action
@dataclass
class ChangeAgentStateAction(Action):
"""Fake action, just to notify the client that a task state has changed."""
agent_state: str
thought: str = ''
action: str = ActionType.CHANGE_AGENT_STATE
@property
def message(self) -> str:
return f'Agent state changed to {self.agent_state}'
@dataclass
class AgentSummarizeAction(Action):
summary: str
action: str = ActionType.SUMMARIZE
@property
def message(self) -> str:
return self.summary
def __str__(self) -> str:
ret = '**AgentSummarizeAction**\n'
ret += f'SUMMARY: {self.summary}'
return ret
@dataclass
class AgentFinishAction(Action):
"""An action where the agent finishes the task.
Attributes:
outputs (dict): The outputs of the agent, for instance "content".
thought (str): The agent's explanation of its actions.
action (str): The action type, namely ActionType.FINISH.
"""
outputs: dict[str, Any] = field(default_factory=dict)
thought: str = ''
action: str = ActionType.FINISH
@property
def message(self) -> str:
if self.thought != '':
return self.thought
return "All done! What's next on the agenda?"
@dataclass
class AgentRejectAction(Action):
outputs: dict = field(default_factory=dict)
thought: str = ''
action: str = ActionType.REJECT
@property
def message(self) -> str:
msg: str = 'Task is rejected by the agent.'
if 'reason' in self.outputs:
msg += ' Reason: ' + self.outputs['reason']
return msg
@dataclass
class AgentDelegateAction(Action):
agent: str
inputs: dict
thought: str = ''
action: str = ActionType.DELEGATE
@property
def message(self) -> str:
return f"I'm asking {self.agent} for help with this task."

View File

@@ -0,0 +1,47 @@
from dataclasses import dataclass
from typing import ClassVar
from openhands.core.schema import ActionType
from .action import Action, ActionSecurityRisk
@dataclass
class BrowseURLAction(Action):
url: str
thought: str = ''
action: str = ActionType.BROWSE
runnable: ClassVar[bool] = True
security_risk: ActionSecurityRisk | None = None
@property
def message(self) -> str:
return f'Browsing URL: {self.url}'
def __str__(self) -> str:
ret = '**BrowseURLAction**\n'
if self.thought:
ret += f'THOUGHT: {self.thought}\n'
ret += f'URL: {self.url}'
return ret
@dataclass
class BrowseInteractiveAction(Action):
browser_actions: str
thought: str = ''
browsergym_send_msg_to_user: str = ''
action: str = ActionType.BROWSE_INTERACTIVE
runnable: ClassVar[bool] = True
security_risk: ActionSecurityRisk | None = None
@property
def message(self) -> str:
return f'Executing browser actions: {self.browser_actions}'
def __str__(self) -> str:
ret = '**BrowseInteractiveAction**\n'
if self.thought:
ret += f'THOUGHT: {self.thought}\n'
ret += f'BROWSER_ACTIONS: {self.browser_actions}'
return ret

View File

@@ -0,0 +1,57 @@
from dataclasses import dataclass
from typing import ClassVar
from openhands.core.schema import ActionType
from .action import Action, ActionConfirmationStatus, ActionSecurityRisk
@dataclass
class CmdRunAction(Action):
command: str
thought: str = ''
keep_prompt: bool = True
# if True, the command prompt will be kept in the command output observation
# Example of command output:
# root@sandbox:~# ls
# file1.txt
# file2.txt
# root@sandbox:~# <-- this is the command prompt
action: str = ActionType.RUN
runnable: ClassVar[bool] = True
is_confirmed: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED
security_risk: ActionSecurityRisk | None = None
@property
def message(self) -> str:
return f'Running command: {self.command}'
def __str__(self) -> str:
ret = f'**CmdRunAction (source={self.source})**\n'
if self.thought:
ret += f'THOUGHT: {self.thought}\n'
ret += f'COMMAND:\n{self.command}'
return ret
@dataclass
class IPythonRunCellAction(Action):
code: str
thought: str = ''
action: str = ActionType.RUN_IPYTHON
runnable: ClassVar[bool] = True
is_confirmed: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED
security_risk: ActionSecurityRisk | None = None
kernel_init_code: str = '' # code to run in the kernel (if the kernel is restarted)
def __str__(self) -> str:
ret = '**IPythonRunCellAction**\n'
if self.thought:
ret += f'THOUGHT: {self.thought}\n'
ret += f'CODE:\n{self.code}'
return ret
@property
def message(self) -> str:
return f'Running Python code interactively: {self.code}'

View File

@@ -0,0 +1,16 @@
from dataclasses import dataclass
from openhands.core.schema import ActionType
from .action import Action
@dataclass
class NullAction(Action):
"""An action that does nothing."""
action: str = ActionType.NULL
@property
def message(self) -> str:
return 'No action'

View File

@@ -0,0 +1,42 @@
from dataclasses import dataclass
from typing import ClassVar
from openhands.core.schema import ActionType
from .action import Action, ActionSecurityRisk
@dataclass
class FileReadAction(Action):
"""Reads a file from a given path.
Can be set to read specific lines using start and end
Default lines 0:-1 (whole file)
"""
path: str
start: int = 0
end: int = -1
thought: str = ''
action: str = ActionType.READ
runnable: ClassVar[bool] = True
security_risk: ActionSecurityRisk | None = None
@property
def message(self) -> str:
return f'Reading file: {self.path}'
@dataclass
class FileWriteAction(Action):
path: str
content: str
start: int = 0
end: int = -1
thought: str = ''
action: str = ActionType.WRITE
runnable: ClassVar[bool] = True
security_risk: ActionSecurityRisk | None = None
@property
def message(self) -> str:
return f'Writing file: {self.path}'

View File

@@ -0,0 +1,26 @@
from dataclasses import dataclass
from openhands.core.schema import ActionType
from .action import Action, ActionSecurityRisk
@dataclass
class MessageAction(Action):
content: str
images_urls: list | None = None
wait_for_response: bool = False
action: str = ActionType.MESSAGE
security_risk: ActionSecurityRisk | None = None
@property
def message(self) -> str:
return self.content
def __str__(self) -> str:
ret = f'**MessageAction** (source={self.source})\n'
ret += f'CONTENT: {self.content}'
if self.images_urls:
for url in self.images_urls:
ret += f'\nIMAGE_URL: {url}'
return ret

View File

@@ -0,0 +1,30 @@
from dataclasses import dataclass, field
from openhands.core.schema import ActionType
from .action import Action
@dataclass
class AddTaskAction(Action):
parent: str
goal: str
subtasks: list = field(default_factory=list)
thought: str = ''
action: str = ActionType.ADD_TASK
@property
def message(self) -> str:
return f'Added task: {self.goal}'
@dataclass
class ModifyTaskAction(Action):
task_id: str
state: str
thought: str = ''
action: str = ActionType.MODIFY_TASK
@property
def message(self) -> str:
return f'Set task {self.task_id} to {self.state}'

51
openhands/events/event.py Normal file
View File

@@ -0,0 +1,51 @@
import datetime
from dataclasses import dataclass
from enum import Enum
class EventSource(str, Enum):
AGENT = 'agent'
USER = 'user'
@dataclass
class Event:
@property
def message(self) -> str | None:
if hasattr(self, '_message'):
return self._message # type: ignore[attr-defined]
return ''
@property
def id(self) -> int:
if hasattr(self, '_id'):
return self._id # type: ignore[attr-defined]
return -1
@property
def timestamp(self) -> datetime.datetime | None:
if hasattr(self, '_timestamp'):
return self._timestamp # type: ignore[attr-defined]
return None
@property
def source(self) -> EventSource | None:
if hasattr(self, '_source'):
return self._source # type: ignore[attr-defined]
return None
@property
def cause(self) -> int | None:
if hasattr(self, '_cause'):
return self._cause # type: ignore[attr-defined]
return None
@property
def timeout(self) -> int | None:
if hasattr(self, '_timeout'):
return self._timeout # type: ignore[attr-defined]
return None
@timeout.setter
def timeout(self, value: int | None) -> None:
self._timeout = value

View File

@@ -0,0 +1,25 @@
from .agent import AgentStateChangedObservation
from .browse import BrowserOutputObservation
from .commands import CmdOutputObservation, IPythonRunCellObservation
from .delegate import AgentDelegateObservation
from .empty import NullObservation
from .error import ErrorObservation
from .files import FileReadObservation, FileWriteObservation
from .observation import Observation
from .reject import UserRejectObservation
from .success import SuccessObservation
__all__ = [
'Observation',
'NullObservation',
'CmdOutputObservation',
'IPythonRunCellObservation',
'BrowserOutputObservation',
'FileReadObservation',
'FileWriteObservation',
'ErrorObservation',
'AgentStateChangedObservation',
'AgentDelegateObservation',
'SuccessObservation',
'UserRejectObservation',
]

View File

@@ -0,0 +1,17 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class AgentStateChangedObservation(Observation):
"""This data class represents the result from delegating to another agent"""
agent_state: str
observation: str = ObservationType.AGENT_STATE_CHANGED
@property
def message(self) -> str:
return ''

View File

@@ -0,0 +1,44 @@
from dataclasses import dataclass, field
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class BrowserOutputObservation(Observation):
"""This data class represents the output of a browser."""
url: str
screenshot: str = field(repr=False) # don't show in repr
error: bool = False
observation: str = ObservationType.BROWSE
# do not include in the memory
open_pages_urls: list = field(default_factory=list)
active_page_index: int = -1
dom_object: dict = field(default_factory=dict, repr=False) # don't show in repr
axtree_object: dict = field(default_factory=dict, repr=False) # don't show in repr
extra_element_properties: dict = field(
default_factory=dict, repr=False
) # don't show in repr
last_browser_action: str = ''
last_browser_action_error: str = ''
focused_element_bid: str = ''
@property
def message(self) -> str:
return 'Visited ' + self.url
def __str__(self) -> str:
return (
'**BrowserOutputObservation**\n'
f'URL: {self.url}\n'
f'Error: {self.error}\n'
f'Open pages: {self.open_pages_urls}\n'
f'Active page index: {self.active_page_index}\n'
f'Last browser action: {self.last_browser_action}\n'
f'Last browser action error: {self.last_browser_action_error}\n'
f'Focused element bid: {self.focused_element_bid}\n'
f'axTree: {self.axtree_object}\n'
f'CONTENT: {self.content}\n'
)

View File

@@ -0,0 +1,45 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class CmdOutputObservation(Observation):
"""This data class represents the output of a command."""
command_id: int
command: str
exit_code: int = 0
observation: str = ObservationType.RUN
@property
def error(self) -> bool:
return self.exit_code != 0
@property
def message(self) -> str:
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
def __str__(self) -> str:
return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}'
@dataclass
class IPythonRunCellObservation(Observation):
"""This data class represents the output of a IPythonRunCellAction."""
code: str
observation: str = ObservationType.RUN_IPYTHON
@property
def error(self) -> bool:
return False # IPython cells do not return exit codes
@property
def message(self) -> str:
return 'Code executed in IPython cell.'
def __str__(self) -> str:
return f'**IPythonRunCellObservation**\n{self.content}'

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class AgentDelegateObservation(Observation):
"""This data class represents the result from delegating to another agent.
Attributes:
content (str): The content of the observation.
outputs (dict): The outputs of the delegated agent.
observation (str): The type of observation.
"""
outputs: dict
observation: str = ObservationType.DELEGATE
@property
def message(self) -> str:
return ''

View File

@@ -0,0 +1,18 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class NullObservation(Observation):
"""This data class represents a null observation.
This is used when the produced action is NOT executable.
"""
observation: str = ObservationType.NULL
@property
def message(self) -> str:
return 'No observation'

View File

@@ -0,0 +1,16 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class ErrorObservation(Observation):
"""This data class represents an error encountered by the agent."""
observation: str = ObservationType.ERROR
@property
def message(self) -> str:
return self.content

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class FileReadObservation(Observation):
"""This data class represents the content of a file."""
path: str
observation: str = ObservationType.READ
@property
def message(self) -> str:
return f'I read the file {self.path}.'
@dataclass
class FileWriteObservation(Observation):
"""This data class represents a file write operation"""
path: str
observation: str = ObservationType.WRITE
@property
def message(self) -> str:
return f'I wrote to the file {self.path}.'

View File

@@ -0,0 +1,8 @@
from dataclasses import dataclass
from openhands.events.event import Event
@dataclass
class Observation(Event):
content: str

View File

@@ -0,0 +1,16 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class UserRejectObservation(Observation):
"""This data class represents the result of a successful action."""
observation: str = ObservationType.USER_REJECTED
@property
def message(self) -> str:
return self.content

View File

@@ -0,0 +1,16 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from .observation import Observation
@dataclass
class SuccessObservation(Observation):
"""This data class represents the result of a successful action."""
observation: str = ObservationType.SUCCESS
@property
def message(self) -> str:
return self.content

View File

@@ -0,0 +1,19 @@
from .action import (
action_from_dict,
)
from .event import (
event_from_dict,
event_to_dict,
event_to_memory,
)
from .observation import (
observation_from_dict,
)
__all__ = [
'action_from_dict',
'event_from_dict',
'event_to_dict',
'event_to_memory',
'observation_from_dict',
]

View File

@@ -0,0 +1,61 @@
from openhands.core.exceptions import LLMMalformedActionError
from openhands.events.action.action import Action
from openhands.events.action.agent import (
AgentDelegateAction,
AgentFinishAction,
AgentRejectAction,
ChangeAgentStateAction,
)
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
from openhands.events.action.commands import (
CmdRunAction,
IPythonRunCellAction,
)
from openhands.events.action.empty import NullAction
from openhands.events.action.files import FileReadAction, FileWriteAction
from openhands.events.action.message import MessageAction
from openhands.events.action.tasks import AddTaskAction, ModifyTaskAction
actions = (
NullAction,
CmdRunAction,
IPythonRunCellAction,
BrowseURLAction,
BrowseInteractiveAction,
FileReadAction,
FileWriteAction,
AgentFinishAction,
AgentRejectAction,
AgentDelegateAction,
AddTaskAction,
ModifyTaskAction,
ChangeAgentStateAction,
MessageAction,
)
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
def action_from_dict(action: dict) -> Action:
if not isinstance(action, dict):
raise LLMMalformedActionError('action must be a dictionary')
action = action.copy()
if 'action' not in action:
raise LLMMalformedActionError(f"'action' key is not found in {action=}")
if not isinstance(action['action'], str):
raise LLMMalformedActionError(
f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}"
)
action_class = ACTION_TYPE_TO_CLASS.get(action['action'])
if action_class is None:
raise LLMMalformedActionError(
f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}"
)
args = action.get('args', {})
try:
decoded_action = action_class(**args)
if 'timeout' in action:
decoded_action.timeout = action['timeout']
except TypeError:
raise LLMMalformedActionError(f'action={action} has the wrong arguments')
return decoded_action

View File

@@ -0,0 +1,101 @@
from dataclasses import asdict
from datetime import datetime
from openhands.events import Event, EventSource
from openhands.events.observation.observation import Observation
from .action import action_from_dict
from .observation import observation_from_dict
from .utils import remove_fields
# TODO: move `content` into `extras`
TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation']
UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause']
DELETE_FROM_MEMORY_EXTRAS = {
'screenshot',
'dom_object',
'axtree_object',
'open_pages_urls',
'active_page_index',
'last_browser_action',
'last_browser_action_error',
'focused_element_bid',
'extra_element_properties',
}
def event_from_dict(data) -> 'Event':
evt: Event
if 'action' in data:
evt = action_from_dict(data)
elif 'observation' in data:
evt = observation_from_dict(data)
else:
raise ValueError('Unknown event type: ' + data)
for key in UNDERSCORE_KEYS:
if key in data:
value = data[key]
if key == 'timestamp':
value = datetime.fromisoformat(value)
if key == 'source':
value = EventSource(value)
setattr(evt, '_' + key, value)
return evt
def event_to_dict(event: 'Event') -> dict:
props = asdict(event)
d = {}
for key in TOP_KEYS:
if hasattr(event, key) and getattr(event, key) is not None:
d[key] = getattr(event, key)
elif hasattr(event, f'_{key}') and getattr(event, f'_{key}') is not None:
d[key] = getattr(event, f'_{key}')
if key == 'id' and d.get('id') == -1:
d.pop('id', None)
if key == 'timestamp' and 'timestamp' in d:
d['timestamp'] = d['timestamp'].isoformat()
if key == 'source' and 'source' in d:
d['source'] = d['source'].value
props.pop(key, None)
if 'security_risk' in props and props['security_risk'] is None:
props.pop('security_risk')
if 'action' in d:
d['args'] = props
if event.timeout is not None:
d['timeout'] = event.timeout
elif 'observation' in d:
d['content'] = props.pop('content', '')
d['extras'] = props
else:
raise ValueError('Event must be either action or observation')
return d
def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
d = event_to_dict(event)
d.pop('id', None)
d.pop('cause', None)
d.pop('timestamp', None)
d.pop('message', None)
d.pop('images_urls', None)
if 'extras' in d:
remove_fields(d['extras'], DELETE_FROM_MEMORY_EXTRAS)
if isinstance(event, Observation) and 'content' in d:
d['content'] = truncate_content(d['content'], max_message_chars)
return d
def truncate_content(content: str, max_chars: int) -> str:
"""Truncate the middle of the observation content if it is too long."""
if len(content) <= max_chars:
return content
# truncate the middle and include a message to the LLM about it
half = max_chars // 2
return (
content[:half]
+ '\n[... Observation truncated due to length ...]\n'
+ content[-half:]
)

View File

@@ -0,0 +1,48 @@
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
CmdOutputObservation,
IPythonRunCellObservation,
)
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.files import FileReadObservation, FileWriteObservation
from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation
observations = (
NullObservation,
CmdOutputObservation,
IPythonRunCellObservation,
BrowserOutputObservation,
FileReadObservation,
FileWriteObservation,
AgentDelegateObservation,
SuccessObservation,
ErrorObservation,
AgentStateChangedObservation,
UserRejectObservation,
)
OBSERVATION_TYPE_TO_CLASS = {
observation_class.observation: observation_class # type: ignore[attr-defined]
for observation_class in observations
}
def observation_from_dict(observation: dict) -> Observation:
observation = observation.copy()
if 'observation' not in observation:
raise KeyError(f"'observation' key is not found in {observation=}")
observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation'])
if observation_class is None:
raise KeyError(
f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}"
)
observation.pop('observation')
observation.pop('message', None)
content = observation.pop('content', '')
extras = observation.pop('extras', {})
return observation_class(content=content, **extras)

View File

@@ -0,0 +1,20 @@
def remove_fields(obj, fields: set[str]):
"""Remove fields from an object.
Parameters:
- obj: The dictionary, or list of dictionaries to remove fields from
- fields (set[str]): A set of field names to remove from the object
"""
if isinstance(obj, dict):
for field in fields:
if field in obj:
del obj[field]
for _, value in obj.items():
remove_fields(value, fields)
elif isinstance(obj, list) or isinstance(obj, tuple):
for item in obj:
remove_fields(item, fields)
elif hasattr(obj, '__dataclass_fields__'):
raise ValueError(
'Object must not contain dataclass, consider converting to dict first'
)

155
openhands/events/stream.py Normal file
View File

@@ -0,0 +1,155 @@
import asyncio
import threading
from datetime import datetime
from enum import Enum
from typing import Callable, Iterable
from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.storage import FileStore
from .event import Event, EventSource
class EventStreamSubscriber(str, Enum):
AGENT_CONTROLLER = 'agent_controller'
SECURITY_ANALYZER = 'security_analyzer'
SERVER = 'server'
RUNTIME = 'runtime'
MAIN = 'main'
TEST = 'test'
class EventStream:
sid: str
file_store: FileStore
# For each subscriber ID, there is a stack of callback functions - useful
# when there are agent delegates
_subscribers: dict[str, list[Callable]]
_cur_id: int
_lock: threading.Lock
def __init__(self, sid: str, file_store: FileStore):
self.sid = sid
self.file_store = file_store
self._subscribers = {}
self._cur_id = 0
self._lock = threading.Lock()
self._reinitialize_from_file_store()
def _reinitialize_from_file_store(self) -> None:
try:
events = self.file_store.list(f'sessions/{self.sid}/events')
except FileNotFoundError:
logger.debug(f'No events found for session {self.sid}')
self._cur_id = 0
return
# if we have events, we need to find the highest id to prepare for new events
for event_str in events:
id = self._get_id_from_filename(event_str)
if id >= self._cur_id:
self._cur_id = id + 1
def _get_filename_for_id(self, id: int) -> str:
return f'sessions/{self.sid}/events/{id}.json'
@staticmethod
def _get_id_from_filename(filename: str) -> int:
try:
return int(filename.split('/')[-1].split('.')[0])
except ValueError:
logger.warning(f'get id from filename ({filename}) failed.')
return -1
def get_events(
self,
start_id=0,
end_id=None,
reverse=False,
filter_out_type: tuple[type[Event], ...] | None = None,
) -> Iterable[Event]:
if reverse:
if end_id is None:
end_id = self._cur_id - 1
event_id = end_id
while event_id >= start_id:
try:
event = self.get_event(event_id)
if filter_out_type is None or not isinstance(
event, filter_out_type
):
yield event
except FileNotFoundError:
logger.debug(f'No event found for ID {event_id}')
event_id -= 1
else:
event_id = start_id
while True:
if end_id is not None and event_id > end_id:
break
try:
event = self.get_event(event_id)
if filter_out_type is None or not isinstance(
event, filter_out_type
):
yield event
except FileNotFoundError:
break
event_id += 1
def get_event(self, id: int) -> Event:
filename = self._get_filename_for_id(id)
content = self.file_store.read(filename)
data = json.loads(content)
return event_from_dict(data)
def get_latest_event(self) -> Event:
return self.get_event(self._cur_id - 1)
def get_latest_event_id(self) -> int:
return self._cur_id - 1
def subscribe(self, id: EventStreamSubscriber, callback: Callable, append=False):
if id in self._subscribers:
if append:
self._subscribers[id].append(callback)
else:
raise ValueError('Subscriber already exists: ' + id)
else:
self._subscribers[id] = [callback]
def unsubscribe(self, id: EventStreamSubscriber):
if id not in self._subscribers:
logger.warning('Subscriber not found during unsubscribe: ' + id)
else:
self._subscribers[id].pop()
if len(self._subscribers[id]) == 0:
del self._subscribers[id]
def add_event(self, event: Event, source: EventSource):
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
self._cur_id += 1
logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
event._timestamp = datetime.now() # type: ignore [attr-defined]
event._source = source # type: ignore [attr-defined]
data = event_to_dict(event)
if event.id is not None:
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
for key in sorted(self._subscribers.keys()):
stack = self._subscribers[key]
callback = stack[-1]
asyncio.create_task(callback(event))
def filtered_events_by_source(self, source: EventSource):
for event in self.get_events():
if event.source == source:
yield event
def clear(self):
self.file_store.delete(f'sessions/{self.sid}')
self._cur_id = 0
# self._subscribers = {}
self._reinitialize_from_file_store()

32
openhands/llm/bedrock.py Normal file
View File

@@ -0,0 +1,32 @@
import boto3
from openhands.core.logger import openhands_logger as logger
def list_foundation_models(
aws_region_name: str, aws_access_key_id: str, aws_secret_access_key: str
) -> list[str]:
try:
# The AWS bedrock model id is not queried, if no AWS parameters are configured.
client = boto3.client(
service_name='bedrock',
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
foundation_models_list = client.list_foundation_models(
byOutputModality='TEXT', byInferenceType='ON_DEMAND'
)
model_summaries = foundation_models_list['modelSummaries']
return ['bedrock/' + model['modelId'] for model in model_summaries]
except Exception as err:
logger.warning(
'%s. Please config AWS_REGION_NAME AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY'
' if you want use bedrock model.',
err,
)
return []
def remove_error_modelId(model_list):
return list(filter(lambda m: not m.startswith('bedrock'), model_list))

511
openhands/llm/llm.py Normal file
View File

@@ -0,0 +1,511 @@
import asyncio
import copy
import warnings
from functools import partial
from openhands.core.config import LLMConfig
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
APIConnectionError,
ContentPolicyViolationError,
InternalServerError,
OpenAIError,
RateLimitError,
ServiceUnavailableError,
)
from litellm.types.utils import CostPerToken
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import llm_prompt_logger, llm_response_logger
from openhands.core.logger import openhands_logger as logger
from openhands.core.metrics import Metrics
__all__ = ['LLM']
message_separator = '\n\n----------\n\n'
class LLM:
"""The LLM class represents a Language Model instance.
Attributes:
config: an LLMConfig object specifying the configuration of the LLM.
"""
def __init__(
self,
config: LLMConfig,
metrics: Metrics | None = None,
):
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
Passing simple parameters always overrides config.
Args:
config: The LLM configuration
"""
self.config = copy.deepcopy(config)
self.metrics = metrics if metrics is not None else Metrics()
self.cost_metric_supported = True
# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)
# litellm actually uses base Exception here for unknown model
self.model_info = None
try:
if self.config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.config.model)
else:
self.model_info = litellm.get_model_info(
self.config.model.split(':')[0]
)
# noinspection PyBroadException
except Exception as e:
logger.warning(f'Could not get model info for {config.model}:\n{e}')
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
and isinstance(self.model_info['max_input_tokens'], int)
):
self.config.max_input_tokens = self.model_info['max_input_tokens']
else:
# Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
self.config.max_input_tokens = 4096
if self.config.max_output_tokens is None:
if (
self.model_info is not None
and 'max_output_tokens' in self.model_info
and isinstance(self.model_info['max_output_tokens'], int)
):
self.config.max_output_tokens = self.model_info['max_output_tokens']
else:
# Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
self.config.max_output_tokens = 1024
if self.config.drop_params:
litellm.drop_params = self.config.drop_params
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
)
completion_unwrapped = self._completion
def attempt_on_error(retry_state):
logger.error(
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
exc_info=False,
)
return None
@retry(
reraise=True,
stop=stop_after_attempt(self.config.num_retries),
wait=wait_random_exponential(
multiplier=self.config.retry_multiplier,
min=self.config.retry_min_wait,
max=self.config.retry_max_wait,
),
retry=retry_if_exception_type(
(
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
ContentPolicyViolationError,
)
),
after=attempt_on_error,
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
# some callers might just send the messages directly
if 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1]
# log the prompt
debug_message = ''
for message in messages:
content = message['content']
if isinstance(content, list):
for element in content:
if isinstance(element, dict):
if 'text' in element:
content_str = element['text'].strip()
elif (
'image_url' in element and 'url' in element['image_url']
):
content_str = element['image_url']['url']
else:
content_str = str(element)
else:
content_str = str(element)
debug_message += message_separator + content_str
else:
content_str = str(content)
debug_message += message_separator + content_str
llm_prompt_logger.debug(debug_message)
# skip if messages is empty (thus debug_message is empty)
if debug_message:
resp = completion_unwrapped(*args, **kwargs)
else:
resp = {'choices': [{'message': {'content': ''}}]}
# log the response
message_back = resp['choices'][0]['message']['content']
llm_response_logger.debug(message_back)
# post-process to log costs
self._post_completion(resp)
return resp
self._completion = wrapper # type: ignore
# Async version
self._async_completion = partial(
self._call_acompletion,
model=self.config.model,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
drop_params=True,
)
async_completion_unwrapped = self._async_completion
@retry(
reraise=True,
stop=stop_after_attempt(self.config.num_retries),
wait=wait_random_exponential(
multiplier=self.config.retry_multiplier,
min=self.config.retry_min_wait,
max=self.config.retry_max_wait,
),
retry=retry_if_exception_type(
(
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
ContentPolicyViolationError,
)
),
after=attempt_on_error,
)
async def async_completion_wrapper(*args, **kwargs):
"""Async wrapper for the litellm acompletion function."""
# some callers might just send the messages directly
if 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1]
# log the prompt
debug_message = ''
for message in messages:
content = message['content']
if isinstance(content, list):
for element in content:
if isinstance(element, dict):
if 'text' in element:
content_str = element['text']
elif (
'image_url' in element and 'url' in element['image_url']
):
content_str = element['image_url']['url']
else:
content_str = str(element)
else:
content_str = str(element)
debug_message += message_separator + content_str
else:
content_str = str(content)
debug_message += message_separator + content_str
llm_prompt_logger.debug(debug_message)
async def check_stopped():
while True:
if (
hasattr(self.config, 'on_cancel_requested_fn')
and self.config.on_cancel_requested_fn is not None
and await self.config.on_cancel_requested_fn()
):
raise UserCancelledError('LLM request cancelled by user')
await asyncio.sleep(0.1)
stop_check_task = asyncio.create_task(check_stopped())
try:
# Directly call and await litellm_acompletion
resp = await async_completion_unwrapped(*args, **kwargs)
# skip if messages is empty (thus debug_message is empty)
if debug_message:
message_back = resp['choices'][0]['message']['content']
llm_response_logger.debug(message_back)
else:
resp = {'choices': [{'message': {'content': ''}}]}
self._post_completion(resp)
# We do not support streaming in this method, thus return resp
return resp
except UserCancelledError:
logger.info('LLM request cancelled by user.')
raise
except OpenAIError as e:
logger.error(f'OpenAIError occurred:\n{e}')
raise
except (
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
) as e:
logger.error(f'Completion Error occurred:\n{e}')
raise
finally:
await asyncio.sleep(0.1)
stop_check_task.cancel()
try:
await stop_check_task
except asyncio.CancelledError:
pass
@retry(
reraise=True,
stop=stop_after_attempt(self.config.num_retries),
wait=wait_random_exponential(
multiplier=self.config.retry_multiplier,
min=self.config.retry_min_wait,
max=self.config.retry_max_wait,
),
retry=retry_if_exception_type(
(
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
ContentPolicyViolationError,
)
),
after=attempt_on_error,
)
async def async_acompletion_stream_wrapper(*args, **kwargs):
"""Async wrapper for the litellm acompletion with streaming function."""
# some callers might just send the messages directly
if 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1]
# log the prompt
debug_message = ''
for message in messages:
debug_message += message_separator + message['content']
llm_prompt_logger.debug(debug_message)
try:
# Directly call and await litellm_acompletion
resp = await async_completion_unwrapped(*args, **kwargs)
# For streaming we iterate over the chunks
async for chunk in resp:
# Check for cancellation before yielding the chunk
if (
hasattr(self.config, 'on_cancel_requested_fn')
and self.config.on_cancel_requested_fn is not None
and await self.config.on_cancel_requested_fn()
):
raise UserCancelledError(
'LLM request cancelled due to CANCELLED state'
)
# with streaming, it is "delta", not "message"!
message_back = chunk['choices'][0]['delta']['content']
llm_response_logger.debug(message_back)
self._post_completion(chunk)
yield chunk
except UserCancelledError:
logger.info('LLM request cancelled by user.')
raise
except OpenAIError as e:
logger.error(f'OpenAIError occurred:\n{e}')
raise
except (
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
) as e:
logger.error(f'Completion Error occurred:\n{e}')
raise
finally:
if kwargs.get('stream', False):
await asyncio.sleep(0.1)
self._async_completion = async_completion_wrapper # type: ignore
self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
async def _call_acompletion(self, *args, **kwargs):
return await litellm.acompletion(*args, **kwargs)
@property
def completion(self):
"""Decorator for the litellm completion function.
Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion
@property
def async_completion(self):
"""Decorator for the async litellm acompletion function.
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
"""
return self._async_completion
@property
def async_streaming_completion(self):
"""Decorator for the async litellm acompletion function with streaming.
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
"""
return self._async_streaming_completion
def supports_vision(self):
return litellm.supports_vision(self.config.model)
def _post_completion(self, response: str) -> None:
"""Post-process the completion response."""
try:
cur_cost = self.completion_cost(response)
except Exception:
cur_cost = 0
if self.cost_metric_supported:
logger.info(
'Cost: %.2f USD | Accumulated Cost: %.2f USD',
cur_cost,
self.metrics.accumulated_cost,
)
def get_token_count(self, messages):
"""Get the number of tokens in a list of messages.
Args:
messages (list): A list of messages.
Returns:
int: The number of tokens.
"""
return litellm.token_counter(model=self.config.model, messages=messages)
def is_local(self):
"""Determines if the system is using a locally running LLM.
Returns:
boolean: True if executing a local model.
"""
if self.config.base_url is not None:
for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
if substring in self.config.base_url:
return True
elif self.config.model is not None:
if self.config.model.startswith('ollama'):
return True
return False
def completion_cost(self, response):
"""Calculate the cost of a completion response based on the model. Local models are treated as free.
Add the current cost into total cost in metrics.
Args:
response: A response from a model invocation.
Returns:
number: The cost of the response.
"""
if not self.cost_metric_supported:
return 0.0
extra_kwargs = {}
if (
self.config.input_cost_per_token is not None
and self.config.output_cost_per_token is not None
):
cost_per_token = CostPerToken(
input_cost_per_token=self.config.input_cost_per_token,
output_cost_per_token=self.config.output_cost_per_token,
)
logger.info(f'Using custom cost per token: {cost_per_token}')
extra_kwargs['custom_cost_per_token'] = cost_per_token
if not self.is_local():
try:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
)
self.metrics.add_cost(cost)
return cost
except Exception:
self.cost_metric_supported = False
logger.warning('Cost calculation not supported for this model.')
return 0.0
def __str__(self):
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
def __repr__(self):
return str(self)
def reset(self):
self.metrics = Metrics()

View File

@@ -0,0 +1,23 @@
# Memory Component
- Short Term History
- Memory Condenser
- Long Term Memory
## Short Term History
- Short term history filters the event stream and computes the messages that are injected into the context
- It filters out certain events of no interest for the Agent, such as AgentChangeStateObservation or NullAction/NullObservation
- When the context window or the token limit set by the user is exceeded, history starts condensing: chunks of messages into summaries.
- Each summary is then injected into the context, in the place of the respective chunk it summarizes
## Memory Condenser
- Memory condenser is responsible for summarizing the chunks of events
- It summarizes the earlier events first
- It starts with the earliest agent actions and observations between two user messages
- Then it does the same for later chunks of events between user messages
- If there are no more agent events, it summarizes the user messages, this time one by one, if they're large enough and not immediately after an AgentFinishAction event (we assume those are tasks, potentially important)
- Summaries are retrieved from the LLM as AgentSummarizeAction, and are saved in State.
## Long Term Memory
- Long term memory component stores embeddings for events and prompts in a vector store
- The agent can query it when it needs detailed information about a past event or to learn new actions

View File

@@ -0,0 +1,5 @@
from .condenser import MemoryCondenser
from .history import ShortTermHistory
from .memory import LongTermMemory
__all__ = ['LongTermMemory', 'ShortTermHistory', 'MemoryCondenser']

View File

@@ -0,0 +1,24 @@
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM
class MemoryCondenser:
def condense(self, summarize_prompt: str, llm: LLM):
"""Attempts to condense the memory by using the llm
Parameters:
- llm (LLM): llm to be used for summarization
Raises:
- Exception: the same exception as it got from the llm or processing the response
"""
try:
messages = [{'content': summarize_prompt, 'role': 'user'}]
resp = llm.completion(messages=messages)
summary_response = resp['choices'][0]['message']['content']
return summary_response
except Exception as e:
logger.error('Error condensing thoughts: %s', str(e), exc_info=False)
# TODO If the llm fails with ContextWindowExceededError, we can try to condense the memory chunk by chunk
raise

261
openhands/memory/history.py Normal file
View File

@@ -0,0 +1,261 @@
from typing import ClassVar, Iterable
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.agent import (
AgentDelegateAction,
ChangeAgentStateAction,
)
from openhands.events.action.empty import NullAction
from openhands.events.action.message import MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.commands import CmdOutputObservation
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import event_to_dict
from openhands.events.stream import EventStream
class ShortTermHistory(list[Event]):
"""A list of events that represents the short-term memory of the agent.
This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
"""
start_id: int
end_id: int
_event_stream: EventStream
delegates: dict[tuple[int, int], tuple[str, str]]
filter_out: ClassVar[tuple[type[Event], ...]] = (
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
)
def __init__(self):
super().__init__()
self.start_id = -1
self.end_id = -1
self.delegates = {}
def set_event_stream(self, event_stream: EventStream):
self._event_stream = event_stream
def get_events_as_list(self) -> list[Event]:
"""Return the history as a list of Event objects."""
return list(self.get_events())
def get_events(self, reverse: bool = False) -> Iterable[Event]:
"""Return the events as a stream of Event objects."""
# TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
# or even if it is, because currently we don't add it to the summary
# iterate from start_id to end_id, or reverse
start_id = self.start_id if self.start_id != -1 else 0
end_id = (
self.end_id
if self.end_id != -1
else self._event_stream.get_latest_event_id()
)
for event in self._event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=reverse,
filter_out_type=self.filter_out,
):
# TODO add summaries
# and filter out events that were included in a summary
# filter out the events from a delegate of the current agent
if not any(
# except for the delegate action and observation themselves, currently
# AgentDelegateAction has id = delegate_start
# AgentDelegateObservation has id = delegate_end
delegate_start < event.id < delegate_end
for delegate_start, delegate_end in self.delegates.keys()
):
yield event
def get_last_action(self, end_id: int = -1) -> Action | None:
"""Return the last action from the event stream, filtered to exclude unwanted events."""
# from end_id in reverse, find the first action
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
last_action = next(
(
event
for event in self._event_stream.get_events(
end_id=end_id, reverse=True, filter_out_type=self.filter_out
)
if isinstance(event, Action)
),
None,
)
return last_action
def get_last_observation(self, end_id: int = -1) -> Observation | None:
"""Return the last observation from the event stream, filtered to exclude unwanted events."""
# from end_id in reverse, find the first observation
end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
last_observation = next(
(
event
for event in self._event_stream.get_events(
end_id=end_id, reverse=True, filter_out_type=self.filter_out
)
if isinstance(event, Observation)
),
None,
)
return last_observation
def get_last_user_message(self) -> str:
"""Return the content of the last user message from the event stream."""
last_user_message = next(
(
event.content
for event in self._event_stream.get_events(reverse=True)
if isinstance(event, MessageAction) and event.source == EventSource.USER
),
None,
)
return last_user_message if last_user_message is not None else ''
def get_last_agent_message(self) -> str:
"""Return the content of the last agent message from the event stream."""
last_agent_message = next(
(
event.content
for event in self._event_stream.get_events(reverse=True)
if isinstance(event, MessageAction)
and event.source == EventSource.AGENT
),
None,
)
return last_agent_message if last_agent_message is not None else ''
def get_last_events(self, n: int) -> list[Event]:
"""Return the last n events from the event stream."""
# dummy agent is using this
# it should work, but it's not great to store temporary lists now just for a test
end_id = self._event_stream.get_latest_event_id()
start_id = max(0, end_id - n + 1)
return list(
event
for event in self._event_stream.get_events(
start_id=start_id,
end_id=end_id,
filter_out_type=self.filter_out,
)
)
def has_delegation(self) -> bool:
for event in self._event_stream.get_events():
if isinstance(event, AgentDelegateObservation):
return True
return False
def on_event(self, event: Event):
if not isinstance(event, AgentDelegateObservation):
return
logger.debug('AgentDelegateObservation received')
# figure out what this delegate's actions were
# from the last AgentDelegateAction to this AgentDelegateObservation
# and save their ids as start and end ids
# in order to use later to exclude them from parent stream
# or summarize them
delegate_end = event.id
delegate_start = -1
delegate_agent: str = ''
delegate_task: str = ''
for prev_event in self._event_stream.get_events(
end_id=event.id - 1, reverse=True
):
if isinstance(prev_event, AgentDelegateAction):
delegate_start = prev_event.id
delegate_agent = prev_event.agent
delegate_task = prev_event.inputs.get('task', '')
break
if delegate_start == -1:
logger.error(
f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
)
return
self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
logger.debug(
f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
)
# TODO remove me when unnecessary
# history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
# we rebuild the pairs here
# for compatibility with the existing output format in evaluations
def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
history_pairs = []
for action, observation in self.get_pairs():
history_pairs.append((event_to_dict(action), event_to_dict(observation)))
return history_pairs
def get_pairs(self) -> list[tuple[Action, Observation]]:
"""Return the history as a list of tuples (action, observation)."""
tuples: list[tuple[Action, Observation]] = []
action_map: dict[int, Action] = {}
observation_map: dict[int, Observation] = {}
# runnable actions are set as cause of observations
# (MessageAction, NullObservation) for source=USER
# (MessageAction, NullObservation) for source=AGENT
# (other_action?, NullObservation)
# (NullAction, CmdOutputObservation) background CmdOutputObservations
for event in self.get_events_as_list():
if event.id is None or event.id == -1:
logger.debug(f'Event {event} has no ID')
if isinstance(event, Action):
action_map[event.id] = event
if isinstance(event, Observation):
if event.cause is None or event.cause == -1:
logger.debug(f'Observation {event} has no cause')
if event.cause is None:
# runnable actions are set as cause of observations
# NullObservations have no cause
continue
observation_map[event.cause] = event
for action_id, action in action_map.items():
observation = observation_map.get(action_id)
if observation:
# observation with a cause
tuples.append((action, observation))
else:
tuples.append((action, NullObservation('')))
for cause_id, observation in observation_map.items():
if cause_id not in action_map:
if isinstance(observation, NullObservation):
continue
if not isinstance(observation, CmdOutputObservation):
logger.debug(f'Observation {observation} has no cause')
tuples.append((NullAction(), observation))
return tuples.copy()

184
openhands/memory/memory.py Normal file
View File

@@ -0,0 +1,184 @@
import threading
from openai._exceptions import APIConnectionError, InternalServerError, RateLimitError
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
try:
import chromadb
import llama_index.embeddings.openai.base as llama_openai
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.vector_stores.chroma import ChromaVectorStore
LLAMA_INDEX_AVAILABLE = True
except ImportError:
LLAMA_INDEX_AVAILABLE = False
if LLAMA_INDEX_AVAILABLE:
# TODO: this could be made configurable
num_retries: int = 10
retry_min_wait: int = 3
retry_max_wait: int = 300
# llama-index includes a retry decorator around openai.get_embeddings() function
# it is initialized with hard-coded values and errors
# this non-customizable behavior is creating issues when it's retrying faster than providers' rate limits
# this block attempts to banish it and replace it with our decorator, to allow users to set their own limits
if hasattr(llama_openai.get_embeddings, '__wrapped__'):
original_get_embeddings = llama_openai.get_embeddings.__wrapped__
else:
logger.warning('Cannot set custom retry limits.')
num_retries = 1
original_get_embeddings = llama_openai.get_embeddings
def attempt_on_error(retry_state):
logger.error(
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
exc_info=False,
)
return None
@retry(
reraise=True,
stop=stop_after_attempt(num_retries),
wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, InternalServerError)
),
after=attempt_on_error,
)
def wrapper_get_embeddings(*args, **kwargs):
return original_get_embeddings(*args, **kwargs)
llama_openai.get_embeddings = wrapper_get_embeddings
class EmbeddingsLoader:
"""Loader for embedding model initialization."""
@staticmethod
def get_embedding_model(strategy: str, llm_config: LLMConfig):
supported_ollama_embed_models = [
'llama2',
'mxbai-embed-large',
'nomic-embed-text',
'all-minilm',
'stable-code',
]
if strategy in supported_ollama_embed_models:
from llama_index.embeddings.ollama import OllamaEmbedding
return OllamaEmbedding(
model_name=strategy,
base_url=llm_config.embedding_base_url,
ollama_additional_kwargs={'mirostat': 0},
)
elif strategy == 'openai':
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model='text-embedding-ada-002',
api_key=llm_config.api_key,
)
elif strategy == 'azureopenai':
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
return AzureOpenAIEmbedding(
model='text-embedding-ada-002',
deployment_name=llm_config.embedding_deployment_name,
api_key=llm_config.api_key,
azure_endpoint=llm_config.base_url,
api_version=llm_config.api_version,
)
elif (strategy is not None) and (strategy.lower() == 'none'):
# TODO: this works but is not elegant enough. The incentive is when
# an agent using embeddings is not used, there is no reason we need to
# initialize an embedding model
return None
else:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
return HuggingFaceEmbedding(model_name='BAAI/bge-small-en-v1.5')
class LongTermMemory:
"""Handles storing information for the agent to access later, using chromadb."""
def __init__(self, llm_config: LLMConfig, memory_max_threads: int = 1):
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
if not LLAMA_INDEX_AVAILABLE:
raise ImportError(
'llama_index and its dependencies are not installed. '
'To use LongTermMemory, please run: poetry install --with llama-index'
)
db = chromadb.Client(chromadb.Settings(anonymized_telemetry=False))
self.collection = db.get_or_create_collection(name='memories')
vector_store = ChromaVectorStore(chroma_collection=self.collection)
embedding_strategy = llm_config.embedding_model
embed_model = EmbeddingsLoader.get_embedding_model(
embedding_strategy, llm_config
)
self.index = VectorStoreIndex.from_vector_store(vector_store, embed_model)
self.sema = threading.Semaphore(value=memory_max_threads)
self.thought_idx = 0
self._add_threads: list[threading.Thread] = []
def add_event(self, event: dict):
"""Adds a new event to the long term memory with a unique id.
Parameters:
- event (dict): The new event to be added to memory
"""
id = ''
t = ''
if 'action' in event:
t = 'action'
id = event['action']
elif 'observation' in event:
t = 'observation'
id = event['observation']
doc = Document(
text=json.dumps(event),
doc_id=str(self.thought_idx),
extra_info={
'type': t,
'id': id,
'idx': self.thought_idx,
},
)
self.thought_idx += 1
logger.debug('Adding %s event to memory: %d', t, self.thought_idx)
thread = threading.Thread(target=self._add_doc, args=(doc,))
self._add_threads.append(thread)
thread.start() # We add the doc concurrently so we don't have to wait ~500ms for the insert
def _add_doc(self, doc):
with self.sema:
self.index.insert(doc)
def search(self, query: str, k: int = 10):
"""Searches through the current memory using VectorIndexRetriever
Parameters:
- query (str): A query to match search results to
- k (int): Number of top results to return
Returns:
- list[str]: list of top k results found in current memory
"""
retriever = VectorIndexRetriever(
index=self.index,
similarity_top_k=k,
)
results = retriever.retrieve(query)
return [r.get_text() for r in results]

View File

@@ -0,0 +1,21 @@
from .e2b.sandbox import E2BBox
def get_runtime_cls(name: str):
# Local imports to avoid circular imports
if name == 'eventstream':
from .client.runtime import EventStreamRuntime
return EventStreamRuntime
elif name == 'e2b':
from .e2b.runtime import E2BRuntime
return E2BRuntime
else:
raise ValueError(f'Runtime {name} not supported')
__all__ = [
'E2BBox',
'get_runtime_cls',
]

View File

@@ -0,0 +1,3 @@
from .utils import browse
__all__ = ['browse']

View File

@@ -0,0 +1,231 @@
import atexit
import base64
import io
import json
import multiprocessing
import time
import uuid
import browsergym.core # noqa F401 (we register the openended task as a gym environment)
import gymnasium as gym
import html2text
import numpy as np
import tenacity
from browsergym.utils.obs import flatten_dom_to_str
from PIL import Image
from openhands.core.exceptions import BrowserInitException
from openhands.core.logger import openhands_logger as logger
BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
class BrowserEnv:
def __init__(self, browsergym_eval_env: str | None = None):
self.html_text_converter = self.get_html_text_converter()
self.eval_mode = False
self.eval_dir = ''
# EVAL only: browsergym_eval_env must be provided for evaluation
self.browsergym_eval_env = browsergym_eval_env
self.eval_mode = bool(browsergym_eval_env)
# Initialize browser environment process
multiprocessing.set_start_method('spawn', force=True)
self.browser_side, self.agent_side = multiprocessing.Pipe()
self.init_browser()
atexit.register(self.close)
def get_html_text_converter(self):
html_text_converter = html2text.HTML2Text()
# ignore links and images
html_text_converter.ignore_links = False
html_text_converter.ignore_images = True
# use alt text for images
html_text_converter.images_to_alt = True
# disable auto text wrapping
html_text_converter.body_width = 0
return html_text_converter
@tenacity.retry(
wait=tenacity.wait_fixed(1),
stop=tenacity.stop_after_attempt(5),
retry=tenacity.retry_if_exception_type(BrowserInitException),
)
def init_browser(self):
logger.info('Starting browser env...')
try:
self.process = multiprocessing.Process(target=self.browser_process)
self.process.start()
except Exception as e:
logger.error(f'Failed to start browser process: {e}')
raise
if not self.check_alive():
self.close()
raise BrowserInitException('Failed to start browser environment.')
def browser_process(self):
if self.eval_mode:
assert self.browsergym_eval_env is not None
logger.info('Initializing browser env for web browsing evaluation.')
if 'webarena' in self.browsergym_eval_env:
import browsergym.webarena # noqa F401 register webarena tasks as gym environments
elif 'miniwob' in self.browsergym_eval_env:
import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments
else:
raise ValueError(
f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
)
env = gym.make(self.browsergym_eval_env)
else:
env = gym.make(
'browsergym/openended',
task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
wait_for_user_message=False,
headless=True,
disable_env_checker=True,
)
obs, info = env.reset()
# EVAL ONLY: save the goal into file for evaluation
self.eval_goal = None
self.eval_rewards: list[float] = []
if self.eval_mode:
logger.info(f"Browsing goal: {obs['goal']}")
self.eval_goal = obs['goal']
logger.info('Browser env started.')
while True:
try:
if self.browser_side.poll(timeout=0.01):
unique_request_id, action_data = self.browser_side.recv()
# shutdown the browser environment
if unique_request_id == 'SHUTDOWN':
logger.info('SHUTDOWN recv, shutting down browser env...')
env.close()
return
elif unique_request_id == 'IS_ALIVE':
self.browser_side.send(('ALIVE', None))
continue
# EVAL ONLY: Get evaluation info
if action_data['action'] == BROWSER_EVAL_GET_GOAL_ACTION:
self.browser_side.send(
(unique_request_id, {'text_content': self.eval_goal})
)
continue
elif action_data['action'] == BROWSER_EVAL_GET_REWARDS_ACTION:
self.browser_side.send(
(
unique_request_id,
{'text_content': json.dumps(self.eval_rewards)},
)
)
continue
action = action_data['action']
obs, reward, terminated, truncated, info = env.step(action)
# EVAL ONLY: Save the rewards into file for evaluation
if self.eval_mode:
self.eval_rewards.append(reward)
# add text content of the page
html_str = flatten_dom_to_str(obs['dom_object'])
obs['text_content'] = self.html_text_converter.handle(html_str)
# make observation serializable
obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
obs['active_page_index'] = obs['active_page_index'].item()
obs['elapsed_time'] = obs['elapsed_time'].item()
self.browser_side.send((unique_request_id, obs))
except KeyboardInterrupt:
logger.info('Browser env process interrupted by user.')
try:
env.close()
except Exception:
pass
return
def step(self, action_str: str, timeout: float = 30) -> dict:
"""Execute an action in the browser environment and return the observation."""
unique_request_id = str(uuid.uuid4())
self.agent_side.send((unique_request_id, {'action': action_str}))
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError('Browser environment took too long to respond.')
if self.agent_side.poll(timeout=0.01):
response_id, obs = self.agent_side.recv()
if response_id == unique_request_id:
return obs
def check_alive(self, timeout: float = 60):
self.agent_side.send(('IS_ALIVE', None))
if self.agent_side.poll(timeout=timeout):
response_id, _ = self.agent_side.recv()
if response_id == 'ALIVE':
return True
logger.info(f'Browser env is not alive. Response ID: {response_id}')
def close(self):
if not self.process.is_alive():
return
try:
self.agent_side.send(('SHUTDOWN', None))
self.process.join(5) # Wait for the process to terminate
if self.process.is_alive():
logger.error(
'Browser process did not terminate, forcefully terminating...'
)
self.process.terminate()
self.process.join(5) # Wait for the process to terminate
if self.process.is_alive():
self.process.kill()
self.process.join(5) # Wait for the process to terminate
self.agent_side.close()
self.browser_side.close()
except Exception:
logger.error('Encountered an error when closing browser env', exc_info=True)
@staticmethod
def image_to_png_base64_url(
image: np.ndarray | Image.Image, add_data_prefix: bool = False
):
"""Convert a numpy array to a base64 encoded png image url."""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode in ('RGBA', 'LA'):
image = image.convert('RGB')
buffered = io.BytesIO()
image.save(buffered, format='PNG')
image_base64 = base64.b64encode(buffered.getvalue()).decode()
return (
f'data:image/png;base64,{image_base64}'
if add_data_prefix
else f'{image_base64}'
)
@staticmethod
def image_to_jpg_base64_url(
image: np.ndarray | Image.Image, add_data_prefix: bool = False
):
"""Convert a numpy array to a base64 encoded jpeg image url."""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode in ('RGBA', 'LA'):
image = image.convert('RGB')
buffered = io.BytesIO()
image.save(buffered, format='JPEG')
image_base64 = base64.b64encode(buffered.getvalue()).decode()
return (
f'data:image/jpeg;base64,{image_base64}'
if add_data_prefix
else f'{image_base64}'
)

View File

@@ -0,0 +1,60 @@
import os
from openhands.core.exceptions import BrowserUnavailableException
from openhands.core.schema import ActionType
from openhands.events.action import BrowseInteractiveAction, BrowseURLAction
from openhands.events.observation import BrowserOutputObservation
from openhands.runtime.browser.browser_env import BrowserEnv
async def browse(
action: BrowseURLAction | BrowseInteractiveAction, browser: BrowserEnv | None
) -> BrowserOutputObservation:
if browser is None:
raise BrowserUnavailableException()
if isinstance(action, BrowseURLAction):
# legacy BrowseURLAction
asked_url = action.url
if not asked_url.startswith('http'):
asked_url = os.path.abspath(os.curdir) + action.url
action_str = f'goto("{asked_url}")'
elif isinstance(action, BrowseInteractiveAction):
# new BrowseInteractiveAction, supports full featured BrowserGym actions
# action in BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/functions.py
action_str = action.browser_actions
else:
raise ValueError(f'Invalid action type: {action.action}')
try:
# obs provided by BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/env.py#L396
obs = browser.step(action_str)
return BrowserOutputObservation(
content=obs['text_content'], # text content of the page
url=obs.get('url', ''), # URL of the page
screenshot=obs.get('screenshot', None), # base64-encoded screenshot, png
open_pages_urls=obs.get('open_pages_urls', []), # list of open pages
active_page_index=obs.get(
'active_page_index', -1
), # index of the active page
dom_object=obs.get('dom_object', {}), # DOM object
axtree_object=obs.get('axtree_object', {}), # accessibility tree object
extra_element_properties=obs.get('extra_element_properties', {}),
focused_element_bid=obs.get(
'focused_element_bid', None
), # focused element bid
last_browser_action=obs.get(
'last_action', ''
), # last browser env action performed
last_browser_action_error=obs.get('last_action_error', ''),
error=True if obs.get('last_action_error', '') else False, # error flag
)
except Exception as e:
return BrowserOutputObservation(
content=str(e),
screenshot='',
error=True,
last_browser_action_error=str(e),
url=asked_url if action.action == ActionType.BROWSE else '',
)

View File

@@ -0,0 +1,4 @@
from .base import RuntimeBuilder
from .docker import DockerRuntimeBuilder
__all__ = ['RuntimeBuilder', 'DockerRuntimeBuilder']

View File

@@ -0,0 +1,37 @@
import abc
class RuntimeBuilder(abc.ABC):
@abc.abstractmethod
def build(
self,
path: str,
tags: list[str],
) -> str:
"""
Build the runtime image.
Args:
path (str): The path to the runtime image's build directory.
tags (list[str]): The tags to apply to the runtime image (e.g., ["repo:my-repo", "sha:my-sha"]).
Returns:
str: The name of the runtime image (e.g., "repo:sha").
Raises:
RuntimeError: If the build failed.
"""
pass
@abc.abstractmethod
def image_exists(self, image_name: str) -> bool:
"""
Check if the runtime image exists.
Args:
image_name (str): The name of the runtime image (e.g., "repo:sha").
Returns:
bool: Whether the runtime image exists.
"""
pass

View File

@@ -0,0 +1,83 @@
import docker
from openhands.core.logger import openhands_logger as logger
from .base import RuntimeBuilder
class DockerRuntimeBuilder(RuntimeBuilder):
def __init__(self, docker_client: docker.DockerClient):
self.docker_client = docker_client
def build(self, path: str, tags: list[str]) -> str:
target_image_hash_name = tags[0]
target_image_repo, target_image_hash_tag = target_image_hash_name.split(':')
target_image_tag = tags[1].split(':')[1] if len(tags) > 1 else None
try:
build_logs = self.docker_client.api.build(
path=path,
tag=target_image_hash_name,
rm=True,
decode=True,
)
except docker.errors.BuildError as e:
logger.error(f'Sandbox image build failed: {e}')
raise RuntimeError(f'Sandbox image build failed: {e}')
for log in build_logs:
if 'stream' in log:
print(log['stream'].strip())
elif 'error' in log:
logger.error(log['error'].strip())
else:
logger.info(str(log))
logger.info(f'Image [{target_image_hash_name}] build finished.')
assert (
target_image_tag
), f'Expected target image tag [{target_image_tag}] is None'
image = self.docker_client.images.get(target_image_hash_name)
image.tag(target_image_repo, target_image_tag)
logger.info(
f'Re-tagged image [{target_image_hash_name}] with more generic tag [{target_image_tag}]'
)
# Check if the image is built successfully
image = self.docker_client.images.get(target_image_hash_name)
if image is None:
raise RuntimeError(
f'Build failed: Image {target_image_hash_name} not found'
)
tags_str = (
f'{target_image_hash_tag}, {target_image_tag}'
if target_image_tag
else target_image_hash_tag
)
logger.info(
f'Image {target_image_repo} with tags [{tags_str}] built successfully'
)
return target_image_hash_name
def image_exists(self, image_name: str) -> bool:
"""Check if the image exists in the registry (try to pull it first) AND in the local store.
Args:
image_name (str): The Docker image to check (<image repo>:<image tag>)
Returns:
bool: Whether the Docker image exists in the registry and in the local store
"""
# Try to pull the Docker image from the registry
try:
self.docker_client.images.pull(image_name)
except Exception:
logger.info(f'Cannot pull image {image_name} directly')
images = self.docker_client.images.list()
if images:
for image in images:
if image_name in image.tags:
return True
return False

View File

@@ -0,0 +1,685 @@
"""
This is the main file for the runtime client.
It is responsible for executing actions received from OpenHands backend and producing observations.
NOTE: this will be executed inside the docker sandbox.
"""
import argparse
import asyncio
import os
import re
import shutil
import subprocess
from contextlib import asynccontextmanager
from pathlib import Path
import pexpect
from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
from pydantic import BaseModel
from uvicorn import run
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
Action,
BrowseInteractiveAction,
BrowseURLAction,
CmdRunAction,
FileReadAction,
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.observation import (
CmdOutputObservation,
ErrorObservation,
FileReadObservation,
FileWriteObservation,
IPythonRunCellObservation,
Observation,
)
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.runtime.browser import browse
from openhands.runtime.browser.browser_env import BrowserEnv
from openhands.runtime.plugins import (
ALL_PLUGINS,
JupyterPlugin,
Plugin,
)
from openhands.runtime.utils import split_bash_commands
from openhands.runtime.utils.files import insert_lines, read_lines
class ActionRequest(BaseModel):
action: dict
ROOT_GID = 0
INIT_COMMANDS = [
'git config --global user.name "openhands"',
'git config --global user.email "openhands@all-hands.dev"',
"alias git='git --no-pager'",
]
class RuntimeClient:
"""RuntimeClient is running inside docker sandbox.
It is responsible for executing actions received from OpenHands backend and producing observations.
"""
def __init__(
self,
plugins_to_load: list[Plugin],
work_dir: str,
username: str,
user_id: int,
browsergym_eval_env: str | None,
) -> None:
self.plugins_to_load = plugins_to_load
self.username = username
self.user_id = user_id
self.pwd = work_dir # current PWD
self._initial_pwd = work_dir
self._init_user(self.username, self.user_id)
self._init_bash_shell(self.pwd, self.username)
self.lock = asyncio.Lock()
self.plugins: dict[str, Plugin] = {}
self.browser = BrowserEnv(browsergym_eval_env)
self._initial_pwd = work_dir
@property
def initial_pwd(self):
return self._initial_pwd
async def ainit(self):
for plugin in self.plugins_to_load:
await plugin.initialize(self.username)
self.plugins[plugin.name] = plugin
logger.info(f'Initializing plugin: {plugin.name}')
if isinstance(plugin, JupyterPlugin):
await self.run_ipython(
IPythonRunCellAction(code=f'import os; os.chdir("{self.pwd}")')
)
# This is a temporary workaround
# TODO: refactor AgentSkills to be part of JupyterPlugin
# AFTER ServerRuntime is deprecated
if 'agent_skills' in self.plugins and 'jupyter' in self.plugins:
obs = await self.run_ipython(
IPythonRunCellAction(
code='from openhands.runtime.plugins.agent_skills.agentskills import *\n'
)
)
logger.info(f'AgentSkills initialized: {obs}')
await self._init_bash_commands()
def _init_user(self, username: str, user_id: int) -> None:
"""Create user if not exists."""
# Skip root since it is already created
if username == 'root':
return
# Check if the username already exists
try:
subprocess.run(
f'id -u {username}', shell=True, check=True, capture_output=True
)
logger.debug(f'User {username} already exists. Skipping creation.')
return
except subprocess.CalledProcessError:
pass # User does not exist, continue with creation
# Add sudoer
sudoer_line = r"echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers"
output = subprocess.run(sudoer_line, shell=True, capture_output=True)
if output.returncode != 0:
raise RuntimeError(f'Failed to add sudoer: {output.stderr.decode()}')
logger.debug(f'Added sudoer successfully. Output: [{output.stdout.decode()}]')
# Attempt to add the user, retrying with incremented user_id if necessary
while True:
command = (
f'useradd -rm -d /home/{username} -s /bin/bash '
f'-g root -G sudo -u {user_id} {username}'
)
if not os.path.exists(self.initial_pwd):
command += f' && mkdir -p {self.initial_pwd}'
command += f' && chown -R {username}:root {self.initial_pwd}'
command += f' && chmod g+s {self.initial_pwd}'
output = subprocess.run(command, shell=True, capture_output=True)
if output.returncode == 0:
logger.debug(
f'Added user {username} successfully with UID {user_id}. Output: [{output.stdout.decode()}]'
)
break
elif f'UID {user_id} is not unique' in output.stderr.decode():
logger.warning(
f'UID {user_id} is not unique. Incrementing UID and retrying...'
)
user_id += 1
else:
raise RuntimeError(
f'Failed to create user {username}: {output.stderr.decode()}'
)
def _init_bash_shell(self, work_dir: str, username: str) -> None:
self.shell = pexpect.spawn(
f'su - {username}',
encoding='utf-8',
echo=False,
)
self.__bash_PS1 = r'[PEXPECT_BEGIN] \u@\h:\w [PEXPECT_END]'
# This should NOT match "PS1=\u@\h:\w [PEXPECT]$" when `env` is executed
self.__bash_expect_regex = (
r'\[PEXPECT_BEGIN\] ([a-z0-9_-]*)@([a-zA-Z0-9.-]*):(.+) \[PEXPECT_END\]'
)
self.shell.sendline(f'export PS1="{self.__bash_PS1}"; export PS2=""')
self.shell.expect(self.__bash_expect_regex)
self.shell.sendline(f'cd {work_dir}')
self.shell.expect(self.__bash_expect_regex)
logger.debug(
f'Bash initialized. Working directory: {work_dir}. Output: {self.shell.before}'
)
async def _init_bash_commands(self):
logger.info(f'Initializing by running {len(INIT_COMMANDS)} bash commands...')
for command in INIT_COMMANDS:
action = CmdRunAction(command=command)
action.timeout = 300
logger.debug(f'Executing init command: {command}')
obs: CmdOutputObservation = await self.run(action)
logger.debug(
f'Init command outputs (exit code: {obs.exit_code}): {obs.content}'
)
assert obs.exit_code == 0
logger.info('Bash init commands completed')
def _get_bash_prompt_and_update_pwd(self):
ps1 = self.shell.after
# begin at the last occurence of '[PEXPECT_BEGIN]'.
# In multi-line bash commands, the prompt will be repeated
# and the matched regex captures all of them
# - we only want the last one (newest prompt)
_begin_pos = ps1.rfind('[PEXPECT_BEGIN]')
if _begin_pos != -1:
ps1 = ps1[_begin_pos:]
# parse the ps1 to get username, hostname, and working directory
matched = re.match(self.__bash_expect_regex, ps1)
assert (
matched is not None
), f'Failed to parse bash prompt: {ps1}. This should not happen.'
username, hostname, working_dir = matched.groups()
self.pwd = os.path.expanduser(working_dir)
# re-assemble the prompt
prompt = f'{username}@{hostname}:{working_dir} '
if username == 'root':
prompt += '#'
else:
prompt += '$'
return prompt + ' '
def _execute_bash(
self,
command: str,
timeout: int | None,
keep_prompt: bool = True,
) -> tuple[str, int]:
logger.debug(f'Executing command: {command}')
try:
self.shell.sendline(command)
self.shell.expect(self.__bash_expect_regex, timeout=timeout)
output = self.shell.before
# Get exit code
self.shell.sendline('echo $?')
logger.debug(f'Executing command for exit code: {command}')
self.shell.expect(self.__bash_expect_regex, timeout=timeout)
_exit_code_output = self.shell.before
logger.debug(f'Exit code Output: {_exit_code_output}')
exit_code = int(_exit_code_output.strip().split()[0])
except pexpect.TIMEOUT as e:
self.shell.sendintr() # send SIGINT to the shell
self.shell.expect(self.__bash_expect_regex, timeout=timeout)
output = self.shell.before
output += (
'\r\n\r\n'
+ f'[Command timed out after {timeout} seconds. SIGINT was sent to interrupt it.]'
)
exit_code = 130 # SIGINT
logger.error(f'Failed to execute command: {command}. Error: {e}')
finally:
bash_prompt = self._get_bash_prompt_and_update_pwd()
if keep_prompt:
output += '\r\n' + bash_prompt
logger.debug(f'Command output: {output}')
return output, exit_code
async def run_action(self, action) -> Observation:
action_type = action.action
observation = await getattr(self, action_type)(action)
return observation
async def run(self, action: CmdRunAction) -> CmdOutputObservation:
try:
assert (
action.timeout is not None
), f'Timeout argument is required for CmdRunAction: {action}'
commands = split_bash_commands(action.command)
all_output = ''
for command in commands:
output, exit_code = self._execute_bash(
command,
timeout=action.timeout,
keep_prompt=action.keep_prompt,
)
if all_output:
# previous output already exists with prompt "user@hostname:working_dir #""
# we need to add the command to the previous output,
# so model knows the following is the output of another action)
all_output = all_output.rstrip() + ' ' + command + '\r\n'
all_output += str(output) + '\r\n'
if exit_code != 0:
break
return CmdOutputObservation(
command_id=-1,
content=all_output.rstrip('\r\n'),
command=action.command,
exit_code=exit_code,
)
except UnicodeDecodeError:
raise RuntimeError('Command output could not be decoded as utf-8')
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
if 'jupyter' in self.plugins:
_jupyter_plugin: JupyterPlugin = self.plugins['jupyter'] # type: ignore
# This is used to make AgentSkills in Jupyter aware of the
# current working directory in Bash
if self.pwd != getattr(self, '_jupyter_pwd', None):
logger.debug(
f"{self.pwd} != {getattr(self, '_jupyter_pwd', None)} -> reset Jupyter PWD"
)
reset_jupyter_pwd_code = f'import os; os.chdir("{self.pwd}")'
_aux_action = IPythonRunCellAction(code=reset_jupyter_pwd_code)
_reset_obs = await _jupyter_plugin.run(_aux_action)
logger.debug(
f'Changed working directory in IPython to: {self.pwd}. Output: {_reset_obs}'
)
self._jupyter_pwd = self.pwd
obs: IPythonRunCellObservation = await _jupyter_plugin.run(action)
obs.content = obs.content.rstrip()
obs.content += f'\n[Jupyter current working directory: {self.pwd}]'
return obs
else:
raise RuntimeError(
'JupyterRequirement not found. Unable to run IPython action.'
)
def _get_working_directory(self):
# NOTE: this is part of initialization, so we hard code the timeout
result, exit_code = self._execute_bash('pwd', timeout=60, keep_prompt=False)
if exit_code != 0:
raise RuntimeError('Failed to get working directory')
return result.strip()
def _resolve_path(self, path: str, working_dir: str) -> str:
filepath = Path(path)
if not filepath.is_absolute():
return str(Path(working_dir) / filepath)
return str(filepath)
async def read(self, action: FileReadAction) -> Observation:
# NOTE: the client code is running inside the sandbox,
# so there's no need to check permission
working_dir = self._get_working_directory()
filepath = self._resolve_path(action.path, working_dir)
try:
with open(filepath, 'r', encoding='utf-8') as file:
lines = read_lines(file.readlines(), action.start, action.end)
except FileNotFoundError:
return ErrorObservation(
f'File not found: {filepath}. Your current working directory is {working_dir}.'
)
except UnicodeDecodeError:
return ErrorObservation(f'File could not be decoded as utf-8: {filepath}.')
except IsADirectoryError:
return ErrorObservation(
f'Path is a directory: {filepath}. You can only read files'
)
code_view = ''.join(lines)
return FileReadObservation(path=filepath, content=code_view)
async def write(self, action: FileWriteAction) -> Observation:
working_dir = self._get_working_directory()
filepath = self._resolve_path(action.path, working_dir)
insert = action.content.split('\n')
try:
if not os.path.exists(os.path.dirname(filepath)):
os.makedirs(os.path.dirname(filepath))
file_exists = os.path.exists(filepath)
if file_exists:
file_stat = os.stat(filepath)
else:
file_stat = None
mode = 'w' if not file_exists else 'r+'
try:
with open(filepath, mode, encoding='utf-8') as file:
if mode != 'w':
all_lines = file.readlines()
new_file = insert_lines(
insert, all_lines, action.start, action.end
)
else:
new_file = [i + '\n' for i in insert]
file.seek(0)
file.writelines(new_file)
file.truncate()
# Handle file permissions
if file_exists:
assert file_stat is not None
# restore the original file permissions if the file already exists
os.chmod(filepath, file_stat.st_mode)
os.chown(filepath, file_stat.st_uid, file_stat.st_gid)
else:
# set the new file permissions if the file is new
os.chmod(filepath, 0o644)
os.chown(filepath, self.user_id, self.user_id)
except FileNotFoundError:
return ErrorObservation(f'File not found: {filepath}')
except IsADirectoryError:
return ErrorObservation(
f'Path is a directory: {filepath}. You can only write to files'
)
except UnicodeDecodeError:
return ErrorObservation(
f'File could not be decoded as utf-8: {filepath}'
)
except PermissionError:
return ErrorObservation(f'Malformed paths not permitted: {filepath}')
return FileWriteObservation(content='', path=filepath)
async def browse(self, action: BrowseURLAction) -> Observation:
return await browse(action, self.browser)
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
return await browse(action, self.browser)
def close(self):
self.shell.close()
self.browser.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('port', type=int, help='Port to listen on')
parser.add_argument('--working-dir', type=str, help='Working directory')
parser.add_argument('--plugins', type=str, help='Plugins to initialize', nargs='+')
parser.add_argument(
'--username', type=str, help='User to run as', default='openhands'
)
parser.add_argument('--user-id', type=int, help='User ID to run as', default=1000)
parser.add_argument(
'--browsergym-eval-env',
type=str,
help='BrowserGym environment used for browser evaluation',
default=None,
)
# example: python client.py 8000 --working-dir /workspace --plugins JupyterRequirement
args = parser.parse_args()
plugins_to_load: list[Plugin] = []
if args.plugins:
for plugin in args.plugins:
if plugin not in ALL_PLUGINS:
raise ValueError(f'Plugin {plugin} not found')
plugins_to_load.append(ALL_PLUGINS[plugin]()) # type: ignore
client: RuntimeClient | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global client
client = RuntimeClient(
plugins_to_load,
work_dir=args.working_dir,
username=args.username,
user_id=args.user_id,
browsergym_eval_env=args.browsergym_eval_env,
)
await client.ainit()
yield
# Clean up & release the resources
client.close()
app = FastAPI(lifespan=lifespan)
@app.middleware('http')
async def one_request_at_a_time(request: Request, call_next):
assert client is not None
async with client.lock:
response = await call_next(request)
return response
@app.post('/execute_action')
async def execute_action(action_request: ActionRequest):
assert client is not None
try:
action = event_from_dict(action_request.action)
if not isinstance(action, Action):
raise HTTPException(status_code=400, detail='Invalid action type')
observation = await client.run_action(action)
return event_to_dict(observation)
except Exception as e:
logger.error(f'Error processing command: {str(e)}')
raise HTTPException(status_code=500, detail=str(e))
@app.post('/upload_file')
async def upload_file(
file: UploadFile, destination: str = '/', recursive: bool = False
):
assert client is not None
try:
# Ensure the destination directory exists
if not os.path.isabs(destination):
raise HTTPException(
status_code=400, detail='Destination must be an absolute path'
)
full_dest_path = destination
if not os.path.exists(full_dest_path):
os.makedirs(full_dest_path, exist_ok=True)
if recursive:
# For recursive uploads, we expect a zip file
if not file.filename.endswith('.zip'):
raise HTTPException(
status_code=400, detail='Recursive uploads must be zip files'
)
zip_path = os.path.join(full_dest_path, file.filename)
with open(zip_path, 'wb') as buffer:
shutil.copyfileobj(file.file, buffer)
# Extract the zip file
shutil.unpack_archive(zip_path, full_dest_path)
os.remove(zip_path) # Remove the zip file after extraction
logger.info(
f'Uploaded file {file.filename} and extracted to {destination}'
)
else:
# For single file uploads
file_path = os.path.join(full_dest_path, file.filename)
with open(file_path, 'wb') as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f'Uploaded file {file.filename} to {destination}')
return JSONResponse(
content={
'filename': file.filename,
'destination': destination,
'recursive': recursive,
},
status_code=200,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/alive')
async def alive():
return {'status': 'ok'}
# ================================
# File-specific operations for UI
# ================================
@app.post('/list_files')
async def list_files(request: Request):
"""List files in the specified path.
This function retrieves a list of files from the agent's runtime file store,
excluding certain system and hidden files/directories.
To list files:
```sh
curl http://localhost:3000/api/list-files
```
Args:
request (Request): The incoming request object.
path (str, optional): The path to list files from. Defaults to '/'.
Returns:
list: A list of file names in the specified path.
Raises:
HTTPException: If there's an error listing the files.
"""
assert client is not None
# get request as dict
request_dict = await request.json()
path = request_dict.get('path', None)
# Get the full path of the requested directory
if path is None:
full_path = client.initial_pwd
elif os.path.isabs(path):
full_path = path
else:
full_path = os.path.join(client.initial_pwd, path)
if not os.path.exists(full_path):
return JSONResponse(
content={'error': f'Directory {full_path} does not exist'},
status_code=400,
)
try:
# Check if the directory exists
if not os.path.exists(full_path) or not os.path.isdir(full_path):
return []
# Check if .gitignore exists
gitignore_path = os.path.join(full_path, '.gitignore')
if os.path.exists(gitignore_path):
# Use PathSpec to parse .gitignore
with open(gitignore_path, 'r') as f:
spec = PathSpec.from_lines(GitWildMatchPattern, f.readlines())
else:
# Fallback to default exclude list if .gitignore doesn't exist
default_exclude = [
'.git',
'.DS_Store',
'.svn',
'.hg',
'.idea',
'.vscode',
'.settings',
'.pytest_cache',
'__pycache__',
'node_modules',
'vendor',
'build',
'dist',
'bin',
'logs',
'log',
'tmp',
'temp',
'coverage',
'venv',
'env',
]
spec = PathSpec.from_lines(GitWildMatchPattern, default_exclude)
entries = os.listdir(full_path)
# Filter entries using PathSpec
filtered_entries = [
os.path.join(full_path, entry)
for entry in entries
if not spec.match_file(os.path.relpath(entry, str(full_path)))
]
# Separate directories and files
directories = []
files = []
for entry in filtered_entries:
# Remove leading slash and any parent directory components
entry_relative = entry.lstrip('/').split('/')[-1]
# Construct the full path by joining the base path with the relative entry path
full_entry_path = os.path.join(full_path, entry_relative)
if os.path.exists(full_entry_path):
is_dir = os.path.isdir(full_entry_path)
if is_dir:
# add trailing slash to directories
# required by FE to differentiate directories and files
entry = entry.rstrip('/') + '/'
directories.append(entry)
else:
files.append(entry)
# Sort directories and files separately
directories.sort(key=lambda s: s.lower())
files.sort(key=lambda s: s.lower())
# Combine sorted directories and files
sorted_entries = directories + files
return sorted_entries
except Exception as e:
logger.error(f'Error listing files: {e}', exc_info=True)
return []
logger.info(f'Starting action execution API on port {args.port}')
print(f'Starting action execution API on port {args.port}')
run(app, host='0.0.0.0', port=args.port)

View File

@@ -0,0 +1,386 @@
import asyncio
import os
import tempfile
import uuid
from zipfile import ZipFile
import aiohttp
import docker
import tenacity
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.action import (
BrowseInteractiveAction,
BrowseURLAction,
CmdRunAction,
FileReadAction,
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.action.action import Action
from openhands.events.observation import (
ErrorObservation,
NullObservation,
Observation,
)
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.builder import DockerRuntimeBuilder
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.runtime import Runtime
from openhands.runtime.utils import find_available_tcp_port
from openhands.runtime.utils.runtime_build import build_runtime_image
class EventStreamRuntime(Runtime):
"""This runtime will subscribe the event stream.
When receive an event, it will send the event to od-runtime-client which run inside the docker environment.
"""
container_name_prefix = 'openhands-sandbox-'
def __init__(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
container_image: str | None = None,
):
super().__init__(
config, event_stream, sid, plugins
) # will initialize the event stream
self._port = find_available_tcp_port()
self.api_url = f'http://{self.config.sandbox.api_hostname}:{self._port}'
self.session: aiohttp.ClientSession | None = None
self.instance_id = (
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
)
# TODO: We can switch to aiodocker when `get_od_sandbox_image` is updated to use aiodocker
self.docker_client: docker.DockerClient = self._init_docker_client()
self.container_image = (
self.config.sandbox.container_image
if container_image is None
else container_image
)
self.container_name = self.container_name_prefix + self.instance_id
self.container = None
self.action_semaphore = asyncio.Semaphore(1) # Ensure one action at a time
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
logger.debug(f'EventStreamRuntime `{sid}` config:\n{self.config}')
async def ainit(self, env_vars: dict[str, str] | None = None):
if self.config.sandbox.od_runtime_extra_deps:
logger.info(
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.od_runtime_extra_deps}'
)
self.container_image = build_runtime_image(
self.container_image,
self.runtime_builder,
extra_deps=self.config.sandbox.od_runtime_extra_deps,
)
self.container = await self._init_container(
self.sandbox_workspace_dir,
mount_dir=self.config.workspace_mount_path,
plugins=self.plugins,
)
# MUST call super().ainit() to initialize both default env vars
# AND the ones in env vars!
await super().ainit(env_vars)
logger.info(
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'
)
logger.info(f'Container initialized with env vars: {env_vars}')
@staticmethod
def _init_docker_client() -> docker.DockerClient:
try:
return docker.from_env()
except Exception as ex:
logger.error(
'Launch docker client failed. Please make sure you have installed docker and started the docker daemon.'
)
raise ex
@tenacity.retry(
stop=tenacity.stop_after_attempt(5),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
)
async def _init_container(
self,
sandbox_workspace_dir: str,
mount_dir: str | None = None,
plugins: list[PluginRequirement] | None = None,
):
try:
logger.info(
f'Starting container with image: {self.container_image} and name: {self.container_name}'
)
plugin_arg = ''
if plugins is not None and len(plugins) > 0:
plugin_arg = (
f'--plugins {" ".join([plugin.name for plugin in plugins])} '
)
network_mode: str | None = None
port_mapping: dict[str, int] | None = None
if self.config.sandbox.use_host_network:
network_mode = 'host'
logger.warn(
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop'
)
else:
port_mapping = {f'{self._port}/tcp': self._port}
if mount_dir is not None:
volumes = {mount_dir: {'bind': sandbox_workspace_dir, 'mode': 'rw'}}
logger.info(f'Mount dir: {sandbox_workspace_dir}')
else:
logger.warn(
'Mount dir is not set, will not mount the workspace directory to the container.'
)
volumes = None
if self.config.sandbox.browsergym_eval_env is not None:
browsergym_arg = (
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
)
else:
browsergym_arg = ''
container = self.docker_client.containers.run(
self.container_image,
command=(
f'/openhands/miniforge3/bin/mamba run --no-capture-output -n base '
'PYTHONUNBUFFERED=1 poetry run '
f'python -u -m openhands.runtime.client.client {self._port} '
f'--working-dir {sandbox_workspace_dir} '
f'{plugin_arg}'
f'--username {"openhands" if self.config.run_as_openhands else "root"} '
f'--user-id {self.config.sandbox.user_id} '
f'{browsergym_arg}'
),
network_mode=network_mode,
ports=port_mapping,
working_dir='/openhands/code/',
name=self.container_name,
detach=True,
environment={'DEBUG': 'true'} if self.config.debug else None,
volumes=volumes,
)
logger.info(f'Container started. Server url: {self.api_url}')
return container
except Exception as e:
logger.error('Failed to start container')
logger.exception(e)
await self.close(close_client=False)
raise e
async def _ensure_session(self):
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession()
return self.session
@tenacity.retry(
stop=tenacity.stop_after_attempt(10),
wait=tenacity.wait_exponential(multiplier=2, min=10, max=60),
)
async def _wait_until_alive(self):
logger.debug('Getting container logs...')
container = self.docker_client.containers.get(self.container_name)
# get logs
_logs = container.logs(tail=10).decode('utf-8').split('\n')
# add indent
_logs = '\n'.join([f' |{log}' for log in _logs])
logger.info(
'\n'
+ '-' * 30
+ 'Container logs (last 10 lines):'
+ '-' * 30
+ f'\n{_logs}'
+ '\n'
+ '-' * 90
)
async with aiohttp.ClientSession() as session:
async with session.get(f'{self.api_url}/alive') as response:
if response.status == 200:
return
else:
msg = f'Action execution API is not alive. Response: {response}'
logger.error(msg)
raise RuntimeError(msg)
@property
def sandbox_workspace_dir(self):
return self.config.workspace_mount_path_in_sandbox
async def close(self, close_client: bool = True):
if self.session is not None and not self.session.closed:
await self.session.close()
containers = self.docker_client.containers.list(all=True)
for container in containers:
try:
if container.name.startswith(self.container_name_prefix):
logs = container.logs(tail=1000).decode('utf-8')
logger.debug(
f'==== Container logs ====\n{logs}\n==== End of container logs ===='
)
container.remove(force=True)
except docker.errors.NotFound:
pass
if close_client:
self.docker_client.close()
async def run_action(self, action: Action) -> Observation:
# set timeout to default if not set
if action.timeout is None:
action.timeout = self.config.sandbox.timeout
async with self.action_semaphore:
if not action.runnable:
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
return ErrorObservation(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.'
)
logger.info('Awaiting session')
session = await self._ensure_session()
await self._wait_until_alive()
assert action.timeout is not None
try:
logger.info('Executing command')
async with session.post(
f'{self.api_url}/execute_action',
json={'action': event_to_dict(action)},
timeout=action.timeout,
) as response:
if response.status == 200:
output = await response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
return obs
else:
error_message = await response.text()
logger.error(f'Error from server: {error_message}')
obs = ErrorObservation(
f'Command execution failed: {error_message}'
)
except asyncio.TimeoutError:
logger.error('No response received within the timeout period.')
obs = ErrorObservation('Command execution timed out')
except Exception as e:
logger.error(f'Error during command execution: {e}')
obs = ErrorObservation(f'Command execution failed: {str(e)}')
return obs
async def run(self, action: CmdRunAction) -> Observation:
return await self.run_action(action)
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
return await self.run_action(action)
async def read(self, action: FileReadAction) -> Observation:
return await self.run_action(action)
async def write(self, action: FileWriteAction) -> Observation:
return await self.run_action(action)
async def browse(self, action: BrowseURLAction) -> Observation:
return await self.run_action(action)
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
return await self.run_action(action)
# ====================================================================
# Implement these methods (for file operations) in the subclass
# ====================================================================
async def copy_to(
self, host_src: str, sandbox_dest: str, recursive: bool = False
) -> None:
if not os.path.exists(host_src):
raise FileNotFoundError(f'Source file {host_src} does not exist')
session = await self._ensure_session()
await self._wait_until_alive()
try:
if recursive:
# For recursive copy, create a zip file
with tempfile.NamedTemporaryFile(
suffix='.zip', delete=False
) as temp_zip:
temp_zip_path = temp_zip.name
with ZipFile(temp_zip_path, 'w') as zipf:
for root, _, files in os.walk(host_src):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(
file_path, os.path.dirname(host_src)
)
zipf.write(file_path, arcname)
upload_data = {'file': open(temp_zip_path, 'rb')}
else:
# For single file copy
upload_data = {'file': open(host_src, 'rb')}
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
async with session.post(
f'{self.api_url}/upload_file', data=upload_data, params=params
) as response:
if response.status == 200:
return
else:
error_message = await response.text()
raise Exception(f'Copy operation failed: {error_message}')
except asyncio.TimeoutError:
raise TimeoutError('Copy operation timed out')
except Exception as e:
raise RuntimeError(f'Copy operation failed: {str(e)}')
finally:
if recursive:
os.unlink(temp_zip_path)
logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}')
async def list_files(self, path: str | None = None) -> list[str]:
"""List files in the sandbox.
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
"""
session = await self._ensure_session()
await self._wait_until_alive()
try:
data = {}
if path is not None:
data['path'] = path
async with session.post(
f'{self.api_url}/list_files', json=data
) as response:
if response.status == 200:
response_json = await response.json()
assert isinstance(response_json, list)
return response_json
else:
error_message = await response.text()
raise Exception(f'List files operation failed: {error_message}')
except asyncio.TimeoutError:
raise TimeoutError('List files operation timed out')
except Exception as e:
raise RuntimeError(f'List files operation failed: {str(e)}')

View File

@@ -0,0 +1,35 @@
# How to use E2B
[E2B](https://e2b.dev) is an [open-source](https://github.com/e2b-dev/e2b) secure cloud environment (sandbox) made for running AI-generated code and agents. E2B offers [Python](https://pypi.org/project/e2b/) and [JS/TS](https://www.npmjs.com/package/e2b) SDK to spawn and control these sandboxes.
## Getting started
1. [Get your API key](https://e2b.dev/docs/getting-started/api-key)
1. Set your E2B API key to the `E2B_API_KEY` env var when starting the Docker container
1. **Optional** - Install the CLI with NPM.
```sh
npm install -g @e2b/cli@latest
```
Full CLI API is [here](https://e2b.dev/docs/cli/installation).
## OpenHands sandbox
You can use the E2B CLI to create a custom sandbox with a Dockerfile. Read the full guide [here](https://e2b.dev/docs/guide/custom-sandbox). The premade OpenHands sandbox for E2B is set up in the [`containers` directory](/containers/e2b-sandbox). and it's called `openhands`.
## Debugging
You can connect to a running E2B sandbox with E2B CLI in your terminal.
- List all running sandboxes (based on your API key)
```sh
e2b sandbox list
```
- Connect to a running sandbox
```sh
e2b sandbox connect <sandbox-id>
```
## Links
- [E2B Docs](https://e2b.dev/docs)
- [E2B GitHub](https://github.com/e2b-dev/e2b)

View File

@@ -0,0 +1,18 @@
from openhands.storage.files import FileStore
class E2BFileStore(FileStore):
def __init__(self, filesystem):
self.filesystem = filesystem
def write(self, path: str, contents: str) -> None:
self.filesystem.write(path, contents)
def read(self, path: str) -> str:
return self.filesystem.read(path)
def list(self, path: str) -> list[str]:
return self.filesystem.list(path)
def delete(self, path: str) -> None:
self.filesystem.delete(path)

View File

@@ -0,0 +1,57 @@
from openhands.core.config import AppConfig
from openhands.events.action import (
FileReadAction,
FileWriteAction,
)
from openhands.events.observation import (
ErrorObservation,
FileReadObservation,
FileWriteObservation,
Observation,
)
from openhands.events.stream import EventStream
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.runtime import Runtime
from ..utils.files import insert_lines, read_lines
from .filestore import E2BFileStore
from .sandbox import E2BSandbox
class E2BRuntime(Runtime):
def __init__(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
sandbox: E2BSandbox | None = None,
):
super().__init__(config, event_stream, sid, plugins)
if sandbox is None:
self.sandbox = E2BSandbox()
if not isinstance(self.sandbox, E2BSandbox):
raise ValueError('E2BRuntime requires an E2BSandbox')
self.file_store = E2BFileStore(self.sandbox.filesystem)
async def read(self, action: FileReadAction) -> Observation:
content = self.file_store.read(action.path)
lines = read_lines(content.split('\n'), action.start, action.end)
code_view = ''.join(lines)
return FileReadObservation(code_view, path=action.path)
async def write(self, action: FileWriteAction) -> Observation:
if action.start == 0 and action.end == -1:
self.file_store.write(action.path, action.content)
return FileWriteObservation(content='', path=action.path)
files = self.file_store.list(action.path)
if action.path in files:
all_lines = self.file_store.read(action.path).split('\n')
new_file = insert_lines(
action.content.split('\n'), all_lines, action.start, action.end
)
self.file_store.write(action.path, ''.join(new_file))
return FileWriteObservation('', path=action.path)
else:
# FIXME: we should create a new file here
return ErrorObservation(f'File not found: {action.path}')

View File

@@ -0,0 +1,116 @@
import copy
import os
import tarfile
from glob import glob
from e2b import Sandbox as E2BSandbox
from e2b.sandbox.exception import (
TimeoutException,
)
from openhands.core.config import SandboxConfig
from openhands.core.logger import openhands_logger as logger
class E2BBox:
closed = False
_cwd: str = '/home/user'
_env: dict[str, str] = {}
is_initial_session: bool = True
def __init__(
self,
config: SandboxConfig,
e2b_api_key: str,
template: str = 'openhands',
):
self.config = copy.deepcopy(config)
self.initialize_plugins: bool = config.initialize_plugins
self.sandbox = E2BSandbox(
api_key=e2b_api_key,
template=template,
# It's possible to stream stdout and stderr from sandbox and from each process
on_stderr=lambda x: logger.info(f'E2B sandbox stderr: {x}'),
on_stdout=lambda x: logger.info(f'E2B sandbox stdout: {x}'),
cwd=self._cwd, # Default workdir inside sandbox
)
logger.info(f'Started E2B sandbox with ID "{self.sandbox.id}"')
@property
def filesystem(self):
return self.sandbox.filesystem
def _archive(self, host_src: str, recursive: bool = False):
if recursive:
assert os.path.isdir(
host_src
), 'Source must be a directory when recursive is True'
files = glob(host_src + '/**/*', recursive=True)
srcname = os.path.basename(host_src)
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
with tarfile.open(tar_filename, mode='w') as tar:
for file in files:
tar.add(
file, arcname=os.path.relpath(file, os.path.dirname(host_src))
)
else:
assert os.path.isfile(
host_src
), 'Source must be a file when recursive is False'
srcname = os.path.basename(host_src)
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
with tarfile.open(tar_filename, mode='w') as tar:
tar.add(host_src, arcname=srcname)
return tar_filename
def execute(self, cmd: str, timeout: int | None = None) -> tuple[int, str]:
timeout = timeout if timeout is not None else self.config.timeout
process = self.sandbox.process.start(cmd, env_vars=self._env)
try:
process_output = process.wait(timeout=timeout)
except TimeoutException:
logger.info('Command timed out, killing process...')
process.kill()
return -1, f'Command: "{cmd}" timed out'
logs = [m.line for m in process_output.messages]
logs_str = '\n'.join(logs)
if process.exit_code is None:
return -1, logs_str
assert process_output.exit_code is not None
return process_output.exit_code, logs_str
def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
"""Copies a local file or directory to the sandbox."""
tar_filename = self._archive(host_src, recursive)
# Prepend the sandbox destination with our sandbox cwd
sandbox_dest = os.path.join(self._cwd, sandbox_dest.removeprefix('/'))
with open(tar_filename, 'rb') as tar_file:
# Upload the archive to /home/user (default destination that always exists)
uploaded_path = self.sandbox.upload_file(tar_file)
# Check if sandbox_dest exists. If not, create it.
process = self.sandbox.process.start_and_wait(f'test -d {sandbox_dest}')
if process.exit_code != 0:
self.sandbox.filesystem.make_dir(sandbox_dest)
# Extract the archive into the destination and delete the archive
process = self.sandbox.process.start_and_wait(
f'sudo tar -xf {uploaded_path} -C {sandbox_dest} && sudo rm {uploaded_path}'
)
if process.exit_code != 0:
raise Exception(
f'Failed to extract {uploaded_path} to {sandbox_dest}: {process.stderr}'
)
# Delete the local archive
os.remove(tar_filename)
def close(self):
self.sandbox.close()
def get_working_directory(self):
return self.sandbox.cwd

View File

@@ -0,0 +1,18 @@
# Requirements
from .agent_skills import AgentSkillsPlugin, AgentSkillsRequirement
from .jupyter import JupyterPlugin, JupyterRequirement
from .requirement import Plugin, PluginRequirement
__all__ = [
'Plugin',
'PluginRequirement',
'AgentSkillsRequirement',
'AgentSkillsPlugin',
'JupyterRequirement',
'JupyterPlugin',
]
ALL_PLUGINS = {
'jupyter': JupyterPlugin,
'agent_skills': AgentSkillsPlugin,
}

View File

@@ -0,0 +1,57 @@
# OpenHands Skill Sets
This folder implements a skill/tool set `agentskills` for OpenHands.
It is intended to be used by the agent **inside sandbox**.
The skill set will be exposed as a `pip` package that can be installed as a plugin inside the sandbox.
The skill set can contain a bunch of wrapped tools for agent ([many examples here](https://github.com/All-Hands-AI/OpenHands/pull/1914)), for example:
- Audio/Video to text (these are a temporary solution, and we should switch to multimodal models when they are sufficiently cheap
- PDF to text
- etc.
# Inclusion Criteria
We are walking a fine line here.
We DON't want to *wrap* every possible python packages and re-teach agent their usage (e.g., LLM already knows `pandas` pretty well, so we don't really need create a skill that reads `csv` - it can just use `pandas`).
We ONLY want to add a new skill, when:
- Such skill is not easily achievable for LLM to write code directly (e.g., edit code and replace certain line)
- It involves calling an external model (e.g., you need to call a speech to text model, editor model for speculative editing)
# Intended functionality
- Tool/skill usage (through `IPythonRunAction`)
```python
# In[1]
from agentskills import open_file, edit_file
open_file("/workspace/a.txt")
# Out[1]
[SWE-agent open output]
# In[2]
edit_file(
"/workspace/a.txt",
start=1, end=3,
content=(
("REPLACE TEXT")
))
# Out[1]
[SWE-agent edit output]
```
- Tool/skill retrieval (through `IPythonRunAction`)
```python
# In[1]
from agentskills import help_me
help_me("I want to solve a task that involves reading a bunch of PDFs and reason about them")
# Out[1]
"Here are the top skills that may be helpful to you:
- `pdf_to_text`: [documentation about the tools]
...
"
```

View File

@@ -0,0 +1,15 @@
from dataclasses import dataclass
from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
from . import agentskills
@dataclass
class AgentSkillsRequirement(PluginRequirement):
name: str = 'agent_skills'
documentation: str = agentskills.DOCUMENTATION
class AgentSkillsPlugin(Plugin):
name: str = 'agent_skills'

View File

@@ -0,0 +1,25 @@
from inspect import signature
from . import file_ops, file_reader
from .utils.dependency import import_functions
import_functions(
module=file_ops, function_names=file_ops.__all__, target_globals=globals()
)
import_functions(
module=file_reader, function_names=file_reader.__all__, target_globals=globals()
)
__all__ = file_ops.__all__ + file_reader.__all__
DOCUMENTATION = ''
for func_name in __all__:
func = globals()[func_name]
cur_doc = func.__doc__
# remove indentation from docstring and extra empty lines
cur_doc = '\n'.join(filter(None, map(lambda x: x.strip(), cur_doc.split('\n'))))
# now add a consistent 4 indentation
cur_doc = '\n'.join(map(lambda x: ' ' * 4 + x, cur_doc.split('\n')))
fn_signature = f'{func.__name__}' + str(signature(func))
DOCUMENTATION += f'{fn_signature}:\n{cur_doc}\n\n'

View File

@@ -0,0 +1,7 @@
from ..utils.dependency import import_functions
from . import file_ops
import_functions(
module=file_ops, function_names=file_ops.__all__, target_globals=globals()
)
__all__ = file_ops.__all__

View File

@@ -0,0 +1,857 @@
"""file_ops.py
This module provides various file manipulation skills for the OpenHands agent.
Functions:
- open_file(path: str, line_number: int | None = 1, context_lines: int = 100): Opens a file and optionally moves to a specific line.
- goto_line(line_number: int): Moves the window to show the specified line number.
- scroll_down(): Moves the window down by the number of lines specified in WINDOW.
- scroll_up(): Moves the window up by the number of lines specified in WINDOW.
- create_file(filename: str): Creates and opens a new file with the given name.
- search_dir(search_term: str, dir_path: str = './'): Searches for a term in all files in the specified directory.
- search_file(search_term: str, file_path: str | None = None): Searches for a term in the specified file or the currently open file.
- find_file(file_name: str, dir_path: str = './'): Finds all files with the given name in the specified directory.
- edit_file_by_replace(file_name: str, to_replace: str, new_content: str): Replaces specific content in a file with new content.
- insert_content_at_line(file_name: str, line_number: int, content: str): Inserts given content at the specified line number in a file.
- append_file(file_name: str, content: str): Appends the given content to the end of the specified file.
"""
import os
import re
import shutil
import tempfile
if __package__ is None or __package__ == '':
from aider import Linter
else:
from ..utils.aider import Linter
CURRENT_FILE: str | None = None
CURRENT_LINE = 1
WINDOW = 100
# This is also used in unit tests!
MSG_FILE_UPDATED = '[File updated (edited at line {line_number}). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]'
# ==================================================================================================
def _is_valid_filename(file_name) -> bool:
if not file_name or not isinstance(file_name, str) or not file_name.strip():
return False
invalid_chars = '<>:"/\\|?*'
if os.name == 'nt': # Windows
invalid_chars = '<>:"/\\|?*'
elif os.name == 'posix': # Unix-like systems
invalid_chars = '\0'
for char in invalid_chars:
if char in file_name:
return False
return True
def _is_valid_path(path) -> bool:
if not path or not isinstance(path, str):
return False
try:
return os.path.exists(os.path.normpath(path))
except PermissionError:
return False
def _create_paths(file_name) -> bool:
try:
dirname = os.path.dirname(file_name)
if dirname:
os.makedirs(dirname, exist_ok=True)
return True
except PermissionError:
return False
def _check_current_file(file_path: str | None = None) -> bool:
global CURRENT_FILE
if not file_path:
file_path = CURRENT_FILE
if not file_path or not os.path.isfile(file_path):
raise ValueError('No file open. Use the open_file function first.')
return True
def _clamp(value, min_value, max_value):
return max(min_value, min(value, max_value))
def _lint_file(file_path: str) -> tuple[str | None, int | None]:
"""Lint the file at the given path and return a tuple with a boolean indicating if there are errors,
and the line number of the first error, if any.
Returns:
tuple[str | None, int | None]: (lint_error, first_error_line_number)
"""
linter = Linter(root=os.getcwd())
lint_error = linter.lint(file_path)
if not lint_error:
# Linting successful. No issues found.
return None, None
return 'ERRORS:\n' + lint_error.text, lint_error.lines[0]
def _print_window(file_path, targeted_line, window, return_str=False):
global CURRENT_LINE
_check_current_file(file_path)
with open(file_path) as file:
content = file.read()
# Ensure the content ends with a newline character
if not content.endswith('\n'):
content += '\n'
lines = content.splitlines(True) # Keep all line ending characters
total_lines = len(lines)
# cover edge cases
CURRENT_LINE = _clamp(targeted_line, 1, total_lines)
half_window = max(1, window // 2)
# Ensure at least one line above and below the targeted line
start = max(1, CURRENT_LINE - half_window)
end = min(total_lines, CURRENT_LINE + half_window)
# Adjust start and end to ensure at least one line above and below
if start == 1:
end = min(total_lines, start + window - 1)
if end == total_lines:
start = max(1, end - window + 1)
output = ''
# only display this when there's at least one line above
if start > 1:
output += f'({start - 1} more lines above)\n'
else:
output += '(this is the beginning of the file)\n'
for i in range(start, end + 1):
_new_line = f'{i}|{lines[i-1]}'
if not _new_line.endswith('\n'):
_new_line += '\n'
output += _new_line
if end < total_lines:
output += f'({total_lines - end} more lines below)\n'
else:
output += '(this is the end of the file)\n'
output = output.rstrip()
if return_str:
return output
else:
print(output)
def _cur_file_header(current_file, total_lines) -> str:
if not current_file:
return ''
return f'[File: {os.path.abspath(current_file)} ({total_lines} lines total)]\n'
def open_file(
path: str, line_number: int | None = 1, context_lines: int | None = WINDOW
) -> None:
"""Opens the file at the given path in the editor. If line_number is provided, the window will be moved to include that line.
It only shows the first 100 lines by default! Max `context_lines` supported is 2000, use `scroll up/down`
to view the file if you want to see more.
Args:
path: str: The path to the file to open, preferred absolute path.
line_number: int | None = 1: The line number to move to. Defaults to 1.
context_lines: int | None = 100: Only shows this number of lines in the context window (usually from line 1), with line_number as the center (if possible). Defaults to 100.
"""
global CURRENT_FILE, CURRENT_LINE, WINDOW
if not os.path.isfile(path):
raise FileNotFoundError(f'File {path} not found')
CURRENT_FILE = os.path.abspath(path)
with open(CURRENT_FILE) as file:
total_lines = max(1, sum(1 for _ in file))
if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines:
raise ValueError(f'Line number must be between 1 and {total_lines}')
CURRENT_LINE = line_number
# Override WINDOW with context_lines
if context_lines is None or context_lines < 1:
context_lines = WINDOW
output = _cur_file_header(CURRENT_FILE, total_lines)
output += _print_window(
CURRENT_FILE, CURRENT_LINE, _clamp(context_lines, 1, 2000), return_str=True
)
print(output)
def goto_line(line_number: int) -> None:
"""Moves the window to show the specified line number.
Args:
line_number: int: The line number to move to.
"""
global CURRENT_FILE, CURRENT_LINE, WINDOW
_check_current_file()
with open(str(CURRENT_FILE)) as file:
total_lines = max(1, sum(1 for _ in file))
if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines:
raise ValueError(f'Line number must be between 1 and {total_lines}')
CURRENT_LINE = _clamp(line_number, 1, total_lines)
output = _cur_file_header(CURRENT_FILE, total_lines)
output += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True)
print(output)
def scroll_down() -> None:
"""Moves the window down by 100 lines.
Args:
None
"""
global CURRENT_FILE, CURRENT_LINE, WINDOW
_check_current_file()
with open(str(CURRENT_FILE)) as file:
total_lines = max(1, sum(1 for _ in file))
CURRENT_LINE = _clamp(CURRENT_LINE + WINDOW, 1, total_lines)
output = _cur_file_header(CURRENT_FILE, total_lines)
output += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True)
print(output)
def scroll_up() -> None:
"""Moves the window up by 100 lines.
Args:
None
"""
global CURRENT_FILE, CURRENT_LINE, WINDOW
_check_current_file()
with open(str(CURRENT_FILE)) as file:
total_lines = max(1, sum(1 for _ in file))
CURRENT_LINE = _clamp(CURRENT_LINE - WINDOW, 1, total_lines)
output = _cur_file_header(CURRENT_FILE, total_lines)
output += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True)
print(output)
def create_file(filename: str) -> None:
"""Creates and opens a new file with the given name.
Args:
filename: str: The name of the file to create.
"""
if os.path.exists(filename):
raise FileExistsError(f"File '{filename}' already exists.")
with open(filename, 'w') as file:
file.write('\n')
open_file(filename)
print(f'[File {filename} created.]')
LINTER_ERROR_MSG = '[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n'
class LineNumberError(Exception):
pass
def _append_impl(lines, content):
"""Internal method to handle appending to a file.
Args:
lines: list[str]: The lines in the original file.
content: str: The content to append to the file.
Returns:
content: str: The new content of the file.
n_added_lines: int: The number of lines added to the file.
"""
content_lines = content.splitlines(keepends=True)
n_added_lines = len(content_lines)
if lines and not (len(lines) == 1 and lines[0].strip() == ''):
# file is not empty
if not lines[-1].endswith('\n'):
lines[-1] += '\n'
new_lines = lines + content_lines
content = ''.join(new_lines)
else:
# file is empty
content = ''.join(content_lines)
return content, n_added_lines
def _insert_impl(lines, start, content):
"""Internal method to handle inserting to a file.
Args:
lines: list[str]: The lines in the original file.
start: int: The start line number for inserting.
content: str: The content to insert to the file.
Returns:
content: str: The new content of the file.
n_added_lines: int: The number of lines added to the file.
Raises:
LineNumberError: If the start line number is invalid.
"""
inserted_lines = [content + '\n' if not content.endswith('\n') else content]
if len(lines) == 0:
new_lines = inserted_lines
elif start is not None:
if len(lines) == 1 and lines[0].strip() == '':
# if the file with only 1 line and that line is empty
lines = []
if len(lines) == 0:
new_lines = inserted_lines
else:
new_lines = lines[: start - 1] + inserted_lines + lines[start - 1 :]
else:
raise LineNumberError(
f'Invalid line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive).'
)
content = ''.join(new_lines)
n_added_lines = len(inserted_lines)
return content, n_added_lines
def _edit_impl(lines, start, end, content):
"""Internal method to handle editing a file.
REQUIRES (should be checked by caller):
start <= end
start and end are between 1 and len(lines) (inclusive)
content ends with a newline
Args:
lines: list[str]: The lines in the original file.
start: int: The start line number for editing.
end: int: The end line number for editing.
content: str: The content to replace the lines with.
Returns:
content: str: The new content of the file.
n_added_lines: int: The number of lines added to the file.
"""
# Handle cases where start or end are None
if start is None:
start = 1 # Default to the beginning
if end is None:
end = len(lines) # Default to the end
# Check arguments
if not (1 <= start <= len(lines)):
raise LineNumberError(
f'Invalid start line number: {start}. Line numbers must be between 1 and {len(lines)} (inclusive).'
)
if not (1 <= end <= len(lines)):
raise LineNumberError(
f'Invalid end line number: {end}. Line numbers must be between 1 and {len(lines)} (inclusive).'
)
if start > end:
raise LineNumberError(
f'Invalid line range: {start}-{end}. Start must be less than or equal to end.'
)
if not content.endswith('\n'):
content += '\n'
content_lines = content.splitlines(True)
n_added_lines = len(content_lines)
new_lines = lines[: start - 1] + content_lines + lines[end:]
content = ''.join(new_lines)
return content, n_added_lines
def _edit_file_impl(
file_name: str,
start: int | None = None,
end: int | None = None,
content: str = '',
is_insert: bool = False,
is_append: bool = False,
) -> str:
"""Internal method to handle common logic for edit_/append_file methods.
Args:
file_name: str: The name of the file to edit or append to.
start: int | None = None: The start line number for editing. Ignored if is_append is True.
end: int | None = None: The end line number for editing. Ignored if is_append is True.
content: str: The content to replace the lines with or to append.
is_insert: bool = False: Whether to insert content at the given line number instead of editing.
is_append: bool = False: Whether to append content to the file instead of editing.
"""
ret_str = ''
global CURRENT_FILE, CURRENT_LINE, WINDOW
ERROR_MSG = f'[Error editing file {file_name}. Please confirm the file is correct.]'
ERROR_MSG_SUFFIX = (
'Your changes have NOT been applied. Please fix your edit command and try again.\n'
'You either need to 1) Open the correct file and try again or 2) Specify the correct line number arguments.\n'
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.'
)
if not _is_valid_filename(file_name):
raise FileNotFoundError('Invalid file name.')
if not _is_valid_path(file_name):
raise FileNotFoundError('Invalid path or file name.')
if not _create_paths(file_name):
raise PermissionError('Could not access or create directories.')
if not os.path.isfile(file_name):
raise FileNotFoundError(f'File {file_name} not found.')
if is_insert and is_append:
raise ValueError('Cannot insert and append at the same time.')
# Use a temporary file to write changes
content = str(content or '')
temp_file_path = ''
src_abs_path = os.path.abspath(file_name)
first_error_line = None
try:
n_added_lines = None
# lint the original file
enable_auto_lint = os.getenv('ENABLE_AUTO_LINT', 'false').lower() == 'true'
if enable_auto_lint:
original_lint_error, _ = _lint_file(file_name)
# Create a temporary file
with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
temp_file_path = temp_file.name
# Read the original file and check if empty and for a trailing newline
with open(file_name) as original_file:
lines = original_file.readlines()
if is_append:
content, n_added_lines = _append_impl(lines, content)
elif is_insert:
try:
content, n_added_lines = _insert_impl(lines, start, content)
except LineNumberError as e:
ret_str += (f'{ERROR_MSG}\n' f'{e}\n' f'{ERROR_MSG_SUFFIX}') + '\n'
return ret_str
else:
try:
content, n_added_lines = _edit_impl(lines, start, end, content)
except LineNumberError as e:
ret_str += (f'{ERROR_MSG}\n' f'{e}\n' f'{ERROR_MSG_SUFFIX}') + '\n'
return ret_str
if not content.endswith('\n'):
content += '\n'
# Write the new content to the temporary file
temp_file.write(content)
# Replace the original file with the temporary file atomically
shutil.move(temp_file_path, src_abs_path)
# Handle linting
# NOTE: we need to get env var inside this function
# because the env var will be set AFTER the agentskills is imported
if enable_auto_lint:
# BACKUP the original file
original_file_backup_path = os.path.join(
os.path.dirname(file_name),
f'.backup.{os.path.basename(file_name)}',
)
with open(original_file_backup_path, 'w') as f:
f.writelines(lines)
lint_error, first_error_line = _lint_file(file_name)
# Select the errors caused by the modification
def extract_last_part(line):
parts = line.split(':')
if len(parts) > 1:
return parts[-1].strip()
return line.strip()
def subtract_strings(str1, str2) -> str:
lines1 = str1.splitlines()
lines2 = str2.splitlines()
last_parts1 = [extract_last_part(line) for line in lines1]
remaining_lines = [
line
for line in lines2
if extract_last_part(line) not in last_parts1
]
result = '\n'.join(remaining_lines)
return result
if original_lint_error and lint_error:
lint_error = subtract_strings(original_lint_error, lint_error)
if lint_error == '':
lint_error = None
first_error_line = None
if lint_error is not None:
if first_error_line is not None:
show_line = int(first_error_line)
elif is_append:
# original end-of-file
show_line = len(lines)
# insert OR edit WILL provide meaningful line numbers
elif start is not None and end is not None:
show_line = int((start + end) / 2)
else:
raise ValueError('Invalid state. This should never happen.')
ret_str += LINTER_ERROR_MSG
ret_str += lint_error + '\n'
editor_lines = n_added_lines + 20
ret_str += '[This is how your edit would have looked if applied]\n'
ret_str += '-------------------------------------------------\n'
ret_str += (
_print_window(file_name, show_line, editor_lines, return_str=True)
+ '\n'
)
ret_str += '-------------------------------------------------\n\n'
ret_str += '[This is the original code before your edit]\n'
ret_str += '-------------------------------------------------\n'
ret_str += (
_print_window(
original_file_backup_path,
show_line,
editor_lines,
return_str=True,
)
+ '\n'
)
ret_str += '-------------------------------------------------\n'
ret_str += (
'Your changes have NOT been applied. Please fix your edit command and try again.\n'
'You either need to 1) Specify the correct start/end line arguments or 2) Correct your edit code.\n'
'DO NOT re-run the same failed edit command. Running it again will lead to the same error.'
)
# recover the original file
with open(original_file_backup_path) as fin, open(
file_name, 'w'
) as fout:
fout.write(fin.read())
os.remove(original_file_backup_path)
return ret_str
except FileNotFoundError as e:
ret_str += f'File not found: {e}\n'
except IOError as e:
ret_str += f'An error occurred while handling the file: {e}\n'
except ValueError as e:
ret_str += f'Invalid input: {e}\n'
except Exception as e:
# Clean up the temporary file if an error occurs
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
print(f'An unexpected error occurred: {e}')
raise e
# Update the file information and print the updated content
with open(file_name, 'r', encoding='utf-8') as file:
n_total_lines = max(1, len(file.readlines()))
if first_error_line is not None and int(first_error_line) > 0:
CURRENT_LINE = first_error_line
else:
if is_append:
CURRENT_LINE = max(1, len(lines)) # end of original file
else:
CURRENT_LINE = start or n_total_lines or 1
ret_str += f'[File: {os.path.abspath(file_name)} ({n_total_lines} lines total after edit)]\n'
CURRENT_FILE = file_name
ret_str += _print_window(CURRENT_FILE, CURRENT_LINE, WINDOW, return_str=True) + '\n'
ret_str += MSG_FILE_UPDATED.format(line_number=CURRENT_LINE)
return ret_str
def edit_file_by_replace(file_name: str, to_replace: str, new_content: str) -> None:
"""Edit a file. This will search for `to_replace` in the given file and replace it with `new_content`.
Every *to_replace* must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc.
Include enough lines to make code in `to_replace` unique. `to_replace` should NOT be empty.
For example, given a file "/workspace/example.txt" with the following content:
```
line 1
line 2
line 2
line 3
```
EDITING: If you want to replace the second occurrence of "line 2", you can make `to_replace` unique:
edit_file_by_replace(
'/workspace/example.txt',
to_replace='line 2\nline 3',
new_content='new line\nline 3',
)
This will replace only the second "line 2" with "new line". The first "line 2" will remain unchanged.
The resulting file will be:
```
line 1
line 2
new line
line 3
```
REMOVAL: If you want to remove "line 2" and "line 3", you can set `new_content` to an empty string:
edit_file_by_replace(
'/workspace/example.txt',
to_replace='line 2\nline 3',
new_content='',
)
Args:
file_name: str: The name of the file to edit.
to_replace: str: The content to search for and replace.
new_content: str: The new content to replace the old content with.
"""
# FIXME: support replacing *all* occurrences
if to_replace.strip() == '':
raise ValueError('`to_replace` must not be empty.')
if to_replace == new_content:
raise ValueError('`to_replace` and `new_content` must be different.')
# search for `to_replace` in the file
# if found, replace it with `new_content`
# if not found, perform a fuzzy search to find the closest match and replace it with `new_content`
with open(file_name, 'r') as file:
file_content = file.read()
if file_content.count(to_replace) > 1:
raise ValueError(
'`to_replace` appears more than once, please include enough lines to make code in `to_replace` unique.'
)
start = file_content.find(to_replace)
if start != -1:
# Convert start from index to line number
start_line_number = file_content[:start].count('\n') + 1
end_line_number = start_line_number + len(to_replace.splitlines()) - 1
else:
def _fuzzy_transform(s: str) -> str:
# remove all space except newline
return re.sub(r'[^\S\n]+', '', s)
# perform a fuzzy search (remove all spaces except newlines)
to_replace_fuzzy = _fuzzy_transform(to_replace)
file_content_fuzzy = _fuzzy_transform(file_content)
# find the closest match
start = file_content_fuzzy.find(to_replace_fuzzy)
if start == -1:
print(
f'[No exact match found in {file_name} for\n```\n{to_replace}\n```\n]'
)
return
# Convert start from index to line number for fuzzy match
start_line_number = file_content_fuzzy[:start].count('\n') + 1
end_line_number = start_line_number + len(to_replace.splitlines()) - 1
ret_str = _edit_file_impl(
file_name,
start=start_line_number,
end=end_line_number,
content=new_content,
is_insert=False,
)
# lint_error = bool(LINTER_ERROR_MSG in ret_str)
# TODO: automatically tries to fix linter error (maybe involve some static analysis tools on the location near the edit to figure out indentation)
print(ret_str)
def insert_content_at_line(file_name: str, line_number: int, content: str) -> None:
"""Insert content at the given line number in a file.
This will NOT modify the content of the lines before OR after the given line number.
For example, if the file has the following content:
```
line 1
line 2
line 3
```
and you call `insert_content_at_line('file.txt', 2, 'new line')`, the file will be updated to:
```
line 1
new line
line 2
line 3
```
Args:
file_name: str: The name of the file to edit.
line_number: int: The line number (starting from 1) to insert the content after.
content: str: The content to insert.
"""
ret_str = _edit_file_impl(
file_name,
start=line_number,
end=line_number,
content=content,
is_insert=True,
is_append=False,
)
print(ret_str)
def append_file(file_name: str, content: str) -> None:
"""Append content to the given file.
It appends text `content` to the end of the specified file.
Args:
file_name: str: The name of the file to edit.
line_number: int: The line number (starting from 1) to insert the content after.
content: str: The content to insert.
"""
ret_str = _edit_file_impl(
file_name,
start=None,
end=None,
content=content,
is_insert=False,
is_append=True,
)
print(ret_str)
def search_dir(search_term: str, dir_path: str = './') -> None:
"""Searches for search_term in all files in dir. If dir is not provided, searches in the current directory.
Args:
search_term: str: The term to search for.
dir_path: str: The path to the directory to search.
"""
if not os.path.isdir(dir_path):
raise FileNotFoundError(f'Directory {dir_path} not found')
matches = []
for root, _, files in os.walk(dir_path):
for file in files:
if file.startswith('.'):
continue
file_path = os.path.join(root, file)
with open(file_path, 'r', errors='ignore') as f:
for line_num, line in enumerate(f, 1):
if search_term in line:
matches.append((file_path, line_num, line.strip()))
if not matches:
print(f'No matches found for "{search_term}" in {dir_path}')
return
num_matches = len(matches)
num_files = len(set(match[0] for match in matches))
if num_files > 100:
print(
f'More than {num_files} files matched for "{search_term}" in {dir_path}. Please narrow your search.'
)
return
print(f'[Found {num_matches} matches for "{search_term}" in {dir_path}]')
for file_path, line_num, line in matches:
print(f'{file_path} (Line {line_num}): {line}')
print(f'[End of matches for "{search_term}" in {dir_path}]')
def search_file(search_term: str, file_path: str | None = None) -> None:
"""Searches for search_term in file. If file is not provided, searches in the current open file.
Args:
search_term: str: The term to search for.
file_path: str | None: The path to the file to search.
"""
global CURRENT_FILE
if file_path is None:
file_path = CURRENT_FILE
if file_path is None:
raise FileNotFoundError(
'No file specified or open. Use the open_file function first.'
)
if not os.path.isfile(file_path):
raise FileNotFoundError(f'File {file_path} not found')
matches = []
with open(file_path) as file:
for i, line in enumerate(file, 1):
if search_term in line:
matches.append((i, line.strip()))
if matches:
print(f'[Found {len(matches)} matches for "{search_term}" in {file_path}]')
for match in matches:
print(f'Line {match[0]}: {match[1]}')
print(f'[End of matches for "{search_term}" in {file_path}]')
else:
print(f'[No matches found for "{search_term}" in {file_path}]')
def find_file(file_name: str, dir_path: str = './') -> None:
"""Finds all files with the given name in the specified directory.
Args:
file_name: str: The name of the file to find.
dir_path: str: The path to the directory to search.
"""
if not os.path.isdir(dir_path):
raise FileNotFoundError(f'Directory {dir_path} not found')
matches = []
for root, _, files in os.walk(dir_path):
for file in files:
if file_name in file:
matches.append(os.path.join(root, file))
if matches:
print(f'[Found {len(matches)} matches for "{file_name}" in {dir_path}]')
for match in matches:
print(f'{match}')
print(f'[End of matches for "{file_name}" in {dir_path}]')
else:
print(f'[No matches found for "{file_name}" in {dir_path}]')
__all__ = [
'open_file',
'goto_line',
'scroll_down',
'scroll_up',
'create_file',
'edit_file_by_replace',
'insert_content_at_line',
'append_file',
'search_dir',
'search_file',
'find_file',
]

View File

@@ -0,0 +1,7 @@
from ..utils.dependency import import_functions
from . import file_readers
import_functions(
module=file_readers, function_names=file_readers.__all__, target_globals=globals()
)
__all__ = file_readers.__all__

View File

@@ -0,0 +1,244 @@
"""File reader skills for the OpenHands agent.
This module provides various functions to parse and extract content from different file types,
including PDF, DOCX, LaTeX, audio, image, video, and PowerPoint files. It utilizes different
libraries and APIs to process these files and output their content or descriptions.
Functions:
parse_pdf(file_path: str) -> None: Parse and print content of a PDF file.
parse_docx(file_path: str) -> None: Parse and print content of a DOCX file.
parse_latex(file_path: str) -> None: Parse and print content of a LaTeX file.
parse_audio(file_path: str, model: str = 'whisper-1') -> None: Transcribe and print content of an audio file.
parse_image(file_path: str, task: str = 'Describe this image as detail as possible.') -> None: Analyze and print description of an image file.
parse_video(file_path: str, task: str = 'Describe this image as detail as possible.', frame_interval: int = 30) -> None: Analyze and print description of video frames.
parse_pptx(file_path: str) -> None: Parse and print content of a PowerPoint file.
Note:
Some functions (parse_audio, parse_video, parse_image) require OpenAI API credentials
and are only available if the necessary environment variables are set.
"""
import base64
import docx
import PyPDF2
from pptx import Presentation
from pylatexenc.latex2text import LatexNodes2Text
from ..utils.config import (
_get_max_token,
_get_openai_api_key,
_get_openai_base_url,
_get_openai_client,
_get_openai_model,
)
def parse_pdf(file_path: str) -> None:
"""Parses the content of a PDF file and prints it.
Args:
file_path: str: The path to the file to open.
"""
print(f'[Reading PDF file from {file_path}]')
content = PyPDF2.PdfReader(file_path)
text = ''
for page_idx in range(len(content.pages)):
text += (
f'@@ Page {page_idx + 1} @@\n'
+ content.pages[page_idx].extract_text()
+ '\n\n'
)
print(text.strip())
def parse_docx(file_path: str) -> None:
"""Parses the content of a DOCX file and prints it.
Args:
file_path: str: The path to the file to open.
"""
print(f'[Reading DOCX file from {file_path}]')
content = docx.Document(file_path)
text = ''
for i, para in enumerate(content.paragraphs):
text += f'@@ Page {i + 1} @@\n' + para.text + '\n\n'
print(text)
def parse_latex(file_path: str) -> None:
"""Parses the content of a LaTex file and prints it.
Args:
file_path: str: The path to the file to open.
"""
print(f'[Reading LaTex file from {file_path}]')
with open(file_path) as f:
data = f.read()
text = LatexNodes2Text().latex_to_text(data)
print(text.strip())
def _base64_img(file_path: str) -> str:
with open(file_path, 'rb') as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
return encoded_image
def _base64_video(file_path: str, frame_interval: int = 10) -> list[str]:
import cv2
video = cv2.VideoCapture(file_path)
base64_frames = []
frame_count = 0
while video.isOpened():
success, frame = video.read()
if not success:
break
if frame_count % frame_interval == 0:
_, buffer = cv2.imencode('.jpg', frame)
base64_frames.append(base64.b64encode(buffer).decode('utf-8'))
frame_count += 1
video.release()
return base64_frames
def _prepare_image_messages(task: str, base64_image: str):
return [
{
'role': 'user',
'content': [
{'type': 'text', 'text': task},
{
'type': 'image_url',
'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'},
},
],
}
]
def parse_audio(file_path: str, model: str = 'whisper-1') -> None:
"""Parses the content of an audio file and prints it.
Args:
file_path: str: The path to the audio file to transcribe.
model: str: The audio model to use for transcription. Defaults to 'whisper-1'.
"""
print(f'[Transcribing audio file from {file_path}]')
try:
# TODO: record the COST of the API call
with open(file_path, 'rb') as audio_file:
transcript = _get_openai_client().audio.translations.create(
model=model, file=audio_file
)
print(transcript.text)
except Exception as e:
print(f'Error transcribing audio file: {e}')
def parse_image(
file_path: str, task: str = 'Describe this image as detail as possible.'
) -> None:
"""Parses the content of an image file and prints the description.
Args:
file_path: str: The path to the file to open.
task: str: The task description for the API call. Defaults to 'Describe this image as detail as possible.'.
"""
print(f'[Reading image file from {file_path}]')
# TODO: record the COST of the API call
try:
base64_image = _base64_img(file_path)
response = _get_openai_client().chat.completions.create(
model=_get_openai_model(),
messages=_prepare_image_messages(task, base64_image),
max_tokens=_get_max_token(),
)
content = response.choices[0].message.content
print(content)
except Exception as error:
print(f'Error with the request: {error}')
def parse_video(
file_path: str,
task: str = 'Describe this image as detail as possible.',
frame_interval: int = 30,
) -> None:
"""Parses the content of an image file and prints the description.
Args:
file_path: str: The path to the video file to open.
task: str: The task description for the API call. Defaults to 'Describe this image as detail as possible.'.
frame_interval: int: The interval between frames to analyze. Defaults to 30.
"""
print(
f'[Processing video file from {file_path} with frame interval {frame_interval}]'
)
task = task or 'This is one frame from a video, please summarize this frame.'
base64_frames = _base64_video(file_path)
selected_frames = base64_frames[::frame_interval]
if len(selected_frames) > 30:
new_interval = len(base64_frames) // 30
selected_frames = base64_frames[::new_interval]
print(f'Totally {len(selected_frames)} would be analyze...\n')
idx = 0
for base64_frame in selected_frames:
idx += 1
print(f'Process the {file_path}, current No. {idx * frame_interval} frame...')
# TODO: record the COST of the API call
try:
response = _get_openai_client().chat.completions.create(
model=_get_openai_model(),
messages=_prepare_image_messages(task, base64_frame),
max_tokens=_get_max_token(),
)
content = response.choices[0].message.content
current_frame_content = f"Frame {idx}'s content: {content}\n"
print(current_frame_content)
except Exception as error:
print(f'Error with the request: {error}')
def parse_pptx(file_path: str) -> None:
"""Parses the content of a pptx file and prints it.
Args:
file_path: str: The path to the file to open.
"""
print(f'[Reading PowerPoint file from {file_path}]')
try:
pres = Presentation(str(file_path))
text = []
for slide_idx, slide in enumerate(pres.slides):
text.append(f'@@ Slide {slide_idx + 1} @@')
for shape in slide.shapes:
if hasattr(shape, 'text'):
text.append(shape.text)
print('\n'.join(text))
except Exception as e:
print(f'Error reading PowerPoint file: {e}')
__all__ = [
'parse_pdf',
'parse_docx',
'parse_latex',
'parse_pptx',
]
# This is called from OpenHands's side
# If SANDBOX_ENV_OPENAI_API_KEY is set, we will be able to use these tools in the sandbox environment
if _get_openai_api_key() and _get_openai_base_url():
__all__ += ['parse_audio', 'parse_video', 'parse_image']

View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,8 @@
# Aider is AI pair programming in your terminal
Aider lets you pair program with LLMs,
to edit code in your local git repository.
Please see the [original repository](https://github.com/paul-gauthier/aider) for more information.
OpenHands has adapted and integrated its linter module ([original code](https://github.com/paul-gauthier/aider/blob/main/aider/linter.py)).

View File

@@ -0,0 +1,6 @@
if __package__ is None or __package__ == '':
from linter import Linter, LintResult
else:
from .linter import Linter, LintResult
__all__ = ['Linter', 'LintResult']

View File

@@ -0,0 +1,223 @@
import os
import subprocess
import sys
import traceback
import warnings
from dataclasses import dataclass
from pathlib import Path
from grep_ast import TreeContext, filename_to_lang
from tree_sitter_languages import get_parser # noqa: E402
# tree_sitter is throwing a FutureWarning
warnings.simplefilter('ignore', category=FutureWarning)
@dataclass
class LintResult:
text: str
lines: list
class Linter:
def __init__(self, encoding='utf-8', root=None):
self.encoding = encoding
self.root = root
self.languages = dict(
python=self.py_lint,
)
self.all_lint_cmd = None
def set_linter(self, lang, cmd):
if lang:
self.languages[lang] = cmd
return
self.all_lint_cmd = cmd
def get_rel_fname(self, fname):
if self.root:
return os.path.relpath(fname, self.root)
else:
return fname
def run_cmd(self, cmd, rel_fname, code):
cmd += ' ' + rel_fname
cmd = cmd.split()
process = subprocess.Popen(
cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
stdout, _ = process.communicate()
errors = stdout.decode().strip()
self.returncode = process.returncode
if self.returncode == 0:
return # zero exit status
cmd = ' '.join(cmd)
res = ''
res += errors
line_num = extract_error_line_from(res)
return LintResult(text=res, lines=[line_num])
def get_abs_fname(self, fname):
if os.path.isabs(fname):
return fname
elif os.path.isfile(fname):
rel_fname = self.get_rel_fname(fname)
return os.path.abspath(rel_fname)
else: # if a temp file
return self.get_rel_fname(fname)
def lint(self, fname, cmd=None) -> LintResult | None:
code = Path(fname).read_text(self.encoding)
absolute_fname = self.get_abs_fname(fname)
if cmd:
cmd = cmd.strip()
if not cmd:
lang = filename_to_lang(fname)
if not lang:
return None
if self.all_lint_cmd:
cmd = self.all_lint_cmd
else:
cmd = self.languages.get(lang)
if callable(cmd):
linkres = cmd(fname, absolute_fname, code)
elif cmd:
linkres = self.run_cmd(cmd, absolute_fname, code)
else:
linkres = basic_lint(absolute_fname, code)
return linkres
def flake_lint(self, rel_fname, code):
fatal = 'F821,F822,F831,E112,E113,E999,E902'
flake8 = f'flake8 --select={fatal} --isolated'
try:
flake_res = self.run_cmd(flake8, rel_fname, code)
except FileNotFoundError:
flake_res = None
return flake_res
def py_lint(self, fname, rel_fname, code):
error = self.flake_lint(rel_fname, code)
if not error:
error = lint_python_compile(fname, code)
if not error:
error = basic_lint(rel_fname, code)
return error
def lint_python_compile(fname, code):
try:
compile(code, fname, 'exec') # USE TRACEBACK BELOW HERE
return
except IndentationError as err:
end_lineno = getattr(err, 'end_lineno', err.lineno)
if isinstance(end_lineno, int):
line_numbers = list(range(end_lineno - 1, end_lineno))
else:
line_numbers = []
tb_lines = traceback.format_exception(type(err), err, err.__traceback__)
last_file_i = 0
target = '# USE TRACEBACK'
target += ' BELOW HERE'
for i in range(len(tb_lines)):
if target in tb_lines[i]:
last_file_i = i
break
tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :]
res = ''.join(tb_lines)
return LintResult(text=res, lines=line_numbers)
def basic_lint(fname, code):
"""
Use tree-sitter to look for syntax errors, display them with tree context.
"""
lang = filename_to_lang(fname)
if not lang:
return
parser = get_parser(lang)
tree = parser.parse(bytes(code, 'utf-8'))
errors = traverse_tree(tree.root_node)
if not errors:
return
return LintResult(text=f'{fname}:{errors[0]}', lines=errors)
def extract_error_line_from(lint_error):
# moved from openhands.agentskills#_lint_file
for line in lint_error.splitlines(True):
if line.strip():
# The format of the error message is: <filename>:<line>:<column>: <error code> <error message>
parts = line.split(':')
if len(parts) >= 2:
try:
first_error_line = int(parts[1])
break
except ValueError:
continue
return first_error_line
def tree_context(fname, code, line_nums):
context = TreeContext(
fname,
code,
color=False,
line_number=True,
child_context=False,
last_line=False,
margin=0,
mark_lois=True,
loi_pad=3,
# header_max=30,
show_top_of_file_parent_scope=False,
)
line_nums = set(line_nums)
context.add_lines_of_interest(line_nums)
context.add_context()
output = context.format()
return output
# Traverse the tree to find errors
def traverse_tree(node):
errors = []
if node.type == 'ERROR' or node.is_missing:
line_no = node.start_point[0] + 1
errors.append(line_no)
for child in node.children:
errors += traverse_tree(child)
return errors
def main():
"""
Main function to parse files provided as command line arguments.
"""
if len(sys.argv) < 2:
print('Usage: python linter.py <file1> <file2> ...')
sys.exit(1)
linter = Linter(root=os.getcwd())
for file_path in sys.argv[1:]:
errors = linter.lint(file_path)
if errors:
print(errors)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,30 @@
import os
from openai import OpenAI
# ==================================================================================================
# OPENAI
# TODO: Move this to EventStream Actions when EventStreamRuntime is fully implemented
# NOTE: we need to get env vars inside functions because they will be set in IPython
# AFTER the agentskills is imported (the case for EventStreamRuntime)
# ==================================================================================================
def _get_openai_api_key():
return os.getenv('OPENAI_API_KEY', os.getenv('SANDBOX_ENV_OPENAI_API_KEY', ''))
def _get_openai_base_url():
return os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1')
def _get_openai_model():
return os.getenv('OPENAI_MODEL', 'gpt-4o-2024-05-13')
def _get_max_token():
return os.getenv('MAX_TOKEN', 500)
def _get_openai_client():
client = OpenAI(api_key=_get_openai_api_key(), base_url=_get_openai_base_url())
return client

View File

@@ -0,0 +1,11 @@
from types import ModuleType
def import_functions(
module: ModuleType, function_names: list[str], target_globals: dict
) -> None:
for name in function_names:
if hasattr(module, name):
target_globals[name] = getattr(module, name)
else:
raise ValueError(f'Function {name} not found in {module.__name__}')

View File

@@ -0,0 +1,76 @@
import subprocess
import time
from dataclasses import dataclass
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import Action, IPythonRunCellAction
from openhands.events.observation import IPythonRunCellObservation
from openhands.runtime.plugins.requirement import Plugin, PluginRequirement
from openhands.runtime.utils import find_available_tcp_port
from .execute_server import JupyterKernel
@dataclass
class JupyterRequirement(PluginRequirement):
name: str = 'jupyter'
class JupyterPlugin(Plugin):
name: str = 'jupyter'
async def initialize(self, username: str, kernel_id: str = 'openhands-default'):
self.kernel_gateway_port = find_available_tcp_port()
self.kernel_id = kernel_id
self.gateway_process = subprocess.Popen(
(
f"su - {username} -s /bin/bash << 'EOF'\n"
'cd /openhands/code\n'
'export POETRY_VIRTUALENVS_PATH=/openhands/poetry;\n'
'export PYTHONPATH=/openhands/code:$PYTHONPATH;\n'
'/openhands/miniforge3/bin/mamba run -n base '
'poetry run jupyter kernelgateway '
'--KernelGatewayApp.ip=0.0.0.0 '
f'--KernelGatewayApp.port={self.kernel_gateway_port}\n'
'EOF'
),
stderr=subprocess.STDOUT,
shell=True,
)
# read stdout until the kernel gateway is ready
output = ''
while True and self.gateway_process.stdout is not None:
line = self.gateway_process.stdout.readline().decode('utf-8')
output += line
if 'at' in line:
break
time.sleep(1)
logger.debug('Waiting for jupyter kernel gateway to start...')
logger.info(
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
)
async def _run(self, action: Action) -> IPythonRunCellObservation:
"""Internal method to run a code cell in the jupyter kernel."""
if not isinstance(action, IPythonRunCellAction):
raise ValueError(
f'Jupyter plugin only supports IPythonRunCellAction, but got {action}'
)
if not hasattr(self, 'kernel'):
self.kernel = JupyterKernel(
f'localhost:{self.kernel_gateway_port}', self.kernel_id
)
if not self.kernel.initialized:
await self.kernel.initialize()
output = await self.kernel.execute(action.code, timeout=action.timeout)
return IPythonRunCellObservation(
content=output,
code=action.code,
)
async def run(self, action: Action) -> IPythonRunCellObservation:
obs = await self._run(action)
return obs

View File

@@ -0,0 +1,285 @@
#!/usr/bin/env python3
import asyncio
import logging
import os
import re
from uuid import uuid4
import tornado
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from tornado.escape import json_decode, json_encode, url_escape
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
from tornado.ioloop import PeriodicCallback
from tornado.websocket import websocket_connect
logging.basicConfig(level=logging.INFO)
def strip_ansi(o: str) -> str:
"""Removes ANSI escape sequences from `o`, as defined by ECMA-048 in
http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf
# https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py
>>> strip_ansi("\\033[33mLorem ipsum\\033[0m")
'Lorem ipsum'
>>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.")
'Lorem Ipsum sit\\namet.'
>>> strip_ansi("")
''
>>> strip_ansi("\\x1b[0m")
''
>>> strip_ansi("Lorem")
'Lorem'
>>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m')
'Lorem ipsum'
>>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m')
'Lorem dolor sit ipsum'
"""
# pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/')
pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m')
stripped = pattern.sub('', o)
return stripped
class JupyterKernel:
def __init__(self, url_suffix, convid, lang='python'):
self.base_url = f'http://{url_suffix}'
self.base_ws_url = f'ws://{url_suffix}'
self.lang = lang
self.kernel_id = None
self.ws = None
self.convid = convid
logging.info(
f'Jupyter kernel created for conversation {convid} at {url_suffix}'
)
self.heartbeat_interval = 10000 # 10 seconds
self.heartbeat_callback = None
self.initialized = False
async def initialize(self):
await self.execute(r'%colors nocolor')
# pre-defined tools
self.tools_to_run: list[str] = [
# TODO: You can add code for your pre-defined tools here
]
for tool in self.tools_to_run:
res = await self.execute(tool)
logging.info(f'Tool [{tool}] initialized:\n{res}')
self.initialized = True
async def _send_heartbeat(self):
if not self.ws:
return
try:
self.ws.ping()
# logging.info('Heartbeat sent...')
except tornado.iostream.StreamClosedError:
# logging.info('Heartbeat failed, reconnecting...')
try:
await self._connect()
except ConnectionRefusedError:
logging.info(
'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?'
)
async def _connect(self):
if self.ws:
self.ws.close()
self.ws = None
client = AsyncHTTPClient()
if not self.kernel_id:
n_tries = 5
while n_tries > 0:
try:
response = await client.fetch(
'{}/api/kernels'.format(self.base_url),
method='POST',
body=json_encode({'name': self.lang}),
)
kernel = json_decode(response.body)
self.kernel_id = kernel['id']
break
except Exception:
# kernels are not ready yet
n_tries -= 1
await asyncio.sleep(1)
if n_tries == 0:
raise ConnectionRefusedError('Failed to connect to kernel')
ws_req = HTTPRequest(
url='{}/api/kernels/{}/channels'.format(
self.base_ws_url, url_escape(self.kernel_id)
)
)
self.ws = await websocket_connect(ws_req)
logging.info('Connected to kernel websocket')
# Setup heartbeat
if self.heartbeat_callback:
self.heartbeat_callback.stop()
self.heartbeat_callback = PeriodicCallback(
self._send_heartbeat, self.heartbeat_interval
)
self.heartbeat_callback.start()
@retry(
retry=retry_if_exception_type(ConnectionRefusedError),
stop=stop_after_attempt(3),
wait=wait_fixed(2),
)
async def execute(self, code, timeout=120):
if not self.ws:
await self._connect()
msg_id = uuid4().hex
assert self.ws is not None
res = await self.ws.write_message(
json_encode(
{
'header': {
'username': '',
'version': '5.0',
'session': '',
'msg_id': msg_id,
'msg_type': 'execute_request',
},
'parent_header': {},
'channel': 'shell',
'content': {
'code': code,
'silent': False,
'store_history': False,
'user_expressions': {},
'allow_stdin': False,
},
'metadata': {},
'buffers': {},
}
)
)
logging.info(f'Executed code in jupyter kernel:\n{res}')
outputs = []
async def wait_for_messages():
execution_done = False
while not execution_done:
assert self.ws is not None
msg = await self.ws.read_message()
msg = json_decode(msg)
msg_type = msg['msg_type']
parent_msg_id = msg['parent_header'].get('msg_id', None)
if parent_msg_id != msg_id:
continue
if os.environ.get('DEBUG'):
logging.info(
f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}"
)
if msg_type == 'error':
traceback = '\n'.join(msg['content']['traceback'])
outputs.append(traceback)
execution_done = True
elif msg_type == 'stream':
outputs.append(msg['content']['text'])
elif msg_type in ['execute_result', 'display_data']:
outputs.append(msg['content']['data']['text/plain'])
if 'image/png' in msg['content']['data']:
# use markdone to display image (in case of large image)
outputs.append(
f"\n![image](data:image/png;base64,{msg['content']['data']['image/png']})\n"
)
elif msg_type == 'execute_reply':
execution_done = True
return execution_done
async def interrupt_kernel():
client = AsyncHTTPClient()
interrupt_response = await client.fetch(
f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt',
method='POST',
body=json_encode({'kernel_id': self.kernel_id}),
)
logging.info(f'Kernel interrupted: {interrupt_response}')
try:
execution_done = await asyncio.wait_for(wait_for_messages(), timeout)
except asyncio.TimeoutError:
await interrupt_kernel()
return f'[Execution timed out ({timeout} seconds).]'
if not outputs and execution_done:
ret = '[Code executed successfully with no output]'
else:
ret = ''.join(outputs)
# Remove ANSI
ret = strip_ansi(ret)
if os.environ.get('DEBUG'):
logging.info(f'OUTPUT:\n{ret}')
return ret
async def shutdown_async(self):
if self.kernel_id:
client = AsyncHTTPClient()
await client.fetch(
'{}/api/kernels/{}'.format(self.base_url, self.kernel_id),
method='DELETE',
)
self.kernel_id = None
if self.ws:
self.ws.close()
self.ws = None
class ExecuteHandler(tornado.web.RequestHandler):
def initialize(self, jupyter_kernel):
self.jupyter_kernel = jupyter_kernel
async def post(self):
data = json_decode(self.request.body)
code = data.get('code')
if not code:
self.set_status(400)
self.write('Missing code')
return
output = await self.jupyter_kernel.execute(code)
self.write(output)
def make_app():
jupyter_kernel = JupyterKernel(
f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT')}",
os.environ.get('JUPYTER_GATEWAY_KERNEL_ID'),
)
asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize())
return tornado.web.Application(
[
(r'/execute', ExecuteHandler, {'jupyter_kernel': jupyter_kernel}),
]
)
if __name__ == '__main__':
app = make_app()
app.listen(os.environ.get('JUPYTER_EXEC_SERVER_PORT'))
tornado.ioloop.IOLoop.current().start()

View File

@@ -0,0 +1,31 @@
from abc import abstractmethod
from dataclasses import dataclass
from openhands.events.action import Action
from openhands.events.observation import Observation
class Plugin:
"""Base class for a plugin.
This will be initialized by the runtime client, which will run inside docker.
"""
name: str
@abstractmethod
async def initialize(self, username: str):
"""Initialize the plugin."""
pass
@abstractmethod
async def run(self, action: Action) -> Observation:
"""Run the plugin for a given action."""
pass
@dataclass
class PluginRequirement:
"""Requirement for a plugin."""
name: str

View File

@@ -0,0 +1,211 @@
import asyncio
import atexit
import copy
import json
import os
from abc import abstractmethod
from openhands.core.config import AppConfig, SandboxConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
Action,
ActionConfirmationStatus,
BrowseInteractiveAction,
BrowseURLAction,
CmdRunAction,
FileReadAction,
FileWriteAction,
IPythonRunCellAction,
)
from openhands.events.event import Event
from openhands.events.observation import (
CmdOutputObservation,
ErrorObservation,
NullObservation,
Observation,
UserRejectObservation,
)
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
ret = {}
for key in os.environ:
if key.startswith('SANDBOX_ENV_'):
sandbox_key = key.removeprefix('SANDBOX_ENV_')
ret[sandbox_key] = os.environ[key]
if sandbox_config.enable_auto_lint:
ret['ENABLE_AUTO_LINT'] = 'true'
return ret
class Runtime:
"""The runtime is how the agent interacts with the external environment.
This includes a bash sandbox, a browser, and filesystem interactions.
sid is the session id, which is used to identify the current user session.
"""
sid: str
config: AppConfig
DEFAULT_ENV_VARS: dict[str, str]
def __init__(
self,
config: AppConfig,
event_stream: EventStream,
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
):
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
self.config = copy.deepcopy(config)
self.DEFAULT_ENV_VARS = _default_env_vars(config.sandbox)
atexit.register(self.close_sync)
logger.debug(f'Runtime `{sid}` config:\n{self.config}')
async def ainit(self, env_vars: dict[str, str] | None = None) -> None:
"""
Initialize the runtime (asynchronously).
This method should be called after the runtime's constructor.
"""
if self.DEFAULT_ENV_VARS:
logger.debug(f'Adding default env vars: {self.DEFAULT_ENV_VARS}')
await self.add_env_vars(self.DEFAULT_ENV_VARS)
if env_vars is not None:
logger.debug(f'Adding provided env vars: {env_vars}')
await self.add_env_vars(env_vars)
async def close(self) -> None:
pass
def close_sync(self) -> None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, use asyncio.run()
asyncio.run(self.close())
else:
# There is a running event loop, create a task
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
# ====================================================================
async def add_env_vars(self, env_vars: dict[str, str]) -> None:
# Add env vars to the IPython shell (if Jupyter is used)
if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
code = 'import os\n'
for key, value in env_vars.items():
# Note: json.dumps gives us nice escaping for free
code += f'os.environ["{key}"] = {json.dumps(value)}\n'
code += '\n'
obs = await self.run_ipython(IPythonRunCellAction(code))
logger.info(f'Added env vars to IPython: code={code}, obs={obs}')
# Add env vars to the Bash shell
cmd = ''
for key, value in env_vars.items():
# Note: json.dumps gives us nice escaping for free
cmd += f'export {key}={json.dumps(value)}; '
if not cmd:
return
cmd = cmd.strip()
logger.debug(f'Adding env var: {cmd}')
obs = await self.run(CmdRunAction(cmd))
if not isinstance(obs, CmdOutputObservation) or obs.exit_code != 0:
raise RuntimeError(
f'Failed to add env vars [{env_vars}] to environment: {obs.content}'
)
async def on_event(self, event: Event) -> None:
if isinstance(event, Action):
# set timeout to default if not set
if event.timeout is None:
event.timeout = self.config.sandbox.timeout
assert event.timeout is not None
observation = await self.run_action(event)
observation._cause = event.id # type: ignore[attr-defined]
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
async def run_action(self, action: Action) -> Observation:
"""Run an action and return the resulting observation.
If the action is not runnable in any runtime, a NullObservation is returned.
If the action is not supported by the current runtime, an ErrorObservation is returned.
"""
if not action.runnable:
return NullObservation('')
if (
hasattr(action, 'is_confirmed')
and action.is_confirmed == ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
return ErrorObservation(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.'
)
if (
hasattr(action, 'is_confirmed')
and action.is_confirmed == ActionConfirmationStatus.REJECTED
):
return UserRejectObservation(
'Action has been rejected by the user! Waiting for further user input.'
)
observation = await getattr(self, action_type)(action)
return observation
# ====================================================================
# Action execution
# ====================================================================
@abstractmethod
async def run(self, action: CmdRunAction) -> Observation:
pass
@abstractmethod
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
pass
@abstractmethod
async def read(self, action: FileReadAction) -> Observation:
pass
@abstractmethod
async def write(self, action: FileWriteAction) -> Observation:
pass
@abstractmethod
async def browse(self, action: BrowseURLAction) -> Observation:
pass
@abstractmethod
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
pass
# ====================================================================
# File operations
# ====================================================================
@abstractmethod
async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
raise NotImplementedError('This method is not implemented in the base class.')
@abstractmethod
async def list_files(self, path: str | None = None) -> list[str]:
"""List files in the sandbox.
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
"""
raise NotImplementedError('This method is not implemented in the base class.')

View File

@@ -0,0 +1,4 @@
from .bash import split_bash_commands
from .system import find_available_tcp_port
__all__ = ['find_available_tcp_port', 'split_bash_commands']

View File

@@ -0,0 +1,52 @@
import bashlex
from openhands.core.logger import openhands_logger as logger
def split_bash_commands(commands):
try:
parsed = bashlex.parse(commands)
except bashlex.errors.ParsingError as e:
logger.debug(
f'Failed to parse bash commands\n'
f'[input]: {commands}\n'
f'[warning]: {e}\n'
f'The original command will be returned as is.'
)
# If parsing fails, return the original commands
return [commands]
result: list[str] = []
last_end = 0
for node in parsed:
start, end = node.pos
# Include any text between the last command and this one
if start > last_end:
between = commands[last_end:start]
logger.debug(f'BASH PARSING between: {between}')
if result:
result[-1] += between.rstrip()
elif between.strip():
# THIS SHOULD NOT HAPPEN
result.append(between.rstrip())
# Extract the command, preserving original formatting
command = commands[start:end].rstrip()
logger.debug(f'BASH PARSING command: {command}')
result.append(command)
last_end = end
# Add any remaining text after the last command to the last command
remaining = commands[last_end:].rstrip()
logger.debug(f'BASH PARSING remaining: {remaining}')
if last_end < len(commands) and result:
result[-1] += remaining
logger.debug(f'BASH PARSING result[-1] += remaining: {result[-1]}')
elif last_end < len(commands):
if remaining:
result.append(remaining)
logger.debug(f'BASH PARSING result.append(remaining): {result[-1]}')
return result

View File

@@ -0,0 +1,145 @@
import os
from pathlib import Path
from openhands.events.observation import (
ErrorObservation,
FileReadObservation,
FileWriteObservation,
Observation,
)
def resolve_path(
file_path: str,
working_directory: str,
workspace_base: str,
workspace_mount_path_in_sandbox: str,
):
"""Resolve a file path to a path on the host filesystem.
Args:
file_path: The path to resolve.
working_directory: The working directory of the agent.
workspace_mount_path_in_sandbox: The path to the workspace inside the sandbox.
workspace_base: The base path of the workspace on the host filesystem.
Returns:
The resolved path on the host filesystem.
"""
path_in_sandbox = Path(file_path)
# Apply working directory
if not path_in_sandbox.is_absolute():
path_in_sandbox = Path(working_directory) / path_in_sandbox
# Sanitize the path with respect to the root of the full sandbox
# (deny any .. path traversal to parent directories of the sandbox)
abs_path_in_sandbox = path_in_sandbox.resolve()
# If the path is outside the workspace, deny it
if not abs_path_in_sandbox.is_relative_to(workspace_mount_path_in_sandbox):
raise PermissionError(f'File access not permitted: {file_path}')
# Get path relative to the root of the workspace inside the sandbox
path_in_workspace = abs_path_in_sandbox.relative_to(
Path(workspace_mount_path_in_sandbox)
)
# Get path relative to host
path_in_host_workspace = Path(workspace_base) / path_in_workspace
return path_in_host_workspace
def read_lines(all_lines: list[str], start=0, end=-1):
start = max(start, 0)
start = min(start, len(all_lines))
end = -1 if end == -1 else max(end, 0)
end = min(end, len(all_lines))
if end == -1:
if start == 0:
return all_lines
else:
return all_lines[start:]
else:
num_lines = len(all_lines)
begin = max(0, min(start, num_lines - 2))
end = -1 if end > num_lines else max(begin + 1, end)
return all_lines[begin:end]
async def read_file(
path, workdir, workspace_base, workspace_mount_path_in_sandbox, start=0, end=-1
) -> Observation:
try:
whole_path = resolve_path(
path, workdir, workspace_base, workspace_mount_path_in_sandbox
)
except PermissionError:
return ErrorObservation(
f"You're not allowed to access this path: {path}. You can only access paths inside the workspace."
)
try:
with open(whole_path, 'r', encoding='utf-8') as file:
lines = read_lines(file.readlines(), start, end)
except FileNotFoundError:
return ErrorObservation(f'File not found: {path}')
except UnicodeDecodeError:
return ErrorObservation(f'File could not be decoded as utf-8: {path}')
except IsADirectoryError:
return ErrorObservation(f'Path is a directory: {path}. You can only read files')
code_view = ''.join(lines)
return FileReadObservation(path=path, content=code_view)
def insert_lines(
to_insert: list[str], original: list[str], start: int = 0, end: int = -1
):
"""Insert the new content to the original content based on start and end"""
new_lines = [''] if start == 0 else original[:start]
new_lines += [i + '\n' for i in to_insert]
new_lines += [''] if end == -1 else original[end:]
return new_lines
async def write_file(
path,
workdir,
workspace_base,
workspace_mount_path_in_sandbox,
content,
start=0,
end=-1,
) -> Observation:
insert = content.split('\n')
try:
whole_path = resolve_path(
path, workdir, workspace_base, workspace_mount_path_in_sandbox
)
if not os.path.exists(os.path.dirname(whole_path)):
os.makedirs(os.path.dirname(whole_path))
mode = 'w' if not os.path.exists(whole_path) else 'r+'
try:
with open(whole_path, mode, encoding='utf-8') as file:
if mode != 'w':
all_lines = file.readlines()
new_file = insert_lines(insert, all_lines, start, end)
else:
new_file = [i + '\n' for i in insert]
file.seek(0)
file.writelines(new_file)
file.truncate()
except FileNotFoundError:
return ErrorObservation(f'File not found: {path}')
except IsADirectoryError:
return ErrorObservation(
f'Path is a directory: {path}. You can only write to files'
)
except UnicodeDecodeError:
return ErrorObservation(f'File could not be decoded as utf-8: {path}')
except PermissionError:
return ErrorObservation(f'Malformed paths not permitted: {path}')
return FileWriteObservation(content='', path=path)

View File

@@ -0,0 +1,442 @@
import argparse
import os
import shutil
import subprocess
import tempfile
import docker
import toml
from dirhash import dirhash
from jinja2 import Environment, FileSystemLoader
import openhands
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.builder import DockerRuntimeBuilder, RuntimeBuilder
RUNTIME_IMAGE_REPO = os.getenv(
'OD_RUNTIME_RUNTIME_IMAGE_REPO', 'ghcr.io/openhands/od_runtime'
)
def _get_package_version():
"""Read the version from pyproject.toml.
Returns:
- The version specified in pyproject.toml under [tool.poetry]
"""
project_root = os.path.dirname(os.path.dirname(os.path.abspath(openhands.__file__)))
pyproject_path = os.path.join(project_root, 'pyproject.toml')
with open(pyproject_path, 'r') as f:
pyproject_data = toml.load(f)
return pyproject_data['tool']['poetry']['version']
def _create_project_source_dist():
"""Create a source distribution of the project.
Returns:
- str: The path to the project tarball
"""
project_root = os.path.dirname(os.path.dirname(os.path.abspath(openhands.__file__)))
logger.info(f'Using project root: {project_root}')
# run "python -m build -s" on project_root to create project tarball
result = subprocess.run(f'python -m build -s {project_root}', shell=True)
if result.returncode != 0:
logger.error(f'Build failed: {result}')
raise Exception(f'Build failed: {result}')
# Fetch the correct version from pyproject.toml
package_version = _get_package_version()
tarball_path = os.path.join(
project_root, 'dist', f'openhands-{package_version}.tar.gz'
)
if not os.path.exists(tarball_path):
logger.error(f'Source distribution not found at {tarball_path}')
raise Exception(f'Source distribution not found at {tarball_path}')
logger.info(f'Source distribution created at {tarball_path}')
return tarball_path
def _put_source_code_to_dir(temp_dir: str):
"""Builds the project source tarball. Copies it to temp_dir and unpacks it.
The OpenHands source code ends up in the temp_dir/code directory
Parameters:
- temp_dir (str): The directory to put the source code in
"""
# Build the project source tarball
tarball_path = _create_project_source_dist()
filename = os.path.basename(tarball_path)
filename = filename.removesuffix('.tar.gz')
# Move the project tarball to temp_dir
_res = shutil.copy(tarball_path, os.path.join(temp_dir, 'project.tar.gz'))
if _res:
os.remove(tarball_path)
logger.info(
f'Source distribution moved to {os.path.join(temp_dir, "project.tar.gz")}'
)
# Unzip the tarball
shutil.unpack_archive(os.path.join(temp_dir, 'project.tar.gz'), temp_dir)
# Remove the tarball
os.remove(os.path.join(temp_dir, 'project.tar.gz'))
# Rename the directory containing the code to 'code'
os.rename(os.path.join(temp_dir, filename), os.path.join(temp_dir, 'code'))
logger.info(f'Unpacked source code directory: {os.path.join(temp_dir, "code")}')
def _generate_dockerfile(
base_image: str,
skip_init: bool = False,
extra_deps: str | None = None,
) -> str:
"""Generate the Dockerfile content for the runtime image based on the base image.
Parameters:
- base_image (str): The base image provided for the runtime image
- skip_init (boolean):
- extra_deps (str):
Returns:
- str: The resulting Dockerfile content
"""
env = Environment(
loader=FileSystemLoader(
searchpath=os.path.join(os.path.dirname(__file__), 'runtime_templates')
)
)
template = env.get_template('Dockerfile.j2')
dockerfile_content = template.render(
base_image=base_image,
skip_init=skip_init,
extra_deps=extra_deps if extra_deps is not None else '',
)
return dockerfile_content
def prep_docker_build_folder(
dir_path: str,
base_image: str,
skip_init: bool = False,
extra_deps: str | None = None,
) -> str:
"""Prepares a docker build folder by copying the source code and generating the Dockerfile
Parameters:
- dir_path (str): The build folder to place the source code and Dockerfile
- base_image (str): The base Docker image to use for the Dockerfile
- skip_init (str):
- extra_deps (str):
Returns:
- str: The MD5 hash of the build folder directory (dir_path)
"""
# Copy the source code to directory. It will end up in dir_path/code
_put_source_code_to_dir(dir_path)
# Create a Dockerfile and write it to dir_path
dockerfile_content = _generate_dockerfile(
base_image,
skip_init=skip_init,
extra_deps=extra_deps,
)
logger.info(
(
f'===== Dockerfile content start =====\n'
f'{dockerfile_content}\n'
f'===== Dockerfile content end ====='
)
)
with open(os.path.join(dir_path, 'Dockerfile'), 'w') as file:
file.write(dockerfile_content)
# Get the MD5 hash of the dir_path directory
hash = dirhash(dir_path, 'md5')
logger.info(
f'Input base image: {base_image}\n'
f'Skip init: {skip_init}\n'
f'Extra deps: {extra_deps}\n'
f'Hash for docker build directory [{dir_path}] (contents: {os.listdir(dir_path)}): {hash}\n'
)
return hash
def get_runtime_image_repo_and_tag(base_image: str) -> tuple[str, str]:
"""Retrieves the Docker repo and tag associated with the Docker image.
Parameters:
- base_image (str): The name of the base Docker image
Returns:
- tuple[str, str]: The Docker repo and tag of the Docker image
"""
if RUNTIME_IMAGE_REPO in base_image:
logger.info(
f'The provided image [{base_image}] is a already a valid od_runtime image.\n'
f'Will try to reuse it as is.'
)
if ':' not in base_image:
base_image = base_image + ':latest'
repo, tag = base_image.split(':')
return repo, tag
else:
if ':' not in base_image:
base_image = base_image + ':latest'
[repo, tag] = base_image.split(':')
repo = repo.replace('/', '___')
od_version = _get_package_version()
return RUNTIME_IMAGE_REPO, f'od_v{od_version}_image_{repo}_tag_{tag}'
def build_runtime_image(
base_image: str,
runtime_builder: RuntimeBuilder,
extra_deps: str | None = None,
docker_build_folder: str | None = None,
dry_run: bool = False,
force_rebuild: bool = False,
) -> str:
"""Prepares the final docker build folder.
If dry_run is False, it will also build the OpenHands runtime Docker image using the docker build folder.
Parameters:
- base_image (str): The name of the base Docker image to use
- runtime_builder (RuntimeBuilder): The runtime builder to use
- extra_deps (str):
- docker_build_folder (str): The directory to use for the build. If not provided a temporary directory will be used
- dry_run (bool): if True, it will only ready the build folder. It will not actually build the Docker image
- force_rebuild (bool): if True, it will create the Dockerfile which uses the base_image
Returns:
- str: <image_repo>:<MD5 hash>. Where MD5 hash is the hash of the docker build folder
See https://docs.all-hands.dev/modules/usage/runtime for more details.
"""
# Calculate the hash for the docker build folder (source code and Dockerfile)
with tempfile.TemporaryDirectory() as temp_dir:
from_scratch_hash = prep_docker_build_folder(
temp_dir,
base_image=base_image,
skip_init=False,
extra_deps=extra_deps,
)
runtime_image_repo, runtime_image_tag = get_runtime_image_repo_and_tag(base_image)
# The image name in the format <image repo>:<hash>
hash_runtime_image_name = f'{runtime_image_repo}:{from_scratch_hash}'
# non-hash generic image name, it could contain *similar* dependencies
# but *might* not exactly match the state of the source code.
# It resembles the "latest" tag in the docker image naming convention for
# a particular {repo}:{tag} pair (e.g., ubuntu:latest -> od_runtime:ubuntu_tag_latest)
# we will build from IT to save time if the `from_scratch_hash` is not found
generic_runtime_image_name = f'{runtime_image_repo}:{runtime_image_tag}'
# Scenario 1: If we already have an image with the exact same hash, then it means the image is already built
# with the exact same source code and Dockerfile, so we will reuse it. Building it is not required.
if runtime_builder.image_exists(hash_runtime_image_name):
logger.info(
f'Image [{hash_runtime_image_name}] already exists so we will reuse it.'
)
return hash_runtime_image_name
# Scenario 2: If a Docker image with the exact hash is not found, we will FIRST try to re-build it
# by leveraging the `generic_runtime_image_name` to save some time
# from re-building the dependencies (e.g., poetry install, apt install)
elif runtime_builder.image_exists(generic_runtime_image_name) and not force_rebuild:
logger.info(
f'Cannot find docker Image [{hash_runtime_image_name}]\n'
f'Will try to re-build it from latest [{generic_runtime_image_name}] image to potentially save '
f'time for dependencies installation.\n'
)
cur_docker_build_folder = docker_build_folder or tempfile.mkdtemp()
_skip_init_hash = prep_docker_build_folder(
cur_docker_build_folder,
# we want to use the existing generic image as base
# so that we can leverage existing dependencies already installed in the image
base_image=generic_runtime_image_name,
skip_init=True, # skip init since we are re-using the existing image
extra_deps=extra_deps,
)
assert (
_skip_init_hash != from_scratch_hash
), f'The skip_init hash [{_skip_init_hash}] should not match the existing hash [{from_scratch_hash}]'
if not dry_run:
_build_sandbox_image(
docker_folder=cur_docker_build_folder,
runtime_builder=runtime_builder,
target_image_repo=runtime_image_repo,
# NOTE: WE ALWAYS use the "from_scratch_hash" tag for the target image
# otherwise, even if the source code is exactly the same, the image *might* be re-built
# because the same source code will generate different hash when skip_init=True/False
# since the Dockerfile is slightly different
target_image_hash_tag=from_scratch_hash,
target_image_tag=runtime_image_tag,
)
else:
logger.info(
f'Dry run: Skipping image build for [{generic_runtime_image_name}]'
)
if docker_build_folder is None:
shutil.rmtree(cur_docker_build_folder)
# Scenario 3: If the Docker image with the required hash is not found AND we cannot re-use the latest
# relevant image, we will build it completely from scratch
else:
if force_rebuild:
logger.info(
f'Force re-build: Will try to re-build image [{generic_runtime_image_name}] from scratch.\n'
)
cur_docker_build_folder = docker_build_folder or tempfile.mkdtemp()
_new_from_scratch_hash = prep_docker_build_folder(
cur_docker_build_folder,
base_image,
skip_init=False,
extra_deps=extra_deps,
)
assert (
_new_from_scratch_hash == from_scratch_hash
), f'The new from scratch hash [{_new_from_scratch_hash}] does not match the existing hash [{from_scratch_hash}]'
if not dry_run:
_build_sandbox_image(
docker_folder=cur_docker_build_folder,
runtime_builder=runtime_builder,
target_image_repo=runtime_image_repo,
# NOTE: WE ALWAYS use the "from_scratch_hash" tag for the target image
target_image_hash_tag=from_scratch_hash,
target_image_tag=runtime_image_tag,
)
else:
logger.info(
f'Dry run: Skipping image build for [{generic_runtime_image_name}]'
)
if docker_build_folder is None:
shutil.rmtree(cur_docker_build_folder)
return f'{runtime_image_repo}:{from_scratch_hash}'
def _build_sandbox_image(
docker_folder: str,
runtime_builder: RuntimeBuilder,
target_image_repo: str,
target_image_hash_tag: str,
target_image_tag: str,
) -> str:
"""Build and tag the sandbox image.
The image will be tagged as both:
- target_image_hash_tag
- target_image_tag
Parameters:
- docker_folder (str): the path to the docker build folder
- runtime_builder (RuntimeBuilder): the runtime builder instance
- target_image_repo (str): the repository name for the target image
- target_image_hash_tag (str): the *hash* tag for the target image that is calculated based
on the contents of the docker build folder (source code and Dockerfile)
e.g. 1234567890abcdef
-target_image_tag (str): the tag for the target image that's generic and based on the base image name
e.g. od_v0.8.3_image_ubuntu_tag_22.04
"""
target_image_hash_name = f'{target_image_repo}:{target_image_hash_tag}'
target_image_generic_name = f'{target_image_repo}:{target_image_tag}'
try:
success = runtime_builder.build(
path=docker_folder, tags=[target_image_hash_name, target_image_generic_name]
)
if not success:
raise RuntimeError(f'Build failed for image {target_image_hash_name}')
except Exception as e:
logger.error(f'Sandbox image build failed: {e}')
raise
return target_image_hash_name
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--base_image', type=str, default='nikolaik/python-nodejs:python3.11-nodejs22'
)
parser.add_argument('--build_folder', type=str, default=None)
parser.add_argument('--force_rebuild', action='store_true', default=False)
args = parser.parse_args()
if args.build_folder is not None:
# If a build_folder is provided, we do not actually build the Docker image. We copy the necessary source code
# and create a Dockerfile dynamically and place it in the build_folder only. This allows the Docker image to
# then be created using the Dockerfile (most likely using the containers/build.sh script)
build_folder = args.build_folder
assert os.path.exists(
build_folder
), f'Build folder {build_folder} does not exist'
logger.info(
f'Copying the source code and generating the Dockerfile in the build folder: {build_folder}'
)
runtime_image_repo, runtime_image_tag = get_runtime_image_repo_and_tag(
args.base_image
)
logger.info(
f'Runtime image repo: {runtime_image_repo} and runtime image tag: {runtime_image_tag}'
)
with tempfile.TemporaryDirectory() as temp_dir:
# dry_run is true so we only prepare a temp_dir containing the required source code and the Dockerfile. We
# then obtain the MD5 hash of the folder and return <image_repo>:<temp_dir_md5_hash>
runtime_image_hash_name = build_runtime_image(
args.base_image,
runtime_builder=DockerRuntimeBuilder(docker.from_env()),
docker_build_folder=temp_dir,
dry_run=True,
force_rebuild=args.force_rebuild,
)
_runtime_image_repo, runtime_image_hash_tag = runtime_image_hash_name.split(
':'
)
# Move contents of temp_dir to build_folder
shutil.copytree(temp_dir, build_folder, dirs_exist_ok=True)
logger.info(
f'Build folder [{build_folder}] is ready: {os.listdir(build_folder)}'
)
# We now update the config.sh in the build_folder to contain the required values. This is used in the
# containers/build.sh script which is called to actually build the Docker image
with open(os.path.join(build_folder, 'config.sh'), 'a') as file:
file.write(
(
f'\n'
f'DOCKER_IMAGE_TAG={runtime_image_tag}\n'
f'DOCKER_IMAGE_HASH_TAG={runtime_image_hash_tag}\n'
)
)
logger.info(
f'`config.sh` is updated with the image repo[{runtime_image_repo}] and tags [{runtime_image_tag}, {runtime_image_hash_tag}]'
)
logger.info(
f'Dockerfile, source code and config.sh are ready in {build_folder}'
)
else:
# If a build_folder is not provided, after copying the required source code and dynamically creating the
# Dockerfile, we actually build the Docker image
logger.info('Building image in a temporary folder')
docker_builder = DockerRuntimeBuilder(docker.from_env())
image_name = build_runtime_image(args.base_image, docker_builder)
print(f'\nBUILT Image: {image_name}\n')

View File

@@ -0,0 +1,68 @@
{% if skip_init %}
FROM {{ base_image }}
{% else %}
# ================================================================
# START: Build Runtime Image from Scratch
# ================================================================
FROM {{ base_image }}
{% if 'ubuntu' in base_image and (base_image.endswith(':latest') or base_image.endswith(':24.04')) %}
{% set LIBGL_MESA = 'libgl1' %}
{% else %}
{% set LIBGL_MESA = 'libgl1-mesa-glx' %}
{% endif %}
# Install necessary packages and clean up in one layer
RUN apt-get update && \
apt-get install -y wget sudo apt-utils {{ LIBGL_MESA }} libasound2-plugins git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Create necessary directories
RUN mkdir -p /openhands && \
mkdir -p /openhands/logs && \
mkdir -p /openhands/poetry
ENV POETRY_VIRTUALENVS_PATH=/openhands/poetry
RUN if [ ! -d /openhands/miniforge3 ]; then \
wget --progress=bar:force -O Miniforge3.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" && \
bash Miniforge3.sh -b -p /openhands/miniforge3 && \
rm Miniforge3.sh && \
chmod -R g+w /openhands/miniforge3 && \
bash -c ". /openhands/miniforge3/etc/profile.d/conda.sh && conda config --set changeps1 False && conda config --append channels conda-forge"; \
fi
# Install Python and Poetry
RUN /openhands/miniforge3/bin/mamba install conda-forge::poetry python=3.11 -y
# ================================================================
# END: Build Runtime Image from Scratch
# ================================================================
{% endif %}
# ================================================================
# START: Copy Project and Install/Update Dependencies
# ================================================================
RUN if [ -d /openhands/code ]; then rm -rf /openhands/code; fi
COPY ./code /openhands/code
# Install/Update Dependencies
# 1. Install pyproject.toml via poetry
# 2. Install playwright and chromium
# 3. Clear poetry, apt, mamba caches
RUN cd /openhands/code && \
/openhands/miniforge3/bin/mamba run -n base poetry env use python3.11 && \
/openhands/miniforge3/bin/mamba run -n base poetry install --only main,runtime --no-interaction --no-root && \
apt-get update && \
/openhands/miniforge3/bin/mamba run -n base poetry run pip install playwright && \
/openhands/miniforge3/bin/mamba run -n base poetry run playwright install --with-deps chromium && \
export OD_INTERPRETER_PATH=$(/openhands/miniforge3/bin/mamba run -n base poetry run python -c "import sys; print(sys.executable)") && \
{{ extra_deps }} {% if extra_deps %} && {% endif %} \
/openhands/miniforge3/bin/mamba run -n base poetry cache clear --all . && \
{% if not skip_init %}chmod -R g+rws /openhands/poetry && {% endif %} \
apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
/openhands/miniforge3/bin/mamba clean --all
# ================================================================
# END: Copy Project and Install/Update Dependencies
# ================================================================

View File

View File

@@ -0,0 +1,14 @@
import socket
def find_available_tcp_port() -> int:
"""Find an available TCP port, return -1 if none available."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(('localhost', 0))
port = sock.getsockname()[1]
return port
except Exception:
return -1
finally:
sock.close()

View File

@@ -0,0 +1,73 @@
# Security
Given the impressive capabilities of OpenHands and similar coding agents, ensuring robust security measures is essential to prevent unintended actions or security breaches. The SecurityAnalyzer framework provides a structured approach to monitor and analyze agent actions for potential security risks.
To enable this feature:
* From the web interface
* Open Configuration (by clicking the gear icon in the bottom right)
* Select a Security Analyzer from the dropdown
* Save settings
* (to disable) repeat the same steps, but click the X in the Security Analyzer dropdown
* From config.toml
```toml
[security]
# Enable confirmation mode
confirmation_mode = true
# The security analyzer to use
security_analyzer = "your-security-analyzer"
```
(to disable) remove the lines from config.toml
## SecurityAnalyzer Base Class
The `SecurityAnalyzer` class (analyzer.py) is an abstract base class designed to listen to an event stream and analyze actions for security risks and eventually act before the action is executed. Below is a detailed explanation of its components and methods:
### Initialization
- **event_stream**: An instance of `EventStream` that the analyzer will listen to for events.
### Event Handling
- **on_event(event: Event)**: Handles incoming events. If the event is an `Action`, it evaluates its security risk and acts upon it.
### Abstract Methods
- **handle_api_request(request: Request)**: Abstract method to handle API requests.
- **log_event(event: Event)**: Logs events.
- **act(event: Event)**: Defines actions to take based on the analyzed event.
- **security_risk(event: Action)**: Evaluates the security risk of an action and returns the risk level.
- **close()**: Cleanups resources used by the security analyzer.
In conclusion, a concrete security analyzer should evaluate the risk of each event and act accordingly (e.g. auto-confirm, send Slack message, etc).
For customization and decoupling from the OpenHands core logic, the security analyzer can define its own API endpoints that can then be accessed from the frontend. These API endpoints need to be secured (do not allow more capabilities than the core logic
provides).
## How to implement your own Security Analyzer
1. Create a submodule in [security](/openhands/security/) with your analyzer's desired name
* Have your main class inherit from [SecurityAnalyzer](/openhands/security/analyzer.py)
* Optional: define API endpoints for `/api/security/{path:path}` to manage settings,
2. Add your analyzer class to the [options](/openhands/security/options.py) to have it be visible from the frontend combobox
3. Optional: implement your modal frontend (for when you click on the lock) in [security](/frontend/src/components/modals/security/) and add your component to [Security.tsx](/frontend/src/components/modals/security/Security.tsx)
## Implemented Security Analyzers
### Invariant
It uses the [Invariant Analyzer](https://github.com/invariantlabs-ai/invariant) to analyze traces and detect potential issues with OpenHands's workflow. It uses confirmation mode to ask for user confirmation on potentially risky actions.
This allows the agent to run autonomously without fear that it will inadvertently compromise security or perform unintended actions that could be harmful.
Features:
* Detects:
* potential secret leaks by the agent
* security issues in Python code
* malicious bash commands
* Logs:
* actions and their associated risk
* OpenHands traces in JSON format
* Run-time settings:
* the [invariant policy](https://github.com/invariantlabs-ai/invariant?tab=readme-ov-file#policy-language)
* acceptable risk threshold

View File

@@ -0,0 +1,7 @@
from .analyzer import SecurityAnalyzer
from .invariant.analyzer import InvariantAnalyzer
__all__ = [
'SecurityAnalyzer',
'InvariantAnalyzer',
]

Some files were not shown because too many files have changed in this diff Show More