Refactor of error handling (#4575)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
Robert Brennan
2024-11-04 18:30:53 -05:00
committed by GitHub
parent 24117143ae
commit 98751a3ee2
36 changed files with 894 additions and 704 deletions

View File

@@ -128,14 +128,14 @@ describe.skip("ChatInterface", () => {
timestamp: new Date().toISOString(),
},
{
error: "Woops!",
error: true,
id: "",
message: "Something went wrong",
},
];
renderChatInterface(messages);
const error = screen.getByTestId("error-message");
expect(within(error).getByText("Woops!")).toBeInTheDocument();
expect(within(error).getByText("Something went wrong")).toBeInTheDocument();
});

View File

@@ -1,6 +1,7 @@
import React, { useEffect } from "react";
import { useTranslation } from "react-i18next";
import { useSelector } from "react-redux";
import toast from "react-hot-toast";
import { I18nKey } from "#/i18n/declaration";
import { RootState } from "#/store";
import AgentState from "#/types/AgentState";
@@ -16,7 +17,7 @@ enum IndicatorColor {
}
function AgentStatusBar() {
const { t } = useTranslation();
const { t, i18n } = useTranslation();
const { curAgentState } = useSelector((state: RootState) => state.agent);
const { curStatusMessage } = useSelector((state: RootState) => state.status);
@@ -94,15 +95,27 @@ function AgentStatusBar() {
const [statusMessage, setStatusMessage] = React.useState<string>("");
React.useEffect(() => {
if (curAgentState === AgentState.LOADING) {
const trimmedCustomMessage = curStatusMessage.status.trim();
if (trimmedCustomMessage) {
setStatusMessage(t(trimmedCustomMessage));
return;
let message = curStatusMessage.message || "";
if (curStatusMessage?.id) {
const id = curStatusMessage.id.trim();
if (i18n.exists(id)) {
message = t(curStatusMessage.id.trim()) || message;
}
}
if (curStatusMessage?.type === "error") {
toast.error(message);
return;
}
if (curAgentState === AgentState.LOADING && message.trim()) {
setStatusMessage(message);
} else {
setStatusMessage(AgentStatusMap[curAgentState].message);
}
}, [curStatusMessage.id]);
React.useEffect(() => {
setStatusMessage(AgentStatusMap[curAgentState].message);
}, [curAgentState, curStatusMessage.status]);
}, [curAgentState]);
return (
<div className="flex flex-col items-center">

View File

@@ -73,7 +73,7 @@ export function ChatInterface() {
isErrorMessage(message) ? (
<ErrorMessage
key={index}
error={message.error}
id={message.id}
message={message.message}
/>
) : (

View File

@@ -6,6 +6,7 @@ type Message = {
};
type ErrorMessage = {
error: string;
error: boolean;
id?: string;
message: string;
};

View File

@@ -1,14 +1,41 @@
import { useState, useEffect } from "react";
import { useTranslation } from "react-i18next";
interface ErrorMessageProps {
error: string;
id?: string;
message: string;
}
export function ErrorMessage({ error, message }: ErrorMessageProps) {
export function ErrorMessage({ id, message }: ErrorMessageProps) {
const { t, i18n } = useTranslation();
const [showDetails, setShowDetails] = useState(true);
const [headline, setHeadline] = useState("");
const [details, setDetails] = useState(message);
useEffect(() => {
if (id && i18n.exists(id)) {
setHeadline(t(id));
setDetails(message);
setShowDetails(false);
}
}, [id, message, i18n.language]);
return (
<div className="flex gap-2 items-center justify-start border-l-2 border-danger pl-2 my-2 py-2">
<div className="text-sm leading-4 flex flex-col gap-2">
<p className="text-danger font-bold">{error}</p>
<p className="text-neutral-300">{message}</p>
{headline && <p className="text-danger font-bold">{headline}</p>}
{headline && (
<button
type="button"
onClick={() => setShowDetails(!showDetails)}
className="cursor-pointer text-left"
>
{showDetails
? t("ERROR_MESSAGE$HIDE_DETAILS")
: t("ERROR_MESSAGE$SHOW_DETAILS")}
</button>
)}
{showDetails && <p className="text-neutral-300">{details}</p>}
</div>
</div>
);

View File

@@ -1441,6 +1441,12 @@
"fr": "Privé",
"tr": "Özel"
},
"ERROR_MESSAGE$SHOW_DETAILS": {
"en": "Show details"
},
"ERROR_MESSAGE$HIDE_DETAILS": {
"en": "Hide details"
},
"STATUS$STARTING_RUNTIME": {
"en": "Starting Runtime...",
"zh-CN": "启动运行时...",
@@ -1510,5 +1516,17 @@
"ar": "في انتظار جاهزية العميل...",
"fr": "En attente que le client soit prêt...",
"tr": "İstemcinin hazır olması bekleniyor..."
},
"STATUS$ERROR_LLM_AUTHENTICATION": {
"en": "Error authenticating with the LLM provider. Please check your API key"
},
"STATUS$ERROR_RUNTIME_DISCONNECTED": {
"en": "There was an error while connecting to the runtime. Please refresh the page."
},
"AGENT_ERROR$BAD_ACTION": {
"en": "Agent tried to execute a malformed action."
},
"AGENT_ERROR$ACTION_TIMEOUT": {
"en": "Action timed out."
}
}

View File

@@ -184,21 +184,6 @@ function App() {
if (q) addIntialQueryToChat(q, files);
}, [settings]);
const handleError = (message: string) => {
const [error, ...rest] = message.split(":");
const details = rest.join(":");
if (!details) {
dispatch(
addErrorMessage({
error: "An error has occured",
message: error,
}),
);
} else {
dispatch(addErrorMessage({ error, message: details }));
}
};
const handleMessage = React.useCallback(
(message: MessageEvent<WebSocket.Data>) => {
// set token received from the server
@@ -224,7 +209,12 @@ function App() {
return;
}
if (isErrorObservation(parsed)) {
handleError(parsed.message);
dispatch(
addErrorMessage({
id: parsed.extras?.error_id,
message: parsed.message,
}),
);
return;
}

View File

@@ -1,4 +1,8 @@
import { addAssistantMessage, addUserMessage } from "#/state/chatSlice";
import {
addAssistantMessage,
addUserMessage,
addErrorMessage,
} from "#/state/chatSlice";
import { setCode, setActiveFilepath } from "#/state/codeSlice";
import { appendJupyterInput } from "#/state/jupyterSlice";
import {
@@ -119,13 +123,19 @@ export function handleActionMessage(message: ActionMessage) {
}
export function handleStatusMessage(message: StatusMessage) {
const msg = message.status == null ? "" : message.status.trim();
store.dispatch(
setCurStatusMessage({
...message,
status: msg,
}),
);
if (message.type === "info") {
store.dispatch(
setCurStatusMessage({
...message,
}),
);
} else if (message.type === "error") {
store.dispatch(
addErrorMessage({
...message,
}),
);
}
}
export function handleAssistantMessage(data: string | SocketMessage) {
@@ -139,9 +149,11 @@ export function handleAssistantMessage(data: string | SocketMessage) {
if ("action" in socketMessage) {
handleActionMessage(socketMessage);
} else if ("status" in socketMessage) {
} else if ("observation" in socketMessage) {
handleObservationMessage(socketMessage);
} else if ("status_update" in socketMessage) {
handleStatusMessage(socketMessage);
} else {
handleObservationMessage(socketMessage);
console.error("Unknown message type", socketMessage);
}
}

View File

@@ -39,10 +39,10 @@ export const chatSlice = createSlice({
addErrorMessage(
state,
action: PayloadAction<{ error: string; message: string }>,
action: PayloadAction<{ id?: string; message: string }>,
) {
const { error, message } = action.payload;
state.messages.push({ error, message });
const { id, message } = action.payload;
state.messages.push({ id, message, error: true });
},
clearMessages(state) {

View File

@@ -2,8 +2,10 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
import { StatusMessage } from "#/types/Message";
const initialStatusMessage: StatusMessage = {
status: "",
is_error: false,
status_update: true,
type: "info",
id: "",
message: "",
};
export const statusSlice = createSlice({

View File

@@ -33,10 +33,8 @@ export interface ObservationMessage {
}
export interface StatusMessage {
// TODO not implemented yet
// Whether the status is an error, default is false
is_error: boolean;
// A status message to display to the user
status: string;
status_update: true;
type: string;
id: string;
message: string;
}

View File

@@ -54,6 +54,9 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> {
export interface ErrorObservation extends OpenHandsObservationEvent<"error"> {
source: "user";
extras: {
error_id?: string;
};
}
export type OpenHandsObservation =

View File

@@ -1,7 +1,7 @@
import asyncio
import copy
import traceback
from typing import Type
from typing import Callable, Type
import litellm
@@ -35,9 +35,7 @@ from openhands.events.event import Event
from openhands.events.observation import (
AgentDelegateObservation,
AgentStateChangedObservation,
CmdOutputObservation,
ErrorObservation,
FatalErrorObservation,
Observation,
)
from openhands.events.serialization.event import truncate_content
@@ -77,6 +75,7 @@ class AgentController:
initial_state: State | None = None,
is_delegate: bool = False,
headless_mode: bool = True,
status_callback: Callable | None = None,
):
"""Initializes a new instance of the AgentController class.
@@ -119,6 +118,7 @@ class AgentController:
# stuck helper
self._stuck_detector = StuckDetector(self.state)
self.status_callback = status_callback
async def close(self):
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
@@ -132,7 +132,7 @@ class AgentController:
message (str): The message to log.
"""
message = f'[Agent Controller {self.id}] {message}'
getattr(logger, level)(message, extra=extra)
getattr(logger, level)(message, extra=extra, stacklevel=2)
def update_state_before_step(self):
self.state.iteration += 1
@@ -142,22 +142,16 @@ class AgentController:
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
async def report_error(self, message: str, exception: Exception | None = None):
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.
This method should be called for a particular type of errors, which have:
- a user-friendly message, which will be shown in the chat box. This should not be a raw exception message.
- an ErrorObservation that can be sent to the LLM by the user role, with the exception message, so it can self-correct next time.
"""
self.state.last_error = message
if exception:
self.state.last_error += f': {exception}'
detail = str(exception) if exception is not None else ''
if exception is not None and isinstance(exception, litellm.AuthenticationError):
detail = 'Please check your credentials. Is your API key correct?'
self.event_stream.add_event(
ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT
)
async def _react_to_exception(
self,
e: Exception,
):
await self.set_agent_state_to(AgentState.ERROR)
if self.status_callback is not None:
err_id = ''
if isinstance(e, litellm.AuthenticationError):
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
self.status_callback('error', err_id, str(e))
async def start_step_loop(self):
"""The main loop for the agent's step-by-step execution."""
@@ -172,12 +166,7 @@ class AgentController:
except Exception as e:
traceback.print_exc()
self.log('error', f'Error while running the agent: {e}')
self.log('error', traceback.format_exc())
await self.report_error(
'There was an unexpected error while running the agent', exception=e
)
await self.set_agent_state_to(AgentState.ERROR)
break
await self._react_to_exception(e)
await asyncio.sleep(0.1)
@@ -227,15 +216,6 @@ class AgentController:
Args:
observation (observation): The observation to handle.
"""
if (
self._pending_action
and hasattr(self._pending_action, 'confirmation_state')
and self._pending_action.confirmation_state
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
return
# Make sure we print the observation in the same way as the LLM sees it
observation_to_print = copy.deepcopy(observation)
if len(observation_to_print.content) > self.agent.llm.config.max_message_chars:
observation_to_print.content = truncate_content(
@@ -243,7 +223,6 @@ class AgentController:
)
self.log('debug', str(observation_to_print), extra={'msg_type': 'OBSERVATION'})
# Merge with the metrics from the LLM - it will to synced to the controller's local metrics in update_state_after_step()
if observation.llm_metrics is not None:
self.agent.llm.metrics.merge(observation.llm_metrics)
@@ -255,19 +234,11 @@ class AgentController:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
return
if isinstance(observation, CmdOutputObservation):
return
elif isinstance(observation, AgentDelegateObservation):
if isinstance(observation, AgentDelegateObservation):
self.state.history.on_event(observation)
elif isinstance(observation, ErrorObservation):
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)
elif isinstance(observation, FatalErrorObservation):
self.state.last_error = (
f'There was a fatal error during agent execution: {str(observation)}'
)
self.state.metrics.merge(self.state.local_metrics)
await self.set_agent_state_to(AgentState.ERROR)
async def _handle_message_action(self, action: MessageAction):
"""Handles message actions from the event stream.
@@ -420,13 +391,8 @@ class AgentController:
await asyncio.sleep(1)
return
# check if agent got stuck before taking any action
if self._is_stuck():
# This need to go BEFORE report_error to sync metrics
self.event_stream.add_event(
FatalErrorObservation('Agent got stuck in a loop'),
EventSource.ENVIRONMENT,
)
await self._react_to_exception(RuntimeError('Agent got stuck in a loop'))
return
if self.delegate is not None:
@@ -465,15 +431,12 @@ class AgentController:
if action is None:
raise LLMNoActionError('No action was returned')
except (LLMMalformedActionError, LLMNoActionError, LLMResponseError) as e:
# report to the user
# and send the underlying exception to the LLM for self-correction
await self.report_error(str(e))
return
# FIXME: more graceful handling of litellm.exceptions.ContextWindowExceededError
# e.g. try to condense the memory and try again
except litellm.exceptions.ContextWindowExceededError as e:
self.state.last_error = str(e)
await self.set_agent_state_to(AgentState.ERROR)
self.event_stream.add_event(
ErrorObservation(
content=str(e),
),
EventSource.AGENT,
)
return
if action.runnable:
@@ -495,6 +458,7 @@ class AgentController:
self.event_stream.add_event(action, EventSource.AGENT)
await self.update_state_after_step()
self.log('debug', str(action), extra={'msg_type': 'ACTION'})
async def _delegate_step(self):
@@ -524,7 +488,10 @@ class AgentController:
self.delegate = None
self.delegateAction = None
await self.report_error('Delegator agent encountered an error')
self.event_stream.add_event(
ErrorObservation('Delegate agent encountered an error'),
EventSource.AGENT,
)
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
self.log('debug', 'Delegate agent has finished execution')
# retrieve delegate result
@@ -571,21 +538,18 @@ class AgentController:
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# This need to go BEFORE report_error to sync metrics
await self.set_agent_state_to(AgentState.ERROR)
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
f'Agent reached maximum {limit_type} in headless mode, task stopped. '
e = RuntimeError(
f'Agent reached maximum {limit_type} in headless mode. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
)
await self._react_to_exception(e)
else:
await self.set_agent_state_to(AgentState.PAUSED)
await self.report_error(
f'Agent reached maximum {limit_type}, task paused. '
e = RuntimeError(
f'Agent reached maximum {limit_type}. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
f'{TRAFFIC_CONTROL_REMINDER}'
)
# FIXME: this isn't really an exception--we should have a different path
await self._react_to_exception(e)
stop_step = True
return stop_step

View File

@@ -11,6 +11,7 @@ from openhands.events.action import (
MessageAction,
)
from openhands.events.action.agent import AgentFinishAction
from openhands.events.observation import ErrorObservation
from openhands.llm.metrics import Metrics
from openhands.memory.history import ShortTermHistory
from openhands.storage.files import FileStore
@@ -80,7 +81,6 @@ class State:
history: ShortTermHistory = field(default_factory=ShortTermHistory)
inputs: dict = field(default_factory=dict)
outputs: dict = field(default_factory=dict)
last_error: str | None = None
agent_state: AgentState = AgentState.LOADING
resume_state: AgentState | None = None
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
@@ -97,6 +97,7 @@ class State:
# NOTE: This will never be used by the controller, but it can be used by different
# evaluation tasks to store extra data needed to track the progress/state of the task.
extra_data: dict[str, Any] = field(default_factory=dict)
last_error: str = ''
def save_to_session(self, sid: str, file_store: FileStore):
pickled = pickle.dumps(self)
@@ -124,9 +125,6 @@ class State:
else:
state.resume_state = None
# don't carry last_error anymore after restore
state.last_error = None
# first state after restore
state.agent_state = AgentState.LOADING
return state
@@ -151,11 +149,9 @@ class State:
if not hasattr(self, 'history'):
self.history = ShortTermHistory()
# restore the relevant data in history from the state
self.history.start_id = self.start_id
self.history.end_id = self.end_id
# remove the restored data from the state if any
def get_current_user_intent(self):
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import sys
from typing import Type
from termcolor import colored
@@ -13,6 +14,7 @@ from openhands.core.config import (
load_app_config,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.loop import run_agent_until_done
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
@@ -114,7 +116,6 @@ async def main():
sid=sid,
plugins=agent_cls.sandbox_plugins,
)
await runtime.connect()
controller = AgentController(
agent=agent,
@@ -124,11 +125,14 @@ async def main():
event_stream=event_stream,
)
if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())
async def prompt_for_next_task():
next_message = input('How can I help? >> ')
# Run input() in a thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
next_message = await loop.run_in_executor(
None, lambda: input('How can I help? >> ')
)
if not next_message.strip():
await prompt_for_next_task()
if next_message == 'exit':
event_stream.add_event(
ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
@@ -140,31 +144,45 @@ async def main():
async def on_event(event: Event):
display_event(event)
if isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.ERROR:
print('An error occurred. Please try again.')
if event.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.ERROR,
]:
await prompt_for_next_task()
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
await prompt_for_next_task()
await runtime.connect()
while controller.state.agent_state not in [
AgentState.STOPPED,
]:
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
asyncio.create_task(prompt_for_next_task())
print('Exiting...')
await controller.close()
await run_agent_until_done(
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
)
if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print('Received keyboard interrupt, shutting down...')
except ConnectionRefusedError as e:
print(f'Connection refused: {e}')
sys.exit(1)
except Exception as e:
print(f'An error occurred: {e}')
sys.exit(1)
finally:
pass
try:
# Cancel all running tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Wait for all tasks to complete with a timeout
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
except Exception as e:
print(f'Error during cleanup: {e}')
sys.exit(1)

50
openhands/core/loop.py Normal file
View File

@@ -0,0 +1,50 @@
import asyncio
from openhands.controller import AgentController
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.runtime.base import Runtime
async def run_agent_until_done(
controller: AgentController,
runtime: Runtime,
end_states: list[AgentState],
):
"""
run_agent_until_done takes a controller and a runtime, and will run
the agent until it reaches a terminal state.
Note that runtime must be connected before being passed in here.
"""
controller.agent_task = asyncio.create_task(controller.start_step_loop())
def status_callback(msg_type, msg_id, msg):
if msg_type == 'error':
logger.error(msg)
if controller:
controller.state.last_error = msg
asyncio.create_task(controller.set_agent_state_to(AgentState.ERROR))
else:
logger.info(msg)
if hasattr(runtime, 'status_callback') and runtime.status_callback:
raise ValueError(
'Runtime status_callback was set, but run_agent_until_done will override it'
)
if hasattr(controller, 'status_callback') and controller.status_callback:
raise ValueError(
'Controller status_callback was set, but run_agent_until_done will override it'
)
runtime.status_callback = status_callback
controller.status_callback = status_callback
while controller.state.agent_state not in end_states:
await asyncio.sleep(1)
if not controller.agent_task.done():
controller.agent_task.cancel()
try:
await controller.agent_task
except asyncio.CancelledError:
pass

View File

@@ -17,6 +17,7 @@ from openhands.core.config import (
parse_arguments,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.loop import run_agent_until_done
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import MessageAction
@@ -122,7 +123,6 @@ async def run_controller(
if runtime is None:
runtime = create_runtime(config, sid=sid)
await runtime.connect()
event_stream = runtime.event_stream
# restore cli session if enabled
@@ -147,9 +147,6 @@ async def run_controller(
headless_mode=headless_mode,
)
if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())
assert isinstance(
initial_user_action, Action
), f'initial user actions must be an Action, got {type(initial_user_action)}'
@@ -188,22 +185,27 @@ async def run_controller(
event_stream.add_event(action, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
while controller.state.agent_state not in [
await runtime.connect()
end_states = [
AgentState.FINISHED,
AgentState.REJECTED,
AgentState.ERROR,
AgentState.PAUSED,
AgentState.STOPPED,
]:
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
]
try:
await run_agent_until_done(controller, runtime, end_states)
except Exception as e:
logger.error(f'Exception in main loop: {e}')
# save session when we're about to close
if config.enable_cli_session:
end_state = controller.get_state()
end_state.save_to_session(event_stream.sid, event_stream.file_store)
# close when done
await controller.close()
state = controller.get_state()
# save trajectories if applicable

View File

@@ -6,7 +6,7 @@ from openhands.events.observation.commands import (
)
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import ErrorObservation, FatalErrorObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.files import (
FileEditObservation,
FileReadObservation,
@@ -26,7 +26,6 @@ __all__ = [
'FileWriteObservation',
'FileEditObservation',
'ErrorObservation',
'FatalErrorObservation',
'AgentStateChangedObservation',
'AgentDelegateObservation',
'SuccessObservation',

View File

@@ -13,6 +13,7 @@ class ErrorObservation(Observation):
"""
observation: str = ObservationType.ERROR
error_id: str = ''
@property
def message(self) -> str:
@@ -20,17 +21,3 @@ class ErrorObservation(Observation):
def __str__(self) -> str:
return f'**ErrorObservation**\n{self.content}'
@dataclass
class FatalErrorObservation(Observation):
"""This data class represents a fatal error encountered by the agent.
This is the type of error that LLM CANNOT recover from, and the agent controller should stop the execution and report the error to the user.
E.g., Remote runtime action execution failure: 503 Server Error: Service Unavailable for url OR 404 Not Found.
"""
observation: str = ObservationType.ERROR
def __str__(self) -> str:
return f'**FatalErrorObservation**\n{self.content}'

View File

@@ -152,12 +152,16 @@ class EventStream:
def add_event(self, event: Event, source: EventSource):
try:
asyncio.get_running_loop().create_task(self.async_add_event(event, source))
asyncio.get_running_loop().create_task(self._async_add_event(event, source))
except RuntimeError:
# No event loop running...
asyncio.run(self.async_add_event(event, source))
asyncio.run(self._async_add_event(event, source))
async def async_add_event(self, event: Event, source: EventSource):
async def _async_add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
)
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
self._cur_id += 1

View File

@@ -12,7 +12,6 @@ from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.observation.error import FatalErrorObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import event_to_dict
from openhands.events.stream import EventStream
@@ -34,7 +33,6 @@ class ShortTermHistory(list[Event]):
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
FatalErrorObservation,
)
def __init__(self):

View File

@@ -37,7 +37,6 @@ from openhands.events.action import (
from openhands.events.observation import (
CmdOutputObservation,
ErrorObservation,
FatalErrorObservation,
FileReadObservation,
FileWriteObservation,
IPythonRunCellObservation,
@@ -168,7 +167,7 @@ class ActionExecutor:
async def run(
self, action: CmdRunAction
) -> CmdOutputObservation | FatalErrorObservation:
) -> CmdOutputObservation | ErrorObservation:
return self.bash_session.run(action)
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:

View File

@@ -5,6 +5,8 @@ import os
from abc import abstractmethod
from typing import Callable
from requests.exceptions import ConnectionError
from openhands.core.config import AppConfig, SandboxConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventSource, EventStream, EventStreamSubscriber
@@ -31,6 +33,22 @@ from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
from openhands.runtime.utils.edit import FileEditRuntimeMixin
from openhands.utils.async_utils import call_sync_from_async
STATUS_MESSAGES = {
'STATUS$STARTING_RUNTIME': 'Starting runtime...',
'STATUS$STARTING_CONTAINER': 'Starting container...',
'STATUS$PREPARING_CONTAINER': 'Preparing container...',
'STATUS$CONTAINER_STARTED': 'Container started.',
'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...',
}
class RuntimeNotReadyError(Exception):
pass
class RuntimeDisconnectedError(Exception):
pass
def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
ret = {}
@@ -54,6 +72,7 @@ class Runtime(FileEditRuntimeMixin):
config: AppConfig
initial_env_vars: dict[str, str]
attach_to_existing: bool
status_callback: Callable | None
def __init__(
self,
@@ -62,14 +81,14 @@ class Runtime(FileEditRuntimeMixin):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
status_callback: Callable | None = None,
attach_to_existing: bool = False,
):
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
self.status_message_callback = status_message_callback
self.status_callback = status_callback
self.attach_to_existing = attach_to_existing
self.config = copy.deepcopy(config)
@@ -95,7 +114,17 @@ class Runtime(FileEditRuntimeMixin):
def log(self, level: str, message: str) -> None:
message = f'[runtime {self.sid}] {message}'
getattr(logger, level)(message)
getattr(logger, level)(message, stacklevel=2)
def send_status_message(self, message_id: str):
"""Sends a status message if the callback function was provided."""
if self.status_callback:
msg = STATUS_MESSAGES.get(message_id, '')
self.status_callback('info', message_id, msg)
def send_error_message(self, message_id: str, message: str):
if self.status_callback:
self.status_callback('error', message_id, message)
# ====================================================================
@@ -131,15 +160,28 @@ class Runtime(FileEditRuntimeMixin):
if event.timeout is None:
event.timeout = self.config.sandbox.timeout
assert event.timeout is not None
observation: Observation = await call_sync_from_async(
self.run_action, event
)
try:
observation: Observation = await call_sync_from_async(
self.run_action, event
)
except Exception as e:
err_id = ''
if isinstance(e, ConnectionError) or isinstance(
e, RuntimeDisconnectedError
):
err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED'
self.log('error', f'Unexpected error while running action {e}')
self.log('error', f'Problematic action: {str(event)}')
self.send_error_message(err_id, str(e))
self.close()
return
observation._cause = event.id # type: ignore[attr-defined]
observation.tool_call_metadata = event.tool_call_metadata
# this might be unnecessary, since source should be set by the event stream when we're here
source = event.source if event.source else EventSource.AGENT
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
def run_action(self, action: Action) -> Observation:
"""Run an action and return the resulting observation.

View File

@@ -7,7 +7,7 @@ import requests
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.builder import RuntimeBuilder
from openhands.runtime.utils.request import is_429_error, send_request_with_retry
from openhands.runtime.utils.request import send_request
from openhands.runtime.utils.shutdown_listener import (
should_continue,
sleep_if_should_continue,
@@ -45,18 +45,21 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
files.append(('tags', (None, tag)))
# Send the POST request to /build (Begins the build process)
response = send_request_with_retry(
self.session,
'POST',
f'{self.api_url}/build',
files=files,
timeout=30,
retry_fns=[is_429_error],
)
if response.status_code != 202:
logger.error(f'Build initiation failed: {response.text}')
raise RuntimeError(f'Build initiation failed: {response.text}')
try:
response = send_request(
self.session,
'POST',
f'{self.api_url}/build',
files=files,
timeout=30,
)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 429:
logger.warning('Build was rate limited. Retrying in 30 seconds.')
time.sleep(30)
return self.build(path, tags, platform)
else:
raise e
build_data = response.json()
build_id = build_data['build_id']
@@ -70,12 +73,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
logger.error('Build timed out after 30 minutes')
raise RuntimeError('Build timed out after 30 minutes')
status_response = send_request_with_retry(
status_response = send_request(
self.session,
'GET',
f'{self.api_url}/build_status',
params={'build_id': build_id},
timeout=30,
)
if status_response.status_code != 200:
@@ -112,12 +114,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool:
"""Checks if an image exists in the remote registry using the /image_exists endpoint."""
params = {'image': image_name}
response = send_request_with_retry(
response = send_request(
self.session,
'GET',
f'{self.api_url}/image_exists',
params=params,
timeout=30,
)
if response.status_code != 200:

View File

@@ -27,14 +27,14 @@ class E2BRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
sandbox: E2BSandbox | None = None,
status_message_callback: Optional[Callable] = None,
status_callback: Optional[Callable] = None,
):
super().__init__(
config,
event_stream,
sid,
plugins,
status_message_callback=status_message_callback,
status_callback=status_callback,
)
if sandbox is None:
self.sandbox = E2BSandbox()

View File

@@ -25,7 +25,7 @@ from openhands.events.action import (
)
from openhands.events.action.action import Action
from openhands.events.observation import (
FatalErrorObservation,
ErrorObservation,
NullObservation,
Observation,
UserRejectObservation,
@@ -36,8 +36,9 @@ from openhands.runtime.base import Runtime
from openhands.runtime.builder import DockerRuntimeBuilder
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils import find_available_tcp_port
from openhands.runtime.utils.request import send_request_with_retry
from openhands.runtime.utils.request import send_request
from openhands.runtime.utils.runtime_build import build_runtime_image
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.tenacity_stop import stop_if_should_exit
@@ -123,7 +124,7 @@ class EventStreamRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
status_callback: Callable | None = None,
attach_to_existing: bool = False,
):
super().__init__(
@@ -132,7 +133,7 @@ class EventStreamRuntime(Runtime):
sid,
plugins,
env_vars,
status_message_callback,
status_callback,
attach_to_existing,
)
@@ -143,7 +144,7 @@ class EventStreamRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
status_callback: Callable | None = None,
attach_to_existing: bool = False,
):
self.config = config
@@ -151,7 +152,7 @@ class EventStreamRuntime(Runtime):
self._container_port = 30001 # initial dummy value
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
self.session = requests.Session()
self.status_message_callback = status_message_callback
self.status_callback = status_callback
self.docker_client: docker.DockerClient = self._init_docker_client()
self.base_container_image = self.config.sandbox.base_container_image
@@ -181,7 +182,7 @@ class EventStreamRuntime(Runtime):
sid,
plugins,
env_vars,
status_message_callback,
status_callback,
attach_to_existing,
)
@@ -205,21 +206,21 @@ class EventStreamRuntime(Runtime):
self.log(
'info', f'Starting runtime with image: {self.runtime_container_image}'
)
self._init_container()
await call_sync_from_async(self._init_container)
self.log('info', f'Container started: {self.container_name}')
else:
self._attach_to_container()
await call_sync_from_async(self._attach_to_container)
if not self.attach_to_existing:
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
self._wait_until_alive()
await call_sync_from_async(self._wait_until_alive)
if not self.attach_to_existing:
self.log('info', 'Runtime is ready.')
if not self.attach_to_existing:
self.setup_initial_env()
await call_sync_from_async(self.setup_initial_env)
self.log(
'debug',
@@ -238,82 +239,74 @@ class EventStreamRuntime(Runtime):
)
raise ex
@tenacity.retry(
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
wait=tenacity.wait_fixed(5),
)
def _init_container(self):
try:
self.log('debug', 'Preparing to start container...')
self.send_status_message('STATUS$PREPARING_CONTAINER')
plugin_arg = ''
if self.plugins is not None and len(self.plugins) > 0:
plugin_arg = (
f'--plugins {" ".join([plugin.name for plugin in self.plugins])} '
)
self._host_port = self._find_available_port()
self._container_port = (
self._host_port
) # in future this might differ from host port
self.api_url = (
f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
self.log('debug', 'Preparing to start container...')
self.send_status_message('STATUS$PREPARING_CONTAINER')
plugin_arg = ''
if self.plugins is not None and len(self.plugins) > 0:
plugin_arg = (
f'--plugins {" ".join([plugin.name for plugin in self.plugins])} '
)
use_host_network = self.config.sandbox.use_host_network
network_mode: str | None = 'host' if use_host_network else None
port_mapping: dict[str, list[dict[str, str]]] | None = (
None
if use_host_network
else {
f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]
}
)
self._host_port = self._find_available_port()
self._container_port = (
self._host_port
) # in future this might differ from host port
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
if use_host_network:
self.log(
'warn',
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop',
)
use_host_network = self.config.sandbox.use_host_network
network_mode: str | None = 'host' if use_host_network else None
port_mapping: dict[str, list[dict[str, str]]] | None = (
None
if use_host_network
else {f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]}
)
# Combine environment variables
environment = {
'port': str(self._container_port),
'PYTHONUNBUFFERED': 1,
}
if self.config.debug or DEBUG:
environment['DEBUG'] = 'true'
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
if (
self.config.workspace_mount_path is not None
and self.config.workspace_mount_path_in_sandbox is not None
):
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
volumes = {
self.config.workspace_mount_path: {
'bind': self.config.workspace_mount_path_in_sandbox,
'mode': 'rw',
}
}
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
else:
logger.debug(
'Mount dir is not set, will not mount the workspace directory to the container'
)
volumes = None
if use_host_network:
self.log(
'debug',
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
'warn',
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop',
)
if self.config.sandbox.browsergym_eval_env is not None:
browsergym_arg = (
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
)
else:
browsergym_arg = ''
# Combine environment variables
environment = {
'port': str(self._container_port),
'PYTHONUNBUFFERED': 1,
}
if self.config.debug or DEBUG:
environment['DEBUG'] = 'true'
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
if (
self.config.workspace_mount_path is not None
and self.config.workspace_mount_path_in_sandbox is not None
):
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
volumes = {
self.config.workspace_mount_path: {
'bind': self.config.workspace_mount_path_in_sandbox,
'mode': 'rw',
}
}
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
else:
logger.debug(
'Mount dir is not set, will not mount the workspace directory to the container'
)
volumes = None
self.log(
'debug',
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
)
if self.config.sandbox.browsergym_eval_env is not None:
browsergym_arg = (
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
)
else:
browsergym_arg = ''
try:
self.container = self.docker_client.containers.run(
self.runtime_container_image,
command=(
@@ -337,6 +330,21 @@ class EventStreamRuntime(Runtime):
self.log_buffer = LogBuffer(self.container, self.log)
self.log('debug', f'Container started. Server url: {self.api_url}')
self.send_status_message('STATUS$CONTAINER_STARTED')
except docker.errors.APIError as e:
# check 409 error
if '409' in str(e):
self.log(
'warning',
f'Container {self.container_name} already exists. Removing...',
)
self._close_containers(rm_all_containers=True)
return self._init_container()
else:
self.log(
'error',
f'Error: Instance {self.container_name} FAILED to start container!\n',
)
except Exception as e:
self.log(
'error',
@@ -384,27 +392,20 @@ class EventStreamRuntime(Runtime):
@tenacity.retry(
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
wait=tenacity.wait_exponential(multiplier=2, min=1, max=20),
reraise=(ConnectionRefusedError,),
wait=tenacity.wait_fixed(2),
)
def _wait_until_alive(self):
self._refresh_logs()
if not self.log_buffer:
raise RuntimeError('Runtime client is not ready.')
response = send_request_with_retry(
send_request(
self.session,
'GET',
f'{self.api_url}/alive',
retry_exceptions=[ConnectionRefusedError],
timeout=300, # 5 minutes gives the container time to be alive 🧟‍♂️
timeout=5,
)
if response.status_code == 200:
return
else:
msg = f'Action execution API is not alive. Response: {response}'
self.log('error', msg)
raise RuntimeError(msg)
def close(self, rm_all_containers: bool = True):
"""Closes the EventStreamRuntime and associated objects
@@ -421,7 +422,9 @@ class EventStreamRuntime(Runtime):
if self.attach_to_existing:
return
self._close_containers(rm_all_containers)
def _close_containers(self, rm_all_containers: bool = True):
try:
containers = self.docker_client.containers.list(all=True)
for container in containers:
@@ -466,10 +469,11 @@ class EventStreamRuntime(Runtime):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
return FatalErrorObservation(f'Action {action_type} does not exist.')
raise ValueError(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return FatalErrorObservation(
f'Action {action_type} is not supported in the current runtime.'
return ErrorObservation(
f'Action {action_type} is not supported in the current runtime.',
error_id='AGENT_ERROR$BAD_ACTION',
)
if (
getattr(action, 'confirmation_state', None)
@@ -484,33 +488,21 @@ class EventStreamRuntime(Runtime):
assert action.timeout is not None
try:
response = send_request_with_retry(
response = send_request(
self.session,
'POST',
f'{self.api_url}/execute_action',
json={'action': event_to_dict(action)},
timeout=action.timeout,
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
)
if response.status_code == 200:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
else:
self.log('debug', f'action: {action}')
self.log('debug', f'response: {response}')
error_message = response.text
self.log('error', f'Error from server: {error_message}')
obs = FatalErrorObservation(
f'Action execution failed: {error_message}'
)
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
except requests.Timeout:
self.log('error', 'No response received within the timeout period.')
obs = FatalErrorObservation(
f'Action execution timed out after {action.timeout} seconds.'
raise RuntimeError(
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
)
except Exception as e:
self.log('error', f'Error during action execution: {e}')
obs = FatalErrorObservation(f'Action execution failed: {str(e)}')
self._refresh_logs()
return obs
@@ -567,7 +559,7 @@ class EventStreamRuntime(Runtime):
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
response = send_request_with_retry(
send_request(
self.session,
'POST',
f'{self.api_url}/upload_file',
@@ -575,11 +567,6 @@ class EventStreamRuntime(Runtime):
params=params,
timeout=300,
)
if response.status_code == 200:
return
else:
error_message = response.text
raise Exception(f'Copy operation failed: {error_message}')
except requests.Timeout:
raise TimeoutError('Copy operation timed out')
@@ -604,31 +591,25 @@ class EventStreamRuntime(Runtime):
if path is not None:
data['path'] = path
response = send_request_with_retry(
response = send_request(
self.session,
'POST',
f'{self.api_url}/list_files',
json=data,
timeout=30, # 30 seconds because the container should already be alive
timeout=10,
)
if response.status_code == 200:
response_json = response.json()
assert isinstance(response_json, list)
return response_json
else:
error_message = response.text
raise Exception(f'List files operation failed: {error_message}')
response_json = response.json()
assert isinstance(response_json, list)
return response_json
except requests.Timeout:
raise TimeoutError('List files operation timed out')
except Exception as e:
raise RuntimeError(f'List files operation failed: {str(e)}')
def copy_from(self, path: str) -> bytes:
"""Zip all files in the sandbox and return as a stream of bytes."""
self._refresh_logs()
try:
params = {'path': path}
response = send_request_with_retry(
response = send_request(
self.session,
'GET',
f'{self.api_url}/download_files',
@@ -636,16 +617,10 @@ class EventStreamRuntime(Runtime):
stream=True,
timeout=30,
)
if response.status_code == 200:
data = response.content
return data
else:
error_message = response.text
raise Exception(f'Copy operation failed: {error_message}')
data = response.content
return data
except requests.Timeout:
raise TimeoutError('Copy operation timed out')
except Exception as e:
raise RuntimeError(f'Copy operation failed: {str(e)}')
def _is_port_in_use_docker(self, port):
containers = self.docker_client.containers.list()
@@ -663,8 +638,3 @@ class EventStreamRuntime(Runtime):
return port
# If no port is found after max_attempts, return the last tried port
return port
def send_status_message(self, message: str):
"""Sends a status message if the callback function was provided."""
if self.status_message_callback:
self.status_message_callback(message)

View File

@@ -75,7 +75,7 @@ class ModalRuntime(EventStreamRuntime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
status_callback: Callable | None = None,
attach_to_existing: bool = False,
):
assert config.modal_api_token_id, 'Modal API token id is required'
@@ -102,7 +102,7 @@ class ModalRuntime(EventStreamRuntime):
self.container_port = 3000
self.session = requests.Session()
self.status_message_callback = status_message_callback
self.status_callback = status_callback
self.base_container_image_id = self.config.sandbox.base_container_image
self.runtime_container_image_id = self.config.sandbox.runtime_container_image
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
@@ -122,7 +122,7 @@ class ModalRuntime(EventStreamRuntime):
sid,
plugins,
env_vars,
status_message_callback,
status_callback,
attach_to_existing,
)

View File

@@ -1,12 +1,11 @@
import os
import tempfile
import threading
import time
from typing import Callable, Optional
from zipfile import ZipFile
import requests
from requests.exceptions import Timeout
import tenacity
from openhands.core.config import AppConfig
from openhands.events import EventStream
@@ -21,22 +20,26 @@ from openhands.events.action import (
)
from openhands.events.action.action import Action
from openhands.events.observation import (
FatalErrorObservation,
ErrorObservation,
NullObservation,
Observation,
)
from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.base import Runtime
from openhands.runtime.base import (
Runtime,
RuntimeDisconnectedError,
RuntimeNotReadyError,
)
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.command import get_remote_startup_command
from openhands.runtime.utils.request import (
is_404_error,
is_503_error,
send_request_with_retry,
send_request,
)
from openhands.runtime.utils.runtime_build import build_runtime_image
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.tenacity_stop import stop_if_should_exit
class RemoteRuntime(Runtime):
@@ -51,31 +54,32 @@ class RemoteRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Optional[Callable] = None,
status_callback: Optional[Callable] = None,
attach_to_existing: bool = False,
):
# We need to set session and action_semaphore before the __init__ below, or we get odd errors
self.session = requests.Session()
self.action_semaphore = threading.Semaphore(1)
super().__init__(
config,
event_stream,
sid,
plugins,
env_vars,
status_message_callback,
status_callback,
attach_to_existing,
)
if self.config.sandbox.api_key is None:
raise ValueError(
'API key is required to use the remote runtime. '
'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
)
self.session = requests.Session()
self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
self.action_semaphore = threading.Semaphore(1)
if self.config.workspace_base is not None:
self.log(
'warning',
'debug',
'Setting workspace_base is not supported in the remote runtime.',
)
@@ -86,9 +90,13 @@ class RemoteRuntime(Runtime):
self.runtime_url: str | None = None
async def connect(self):
self._start_or_attach_to_runtime()
self._wait_until_alive()
self.setup_initial_env()
await call_sync_from_async(self._start_or_attach_to_runtime)
try:
await call_sync_from_async(self._wait_until_alive)
except RuntimeNotReadyError:
self.log('error', 'Runtime failed to start, timed out before ready')
raise
await call_sync_from_async(self.setup_initial_env)
def _start_or_attach_to_runtime(self):
existing_runtime = self._check_existing_runtime()
@@ -127,44 +135,40 @@ class RemoteRuntime(Runtime):
def _check_existing_runtime(self) -> bool:
try:
response = send_request_with_retry(
self.session,
response = self._send_request(
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.sid}',
timeout=5,
)
except Exception as e:
except requests.HTTPError as e:
if e.response.status_code == 404:
return False
self.log('debug', f'Error while looking for remote runtime: {e}')
return False
raise
if response.status_code == 200:
data = response.json()
status = data.get('status')
if status == 'running':
self._parse_runtime_response(response)
return True
elif status == 'stopped':
self.log('debug', 'Found existing remote runtime, but it is stopped')
return False
elif status == 'paused':
self.log('debug', 'Found existing remote runtime, but it is paused')
self._parse_runtime_response(response)
self._resume_runtime()
return True
else:
self.log('error', f'Invalid response from runtime API: {data}')
return False
data = response.json()
status = data.get('status')
if status == 'running':
self._parse_runtime_response(response)
return True
elif status == 'stopped':
self.log('debug', 'Found existing remote runtime, but it is stopped')
return False
elif status == 'paused':
self.log('debug', 'Found existing remote runtime, but it is paused')
self._parse_runtime_response(response)
self._resume_runtime()
return True
else:
self.log('debug', 'Could not find existing remote runtime')
self.log('error', f'Invalid response from runtime API: {data}')
return False
def _build_runtime(self):
self.log('debug', f'Building RemoteRuntime config:\n{self.config}')
response = send_request_with_retry(
self.session,
response = self._send_request(
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
timeout=30,
timeout=10,
)
response_json = response.json()
registry_prefix = response_json['registry_prefix']
@@ -191,14 +195,13 @@ class RemoteRuntime(Runtime):
force_rebuild=self.config.sandbox.force_rebuild_runtime,
)
response = send_request_with_retry(
self.session,
response = self._send_request(
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
params={'image': self.container_image},
timeout=30,
timeout=10,
)
if response.status_code != 200 or not response.json()['exists']:
if not response.json()['exists']:
raise RuntimeError(f'Container image {self.container_image} does not exist')
def _start_runtime(self):
@@ -228,17 +231,11 @@ class RemoteRuntime(Runtime):
}
# Start the sandbox using the /start endpoint
response = send_request_with_retry(
self.session,
response = self._send_request(
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/start',
json=start_request,
timeout=300,
)
if response.status_code != 201:
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Failed to start runtime: {response.text}'
)
self._parse_runtime_response(response)
self.log(
'debug',
@@ -246,17 +243,12 @@ class RemoteRuntime(Runtime):
)
def _resume_runtime(self):
response = send_request_with_retry(
self.session,
self._send_request(
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/resume',
json={'runtime_id': self.runtime_id},
timeout=30,
)
if response.status_code != 200:
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Failed to resume runtime: {response.text}'
)
self.log('debug', 'Runtime resumed.')
def _parse_runtime_response(self, response: requests.Response):
@@ -268,72 +260,57 @@ class RemoteRuntime(Runtime):
{'X-Session-API-Key': start_response['session_api_key']}
)
@tenacity.retry(
stop=tenacity.stop_after_delay(180) | stop_if_should_exit(),
reraise=True,
retry=tenacity.retry_if_exception_type(RuntimeNotReadyError),
wait=tenacity.wait_fixed(2),
)
def _wait_until_alive(self):
self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}')
# send GET request to /runtime/<id>
pod_running = False
max_not_found_count = 12 # 2 minutes
not_found_count = 0
while not pod_running:
runtime_info_response = send_request_with_retry(
self.session,
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}',
timeout=5,
)
if runtime_info_response.status_code != 200:
raise RuntimeError(
f'Failed to get runtime status: {runtime_info_response.status_code}. Response: {runtime_info_response.text}'
)
runtime_data = runtime_info_response.json()
assert runtime_data['runtime_id'] == self.runtime_id
pod_status = runtime_data['pod_status']
self.log(
'debug',
f'Waiting for runtime pod to be active. Current status: {pod_status}',
)
if pod_status == 'Ready':
pod_running = True
break
elif pod_status == 'Not Found' and not_found_count < max_not_found_count:
not_found_count += 1
self.log(
'debug',
f'Runtime pod not found. Count: {not_found_count} / {max_not_found_count}',
)
elif pod_status in ('Failed', 'Unknown', 'Not Found'):
# clean up the runtime
self.close()
raise RuntimeError(
f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
)
# Pending otherwise - add proper sleep
time.sleep(10)
response = send_request_with_retry(
self.session,
runtime_info_response = self._send_request(
'GET',
f'{self.runtime_url}/alive',
# Retry 404 & 503 errors for the /alive endpoint
# because the runtime might just be starting up
# and have not registered the endpoint yet
retry_fns=[is_404_error, is_503_error],
# leave enough time for the runtime to start up
timeout=600,
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}',
)
if response.status_code != 200:
msg = f'Runtime (ID={self.runtime_id}) is not alive yet. Status: {response.status_code}.'
self.log('warning', msg)
raise RuntimeError(msg)
runtime_data = runtime_info_response.json()
assert 'runtime_id' in runtime_data
assert runtime_data['runtime_id'] == self.runtime_id
assert 'pod_status' in runtime_data
pod_status = runtime_data['pod_status']
if pod_status == 'Ready':
try:
self._send_request(
'GET',
f'{self.runtime_url}/alive',
) # will raise exception if we don't get 200 back.
except requests.HTTPError as e:
self.log(
'warning', f"Runtime /alive failed, but pod says it's ready: {e}"
)
raise RuntimeNotReadyError(
f'Runtime /alive failed to respond with 200: {e}'
)
return
if pod_status in ('Failed', 'Unknown', 'Not Found'):
# clean up the runtime
self.close()
raise RuntimeError(
f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
)
self.log(
'debug',
f'Waiting for runtime pod to be active. Current status: {pod_status}',
)
raise RuntimeNotReadyError()
def close(self, timeout: int = 10):
if self.config.sandbox.keep_remote_runtime_alive or self.attach_to_existing:
self.session.close()
return
if self.runtime_id:
if self.runtime_id and self.session:
try:
response = send_request_with_retry(
self.session,
response = self._send_request(
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/stop',
json={'runtime_id': self.runtime_id},
@@ -361,12 +338,11 @@ class RemoteRuntime(Runtime):
return NullObservation('')
action_type = action.action # type: ignore[attr-defined]
if action_type not in ACTION_TYPE_TO_CLASS:
return FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action {action_type} does not exist.'
)
raise ValueError(f'Action {action_type} does not exist.')
if not hasattr(self, action_type):
return FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.'
return ErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.',
error_id='AGENT_ERROR$BAD_ACTION',
)
assert action.timeout is not None
@@ -374,36 +350,37 @@ class RemoteRuntime(Runtime):
try:
request_body = {'action': event_to_dict(action)}
self.log('debug', f'Request body: {request_body}')
response = send_request_with_retry(
self.session,
response = self._send_request(
'POST',
f'{self.runtime_url}/execute_action',
json=request_body,
timeout=action.timeout,
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
)
if response.status_code == 200:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
return obs
else:
error_message = response.text
self.log('error', f'Error from server: {error_message}')
obs = FatalErrorObservation(
f'Action execution failed: {error_message}'
)
except Timeout:
self.log('error', 'No response received within the timeout period.')
obs = FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action execution timed out'
)
except Exception as e:
self.log('error', f'Error during action execution: {e}')
obs = FatalErrorObservation(
f'[Runtime (ID={self.runtime_id})] Action execution failed: {str(e)}'
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
except requests.Timeout:
raise RuntimeError(
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
)
return obs
def _send_request(self, method, url, **kwargs):
is_runtime_request = self.runtime_url and self.runtime_url in url
try:
return send_request(self.session, method, url, **kwargs)
except requests.Timeout:
self.log('error', 'No response received within the timeout period.')
raise
except requests.HTTPError as e:
if is_runtime_request and e.response.status_code == 404:
raise RuntimeDisconnectedError(
f'404 error while connecting to {self.runtime_url}'
)
else:
raise e
def run(self, action: CmdRunAction) -> Observation:
return self.run_action(action)
@@ -450,32 +427,16 @@ class RemoteRuntime(Runtime):
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
response = send_request_with_retry(
self.session,
response = self._send_request(
'POST',
f'{self.runtime_url}/upload_file',
files=upload_data,
params=params,
timeout=300,
)
if response.status_code == 200:
self.log(
'debug',
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
)
return
else:
error_message = response.text
raise Exception(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
)
except TimeoutError:
raise TimeoutError(
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
)
except Exception as e:
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
self.log(
'debug',
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
)
finally:
if recursive:
@@ -485,64 +446,27 @@ class RemoteRuntime(Runtime):
)
def list_files(self, path: str | None = None) -> list[str]:
try:
data = {}
if path is not None:
data['path'] = path
data = {}
if path is not None:
data['path'] = path
response = send_request_with_retry(
self.session,
'POST',
f'{self.runtime_url}/list_files',
json=data,
timeout=30,
)
if response.status_code == 200:
response_json = response.json()
assert isinstance(response_json, list)
return response_json
else:
error_message = response.text
raise Exception(
f'[Runtime (ID={self.runtime_id})] List files operation failed: {error_message}'
)
except TimeoutError:
raise TimeoutError(
f'[Runtime (ID={self.runtime_id})] List files operation timed out'
)
except Exception as e:
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] List files operation failed: {str(e)}'
)
response = self._send_request(
'POST',
f'{self.runtime_url}/list_files',
json=data,
timeout=30,
)
response_json = response.json()
assert isinstance(response_json, list)
return response_json
def copy_from(self, path: str) -> bytes:
"""Zip all files in the sandbox and return as a stream of bytes."""
try:
params = {'path': path}
response = send_request_with_retry(
self.session,
'GET',
f'{self.runtime_url}/download_files',
params=params,
timeout=30,
)
if response.status_code == 200:
return response.content
else:
error_message = response.text
raise Exception(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
)
except requests.Timeout:
raise TimeoutError(
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
)
except Exception as e:
raise RuntimeError(
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
)
def send_status_message(self, message: str):
"""Sends a status message if the callback function was provided."""
if self.status_message_callback:
self.status_message_callback(message)
params = {'path': path}
response = self._send_request(
'GET',
f'{self.runtime_url}/download_files',
params=params,
timeout=30,
)
return response.content

View File

@@ -9,7 +9,7 @@ from openhands.events.action import CmdRunAction
from openhands.events.event import EventSource
from openhands.events.observation import (
CmdOutputObservation,
FatalErrorObservation,
ErrorObservation,
)
SOFT_TIMEOUT_SECONDS = 5
@@ -275,7 +275,7 @@ class BashSession:
output += '\r\n' + bash_prompt
return output, exit_code
def run(self, action: CmdRunAction) -> CmdOutputObservation | FatalErrorObservation:
def run(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation:
try:
assert (
action.timeout is not None
@@ -329,6 +329,6 @@ class BashSession:
interpreter_details=python_interpreter,
)
except UnicodeDecodeError as e:
return FatalErrorObservation(
f'Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}'
return ErrorObservation(
f'Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}',
)

View File

@@ -13,7 +13,6 @@ from openhands.events.action import (
)
from openhands.events.observation import (
ErrorObservation,
FatalErrorObservation,
FileEditObservation,
FileReadObservation,
FileWriteObservation,
@@ -214,9 +213,7 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
if isinstance(obs, ErrorObservation):
return obs
if not isinstance(obs, FileWriteObservation):
return FatalErrorObservation(
f'Fatal Runtime in editing: Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
)
raise ValueError(f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}')
return FileEditObservation(
content=get_diff('', action.content, action.path),
path=action.path,
@@ -225,9 +222,7 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
new_content=action.content,
)
if not isinstance(obs, FileReadObservation):
return FatalErrorObservation(
f'Fatal Runtime in editing: Expected FileReadObservation, got {type(obs)}: {str(obs)}'
)
raise ValueError(f'Expected FileReadObservation, got {type(obs)}: {str(obs)}')
original_file_content = obs.content
old_file_lines = original_file_content.split('\n')

View File

@@ -1,22 +1,12 @@
from typing import Any, Callable, Type
from typing import Any
import requests
from requests.exceptions import (
ChunkedEncodingError,
ConnectionError,
)
from tenacity import (
retry,
retry_if_exception,
retry_if_exception_type,
stop_after_delay,
wait_exponential,
)
from urllib3.exceptions import IncompleteRead
from openhands.core.logger import openhands_logger as logger
from openhands.utils.tenacity_stop import stop_if_should_exit
def is_server_error(exception):
return (
@@ -60,37 +50,13 @@ DEFAULT_RETRY_EXCEPTIONS = [
]
def send_request_with_retry(
def send_request(
session: requests.Session,
method: str,
url: str,
timeout: int,
retry_exceptions: list[Type[Exception]] | None = None,
retry_fns: list[Callable[[Exception], bool]] | None = None,
timeout: int = 10,
**kwargs: Any,
) -> requests.Response:
exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS
retry_condition = retry_if_exception_type(
tuple(exceptions_to_catch)
) | retry_if_exception(is_502_error)
if retry_fns is not None:
for fn in retry_fns:
retry_condition |= retry_if_exception(fn)
# wait a few more seconds to get the timeout error from client side
kwargs['timeout'] = timeout + 10
@retry(
stop=stop_after_delay(timeout) | stop_if_should_exit(),
wait=wait_exponential(multiplier=1, min=4, max=20),
retry=retry_condition,
reraise=True,
before_sleep=lambda retry_state: logger.debug(
f'Retrying {method} request to {url} due to {retry_state.outcome.exception()}. Attempt {retry_state.attempt_number}'
),
)
def _send_request_with_retry():
response = session.request(method, url, **kwargs)
response.raise_for_status()
return response
return _send_request_with_retry()
response = session.request(method, url, **kwargs)
response.raise_for_status()
return response

View File

@@ -32,7 +32,12 @@ class AgentSession:
_closed: bool = False
loop: asyncio.AbstractEventLoop | None = None
def __init__(self, sid: str, file_store: FileStore):
def __init__(
self,
sid: str,
file_store: FileStore,
status_callback: Optional[Callable] = None,
):
"""Initializes a new instance of the Session class
Parameters:
@@ -43,6 +48,7 @@ class AgentSession:
self.sid = sid
self.event_stream = EventStream(sid, file_store)
self.file_store = file_store
self._status_callback = status_callback
async def start(
self,
@@ -53,7 +59,6 @@ class AgentSession:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
"""Starts the Agent session
Parameters:
@@ -80,7 +85,6 @@ class AgentSession:
max_budget_per_task,
agent_to_llm_config,
agent_configs,
status_message_callback,
)
def _start_thread(self, *args):
@@ -99,14 +103,12 @@ class AgentSession:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
runtime_name=runtime_name,
config=config,
agent=agent,
status_message_callback=status_message_callback,
)
self._create_controller(
agent,
@@ -132,6 +134,10 @@ class AgentSession:
asyncio.get_event_loop().run_in_executor(None, inner_close)
async def stop_agent_loop_for_error(self):
if self.controller is not None:
await self.controller.set_agent_state_to(AgentState.ERROR)
async def _close(self):
if self._closed:
return
@@ -162,7 +168,6 @@ class AgentSession:
runtime_name: str,
config: AppConfig,
agent: Agent,
status_message_callback: Optional[Callable] = None,
):
"""Creates a runtime instance
@@ -182,13 +187,17 @@ class AgentSession:
event_stream=self.event_stream,
sid=self.sid,
plugins=agent.sandbox_plugins,
status_message_callback=status_message_callback,
status_callback=self._status_callback,
)
try:
await self.runtime.connect()
except Exception as e:
logger.error(f'Runtime initialization failed: {e}', exc_info=True)
if self._status_callback:
self._status_callback(
'error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e)
)
raise
if self.runtime is not None:
@@ -252,9 +261,8 @@ class AgentSession:
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
confirmation_mode=confirmation_mode,
# AgentSession is designed to communicate with the frontend, so we don't want to
# run the agent in headless mode.
headless_mode=False,
status_callback=self._status_callback,
)
try:
agent_state = State.restore_from_session(self.sid, self.file_store)

View File

@@ -40,7 +40,9 @@ class Session:
self.sid = sid
self.websocket = ws
self.last_active_ts = int(time.time())
self.agent_session = AgentSession(sid, file_store)
self.agent_session = AgentSession(
sid, file_store, status_callback=self.queue_status_message
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event
)
@@ -115,7 +117,6 @@ class Session:
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
status_message_callback=self.queue_status_message,
)
except Exception as e:
logger.exception(f'Error creating controller: {e}')
@@ -171,14 +172,6 @@ class Session:
'Model does not support image upload, change to a different model or try without an image.'
)
return
if self.loop:
asyncio.run_coroutine_threadsafe(
self._add_event(event, EventSource.USER), self.loop
) # type: ignore
else:
raise RuntimeError('No event loop found')
async def _add_event(self, event, event_source):
self.agent_session.event_stream.add_event(event, EventSource.USER)
async def send(self, data: dict[str, object]) -> bool:
@@ -200,11 +193,17 @@ class Session:
"""Sends an error message to the client."""
return await self.send({'error': True, 'message': message})
async def send_status_message(self, message: str) -> bool:
async def _send_status_message(self, msg_type: str, id: str, message: str) -> bool:
"""Sends a status message to the client."""
return await self.send({'status': message})
if msg_type == 'error':
await self.agent_session.stop_agent_loop_for_error()
def queue_status_message(self, message: str):
return await self.send(
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
)
def queue_status_message(self, msg_type: str, id: str, message: str):
"""Queues a status message to be sent asynchronously."""
# Ensure the coroutine runs in the main event loop
asyncio.run_coroutine_threadsafe(self.send_status_message(message), self.loop)
asyncio.run_coroutine_threadsafe(
self._send_status_message(msg_type, id, message), self.loop
)

View File

@@ -0,0 +1,231 @@
"""Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox."""
import asyncio
import os
import tempfile
from unittest.mock import MagicMock
import pandas as pd
import pytest
from conftest import TEST_IN_CI
from evaluation.utils.shared import (
EvalException,
EvalMetadata,
EvalOutput,
assert_and_raise,
codeact_user_response,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
)
from openhands.agenthub import Agent
from openhands.controller.state.state import State
from openhands.core.config import (
AgentConfig,
AppConfig,
LLMConfig,
SandboxConfig,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
from openhands.events.serialization.event import event_to_dict
from openhands.llm import LLM
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
}
def get_config(
metadata: EvalMetadata,
) -> AppConfig:
assert (
os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL') is not None
), 'SANDBOX_REMOTE_RUNTIME_API_URL must be set.'
assert (
os.environ.get('ALLHANDS_API_KEY') is not None
), 'ALLHANDS_API_KEY must be set.'
config = AppConfig(
default_agent=metadata.agent_class,
run_as_openhands=False,
max_iterations=metadata.max_iterations,
runtime='remote',
sandbox=SandboxConfig(
base_container_image='python:3.11-bookworm',
enable_auto_lint=True,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
),
# do not mount workspace
workspace_base=None,
workspace_mount_path=None,
)
agent_config = AgentConfig(
codeact_enable_jupyter=False,
codeact_enable_browsing=False,
codeact_enable_llm_editor=False,
)
config.set_agent_config(agent_config)
return config
def initialize_runtime(
runtime: Runtime,
):
"""Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent.
"""
logger.info('-' * 30)
logger.info('BEGIN Runtime Initialization Fn')
logger.info('-' * 30)
obs: CmdOutputObservation
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
action = CmdRunAction(command='mkdir -p /dummy_dir')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0,
f'Failed to create /dummy_dir: {str(obs)}',
)
with tempfile.TemporaryDirectory() as temp_dir:
# Construct the full path for the desired file name within the temporary directory
temp_file_path = os.path.join(temp_dir, 'dummy_file')
# Write to the file with the desired name within the temporary directory
with open(temp_file_path, 'w') as f:
f.write('dummy content')
# Copy the file to the desired location
runtime.copy_to(temp_file_path, '/dummy_dir/')
logger.info('-' * 30)
logger.info('END Runtime Initialization Fn')
logger.info('-' * 30)
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
config = get_config(metadata)
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
else:
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
try:
initialize_runtime(runtime)
instruction = 'dummy instruction'
agent = Agent.get_cls(metadata.agent_class)(
llm=LLM(config=metadata.llm_config),
config=config.get_agent_config(metadata.agent_class),
)
def next_command(*args, **kwargs):
return CmdRunAction(command='ls -lah')
agent.step = MagicMock(side_effect=next_command)
# Here's how you can run the agent (similar to the `main` function) and get the final task state
state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
],
agent=agent,
)
)
# if fatal error, throw EvalError to trigger re-run
if (
state.last_error
and 'fatal error during agent execution' in state.last_error
and 'stuck in a loop' not in state.last_error
):
raise EvalException('Fatal error detected: ' + state.last_error)
finally:
runtime.close()
test_result = {}
if state is None:
raise ValueError('State should not be None.')
histories = [event_to_dict(event) for event in state.history.get_events()]
metrics = state.metrics.get() if state.metrics else None
# Save the output
output = EvalOutput(
instance_id=instance.instance_id,
instruction=instruction,
instance=instance.to_dict(), # SWE Bench specific
test_result=test_result,
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
)
return output
@pytest.mark.skipif(
TEST_IN_CI,
reason='This test should only be run locally, not in CI.',
)
def test_stress_remote_runtime(n_eval_workers: int = 64):
"""Mimic evaluation setting to test remote runtime in a multi-processing setting."""
llm_config = LLMConfig()
metadata = make_metadata(
llm_config,
'dummy_dataset_descrption',
'CodeActAgent',
max_iterations=10,
eval_note='dummy_eval_note',
eval_output_dir='./dummy_eval_output_dir',
details={},
)
# generate 300 random dummy instances
dummy_instance = pd.DataFrame(
{
'instance_id': [f'dummy_instance_{i}' for i in range(300)],
}
)
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
instances = prepare_dataset(
dummy_instance, output_file, eval_n_limit=len(dummy_instance)
)
run_evaluation(instances, metadata, output_file, n_eval_workers, process_instance)

View File

@@ -7,14 +7,12 @@ from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import TrafficControlState
from openhands.core.config import AppConfig
from openhands.core.exceptions import LLMMalformedActionError
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
from openhands.events.observation import (
ErrorObservation,
FatalErrorObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
@@ -45,6 +43,11 @@ def mock_event_stream():
return MagicMock(spec=EventStream)
@pytest.fixture
def mock_status_callback():
return AsyncMock()
@pytest.mark.asyncio
async def test_set_agent_state(mock_agent, mock_event_stream):
controller = AgentController(
@@ -98,39 +101,19 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
@pytest.mark.asyncio
async def test_report_error(mock_agent, mock_event_stream):
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
status_callback=mock_status_callback,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
error_message = 'Test error'
await controller.report_error(error_message)
assert controller.state.last_error == error_message
controller.event_stream.add_event.assert_called_once()
await controller.close()
@pytest.mark.asyncio
async def test_step_with_exception(mock_agent, mock_event_stream):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller.state.agent_state = AgentState.RUNNING
controller.report_error = AsyncMock()
controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
await controller._step()
# Verify that report_error was called with the correct error message
controller.report_error.assert_called_once_with('Malformed action')
await controller._react_to_exception(RuntimeError(error_message))
controller.status_callback.assert_called_once()
await controller.close()
@@ -141,21 +124,24 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
# a random message to send to the runtime
event = CmdRunAction(command='ls')
agent.step.return_value = event
agent = MagicMock(spec=Agent)
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls')
agent.step = agent_step_fn
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
fatal_error_obs = FatalErrorObservation('Fatal error detected')
fatal_error_obs._cause = event.id
runtime = MagicMock(spec=Runtime)
async def on_event(event: Event):
if isinstance(event, CmdRunAction):
await event_stream.async_add_event(fatal_error_obs, EventSource.USER)
error_obs = ErrorObservation('You messed around with Jim')
error_obs._cause = event.id
event_stream.add_event(error_obs, EventSource.USER)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
runtime.event_stream = event_stream
@@ -170,30 +156,23 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
)
print(f'state: {state}')
print(f'event_stream: {list(event_stream.get_events())}')
assert state.iteration == 1
# it will first become AgentState.ERROR, then become AgentState.STOPPED
# in side run_controller (since the while loop + sleep no longer loop)
assert state.agent_state == AgentState.STOPPED
assert (
state.last_error
== 'There was a fatal error during agent execution: **FatalErrorObservation**\nFatal error detected'
)
assert len(list(event_stream.get_events())) == 5
assert state.iteration == 4
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'Agent got stuck in a loop'
assert len(list(event_stream.get_events())) == 11
@pytest.mark.asyncio
async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
async def test_run_controller_stop_with_stuck():
config = AppConfig()
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
# a random message to send to the runtime
event = CmdRunAction(command='ls')
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
return event
return CmdRunAction(command='ls')
agent.step = agent_step_fn
agent.llm = MagicMock(spec=LLM)
@@ -207,9 +186,7 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
'Non fatal error here to trigger loop'
)
non_fatal_error_obs._cause = event.id
await event_stream.async_add_event(
non_fatal_error_obs, EventSource.ENVIRONMENT
)
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
runtime.event_stream = event_stream
@@ -228,7 +205,7 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
print(f'event {i}: {event_to_dict(event)}')
assert state.iteration == 4
assert len(events) == 12
assert len(events) == 11
# check the eventstream have 4 pairs of repeated actions and observations
repeating_actions_and_observations = events[2:10]
for action, observation in zip(
@@ -246,13 +223,8 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
assert last_event['extras']['agent_state'] == 'error'
assert last_event['observation'] == 'agent_state_changed'
# it will first become AgentState.ERROR, then become AgentState.STOPPED
# in side run_controller (since the while loop + sleep no longer loop)
assert state.agent_state == AgentState.STOPPED
assert (
state.last_error
== 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
)
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'Agent got stuck in a loop'
@pytest.mark.asyncio
@@ -319,7 +291,7 @@ async def test_step_max_iterations(mock_agent, mock_event_stream):
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.PAUSED
assert controller.state.agent_state == AgentState.ERROR
await controller.close()
@@ -359,7 +331,7 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.PAUSED
assert controller.state.agent_state == AgentState.ERROR
await controller.close()

View File

@@ -440,9 +440,10 @@ class TestStuckDetector:
read_observation_2._cause = read_action_2._id
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
# one more message to break the pattern
message_null_observation = NullObservation(content='')
message_action = MessageAction(content='Come on', wait_for_response=False)
event_stream.add_event(message_action, EventSource.USER)
message_null_observation = NullObservation(content='')
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
cmd_action_3 = CmdRunAction(command='ls')