diff --git a/gradio_ui/agent/task_plan_agent.py b/gradio_ui/agent/task_plan_agent.py index 0490576..46d3d2a 100644 --- a/gradio_ui/agent/task_plan_agent.py +++ b/gradio_ui/agent/task_plan_agent.py @@ -4,8 +4,19 @@ from gradio_ui.agent.base_agent import BaseAgent from xbrain.core.chat import run class TaskPlanAgent(BaseAgent): - def __call__(self, messages): - response = run(messages, user_prompt=system_prompt, response_format=TaskPlanResponse) + def __call__(self, messages, parsed_screen_result): + screen_info = str(parsed_screen_result['parsed_content_list']) + messages[-1] = {"role": "user", + "content": [ + {"type": "text", "text": messages[-1]["content"]}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{parsed_screen_result['base64_image']}"} + } + ] + } + response = run(messages, user_prompt=system_prompt.format(screen_info=screen_info), response_format=TaskPlanResponse) + messages.append({"role": "assistant", "content": response}) return json.loads(response) class Plan(BaseModel): @@ -16,6 +27,7 @@ class Plan(BaseModel): class TaskPlanResponse(BaseModel): + reasoning: str = Field(description="描述您规划任务的逻辑") task_plan: list[Plan] = Field(description="具体的操作步骤序列") @@ -23,6 +35,9 @@ system_prompt = """ ### 目标 ### 你是自动化操作规划专家,根据屏幕内容和用户需求,规划精确可执行的操作序列。 +当前屏幕内容如下: +{screen_info} + ### 输入 ### 1. 用户需求:文本描述形式的任务目标 2. 当前环境:屏幕上可见的元素和状态 @@ -30,13 +45,13 @@ system_prompt = """ ### 输出格式 ### 操作序列应采用以下JSON格式: [ - { + {{ "操作类型": "点击/输入/拖拽/等待/判断...", "目标元素": "元素描述或坐标", "参数": "具体参数,如文本内容", "预期结果": "操作后的预期状态", "错误处理": "操作失败时的替代方案" - }, + }} ] ### 操作类型说明 ### @@ -50,41 +65,41 @@ system_prompt = """ 输入:获取AI新闻 输出: [ - { + {{ "操作类型": "点击", "目标元素": "浏览器图标", "参数": "无", "预期结果": "浏览器打开", "错误处理": "如未找到浏览器图标,尝试通过开始菜单搜索浏览器" - }, - { + }}, + {{ "操作类型": "输入", "目标元素": "地址栏", "参数": "https://www.baidu.com", "预期结果": "百度首页加载完成", "错误处理": "如连接失败,重试或尝试其他搜索引擎" - }, - { + }}, + {{ "操作类型": "输入", "目标元素": "搜索框", "参数": "AI最新新闻", "预期结果": "搜索框填充完成", "错误处理": "如搜索框不可用,尝试刷新页面" - }, - { + }}, + {{ "操作类型": "点击", "目标元素": "搜索按钮", "参数": "无", "预期结果": "显示搜索结果页", "错误处理": "如点击无反应,尝试按回车键" - }, - { + }}, + {{ "操作类型": "判断", "目标元素": "搜索结果列表", "参数": "包含AI相关内容", "预期结果": "找到相关新闻", "错误处理": "如无相关结果,尝试修改搜索关键词" - } + }} ] """ diff --git a/gradio_ui/agent/task_run_agent.py b/gradio_ui/agent/task_run_agent.py index 564901d..387ae91 100644 --- a/gradio_ui/agent/task_run_agent.py +++ b/gradio_ui/agent/task_run_agent.py @@ -1,9 +1,6 @@ import json import uuid from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage -from PIL import ImageDraw -import base64 -from io import BytesIO from pydantic import BaseModel, Field from gradio_ui.agent.base_agent import BaseAgent from xbrain.core.chat import run @@ -13,22 +10,18 @@ class TaskRunAgent(BaseAgent): def __init__(self): self.OUTPUT_DIR = "./tmp/outputs" - def __call__(self, task_plan, parsed_screen, messages): - screen_info = str(parsed_screen['parsed_content_list']) + def __call__(self, task_plan, parsed_screen_result, messages): + screen_info = str(parsed_screen_result['parsed_content_list']) self.SYSTEM_PROMPT = system_prompt.format(task_plan=str(task_plan), device=self.get_device(), screen_info=screen_info) - img_to_show = parsed_screen["image"] - buffered = BytesIO() - img_to_show.save(buffered, format="PNG") - img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") messages.append( {"role": "user", "content": [ - {"type": "text", "text": "图片是当前屏幕的截图"}, + {"type": "text", "text": "Image is the screenshot of the current screen"}, { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{img_to_show_base64}"} + "image_url": {"url": f"data:image/png;base64,{parsed_screen_result['base64_image']}"} } ] } @@ -40,21 +33,6 @@ class TaskRunAgent(BaseAgent): ) messages.append({"role": "assistant", "content": vlm_response}) vlm_response_json = json.loads(vlm_response) - if "box_id" in vlm_response_json: - try: - bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])].coordinates - vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 ), int((bbox[1] + bbox[3]) / 2 )] - x, y = vlm_response_json["box_centroid_coordinate"] - radius = 10 - draw = ImageDraw.Draw(img_to_show) - draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red') - draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2) - buffered = BytesIO() - img_to_show.save(buffered, format="PNG") - img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") - except Exception as e: - print(f"Error parsing: {vlm_response_json}") - print(f"Error: {e}") response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')] if 'box_centroid_coordinate' in vlm_response_json: move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', diff --git a/gradio_ui/agent/verification_agent.py b/gradio_ui/agent/verification_agent.py index bce9009..05922cf 100644 --- a/gradio_ui/agent/verification_agent.py +++ b/gradio_ui/agent/verification_agent.py @@ -15,10 +15,10 @@ class VerificationAgent(BaseAgent): messages.append({"role": "assistant", "content": response}) return json.loads(response) -class VerificationResponse(BaseModel): +class VerificationResponse(BaseModel): verification_status: str = Field(description="验证状态", json_schema_extra={"enum": ["success", "error"]}) verification_method: str = Field(description="验证方法") - evidence: str = Field(description="证据") + reasoning: str = Field(description="描述您验证的逻辑") failure_reason: str = Field(description="失败原因") remedy_measures: list[str] = Field(description="补救措施") @@ -72,7 +72,7 @@ prompt = """ { "verification_status": "success", "verification_method": "视觉验证+内容验证", - "evidence": "1. 检测到欢迎消息'你好,用户名' 2. 导航栏显示用户头像 3. URL已变更为首页地址", + "reasoning": "1. 检测到欢迎消息'你好,用户名' 2. 导航栏显示用户头像 3. URL已变更为首页地址", "failure_reason": "无", "remedy_measures": [], } diff --git a/gradio_ui/app.py b/gradio_ui/app.py index c8b4a67..399cc21 100644 --- a/gradio_ui/app.py +++ b/gradio_ui/app.py @@ -2,6 +2,7 @@ python app.py --windows_host_url localhost:8006 --omniparser_server_url localhost:8000 """ +import json import os from pathlib import Path import argparse @@ -48,6 +49,8 @@ def setup_state(state): if "messages" not in state: state["messages"] = [] + if "chatbox_messages" not in state: + state["chatbox_messages"] = [] if "auth_validated" not in state: state["auth_validated"] = False if "responses" not in state: @@ -99,13 +102,49 @@ def process_input(user_input, state, vision_agent_state): "content": user_input } ) - yield state['messages'] + state["chatbox_messages"].append( + { + "role": "user", + "content": user_input + } + ) + yield state['chatbox_messages'] agent = vision_agent_state["agent"] for _ in sampling_loop_sync( model=state["model"], messages=state["messages"], vision_agent = agent - ): yield state['messages'] + ): + state['chatbox_messages'] = [] + for message in state['messages']: + # convert message["content"] to gradio chatbox format + if type(message["content"]) is list: + gradio_chatbox_content = "" + for content in message["content"]: + # convert image_url to gradio image format + if content["type"] == "image_url": + gradio_chatbox_content += f'
' + # convert text to gradio text format + elif content["type"] == "text": + # agent response is json format and must contains reasoning + if message["role"] == "assistant": + content_json = json.loads(content["text"]) + gradio_chatbox_content += f'

{content_json["reasoning"]}

' + gradio_chatbox_content += f'
Detail
{json.dumps(content_json, indent=4)}
' + else: + gradio_chatbox_content += content["text"] + + state['chatbox_messages'].append({ + "role": message["role"], + "content": gradio_chatbox_content + }) + else: + state['chatbox_messages'].append({ + "role": message["role"], + "content": message["content"] + }) + yield state['chatbox_messages'] + def stop_app(state): state["stop"] = True @@ -202,7 +241,7 @@ def run(): autoscroll=True, height=580, type="messages") - + def update_model(model, state): state["model"] = model @@ -215,9 +254,10 @@ def run(): def clear_chat(state): # Reset message-related state state["messages"] = [] + state["chatbox_messages"] = [] state["responses"] = {} state["tools"] = {} - return state['messages'] + return state["chatbox_messages"] model.change(fn=update_model, inputs=[model, state], outputs=None) api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None) diff --git a/gradio_ui/executor/anthropic_executor.py b/gradio_ui/executor/anthropic_executor.py index 3a415dc..5c68de4 100644 --- a/gradio_ui/executor/anthropic_executor.py +++ b/gradio_ui/executor/anthropic_executor.py @@ -2,7 +2,10 @@ import asyncio import json from typing import Any, cast from anthropic.types.beta import ( - BetaMessageParam + BetaMessageParam, + BetaContentBlockParam, + BetaToolResultBlockParam, + BetaContentBlock ) from gradio_ui.tools import ComputerTool, ToolCollection @@ -13,13 +16,21 @@ class AnthropicExecutor: ComputerTool() ) - def __call__(self,messages: list[BetaMessageParam]): - content = json.loads(messages[-1]["content"]) - if content["next_action"] is not None: - # Run the asynchronous tool execution in a synchronous context - result = asyncio.run(self.tool_collection.run( - name=content["next_action"], - tool_input=cast(dict[str, Any], content["value"]), - )) - messages.append({"role": "assistant", "content": "tool result:\n"+str(result)}) - return messages \ No newline at end of file + def __call__(self, response, messages): + tool_result_content: list[str] = [] + for content_block in cast(list[BetaContentBlock], response.content): + # Execute the tool + if content_block.type == "tool_use": + # Run the asynchronous tool execution in a synchronous context + result = asyncio.run(self.tool_collection.run( + name=content_block.name, + tool_input=cast(dict[str, Any], content_block.input), + )) + tool_result_content.append( + str(result) + ) + messages.append({"role": "assistant", "content": "Run tool result:\n"+str(tool_result_content)}) + if not tool_result_content: + return messages + + return tool_result_content diff --git a/gradio_ui/loop.py b/gradio_ui/loop.py index 7cfd3b9..eca8e07 100644 --- a/gradio_ui/loop.py +++ b/gradio_ui/loop.py @@ -1,18 +1,16 @@ """ Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools. """ -from collections.abc import Callable +import base64 +from io import BytesIO from time import sleep import cv2 from gradio_ui.agent.verification_agent import VerificationAgent from gradio_ui.agent.vision_agent import VisionAgent from gradio_ui.tools.screen_capture import get_screenshot -from anthropic.types.beta import ( - BetaMessageParam -) +from anthropic.types.beta import (BetaMessageParam) from gradio_ui.agent.task_plan_agent import TaskPlanAgent from gradio_ui.agent.task_run_agent import TaskRunAgent -from gradio_ui.tools import ToolResult from gradio_ui.executor.anthropic_executor import AnthropicExecutor import numpy as np from PIL import Image @@ -31,72 +29,76 @@ def sampling_loop_sync( print('in sampling_loop_sync, model:', model) task_plan_agent = TaskPlanAgent() executor = AnthropicExecutor() - plan_list = task_plan_agent(messages=messages) - yield - task_run_agent = TaskRunAgent() verification_agent = VerificationAgent() + task_run_agent = TaskRunAgent() + parsed_screen_result = parsed_screen(vision_agent) + plan_list = task_plan_agent(messages=messages, parsed_screen_result=parsed_screen_result) + yield for plan in plan_list: execute_task_plan(plan, vision_agent, task_run_agent, executor, messages) yield sleep(2) - verification_loop(vision_agent, plan, verification_agent, executor, task_run_agent, messages) + verification_loop(vision_agent, verification_agent, executor, task_run_agent, messages) yield -def verification_loop(vision_agent, plan, verification_agent, executor, task_run_agent, messages): +def verification_loop(vision_agent, verification_agent, executor, task_run_agent, messages): """verification agent will be called in the loop""" while True: - # 验证结果 - verification_result = verification_agent(plan["expected_result"], messages) + # verification result + verification_result = verification_agent( messages) yield - # 如果验证成功,返回结果 + # if verification success, return result if verification_result["verification_status"] == "success": return - # 如果验证失败,执行补救措施 + # if verification failed, execute remedy measures elif verification_result["verification_status"] == "error": execute_task_plan(verification_result["remedy_measures"], vision_agent, task_run_agent, executor, messages) yield def execute_task_plan(plan, vision_agent, task_run_agent, executor, messages): - parsed_screen = parse_screen(vision_agent) - task_run_agent(task_plan=plan, parsed_screen=parsed_screen, messages=messages) - executor(messages) + parsed_screen_result = parsed_screen(vision_agent) + tools_use_needed, __ = task_run_agent(task_plan=plan, parsed_screen_result=parsed_screen_result, messages=messages) + executor(tools_use_needed, messages) -def parse_screen(vision_agent: VisionAgent): +def parsed_screen(vision_agent: VisionAgent): screenshot, screenshot_path = get_screenshot() response_json = {} response_json['parsed_content_list'] = vision_agent(str(screenshot_path)) response_json['width'] = screenshot.size[0] response_json['height'] = screenshot.size[1] response_json['image'] = draw_elements(screenshot, response_json['parsed_content_list']) + buffered = BytesIO() + response_json['image'].save(buffered, format="PNG") + response_json['base64_image'] = base64.b64encode(buffered.getvalue()).decode("utf-8") return response_json def draw_elements(screenshot, parsed_content_list): """ - 将PIL图像转换为OpenCV兼容格式并绘制边界框 + Convert PIL image to OpenCV compatible format and draw bounding boxes Args: - screenshot: PIL Image对象 - parsed_content_list: 包含边界框信息的列表 + screenshot: PIL Image object + parsed_content_list: list containing bounding box information Returns: - 带有绘制边界框的PIL图像 + PIL image with drawn bounding boxes """ - # 将PIL图像转换为opencv格式 + # convert PIL image to opencv format opencv_image = np.array(screenshot) opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR) - # 绘制边界框 + # draw bounding boxes for idx, element in enumerate(parsed_content_list): bbox = element.coordinates x1, y1, x2, y2 = bbox - # 转换坐标为整数 + # convert coordinates to integers x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) - # 绘制矩形 + # draw rectangle cv2.rectangle(opencv_image, (x1, y1), (x2, y2), (0, 0, 255), 2) - # 在矩形边框左上角绘制序号 + # draw index number cv2.putText(opencv_image, str(idx+1), (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) - # 将OpenCV图像格式转换回PIL格式 + # convert opencv image format back to PIL format opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(opencv_image) diff --git a/gradio_ui/tools/computer.py b/gradio_ui/tools/computer.py index 88d79cc..f802a87 100644 --- a/gradio_ui/tools/computer.py +++ b/gradio_ui/tools/computer.py @@ -2,12 +2,9 @@ import base64 import time from typing import Literal, TypedDict from PIL import Image -from util import tool from anthropic.types.beta import BetaToolComputerUse20241022Param from .base import BaseAnthropicTool, ToolError, ToolResult from .screen_capture import get_screenshot -import requests -import re import pyautogui import pyperclip import platform @@ -29,7 +26,9 @@ Action = Literal[ "screenshot", "cursor_position", "hover", - "wait" + "wait", + "scroll_up", + "scroll_down" ] class Resolution(TypedDict): @@ -65,7 +64,6 @@ class ComputerTool(BaseAnthropicTool): @property def options(self) -> ComputerToolOptions: - # 直接使用原始尺寸,不进行缩放 return { "display_width_px": self.width, "display_height_px": self.height, @@ -76,14 +74,12 @@ class ComputerTool(BaseAnthropicTool): return {"name": self.name, "type": self.api_type, **self.options} - def __init__(self, is_scaling: bool = False): + def __init__(self): super().__init__() - # Get screen width and height using Windows command self.display_num = None self.offset_x = 0 self.offset_y = 0 - self.width, self.height = self.get_screen_size() - print(f"screen size: {self.width}, {self.height}") + self.width, self.height = pyautogui.size() self.key_conversion = {"Page_Down": "pagedown", "Page_Up": "pageup", "Super_L": "win", @@ -216,19 +212,3 @@ class ComputerTool(BaseAnthropicTool): # padding to top left padding_image.paste(screenshot, (0, 0)) return padding_image - - - def get_screen_size(self): - """Return width and height of the screen""" - try: - response = tool.execute_command( - ["python", "-c", "import pyautogui; print(pyautogui.size())"] - ) - output = response['output'].strip() - match = re.search(r'Size\(width=(\d+),\s*height=(\d+)\)', output) - if not match: - raise ToolError(f"Could not parse screen size from output: {output}") - width, height = map(int, match.groups()) - return width, height - except requests.exceptions.RequestException as e: - raise ToolError(f"An error occurred while trying to get screen size: {str(e)}") \ No newline at end of file diff --git a/gradio_ui/tools/screen_capture.py b/gradio_ui/tools/screen_capture.py index 430aee2..2642741 100644 --- a/gradio_ui/tools/screen_capture.py +++ b/gradio_ui/tools/screen_capture.py @@ -13,10 +13,8 @@ def get_screenshot(resize: bool = False, target_width: int = 1920, target_height path = output_dir / f"screenshot_{uuid4().hex}.png" try: - # 使用 tool.capture_screen_with_cursor 替代 requests.get img_io = tool.capture_screen_with_cursor() - screenshot = Image.open(img_io) - + screenshot = Image.open(img_io) if resize and screenshot.size != (target_width, target_height): screenshot = screenshot.resize((target_width, target_height)) screenshot.save(path) diff --git a/util/tool.py b/util/tool.py index 309f32f..9c88577 100644 --- a/util/tool.py +++ b/util/tool.py @@ -1,38 +1,9 @@ import os -import shlex -import subprocess -import threading import pyautogui from PIL import Image from io import BytesIO -computer_control_lock = threading.Lock() -def execute_command(command, shell=False): - """Local function to execute a command.""" - with computer_control_lock: - if isinstance(command, str) and not shell: - command = shlex.split(command) - - # Expand user directory - for i, arg in enumerate(command): - if arg.startswith("~/"): - command[i] = os.path.expanduser(arg) - - try: - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True, timeout=120) - return { - 'status': 'success', - 'output': result.stdout, - 'error': result.stderr, - 'returncode': result.returncode - } - except Exception as e: - return { - 'status': 'error', - 'message': str(e) - } - def capture_screen_with_cursor(): """Local function to capture the screen with cursor.""" cursor_path = os.path.join(os.path.dirname(__file__),"..","resources", "cursor.png")