mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
* CodeActAgent: fix message prep if prompt caching is not supported * fix python version in regen tests workflow * fix in conftest "mock_completion" method * add disable_vision to LLMConfig; revert change in message parsing in llm.py * format messages in several files for completion * refactored message(s) formatting (llm.py); added vision_is_active() * fix a unit test * regenerate: added LOG_TO_FILE and FORCE_REGENERATE env flags * try to fix path to logs folder in workflow * llm: prevent index error * try FORCE_USE_LLM in regenerate * tweaks everywhere... * fix 2 random unit test errors :( * added FORCE_REGENERATE_TESTS=true to regenerate CLI * fix test_lint_file_fail_typescript again * double-quotes for env vars in workflow; llm logger set to debug * fix typo in regenerate * regenerate iterations now 20; applied iteration counter fix by Li * regenerate: pass FORCE_REGENERATE flag into env * fixes for int tests. several mock files updated. * browsing_agent: fix response_parser.py adding ) to empty response * test_browse_internet: fix skipif and revert obsolete mock files * regenerate: fi bracketing for http server start/kill conditions * disable test_browse_internet for CodeAct*Agents; mock files updated after merge * missed to include more mock files earlier * reverts after review feedback from Li * forgot one * browsing agent test, partial fixes and updated mock files * test_browse_internet works in my WSL now! * adapt unit test test_prompt_caching.py * add DEBUG to regenerate workflow command * convert regenerate workflow params to inputs * more integration test mock files updated * more files * test_prompt_caching: restored test_prompt_caching_headers purpose * file_ops: fix potential exception, like "cross device copy"; fixed mock files accordingly * reverts/changes wrt feedback from xingyao * updated docs and config template * code cleanup wrt review feedback
199 lines
7.6 KiB
Python
199 lines
7.6 KiB
Python
import asyncio
|
|
import time
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
from openhands.controller.agent import Agent
|
|
from openhands.core.config import AppConfig
|
|
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.core.schema import AgentState
|
|
from openhands.core.schema.action import ActionType
|
|
from openhands.core.schema.config import ConfigType
|
|
from openhands.events.action import ChangeAgentStateAction, MessageAction, NullAction
|
|
from openhands.events.event import Event, EventSource
|
|
from openhands.events.observation import (
|
|
AgentStateChangedObservation,
|
|
CmdOutputObservation,
|
|
NullObservation,
|
|
)
|
|
from openhands.events.serialization import event_from_dict, event_to_dict
|
|
from openhands.events.stream import EventStreamSubscriber
|
|
from openhands.llm.llm import LLM
|
|
from openhands.server.session.agent import AgentSession
|
|
from openhands.storage.files import FileStore
|
|
|
|
DEL_DELT_SEC = 60 * 60 * 5
|
|
|
|
|
|
class Session:
|
|
sid: str
|
|
websocket: WebSocket | None
|
|
last_active_ts: int = 0
|
|
is_alive: bool = True
|
|
agent_session: AgentSession
|
|
|
|
def __init__(
|
|
self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
|
|
):
|
|
self.sid = sid
|
|
self.websocket = ws
|
|
self.last_active_ts = int(time.time())
|
|
self.agent_session = AgentSession(sid, file_store)
|
|
self.agent_session.event_stream.subscribe(
|
|
EventStreamSubscriber.SERVER, self.on_event
|
|
)
|
|
self.config = config
|
|
|
|
async def close(self):
|
|
self.is_alive = False
|
|
await self.agent_session.close()
|
|
|
|
async def loop_recv(self):
|
|
try:
|
|
if self.websocket is None:
|
|
return
|
|
while True:
|
|
try:
|
|
data = await self.websocket.receive_json()
|
|
except ValueError:
|
|
await self.send_error('Invalid JSON')
|
|
continue
|
|
await self.dispatch(data)
|
|
except WebSocketDisconnect:
|
|
await self.close()
|
|
logger.info('WebSocket disconnected, sid: %s', self.sid)
|
|
except RuntimeError as e:
|
|
await self.close()
|
|
logger.exception('Error in loop_recv: %s', e)
|
|
|
|
async def _initialize_agent(self, data: dict):
|
|
self.agent_session.event_stream.add_event(
|
|
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
|
|
)
|
|
self.agent_session.event_stream.add_event(
|
|
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
|
|
)
|
|
# Extract the agent-relevant arguments from the request
|
|
args = {
|
|
key: value for key, value in data.get('args', {}).items() if value != ''
|
|
}
|
|
agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
|
|
self.config.security.confirmation_mode = args.get(
|
|
ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
|
|
)
|
|
self.config.security.security_analyzer = data.get('args', {}).get(
|
|
ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
|
|
)
|
|
max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
|
|
# override default LLM config
|
|
default_llm_config = self.config.get_llm_config()
|
|
default_llm_config.model = args.get(
|
|
ConfigType.LLM_MODEL, default_llm_config.model
|
|
)
|
|
default_llm_config.api_key = args.get(
|
|
ConfigType.LLM_API_KEY, default_llm_config.api_key
|
|
)
|
|
default_llm_config.base_url = args.get(
|
|
ConfigType.LLM_BASE_URL, default_llm_config.base_url
|
|
)
|
|
|
|
# TODO: override other LLM config & agent config groups (#2075)
|
|
|
|
llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
|
|
agent_config = self.config.get_agent_config(agent_cls)
|
|
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
|
|
|
# Create the agent session
|
|
try:
|
|
await self.agent_session.start(
|
|
runtime_name=self.config.runtime,
|
|
config=self.config,
|
|
agent=agent,
|
|
max_iterations=max_iterations,
|
|
max_budget_per_task=self.config.max_budget_per_task,
|
|
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
|
agent_configs=self.config.get_agent_configs(),
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f'Error creating controller: {e}')
|
|
await self.send_error(
|
|
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
|
|
)
|
|
return
|
|
self.agent_session.event_stream.add_event(
|
|
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
|
|
)
|
|
|
|
async def on_event(self, event: Event):
|
|
"""Callback function for agent events.
|
|
|
|
Args:
|
|
event: The agent event (Observation or Action).
|
|
"""
|
|
if isinstance(event, NullAction):
|
|
return
|
|
if isinstance(event, NullObservation):
|
|
return
|
|
if event.source == EventSource.AGENT:
|
|
logger.info('Server event')
|
|
await self.send(event_to_dict(event))
|
|
elif event.source == EventSource.USER and isinstance(
|
|
event, CmdOutputObservation
|
|
):
|
|
await self.send(event_to_dict(event))
|
|
|
|
async def dispatch(self, data: dict):
|
|
action = data.get('action', '')
|
|
if action == ActionType.INIT:
|
|
await self._initialize_agent(data)
|
|
return
|
|
event = event_from_dict(data.copy())
|
|
# This checks if the model supports images
|
|
if isinstance(event, MessageAction) and event.images_urls:
|
|
controller = self.agent_session.controller
|
|
if controller:
|
|
if controller.agent.llm.config.disable_vision:
|
|
await self.send_error(
|
|
'Support for images is disabled for this model, try without an image.'
|
|
)
|
|
return
|
|
if not controller.agent.llm.vision_is_active():
|
|
await self.send_error(
|
|
'Model does not support image upload, change to a different model or try without an image.'
|
|
)
|
|
return
|
|
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
|
|
|
async def send(self, data: dict[str, object]) -> bool:
|
|
try:
|
|
if self.websocket is None or not self.is_alive:
|
|
return False
|
|
await self.websocket.send_json(data)
|
|
await asyncio.sleep(0.001) # This flushes the data to the client
|
|
self.last_active_ts = int(time.time())
|
|
return True
|
|
except WebSocketDisconnect:
|
|
self.is_alive = False
|
|
return False
|
|
|
|
async def send_error(self, message: str) -> bool:
|
|
"""Sends an error message to the client."""
|
|
return await self.send({'error': True, 'message': message})
|
|
|
|
async def send_message(self, message: str) -> bool:
|
|
"""Sends a message to the client."""
|
|
return await self.send({'message': message})
|
|
|
|
def update_connection(self, ws: WebSocket):
|
|
self.websocket = ws
|
|
self.is_alive = True
|
|
self.last_active_ts = int(time.time())
|
|
|
|
def load_from_data(self, data: dict) -> bool:
|
|
self.last_active_ts = data.get('last_active_ts', 0)
|
|
if self.last_active_ts < int(time.time()) - DEL_DELT_SEC:
|
|
return False
|
|
self.is_alive = data.get('is_alive', False)
|
|
return True
|