diff --git a/frontend/src/services/actions.ts b/frontend/src/services/actions.ts index 04a7029fb8..26068c2388 100644 --- a/frontend/src/services/actions.ts +++ b/frontend/src/services/actions.ts @@ -7,13 +7,11 @@ import { appendSecurityAnalyzerInput, } from "#/state/securityAnalyzerSlice"; import { setCurStatusMessage } from "#/state/statusSlice"; -import { setRootTask } from "#/state/taskSlice"; import store from "#/store"; import ActionType from "#/types/ActionType"; import { ActionMessage, StatusMessage } from "#/types/Message"; import { SocketMessage } from "#/types/ResponseType"; import { handleObservationMessage } from "./observations"; -import { getRootTask } from "./taskService"; const messageActions = { [ActionType.BROWSE]: (message: ActionMessage) => { @@ -75,16 +73,6 @@ const messageActions = { store.dispatch(appendJupyterInput(message.args.code)); } }, - [ActionType.ADD_TASK]: () => { - getRootTask().then((fetchedRootTask) => - store.dispatch(setRootTask(fetchedRootTask)), - ); - }, - [ActionType.MODIFY_TASK]: () => { - getRootTask().then((fetchedRootTask) => - store.dispatch(setRootTask(fetchedRootTask)), - ); - }, }; function getRiskText(risk: ActionSecurityRisk) { diff --git a/frontend/src/services/taskService.ts b/frontend/src/services/taskService.ts deleted file mode 100644 index fa84444e29..0000000000 --- a/frontend/src/services/taskService.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { request } from "./api"; - -export type Task = { - id: string; - goal: string; - subtasks: Task[]; - state: TaskState; -}; - -export enum TaskState { - OPEN_STATE = "open", - COMPLETED_STATE = "completed", - ABANDONED_STATE = "abandoned", - IN_PROGRESS_STATE = "in_progress", - VERIFIED_STATE = "verified", -} - -export async function getRootTask(): Promise { - const res = await request("/api/root_task"); - return res as Task; -} diff --git a/frontend/src/state/taskSlice.ts b/frontend/src/state/taskSlice.ts deleted file mode 100644 index 9726318c56..0000000000 --- a/frontend/src/state/taskSlice.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createSlice } from "@reduxjs/toolkit"; -import { Task, TaskState } from "#/services/taskService"; - -export const taskSlice = createSlice({ - name: "task", - initialState: { - task: { - id: "", - goal: "", - subtasks: [], - state: TaskState.OPEN_STATE, - } as Task, - }, - reducers: { - setRootTask: (state, action) => { - state.task = action.payload as Task; - }, - }, -}); - -export const { setRootTask } = taskSlice.actions; - -export default taskSlice.reducer; diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 971750b685..77d5dfc6fb 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -6,7 +6,6 @@ import codeReducer from "./state/codeSlice"; import fileStateReducer from "./state/file-state-slice"; import initialQueryReducer from "./state/initial-query-slice"; import commandReducer from "./state/commandSlice"; -import taskReducer from "./state/taskSlice"; import jupyterReducer from "./state/jupyterSlice"; import securityAnalyzerReducer from "./state/securityAnalyzerSlice"; import statusReducer from "./state/statusSlice"; @@ -18,7 +17,6 @@ export const rootReducer = combineReducers({ chat: chatReducer, code: codeReducer, cmd: commandReducer, - task: taskReducer, agent: agentReducer, jupyter: jupyterReducer, securityAnalyzer: securityAnalyzerReducer, diff --git a/openhands/events/stream.py b/openhands/events/stream.py index b667202278..8cff229c05 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -21,6 +21,14 @@ class EventStreamSubscriber(str, Enum): TEST = 'test' +def session_exists(sid: str, file_store: FileStore) -> bool: + try: + file_store.list(f'sessions/{sid}') + return True + except FileNotFoundError: + return False + + class EventStream: sid: str file_store: FileStore diff --git a/openhands/runtime/client/runtime.py b/openhands/runtime/client/runtime.py index 7292d221f9..e841e32cb5 100644 --- a/openhands/runtime/client/runtime.py +++ b/openhands/runtime/client/runtime.py @@ -1,7 +1,6 @@ import os import tempfile import threading -import uuid from typing import Callable from zipfile import ZipFile @@ -104,7 +103,6 @@ class LogBuffer: class EventStreamRuntime(Runtime): """This runtime will subscribe the event stream. When receive an event, it will send the event to runtime-client which run inside the docker environment. - From the sid also an instance_id is generated in combination with a UID. Args: config (AppConfig): The application configuration. @@ -114,7 +112,7 @@ class EventStreamRuntime(Runtime): env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None. """ - container_name_prefix = 'openhands-sandbox-' + container_name_prefix = 'openhands-runtime-' def __init__( self, @@ -124,27 +122,24 @@ class EventStreamRuntime(Runtime): plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_message_callback: Callable | None = None, + attach_to_existing: bool = False, ): self.config = config self._host_port = 30000 # initial dummy value self._container_port = 30001 # initial dummy value self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}' self.session = requests.Session() - self.instance_id = ( - sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4()) - ) self.status_message_callback = status_message_callback self.send_status_message('STATUS$STARTING_RUNTIME') self.docker_client: docker.DockerClient = self._init_docker_client() self.base_container_image = self.config.sandbox.base_container_image self.runtime_container_image = self.config.sandbox.runtime_container_image - self.container_name = self.container_name_prefix + self.instance_id + self.container_name = self.container_name_prefix + sid self.container = None self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time self.runtime_builder = DockerRuntimeBuilder(self.docker_client) - logger.debug(f'EventStreamRuntime `{self.instance_id}`') # Buffer for container logs self.log_buffer: LogBuffer | None = None @@ -170,15 +165,25 @@ class EventStreamRuntime(Runtime): extra_deps=self.config.sandbox.runtime_extra_deps, force_rebuild=self.config.sandbox.force_rebuild_runtime, ) - self.container = self._init_container( - sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox, # e.g. /workspace - mount_dir=self.config.workspace_mount_path, # e.g. /opt/openhands/_test_workspace - plugins=plugins, - ) + + if not attach_to_existing: + self._init_container( + sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox, # e.g. /workspace + mount_dir=self.config.workspace_mount_path, # e.g. /opt/openhands/_test_workspace + plugins=plugins, + ) + else: + self._attach_to_container() # will initialize both the event stream and the env vars super().__init__( - config, event_stream, sid, plugins, env_vars, status_message_callback + config, + event_stream, + sid, + plugins, + env_vars, + status_message_callback, + attach_to_existing, ) logger.info('Waiting for client to become ready...') @@ -272,7 +277,7 @@ class EventStreamRuntime(Runtime): else: browsergym_arg = '' - container = self.docker_client.containers.run( + self.container = self.docker_client.containers.run( self.runtime_container_image, command=( f'/openhands/micromamba/bin/micromamba run -n openhands ' @@ -292,18 +297,34 @@ class EventStreamRuntime(Runtime): environment=environment, volumes=volumes, ) - self.log_buffer = LogBuffer(container) + self.log_buffer = LogBuffer(self.container) logger.info(f'Container started. Server url: {self.api_url}') self.send_status_message('STATUS$CONTAINER_STARTED') - return container except Exception as e: logger.error( - f'Error: Instance {self.instance_id} FAILED to start container!\n' + f'Error: Instance {self.container_name} FAILED to start container!\n' ) logger.exception(e) self.close(close_client=False) raise e + def _attach_to_container(self): + container = self.docker_client.containers.get(self.container_name) + self.log_buffer = LogBuffer(container) + self.container = container + self._container_port = 0 + for port in container.attrs['NetworkSettings']['Ports']: + self._container_port = int(port.split('/')[0]) + break + self._host_port = self._container_port + self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}' + logger.info( + 'attached to container:', + self.container_name, + self._container_port, + self.api_url, + ) + def _refresh_logs(self): logger.debug('Getting container logs...') diff --git a/openhands/runtime/remote/runtime.py b/openhands/runtime/remote/runtime.py index c259f16a2d..10d340bb55 100644 --- a/openhands/runtime/remote/runtime.py +++ b/openhands/runtime/remote/runtime.py @@ -51,6 +51,7 @@ class RemoteRuntime(Runtime): plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_message_callback: Optional[Callable] = None, + attach_to_existing: bool = False, ): self.config = config self.status_message_callback = status_message_callback @@ -75,21 +76,31 @@ class RemoteRuntime(Runtime): self.runtime_id: str | None = None self.runtime_url: str | None = None - self.instance_id = sid + self.sid = sid - self._start_or_attach_to_runtime(plugins) + self._start_or_attach_to_runtime(plugins, attach_to_existing) # Initialize the eventstream and env vars super().__init__( - config, event_stream, sid, plugins, env_vars, status_message_callback + config, + event_stream, + sid, + plugins, + env_vars, + status_message_callback, + attach_to_existing, ) self._wait_until_alive() self.setup_initial_env() - def _start_or_attach_to_runtime(self, plugins: list[PluginRequirement] | None): + def _start_or_attach_to_runtime( + self, plugins: list[PluginRequirement] | None, attach_to_existing: bool = False + ): existing_runtime = self._check_existing_runtime() if existing_runtime: logger.info(f'Using existing runtime with ID: {self.runtime_id}') + elif attach_to_existing: + raise RuntimeError('Could not find existing runtime to attach to.') else: self.send_status_message('STATUS$STARTING_CONTAINER') if self.config.sandbox.runtime_container_image is None: @@ -117,7 +128,7 @@ class RemoteRuntime(Runtime): response = send_request_with_retry( self.session, 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.instance_id}', + f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.sid}', timeout=5, ) except Exception as e: @@ -146,7 +157,7 @@ class RemoteRuntime(Runtime): return False def _build_runtime(self): - logger.debug(f'RemoteRuntime `{self.instance_id}` config:\n{self.config}') + logger.debug(f'RemoteRuntime `{self.sid}` config:\n{self.config}') response = send_request_with_retry( self.session, 'GET', @@ -209,7 +220,7 @@ class RemoteRuntime(Runtime): ), 'working_dir': '/openhands/code/', 'environment': {'DEBUG': 'true'} if self.config.debug else {}, - 'runtime_id': self.instance_id, + 'runtime_id': self.sid, } # Start the sandbox using the /start endpoint diff --git a/openhands/runtime/runtime.py b/openhands/runtime/runtime.py index e81169aaf7..7e420643c3 100644 --- a/openhands/runtime/runtime.py +++ b/openhands/runtime/runtime.py @@ -52,6 +52,7 @@ class Runtime: sid: str config: AppConfig initial_env_vars: dict[str, str] + attach_to_existing: bool def __init__( self, @@ -61,12 +62,14 @@ class Runtime: plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_message_callback: Callable | None = None, + attach_to_existing: bool = False, ): self.sid = sid self.event_stream = event_stream self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event) self.plugins = plugins if plugins is not None and len(plugins) > 0 else [] self.status_message_callback = status_message_callback + self.attach_to_existing = attach_to_existing self.config = copy.deepcopy(config) atexit.register(self.close) @@ -76,6 +79,8 @@ class Runtime: self.initial_env_vars.update(env_vars) def setup_initial_env(self) -> None: + if self.attach_to_existing: + return logger.debug(f'Adding env vars: {self.initial_env_vars}') self.add_env_vars(self.initial_env_vars) if self.config.sandbox.runtime_startup_env_vars: diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 040888d13b..d7f1777349 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -25,7 +25,6 @@ from fastapi import ( FastAPI, HTTPException, Request, - Response, UploadFile, WebSocket, status, @@ -40,7 +39,6 @@ import openhands.agenthub # noqa F401 (we import this to get the agents registe from openhands.controller.agent import Agent from openhands.core.config import LLMConfig, load_app_config from openhands.core.logger import openhands_logger as logger -from openhands.core.schema import AgentState # Add this import from openhands.events.action import ( ChangeAgentStateAction, FileReadAction, @@ -213,8 +211,10 @@ async def attach_session(request: Request, call_next): content={'error': 'Invalid token'}, ) - request.state.session = session_manager.get_session(request.state.sid) - if request.state.session is None: + request.state.conversation = session_manager.attach_to_conversation( + request.state.sid + ) + if request.state.conversation is None: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, content={'error': 'Session not found'}, @@ -434,12 +434,13 @@ async def list_files(request: Request, path: str | None = None): Raises: HTTPException: If there's an error listing the files. """ - if not request.state.session.agent_session.runtime: + if not request.state.conversation.runtime: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, content={'error': 'Runtime not yet initialized'}, ) - runtime: Runtime = request.state.session.agent_session.runtime + + runtime: Runtime = request.state.conversation.runtime file_list = await sync_from_async(runtime.list_files, path) if path: file_list = [os.path.join(path, f) for f in file_list] @@ -485,7 +486,7 @@ async def select_file(file: str, request: Request): Raises: HTTPException: If there's an error opening the file. """ - runtime: Runtime = request.state.session.agent_session.runtime + runtime: Runtime = request.state.conversation.runtime file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file) read_action = FileReadAction(file) @@ -567,7 +568,7 @@ async def upload_file(request: Request, files: list[UploadFile]): tmp_file.write(file_contents) tmp_file.flush() - runtime: Runtime = request.state.session.agent_session.runtime + runtime: Runtime = request.state.conversation.runtime runtime.copy_to( tmp_file_path, runtime.config.workspace_mount_path_in_sandbox ) @@ -635,35 +636,6 @@ async def submit_feedback(request: Request, feedback: FeedbackDataModel): ) -@app.get('/api/root_task') -def get_root_task(request: Request): - """Retrieve the root task of the current agent session. - - To get the root_task: - ```sh - curl -H "Authorization: Bearer " http://localhost:3000/api/root_task - ``` - - Args: - request (Request): The incoming request object. - - Returns: - dict: The root task data if available. - - Raises: - HTTPException: If the root task is not available. - """ - controller = request.state.session.agent_session.controller - if controller is not None: - state = controller.get_state() - if state: - return JSONResponse( - status_code=status.HTTP_200_OK, - content=state.root_task.to_dict(), - ) - return Response(status_code=status.HTTP_204_NO_CONTENT) - - @app.get('/api/defaults') async def appconfig_defaults(): """Retrieve the default configuration settings. @@ -700,22 +672,6 @@ async def save_file(request: Request): - 500 error if there's an unexpected error during the save operation. """ try: - # Get the agent's current state - controller = request.state.session.agent_session.controller - agent_state = controller.get_agent_state() - - # Check if the agent is in an allowed state for editing - if agent_state not in [ - AgentState.INIT, - AgentState.PAUSED, - AgentState.FINISHED, - AgentState.AWAITING_USER_INPUT, - ]: - raise HTTPException( - status_code=403, - detail='Code editing is only allowed when the agent is paused, finished, or awaiting user input', - ) - # Extract file path and content from the request data = await request.json() file_path = data.get('filePath') @@ -726,7 +682,7 @@ async def save_file(request: Request): raise HTTPException(status_code=400, detail='Missing filePath or content') # Save the file to the agent's runtime file store - runtime: Runtime = request.state.session.agent_session.runtime + runtime: Runtime = request.state.conversation.runtime file_path = os.path.join( runtime.config.workspace_mount_path_in_sandbox, file_path ) @@ -768,13 +724,11 @@ async def security_api(request: Request): Raises: HTTPException: If the security analyzer is not initialized. """ - if not request.state.session.agent_session.security_analyzer: + if not request.state.conversation.security_analyzer: raise HTTPException(status_code=404, detail='Security analyzer not initialized') - return ( - await request.state.session.agent_session.security_analyzer.handle_api_request( - request - ) + return await request.state.conversation.security_analyzer.handle_api_request( + request ) @@ -782,7 +736,7 @@ async def security_api(request: Request): async def zip_current_workspace(request: Request): try: logger.info('Zipping workspace') - runtime: Runtime = request.state.session.agent_session.runtime + runtime: Runtime = request.state.conversation.runtime path = runtime.config.workspace_mount_path_in_sandbox zip_file_bytes = runtime.copy_from(path) diff --git a/openhands/server/session/conversation.py b/openhands/server/session/conversation.py new file mode 100644 index 0000000000..dc2ca833ff --- /dev/null +++ b/openhands/server/session/conversation.py @@ -0,0 +1,36 @@ +from openhands.core.config import AppConfig +from openhands.events.stream import EventStream +from openhands.runtime import get_runtime_cls +from openhands.runtime.runtime import Runtime +from openhands.security import SecurityAnalyzer, options +from openhands.storage.files import FileStore + + +class Conversation: + sid: str + file_store: FileStore + event_stream: EventStream + runtime: Runtime + + def __init__( + self, + sid: str, + file_store: FileStore, + config: AppConfig, + ): + self.sid = sid + self.config = config + self.file_store = file_store + self.event_stream = EventStream(sid, file_store) + if config.security.security_analyzer: + self.security_analyzer = options.SecurityAnalyzers.get( + config.security.security_analyzer, SecurityAnalyzer + )(self.event_stream) + + runtime_cls = get_runtime_cls(self.config.runtime) + self.runtime = runtime_cls( + config=config, + event_stream=self.event_stream, + sid=self.sid, + attach_to_existing=True, + ) diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 5c81ec1a55..5cc46fff44 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -6,7 +6,9 @@ from fastapi import WebSocket from openhands.core.config import AppConfig from openhands.core.logger import openhands_logger as logger +from openhands.events.stream import session_exists from openhands.runtime.utils.shutdown_listener import should_continue +from openhands.server.session.conversation import Conversation from openhands.server.session.session import Session from openhands.storage.files import FileStore @@ -44,6 +46,11 @@ class SessionManager: return None return self._sessions.get(sid) + def attach_to_conversation(self, sid: str) -> Conversation | None: + if not session_exists(sid, self.file_store): + return None + return Conversation(sid, file_store=self.file_store, config=self.config) + async def send(self, sid: str, data: dict[str, object]) -> bool: """Sends data to the client.""" session = self.get_session(sid)