mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Feat: Introduce class for SessionInitData rather than using a dict (#5406)
This commit is contained in:
parent
1146b6248b
commit
de81020a8d
@ -95,10 +95,10 @@ workspace_base = "./workspace"
|
||||
# AWS secret access key
|
||||
#aws_secret_access_key = ""
|
||||
|
||||
# API key to use
|
||||
# API key to use (For Headless / CLI only - In Web this is overridden by Session Init)
|
||||
api_key = "your-api-key"
|
||||
|
||||
# API base URL
|
||||
# API base URL (For Headless / CLI only - In Web this is overridden by Session Init)
|
||||
#base_url = ""
|
||||
|
||||
# API version
|
||||
@ -131,7 +131,7 @@ embedding_model = "local"
|
||||
# Maximum number of output tokens
|
||||
#max_output_tokens = 0
|
||||
|
||||
# Model to use
|
||||
# Model to use. (For Headless / CLI only - In Web this is overridden by Session Init)
|
||||
model = "gpt-4o"
|
||||
|
||||
# Number of retries to attempt when an operation fails with the LLM.
|
||||
@ -237,10 +237,10 @@ llm_config = 'gpt3'
|
||||
##############################################################################
|
||||
[security]
|
||||
|
||||
# Enable confirmation mode
|
||||
# Enable confirmation mode (For Headless / CLI only - In Web this is overridden by Session Init)
|
||||
#confirmation_mode = false
|
||||
|
||||
# The security analyzer to use
|
||||
# The security analyzer to use (For Headless / CLI only - In Web this is overridden by Session Init)
|
||||
#security_analyzer = ""
|
||||
|
||||
#################################### Eval ####################################
|
||||
|
||||
@ -11,6 +11,7 @@ from openhands.events.stream import EventStream, session_exists
|
||||
from openhands.runtime.base import RuntimeUnavailableError
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.server.session.session_init_data import SessionInitData
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
@ -141,7 +142,7 @@ class SessionManager:
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
await conversation.disconnect()
|
||||
|
||||
async def init_or_join_session(self, sid: str, connection_id: str, data: dict):
|
||||
async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self.local_connection_id_to_session_id[connection_id] = sid
|
||||
|
||||
@ -156,7 +157,7 @@ class SessionManager:
|
||||
if redis_client and await self._is_session_running_in_cluster(sid):
|
||||
return EventStream(sid, self.file_store)
|
||||
|
||||
return await self.start_local_session(sid, data)
|
||||
return await self.start_local_session(sid, session_init_data)
|
||||
|
||||
async def _is_session_running_in_cluster(self, sid: str) -> bool:
|
||||
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
|
||||
@ -210,14 +211,14 @@ class SessionManager:
|
||||
finally:
|
||||
self._has_remote_connections_flags.pop(sid)
|
||||
|
||||
async def start_local_session(self, sid: str, data: dict):
|
||||
async def start_local_session(self, sid: str, session_init_data: SessionInitData):
|
||||
# Start a new local session
|
||||
logger.info(f'start_new_local_session:{sid}')
|
||||
session = Session(
|
||||
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
|
||||
)
|
||||
self.local_sessions_by_sid[sid] = session
|
||||
await session.initialize_agent(data)
|
||||
await session.initialize_agent(session_init_data)
|
||||
return session.agent_session.event_stream
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
import time
|
||||
|
||||
import socketio
|
||||
@ -21,6 +22,7 @@ from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.server.session.session_init_data import SessionInitData
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
ROOM_KEY = 'room:{sid}'
|
||||
@ -34,7 +36,6 @@ class Session:
|
||||
agent_session: AgentSession
|
||||
loop: asyncio.AbstractEventLoop
|
||||
config: AppConfig
|
||||
settings: dict | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -52,41 +53,31 @@ class Session:
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||
)
|
||||
self.config = config
|
||||
# Copying this means that when we update variables they are not applied to the shared global configuration!
|
||||
self.config = deepcopy(config)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.settings = None
|
||||
|
||||
def close(self):
|
||||
self.is_alive = False
|
||||
self.agent_session.close()
|
||||
|
||||
async def initialize_agent(self, data: dict):
|
||||
self.settings = data
|
||||
async def initialize_agent(self, session_init_data: SessionInitData):
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
# Extract the agent-relevant arguments from the request
|
||||
args = {key: value for key, value in data.get('args', {}).items()}
|
||||
agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
|
||||
self.config.security.confirmation_mode = args.get(
|
||||
ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
|
||||
)
|
||||
self.config.security.security_analyzer = data.get('args', {}).get(
|
||||
ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
|
||||
)
|
||||
max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
|
||||
agent_cls = session_init_data.agent or self.config.default_agent
|
||||
self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
|
||||
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
|
||||
max_iterations = session_init_data.max_iterations or self.config.max_iterations
|
||||
# override default LLM config
|
||||
|
||||
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = args.get(
|
||||
ConfigType.LLM_MODEL, default_llm_config.model
|
||||
)
|
||||
default_llm_config.api_key = args.get(
|
||||
ConfigType.LLM_API_KEY, default_llm_config.api_key
|
||||
)
|
||||
default_llm_config.base_url = args.get(
|
||||
ConfigType.LLM_BASE_URL, default_llm_config.base_url
|
||||
)
|
||||
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
|
||||
default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
|
||||
default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
|
||||
18
openhands/server/session/session_init_data.py
Normal file
18
openhands/server/session/session_init_data.py
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInitData:
|
||||
"""
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
language: str | None = None
|
||||
agent: str | None = None
|
||||
max_iterations: int | None = None
|
||||
security_analyzer: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
llm_model: str | None = None
|
||||
llm_api_key: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
@ -13,6 +13,7 @@ from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.server.auth import get_sid_from_token, sign_token
|
||||
from openhands.server.github_utils import authenticate_github_user
|
||||
from openhands.server.session.session_init_data import SessionInitData
|
||||
from openhands.server.shared import config, session_manager, sio
|
||||
|
||||
|
||||
@ -26,19 +27,30 @@ async def oh_action(connection_id: str, data: dict):
|
||||
# If it's an init, we do it here.
|
||||
action = data.get('action', '')
|
||||
if action == ActionType.INIT:
|
||||
await init_connection(connection_id, data)
|
||||
token = data.pop('token', None)
|
||||
github_token = data.pop('github_token', None)
|
||||
latest_event_id = int(data.pop('latest_event_id', -1))
|
||||
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
|
||||
session_init_data = SessionInitData(**kwargs)
|
||||
await init_connection(
|
||||
connection_id, token, github_token, session_init_data, latest_event_id
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f'sio:oh_action:{connection_id}')
|
||||
await session_manager.send_to_event_stream(connection_id, data)
|
||||
|
||||
|
||||
async def init_connection(connection_id: str, data: dict):
|
||||
gh_token = data.pop('github_token', None)
|
||||
async def init_connection(
|
||||
connection_id: str,
|
||||
token: str | None,
|
||||
gh_token: str | None,
|
||||
session_init_data: SessionInitData,
|
||||
latest_event_id: int,
|
||||
):
|
||||
if not await authenticate_github_user(gh_token):
|
||||
raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
|
||||
|
||||
token = data.pop('token', None)
|
||||
if token:
|
||||
sid = get_sid_from_token(token, config.jwt_secret)
|
||||
if sid == '':
|
||||
@ -52,10 +64,10 @@ async def init_connection(connection_id: str, data: dict):
|
||||
token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
|
||||
|
||||
latest_event_id = int(data.pop('latest_event_id', -1))
|
||||
|
||||
# The session in question should exist, but may not actually be running locally...
|
||||
event_stream = await session_manager.init_or_join_session(sid, connection_id, data)
|
||||
event_stream = await session_manager.init_or_join_session(
|
||||
sid, connection_id, session_init_data
|
||||
)
|
||||
|
||||
# Send events
|
||||
agent_state_changed = None
|
||||
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.server.session.manager import SessionManager
|
||||
from openhands.server.session.session_init_data import SessionInitData
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@ -100,7 +101,7 @@ async def test_init_new_local_session():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.init_or_join_session(
|
||||
'new-session-id', 'new-session-id', {'type': 'mock-settings'}
|
||||
'new-session-id', 'new-session-id', SessionInitData()
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 1
|
||||
@ -132,11 +133,11 @@ async def test_join_local_session():
|
||||
) as session_manager:
|
||||
# First call initializes
|
||||
await session_manager.init_or_join_session(
|
||||
'new-session-id', 'new-session-id', {'type': 'mock-settings'}
|
||||
'new-session-id', 'new-session-id', SessionInitData()
|
||||
)
|
||||
# Second call joins
|
||||
await session_manager.init_or_join_session(
|
||||
'new-session-id', 'extra-connection-id', {'type': 'mock-settings'}
|
||||
'new-session-id', 'extra-connection-id', SessionInitData()
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 2
|
||||
@ -168,7 +169,7 @@ async def test_join_cluster_session():
|
||||
) as session_manager:
|
||||
# First call initializes
|
||||
await session_manager.init_or_join_session(
|
||||
'new-session-id', 'new-session-id', {'type': 'mock-settings'}
|
||||
'new-session-id', 'new-session-id', SessionInitData()
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 0
|
||||
assert sio.enter_room.await_count == 1
|
||||
@ -199,7 +200,7 @@ async def test_add_to_local_event_stream():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.init_or_join_session(
|
||||
'new-session-id', 'connection-id', {'type': 'mock-settings'}
|
||||
'new-session-id', 'connection-id', SessionInitData()
|
||||
)
|
||||
await session_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
@ -232,7 +233,7 @@ async def test_add_to_cluster_event_stream():
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.init_or_join_session(
|
||||
'new-session-id', 'connection-id', {'type': 'mock-settings'}
|
||||
'new-session-id', 'connection-id', SessionInitData()
|
||||
)
|
||||
await session_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user