mirror of
https://github.com/yuruotong1/autoMate.git
synced 2025-12-26 05:16:21 +08:00
优化任务规划功能
This commit is contained in:
parent
73a5c991c8
commit
15d18b0c7d
@ -2,8 +2,14 @@ from gradio_ui.agent.base_agent import BaseAgent
|
||||
from xbrain.core.chat import run
|
||||
|
||||
class TaskPlanAgent(BaseAgent):
|
||||
def __init__(self, output_callback):
|
||||
self.output_callback = output_callback
|
||||
|
||||
def __call__(self, user_task: str):
|
||||
return run([{"role": "user", "content": user_task}], user_prompt=self.SYSTEM_PROMPT)
|
||||
self.output_callback("正在规划任务中...", sender="bot")
|
||||
response = run([{"role": "user", "content": user_task}], user_prompt=system_prompt)
|
||||
self.output_callback(response, sender="bot")
|
||||
return response
|
||||
|
||||
system_prompt = """
|
||||
### 目标 ###
|
||||
|
||||
@ -22,28 +22,40 @@ class TaskRunAgent(BaseAgent):
|
||||
screen_info=screen_info)
|
||||
|
||||
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
|
||||
vlm_response = run([{"role": "user", "content": "next"}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
|
||||
img_to_show = parsed_screen["image"]
|
||||
buffered = BytesIO()
|
||||
img_to_show.save(buffered, format="PNG")
|
||||
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
vlm_response = run([
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "图片是当前屏幕的截图,请根据图片以及解析出来的元素,确定下一步操作"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img_to_show_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
|
||||
vlm_response_json = json.loads(vlm_response)
|
||||
if "box_id" in vlm_response_json:
|
||||
try:
|
||||
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])].coordinates
|
||||
# vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
|
||||
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 ), int((bbox[1] + bbox[3]) / 2 )]
|
||||
|
||||
# img_to_show_data = base64.b64decode(img_to_show_base64)
|
||||
# img_to_show = Image.open(BytesIO(img_to_show_data))
|
||||
img_to_show = parsed_screen["image"]
|
||||
draw = ImageDraw.Draw(img_to_show)
|
||||
x, y = vlm_response_json["box_centroid_coordinate"]
|
||||
radius = 10
|
||||
draw = ImageDraw.Draw(img_to_show)
|
||||
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
|
||||
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
|
||||
buffered = BytesIO()
|
||||
img_to_show.save(buffered, format="PNG")
|
||||
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except:
|
||||
except Exception as e:
|
||||
print(f"Error parsing: {vlm_response_json}")
|
||||
pass
|
||||
print(f"Error: {e}")
|
||||
self.output_callback(f'<img src="data:image/png;base64,{img_to_show_base64}">', sender="bot")
|
||||
self.output_callback(
|
||||
f'<details>'
|
||||
@ -112,14 +124,12 @@ class TaskRunAgentResponse(BaseModel):
|
||||
|
||||
system_prompt = """
|
||||
### 目标 ###
|
||||
你正在使用{device}设备,请你根据【总体任务】、【历史操作记录】和【当前屏幕信息】确定【下一步操作】:
|
||||
你是一个自动化规划师,需要完成用户的任务。请你根据屏幕信息确定【下一步操作】,以完成任务:
|
||||
|
||||
1. 结合【当前屏幕信息】、【历史操作记录】,思考一下当前处于【总体任务】的哪一阶段了,然后再确定【下一步操作】。
|
||||
|
||||
你当前的【总体任务】是:
|
||||
你当前的任务是:
|
||||
{task_plan}
|
||||
|
||||
以下是检测当前屏幕上所有的【当前屏幕信息】:
|
||||
以下是用yolo检测的当前屏幕上的所有元素:
|
||||
|
||||
{screen_info}
|
||||
##########
|
||||
@ -135,7 +145,7 @@ system_prompt = """
|
||||
8. 如果您收到登录信息页面或验证码页面的提示,或者您认为下一步操作需要用户许可,您应该在json字段中说"Next Action": "None"。
|
||||
9. 你只能使用鼠标和键盘与计算机进行交互。
|
||||
10. 你只能与桌面图形用户界面交互(无法访问终端或应用程序菜单)。
|
||||
11. 如果当前屏幕没有显示任何可操作的元素,并且当前屏幕不能下滑,请选择None,退出操作。
|
||||
11. 如果当前屏幕没有显示任何可操作的元素,并且当前屏幕不能下滑,请返回None。
|
||||
|
||||
##########
|
||||
### 输出格式 ###
|
||||
|
||||
@ -88,7 +88,6 @@ class VisionAgent:
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: 图片读取失败")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
def _get_optimal_device_and_dtype(self):
|
||||
|
||||
@ -3,6 +3,9 @@ Agentic sampling loop that calls the Anthropic API and local implenmentation of
|
||||
"""
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from time import time
|
||||
|
||||
import cv2
|
||||
from gradio_ui.agent.vision_agent import VisionAgent
|
||||
from gradio_ui.tools.screen_capture import get_screenshot
|
||||
from anthropic import APIResponse
|
||||
@ -19,6 +22,9 @@ from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient
|
||||
from gradio_ui.agent.vlm_agent import VLMAgent
|
||||
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
OUTPUT_DIR = "./tmp/outputs"
|
||||
|
||||
def sampling_loop_sync(
|
||||
@ -35,7 +41,7 @@ def sampling_loop_sync(
|
||||
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
|
||||
"""
|
||||
print('in sampling_loop_sync, model:', model)
|
||||
task_plan_agent = TaskPlanAgent()
|
||||
task_plan_agent = TaskPlanAgent(output_callback=output_callback)
|
||||
executor = AnthropicExecutor(
|
||||
output_callback=output_callback,
|
||||
tool_output_callback=tool_output_callback,
|
||||
@ -48,6 +54,7 @@ def sampling_loop_sync(
|
||||
while True:
|
||||
parsed_screen = parse_screen(vision_agent)
|
||||
tools_use_needed, __ = task_run_agent(task_plan=plan, parsed_screen=parsed_screen)
|
||||
time.sleep(1)
|
||||
for message, tool_result_content in executor(tools_use_needed, messages):
|
||||
yield message
|
||||
if not tool_result_content:
|
||||
@ -59,5 +66,37 @@ def parse_screen(vision_agent: VisionAgent):
|
||||
response_json['parsed_content_list'] = vision_agent(str(screenshot_path))
|
||||
response_json['width'] = screenshot.size[0]
|
||||
response_json['height'] = screenshot.size[1]
|
||||
response_json['image'] = screenshot
|
||||
response_json['image'] = draw_elements(screenshot, response_json['parsed_content_list'])
|
||||
return response_json
|
||||
|
||||
def draw_elements(screenshot, parsed_content_list):
|
||||
"""
|
||||
将PIL图像转换为OpenCV兼容格式并绘制边界框
|
||||
|
||||
Args:
|
||||
screenshot: PIL Image对象
|
||||
parsed_content_list: 包含边界框信息的列表
|
||||
|
||||
Returns:
|
||||
带有绘制边界框的PIL图像
|
||||
"""
|
||||
# 将PIL图像转换为opencv格式
|
||||
opencv_image = np.array(screenshot)
|
||||
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
|
||||
# 绘制边界框
|
||||
for idx, element in enumerate(parsed_content_list):
|
||||
bbox = element.coordinates
|
||||
x1, y1, x2, y2 = bbox
|
||||
# 转换坐标为整数
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
# 绘制矩形
|
||||
cv2.rectangle(opencv_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
||||
# 在矩形边框左上角绘制序号
|
||||
cv2.putText(opencv_image, str(idx+1), (x1, y1-10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
|
||||
|
||||
# 将OpenCV图像格式转换回PIL格式
|
||||
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(opencv_image)
|
||||
|
||||
return pil_image
|
||||
Loading…
x
Reference in New Issue
Block a user