Add initial user msg to /new_conversation route (#6314)

This commit is contained in:
Rohit Malhotra 2025-01-17 09:43:03 -05:00 committed by GitHub
parent 2edb2337c2
commit 000055ba73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 8 deletions

View File

@ -32,12 +32,14 @@ UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id'
class InitSessionRequest(BaseModel):
github_token: str | None = None
selected_repository: str | None = None
initial_user_msg: str | None = None
async def _create_new_conversation(
user_id: str | None,
token: str | None,
selected_repository: str | None,
initial_user_msg: str | None,
):
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
@ -89,7 +91,7 @@ async def _create_new_conversation(
logger.info(f'Starting agent loop for conversation {conversation_id}')
event_stream = await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data, user_id
conversation_id, conversation_init_data, user_id, initial_user_msg
)
try:
event_stream.subscribe(
@ -114,10 +116,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
user_id = get_user_id(request)
github_token = getattr(request.state, 'github_token', '') or data.github_token
selected_repository = data.selected_repository
initial_user_msg = data.initial_user_msg
try:
conversation_id = await _create_new_conversation(
user_id, github_token, selected_repository
user_id, github_token, selected_repository, initial_user_msg
)
return JSONResponse(
@ -140,6 +143,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
'message': str(e),
'msg_id': 'STATUS$ERROR_LLM_AUTHENTICATION',
},
status_code=400,
)

View File

@ -10,6 +10,7 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import ChangeAgentStateAction
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.microagent import BaseMicroAgent
@ -71,6 +72,7 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
selected_repository: str | None = None,
initial_user_msg: str | None = None,
):
"""Starts the Agent session
Parameters:
@ -112,6 +114,12 @@ class AgentSession:
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)
if initial_user_msg:
self.event_stream.add_event(
MessageAction(content=initial_user_msg), EventSource.USER
)
self._starting = False
async def close(self):

View File

@ -442,7 +442,11 @@ class SessionManager:
self._connection_queries.pop(query_id, None)
async def maybe_start_agent_loop(
self, sid: str, settings: Settings, user_id: str | None
self,
sid: str,
settings: Settings,
user_id: str | None,
initial_user_msg: str | None = None,
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None
@ -462,7 +466,7 @@ class SessionManager:
user_id=user_id,
)
self._local_agent_loops_by_sid[sid] = session
asyncio.create_task(session.initialize_agent(settings))
asyncio.create_task(session.initialize_agent(settings, initial_user_msg))
event_stream = await self._get_event_stream(sid)
if not event_stream:

View File

@ -74,10 +74,7 @@ class Session:
self.is_alive = False
await self.agent_session.close()
async def initialize_agent(
self,
settings: Settings,
):
async def initialize_agent(self, settings: Settings, initial_user_msg: str | None):
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
@ -122,6 +119,7 @@ class Session:
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
selected_repository=selected_repository,
initial_user_msg=initial_user_msg,
)
except Exception as e:
logger.exception(f'Error creating agent_session: {e}')