wol 03f2c49b7e
Some checks failed
Remove old artifacts / remove-old-artifacts (push) Has been cancelled
init
2025-08-04 00:20:29 +08:00

361 lines
10 KiB
Python

from typing_extensions import Any, Literal, TypedDict
from pydantic import BaseModel
from app.exception.exception import ProgramException
from app.model.chat import McpServers, Status, SupplementChat, Chat, UpdateData
import asyncio
from enum import Enum
from camel.tasks import Task
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import datetime, timedelta
import weakref
from loguru import logger
class Action(str, Enum):
improve = "improve" # user -> backend
update_task = "update_task" # user -> backend
task_state = "task_state" # backend -> user
start = "start" # user -> backend
create_agent = "create_agent" # backend -> user
activate_agent = "activate_agent" # backend -> user
deactivate_agent = "deactivate_agent" # backend -> user
assign_task = "assign_task" # backend -> user
activate_toolkit = "activate_toolkit" # backend -> user
deactivate_toolkit = "deactivate_toolkit" # backend -> user
write_file = "write_file" # backend -> user
ask = "ask" # backend -> user
notice = "notice" # backend -> user
search_mcp = "search_mcp" # backend -> user
install_mcp = "install_mcp" # backend -> user
terminal = "terminal" # backend -> user
end = "end" # backend -> user
stop = "stop" # user -> backend
supplement = "supplement" # user -> backend
pause = "pause" # user -> backend user take control
resume = "resume" # user -> backend user take control
new_agent = "new_agent" # user -> backend
budget_not_enough = "budget_not_enough" # backend -> user
class ActionImproveData(BaseModel):
action: Literal[Action.improve] = Action.improve
data: str
class ActionStartData(BaseModel):
action: Literal[Action.start] = Action.start
class ActionUpdateTaskData(BaseModel):
action: Literal[Action.update_task] = Action.update_task
data: UpdateData
class ActionTaskStateData(BaseModel):
action: Literal[Action.task_state] = Action.task_state
data: dict[Literal["task_id", "content", "state", "result", "failure_count"], str | int]
class ActionAskData(BaseModel):
action: Literal[Action.ask] = Action.ask
data: dict[Literal["question", "agent"], str]
class AgentDataDict(TypedDict):
agent_name: str
agent_id: str
tools: list[str]
class ActionCreateAgentData(BaseModel):
action: Literal[Action.create_agent] = Action.create_agent
data: AgentDataDict
class ActionActivateAgentData(BaseModel):
action: Literal[Action.activate_agent] = Action.activate_agent
data: dict[Literal["agent_name", "process_task_id", "agent_id", "message"], str]
class DataDict(TypedDict):
agent_name: str
agent_id: str
process_task_id: str
message: str
tokens: int
class ActionDeactivateAgentData(BaseModel):
action: Literal[Action.deactivate_agent] = Action.deactivate_agent
data: DataDict
class ActionAssignTaskData(BaseModel):
action: Literal[Action.assign_task] = Action.assign_task
data: dict[Literal["assignee_id", "task_id", "content", "state"], str]
class ActionActivateToolkitData(BaseModel):
action: Literal[Action.activate_toolkit] = Action.activate_toolkit
data: dict[
Literal["agent_name", "toolkit_name", "process_task_id", "method_name", "message"],
str,
]
class ActionDeactivateToolkitData(BaseModel):
action: Literal[Action.deactivate_toolkit] = Action.deactivate_toolkit
data: dict[
Literal["agent_name", "toolkit_name", "process_task_id", "method_name", "message"],
str,
]
class ActionWriteFileData(BaseModel):
action: Literal[Action.write_file] = Action.write_file
process_task_id: str
data: str
class ActionNoticeData(BaseModel):
action: Literal[Action.notice] = Action.notice
process_task_id: str
data: str
class ActionSearchMcpData(BaseModel):
action: Literal[Action.search_mcp] = Action.search_mcp
data: Any
class ActionInstallMcpData(BaseModel):
action: Literal[Action.install_mcp] = Action.install_mcp
data: McpServers
class ActionTerminalData(BaseModel):
action: Literal[Action.terminal] = Action.terminal
process_task_id: str
data: str
class ActionStopData(BaseModel):
action: Literal[Action.stop] = Action.stop
class ActionEndData(BaseModel):
action: Literal[Action.end] = Action.end
class ActionSupplementData(BaseModel):
action: Literal[Action.supplement] = Action.supplement
data: SupplementChat
class ActionTakeControl(BaseModel):
action: Literal[Action.pause, Action.resume]
class ActionNewAgent(BaseModel):
action: Literal[Action.new_agent] = Action.new_agent
name: str
description: str
tools: list[str]
mcp_tools: McpServers | None
class ActionBudgetNotEnough(BaseModel):
action: Literal[Action.budget_not_enough] = Action.budget_not_enough
ActionData = (
ActionImproveData
| ActionStartData
| ActionUpdateTaskData
| ActionTaskStateData
| ActionAskData
| ActionCreateAgentData
| ActionActivateAgentData
| ActionDeactivateAgentData
| ActionAssignTaskData
| ActionActivateToolkitData
| ActionDeactivateToolkitData
| ActionWriteFileData
| ActionNoticeData
| ActionSearchMcpData
| ActionInstallMcpData
| ActionTerminalData
| ActionStopData
| ActionEndData
| ActionSupplementData
| ActionTakeControl
| ActionNewAgent
| ActionBudgetNotEnough
)
class Agents(str, Enum):
task_agent = "task_agent"
coordinator_agent = "coordinator_agent"
new_worker_agent = "new_worker_agent"
developer_agent = "developer_agent"
search_agent = "search_agent"
document_agent = "document_agent"
multi_modal_agent = "multi_modal_agent"
social_medium_agent = "social_medium_agent"
mcp_agent = "mcp_agent"
class TaskLock:
id: str
status: Status = Status.confirming
active_agent: str = ""
mcp: list[str]
queue: asyncio.Queue[ActionData]
"""Queue monitoring for SSE response"""
human_input: dict[str, asyncio.Queue[str]]
"""After receiving user's reply, put the reply into the corresponding agent's queue"""
created_at: datetime
last_accessed: datetime
background_tasks: set[asyncio.Task]
"""Track all background tasks for cleanup"""
def __init__(self, id: str, queue: asyncio.Queue, human_input: dict) -> None:
self.id = id
self.queue = queue
self.human_input = human_input
self.created_at = datetime.now()
self.last_accessed = datetime.now()
self.background_tasks = set()
async def put_queue(self, data: ActionData):
self.last_accessed = datetime.now()
await self.queue.put(data)
async def get_queue(self):
self.last_accessed = datetime.now()
return await self.queue.get()
async def put_human_input(self, agent: str, data: Any = None):
await self.human_input[agent].put(data)
async def get_human_input(self, agent: str):
return await self.human_input[agent].get()
def add_human_input_listen(self, agent: str):
self.human_input[agent] = asyncio.Queue(1)
def add_background_task(self, task: asyncio.Task) -> None:
r"""Add a task to track and clean up weak references"""
self.background_tasks.add(task)
task.add_done_callback(lambda t: self.background_tasks.discard(t))
async def cleanup(self):
r"""Cancel all background tasks and clean up resources"""
for task in list(self.background_tasks):
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.background_tasks.clear()
task_locks = dict[str, TaskLock]()
# Cleanup task for removing stale task locks
_cleanup_task: asyncio.Task | None = None
task_index: dict[str, weakref.ref[Task]] = {}
def get_task_lock(id: str) -> TaskLock:
if id not in task_locks:
raise ProgramException("Task not found")
return task_locks[id]
def create_task_lock(id: str) -> TaskLock:
if id in task_locks:
raise ProgramException("Task already exists")
task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={})
# Start cleanup task if not running
global _cleanup_task
if _cleanup_task is None or _cleanup_task.done():
_cleanup_task = asyncio.create_task(_periodic_cleanup())
return task_locks[id]
async def delete_task_lock(id: str):
if id not in task_locks:
raise ProgramException("Task not found")
# Clean up background tasks before deletion
task_lock = task_locks[id]
await task_lock.cleanup()
del task_locks[id]
logger.debug(f"Deleted task lock {id}, remaining locks: {len(task_locks)}")
def get_camel_task(id: str, tasks: list[Task]) -> None | Task:
if id in task_index:
task_ref = task_index[id]
task = task_ref()
if task is not None:
return task
else:
# Weak reference died, remove from index
del task_index[id]
# Fallback to search and rebuild index
for item in tasks:
# Add to index
task_index[item.id] = weakref.ref(item)
if item.id == id:
return item
else:
task = get_camel_task(id, item.subtasks)
if task is not None:
return task
return None
async def _periodic_cleanup():
r"""Periodically clean up stale task locks"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
current_time = datetime.now()
stale_timeout = timedelta(hours=2) # Consider tasks stale after 2 hours
stale_ids = []
for task_id, task_lock in task_locks.items():
if current_time - task_lock.last_accessed > stale_timeout:
stale_ids.append(task_id)
for task_id in stale_ids:
logger.warning(f"Cleaning up stale task lock: {task_id}")
await delete_task_lock(task_id)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in periodic cleanup: {e}")
process_task = ContextVar[str]("id")
@contextmanager
def set_process_task(process_task_id: str):
origin = process_task.set(process_task_id)
try:
yield
finally:
process_task.reset(origin)