更新 task_run_agent

This commit is contained in:
yuruo 2025-03-11 23:03:10 +08:00
parent b2d559f15a
commit 7542d73ccf
4 changed files with 113 additions and 148 deletions

View File

@ -1,18 +1,65 @@
import json
import uuid
from anthropic.types.beta import BetaMessage, BetaUsage
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from PIL import Image, ImageDraw
import base64
from gradio import Image
from io import BytesIO
from pydantic import BaseModel, Field
from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
import platform
import re
class TaskRunAgent(BaseAgent):
def __call__(self, task_plan: str, screen_info):
def __init__(self):
self.OUTPUT_DIR = "./tmp/outputs"
device = self.get_device()
def __call__(self, task_plan, parsed_screen):
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
device=device,
screen_info=screen_info)
print(self.SYSTEM_PROMPT)
device=self.get_device(),
screen_info=parsed_screen["parsed_content_list"])
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
vlm_response = run([{"role": "user", "content": "next"}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
vlm_response_json = json.loads(vlm_response)
if "box_id" in vlm_response_json:
try:
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])]["bbox"]
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
img_to_show_data = base64.b64decode(img_to_show_base64)
img_to_show = Image.open(BytesIO(img_to_show_data))
draw = ImageDraw.Draw(img_to_show)
x, y = vlm_response_json["box_centroid_coordinate"]
radius = 10
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
buffered = BytesIO()
img_to_show.save(buffered, format="PNG")
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
except:
print(f"Error parsing: {vlm_response_json}")
pass
response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')]
if 'box_centroid_coordinate' in vlm_response_json:
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
name='computer', type='tool_use')
response_content.append(move_cursor_block)
if vlm_response_json["next_action"] == "None":
print("Task paused/completed.")
elif vlm_response_json["next_action"] == "type":
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["next_action"], 'text': vlm_response_json["value"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
else:
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["next_action"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
return response_message, vlm_response_json
def get_device(self):
# 获取当前操作系统信息
@ -27,10 +74,6 @@ class TaskRunAgent(BaseAgent):
device = system
return device
def __call__(self, task):
res = run([{"role": "user", "content": task}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=res, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
return response_message
def extract_data(self, input_string, data_type):
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
@ -84,12 +127,12 @@ system_prompt = """
##########
### 输出格式 ###
```json
{
{{
"Reasoning": str, # 描述当前屏幕上的内容,考虑历史记录,然后描述您如何实现任务的逐步思考,一次从可用操作中选择一个操作。
"Next Action": "action_type, action description" | "None" # 一次一个操作,简短精确地描述它。
"Box ID": n,
"value": "xxx" # 仅当操作为type时提供value字段否则不包括value键
}
}}
```
Next Action仅包括下面之一
@ -106,28 +149,28 @@ system_prompt = """
### 案例 ###
一个例子
```json
{
"Reasoning": "当前屏幕显示亚马逊的谷歌搜索结果在之前的操作中我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。",
"Next Action": "left_click",
"Box ID": m
}
{{
"reasoning": "当前屏幕显示亚马逊的谷歌搜索结果在之前的操作中我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。",
"next_action": "left_click",
"box_id": m
}}
```
另一个例子
```json
{
"Reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"",
"Next Action": "type",
"Box ID": n,
{{
"reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"",
"next_action": "type",
"box_id": n,
"value": "Apple watch"
}
}}
```
另一个例子
```json
{
"Reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。",
"Next Action": "scroll_down"
}
{{
"reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。",
"next_action": "scroll_down"
}}
"""

View File

@ -67,37 +67,6 @@ class VisionAgent:
print("图像描述模型加载成功")
except Exception as e:
print(f"加载图像描述模型失败: {e}")
print("尝试使用备用方法加载...")
# 备用加载方法
try:
# 先加载到CPU再转移到目标设备
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
)
# 如果是CUDA设备尝试转换为float16
if self.device.type == 'cuda':
try:
self.caption_model = self.caption_model.to(dtype=torch.float16)
except:
print("转换为float16失败使用float32")
# 移动到目标设备
self.caption_model = self.caption_model.to(self.device)
print("使用备用方法加载成功")
except Exception as e2:
print(f"备用加载方法也失败: {e2}")
print("回退到CPU模式")
self.device = torch.device("cpu")
self.dtype = torch.float32
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
# 设置提示词
self.prompt = "<CAPTION>"
@ -113,6 +82,15 @@ class VisionAgent:
self.elements: List[UIElement] = []
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
def __call__(self, image_path: str) -> List[UIElement]:
"""Process an image from file path."""
# image = self.load_image(image_source)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"Vision agent: 图片读取失败")
return self.analyze_image(image)
def _get_optimal_device_and_dtype(self):
"""确定最佳设备和数据类型"""
if torch.cuda.is_available():
@ -261,55 +239,7 @@ class VisionAgent:
torch.cuda.empty_cache()
except RuntimeError as e:
print(f"批次处理失败: {e}")
# 如果是CUDA错误尝试在CPU上处理
if "CUDA" in str(e) or "cuda" in str(e):
print("尝试在CPU上处理此批次...")
try:
# 临时将模型移至CPU
self.caption_model = self.caption_model.to("cpu")
# 在CPU上处理
inputs = self.caption_processor(
images=batch,
text=[self.prompt] * len(batch),
return_tensors="pt"
)
with torch.no_grad():
if 'florence' in self.caption_model.config.model_type:
generated_ids = self.caption_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=20,
num_beams=1,
do_sample=False
)
else:
generated_ids = self.caption_model.generate(
**inputs,
max_length=50,
num_beams=3,
early_stopping=True
)
texts = self.caption_processor.batch_decode(
generated_ids,
skip_special_tokens=True
)
texts = [text.strip() for text in texts]
generated_texts.extend(texts)
# 处理完成后将模型移回原设备
self.caption_model = self.caption_model.to(self.device)
except Exception as cpu_e:
print(f"CPU处理也失败: {cpu_e}")
generated_texts.extend(["[描述生成失败]"] * len(batch))
else:
# 非CUDA错误直接添加占位符
generated_texts.extend(["[描述生成失败]"] * len(batch))
print(f"批次处理失败: {e}")
return generated_texts
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
@ -378,14 +308,6 @@ class VisionAgent:
raise ValueError(error_msg) from e
def __call__(self, image_path: str) -> List[UIElement]:
"""Process an image from file path."""
# image = self.load_image(image_source)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"Vision agent: 图片读取失败")
return self.analyze_image(image)

View File

@ -27,44 +27,42 @@ class AnthropicExecutor:
self.output_callback = output_callback
self.tool_output_callback = tool_output_callback
def __call__(self, response: BetaMessage, messages: list[BetaMessageParam]):
def __call__(self, response, messages: list[BetaMessageParam]):
new_message = {
"role": "assistant",
"content": cast(list[BetaContentBlockParam], response.content),
"content": cast(list[BetaContentBlockParam], response),
}
if new_message not in messages:
messages.append(new_message)
else:
print("new_message already in messages, there are duplicates.")
tool_result_content: list[BetaToolResultBlockParam] = []
for content_block in cast(list[BetaContentBlock], response.content):
self.output_callback(content_block, sender="bot")
# Execute the tool
if content_block.type == "tool_use":
# Run the asynchronous tool execution in a synchronous context
result = asyncio.run(self.tool_collection.run(
name=content_block.name,
tool_input=cast(dict[str, Any], content_block.input),
))
self.output_callback(result, sender="bot")
tool_result_content.append(
_make_api_tool_result(result, content_block.id)
)
self.tool_output_callback(result, content_block.id)
self.output_callback(response["action_type"], sender="bot")
# Execute the tool
if response["next_action"] != None:
# Run the asynchronous tool execution in a synchronous context
result = asyncio.run(self.tool_collection.run(
name=response["action_type"],
tool_input=cast(dict[str, Any], content_block.input),
))
self.output_callback(result, sender="bot")
tool_result_content.append(
_make_api_tool_result(result, content_block.id)
)
self.tool_output_callback(result, content_block.id)
# Craft messages based on the content_block
# Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
display_messages = _message_display_callback(messages)
# display_messages = []
# Send the messages to the gradio
for user_msg, bot_msg in display_messages:
# yield [user_msg, bot_msg], tool_result_content
yield [None, None], tool_result_content
# Craft messages based on the content_block
# Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
display_messages = _message_display_callback(messages)
# display_messages = []
# Send the messages to the gradio
for user_msg, bot_msg in display_messages:
# yield [user_msg, bot_msg], tool_result_content
yield [None, None], tool_result_content
if not tool_result_content:
return messages

View File

@ -61,14 +61,16 @@ def sampling_loop_sync(
while True:
parsed_screen = parse_screen(vision_agent)
# tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
tools_use_needed = task_run_agent(task_plan=plan, screen_info=parsed_screen)
tools_use_needed, vlm_response_json = task_run_agent(task_plan=plan, screen_info=parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content:
return messages
def parse_screen(vision_agent: VisionAgent):
_, screenshot_path = get_screenshot()
parsed_screen = vision_agent(str(screenshot_path))
return parsed_screen
screenshot, screenshot_path = get_screenshot()
response_json = {}
response_json['parsed_content_list'] = vision_agent(str(screenshot_path))
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
return response_json