Allow attaching to existing sessions without reinitializing the runtime (#4329)

Co-authored-by: tofarr <tofarr@gmail.com>
This commit is contained in:
Robert Brennan 2024-10-14 11:24:29 -04:00 committed by GitHub
parent 640ce0f60d
commit 63ff69fd97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 127 additions and 143 deletions

View File

@ -7,13 +7,11 @@ import {
appendSecurityAnalyzerInput,
} from "#/state/securityAnalyzerSlice";
import { setCurStatusMessage } from "#/state/statusSlice";
import { setRootTask } from "#/state/taskSlice";
import store from "#/store";
import ActionType from "#/types/ActionType";
import { ActionMessage, StatusMessage } from "#/types/Message";
import { SocketMessage } from "#/types/ResponseType";
import { handleObservationMessage } from "./observations";
import { getRootTask } from "./taskService";
const messageActions = {
[ActionType.BROWSE]: (message: ActionMessage) => {
@ -75,16 +73,6 @@ const messageActions = {
store.dispatch(appendJupyterInput(message.args.code));
}
},
[ActionType.ADD_TASK]: () => {
getRootTask().then((fetchedRootTask) =>
store.dispatch(setRootTask(fetchedRootTask)),
);
},
[ActionType.MODIFY_TASK]: () => {
getRootTask().then((fetchedRootTask) =>
store.dispatch(setRootTask(fetchedRootTask)),
);
},
};
function getRiskText(risk: ActionSecurityRisk) {

View File

@ -1,21 +0,0 @@
import { request } from "./api";
export type Task = {
id: string;
goal: string;
subtasks: Task[];
state: TaskState;
};
export enum TaskState {
OPEN_STATE = "open",
COMPLETED_STATE = "completed",
ABANDONED_STATE = "abandoned",
IN_PROGRESS_STATE = "in_progress",
VERIFIED_STATE = "verified",
}
export async function getRootTask(): Promise<Task | undefined> {
const res = await request("/api/root_task");
return res as Task;
}

View File

@ -1,23 +0,0 @@
import { createSlice } from "@reduxjs/toolkit";
import { Task, TaskState } from "#/services/taskService";
export const taskSlice = createSlice({
name: "task",
initialState: {
task: {
id: "",
goal: "",
subtasks: [],
state: TaskState.OPEN_STATE,
} as Task,
},
reducers: {
setRootTask: (state, action) => {
state.task = action.payload as Task;
},
},
});
export const { setRootTask } = taskSlice.actions;
export default taskSlice.reducer;

View File

@ -6,7 +6,6 @@ import codeReducer from "./state/codeSlice";
import fileStateReducer from "./state/file-state-slice";
import initialQueryReducer from "./state/initial-query-slice";
import commandReducer from "./state/commandSlice";
import taskReducer from "./state/taskSlice";
import jupyterReducer from "./state/jupyterSlice";
import securityAnalyzerReducer from "./state/securityAnalyzerSlice";
import statusReducer from "./state/statusSlice";
@ -18,7 +17,6 @@ export const rootReducer = combineReducers({
chat: chatReducer,
code: codeReducer,
cmd: commandReducer,
task: taskReducer,
agent: agentReducer,
jupyter: jupyterReducer,
securityAnalyzer: securityAnalyzerReducer,

View File

@ -21,6 +21,14 @@ class EventStreamSubscriber(str, Enum):
TEST = 'test'
def session_exists(sid: str, file_store: FileStore) -> bool:
try:
file_store.list(f'sessions/{sid}')
return True
except FileNotFoundError:
return False
class EventStream:
sid: str
file_store: FileStore

View File

@ -1,7 +1,6 @@
import os
import tempfile
import threading
import uuid
from typing import Callable
from zipfile import ZipFile
@ -104,7 +103,6 @@ class LogBuffer:
class EventStreamRuntime(Runtime):
"""This runtime will subscribe the event stream.
When receive an event, it will send the event to runtime-client which run inside the docker environment.
From the sid also an instance_id is generated in combination with a UID.
Args:
config (AppConfig): The application configuration.
@ -114,7 +112,7 @@ class EventStreamRuntime(Runtime):
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
"""
container_name_prefix = 'openhands-sandbox-'
container_name_prefix = 'openhands-runtime-'
def __init__(
self,
@ -124,27 +122,24 @@ class EventStreamRuntime(Runtime):
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
attach_to_existing: bool = False,
):
self.config = config
self._host_port = 30000 # initial dummy value
self._container_port = 30001 # initial dummy value
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
self.session = requests.Session()
self.instance_id = (
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
)
self.status_message_callback = status_message_callback
self.send_status_message('STATUS$STARTING_RUNTIME')
self.docker_client: docker.DockerClient = self._init_docker_client()
self.base_container_image = self.config.sandbox.base_container_image
self.runtime_container_image = self.config.sandbox.runtime_container_image
self.container_name = self.container_name_prefix + self.instance_id
self.container_name = self.container_name_prefix + sid
self.container = None
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
self.runtime_builder = DockerRuntimeBuilder(self.docker_client)
logger.debug(f'EventStreamRuntime `{self.instance_id}`')
# Buffer for container logs
self.log_buffer: LogBuffer | None = None
@ -170,15 +165,25 @@ class EventStreamRuntime(Runtime):
extra_deps=self.config.sandbox.runtime_extra_deps,
force_rebuild=self.config.sandbox.force_rebuild_runtime,
)
self.container = self._init_container(
sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox, # e.g. /workspace
mount_dir=self.config.workspace_mount_path, # e.g. /opt/openhands/_test_workspace
plugins=plugins,
)
if not attach_to_existing:
self._init_container(
sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox, # e.g. /workspace
mount_dir=self.config.workspace_mount_path, # e.g. /opt/openhands/_test_workspace
plugins=plugins,
)
else:
self._attach_to_container()
# will initialize both the event stream and the env vars
super().__init__(
config, event_stream, sid, plugins, env_vars, status_message_callback
config,
event_stream,
sid,
plugins,
env_vars,
status_message_callback,
attach_to_existing,
)
logger.info('Waiting for client to become ready...')
@ -272,7 +277,7 @@ class EventStreamRuntime(Runtime):
else:
browsergym_arg = ''
container = self.docker_client.containers.run(
self.container = self.docker_client.containers.run(
self.runtime_container_image,
command=(
f'/openhands/micromamba/bin/micromamba run -n openhands '
@ -292,18 +297,34 @@ class EventStreamRuntime(Runtime):
environment=environment,
volumes=volumes,
)
self.log_buffer = LogBuffer(container)
self.log_buffer = LogBuffer(self.container)
logger.info(f'Container started. Server url: {self.api_url}')
self.send_status_message('STATUS$CONTAINER_STARTED')
return container
except Exception as e:
logger.error(
f'Error: Instance {self.instance_id} FAILED to start container!\n'
f'Error: Instance {self.container_name} FAILED to start container!\n'
)
logger.exception(e)
self.close(close_client=False)
raise e
def _attach_to_container(self):
container = self.docker_client.containers.get(self.container_name)
self.log_buffer = LogBuffer(container)
self.container = container
self._container_port = 0
for port in container.attrs['NetworkSettings']['Ports']:
self._container_port = int(port.split('/')[0])
break
self._host_port = self._container_port
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
logger.info(
'attached to container:',
self.container_name,
self._container_port,
self.api_url,
)
def _refresh_logs(self):
logger.debug('Getting container logs...')

View File

@ -51,6 +51,7 @@ class RemoteRuntime(Runtime):
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Optional[Callable] = None,
attach_to_existing: bool = False,
):
self.config = config
self.status_message_callback = status_message_callback
@ -75,21 +76,31 @@ class RemoteRuntime(Runtime):
self.runtime_id: str | None = None
self.runtime_url: str | None = None
self.instance_id = sid
self.sid = sid
self._start_or_attach_to_runtime(plugins)
self._start_or_attach_to_runtime(plugins, attach_to_existing)
# Initialize the eventstream and env vars
super().__init__(
config, event_stream, sid, plugins, env_vars, status_message_callback
config,
event_stream,
sid,
plugins,
env_vars,
status_message_callback,
attach_to_existing,
)
self._wait_until_alive()
self.setup_initial_env()
def _start_or_attach_to_runtime(self, plugins: list[PluginRequirement] | None):
def _start_or_attach_to_runtime(
self, plugins: list[PluginRequirement] | None, attach_to_existing: bool = False
):
existing_runtime = self._check_existing_runtime()
if existing_runtime:
logger.info(f'Using existing runtime with ID: {self.runtime_id}')
elif attach_to_existing:
raise RuntimeError('Could not find existing runtime to attach to.')
else:
self.send_status_message('STATUS$STARTING_CONTAINER')
if self.config.sandbox.runtime_container_image is None:
@ -117,7 +128,7 @@ class RemoteRuntime(Runtime):
response = send_request_with_retry(
self.session,
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.instance_id}',
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.sid}',
timeout=5,
)
except Exception as e:
@ -146,7 +157,7 @@ class RemoteRuntime(Runtime):
return False
def _build_runtime(self):
logger.debug(f'RemoteRuntime `{self.instance_id}` config:\n{self.config}')
logger.debug(f'RemoteRuntime `{self.sid}` config:\n{self.config}')
response = send_request_with_retry(
self.session,
'GET',
@ -209,7 +220,7 @@ class RemoteRuntime(Runtime):
),
'working_dir': '/openhands/code/',
'environment': {'DEBUG': 'true'} if self.config.debug else {},
'runtime_id': self.instance_id,
'runtime_id': self.sid,
}
# Start the sandbox using the /start endpoint

View File

@ -52,6 +52,7 @@ class Runtime:
sid: str
config: AppConfig
initial_env_vars: dict[str, str]
attach_to_existing: bool
def __init__(
self,
@ -61,12 +62,14 @@ class Runtime:
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
attach_to_existing: bool = False,
):
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
self.status_message_callback = status_message_callback
self.attach_to_existing = attach_to_existing
self.config = copy.deepcopy(config)
atexit.register(self.close)
@ -76,6 +79,8 @@ class Runtime:
self.initial_env_vars.update(env_vars)
def setup_initial_env(self) -> None:
if self.attach_to_existing:
return
logger.debug(f'Adding env vars: {self.initial_env_vars}')
self.add_env_vars(self.initial_env_vars)
if self.config.sandbox.runtime_startup_env_vars:

View File

@ -25,7 +25,6 @@ from fastapi import (
FastAPI,
HTTPException,
Request,
Response,
UploadFile,
WebSocket,
status,
@ -40,7 +39,6 @@ import openhands.agenthub # noqa F401 (we import this to get the agents registe
from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig, load_app_config
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState # Add this import
from openhands.events.action import (
ChangeAgentStateAction,
FileReadAction,
@ -213,8 +211,10 @@ async def attach_session(request: Request, call_next):
content={'error': 'Invalid token'},
)
request.state.session = session_manager.get_session(request.state.sid)
if request.state.session is None:
request.state.conversation = session_manager.attach_to_conversation(
request.state.sid
)
if request.state.conversation is None:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Session not found'},
@ -434,12 +434,13 @@ async def list_files(request: Request, path: str | None = None):
Raises:
HTTPException: If there's an error listing the files.
"""
if not request.state.session.agent_session.runtime:
if not request.state.conversation.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Runtime not yet initialized'},
)
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
file_list = await sync_from_async(runtime.list_files, path)
if path:
file_list = [os.path.join(path, f) for f in file_list]
@ -485,7 +486,7 @@ async def select_file(file: str, request: Request):
Raises:
HTTPException: If there's an error opening the file.
"""
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
read_action = FileReadAction(file)
@ -567,7 +568,7 @@ async def upload_file(request: Request, files: list[UploadFile]):
tmp_file.write(file_contents)
tmp_file.flush()
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
runtime.copy_to(
tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
)
@ -635,35 +636,6 @@ async def submit_feedback(request: Request, feedback: FeedbackDataModel):
)
@app.get('/api/root_task')
def get_root_task(request: Request):
"""Retrieve the root task of the current agent session.
To get the root_task:
```sh
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task
```
Args:
request (Request): The incoming request object.
Returns:
dict: The root task data if available.
Raises:
HTTPException: If the root task is not available.
"""
controller = request.state.session.agent_session.controller
if controller is not None:
state = controller.get_state()
if state:
return JSONResponse(
status_code=status.HTTP_200_OK,
content=state.root_task.to_dict(),
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@app.get('/api/defaults')
async def appconfig_defaults():
"""Retrieve the default configuration settings.
@ -700,22 +672,6 @@ async def save_file(request: Request):
- 500 error if there's an unexpected error during the save operation.
"""
try:
# Get the agent's current state
controller = request.state.session.agent_session.controller
agent_state = controller.get_agent_state()
# Check if the agent is in an allowed state for editing
if agent_state not in [
AgentState.INIT,
AgentState.PAUSED,
AgentState.FINISHED,
AgentState.AWAITING_USER_INPUT,
]:
raise HTTPException(
status_code=403,
detail='Code editing is only allowed when the agent is paused, finished, or awaiting user input',
)
# Extract file path and content from the request
data = await request.json()
file_path = data.get('filePath')
@ -726,7 +682,7 @@ async def save_file(request: Request):
raise HTTPException(status_code=400, detail='Missing filePath or content')
# Save the file to the agent's runtime file store
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
file_path = os.path.join(
runtime.config.workspace_mount_path_in_sandbox, file_path
)
@ -768,13 +724,11 @@ async def security_api(request: Request):
Raises:
HTTPException: If the security analyzer is not initialized.
"""
if not request.state.session.agent_session.security_analyzer:
if not request.state.conversation.security_analyzer:
raise HTTPException(status_code=404, detail='Security analyzer not initialized')
return (
await request.state.session.agent_session.security_analyzer.handle_api_request(
request
)
return await request.state.conversation.security_analyzer.handle_api_request(
request
)
@ -782,7 +736,7 @@ async def security_api(request: Request):
async def zip_current_workspace(request: Request):
try:
logger.info('Zipping workspace')
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
path = runtime.config.workspace_mount_path_in_sandbox
zip_file_bytes = runtime.copy_from(path)

View File

@ -0,0 +1,36 @@
from openhands.core.config import AppConfig
from openhands.events.stream import EventStream
from openhands.runtime import get_runtime_cls
from openhands.runtime.runtime import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.storage.files import FileStore
class Conversation:
sid: str
file_store: FileStore
event_stream: EventStream
runtime: Runtime
def __init__(
self,
sid: str,
file_store: FileStore,
config: AppConfig,
):
self.sid = sid
self.config = config
self.file_store = file_store
self.event_stream = EventStream(sid, file_store)
if config.security.security_analyzer:
self.security_analyzer = options.SecurityAnalyzers.get(
config.security.security_analyzer, SecurityAnalyzer
)(self.event_stream)
runtime_cls = get_runtime_cls(self.config.runtime)
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
sid=self.sid,
attach_to_existing=True,
)

View File

@ -6,7 +6,9 @@ from fastapi import WebSocket
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import session_exists
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import Session
from openhands.storage.files import FileStore
@ -44,6 +46,11 @@ class SessionManager:
return None
return self._sessions.get(sid)
def attach_to_conversation(self, sid: str) -> Conversation | None:
if not session_exists(sid, self.file_store):
return None
return Conversation(sid, file_store=self.file_store, config=self.config)
async def send(self, sid: str, data: dict[str, object]) -> bool:
"""Sends data to the client."""
session = self.get_session(sid)