更新agent逻辑

This commit is contained in:
yuruo
2025-03-14 10:54:55 +08:00
parent 0f09774bef
commit e9fa89e4a0
9 changed files with 138 additions and 143 deletions

View File

@@ -4,8 +4,19 @@ from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
class TaskPlanAgent(BaseAgent):
def __call__(self, messages):
response = run(messages, user_prompt=system_prompt, response_format=TaskPlanResponse)
def __call__(self, messages, parsed_screen_result):
screen_info = str(parsed_screen_result['parsed_content_list'])
messages[-1] = {"role": "user",
"content": [
{"type": "text", "text": messages[-1]["content"]},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{parsed_screen_result['base64_image']}"}
}
]
}
response = run(messages, user_prompt=system_prompt.format(screen_info=screen_info), response_format=TaskPlanResponse)
messages.append({"role": "assistant", "content": response})
return json.loads(response)
class Plan(BaseModel):
@@ -16,6 +27,7 @@ class Plan(BaseModel):
class TaskPlanResponse(BaseModel):
reasoning: str = Field(description="描述您规划任务的逻辑")
task_plan: list[Plan] = Field(description="具体的操作步骤序列")
@@ -23,6 +35,9 @@ system_prompt = """
### 目标 ###
你是自动化操作规划专家,根据屏幕内容和用户需求,规划精确可执行的操作序列。
当前屏幕内容如下:
{screen_info}
### 输入 ###
1. 用户需求:文本描述形式的任务目标
2. 当前环境:屏幕上可见的元素和状态
@@ -30,13 +45,13 @@ system_prompt = """
### 输出格式 ###
操作序列应采用以下JSON格式
[
{
{{
"操作类型": "点击/输入/拖拽/等待/判断...",
"目标元素": "元素描述或坐标",
"参数": "具体参数,如文本内容",
"预期结果": "操作后的预期状态",
"错误处理": "操作失败时的替代方案"
},
}}
]
### 操作类型说明 ###
@@ -50,41 +65,41 @@ system_prompt = """
输入获取AI新闻
输出:
[
{
{{
"操作类型": "点击",
"目标元素": "浏览器图标",
"参数": "",
"预期结果": "浏览器打开",
"错误处理": "如未找到浏览器图标,尝试通过开始菜单搜索浏览器"
},
{
}},
{{
"操作类型": "输入",
"目标元素": "地址栏",
"参数": "https://www.baidu.com",
"预期结果": "百度首页加载完成",
"错误处理": "如连接失败,重试或尝试其他搜索引擎"
},
{
}},
{{
"操作类型": "输入",
"目标元素": "搜索框",
"参数": "AI最新新闻",
"预期结果": "搜索框填充完成",
"错误处理": "如搜索框不可用,尝试刷新页面"
},
{
}},
{{
"操作类型": "点击",
"目标元素": "搜索按钮",
"参数": "",
"预期结果": "显示搜索结果页",
"错误处理": "如点击无反应,尝试按回车键"
},
{
}},
{{
"操作类型": "判断",
"目标元素": "搜索结果列表",
"参数": "包含AI相关内容",
"预期结果": "找到相关新闻",
"错误处理": "如无相关结果,尝试修改搜索关键词"
}
}}
]
"""

View File

@@ -1,9 +1,6 @@
import json
import uuid
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from PIL import ImageDraw
import base64
from io import BytesIO
from pydantic import BaseModel, Field
from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
@@ -13,22 +10,18 @@ class TaskRunAgent(BaseAgent):
def __init__(self):
self.OUTPUT_DIR = "./tmp/outputs"
def __call__(self, task_plan, parsed_screen, messages):
screen_info = str(parsed_screen['parsed_content_list'])
def __call__(self, task_plan, parsed_screen_result, messages):
screen_info = str(parsed_screen_result['parsed_content_list'])
self.SYSTEM_PROMPT = system_prompt.format(task_plan=str(task_plan),
device=self.get_device(),
screen_info=screen_info)
img_to_show = parsed_screen["image"]
buffered = BytesIO()
img_to_show.save(buffered, format="PNG")
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
messages.append(
{"role": "user",
"content": [
{"type": "text", "text": "图片是当前屏幕的截图"},
{"type": "text", "text": "Image is the screenshot of the current screen"},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_to_show_base64}"}
"image_url": {"url": f"data:image/png;base64,{parsed_screen_result['base64_image']}"}
}
]
}
@@ -40,21 +33,6 @@ class TaskRunAgent(BaseAgent):
)
messages.append({"role": "assistant", "content": vlm_response})
vlm_response_json = json.loads(vlm_response)
if "box_id" in vlm_response_json:
try:
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])].coordinates
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 ), int((bbox[1] + bbox[3]) / 2 )]
x, y = vlm_response_json["box_centroid_coordinate"]
radius = 10
draw = ImageDraw.Draw(img_to_show)
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
buffered = BytesIO()
img_to_show.save(buffered, format="PNG")
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e:
print(f"Error parsing: {vlm_response_json}")
print(f"Error: {e}")
response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')]
if 'box_centroid_coordinate' in vlm_response_json:
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',

