mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Remove global config from session (#2987)
* Remove global config from session * Fix double agent
This commit is contained in:
parent
9d41314d1a
commit
692fe21d60
@ -40,7 +40,9 @@ from opendevin.events.observation import (
|
||||
from opendevin.events.serialization import event_to_dict
|
||||
from opendevin.llm import bedrock
|
||||
from opendevin.server.auth import get_sid_from_token, sign_token
|
||||
from opendevin.server.session import session_manager
|
||||
from opendevin.server.session import SessionManager
|
||||
|
||||
session_manager = SessionManager(config)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from .manager import SessionManager
|
||||
from .session import Session
|
||||
|
||||
session_manager = SessionManager()
|
||||
|
||||
__all__ = ['Session', 'SessionManager', 'session_manager']
|
||||
__all__ = ['Session', 'SessionManager']
|
||||
|
||||
@ -4,11 +4,9 @@ from agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from opendevin.controller import AgentController
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.config import SandboxConfig
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import ConfigType
|
||||
from opendevin.events.stream import EventStream
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.runtime import DockerSSHBox, get_runtime_cls
|
||||
from opendevin.runtime.runtime import Runtime
|
||||
from opendevin.runtime.server.runtime import ServerRuntime
|
||||
@ -32,7 +30,14 @@ class AgentSession:
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream(sid)
|
||||
|
||||
async def start(self, start_event: dict):
|
||||
async def start(
|
||||
self,
|
||||
runtime_name: str,
|
||||
sandbox_config: SandboxConfig,
|
||||
agent: Agent,
|
||||
confirmation_mode: bool,
|
||||
max_iterations: int,
|
||||
):
|
||||
"""Starts the agent session.
|
||||
|
||||
Args:
|
||||
@ -42,8 +47,8 @@ class AgentSession:
|
||||
raise Exception(
|
||||
'Session already started. You need to close this session and start a new one.'
|
||||
)
|
||||
await self._create_runtime()
|
||||
await self._create_controller(start_event)
|
||||
await self._create_runtime(runtime_name, sandbox_config)
|
||||
await self._create_controller(agent, confirmation_mode, max_iterations)
|
||||
|
||||
async def close(self):
|
||||
if self._closed:
|
||||
@ -56,52 +61,28 @@ class AgentSession:
|
||||
await self.runtime.close()
|
||||
self._closed = True
|
||||
|
||||
async def _create_runtime(self):
|
||||
async def _create_runtime(self, runtime_name: str, sandbox_config: SandboxConfig):
|
||||
"""Creates a runtime instance."""
|
||||
if self.runtime is not None:
|
||||
raise Exception('Runtime already created')
|
||||
|
||||
logger.info(f'Using runtime: {config.runtime}')
|
||||
runtime_cls = get_runtime_cls(config.runtime)
|
||||
logger.info(f'Using runtime: {runtime_name}')
|
||||
runtime_cls = get_runtime_cls(runtime_name)
|
||||
self.runtime = runtime_cls(
|
||||
sandbox_config=config.sandbox, event_stream=self.event_stream, sid=self.sid
|
||||
sandbox_config=sandbox_config, event_stream=self.event_stream, sid=self.sid
|
||||
)
|
||||
await self.runtime.ainit()
|
||||
|
||||
async def _create_controller(self, start_event: dict):
|
||||
"""Creates an AgentController instance.
|
||||
|
||||
Args:
|
||||
start_event: The start event data.
|
||||
"""
|
||||
async def _create_controller(
|
||||
self, agent: Agent, confirmation_mode: bool, max_iterations: int
|
||||
):
|
||||
"""Creates an AgentController instance."""
|
||||
if self.controller is not None:
|
||||
raise Exception('Controller already created')
|
||||
if self.runtime is None:
|
||||
raise Exception('Runtime must be initialized before the agent controller')
|
||||
args = {
|
||||
key: value
|
||||
for key, value in start_event.get('args', {}).items()
|
||||
if value != ''
|
||||
} # remove empty values, prevent FE from sending empty strings
|
||||
agent_cls = args.get(ConfigType.AGENT, config.default_agent)
|
||||
confirmation_mode = args.get(
|
||||
ConfigType.CONFIRMATION_MODE, config.confirmation_mode
|
||||
)
|
||||
max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
|
||||
|
||||
# override default LLM config
|
||||
default_llm_config = config.get_llm_config()
|
||||
default_llm_config.model = args.get(
|
||||
ConfigType.LLM_MODEL, default_llm_config.model
|
||||
)
|
||||
default_llm_config.api_key = args.get(
|
||||
ConfigType.LLM_API_KEY, default_llm_config.api_key
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = LLM(config=config.get_llm_config_from_agent(agent_cls))
|
||||
agent = Agent.get_cls(agent_cls)(llm)
|
||||
logger.info(f'Creating agent {agent.name} using LLM {llm}')
|
||||
logger.info(f'Creating agent {agent.name} using LLM {agent.llm.config.model}')
|
||||
if isinstance(agent, CodeActAgent):
|
||||
if not self.runtime or not (
|
||||
isinstance(self.runtime, ServerRuntime)
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from opendevin.core.config import AppConfig
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
|
||||
from .session import Session
|
||||
@ -14,13 +15,14 @@ class SessionManager:
|
||||
cleanup_interval: int = 300
|
||||
session_timeout: int = 600
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: AppConfig):
|
||||
asyncio.create_task(self._cleanup_sessions())
|
||||
self.config = config
|
||||
|
||||
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
|
||||
if sid in self._sessions:
|
||||
asyncio.create_task(self._sessions[sid].close())
|
||||
self._sessions[sid] = Session(sid=sid, ws=ws_conn)
|
||||
self._sessions[sid] = Session(sid=sid, ws=ws_conn, config=self.config)
|
||||
return self._sessions[sid]
|
||||
|
||||
def get_session(self, sid: str) -> Session | None:
|
||||
|
||||
@ -3,10 +3,13 @@ import time
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.core.config import AppConfig
|
||||
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import AgentState
|
||||
from opendevin.core.schema.action import ActionType
|
||||
from opendevin.core.schema.config import ConfigType
|
||||
from opendevin.events.action import Action, ChangeAgentStateAction, NullAction
|
||||
from opendevin.events.event import Event, EventSource
|
||||
from opendevin.events.observation import (
|
||||
@ -16,6 +19,7 @@ from opendevin.events.observation import (
|
||||
)
|
||||
from opendevin.events.serialization import event_from_dict, event_to_dict
|
||||
from opendevin.events.stream import EventStreamSubscriber
|
||||
from opendevin.llm.llm import LLM
|
||||
|
||||
from .agent import AgentSession
|
||||
|
||||
@ -29,7 +33,7 @@ class Session:
|
||||
is_alive: bool = True
|
||||
agent_session: AgentSession
|
||||
|
||||
def __init__(self, sid: str, ws: WebSocket | None):
|
||||
def __init__(self, sid: str, ws: WebSocket | None, config: AppConfig):
|
||||
self.sid = sid
|
||||
self.websocket = ws
|
||||
self.last_active_ts = int(time.time())
|
||||
@ -37,6 +41,7 @@ class Session:
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def close(self):
|
||||
self.is_alive = False
|
||||
@ -67,8 +72,38 @@ class Session:
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
|
||||
)
|
||||
# Extract the agent-relevant arguments from the request
|
||||
args = {
|
||||
key: value for key, value in data.get('args', {}).items() if value != ''
|
||||
}
|
||||
agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
|
||||
confirmation_mode = args.get(
|
||||
ConfigType.CONFIRMATION_MODE, self.config.confirmation_mode
|
||||
)
|
||||
max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
|
||||
# override default LLM config
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = args.get(
|
||||
ConfigType.LLM_MODEL, default_llm_config.model
|
||||
)
|
||||
default_llm_config.api_key = args.get(
|
||||
ConfigType.LLM_API_KEY, default_llm_config.api_key
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
|
||||
agent = Agent.get_cls(agent_cls)(llm)
|
||||
|
||||
# Create the agent session
|
||||
try:
|
||||
await self.agent_session.start(data)
|
||||
await self.agent_session.start(
|
||||
runtime_name=self.config.runtime,
|
||||
sandbox_config=self.config.sandbox,
|
||||
agent=agent,
|
||||
confirmation_mode=confirmation_mode,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
await self.send_error(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user