mirror of
https://github.com/yuruotong1/autoMate.git
synced 2025-12-26 05:16:21 +08:00
更新 task_run_agent
This commit is contained in:
parent
b2d559f15a
commit
7542d73ccf
@ -1,18 +1,65 @@
|
||||
import json
|
||||
import uuid
|
||||
from anthropic.types.beta import BetaMessage, BetaUsage
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
|
||||
from PIL import Image, ImageDraw
|
||||
import base64
|
||||
from gradio import Image
|
||||
from io import BytesIO
|
||||
from pydantic import BaseModel, Field
|
||||
from gradio_ui.agent.base_agent import BaseAgent
|
||||
from xbrain.core.chat import run
|
||||
import platform
|
||||
import re
|
||||
class TaskRunAgent(BaseAgent):
|
||||
def __call__(self, task_plan: str, screen_info):
|
||||
def __init__(self):
|
||||
self.OUTPUT_DIR = "./tmp/outputs"
|
||||
device = self.get_device()
|
||||
|
||||
def __call__(self, task_plan, parsed_screen):
|
||||
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
|
||||
device=device,
|
||||
screen_info=screen_info)
|
||||
print(self.SYSTEM_PROMPT)
|
||||
device=self.get_device(),
|
||||
screen_info=parsed_screen["parsed_content_list"])
|
||||
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
|
||||
vlm_response = run([{"role": "user", "content": "next"}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
|
||||
vlm_response_json = json.loads(vlm_response)
|
||||
if "box_id" in vlm_response_json:
|
||||
try:
|
||||
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])]["bbox"]
|
||||
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
|
||||
img_to_show_data = base64.b64decode(img_to_show_base64)
|
||||
img_to_show = Image.open(BytesIO(img_to_show_data))
|
||||
draw = ImageDraw.Draw(img_to_show)
|
||||
x, y = vlm_response_json["box_centroid_coordinate"]
|
||||
radius = 10
|
||||
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
|
||||
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
|
||||
buffered = BytesIO()
|
||||
img_to_show.save(buffered, format="PNG")
|
||||
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except:
|
||||
print(f"Error parsing: {vlm_response_json}")
|
||||
pass
|
||||
response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')]
|
||||
if 'box_centroid_coordinate' in vlm_response_json:
|
||||
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(move_cursor_block)
|
||||
|
||||
if vlm_response_json["next_action"] == "None":
|
||||
print("Task paused/completed.")
|
||||
elif vlm_response_json["next_action"] == "type":
|
||||
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': vlm_response_json["next_action"], 'text': vlm_response_json["value"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(sim_content_block)
|
||||
else:
|
||||
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': vlm_response_json["next_action"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(sim_content_block)
|
||||
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
|
||||
return response_message, vlm_response_json
|
||||
|
||||
|
||||
def get_device(self):
|
||||
# 获取当前操作系统信息
|
||||
@ -27,10 +74,6 @@ class TaskRunAgent(BaseAgent):
|
||||
device = system
|
||||
return device
|
||||
|
||||
def __call__(self, task):
|
||||
res = run([{"role": "user", "content": task}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
|
||||
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=res, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
|
||||
return response_message
|
||||
|
||||
def extract_data(self, input_string, data_type):
|
||||
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
|
||||
@ -84,12 +127,12 @@ system_prompt = """
|
||||
##########
|
||||
### 输出格式 ###
|
||||
```json
|
||||
{
|
||||
{{
|
||||
"Reasoning": str, # 描述当前屏幕上的内容,考虑历史记录,然后描述您如何实现任务的逐步思考,一次从可用操作中选择一个操作。
|
||||
"Next Action": "action_type, action description" | "None" # 一次一个操作,简短精确地描述它。
|
||||
"Box ID": n,
|
||||
"value": "xxx" # 仅当操作为type时提供value字段,否则不包括value键
|
||||
}
|
||||
}}
|
||||
```
|
||||
|
||||
【Next Action】仅包括下面之一:
|
||||
@ -106,28 +149,28 @@ system_prompt = """
|
||||
### 案例 ###
|
||||
一个例子:
|
||||
```json
|
||||
{
|
||||
"Reasoning": "当前屏幕显示亚马逊的谷歌搜索结果,在之前的操作中,我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。",
|
||||
"Next Action": "left_click",
|
||||
"Box ID": m
|
||||
}
|
||||
{{
|
||||
"reasoning": "当前屏幕显示亚马逊的谷歌搜索结果,在之前的操作中,我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。",
|
||||
"next_action": "left_click",
|
||||
"box_id": m
|
||||
}}
|
||||
```
|
||||
|
||||
另一个例子:
|
||||
```json
|
||||
{
|
||||
"Reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"。",
|
||||
"Next Action": "type",
|
||||
"Box ID": n,
|
||||
{{
|
||||
"reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"。",
|
||||
"next_action": "type",
|
||||
"box_id": n,
|
||||
"value": "Apple watch"
|
||||
}
|
||||
}}
|
||||
```
|
||||
|
||||
另一个例子:
|
||||
```json
|
||||
{
|
||||
"Reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。",
|
||||
"Next Action": "scroll_down"
|
||||
}
|
||||
{{
|
||||
"reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。",
|
||||
"next_action": "scroll_down"
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@ -67,37 +67,6 @@ class VisionAgent:
|
||||
print("图像描述模型加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载图像描述模型失败: {e}")
|
||||
print("尝试使用备用方法加载...")
|
||||
|
||||
# 备用加载方法
|
||||
try:
|
||||
# 先加载到CPU,再转移到目标设备
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 如果是CUDA设备,尝试转换为float16
|
||||
if self.device.type == 'cuda':
|
||||
try:
|
||||
self.caption_model = self.caption_model.to(dtype=torch.float16)
|
||||
except:
|
||||
print("转换为float16失败,使用float32")
|
||||
|
||||
# 移动到目标设备
|
||||
self.caption_model = self.caption_model.to(self.device)
|
||||
print("使用备用方法加载成功")
|
||||
except Exception as e2:
|
||||
print(f"备用加载方法也失败: {e2}")
|
||||
print("回退到CPU模式")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
# 设置提示词
|
||||
self.prompt = "<CAPTION>"
|
||||
@ -113,6 +82,15 @@ class VisionAgent:
|
||||
self.elements: List[UIElement] = []
|
||||
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
|
||||
|
||||
def __call__(self, image_path: str) -> List[UIElement]:
|
||||
"""Process an image from file path."""
|
||||
# image = self.load_image(image_source)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: 图片读取失败")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
def _get_optimal_device_and_dtype(self):
|
||||
"""确定最佳设备和数据类型"""
|
||||
if torch.cuda.is_available():
|
||||
@ -261,55 +239,7 @@ class VisionAgent:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except RuntimeError as e:
|
||||
print(f"批次处理失败: {e}")
|
||||
|
||||
# 如果是CUDA错误,尝试在CPU上处理
|
||||
if "CUDA" in str(e) or "cuda" in str(e):
|
||||
print("尝试在CPU上处理此批次...")
|
||||
try:
|
||||
# 临时将模型移至CPU
|
||||
self.caption_model = self.caption_model.to("cpu")
|
||||
|
||||
# 在CPU上处理
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
if 'florence' in self.caption_model.config.model_type:
|
||||
generated_ids = self.caption_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False
|
||||
)
|
||||
else:
|
||||
generated_ids = self.caption_model.generate(
|
||||
**inputs,
|
||||
max_length=50,
|
||||
num_beams=3,
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
texts = self.caption_processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
texts = [text.strip() for text in texts]
|
||||
generated_texts.extend(texts)
|
||||
|
||||
# 处理完成后将模型移回原设备
|
||||
self.caption_model = self.caption_model.to(self.device)
|
||||
except Exception as cpu_e:
|
||||
print(f"CPU处理也失败: {cpu_e}")
|
||||
generated_texts.extend(["[描述生成失败]"] * len(batch))
|
||||
else:
|
||||
# 非CUDA错误,直接添加占位符
|
||||
generated_texts.extend(["[描述生成失败]"] * len(batch))
|
||||
|
||||
print(f"批次处理失败: {e}")
|
||||
return generated_texts
|
||||
|
||||
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
|
||||
@ -378,14 +308,6 @@ class VisionAgent:
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def __call__(self, image_path: str) -> List[UIElement]:
|
||||
"""Process an image from file path."""
|
||||
# image = self.load_image(image_source)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: 图片读取失败")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
|
||||
|
||||
@ -27,44 +27,42 @@ class AnthropicExecutor:
|
||||
self.output_callback = output_callback
|
||||
self.tool_output_callback = tool_output_callback
|
||||
|
||||
def __call__(self, response: BetaMessage, messages: list[BetaMessageParam]):
|
||||
def __call__(self, response, messages: list[BetaMessageParam]):
|
||||
new_message = {
|
||||
"role": "assistant",
|
||||
"content": cast(list[BetaContentBlockParam], response.content),
|
||||
"content": cast(list[BetaContentBlockParam], response),
|
||||
}
|
||||
if new_message not in messages:
|
||||
messages.append(new_message)
|
||||
else:
|
||||
print("new_message already in messages, there are duplicates.")
|
||||
|
||||
tool_result_content: list[BetaToolResultBlockParam] = []
|
||||
for content_block in cast(list[BetaContentBlock], response.content):
|
||||
self.output_callback(content_block, sender="bot")
|
||||
# Execute the tool
|
||||
if content_block.type == "tool_use":
|
||||
# Run the asynchronous tool execution in a synchronous context
|
||||
result = asyncio.run(self.tool_collection.run(
|
||||
name=content_block.name,
|
||||
tool_input=cast(dict[str, Any], content_block.input),
|
||||
))
|
||||
|
||||
self.output_callback(result, sender="bot")
|
||||
|
||||
tool_result_content.append(
|
||||
_make_api_tool_result(result, content_block.id)
|
||||
)
|
||||
self.tool_output_callback(result, content_block.id)
|
||||
self.output_callback(response["action_type"], sender="bot")
|
||||
# Execute the tool
|
||||
if response["next_action"] != None:
|
||||
# Run the asynchronous tool execution in a synchronous context
|
||||
result = asyncio.run(self.tool_collection.run(
|
||||
name=response["action_type"],
|
||||
tool_input=cast(dict[str, Any], content_block.input),
|
||||
))
|
||||
|
||||
self.output_callback(result, sender="bot")
|
||||
|
||||
tool_result_content.append(
|
||||
_make_api_tool_result(result, content_block.id)
|
||||
)
|
||||
self.tool_output_callback(result, content_block.id)
|
||||
|
||||
# Craft messages based on the content_block
|
||||
# Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
|
||||
|
||||
display_messages = _message_display_callback(messages)
|
||||
# display_messages = []
|
||||
|
||||
# Send the messages to the gradio
|
||||
for user_msg, bot_msg in display_messages:
|
||||
# yield [user_msg, bot_msg], tool_result_content
|
||||
yield [None, None], tool_result_content
|
||||
# Craft messages based on the content_block
|
||||
# Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
|
||||
|
||||
display_messages = _message_display_callback(messages)
|
||||
# display_messages = []
|
||||
|
||||
# Send the messages to the gradio
|
||||
for user_msg, bot_msg in display_messages:
|
||||
# yield [user_msg, bot_msg], tool_result_content
|
||||
yield [None, None], tool_result_content
|
||||
|
||||
if not tool_result_content:
|
||||
return messages
|
||||
|
||||
@ -61,14 +61,16 @@ def sampling_loop_sync(
|
||||
|
||||
while True:
|
||||
parsed_screen = parse_screen(vision_agent)
|
||||
# tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
|
||||
tools_use_needed = task_run_agent(task_plan=plan, screen_info=parsed_screen)
|
||||
tools_use_needed, vlm_response_json = task_run_agent(task_plan=plan, screen_info=parsed_screen)
|
||||
for message, tool_result_content in executor(tools_use_needed, messages):
|
||||
yield message
|
||||
if not tool_result_content:
|
||||
return messages
|
||||
|
||||
def parse_screen(vision_agent: VisionAgent):
|
||||
_, screenshot_path = get_screenshot()
|
||||
parsed_screen = vision_agent(str(screenshot_path))
|
||||
return parsed_screen
|
||||
screenshot, screenshot_path = get_screenshot()
|
||||
response_json = {}
|
||||
response_json['parsed_content_list'] = vision_agent(str(screenshot_path))
|
||||
response_json['width'] = screenshot.size[0]
|
||||
response_json['height'] = screenshot.size[1]
|
||||
return response_json
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user