View File

@@ -15,10 +15,10 @@ class VerificationAgent(BaseAgent):
messages.append({"role": "assistant", "content": response})
return json.loads(response)
class VerificationResponse(BaseModel):
class VerificationResponse(BaseModel):
verification_status: str = Field(description="验证状态", json_schema_extra={"enum": ["success", "error"]})
verification_method: str = Field(description="验证方法")
evidence: str = Field(description="证据")
reasoning: str = Field(description="描述您验证的逻辑")
failure_reason: str = Field(description="失败原因")
remedy_measures: list[str] = Field(description="补救措施")
@@ -72,7 +72,7 @@ prompt = """
{
"verification_status": "success",
"verification_method": "视觉验证+内容验证",
"evidence": "1. 检测到欢迎消息'你好,用户名' 2. 导航栏显示用户头像 3. URL已变更为首页地址",
"reasoning": "1. 检测到欢迎消息'你好,用户名' 2. 导航栏显示用户头像 3. URL已变更为首页地址",
"failure_reason": "",
"remedy_measures": [],
}

View File

@@ -2,6 +2,7 @@
python app.py --windows_host_url localhost:8006 --omniparser_server_url localhost:8000
"""
import json
import os
from pathlib import Path
import argparse
@@ -48,6 +49,8 @@ def setup_state(state):
if "messages" not in state:
state["messages"] = []
if "chatbox_messages" not in state:
state["chatbox_messages"] = []
if "auth_validated" not in state:
state["auth_validated"] = False
if "responses" not in state:
@@ -99,13 +102,49 @@ def process_input(user_input, state, vision_agent_state):
"content": user_input
}
)
yield state['messages']
state["chatbox_messages"].append(
{
"role": "user",
"content": user_input
}
)
yield state['chatbox_messages']
agent = vision_agent_state["agent"]
for _ in sampling_loop_sync(
model=state["model"],
messages=state["messages"],
vision_agent = agent
): yield state['messages']
):
state['chatbox_messages'] = []
for message in state['messages']:
# convert message["content"] to gradio chatbox format
if type(message["content"]) is list:
gradio_chatbox_content = ""
for content in message["content"]:
# convert image_url to gradio image format
if content["type"] == "image_url":
gradio_chatbox_content += f'<br/><img src="{content['image_url']["url"]}">'
# convert text to gradio text format
elif content["type"] == "text":
# agent response is json format and must contains reasoning
if message["role"] == "assistant":
content_json = json.loads(content["text"])
gradio_chatbox_content += f'<br/><h3>{content_json["reasoning"]}</h3>'
gradio_chatbox_content += f'<br/> <details> <summary>Detail</summary> <pre>{json.dumps(content_json, indent=4)}</pre> </details>'
else:
gradio_chatbox_content += content["text"]
state['chatbox_messages'].append({
"role": message["role"],
"content": gradio_chatbox_content
})
else:
state['chatbox_messages'].append({
"role": message["role"],
"content": message["content"]
})
yield state['chatbox_messages']
def stop_app(state):
state["stop"] = True
@@ -202,7 +241,7 @@ def run():
autoscroll=True,
height=580,
type="messages")
def update_model(model, state):
state["model"] = model
@@ -215,9 +254,10 @@ def run():
def clear_chat(state):
# Reset message-related state
state["messages"] = []
state["chatbox_messages"] = []
state["responses"] = {}
state["tools"] = {}
return state['messages']
return state["chatbox_messages"]
model.change(fn=update_model, inputs=[model, state], outputs=None)
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)

View File

