mirror of
https://github.com/OpenHands/OpenHands.git
synced 2026-03-22 13:47:19 +08:00
Refactor of error handling (#4575)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
@@ -128,14 +128,14 @@ describe.skip("ChatInterface", () => {
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
{
|
||||
error: "Woops!",
|
||||
error: true,
|
||||
id: "",
|
||||
message: "Something went wrong",
|
||||
},
|
||||
];
|
||||
renderChatInterface(messages);
|
||||
|
||||
const error = screen.getByTestId("error-message");
|
||||
expect(within(error).getByText("Woops!")).toBeInTheDocument();
|
||||
expect(within(error).getByText("Something went wrong")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import React, { useEffect } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useSelector } from "react-redux";
|
||||
import toast from "react-hot-toast";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { RootState } from "#/store";
|
||||
import AgentState from "#/types/AgentState";
|
||||
@@ -16,7 +17,7 @@ enum IndicatorColor {
|
||||
}
|
||||
|
||||
function AgentStatusBar() {
|
||||
const { t } = useTranslation();
|
||||
const { t, i18n } = useTranslation();
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const { curStatusMessage } = useSelector((state: RootState) => state.status);
|
||||
|
||||
@@ -94,15 +95,27 @@ function AgentStatusBar() {
|
||||
const [statusMessage, setStatusMessage] = React.useState<string>("");
|
||||
|
||||
React.useEffect(() => {
|
||||
if (curAgentState === AgentState.LOADING) {
|
||||
const trimmedCustomMessage = curStatusMessage.status.trim();
|
||||
if (trimmedCustomMessage) {
|
||||
setStatusMessage(t(trimmedCustomMessage));
|
||||
return;
|
||||
let message = curStatusMessage.message || "";
|
||||
if (curStatusMessage?.id) {
|
||||
const id = curStatusMessage.id.trim();
|
||||
if (i18n.exists(id)) {
|
||||
message = t(curStatusMessage.id.trim()) || message;
|
||||
}
|
||||
}
|
||||
if (curStatusMessage?.type === "error") {
|
||||
toast.error(message);
|
||||
return;
|
||||
}
|
||||
if (curAgentState === AgentState.LOADING && message.trim()) {
|
||||
setStatusMessage(message);
|
||||
} else {
|
||||
setStatusMessage(AgentStatusMap[curAgentState].message);
|
||||
}
|
||||
}, [curStatusMessage.id]);
|
||||
|
||||
React.useEffect(() => {
|
||||
setStatusMessage(AgentStatusMap[curAgentState].message);
|
||||
}, [curAgentState, curStatusMessage.status]);
|
||||
}, [curAgentState]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col items-center">
|
||||
|
||||
@@ -73,7 +73,7 @@ export function ChatInterface() {
|
||||
isErrorMessage(message) ? (
|
||||
<ErrorMessage
|
||||
key={index}
|
||||
error={message.error}
|
||||
id={message.id}
|
||||
message={message.message}
|
||||
/>
|
||||
) : (
|
||||
|
||||
3
frontend/src/components/chat/message.d.ts
vendored
3
frontend/src/components/chat/message.d.ts
vendored
@@ -6,6 +6,7 @@ type Message = {
|
||||
};
|
||||
|
||||
type ErrorMessage = {
|
||||
error: string;
|
||||
error: boolean;
|
||||
id?: string;
|
||||
message: string;
|
||||
};
|
||||
|
||||
@@ -1,14 +1,41 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
interface ErrorMessageProps {
|
||||
error: string;
|
||||
id?: string;
|
||||
message: string;
|
||||
}
|
||||
|
||||
export function ErrorMessage({ error, message }: ErrorMessageProps) {
|
||||
export function ErrorMessage({ id, message }: ErrorMessageProps) {
|
||||
const { t, i18n } = useTranslation();
|
||||
const [showDetails, setShowDetails] = useState(true);
|
||||
const [headline, setHeadline] = useState("");
|
||||
const [details, setDetails] = useState(message);
|
||||
|
||||
useEffect(() => {
|
||||
if (id && i18n.exists(id)) {
|
||||
setHeadline(t(id));
|
||||
setDetails(message);
|
||||
setShowDetails(false);
|
||||
}
|
||||
}, [id, message, i18n.language]);
|
||||
|
||||
return (
|
||||
<div className="flex gap-2 items-center justify-start border-l-2 border-danger pl-2 my-2 py-2">
|
||||
<div className="text-sm leading-4 flex flex-col gap-2">
|
||||
<p className="text-danger font-bold">{error}</p>
|
||||
<p className="text-neutral-300">{message}</p>
|
||||
{headline && <p className="text-danger font-bold">{headline}</p>}
|
||||
{headline && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowDetails(!showDetails)}
|
||||
className="cursor-pointer text-left"
|
||||
>
|
||||
{showDetails
|
||||
? t("ERROR_MESSAGE$HIDE_DETAILS")
|
||||
: t("ERROR_MESSAGE$SHOW_DETAILS")}
|
||||
</button>
|
||||
)}
|
||||
{showDetails && <p className="text-neutral-300">{details}</p>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1441,6 +1441,12 @@
|
||||
"fr": "Privé",
|
||||
"tr": "Özel"
|
||||
},
|
||||
"ERROR_MESSAGE$SHOW_DETAILS": {
|
||||
"en": "Show details"
|
||||
},
|
||||
"ERROR_MESSAGE$HIDE_DETAILS": {
|
||||
"en": "Hide details"
|
||||
},
|
||||
"STATUS$STARTING_RUNTIME": {
|
||||
"en": "Starting Runtime...",
|
||||
"zh-CN": "启动运行时...",
|
||||
@@ -1510,5 +1516,17 @@
|
||||
"ar": "في انتظار جاهزية العميل...",
|
||||
"fr": "En attente que le client soit prêt...",
|
||||
"tr": "İstemcinin hazır olması bekleniyor..."
|
||||
},
|
||||
"STATUS$ERROR_LLM_AUTHENTICATION": {
|
||||
"en": "Error authenticating with the LLM provider. Please check your API key"
|
||||
},
|
||||
"STATUS$ERROR_RUNTIME_DISCONNECTED": {
|
||||
"en": "There was an error while connecting to the runtime. Please refresh the page."
|
||||
},
|
||||
"AGENT_ERROR$BAD_ACTION": {
|
||||
"en": "Agent tried to execute a malformed action."
|
||||
},
|
||||
"AGENT_ERROR$ACTION_TIMEOUT": {
|
||||
"en": "Action timed out."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,21 +184,6 @@ function App() {
|
||||
if (q) addIntialQueryToChat(q, files);
|
||||
}, [settings]);
|
||||
|
||||
const handleError = (message: string) => {
|
||||
const [error, ...rest] = message.split(":");
|
||||
const details = rest.join(":");
|
||||
if (!details) {
|
||||
dispatch(
|
||||
addErrorMessage({
|
||||
error: "An error has occured",
|
||||
message: error,
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(addErrorMessage({ error, message: details }));
|
||||
}
|
||||
};
|
||||
|
||||
const handleMessage = React.useCallback(
|
||||
(message: MessageEvent<WebSocket.Data>) => {
|
||||
// set token received from the server
|
||||
@@ -224,7 +209,12 @@ function App() {
|
||||
return;
|
||||
}
|
||||
if (isErrorObservation(parsed)) {
|
||||
handleError(parsed.message);
|
||||
dispatch(
|
||||
addErrorMessage({
|
||||
id: parsed.extras?.error_id,
|
||||
message: parsed.message,
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import { addAssistantMessage, addUserMessage } from "#/state/chatSlice";
|
||||
import {
|
||||
addAssistantMessage,
|
||||
addUserMessage,
|
||||
addErrorMessage,
|
||||
} from "#/state/chatSlice";
|
||||
import { setCode, setActiveFilepath } from "#/state/codeSlice";
|
||||
import { appendJupyterInput } from "#/state/jupyterSlice";
|
||||
import {
|
||||
@@ -119,13 +123,19 @@ export function handleActionMessage(message: ActionMessage) {
|
||||
}
|
||||
|
||||
export function handleStatusMessage(message: StatusMessage) {
|
||||
const msg = message.status == null ? "" : message.status.trim();
|
||||
store.dispatch(
|
||||
setCurStatusMessage({
|
||||
...message,
|
||||
status: msg,
|
||||
}),
|
||||
);
|
||||
if (message.type === "info") {
|
||||
store.dispatch(
|
||||
setCurStatusMessage({
|
||||
...message,
|
||||
}),
|
||||
);
|
||||
} else if (message.type === "error") {
|
||||
store.dispatch(
|
||||
addErrorMessage({
|
||||
...message,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export function handleAssistantMessage(data: string | SocketMessage) {
|
||||
@@ -139,9 +149,11 @@ export function handleAssistantMessage(data: string | SocketMessage) {
|
||||
|
||||
if ("action" in socketMessage) {
|
||||
handleActionMessage(socketMessage);
|
||||
} else if ("status" in socketMessage) {
|
||||
} else if ("observation" in socketMessage) {
|
||||
handleObservationMessage(socketMessage);
|
||||
} else if ("status_update" in socketMessage) {
|
||||
handleStatusMessage(socketMessage);
|
||||
} else {
|
||||
handleObservationMessage(socketMessage);
|
||||
console.error("Unknown message type", socketMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,10 +39,10 @@ export const chatSlice = createSlice({
|
||||
|
||||
addErrorMessage(
|
||||
state,
|
||||
action: PayloadAction<{ error: string; message: string }>,
|
||||
action: PayloadAction<{ id?: string; message: string }>,
|
||||
) {
|
||||
const { error, message } = action.payload;
|
||||
state.messages.push({ error, message });
|
||||
const { id, message } = action.payload;
|
||||
state.messages.push({ id, message, error: true });
|
||||
},
|
||||
|
||||
clearMessages(state) {
|
||||
|
||||
@@ -2,8 +2,10 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
|
||||
import { StatusMessage } from "#/types/Message";
|
||||
|
||||
const initialStatusMessage: StatusMessage = {
|
||||
status: "",
|
||||
is_error: false,
|
||||
status_update: true,
|
||||
type: "info",
|
||||
id: "",
|
||||
message: "",
|
||||
};
|
||||
|
||||
export const statusSlice = createSlice({
|
||||
|
||||
@@ -33,10 +33,8 @@ export interface ObservationMessage {
|
||||
}
|
||||
|
||||
export interface StatusMessage {
|
||||
// TODO not implemented yet
|
||||
// Whether the status is an error, default is false
|
||||
is_error: boolean;
|
||||
|
||||
// A status message to display to the user
|
||||
status: string;
|
||||
status_update: true;
|
||||
type: string;
|
||||
id: string;
|
||||
message: string;
|
||||
}
|
||||
|
||||
@@ -54,6 +54,9 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> {
|
||||
|
||||
export interface ErrorObservation extends OpenHandsObservationEvent<"error"> {
|
||||
source: "user";
|
||||
extras: {
|
||||
error_id?: string;
|
||||
};
|
||||
}
|
||||
|
||||
export type OpenHandsObservation =
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import traceback
|
||||
from typing import Type
|
||||
from typing import Callable, Type
|
||||
|
||||
import litellm
|
||||
|
||||
@@ -35,9 +35,7 @@ from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
FatalErrorObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
@@ -77,6 +75,7 @@ class AgentController:
|
||||
initial_state: State | None = None,
|
||||
is_delegate: bool = False,
|
||||
headless_mode: bool = True,
|
||||
status_callback: Callable | None = None,
|
||||
):
|
||||
"""Initializes a new instance of the AgentController class.
|
||||
|
||||
@@ -119,6 +118,7 @@ class AgentController:
|
||||
|
||||
# stuck helper
|
||||
self._stuck_detector = StuckDetector(self.state)
|
||||
self.status_callback = status_callback
|
||||
|
||||
async def close(self):
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
|
||||
@@ -132,7 +132,7 @@ class AgentController:
|
||||
message (str): The message to log.
|
||||
"""
|
||||
message = f'[Agent Controller {self.id}] {message}'
|
||||
getattr(logger, level)(message, extra=extra)
|
||||
getattr(logger, level)(message, extra=extra, stacklevel=2)
|
||||
|
||||
def update_state_before_step(self):
|
||||
self.state.iteration += 1
|
||||
@@ -142,22 +142,16 @@ class AgentController:
|
||||
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
|
||||
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
|
||||
|
||||
async def report_error(self, message: str, exception: Exception | None = None):
|
||||
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.
|
||||
|
||||
This method should be called for a particular type of errors, which have:
|
||||
- a user-friendly message, which will be shown in the chat box. This should not be a raw exception message.
|
||||
- an ErrorObservation that can be sent to the LLM by the user role, with the exception message, so it can self-correct next time.
|
||||
"""
|
||||
self.state.last_error = message
|
||||
if exception:
|
||||
self.state.last_error += f': {exception}'
|
||||
detail = str(exception) if exception is not None else ''
|
||||
if exception is not None and isinstance(exception, litellm.AuthenticationError):
|
||||
detail = 'Please check your credentials. Is your API key correct?'
|
||||
self.event_stream.add_event(
|
||||
ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT
|
||||
)
|
||||
async def _react_to_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
):
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
if self.status_callback is not None:
|
||||
err_id = ''
|
||||
if isinstance(e, litellm.AuthenticationError):
|
||||
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
|
||||
self.status_callback('error', err_id, str(e))
|
||||
|
||||
async def start_step_loop(self):
|
||||
"""The main loop for the agent's step-by-step execution."""
|
||||
@@ -172,12 +166,7 @@ class AgentController:
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
self.log('error', f'Error while running the agent: {e}')
|
||||
self.log('error', traceback.format_exc())
|
||||
await self.report_error(
|
||||
'There was an unexpected error while running the agent', exception=e
|
||||
)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
break
|
||||
await self._react_to_exception(e)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -227,15 +216,6 @@ class AgentController:
|
||||
Args:
|
||||
observation (observation): The observation to handle.
|
||||
"""
|
||||
if (
|
||||
self._pending_action
|
||||
and hasattr(self._pending_action, 'confirmation_state')
|
||||
and self._pending_action.confirmation_state
|
||||
== ActionConfirmationStatus.AWAITING_CONFIRMATION
|
||||
):
|
||||
return
|
||||
|
||||
# Make sure we print the observation in the same way as the LLM sees it
|
||||
observation_to_print = copy.deepcopy(observation)
|
||||
if len(observation_to_print.content) > self.agent.llm.config.max_message_chars:
|
||||
observation_to_print.content = truncate_content(
|
||||
@@ -243,7 +223,6 @@ class AgentController:
|
||||
)
|
||||
self.log('debug', str(observation_to_print), extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
# Merge with the metrics from the LLM - it will to synced to the controller's local metrics in update_state_after_step()
|
||||
if observation.llm_metrics is not None:
|
||||
self.agent.llm.metrics.merge(observation.llm_metrics)
|
||||
|
||||
@@ -255,19 +234,11 @@ class AgentController:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
return
|
||||
|
||||
if isinstance(observation, CmdOutputObservation):
|
||||
return
|
||||
elif isinstance(observation, AgentDelegateObservation):
|
||||
if isinstance(observation, AgentDelegateObservation):
|
||||
self.state.history.on_event(observation)
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
if self.state.agent_state == AgentState.ERROR:
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
elif isinstance(observation, FatalErrorObservation):
|
||||
self.state.last_error = (
|
||||
f'There was a fatal error during agent execution: {str(observation)}'
|
||||
)
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
|
||||
async def _handle_message_action(self, action: MessageAction):
|
||||
"""Handles message actions from the event stream.
|
||||
@@ -420,13 +391,8 @@ class AgentController:
|
||||
await asyncio.sleep(1)
|
||||
return
|
||||
|
||||
# check if agent got stuck before taking any action
|
||||
if self._is_stuck():
|
||||
# This need to go BEFORE report_error to sync metrics
|
||||
self.event_stream.add_event(
|
||||
FatalErrorObservation('Agent got stuck in a loop'),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
await self._react_to_exception(RuntimeError('Agent got stuck in a loop'))
|
||||
return
|
||||
|
||||
if self.delegate is not None:
|
||||
@@ -465,15 +431,12 @@ class AgentController:
|
||||
if action is None:
|
||||
raise LLMNoActionError('No action was returned')
|
||||
except (LLMMalformedActionError, LLMNoActionError, LLMResponseError) as e:
|
||||
# report to the user
|
||||
# and send the underlying exception to the LLM for self-correction
|
||||
await self.report_error(str(e))
|
||||
return
|
||||
# FIXME: more graceful handling of litellm.exceptions.ContextWindowExceededError
|
||||
# e.g. try to condense the memory and try again
|
||||
except litellm.exceptions.ContextWindowExceededError as e:
|
||||
self.state.last_error = str(e)
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
self.event_stream.add_event(
|
||||
ErrorObservation(
|
||||
content=str(e),
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
return
|
||||
|
||||
if action.runnable:
|
||||
@@ -495,6 +458,7 @@ class AgentController:
|
||||
self.event_stream.add_event(action, EventSource.AGENT)
|
||||
|
||||
await self.update_state_after_step()
|
||||
|
||||
self.log('debug', str(action), extra={'msg_type': 'ACTION'})
|
||||
|
||||
async def _delegate_step(self):
|
||||
@@ -524,7 +488,10 @@ class AgentController:
|
||||
self.delegate = None
|
||||
self.delegateAction = None
|
||||
|
||||
await self.report_error('Delegator agent encountered an error')
|
||||
self.event_stream.add_event(
|
||||
ErrorObservation('Delegate agent encountered an error'),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
self.log('debug', 'Delegate agent has finished execution')
|
||||
# retrieve delegate result
|
||||
@@ -571,21 +538,18 @@ class AgentController:
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
if self.headless_mode:
|
||||
# This need to go BEFORE report_error to sync metrics
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
# set to ERROR state if running in headless mode
|
||||
# since user cannot resume on the web interface
|
||||
await self.report_error(
|
||||
f'Agent reached maximum {limit_type} in headless mode, task stopped. '
|
||||
e = RuntimeError(
|
||||
f'Agent reached maximum {limit_type} in headless mode. '
|
||||
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
|
||||
)
|
||||
await self._react_to_exception(e)
|
||||
else:
|
||||
await self.set_agent_state_to(AgentState.PAUSED)
|
||||
await self.report_error(
|
||||
f'Agent reached maximum {limit_type}, task paused. '
|
||||
e = RuntimeError(
|
||||
f'Agent reached maximum {limit_type}. '
|
||||
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
|
||||
f'{TRAFFIC_CONTROL_REMINDER}'
|
||||
)
|
||||
# FIXME: this isn't really an exception--we should have a different path
|
||||
await self._react_to_exception(e)
|
||||
stop_step = True
|
||||
return stop_step
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import AgentFinishAction
|
||||
from openhands.events.observation import ErrorObservation
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.history import ShortTermHistory
|
||||
from openhands.storage.files import FileStore
|
||||
@@ -80,7 +81,6 @@ class State:
|
||||
history: ShortTermHistory = field(default_factory=ShortTermHistory)
|
||||
inputs: dict = field(default_factory=dict)
|
||||
outputs: dict = field(default_factory=dict)
|
||||
last_error: str | None = None
|
||||
agent_state: AgentState = AgentState.LOADING
|
||||
resume_state: AgentState | None = None
|
||||
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
|
||||
@@ -97,6 +97,7 @@ class State:
|
||||
# NOTE: This will never be used by the controller, but it can be used by different
|
||||
# evaluation tasks to store extra data needed to track the progress/state of the task.
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
last_error: str = ''
|
||||
|
||||
def save_to_session(self, sid: str, file_store: FileStore):
|
||||
pickled = pickle.dumps(self)
|
||||
@@ -124,9 +125,6 @@ class State:
|
||||
else:
|
||||
state.resume_state = None
|
||||
|
||||
# don't carry last_error anymore after restore
|
||||
state.last_error = None
|
||||
|
||||
# first state after restore
|
||||
state.agent_state = AgentState.LOADING
|
||||
return state
|
||||
@@ -151,11 +149,9 @@ class State:
|
||||
if not hasattr(self, 'history'):
|
||||
self.history = ShortTermHistory()
|
||||
|
||||
# restore the relevant data in history from the state
|
||||
self.history.start_id = self.start_id
|
||||
self.history.end_id = self.end_id
|
||||
|
||||
# remove the restored data from the state if any
|
||||
|
||||
def get_current_user_intent(self):
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from typing import Type
|
||||
|
||||
from termcolor import colored
|
||||
@@ -13,6 +14,7 @@ from openhands.core.config import (
|
||||
load_app_config,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.loop import run_agent_until_done
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
@@ -114,7 +116,6 @@ async def main():
|
||||
sid=sid,
|
||||
plugins=agent_cls.sandbox_plugins,
|
||||
)
|
||||
await runtime.connect()
|
||||
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
@@ -124,11 +125,14 @@ async def main():
|
||||
event_stream=event_stream,
|
||||
)
|
||||
|
||||
if controller is not None:
|
||||
controller.agent_task = asyncio.create_task(controller.start_step_loop())
|
||||
|
||||
async def prompt_for_next_task():
|
||||
next_message = input('How can I help? >> ')
|
||||
# Run input() in a thread pool to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
next_message = await loop.run_in_executor(
|
||||
None, lambda: input('How can I help? >> ')
|
||||
)
|
||||
if not next_message.strip():
|
||||
await prompt_for_next_task()
|
||||
if next_message == 'exit':
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
|
||||
@@ -140,31 +144,45 @@ async def main():
|
||||
async def on_event(event: Event):
|
||||
display_event(event)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.ERROR:
|
||||
print('An error occurred. Please try again.')
|
||||
if event.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
]:
|
||||
await prompt_for_next_task()
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
|
||||
await prompt_for_next_task()
|
||||
await runtime.connect()
|
||||
|
||||
while controller.state.agent_state not in [
|
||||
AgentState.STOPPED,
|
||||
]:
|
||||
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
|
||||
asyncio.create_task(prompt_for_next_task())
|
||||
|
||||
print('Exiting...')
|
||||
await controller.close()
|
||||
await run_agent_until_done(
|
||||
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
print('Received keyboard interrupt, shutting down...')
|
||||
except ConnectionRefusedError as e:
|
||||
print(f'Connection refused: {e}')
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f'An error occurred: {e}')
|
||||
sys.exit(1)
|
||||
finally:
|
||||
pass
|
||||
try:
|
||||
# Cancel all running tasks
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
# Wait for all tasks to complete with a timeout
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
print(f'Error during cleanup: {e}')
|
||||
sys.exit(1)
|
||||
|
||||
50
openhands/core/loop.py
Normal file
50
openhands/core/loop.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
|
||||
from openhands.controller import AgentController
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.runtime.base import Runtime
|
||||
|
||||
|
||||
async def run_agent_until_done(
|
||||
controller: AgentController,
|
||||
runtime: Runtime,
|
||||
end_states: list[AgentState],
|
||||
):
|
||||
"""
|
||||
run_agent_until_done takes a controller and a runtime, and will run
|
||||
the agent until it reaches a terminal state.
|
||||
Note that runtime must be connected before being passed in here.
|
||||
"""
|
||||
controller.agent_task = asyncio.create_task(controller.start_step_loop())
|
||||
|
||||
def status_callback(msg_type, msg_id, msg):
|
||||
if msg_type == 'error':
|
||||
logger.error(msg)
|
||||
if controller:
|
||||
controller.state.last_error = msg
|
||||
asyncio.create_task(controller.set_agent_state_to(AgentState.ERROR))
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
if hasattr(runtime, 'status_callback') and runtime.status_callback:
|
||||
raise ValueError(
|
||||
'Runtime status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
if hasattr(controller, 'status_callback') and controller.status_callback:
|
||||
raise ValueError(
|
||||
'Controller status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
|
||||
runtime.status_callback = status_callback
|
||||
controller.status_callback = status_callback
|
||||
|
||||
while controller.state.agent_state not in end_states:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not controller.agent_task.done():
|
||||
controller.agent_task.cancel()
|
||||
try:
|
||||
await controller.agent_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@@ -17,6 +17,7 @@ from openhands.core.config import (
|
||||
parse_arguments,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.loop import run_agent_until_done
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import MessageAction
|
||||
@@ -122,7 +123,6 @@ async def run_controller(
|
||||
|
||||
if runtime is None:
|
||||
runtime = create_runtime(config, sid=sid)
|
||||
await runtime.connect()
|
||||
|
||||
event_stream = runtime.event_stream
|
||||
# restore cli session if enabled
|
||||
@@ -147,9 +147,6 @@ async def run_controller(
|
||||
headless_mode=headless_mode,
|
||||
)
|
||||
|
||||
if controller is not None:
|
||||
controller.agent_task = asyncio.create_task(controller.start_step_loop())
|
||||
|
||||
assert isinstance(
|
||||
initial_user_action, Action
|
||||
), f'initial user actions must be an Action, got {type(initial_user_action)}'
|
||||
@@ -188,22 +185,27 @@ async def run_controller(
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
while controller.state.agent_state not in [
|
||||
|
||||
await runtime.connect()
|
||||
|
||||
end_states = [
|
||||
AgentState.FINISHED,
|
||||
AgentState.REJECTED,
|
||||
AgentState.ERROR,
|
||||
AgentState.PAUSED,
|
||||
AgentState.STOPPED,
|
||||
]:
|
||||
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
|
||||
]
|
||||
|
||||
try:
|
||||
await run_agent_until_done(controller, runtime, end_states)
|
||||
except Exception as e:
|
||||
logger.error(f'Exception in main loop: {e}')
|
||||
|
||||
# save session when we're about to close
|
||||
if config.enable_cli_session:
|
||||
end_state = controller.get_state()
|
||||
end_state.save_to_session(event_stream.sid, event_stream.file_store)
|
||||
|
||||
# close when done
|
||||
await controller.close()
|
||||
state = controller.get_state()
|
||||
|
||||
# save trajectories if applicable
|
||||
|
||||
@@ -6,7 +6,7 @@ from openhands.events.observation.commands import (
|
||||
)
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import ErrorObservation, FatalErrorObservation
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.files import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
@@ -26,7 +26,6 @@ __all__ = [
|
||||
'FileWriteObservation',
|
||||
'FileEditObservation',
|
||||
'ErrorObservation',
|
||||
'FatalErrorObservation',
|
||||
'AgentStateChangedObservation',
|
||||
'AgentDelegateObservation',
|
||||
'SuccessObservation',
|
||||
|
||||
@@ -13,6 +13,7 @@ class ErrorObservation(Observation):
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.ERROR
|
||||
error_id: str = ''
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
@@ -20,17 +21,3 @@ class ErrorObservation(Observation):
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'**ErrorObservation**\n{self.content}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FatalErrorObservation(Observation):
|
||||
"""This data class represents a fatal error encountered by the agent.
|
||||
|
||||
This is the type of error that LLM CANNOT recover from, and the agent controller should stop the execution and report the error to the user.
|
||||
E.g., Remote runtime action execution failure: 503 Server Error: Service Unavailable for url OR 404 Not Found.
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.ERROR
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'**FatalErrorObservation**\n{self.content}'
|
||||
|
||||
@@ -152,12 +152,16 @@ class EventStream:
|
||||
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(self.async_add_event(event, source))
|
||||
asyncio.get_running_loop().create_task(self._async_add_event(event, source))
|
||||
except RuntimeError:
|
||||
# No event loop running...
|
||||
asyncio.run(self.async_add_event(event, source))
|
||||
asyncio.run(self._async_add_event(event, source))
|
||||
|
||||
async def async_add_event(self, event: Event, source: EventSource):
|
||||
async def _async_add_event(self, event: Event, source: EventSource):
|
||||
if hasattr(event, '_id') and event.id is not None:
|
||||
raise ValueError(
|
||||
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
|
||||
)
|
||||
with self._lock:
|
||||
event._id = self._cur_id # type: ignore [attr-defined]
|
||||
self._cur_id += 1
|
||||
|
||||
@@ -12,7 +12,6 @@ from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.observation.error import FatalErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.events.stream import EventStream
|
||||
@@ -34,7 +33,6 @@ class ShortTermHistory(list[Event]):
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
FatalErrorObservation,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -37,7 +37,6 @@ from openhands.events.action import (
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
FatalErrorObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
IPythonRunCellObservation,
|
||||
@@ -168,7 +167,7 @@ class ActionExecutor:
|
||||
|
||||
async def run(
|
||||
self, action: CmdRunAction
|
||||
) -> CmdOutputObservation | FatalErrorObservation:
|
||||
) -> CmdOutputObservation | ErrorObservation:
|
||||
return self.bash_session.run(action)
|
||||
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
|
||||
@@ -5,6 +5,8 @@ import os
|
||||
from abc import abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
from openhands.core.config import AppConfig, SandboxConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
@@ -31,6 +33,22 @@ from openhands.runtime.plugins import JupyterRequirement, PluginRequirement
|
||||
from openhands.runtime.utils.edit import FileEditRuntimeMixin
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
STATUS_MESSAGES = {
|
||||
'STATUS$STARTING_RUNTIME': 'Starting runtime...',
|
||||
'STATUS$STARTING_CONTAINER': 'Starting container...',
|
||||
'STATUS$PREPARING_CONTAINER': 'Preparing container...',
|
||||
'STATUS$CONTAINER_STARTED': 'Container started.',
|
||||
'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...',
|
||||
}
|
||||
|
||||
|
||||
class RuntimeNotReadyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RuntimeDisconnectedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]:
|
||||
ret = {}
|
||||
@@ -54,6 +72,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
config: AppConfig
|
||||
initial_env_vars: dict[str, str]
|
||||
attach_to_existing: bool
|
||||
status_callback: Callable | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -62,14 +81,14 @@ class Runtime(FileEditRuntimeMixin):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
|
||||
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
|
||||
self.status_message_callback = status_message_callback
|
||||
self.status_callback = status_callback
|
||||
self.attach_to_existing = attach_to_existing
|
||||
|
||||
self.config = copy.deepcopy(config)
|
||||
@@ -95,7 +114,17 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
def log(self, level: str, message: str) -> None:
|
||||
message = f'[runtime {self.sid}] {message}'
|
||||
getattr(logger, level)(message)
|
||||
getattr(logger, level)(message, stacklevel=2)
|
||||
|
||||
def send_status_message(self, message_id: str):
|
||||
"""Sends a status message if the callback function was provided."""
|
||||
if self.status_callback:
|
||||
msg = STATUS_MESSAGES.get(message_id, '')
|
||||
self.status_callback('info', message_id, msg)
|
||||
|
||||
def send_error_message(self, message_id: str, message: str):
|
||||
if self.status_callback:
|
||||
self.status_callback('error', message_id, message)
|
||||
|
||||
# ====================================================================
|
||||
|
||||
@@ -131,15 +160,28 @@ class Runtime(FileEditRuntimeMixin):
|
||||
if event.timeout is None:
|
||||
event.timeout = self.config.sandbox.timeout
|
||||
assert event.timeout is not None
|
||||
observation: Observation = await call_sync_from_async(
|
||||
self.run_action, event
|
||||
)
|
||||
try:
|
||||
observation: Observation = await call_sync_from_async(
|
||||
self.run_action, event
|
||||
)
|
||||
except Exception as e:
|
||||
err_id = ''
|
||||
if isinstance(e, ConnectionError) or isinstance(
|
||||
e, RuntimeDisconnectedError
|
||||
):
|
||||
err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED'
|
||||
self.log('error', f'Unexpected error while running action {e}')
|
||||
self.log('error', f'Problematic action: {str(event)}')
|
||||
self.send_error_message(err_id, str(e))
|
||||
self.close()
|
||||
return
|
||||
|
||||
observation._cause = event.id # type: ignore[attr-defined]
|
||||
observation.tool_call_metadata = event.tool_call_metadata
|
||||
|
||||
# this might be unnecessary, since source should be set by the event stream when we're here
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type]
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
"""Run an action and return the resulting observation.
|
||||
|
||||
@@ -7,7 +7,7 @@ import requests
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.runtime.builder import RuntimeBuilder
|
||||
from openhands.runtime.utils.request import is_429_error, send_request_with_retry
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.runtime.utils.shutdown_listener import (
|
||||
should_continue,
|
||||
sleep_if_should_continue,
|
||||
@@ -45,18 +45,21 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
files.append(('tags', (None, tag)))
|
||||
|
||||
# Send the POST request to /build (Begins the build process)
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/build',
|
||||
files=files,
|
||||
timeout=30,
|
||||
retry_fns=[is_429_error],
|
||||
)
|
||||
|
||||
if response.status_code != 202:
|
||||
logger.error(f'Build initiation failed: {response.text}')
|
||||
raise RuntimeError(f'Build initiation failed: {response.text}')
|
||||
try:
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/build',
|
||||
files=files,
|
||||
timeout=30,
|
||||
)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 429:
|
||||
logger.warning('Build was rate limited. Retrying in 30 seconds.')
|
||||
time.sleep(30)
|
||||
return self.build(path, tags, platform)
|
||||
else:
|
||||
raise e
|
||||
|
||||
build_data = response.json()
|
||||
build_id = build_data['build_id']
|
||||
@@ -70,12 +73,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
logger.error('Build timed out after 30 minutes')
|
||||
raise RuntimeError('Build timed out after 30 minutes')
|
||||
|
||||
status_response = send_request_with_retry(
|
||||
status_response = send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/build_status',
|
||||
params={'build_id': build_id},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if status_response.status_code != 200:
|
||||
@@ -112,12 +114,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool:
|
||||
"""Checks if an image exists in the remote registry using the /image_exists endpoint."""
|
||||
params = {'image': image_name}
|
||||
response = send_request_with_retry(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/image_exists',
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
@@ -27,14 +27,14 @@ class E2BRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
sandbox: E2BSandbox | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
status_callback: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
status_message_callback=status_message_callback,
|
||||
status_callback=status_callback,
|
||||
)
|
||||
if sandbox is None:
|
||||
self.sandbox = E2BSandbox()
|
||||
|
||||
@@ -25,7 +25,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.observation import (
|
||||
FatalErrorObservation,
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
UserRejectObservation,
|
||||
@@ -36,8 +36,9 @@ from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
from openhands.runtime.utils.request import send_request_with_retry
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
@@ -123,7 +124,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -132,7 +133,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
@@ -143,7 +144,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
@@ -151,7 +152,7 @@ class EventStreamRuntime(Runtime):
|
||||
self._container_port = 30001 # initial dummy value
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
self.session = requests.Session()
|
||||
self.status_message_callback = status_message_callback
|
||||
self.status_callback = status_callback
|
||||
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.base_container_image = self.config.sandbox.base_container_image
|
||||
@@ -181,7 +182,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
@@ -205,21 +206,21 @@ class EventStreamRuntime(Runtime):
|
||||
self.log(
|
||||
'info', f'Starting runtime with image: {self.runtime_container_image}'
|
||||
)
|
||||
self._init_container()
|
||||
await call_sync_from_async(self._init_container)
|
||||
self.log('info', f'Container started: {self.container_name}')
|
||||
|
||||
else:
|
||||
self._attach_to_container()
|
||||
await call_sync_from_async(self._attach_to_container)
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
self._wait_until_alive()
|
||||
await call_sync_from_async(self._wait_until_alive)
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', 'Runtime is ready.')
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.setup_initial_env()
|
||||
await call_sync_from_async(self.setup_initial_env)
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
@@ -238,82 +239,74 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
raise ex
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
|
||||
wait=tenacity.wait_fixed(5),
|
||||
)
|
||||
def _init_container(self):
|
||||
try:
|
||||
self.log('debug', 'Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
plugin_arg = ''
|
||||
if self.plugins is not None and len(self.plugins) > 0:
|
||||
plugin_arg = (
|
||||
f'--plugins {" ".join([plugin.name for plugin in self.plugins])} '
|
||||
)
|
||||
|
||||
self._host_port = self._find_available_port()
|
||||
self._container_port = (
|
||||
self._host_port
|
||||
) # in future this might differ from host port
|
||||
self.api_url = (
|
||||
f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
self.log('debug', 'Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
plugin_arg = ''
|
||||
if self.plugins is not None and len(self.plugins) > 0:
|
||||
plugin_arg = (
|
||||
f'--plugins {" ".join([plugin.name for plugin in self.plugins])} '
|
||||
)
|
||||
|
||||
use_host_network = self.config.sandbox.use_host_network
|
||||
network_mode: str | None = 'host' if use_host_network else None
|
||||
port_mapping: dict[str, list[dict[str, str]]] | None = (
|
||||
None
|
||||
if use_host_network
|
||||
else {
|
||||
f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]
|
||||
}
|
||||
)
|
||||
self._host_port = self._find_available_port()
|
||||
self._container_port = (
|
||||
self._host_port
|
||||
) # in future this might differ from host port
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
|
||||
if use_host_network:
|
||||
self.log(
|
||||
'warn',
|
||||
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop',
|
||||
)
|
||||
use_host_network = self.config.sandbox.use_host_network
|
||||
network_mode: str | None = 'host' if use_host_network else None
|
||||
port_mapping: dict[str, list[dict[str, str]]] | None = (
|
||||
None
|
||||
if use_host_network
|
||||
else {f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]}
|
||||
)
|
||||
|
||||
# Combine environment variables
|
||||
environment = {
|
||||
'port': str(self._container_port),
|
||||
'PYTHONUNBUFFERED': 1,
|
||||
}
|
||||
if self.config.debug or DEBUG:
|
||||
environment['DEBUG'] = 'true'
|
||||
|
||||
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
|
||||
if (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes = {
|
||||
self.config.workspace_mount_path: {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': 'rw',
|
||||
}
|
||||
}
|
||||
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
|
||||
else:
|
||||
logger.debug(
|
||||
'Mount dir is not set, will not mount the workspace directory to the container'
|
||||
)
|
||||
volumes = None
|
||||
if use_host_network:
|
||||
self.log(
|
||||
'debug',
|
||||
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
|
||||
'warn',
|
||||
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop',
|
||||
)
|
||||
|
||||
if self.config.sandbox.browsergym_eval_env is not None:
|
||||
browsergym_arg = (
|
||||
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
|
||||
)
|
||||
else:
|
||||
browsergym_arg = ''
|
||||
# Combine environment variables
|
||||
environment = {
|
||||
'port': str(self._container_port),
|
||||
'PYTHONUNBUFFERED': 1,
|
||||
}
|
||||
if self.config.debug or DEBUG:
|
||||
environment['DEBUG'] = 'true'
|
||||
|
||||
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
|
||||
if (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes = {
|
||||
self.config.workspace_mount_path: {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': 'rw',
|
||||
}
|
||||
}
|
||||
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
|
||||
else:
|
||||
logger.debug(
|
||||
'Mount dir is not set, will not mount the workspace directory to the container'
|
||||
)
|
||||
volumes = None
|
||||
self.log(
|
||||
'debug',
|
||||
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
|
||||
)
|
||||
|
||||
if self.config.sandbox.browsergym_eval_env is not None:
|
||||
browsergym_arg = (
|
||||
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
|
||||
)
|
||||
else:
|
||||
browsergym_arg = ''
|
||||
|
||||
try:
|
||||
self.container = self.docker_client.containers.run(
|
||||
self.runtime_container_image,
|
||||
command=(
|
||||
@@ -337,6 +330,21 @@ class EventStreamRuntime(Runtime):
|
||||
self.log_buffer = LogBuffer(self.container, self.log)
|
||||
self.log('debug', f'Container started. Server url: {self.api_url}')
|
||||
self.send_status_message('STATUS$CONTAINER_STARTED')
|
||||
except docker.errors.APIError as e:
|
||||
# check 409 error
|
||||
if '409' in str(e):
|
||||
self.log(
|
||||
'warning',
|
||||
f'Container {self.container_name} already exists. Removing...',
|
||||
)
|
||||
self._close_containers(rm_all_containers=True)
|
||||
return self._init_container()
|
||||
|
||||
else:
|
||||
self.log(
|
||||
'error',
|
||||
f'Error: Instance {self.container_name} FAILED to start container!\n',
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
'error',
|
||||
@@ -384,27 +392,20 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
|
||||
wait=tenacity.wait_exponential(multiplier=2, min=1, max=20),
|
||||
reraise=(ConnectionRefusedError,),
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
def _wait_until_alive(self):
|
||||
self._refresh_logs()
|
||||
if not self.log_buffer:
|
||||
raise RuntimeError('Runtime client is not ready.')
|
||||
|
||||
response = send_request_with_retry(
|
||||
send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/alive',
|
||||
retry_exceptions=[ConnectionRefusedError],
|
||||
timeout=300, # 5 minutes gives the container time to be alive 🧟♂️
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
msg = f'Action execution API is not alive. Response: {response}'
|
||||
self.log('error', msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def close(self, rm_all_containers: bool = True):
|
||||
"""Closes the EventStreamRuntime and associated objects
|
||||
@@ -421,7 +422,9 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
if self.attach_to_existing:
|
||||
return
|
||||
self._close_containers(rm_all_containers)
|
||||
|
||||
def _close_containers(self, rm_all_containers: bool = True):
|
||||
try:
|
||||
containers = self.docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
@@ -466,10 +469,11 @@ class EventStreamRuntime(Runtime):
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
return FatalErrorObservation(f'Action {action_type} does not exist.')
|
||||
raise ValueError(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return FatalErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.'
|
||||
return ErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.',
|
||||
error_id='AGENT_ERROR$BAD_ACTION',
|
||||
)
|
||||
if (
|
||||
getattr(action, 'confirmation_state', None)
|
||||
@@ -484,33 +488,21 @@ class EventStreamRuntime(Runtime):
|
||||
assert action.timeout is not None
|
||||
|
||||
try:
|
||||
response = send_request_with_retry(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/execute_action',
|
||||
json={'action': event_to_dict(action)},
|
||||
timeout=action.timeout,
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
else:
|
||||
self.log('debug', f'action: {action}')
|
||||
self.log('debug', f'response: {response}')
|
||||
error_message = response.text
|
||||
self.log('error', f'Error from server: {error_message}')
|
||||
obs = FatalErrorObservation(
|
||||
f'Action execution failed: {error_message}'
|
||||
)
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
except requests.Timeout:
|
||||
self.log('error', 'No response received within the timeout period.')
|
||||
obs = FatalErrorObservation(
|
||||
f'Action execution timed out after {action.timeout} seconds.'
|
||||
raise RuntimeError(
|
||||
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
||||
)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error during action execution: {e}')
|
||||
obs = FatalErrorObservation(f'Action execution failed: {str(e)}')
|
||||
self._refresh_logs()
|
||||
return obs
|
||||
|
||||
@@ -567,7 +559,7 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
response = send_request_with_retry(
|
||||
send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/upload_file',
|
||||
@@ -575,11 +567,6 @@ class EventStreamRuntime(Runtime):
|
||||
params=params,
|
||||
timeout=300,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
@@ -604,31 +591,25 @@ class EventStreamRuntime(Runtime):
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
response = send_request_with_retry(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/list_files',
|
||||
json=data,
|
||||
timeout=30, # 30 seconds because the container should already be alive
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'List files operation failed: {error_message}')
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('List files operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'List files operation failed: {str(e)}')
|
||||
|
||||
def copy_from(self, path: str) -> bytes:
|
||||
"""Zip all files in the sandbox and return as a stream of bytes."""
|
||||
self._refresh_logs()
|
||||
try:
|
||||
params = {'path': path}
|
||||
response = send_request_with_retry(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/download_files',
|
||||
@@ -636,16 +617,10 @@ class EventStreamRuntime(Runtime):
|
||||
stream=True,
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.content
|
||||
return data
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
data = response.content
|
||||
return data
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Copy operation failed: {str(e)}')
|
||||
|
||||
def _is_port_in_use_docker(self, port):
|
||||
containers = self.docker_client.containers.list()
|
||||
@@ -663,8 +638,3 @@ class EventStreamRuntime(Runtime):
|
||||
return port
|
||||
# If no port is found after max_attempts, return the last tried port
|
||||
return port
|
||||
|
||||
def send_status_message(self, message: str):
|
||||
"""Sends a status message if the callback function was provided."""
|
||||
if self.status_message_callback:
|
||||
self.status_message_callback(message)
|
||||
|
||||
@@ -75,7 +75,7 @@ class ModalRuntime(EventStreamRuntime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
assert config.modal_api_token_id, 'Modal API token id is required'
|
||||
@@ -102,7 +102,7 @@ class ModalRuntime(EventStreamRuntime):
|
||||
self.container_port = 3000
|
||||
|
||||
self.session = requests.Session()
|
||||
self.status_message_callback = status_message_callback
|
||||
self.status_callback = status_callback
|
||||
self.base_container_image_id = self.config.sandbox.base_container_image
|
||||
self.runtime_container_image_id = self.config.sandbox.runtime_container_image
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
@@ -122,7 +122,7 @@ class ModalRuntime(EventStreamRuntime):
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
import requests
|
||||
from requests.exceptions import Timeout
|
||||
import tenacity
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events import EventStream
|
||||
@@ -21,22 +20,26 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.observation import (
|
||||
FatalErrorObservation,
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.base import (
|
||||
Runtime,
|
||||
RuntimeDisconnectedError,
|
||||
RuntimeNotReadyError,
|
||||
)
|
||||
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.command import get_remote_startup_command
|
||||
from openhands.runtime.utils.request import (
|
||||
is_404_error,
|
||||
is_503_error,
|
||||
send_request_with_retry,
|
||||
send_request,
|
||||
)
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
class RemoteRuntime(Runtime):
|
||||
@@ -51,31 +54,32 @@ class RemoteRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
status_callback: Optional[Callable] = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
# We need to set session and action_semaphore before the __init__ below, or we get odd errors
|
||||
self.session = requests.Session()
|
||||
self.action_semaphore = threading.Semaphore(1)
|
||||
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
'API key is required to use the remote runtime. '
|
||||
'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
|
||||
self.action_semaphore = threading.Semaphore(1)
|
||||
|
||||
if self.config.workspace_base is not None:
|
||||
self.log(
|
||||
'warning',
|
||||
'debug',
|
||||
'Setting workspace_base is not supported in the remote runtime.',
|
||||
)
|
||||
|
||||
@@ -86,9 +90,13 @@ class RemoteRuntime(Runtime):
|
||||
self.runtime_url: str | None = None
|
||||
|
||||
async def connect(self):
|
||||
self._start_or_attach_to_runtime()
|
||||
self._wait_until_alive()
|
||||
self.setup_initial_env()
|
||||
await call_sync_from_async(self._start_or_attach_to_runtime)
|
||||
try:
|
||||
await call_sync_from_async(self._wait_until_alive)
|
||||
except RuntimeNotReadyError:
|
||||
self.log('error', 'Runtime failed to start, timed out before ready')
|
||||
raise
|
||||
await call_sync_from_async(self.setup_initial_env)
|
||||
|
||||
def _start_or_attach_to_runtime(self):
|
||||
existing_runtime = self._check_existing_runtime()
|
||||
@@ -127,44 +135,40 @@ class RemoteRuntime(Runtime):
|
||||
|
||||
def _check_existing_runtime(self) -> bool:
|
||||
try:
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.sid}',
|
||||
timeout=5,
|
||||
)
|
||||
except Exception as e:
|
||||
except requests.HTTPError as e:
|
||||
if e.response.status_code == 404:
|
||||
return False
|
||||
self.log('debug', f'Error while looking for remote runtime: {e}')
|
||||
return False
|
||||
raise
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
status = data.get('status')
|
||||
if status == 'running':
|
||||
self._parse_runtime_response(response)
|
||||
return True
|
||||
elif status == 'stopped':
|
||||
self.log('debug', 'Found existing remote runtime, but it is stopped')
|
||||
return False
|
||||
elif status == 'paused':
|
||||
self.log('debug', 'Found existing remote runtime, but it is paused')
|
||||
self._parse_runtime_response(response)
|
||||
self._resume_runtime()
|
||||
return True
|
||||
else:
|
||||
self.log('error', f'Invalid response from runtime API: {data}')
|
||||
return False
|
||||
data = response.json()
|
||||
status = data.get('status')
|
||||
if status == 'running':
|
||||
self._parse_runtime_response(response)
|
||||
return True
|
||||
elif status == 'stopped':
|
||||
self.log('debug', 'Found existing remote runtime, but it is stopped')
|
||||
return False
|
||||
elif status == 'paused':
|
||||
self.log('debug', 'Found existing remote runtime, but it is paused')
|
||||
self._parse_runtime_response(response)
|
||||
self._resume_runtime()
|
||||
return True
|
||||
else:
|
||||
self.log('debug', 'Could not find existing remote runtime')
|
||||
self.log('error', f'Invalid response from runtime API: {data}')
|
||||
return False
|
||||
|
||||
def _build_runtime(self):
|
||||
self.log('debug', f'Building RemoteRuntime config:\n{self.config}')
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
|
||||
timeout=30,
|
||||
timeout=10,
|
||||
)
|
||||
response_json = response.json()
|
||||
registry_prefix = response_json['registry_prefix']
|
||||
@@ -191,14 +195,13 @@ class RemoteRuntime(Runtime):
|
||||
force_rebuild=self.config.sandbox.force_rebuild_runtime,
|
||||
)
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
|
||||
params={'image': self.container_image},
|
||||
timeout=30,
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200 or not response.json()['exists']:
|
||||
if not response.json()['exists']:
|
||||
raise RuntimeError(f'Container image {self.container_image} does not exist')
|
||||
|
||||
def _start_runtime(self):
|
||||
@@ -228,17 +231,11 @@ class RemoteRuntime(Runtime):
|
||||
}
|
||||
|
||||
# Start the sandbox using the /start endpoint
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/start',
|
||||
json=start_request,
|
||||
timeout=300,
|
||||
)
|
||||
if response.status_code != 201:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Failed to start runtime: {response.text}'
|
||||
)
|
||||
self._parse_runtime_response(response)
|
||||
self.log(
|
||||
'debug',
|
||||
@@ -246,17 +243,12 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
|
||||
def _resume_runtime(self):
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
self._send_request(
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/resume',
|
||||
json={'runtime_id': self.runtime_id},
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Failed to resume runtime: {response.text}'
|
||||
)
|
||||
self.log('debug', 'Runtime resumed.')
|
||||
|
||||
def _parse_runtime_response(self, response: requests.Response):
|
||||
@@ -268,72 +260,57 @@ class RemoteRuntime(Runtime):
|
||||
{'X-Session-API-Key': start_response['session_api_key']}
|
||||
)
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(180) | stop_if_should_exit(),
|
||||
reraise=True,
|
||||
retry=tenacity.retry_if_exception_type(RuntimeNotReadyError),
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
def _wait_until_alive(self):
|
||||
self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}')
|
||||
# send GET request to /runtime/<id>
|
||||
pod_running = False
|
||||
max_not_found_count = 12 # 2 minutes
|
||||
not_found_count = 0
|
||||
while not pod_running:
|
||||
runtime_info_response = send_request_with_retry(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}',
|
||||
timeout=5,
|
||||
)
|
||||
if runtime_info_response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f'Failed to get runtime status: {runtime_info_response.status_code}. Response: {runtime_info_response.text}'
|
||||
)
|
||||
runtime_data = runtime_info_response.json()
|
||||
assert runtime_data['runtime_id'] == self.runtime_id
|
||||
pod_status = runtime_data['pod_status']
|
||||
self.log(
|
||||
'debug',
|
||||
f'Waiting for runtime pod to be active. Current status: {pod_status}',
|
||||
)
|
||||
if pod_status == 'Ready':
|
||||
pod_running = True
|
||||
break
|
||||
elif pod_status == 'Not Found' and not_found_count < max_not_found_count:
|
||||
not_found_count += 1
|
||||
self.log(
|
||||
'debug',
|
||||
f'Runtime pod not found. Count: {not_found_count} / {max_not_found_count}',
|
||||
)
|
||||
elif pod_status in ('Failed', 'Unknown', 'Not Found'):
|
||||
# clean up the runtime
|
||||
self.close()
|
||||
raise RuntimeError(
|
||||
f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
|
||||
)
|
||||
# Pending otherwise - add proper sleep
|
||||
time.sleep(10)
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
runtime_info_response = self._send_request(
|
||||
'GET',
|
||||
f'{self.runtime_url}/alive',
|
||||
# Retry 404 & 503 errors for the /alive endpoint
|
||||
# because the runtime might just be starting up
|
||||
# and have not registered the endpoint yet
|
||||
retry_fns=[is_404_error, is_503_error],
|
||||
# leave enough time for the runtime to start up
|
||||
timeout=600,
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}',
|
||||
)
|
||||
if response.status_code != 200:
|
||||
msg = f'Runtime (ID={self.runtime_id}) is not alive yet. Status: {response.status_code}.'
|
||||
self.log('warning', msg)
|
||||
raise RuntimeError(msg)
|
||||
runtime_data = runtime_info_response.json()
|
||||
assert 'runtime_id' in runtime_data
|
||||
assert runtime_data['runtime_id'] == self.runtime_id
|
||||
assert 'pod_status' in runtime_data
|
||||
pod_status = runtime_data['pod_status']
|
||||
if pod_status == 'Ready':
|
||||
try:
|
||||
self._send_request(
|
||||
'GET',
|
||||
f'{self.runtime_url}/alive',
|
||||
) # will raise exception if we don't get 200 back.
|
||||
except requests.HTTPError as e:
|
||||
self.log(
|
||||
'warning', f"Runtime /alive failed, but pod says it's ready: {e}"
|
||||
)
|
||||
raise RuntimeNotReadyError(
|
||||
f'Runtime /alive failed to respond with 200: {e}'
|
||||
)
|
||||
return
|
||||
if pod_status in ('Failed', 'Unknown', 'Not Found'):
|
||||
# clean up the runtime
|
||||
self.close()
|
||||
raise RuntimeError(
|
||||
f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
|
||||
)
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
f'Waiting for runtime pod to be active. Current status: {pod_status}',
|
||||
)
|
||||
raise RuntimeNotReadyError()
|
||||
|
||||
def close(self, timeout: int = 10):
|
||||
if self.config.sandbox.keep_remote_runtime_alive or self.attach_to_existing:
|
||||
self.session.close()
|
||||
return
|
||||
if self.runtime_id:
|
||||
if self.runtime_id and self.session:
|
||||
try:
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/stop',
|
||||
json={'runtime_id': self.runtime_id},
|
||||
@@ -361,12 +338,11 @@ class RemoteRuntime(Runtime):
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
return FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action {action_type} does not exist.'
|
||||
)
|
||||
raise ValueError(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.'
|
||||
return ErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.',
|
||||
error_id='AGENT_ERROR$BAD_ACTION',
|
||||
)
|
||||
|
||||
assert action.timeout is not None
|
||||
@@ -374,36 +350,37 @@ class RemoteRuntime(Runtime):
|
||||
try:
|
||||
request_body = {'action': event_to_dict(action)}
|
||||
self.log('debug', f'Request body: {request_body}')
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/execute_action',
|
||||
json=request_body,
|
||||
timeout=action.timeout,
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
return obs
|
||||
else:
|
||||
error_message = response.text
|
||||
self.log('error', f'Error from server: {error_message}')
|
||||
obs = FatalErrorObservation(
|
||||
f'Action execution failed: {error_message}'
|
||||
)
|
||||
except Timeout:
|
||||
self.log('error', 'No response received within the timeout period.')
|
||||
obs = FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action execution timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error during action execution: {e}')
|
||||
obs = FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action execution failed: {str(e)}'
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
except requests.Timeout:
|
||||
raise RuntimeError(
|
||||
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
||||
)
|
||||
return obs
|
||||
|
||||
def _send_request(self, method, url, **kwargs):
|
||||
is_runtime_request = self.runtime_url and self.runtime_url in url
|
||||
try:
|
||||
return send_request(self.session, method, url, **kwargs)
|
||||
except requests.Timeout:
|
||||
self.log('error', 'No response received within the timeout period.')
|
||||
raise
|
||||
except requests.HTTPError as e:
|
||||
if is_runtime_request and e.response.status_code == 404:
|
||||
raise RuntimeDisconnectedError(
|
||||
f'404 error while connecting to {self.runtime_url}'
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
def run(self, action: CmdRunAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
@@ -450,32 +427,16 @@ class RemoteRuntime(Runtime):
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/upload_file',
|
||||
files=upload_data,
|
||||
params=params,
|
||||
timeout=300,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
self.log(
|
||||
'debug',
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
|
||||
)
|
||||
return
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
|
||||
)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
|
||||
self.log(
|
||||
'debug',
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
|
||||
)
|
||||
finally:
|
||||
if recursive:
|
||||
@@ -485,64 +446,27 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
try:
|
||||
data = {}
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
data = {}
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.runtime_url}/list_files',
|
||||
json=data,
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(
|
||||
f'[Runtime (ID={self.runtime_id})] List files operation failed: {error_message}'
|
||||
)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f'[Runtime (ID={self.runtime_id})] List files operation timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] List files operation failed: {str(e)}'
|
||||
)
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/list_files',
|
||||
json=data,
|
||||
timeout=30,
|
||||
)
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
|
||||
def copy_from(self, path: str) -> bytes:
|
||||
"""Zip all files in the sandbox and return as a stream of bytes."""
|
||||
try:
|
||||
params = {'path': path}
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.runtime_url}/download_files',
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.content
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
|
||||
)
|
||||
except requests.Timeout:
|
||||
raise TimeoutError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
|
||||
)
|
||||
|
||||
def send_status_message(self, message: str):
|
||||
"""Sends a status message if the callback function was provided."""
|
||||
if self.status_message_callback:
|
||||
self.status_message_callback(message)
|
||||
params = {'path': path}
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.runtime_url}/download_files',
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
return response.content
|
||||
|
||||
@@ -9,7 +9,7 @@ from openhands.events.action import CmdRunAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
FatalErrorObservation,
|
||||
ErrorObservation,
|
||||
)
|
||||
|
||||
SOFT_TIMEOUT_SECONDS = 5
|
||||
@@ -275,7 +275,7 @@ class BashSession:
|
||||
output += '\r\n' + bash_prompt
|
||||
return output, exit_code
|
||||
|
||||
def run(self, action: CmdRunAction) -> CmdOutputObservation | FatalErrorObservation:
|
||||
def run(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation:
|
||||
try:
|
||||
assert (
|
||||
action.timeout is not None
|
||||
@@ -329,6 +329,6 @@ class BashSession:
|
||||
interpreter_details=python_interpreter,
|
||||
)
|
||||
except UnicodeDecodeError as e:
|
||||
return FatalErrorObservation(
|
||||
f'Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}'
|
||||
return ErrorObservation(
|
||||
f'Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}',
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FatalErrorObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
@@ -214,9 +213,7 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
if isinstance(obs, ErrorObservation):
|
||||
return obs
|
||||
if not isinstance(obs, FileWriteObservation):
|
||||
return FatalErrorObservation(
|
||||
f'Fatal Runtime in editing: Expected FileWriteObservation, got {type(obs)}: {str(obs)}'
|
||||
)
|
||||
raise ValueError(f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}')
|
||||
return FileEditObservation(
|
||||
content=get_diff('', action.content, action.path),
|
||||
path=action.path,
|
||||
@@ -225,9 +222,7 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
new_content=action.content,
|
||||
)
|
||||
if not isinstance(obs, FileReadObservation):
|
||||
return FatalErrorObservation(
|
||||
f'Fatal Runtime in editing: Expected FileReadObservation, got {type(obs)}: {str(obs)}'
|
||||
)
|
||||
raise ValueError(f'Expected FileReadObservation, got {type(obs)}: {str(obs)}')
|
||||
|
||||
original_file_content = obs.content
|
||||
old_file_lines = original_file_content.split('\n')
|
||||
|
||||
@@ -1,22 +1,12 @@
|
||||
from typing import Any, Callable, Type
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.exceptions import (
|
||||
ChunkedEncodingError,
|
||||
ConnectionError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception,
|
||||
retry_if_exception_type,
|
||||
stop_after_delay,
|
||||
wait_exponential,
|
||||
)
|
||||
from urllib3.exceptions import IncompleteRead
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
def is_server_error(exception):
|
||||
return (
|
||||
@@ -60,37 +50,13 @@ DEFAULT_RETRY_EXCEPTIONS = [
|
||||
]
|
||||
|
||||
|
||||
def send_request_with_retry(
|
||||
def send_request(
|
||||
session: requests.Session,
|
||||
method: str,
|
||||
url: str,
|
||||
timeout: int,
|
||||
retry_exceptions: list[Type[Exception]] | None = None,
|
||||
retry_fns: list[Callable[[Exception], bool]] | None = None,
|
||||
timeout: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS
|
||||
retry_condition = retry_if_exception_type(
|
||||
tuple(exceptions_to_catch)
|
||||
) | retry_if_exception(is_502_error)
|
||||
if retry_fns is not None:
|
||||
for fn in retry_fns:
|
||||
retry_condition |= retry_if_exception(fn)
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
kwargs['timeout'] = timeout + 10
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(timeout) | stop_if_should_exit(),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=20),
|
||||
retry=retry_condition,
|
||||
reraise=True,
|
||||
before_sleep=lambda retry_state: logger.debug(
|
||||
f'Retrying {method} request to {url} due to {retry_state.outcome.exception()}. Attempt {retry_state.attempt_number}'
|
||||
),
|
||||
)
|
||||
def _send_request_with_retry():
|
||||
response = session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
return _send_request_with_retry()
|
||||
response = session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
@@ -32,7 +32,12 @@ class AgentSession:
|
||||
_closed: bool = False
|
||||
loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore):
|
||||
def __init__(
|
||||
self,
|
||||
sid: str,
|
||||
file_store: FileStore,
|
||||
status_callback: Optional[Callable] = None,
|
||||
):
|
||||
"""Initializes a new instance of the Session class
|
||||
|
||||
Parameters:
|
||||
@@ -43,6 +48,7 @@ class AgentSession:
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
self.file_store = file_store
|
||||
self._status_callback = status_callback
|
||||
|
||||
async def start(
|
||||
self,
|
||||
@@ -53,7 +59,6 @@ class AgentSession:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
"""Starts the Agent session
|
||||
Parameters:
|
||||
@@ -80,7 +85,6 @@ class AgentSession:
|
||||
max_budget_per_task,
|
||||
agent_to_llm_config,
|
||||
agent_configs,
|
||||
status_message_callback,
|
||||
)
|
||||
|
||||
def _start_thread(self, *args):
|
||||
@@ -99,14 +103,12 @@ class AgentSession:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(
|
||||
runtime_name=runtime_name,
|
||||
config=config,
|
||||
agent=agent,
|
||||
status_message_callback=status_message_callback,
|
||||
)
|
||||
self._create_controller(
|
||||
agent,
|
||||
@@ -132,6 +134,10 @@ class AgentSession:
|
||||
|
||||
asyncio.get_event_loop().run_in_executor(None, inner_close)
|
||||
|
||||
async def stop_agent_loop_for_error(self):
|
||||
if self.controller is not None:
|
||||
await self.controller.set_agent_state_to(AgentState.ERROR)
|
||||
|
||||
async def _close(self):
|
||||
if self._closed:
|
||||
return
|
||||
@@ -162,7 +168,6 @@ class AgentSession:
|
||||
runtime_name: str,
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
"""Creates a runtime instance
|
||||
|
||||
@@ -182,13 +187,17 @@ class AgentSession:
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_message_callback=status_message_callback,
|
||||
status_callback=self._status_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.runtime.connect()
|
||||
except Exception as e:
|
||||
logger.error(f'Runtime initialization failed: {e}', exc_info=True)
|
||||
if self._status_callback:
|
||||
self._status_callback(
|
||||
'error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
if self.runtime is not None:
|
||||
@@ -252,9 +261,8 @@ class AgentSession:
|
||||
agent_to_llm_config=agent_to_llm_config,
|
||||
agent_configs=agent_configs,
|
||||
confirmation_mode=confirmation_mode,
|
||||
# AgentSession is designed to communicate with the frontend, so we don't want to
|
||||
# run the agent in headless mode.
|
||||
headless_mode=False,
|
||||
status_callback=self._status_callback,
|
||||
)
|
||||
try:
|
||||
agent_state = State.restore_from_session(self.sid, self.file_store)
|
||||
|
||||
@@ -40,7 +40,9 @@ class Session:
|
||||
self.sid = sid
|
||||
self.websocket = ws
|
||||
self.last_active_ts = int(time.time())
|
||||
self.agent_session = AgentSession(sid, file_store)
|
||||
self.agent_session = AgentSession(
|
||||
sid, file_store, status_callback=self.queue_status_message
|
||||
)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event
|
||||
)
|
||||
@@ -115,7 +117,6 @@ class Session:
|
||||
max_budget_per_task=self.config.max_budget_per_task,
|
||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
status_message_callback=self.queue_status_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
@@ -171,14 +172,6 @@ class Session:
|
||||
'Model does not support image upload, change to a different model or try without an image.'
|
||||
)
|
||||
return
|
||||
if self.loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_event(event, EventSource.USER), self.loop
|
||||
) # type: ignore
|
||||
else:
|
||||
raise RuntimeError('No event loop found')
|
||||
|
||||
async def _add_event(self, event, event_source):
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
async def send(self, data: dict[str, object]) -> bool:
|
||||
@@ -200,11 +193,17 @@ class Session:
|
||||
"""Sends an error message to the client."""
|
||||
return await self.send({'error': True, 'message': message})
|
||||
|
||||
async def send_status_message(self, message: str) -> bool:
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str) -> bool:
|
||||
"""Sends a status message to the client."""
|
||||
return await self.send({'status': message})
|
||||
if msg_type == 'error':
|
||||
await self.agent_session.stop_agent_loop_for_error()
|
||||
|
||||
def queue_status_message(self, message: str):
|
||||
return await self.send(
|
||||
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
|
||||
)
|
||||
|
||||
def queue_status_message(self, msg_type: str, id: str, message: str):
|
||||
"""Queues a status message to be sent asynchronously."""
|
||||
# Ensure the coroutine runs in the main event loop
|
||||
asyncio.run_coroutine_threadsafe(self.send_status_message(message), self.loop)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_status_message(msg_type, id, message), self.loop
|
||||
)
|
||||
|
||||
231
tests/runtime/test_stress_remote_runtime.py
Normal file
231
tests/runtime/test_stress_remote_runtime.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from conftest import TEST_IN_CI
|
||||
|
||||
from evaluation.utils.shared import (
|
||||
EvalException,
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
assert_and_raise,
|
||||
codeact_user_response,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
)
|
||||
from openhands.agenthub import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AgentConfig,
|
||||
AppConfig,
|
||||
LLMConfig,
|
||||
SandboxConfig,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
}
|
||||
|
||||
|
||||
def get_config(
|
||||
metadata: EvalMetadata,
|
||||
) -> AppConfig:
|
||||
assert (
|
||||
os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL') is not None
|
||||
), 'SANDBOX_REMOTE_RUNTIME_API_URL must be set.'
|
||||
assert (
|
||||
os.environ.get('ALLHANDS_API_KEY') is not None
|
||||
), 'ALLHANDS_API_KEY must be set.'
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
max_iterations=metadata.max_iterations,
|
||||
runtime='remote',
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image='python:3.11-bookworm',
|
||||
enable_auto_lint=True,
|
||||
use_host_network=False,
|
||||
# large enough timeout, since some testcases take very long to run
|
||||
timeout=300,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
keep_remote_runtime_alive=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
agent_config = AgentConfig(
|
||||
codeact_enable_jupyter=False,
|
||||
codeact_enable_browsing=False,
|
||||
codeact_enable_llm_editor=False,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
return config
|
||||
|
||||
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
):
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
"""
|
||||
logger.info('-' * 30)
|
||||
logger.info('BEGIN Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
obs: CmdOutputObservation
|
||||
|
||||
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /dummy_dir')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to create /dummy_dir: {str(obs)}',
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Construct the full path for the desired file name within the temporary directory
|
||||
temp_file_path = os.path.join(temp_dir, 'dummy_file')
|
||||
# Write to the file with the desired name within the temporary directory
|
||||
with open(temp_file_path, 'w') as f:
|
||||
f.write('dummy content')
|
||||
|
||||
# Copy the file to the desired location
|
||||
runtime.copy_to(temp_file_path, '/dummy_dir/')
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
) -> EvalOutput:
|
||||
config = get_config(metadata)
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
||||
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
try:
|
||||
initialize_runtime(runtime)
|
||||
|
||||
instruction = 'dummy instruction'
|
||||
agent = Agent.get_cls(metadata.agent_class)(
|
||||
llm=LLM(config=metadata.llm_config),
|
||||
config=config.get_agent_config(metadata.agent_class),
|
||||
)
|
||||
|
||||
def next_command(*args, **kwargs):
|
||||
return CmdRunAction(command='ls -lah')
|
||||
|
||||
agent.step = MagicMock(side_effect=next_command)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
agent=agent,
|
||||
)
|
||||
)
|
||||
|
||||
# if fatal error, throw EvalError to trigger re-run
|
||||
if (
|
||||
state.last_error
|
||||
and 'fatal error during agent execution' in state.last_error
|
||||
and 'stuck in a loop' not in state.last_error
|
||||
):
|
||||
raise EvalException('Fatal error detected: ' + state.last_error)
|
||||
|
||||
finally:
|
||||
runtime.close()
|
||||
|
||||
test_result = {}
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
histories = [event_to_dict(event) for event in state.history.get_events()]
|
||||
metrics = state.metrics.get() if state.metrics else None
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
instance_id=instance.instance_id,
|
||||
instruction=instruction,
|
||||
instance=instance.to_dict(), # SWE Bench specific
|
||||
test_result=test_result,
|
||||
metadata=metadata,
|
||||
history=histories,
|
||||
metrics=metrics,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
TEST_IN_CI,
|
||||
reason='This test should only be run locally, not in CI.',
|
||||
)
|
||||
def test_stress_remote_runtime(n_eval_workers: int = 64):
|
||||
"""Mimic evaluation setting to test remote runtime in a multi-processing setting."""
|
||||
|
||||
llm_config = LLMConfig()
|
||||
metadata = make_metadata(
|
||||
llm_config,
|
||||
'dummy_dataset_descrption',
|
||||
'CodeActAgent',
|
||||
max_iterations=10,
|
||||
eval_note='dummy_eval_note',
|
||||
eval_output_dir='./dummy_eval_output_dir',
|
||||
details={},
|
||||
)
|
||||
|
||||
# generate 300 random dummy instances
|
||||
dummy_instance = pd.DataFrame(
|
||||
{
|
||||
'instance_id': [f'dummy_instance_{i}' for i in range(300)],
|
||||
}
|
||||
)
|
||||
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(
|
||||
dummy_instance, output_file, eval_n_limit=len(dummy_instance)
|
||||
)
|
||||
|
||||
run_evaluation(instances, metadata, output_file, n_eval_workers, process_instance)
|
||||
@@ -7,14 +7,12 @@ from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.state import TrafficControlState
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.exceptions import LLMMalformedActionError
|
||||
from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FatalErrorObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
@@ -45,6 +43,11 @@ def mock_event_stream():
|
||||
return MagicMock(spec=EventStream)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_status_callback():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_agent_state(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
@@ -98,39 +101,19 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_error(mock_agent, mock_event_stream):
|
||||
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
status_callback=mock_status_callback,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
error_message = 'Test error'
|
||||
await controller.report_error(error_message)
|
||||
assert controller.state.last_error == error_message
|
||||
controller.event_stream.add_event.assert_called_once()
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_with_exception(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.report_error = AsyncMock()
|
||||
controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
|
||||
await controller._step()
|
||||
|
||||
# Verify that report_error was called with the correct error message
|
||||
controller.report_error.assert_called_once_with('Malformed action')
|
||||
await controller._react_to_exception(RuntimeError(error_message))
|
||||
controller.status_callback.assert_called_once()
|
||||
await controller.close()
|
||||
|
||||
|
||||
@@ -141,21 +124,24 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
# a random message to send to the runtime
|
||||
event = CmdRunAction(command='ls')
|
||||
agent.step.return_value = event
|
||||
agent = MagicMock(spec=Agent)
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
|
||||
fatal_error_obs = FatalErrorObservation('Fatal error detected')
|
||||
fatal_error_obs._cause = event.id
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
async def on_event(event: Event):
|
||||
if isinstance(event, CmdRunAction):
|
||||
await event_stream.async_add_event(fatal_error_obs, EventSource.USER)
|
||||
error_obs = ErrorObservation('You messed around with Jim')
|
||||
error_obs._cause = event.id
|
||||
event_stream.add_event(error_obs, EventSource.USER)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
||||
runtime.event_stream = event_stream
|
||||
@@ -170,30 +156,23 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream):
|
||||
)
|
||||
print(f'state: {state}')
|
||||
print(f'event_stream: {list(event_stream.get_events())}')
|
||||
assert state.iteration == 1
|
||||
# it will first become AgentState.ERROR, then become AgentState.STOPPED
|
||||
# in side run_controller (since the while loop + sleep no longer loop)
|
||||
assert state.agent_state == AgentState.STOPPED
|
||||
assert (
|
||||
state.last_error
|
||||
== 'There was a fatal error during agent execution: **FatalErrorObservation**\nFatal error detected'
|
||||
)
|
||||
assert len(list(event_stream.get_events())) == 5
|
||||
assert state.iteration == 4
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Agent got stuck in a loop'
|
||||
assert len(list(event_stream.get_events())) == 11
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
|
||||
async def test_run_controller_stop_with_stuck():
|
||||
config = AppConfig()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
# a random message to send to the runtime
|
||||
event = CmdRunAction(command='ls')
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
return event
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
@@ -207,9 +186,7 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
|
||||
'Non fatal error here to trigger loop'
|
||||
)
|
||||
non_fatal_error_obs._cause = event.id
|
||||
await event_stream.async_add_event(
|
||||
non_fatal_error_obs, EventSource.ENVIRONMENT
|
||||
)
|
||||
event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
|
||||
runtime.event_stream = event_stream
|
||||
@@ -228,7 +205,7 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
|
||||
print(f'event {i}: {event_to_dict(event)}')
|
||||
|
||||
assert state.iteration == 4
|
||||
assert len(events) == 12
|
||||
assert len(events) == 11
|
||||
# check the eventstream have 4 pairs of repeated actions and observations
|
||||
repeating_actions_and_observations = events[2:10]
|
||||
for action, observation in zip(
|
||||
@@ -246,13 +223,8 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
|
||||
assert last_event['extras']['agent_state'] == 'error'
|
||||
assert last_event['observation'] == 'agent_state_changed'
|
||||
|
||||
# it will first become AgentState.ERROR, then become AgentState.STOPPED
|
||||
# in side run_controller (since the while loop + sleep no longer loop)
|
||||
assert state.agent_state == AgentState.STOPPED
|
||||
assert (
|
||||
state.last_error
|
||||
== 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop'
|
||||
)
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Agent got stuck in a loop'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -319,7 +291,7 @@ async def test_step_max_iterations(mock_agent, mock_event_stream):
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
assert controller.state.agent_state == AgentState.PAUSED
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
await controller.close()
|
||||
|
||||
|
||||
@@ -359,7 +331,7 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
assert controller.state.agent_state == AgentState.PAUSED
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
await controller.close()
|
||||
|
||||
|
||||
|
||||
@@ -440,9 +440,10 @@ class TestStuckDetector:
|
||||
read_observation_2._cause = read_action_2._id
|
||||
event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
|
||||
|
||||
# one more message to break the pattern
|
||||
message_null_observation = NullObservation(content='')
|
||||
message_action = MessageAction(content='Come on', wait_for_response=False)
|
||||
event_stream.add_event(message_action, EventSource.USER)
|
||||
|
||||
message_null_observation = NullObservation(content='')
|
||||
event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
cmd_action_3 = CmdRunAction(command='ls')
|
||||
|
||||
Reference in New Issue
Block a user