更新agent

This commit is contained in:
yuruo 2025-03-11 16:06:20 +08:00
parent c89517a3b9
commit 3ca1bd6cca
4 changed files with 26 additions and 19 deletions

View File

@ -1,6 +1,5 @@
class BaseAgent:
def __init__(self, config, *args, **kwargs):
self.config = config
def __init__(self, *args, **kwargs):
self.SYSTEM_PROMPT = ""

View File

@ -6,7 +6,7 @@ class TaskPlanAgent(BaseAgent):
self.SYSTEM_PROMPT = system_prompt
def run(self, user_task: str):
def __call__(self, user_task: str):
return self.chat([{"role": "user", "content": user_task}])
system_prompt = """

View File

@ -1,13 +1,12 @@
import json
import uuid
from anthropic.types.beta import BetaMessage, BetaUsage
from pydantic import BaseModel, Field
from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
import platform
import re
class TaskRunAgent(BaseAgent):
def __init__(self, config, task_plan: str, screen_info):
super().__init__(config)
def __init__(self,task_plan: str, screen_info):
self.OUTPUT_DIR = "./tmp/outputs"
device = self.get_device()
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
@ -27,9 +26,11 @@ class TaskRunAgent(BaseAgent):
device = system
return device
def chat(self, task):
def __call__(self, task):
res = run([{"role": "user", "content": task}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
return res
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=res, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
vlm_response_json = self.extract_data(res, "json")
return response_message, vlm_response_json
def extract_data(self, input_string, data_type):
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks

View File

@ -13,6 +13,8 @@ from anthropic.types.beta import (
BetaMessage,
BetaMessageParam
)
from gradio_ui.agent.task_plan_agent import TaskPlanAgent
from gradio_ui.agent.task_run_agent import TaskRunAgent
from gradio_ui.tools import ToolResult
from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient
@ -37,15 +39,17 @@ def sampling_loop_sync(
"""
print('in sampling_loop_sync, model:', model)
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
actor = VLMAgent(
model=model,
api_key=api_key,
base_url=base_url,
api_response_callback=api_response_callback,
output_callback=output_callback,
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
)
task_plan_agent = TaskPlanAgent()
task_plan_agent()
# actor = VLMAgent(
# model=model,
# api_key=api_key,
# base_url=base_url,
# api_response_callback=api_response_callback,
# output_callback=output_callback,
# max_tokens=max_tokens,
# only_n_most_recent_images=only_n_most_recent_images
# )
executor = AnthropicExecutor(
output_callback=output_callback,
tool_output_callback=tool_output_callback,
@ -54,10 +58,13 @@ def sampling_loop_sync(
tool_result_content = None
print(f"Start the message loop. User messages: {messages}")
plan = task_plan_agent(messages[-1]["content"][0])
task_run_agent = TaskRunAgent()
while True:
parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
# tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
task_run_agent(plan, parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content: