diff --git a/gradio_ui/agent/task_plan_agent.py b/gradio_ui/agent/task_plan_agent.py
index 0490576..46d3d2a 100644
--- a/gradio_ui/agent/task_plan_agent.py
+++ b/gradio_ui/agent/task_plan_agent.py
@@ -4,8 +4,19 @@ from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
class TaskPlanAgent(BaseAgent):
- def __call__(self, messages):
- response = run(messages, user_prompt=system_prompt, response_format=TaskPlanResponse)
+ def __call__(self, messages, parsed_screen_result):
+ screen_info = str(parsed_screen_result['parsed_content_list'])
+ messages[-1] = {"role": "user",
+ "content": [
+ {"type": "text", "text": messages[-1]["content"]},
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/png;base64,{parsed_screen_result['base64_image']}"}
+ }
+ ]
+ }
+ response = run(messages, user_prompt=system_prompt.format(screen_info=screen_info), response_format=TaskPlanResponse)
+ messages.append({"role": "assistant", "content": response})
return json.loads(response)
class Plan(BaseModel):
@@ -16,6 +27,7 @@ class Plan(BaseModel):
class TaskPlanResponse(BaseModel):
+ reasoning: str = Field(description="描述您规划任务的逻辑")
task_plan: list[Plan] = Field(description="具体的操作步骤序列")
@@ -23,6 +35,9 @@ system_prompt = """
### 目标 ###
你是自动化操作规划专家,根据屏幕内容和用户需求,规划精确可执行的操作序列。
+当前屏幕内容如下:
+{screen_info}
+
### 输入 ###
1. 用户需求:文本描述形式的任务目标
2. 当前环境:屏幕上可见的元素和状态
@@ -30,13 +45,13 @@ system_prompt = """
### 输出格式 ###
操作序列应采用以下JSON格式:
[
- {
+ {{
"操作类型": "点击/输入/拖拽/等待/判断...",
"目标元素": "元素描述或坐标",
"参数": "具体参数,如文本内容",
"预期结果": "操作后的预期状态",
"错误处理": "操作失败时的替代方案"
- },
+ }}
]
### 操作类型说明 ###
@@ -50,41 +65,41 @@ system_prompt = """
输入:获取AI新闻
输出:
[
- {
+ {{
"操作类型": "点击",
"目标元素": "浏览器图标",
"参数": "无",
"预期结果": "浏览器打开",
"错误处理": "如未找到浏览器图标,尝试通过开始菜单搜索浏览器"
- },
- {
+ }},
+ {{
"操作类型": "输入",
"目标元素": "地址栏",
"参数": "https://www.baidu.com",
"预期结果": "百度首页加载完成",
"错误处理": "如连接失败,重试或尝试其他搜索引擎"
- },
- {
+ }},
+ {{
"操作类型": "输入",
"目标元素": "搜索框",
"参数": "AI最新新闻",
"预期结果": "搜索框填充完成",
"错误处理": "如搜索框不可用,尝试刷新页面"
- },
- {
+ }},
+ {{
"操作类型": "点击",
"目标元素": "搜索按钮",
"参数": "无",
"预期结果": "显示搜索结果页",
"错误处理": "如点击无反应,尝试按回车键"
- },
- {
+ }},
+ {{
"操作类型": "判断",
"目标元素": "搜索结果列表",
"参数": "包含AI相关内容",
"预期结果": "找到相关新闻",
"错误处理": "如无相关结果,尝试修改搜索关键词"
- }
+ }}
]
"""
diff --git a/gradio_ui/agent/task_run_agent.py b/gradio_ui/agent/task_run_agent.py
index 564901d..387ae91 100644
--- a/gradio_ui/agent/task_run_agent.py
+++ b/gradio_ui/agent/task_run_agent.py
@@ -1,9 +1,6 @@
import json
import uuid
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
-from PIL import ImageDraw
-import base64
-from io import BytesIO
from pydantic import BaseModel, Field
from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
@@ -13,22 +10,18 @@ class TaskRunAgent(BaseAgent):
def __init__(self):
self.OUTPUT_DIR = "./tmp/outputs"
- def __call__(self, task_plan, parsed_screen, messages):
- screen_info = str(parsed_screen['parsed_content_list'])
+ def __call__(self, task_plan, parsed_screen_result, messages):
+ screen_info = str(parsed_screen_result['parsed_content_list'])
self.SYSTEM_PROMPT = system_prompt.format(task_plan=str(task_plan),
device=self.get_device(),
screen_info=screen_info)
- img_to_show = parsed_screen["image"]
- buffered = BytesIO()
- img_to_show.save(buffered, format="PNG")
- img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
messages.append(
{"role": "user",
"content": [
- {"type": "text", "text": "图片是当前屏幕的截图"},
+ {"type": "text", "text": "Image is the screenshot of the current screen"},
{
"type": "image_url",
- "image_url": {"url": f"data:image/png;base64,{img_to_show_base64}"}
+ "image_url": {"url": f"data:image/png;base64,{parsed_screen_result['base64_image']}"}
}
]
}
@@ -40,21 +33,6 @@ class TaskRunAgent(BaseAgent):
)
messages.append({"role": "assistant", "content": vlm_response})
vlm_response_json = json.loads(vlm_response)
- if "box_id" in vlm_response_json:
- try:
- bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])].coordinates
- vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 ), int((bbox[1] + bbox[3]) / 2 )]
- x, y = vlm_response_json["box_centroid_coordinate"]
- radius = 10
- draw = ImageDraw.Draw(img_to_show)
- draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
- draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
- buffered = BytesIO()
- img_to_show.save(buffered, format="PNG")
- img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
- except Exception as e:
- print(f"Error parsing: {vlm_response_json}")
- print(f"Error: {e}")
response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')]
if 'box_centroid_coordinate' in vlm_response_json:
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
diff --git a/gradio_ui/agent/verification_agent.py b/gradio_ui/agent/verification_agent.py
index bce9009..05922cf 100644
--- a/gradio_ui/agent/verification_agent.py
+++ b/gradio_ui/agent/verification_agent.py
@@ -15,10 +15,10 @@ class VerificationAgent(BaseAgent):
messages.append({"role": "assistant", "content": response})
return json.loads(response)
-class VerificationResponse(BaseModel):
+class VerificationResponse(BaseModel):
verification_status: str = Field(description="验证状态", json_schema_extra={"enum": ["success", "error"]})
verification_method: str = Field(description="验证方法")
- evidence: str = Field(description="证据")
+ reasoning: str = Field(description="描述您验证的逻辑")
failure_reason: str = Field(description="失败原因")
remedy_measures: list[str] = Field(description="补救措施")
@@ -72,7 +72,7 @@ prompt = """
{
"verification_status": "success",
"verification_method": "视觉验证+内容验证",
- "evidence": "1. 检测到欢迎消息'你好,用户名' 2. 导航栏显示用户头像 3. URL已变更为首页地址",
+ "reasoning": "1. 检测到欢迎消息'你好,用户名' 2. 导航栏显示用户头像 3. URL已变更为首页地址",
"failure_reason": "无",
"remedy_measures": [],
}
diff --git a/gradio_ui/app.py b/gradio_ui/app.py
index c8b4a67..399cc21 100644
--- a/gradio_ui/app.py
+++ b/gradio_ui/app.py
@@ -2,6 +2,7 @@
python app.py --windows_host_url localhost:8006 --omniparser_server_url localhost:8000
"""
+import json
import os
from pathlib import Path
import argparse
@@ -48,6 +49,8 @@ def setup_state(state):
if "messages" not in state:
state["messages"] = []
+ if "chatbox_messages" not in state:
+ state["chatbox_messages"] = []
if "auth_validated" not in state:
state["auth_validated"] = False
if "responses" not in state:
@@ -99,13 +102,49 @@ def process_input(user_input, state, vision_agent_state):
"content": user_input
}
)
- yield state['messages']
+ state["chatbox_messages"].append(
+ {
+ "role": "user",
+ "content": user_input
+ }
+ )
+ yield state['chatbox_messages']
agent = vision_agent_state["agent"]
for _ in sampling_loop_sync(
model=state["model"],
messages=state["messages"],
vision_agent = agent
- ): yield state['messages']
+ ):
+ state['chatbox_messages'] = []
+ for message in state['messages']:
+ # convert message["content"] to gradio chatbox format
+ if type(message["content"]) is list:
+ gradio_chatbox_content = ""
+ for content in message["content"]:
+ # convert image_url to gradio image format
+ if content["type"] == "image_url":
+ gradio_chatbox_content += f'
'
+ # convert text to gradio text format
+ elif content["type"] == "text":
+ # agent response is json format and must contains reasoning
+ if message["role"] == "assistant":
+ content_json = json.loads(content["text"])
+ gradio_chatbox_content += f'
{content_json["reasoning"]}
'
+ gradio_chatbox_content += f'
Detail
{json.dumps(content_json, indent=4)} '
+ else:
+ gradio_chatbox_content += content["text"]
+
+ state['chatbox_messages'].append({
+ "role": message["role"],
+ "content": gradio_chatbox_content
+ })
+ else:
+ state['chatbox_messages'].append({
+ "role": message["role"],
+ "content": message["content"]
+ })
+ yield state['chatbox_messages']
+
def stop_app(state):
state["stop"] = True
@@ -202,7 +241,7 @@ def run():
autoscroll=True,
height=580,
type="messages")
-
+
def update_model(model, state):
state["model"] = model
@@ -215,9 +254,10 @@ def run():
def clear_chat(state):
# Reset message-related state
state["messages"] = []
+ state["chatbox_messages"] = []
state["responses"] = {}
state["tools"] = {}
- return state['messages']
+ return state["chatbox_messages"]
model.change(fn=update_model, inputs=[model, state], outputs=None)
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
diff --git a/gradio_ui/executor/anthropic_executor.py b/gradio_ui/executor/anthropic_executor.py
index 3a415dc..5c68de4 100644
--- a/gradio_ui/executor/anthropic_executor.py
+++ b/gradio_ui/executor/anthropic_executor.py
@@ -2,7 +2,10 @@ import asyncio
import json
from typing import Any, cast
from anthropic.types.beta import (
- BetaMessageParam
+ BetaMessageParam,
+ BetaContentBlockParam,
+ BetaToolResultBlockParam,
+ BetaContentBlock
)
from gradio_ui.tools import ComputerTool, ToolCollection
@@ -13,13 +16,21 @@ class AnthropicExecutor:
ComputerTool()
)
- def __call__(self,messages: list[BetaMessageParam]):
- content = json.loads(messages[-1]["content"])
- if content["next_action"] is not None:
- # Run the asynchronous tool execution in a synchronous context
- result = asyncio.run(self.tool_collection.run(
- name=content["next_action"],
- tool_input=cast(dict[str, Any], content["value"]),
- ))
- messages.append({"role": "assistant", "content": "tool result:\n"+str(result)})
- return messages
\ No newline at end of file
+ def __call__(self, response, messages):
+ tool_result_content: list[str] = []
+ for content_block in cast(list[BetaContentBlock], response.content):
+ # Execute the tool
+ if content_block.type == "tool_use":
+ # Run the asynchronous tool execution in a synchronous context
+ result = asyncio.run(self.tool_collection.run(
+ name=content_block.name,
+ tool_input=cast(dict[str, Any], content_block.input),
+ ))
+ tool_result_content.append(
+ str(result)
+ )
+ messages.append({"role": "assistant", "content": "Run tool result:\n"+str(tool_result_content)})
+ if not tool_result_content:
+ return messages
+
+ return tool_result_content
diff --git a/gradio_ui/loop.py b/gradio_ui/loop.py
index 7cfd3b9..eca8e07 100644
--- a/gradio_ui/loop.py
+++ b/gradio_ui/loop.py
@@ -1,18 +1,16 @@
"""
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
"""
-from collections.abc import Callable
+import base64
+from io import BytesIO
from time import sleep
import cv2
from gradio_ui.agent.verification_agent import VerificationAgent
from gradio_ui.agent.vision_agent import VisionAgent
from gradio_ui.tools.screen_capture import get_screenshot
-from anthropic.types.beta import (
- BetaMessageParam
-)
+from anthropic.types.beta import (BetaMessageParam)
from gradio_ui.agent.task_plan_agent import TaskPlanAgent
from gradio_ui.agent.task_run_agent import TaskRunAgent
-from gradio_ui.tools import ToolResult
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
import numpy as np
from PIL import Image
@@ -31,72 +29,76 @@ def sampling_loop_sync(
print('in sampling_loop_sync, model:', model)
task_plan_agent = TaskPlanAgent()
executor = AnthropicExecutor()
- plan_list = task_plan_agent(messages=messages)
- yield
- task_run_agent = TaskRunAgent()
verification_agent = VerificationAgent()
+ task_run_agent = TaskRunAgent()
+ parsed_screen_result = parsed_screen(vision_agent)
+ plan_list = task_plan_agent(messages=messages, parsed_screen_result=parsed_screen_result)
+ yield
for plan in plan_list:
execute_task_plan(plan, vision_agent, task_run_agent, executor, messages)
yield
sleep(2)
- verification_loop(vision_agent, plan, verification_agent, executor, task_run_agent, messages)
+ verification_loop(vision_agent, verification_agent, executor, task_run_agent, messages)
yield
-def verification_loop(vision_agent, plan, verification_agent, executor, task_run_agent, messages):
+def verification_loop(vision_agent, verification_agent, executor, task_run_agent, messages):
"""verification agent will be called in the loop"""
while True:
- # 验证结果
- verification_result = verification_agent(plan["expected_result"], messages)
+ # verification result
+ verification_result = verification_agent( messages)
yield
- # 如果验证成功,返回结果
+ # if verification success, return result
if verification_result["verification_status"] == "success":
return
- # 如果验证失败,执行补救措施
+ # if verification failed, execute remedy measures
elif verification_result["verification_status"] == "error":
execute_task_plan(verification_result["remedy_measures"], vision_agent, task_run_agent, executor, messages)
yield
def execute_task_plan(plan, vision_agent, task_run_agent, executor, messages):
- parsed_screen = parse_screen(vision_agent)
- task_run_agent(task_plan=plan, parsed_screen=parsed_screen, messages=messages)
- executor(messages)
+ parsed_screen_result = parsed_screen(vision_agent)
+ tools_use_needed, __ = task_run_agent(task_plan=plan, parsed_screen_result=parsed_screen_result, messages=messages)
+ executor(tools_use_needed, messages)
-def parse_screen(vision_agent: VisionAgent):
+def parsed_screen(vision_agent: VisionAgent):
screenshot, screenshot_path = get_screenshot()
response_json = {}
response_json['parsed_content_list'] = vision_agent(str(screenshot_path))
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['image'] = draw_elements(screenshot, response_json['parsed_content_list'])
+ buffered = BytesIO()
+ response_json['image'].save(buffered, format="PNG")
+ response_json['base64_image'] = base64.b64encode(buffered.getvalue()).decode("utf-8")
return response_json
def draw_elements(screenshot, parsed_content_list):
"""
- 将PIL图像转换为OpenCV兼容格式并绘制边界框
+ Convert PIL image to OpenCV compatible format and draw bounding boxes
Args:
- screenshot: PIL Image对象
- parsed_content_list: 包含边界框信息的列表
+ screenshot: PIL Image object
+ parsed_content_list: list containing bounding box information
Returns:
- 带有绘制边界框的PIL图像
+ PIL image with drawn bounding boxes
"""
- # 将PIL图像转换为opencv格式
+ # convert PIL image to opencv format
opencv_image = np.array(screenshot)
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
- # 绘制边界框
+ # draw bounding boxes
for idx, element in enumerate(parsed_content_list):
bbox = element.coordinates
x1, y1, x2, y2 = bbox
- # 转换坐标为整数
+ # convert coordinates to integers
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
- # 绘制矩形
+ # draw rectangle
cv2.rectangle(opencv_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
- # 在矩形边框左上角绘制序号
+ # draw index number
cv2.putText(opencv_image, str(idx+1), (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
- # 将OpenCV图像格式转换回PIL格式
+ # convert opencv image format back to PIL format
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(opencv_image)
diff --git a/gradio_ui/tools/computer.py b/gradio_ui/tools/computer.py
index 88d79cc..f802a87 100644
--- a/gradio_ui/tools/computer.py
+++ b/gradio_ui/tools/computer.py
@@ -2,12 +2,9 @@ import base64
import time
from typing import Literal, TypedDict
from PIL import Image
-from util import tool
from anthropic.types.beta import BetaToolComputerUse20241022Param
from .base import BaseAnthropicTool, ToolError, ToolResult
from .screen_capture import get_screenshot
-import requests
-import re
import pyautogui
import pyperclip
import platform
@@ -29,7 +26,9 @@ Action = Literal[
"screenshot",
"cursor_position",
"hover",
- "wait"
+ "wait",
+ "scroll_up",
+ "scroll_down"
]
class Resolution(TypedDict):
@@ -65,7 +64,6 @@ class ComputerTool(BaseAnthropicTool):
@property
def options(self) -> ComputerToolOptions:
- # 直接使用原始尺寸,不进行缩放
return {
"display_width_px": self.width,
"display_height_px": self.height,
@@ -76,14 +74,12 @@ class ComputerTool(BaseAnthropicTool):
return {"name": self.name, "type": self.api_type, **self.options}
- def __init__(self, is_scaling: bool = False):
+ def __init__(self):
super().__init__()
- # Get screen width and height using Windows command
self.display_num = None
self.offset_x = 0
self.offset_y = 0
- self.width, self.height = self.get_screen_size()
- print(f"screen size: {self.width}, {self.height}")
+ self.width, self.height = pyautogui.size()
self.key_conversion = {"Page_Down": "pagedown",
"Page_Up": "pageup",
"Super_L": "win",
@@ -216,19 +212,3 @@ class ComputerTool(BaseAnthropicTool):
# padding to top left
padding_image.paste(screenshot, (0, 0))
return padding_image
-
-
- def get_screen_size(self):
- """Return width and height of the screen"""
- try:
- response = tool.execute_command(
- ["python", "-c", "import pyautogui; print(pyautogui.size())"]
- )
- output = response['output'].strip()
- match = re.search(r'Size\(width=(\d+),\s*height=(\d+)\)', output)
- if not match:
- raise ToolError(f"Could not parse screen size from output: {output}")
- width, height = map(int, match.groups())
- return width, height
- except requests.exceptions.RequestException as e:
- raise ToolError(f"An error occurred while trying to get screen size: {str(e)}")
\ No newline at end of file
diff --git a/gradio_ui/tools/screen_capture.py b/gradio_ui/tools/screen_capture.py
index 430aee2..2642741 100644
--- a/gradio_ui/tools/screen_capture.py
+++ b/gradio_ui/tools/screen_capture.py
@@ -13,10 +13,8 @@ def get_screenshot(resize: bool = False, target_width: int = 1920, target_height
path = output_dir / f"screenshot_{uuid4().hex}.png"
try:
- # 使用 tool.capture_screen_with_cursor 替代 requests.get
img_io = tool.capture_screen_with_cursor()
- screenshot = Image.open(img_io)
-
+ screenshot = Image.open(img_io)
if resize and screenshot.size != (target_width, target_height):
screenshot = screenshot.resize((target_width, target_height))
screenshot.save(path)
diff --git a/util/tool.py b/util/tool.py
index 309f32f..9c88577 100644
--- a/util/tool.py
+++ b/util/tool.py
@@ -1,38 +1,9 @@
import os
-import shlex
-import subprocess
-import threading
import pyautogui
from PIL import Image
from io import BytesIO
-computer_control_lock = threading.Lock()
-def execute_command(command, shell=False):
- """Local function to execute a command."""
- with computer_control_lock:
- if isinstance(command, str) and not shell:
- command = shlex.split(command)
-
- # Expand user directory
- for i, arg in enumerate(command):
- if arg.startswith("~/"):
- command[i] = os.path.expanduser(arg)
-
- try:
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True, timeout=120)
- return {
- 'status': 'success',
- 'output': result.stdout,
- 'error': result.stderr,
- 'returncode': result.returncode
- }
- except Exception as e:
- return {
- 'status': 'error',
- 'message': str(e)
- }
-
def capture_screen_with_cursor():
"""Local function to capture the screen with cursor."""
cursor_path = os.path.join(os.path.dirname(__file__),"..","resources", "cursor.png")