Remove global config from session (#2987)

* Remove global config from session

* Fix double agent
This commit is contained in:
Graham Neubig 2024-07-18 11:39:38 -04:00 committed by GitHub
parent 9d41314d1a
commit 692fe21d60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 66 additions and 48 deletions

View File

@ -40,7 +40,9 @@ from opendevin.events.observation import (
from opendevin.events.serialization import event_to_dict
from opendevin.llm import bedrock
from opendevin.server.auth import get_sid_from_token, sign_token
from opendevin.server.session import session_manager
from opendevin.server.session import SessionManager
session_manager = SessionManager(config)
app = FastAPI()
app.add_middleware(

View File

@ -1,6 +1,4 @@
from .manager import SessionManager
from .session import Session
session_manager = SessionManager()
__all__ = ['Session', 'SessionManager', 'session_manager']
__all__ = ['Session', 'SessionManager']

View File

@ -4,11 +4,9 @@ from agenthub.codeact_agent.codeact_agent import CodeActAgent
from opendevin.controller import AgentController
from opendevin.controller.agent import Agent
from opendevin.controller.state.state import State
from opendevin.core.config import config
from opendevin.core.config import SandboxConfig
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import ConfigType
from opendevin.events.stream import EventStream
from opendevin.llm.llm import LLM
from opendevin.runtime import DockerSSHBox, get_runtime_cls
from opendevin.runtime.runtime import Runtime
from opendevin.runtime.server.runtime import ServerRuntime
@ -32,7 +30,14 @@ class AgentSession:
self.sid = sid
self.event_stream = EventStream(sid)
async def start(self, start_event: dict):
async def start(
self,
runtime_name: str,
sandbox_config: SandboxConfig,
agent: Agent,
confirmation_mode: bool,
max_iterations: int,
):
"""Starts the agent session.
Args:
@ -42,8 +47,8 @@ class AgentSession:
raise Exception(
'Session already started. You need to close this session and start a new one.'
)
await self._create_runtime()
await self._create_controller(start_event)
await self._create_runtime(runtime_name, sandbox_config)
await self._create_controller(agent, confirmation_mode, max_iterations)
async def close(self):
if self._closed:
@ -56,52 +61,28 @@ class AgentSession:
await self.runtime.close()
self._closed = True
async def _create_runtime(self):
async def _create_runtime(self, runtime_name: str, sandbox_config: SandboxConfig):
"""Creates a runtime instance."""
if self.runtime is not None:
raise Exception('Runtime already created')
logger.info(f'Using runtime: {config.runtime}')
runtime_cls = get_runtime_cls(config.runtime)
logger.info(f'Using runtime: {runtime_name}')
runtime_cls = get_runtime_cls(runtime_name)
self.runtime = runtime_cls(
sandbox_config=config.sandbox, event_stream=self.event_stream, sid=self.sid
sandbox_config=sandbox_config, event_stream=self.event_stream, sid=self.sid
)
await self.runtime.ainit()
async def _create_controller(self, start_event: dict):
"""Creates an AgentController instance.
Args:
start_event: The start event data.
"""
async def _create_controller(
self, agent: Agent, confirmation_mode: bool, max_iterations: int
):
"""Creates an AgentController instance."""
if self.controller is not None:
raise Exception('Controller already created')
if self.runtime is None:
raise Exception('Runtime must be initialized before the agent controller')
args = {
key: value
for key, value in start_event.get('args', {}).items()
if value != ''
} # remove empty values, prevent FE from sending empty strings
agent_cls = args.get(ConfigType.AGENT, config.default_agent)
confirmation_mode = args.get(
ConfigType.CONFIRMATION_MODE, config.confirmation_mode
)
max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
# override default LLM config
default_llm_config = config.get_llm_config()
default_llm_config.model = args.get(
ConfigType.LLM_MODEL, default_llm_config.model
)
default_llm_config.api_key = args.get(
ConfigType.LLM_API_KEY, default_llm_config.api_key
)
# TODO: override other LLM config & agent config groups (#2075)
llm = LLM(config=config.get_llm_config_from_agent(agent_cls))
agent = Agent.get_cls(agent_cls)(llm)
logger.info(f'Creating agent {agent.name} using LLM {llm}')
logger.info(f'Creating agent {agent.name} using LLM {agent.llm.config.model}')
if isinstance(agent, CodeActAgent):
if not self.runtime or not (
isinstance(self.runtime, ServerRuntime)

View File

@ -4,6 +4,7 @@ from typing import Optional
from fastapi import WebSocket
from opendevin.core.config import AppConfig
from opendevin.core.logger import opendevin_logger as logger
from .session import Session
@ -14,13 +15,14 @@ class SessionManager:
cleanup_interval: int = 300
session_timeout: int = 600
def __init__(self):
def __init__(self, config: AppConfig):
asyncio.create_task(self._cleanup_sessions())
self.config = config
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
if sid in self._sessions:
asyncio.create_task(self._sessions[sid].close())
self._sessions[sid] = Session(sid=sid, ws=ws_conn)
self._sessions[sid] = Session(sid=sid, ws=ws_conn, config=self.config)
return self._sessions[sid]
def get_session(self, sid: str) -> Session | None:

View File

@ -3,10 +3,13 @@ import time
from fastapi import WebSocket, WebSocketDisconnect
from opendevin.controller.agent import Agent
from opendevin.core.config import AppConfig
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import AgentState
from opendevin.core.schema.action import ActionType
from opendevin.core.schema.config import ConfigType
from opendevin.events.action import Action, ChangeAgentStateAction, NullAction
from opendevin.events.event import Event, EventSource
from opendevin.events.observation import (
@ -16,6 +19,7 @@ from opendevin.events.observation import (
)
from opendevin.events.serialization import event_from_dict, event_to_dict
from opendevin.events.stream import EventStreamSubscriber
from opendevin.llm.llm import LLM
from .agent import AgentSession
@ -29,7 +33,7 @@ class Session:
is_alive: bool = True
agent_session: AgentSession
def __init__(self, sid: str, ws: WebSocket | None):
def __init__(self, sid: str, ws: WebSocket | None, config: AppConfig):
self.sid = sid
self.websocket = ws
self.last_active_ts = int(time.time())
@ -37,6 +41,7 @@ class Session:
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event
)
self.config = config
async def close(self):
self.is_alive = False
@ -67,8 +72,38 @@ class Session:
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
)
# Extract the agent-relevant arguments from the request
args = {
key: value for key, value in data.get('args', {}).items() if value != ''
}
agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
confirmation_mode = args.get(
ConfigType.CONFIRMATION_MODE, self.config.confirmation_mode
)
max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
# override default LLM config
default_llm_config = self.config.get_llm_config()
default_llm_config.model = args.get(
ConfigType.LLM_MODEL, default_llm_config.model
)
default_llm_config.api_key = args.get(
ConfigType.LLM_API_KEY, default_llm_config.api_key
)
# TODO: override other LLM config & agent config groups (#2075)
llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
agent = Agent.get_cls(agent_cls)(llm)
# Create the agent session
try:
await self.agent_session.start(data)
await self.agent_session.start(
runtime_name=self.config.runtime,
sandbox_config=self.config.sandbox,
agent=agent,
confirmation_mode=confirmation_mode,
max_iterations=max_iterations,
)
except Exception as e:
logger.exception(f'Error creating controller: {e}')
await self.send_error(