mirror of
https://github.com/OpenHands/OpenHands.git
synced 2025-12-26 05:48:36 +08:00
Add initial user msg to /new_conversation route (#6314)
This commit is contained in:
parent
2edb2337c2
commit
000055ba73
@ -32,12 +32,14 @@ UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id'
|
||||
class InitSessionRequest(BaseModel):
|
||||
github_token: str | None = None
|
||||
selected_repository: str | None = None
|
||||
initial_user_msg: str | None = None
|
||||
|
||||
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
token: str | None,
|
||||
selected_repository: str | None,
|
||||
initial_user_msg: str | None,
|
||||
):
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
@ -89,7 +91,7 @@ async def _create_new_conversation(
|
||||
|
||||
logger.info(f'Starting agent loop for conversation {conversation_id}')
|
||||
event_stream = await session_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data, user_id
|
||||
conversation_id, conversation_init_data, user_id, initial_user_msg
|
||||
)
|
||||
try:
|
||||
event_stream.subscribe(
|
||||
@ -114,10 +116,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
user_id = get_user_id(request)
|
||||
github_token = getattr(request.state, 'github_token', '') or data.github_token
|
||||
selected_repository = data.selected_repository
|
||||
initial_user_msg = data.initial_user_msg
|
||||
|
||||
try:
|
||||
conversation_id = await _create_new_conversation(
|
||||
user_id, github_token, selected_repository
|
||||
user_id, github_token, selected_repository, initial_user_msg
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@ -140,6 +143,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
'message': str(e),
|
||||
'msg_id': 'STATUS$ERROR_LLM_AUTHENTICATION',
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import ChangeAgentStateAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
@ -71,6 +72,7 @@ class AgentSession:
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
github_token: str | None = None,
|
||||
selected_repository: str | None = None,
|
||||
initial_user_msg: str | None = None,
|
||||
):
|
||||
"""Starts the Agent session
|
||||
Parameters:
|
||||
@ -112,6 +114,12 @@ class AgentSession:
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
if initial_user_msg:
|
||||
self.event_stream.add_event(
|
||||
MessageAction(content=initial_user_msg), EventSource.USER
|
||||
)
|
||||
|
||||
self._starting = False
|
||||
|
||||
async def close(self):
|
||||
|
||||
@ -442,7 +442,11 @@ class SessionManager:
|
||||
self._connection_queries.pop(query_id, None)
|
||||
|
||||
async def maybe_start_agent_loop(
|
||||
self, sid: str, settings: Settings, user_id: str | None
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: str | None = None,
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}')
|
||||
session: Session | None = None
|
||||
@ -462,7 +466,7 @@ class SessionManager:
|
||||
user_id=user_id,
|
||||
)
|
||||
self._local_agent_loops_by_sid[sid] = session
|
||||
asyncio.create_task(session.initialize_agent(settings))
|
||||
asyncio.create_task(session.initialize_agent(settings, initial_user_msg))
|
||||
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
|
||||
@ -74,10 +74,7 @@ class Session:
|
||||
self.is_alive = False
|
||||
await self.agent_session.close()
|
||||
|
||||
async def initialize_agent(
|
||||
self,
|
||||
settings: Settings,
|
||||
):
|
||||
async def initialize_agent(self, settings: Settings, initial_user_msg: str | None):
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING),
|
||||
EventSource.ENVIRONMENT,
|
||||
@ -122,6 +119,7 @@ class Session:
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
github_token=github_token,
|
||||
selected_repository=selected_repository,
|
||||
initial_user_msg=initial_user_msg,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating agent_session: {e}')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user