(fix) actions.ts: restored handleAssistantMessage handling order (#4074)

This commit is contained in:
tobitege
2024-09-26 21:56:12 +02:00
committed by GitHub
parent c919086e25
commit 29c34e0b6a
6 changed files with 29 additions and 16 deletions

View File

@@ -1,5 +1,4 @@
import asyncio
from threading import Thread
from typing import Callable, Optional
@@ -76,10 +75,20 @@ class AgentSession:
self.thread = Thread(target=self._run, daemon=True)
self.thread.start()
coro = self._start(runtime_name, config, agent, max_iterations, max_budget_per_task, agent_to_llm_config, agent_configs, status_message_callback)
asyncio.run_coroutine_threadsafe(coro, self.loop) # type: ignore
coro = self._start(
runtime_name,
config,
agent,
max_iterations,
max_budget_per_task,
agent_to_llm_config,
agent_configs,
status_message_callback,
)
asyncio.run_coroutine_threadsafe(coro, self.loop) # type: ignore
async def _start(self,
async def _start(
self,
runtime_name: str,
config: AppConfig,
agent: Agent,
@@ -103,8 +112,8 @@ class AgentSession:
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
)
if self.controller:
self.controller.agent_task = self.controller.start_step_loop()
await self.controller.agent_task # type: ignore
self.controller.agent_task = self.controller.start_step_loop()
await self.controller.agent_task # type: ignore
def _run(self):
asyncio.set_event_loop(self.loop)

View File

@@ -187,6 +187,10 @@ class Session:
"""Sends a message to the client."""
return await self.send({'message': message})
async def send_status_message(self, message: str) -> bool:
"""Sends a status message to the client."""
return await self.send({'status': message})
def update_connection(self, ws: WebSocket):
self.websocket = ws
self.is_alive = True
@@ -202,4 +206,4 @@ class Session:
def queue_status_message(self, message: str):
"""Queues a status message to be sent asynchronously."""
# Ensure the coroutine runs in the main event loop
asyncio.run_coroutine_threadsafe(self.send_message(message), self.loop)
asyncio.run_coroutine_threadsafe(self.send_status_message(message), self.loop)