@@ -2,7 +2,10 @@ import asyncio
import json
from typing import Any, cast
from anthropic.types.beta import (
BetaMessageParam
BetaMessageParam,
BetaContentBlockParam,
BetaToolResultBlockParam,
BetaContentBlock
)
from gradio_ui.tools import ComputerTool, ToolCollection
@@ -13,13 +16,21 @@ class AnthropicExecutor:
ComputerTool()
)
def __call__(self,messages: list[BetaMessageParam]):
content = json.loads(messages[-1]["content"])
if content["next_action"] is not None:
# Run the asynchronous tool execution in a synchronous context
result = asyncio.run(self.tool_collection.run(
name=content["next_action"],
tool_input=cast(dict[str, Any], content["value"]),
))
messages.append({"role": "assistant", "content": "tool result:\n"+str(result)})
return messages
def __call__(self, response, messages):
tool_result_content: list[str] = []
for content_block in cast(list[BetaContentBlock], response.content):
# Execute the tool
if content_block.type == "tool_use":
# Run the asynchronous tool execution in a synchronous context
result = asyncio.run(self.tool_collection.run(
name=content_block.name,
tool_input=cast(dict[str, Any], content_block.input),
))
tool_result_content.append(
str(result)
)
messages.append({"role": "assistant", "content": "Run tool result:\n"+str(tool_result_content)})
if not tool_result_content:
return messages
return tool_result_content

View File

@@ -1,18 +1,16 @@
"""
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
"""
from collections.abc import Callable
import base64
from io import BytesIO
from time import sleep
import cv2
from gradio_ui.agent.verification_agent import VerificationAgent
from gradio_ui.agent.vision_agent import VisionAgent
from gradio_ui.tools.screen_capture import get_screenshot
from anthropic.types.beta import (
BetaMessageParam
)
from anthropic.types.beta import (BetaMessageParam)
from gradio_ui.agent.task_plan_agent import TaskPlanAgent
from gradio_ui.agent.task_run_agent import TaskRunAgent
from gradio_ui.tools import ToolResult
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
import numpy as np
from PIL import Image
@@ -31,72 +29,76 @@ def sampling_loop_sync(
print('in sampling_loop_sync, model:', model)
task_plan_agent = TaskPlanAgent()
executor = AnthropicExecutor()
plan_list = task_plan_agent(messages=messages)
yield
task_run_agent = TaskRunAgent()
verification_agent = VerificationAgent()
task_run_agent = TaskRunAgent()
parsed_screen_result = parsed_screen(vision_agent)
plan_list = task_plan_agent(messages=messages, parsed_screen_result=parsed_screen_result)
yield
for plan in plan_list:
execute_task_plan(plan, vision_agent, task_run_agent, executor, messages)
yield
sleep(2)
verification_loop(vision_agent, plan, verification_agent, executor, task_run_agent, messages)
verification_loop(vision_agent, verification_agent, executor, task_run_agent, messages)
yield
def verification_loop(vision_agent, plan, verification_agent, executor, task_run_agent, messages):
def verification_loop(vision_agent, verification_agent, executor, task_run_agent, messages):
"""verification agent will be called in the loop"""
while True:
# 验证结果
verification_result = verification_agent(plan["expected_result"], messages)
# verification result
verification_result = verification_agent( messages)
yield
# 如果验证成功,返回结果
# if verification success, return result
if verification_result["verification_status"] == "success":
return
# 如果验证失败,执行补救措施
# if verification failed, execute remedy measures
elif verification_result["verification_status"] == "error":
execute_task_plan(verification_result["remedy_measures"], vision_agent, task_run_agent, executor, messages)
yield
def execute_task_plan(plan, vision_agent, task_run_agent, executor, messages):
parsed_screen = parse_screen(vision_agent)
task_run_agent(task_plan=plan, parsed_screen=parsed_screen, messages=messages)
executor(messages)
parsed_screen_result = parsed_screen(vision_agent)
tools_use_needed, __ = task_run_agent(task_plan=plan, parsed_screen_result=parsed_screen_result, messages=messages)
executor(tools_use_needed, messages)
def parse_screen(vision_agent: VisionAgent):
def parsed_screen(vision_agent: VisionAgent):
screenshot, screenshot_path = get_screenshot()
response_json = {}
response_json['parsed_content_list'] = vision_agent(str(screenshot_path))
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['image'] = draw_elements(screenshot, response_json['parsed_content_list'])
buffered = BytesIO()
response_json['image'].save(buffered, format="PNG")
response_json['base64_image'] = base64.b64encode(buffered.getvalue()).decode("utf-8")
return response_json
def draw_elements(screenshot, parsed_content_list):
"""
将PIL图像转换为OpenCV兼容格式并绘制边界框
Convert PIL image to OpenCV compatible format and draw bounding boxes
Args:
screenshot: PIL Image对象
parsed_content_list: 包含边界框信息的列表
screenshot: PIL Image object
parsed_content_list: list containing bounding box information
Returns:
带有绘制边界框的PIL图像
PIL image with drawn bounding boxes
"""
# 将PIL图像转换为opencv格式
# convert PIL image to opencv format
opencv_image = np.array(screenshot)
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
# 绘制边界框
# draw bounding boxes
for idx, element in enumerate(parsed_content_list):
bbox = element.coordinates
x1, y1, x2, y2 = bbox
# 转换坐标为整数
# convert coordinates to integers
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# 绘制矩形
# draw rectangle
cv2.rectangle(opencv_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
# 在矩形边框左上角绘制序号
# draw index number
cv2.putText(opencv_image, str(idx+1), (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
# 将OpenCV图像格式转换回PIL格式
# convert opencv image format back to PIL format
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(opencv_image)

View File

@@ -2,12 +2,9 @@ import base64
import time
from typing import Literal, TypedDict
from PIL import Image
from util import tool
from anthropic.types.beta import BetaToolComputerUse20241022Param
from .base import BaseAnthropicTool, ToolError, ToolResult
from .screen_capture import get_screenshot
import requests
import re
import pyautogui
import pyperclip
import platform
@@ -29,7 +26,9 @@ Action = Literal[
"screenshot",
"cursor_position",
"hover",
"wait"
"wait",
"scroll_up",
"scroll_down"
]
class Resolution(TypedDict):
@@ -65,7 +64,6 @@ class ComputerTool(BaseAnthropicTool):
@property
def options(self) -> ComputerToolOptions:
# 直接使用原始尺寸,不进行缩放
return {
"display_width_px": self.width,
"display_height_px": self.height,
@@ -76,14 +74,12 @@ class ComputerTool(BaseAnthropicTool):
return {"name": self.name, "type": self.api_type, **self.options}
def __init__(self, is_scaling: bool = False):
def __init__(self):
super().__init__()
# Get screen width and height using Windows command
self.display_num = None
self.offset_x = 0
self.offset_y = 0
self.width, self.height = self.get_screen_size()
print(f"screen size: {self.width}, {self.height}")
self.width, self.height = pyautogui.size()
self.key_conversion = {"Page_Down": "pagedown",
"Page_Up": "pageup",
"Super_L": "win",
@@ -216,19 +212,3 @@ class ComputerTool(BaseAnthropicTool):
# padding to top left
padding_image.paste(screenshot, (0, 0))
return padding_image
def get_screen_size(self):
"""Return width and height of the screen"""
try:
response = tool.execute_command(
["python", "-c", "import pyautogui; print(pyautogui.size())"]
)
output = response['output'].strip()
match = re.search(r'Size\(width=(\d+),\s*height=(\d+)\)', output)
if not match:
raise ToolError(f"Could not parse screen size from output: {output}")
width, height = map(int, match.groups())
return width, height
except requests.exceptions.RequestException as e:
raise ToolError(f"An error occurred while trying to get screen size: {str(e)}")

View File

@@ -13,10 +13,8 @@ def get_screenshot(resize: bool = False, target_width: int = 1920, target_height
path = output_dir / f"screenshot_{uuid4().hex}.png"
try:
# 使用 tool.capture_screen_with_cursor 替代 requests.get
img_io = tool.capture_screen_with_cursor()
screenshot = Image.open(img_io)
screenshot = Image.open(img_io)
if resize and screenshot.size != (target_width, target_height):
screenshot = screenshot.resize((target_width, target_height))
screenshot.save(path)

View File

@@ -1,38 +1,9 @@
import os
import shlex
import subprocess
import threading
import pyautogui
from PIL import Image
from io import BytesIO
computer_control_lock = threading.Lock()
def execute_command(command, shell=False):
"""Local function to execute a command."""
with computer_control_lock:
if isinstance(command, str) and not shell:
command = shlex.split(command)
# Expand user directory
for i, arg in enumerate(command):
if arg.startswith("~/"):
command[i] = os.path.expanduser(arg)
try:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True, timeout=120)
return {
'status': 'success',
'output': result.stdout,
'error': result.stderr,
'returncode': result.returncode
}
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def capture_screen_with_cursor():
"""Local function to capture the screen with cursor."""
cursor_path = os.path.join(os.path.dirname(__file__),"..","resources", "cursor.png")