diff --git a/gradio_ui/agent/task_run_agent.py b/gradio_ui/agent/task_run_agent.py index d86fdc8..b3856b4 100644 --- a/gradio_ui/agent/task_run_agent.py +++ b/gradio_ui/agent/task_run_agent.py @@ -1,18 +1,65 @@ +import json import uuid -from anthropic.types.beta import BetaMessage, BetaUsage +from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage +from PIL import Image, ImageDraw +import base64 +from gradio import Image +from io import BytesIO from pydantic import BaseModel, Field from gradio_ui.agent.base_agent import BaseAgent from xbrain.core.chat import run import platform import re class TaskRunAgent(BaseAgent): - def __call__(self, task_plan: str, screen_info): + def __init__(self): self.OUTPUT_DIR = "./tmp/outputs" - device = self.get_device() + + def __call__(self, task_plan, parsed_screen): self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan, - device=device, - screen_info=screen_info) - print(self.SYSTEM_PROMPT) + device=self.get_device(), + screen_info=parsed_screen["parsed_content_list"]) + screen_width, screen_height = parsed_screen['width'], parsed_screen['height'] + vlm_response = run([{"role": "user", "content": "next"}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse) + vlm_response_json = json.loads(vlm_response) + if "box_id" in vlm_response_json: + try: + bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])]["bbox"] + vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)] + img_to_show_data = base64.b64decode(img_to_show_base64) + img_to_show = Image.open(BytesIO(img_to_show_data)) + draw = ImageDraw.Draw(img_to_show) + x, y = vlm_response_json["box_centroid_coordinate"] + radius = 10 + draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red') + draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2) + buffered = BytesIO() + img_to_show.save(buffered, format="PNG") + img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + except: + print(f"Error parsing: {vlm_response_json}") + pass + response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')] + if 'box_centroid_coordinate' in vlm_response_json: + move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', + input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]}, + name='computer', type='tool_use') + response_content.append(move_cursor_block) + + if vlm_response_json["next_action"] == "None": + print("Task paused/completed.") + elif vlm_response_json["next_action"] == "type": + sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', + input={'action': vlm_response_json["next_action"], 'text': vlm_response_json["value"]}, + name='computer', type='tool_use') + response_content.append(sim_content_block) + else: + sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', + input={'action': vlm_response_json["next_action"]}, + name='computer', type='tool_use') + response_content.append(sim_content_block) + response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0)) + return response_message, vlm_response_json + def get_device(self): # 获取当前操作系统信息 @@ -27,10 +74,6 @@ class TaskRunAgent(BaseAgent): device = system return device - def __call__(self, task): - res = run([{"role": "user", "content": task}], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse) - response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=res, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0)) - return response_message def extract_data(self, input_string, data_type): # Regular expression to extract content starting from '```python' until the end if there are no closing backticks @@ -84,12 +127,12 @@ system_prompt = """ ########## ### 输出格式 ### ```json -{ +{{ "Reasoning": str, # 描述当前屏幕上的内容,考虑历史记录,然后描述您如何实现任务的逐步思考,一次从可用操作中选择一个操作。 "Next Action": "action_type, action description" | "None" # 一次一个操作,简短精确地描述它。 "Box ID": n, "value": "xxx" # 仅当操作为type时提供value字段,否则不包括value键 -} +}} ``` 【Next Action】仅包括下面之一: @@ -106,28 +149,28 @@ system_prompt = """ ### 案例 ### 一个例子: ```json -{ - "Reasoning": "当前屏幕显示亚马逊的谷歌搜索结果,在之前的操作中,我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。", - "Next Action": "left_click", - "Box ID": m -} +{{ + "reasoning": "当前屏幕显示亚马逊的谷歌搜索结果,在之前的操作中,我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。", + "next_action": "left_click", + "box_id": m +}} ``` 另一个例子: ```json -{ - "Reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"。", - "Next Action": "type", - "Box ID": n, +{{ + "reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"。", + "next_action": "type", + "box_id": n, "value": "Apple watch" -} +}} ``` 另一个例子: ```json -{ - "Reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。", - "Next Action": "scroll_down" -} +{{ + "reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。", + "next_action": "scroll_down" +}} """ diff --git a/gradio_ui/agent/vision_agent.py b/gradio_ui/agent/vision_agent.py index ca44088..e42f3f6 100644 --- a/gradio_ui/agent/vision_agent.py +++ b/gradio_ui/agent/vision_agent.py @@ -67,37 +67,6 @@ class VisionAgent: print("图像描述模型加载成功") except Exception as e: print(f"加载图像描述模型失败: {e}") - print("尝试使用备用方法加载...") - - # 备用加载方法 - try: - # 先加载到CPU,再转移到目标设备 - self.caption_model = AutoModelForCausalLM.from_pretrained( - caption_model_path, - torch_dtype=torch.float32, - trust_remote_code=True - ) - - # 如果是CUDA设备,尝试转换为float16 - if self.device.type == 'cuda': - try: - self.caption_model = self.caption_model.to(dtype=torch.float16) - except: - print("转换为float16失败,使用float32") - - # 移动到目标设备 - self.caption_model = self.caption_model.to(self.device) - print("使用备用方法加载成功") - except Exception as e2: - print(f"备用加载方法也失败: {e2}") - print("回退到CPU模式") - self.device = torch.device("cpu") - self.dtype = torch.float32 - self.caption_model = AutoModelForCausalLM.from_pretrained( - caption_model_path, - torch_dtype=torch.float32, - trust_remote_code=True - ).to(self.device) # 设置提示词 self.prompt = "" @@ -113,6 +82,15 @@ class VisionAgent: self.elements: List[UIElement] = [] self.ocr_reader = easyocr.Reader(['en', 'ch_sim']) + def __call__(self, image_path: str) -> List[UIElement]: + """Process an image from file path.""" + # image = self.load_image(image_source) + image = cv2.imread(image_path) + if image is None: + raise FileNotFoundError(f"Vision agent: 图片读取失败") + + return self.analyze_image(image) + def _get_optimal_device_and_dtype(self): """确定最佳设备和数据类型""" if torch.cuda.is_available(): @@ -261,55 +239,7 @@ class VisionAgent: torch.cuda.empty_cache() except RuntimeError as e: - print(f"批次处理失败: {e}") - - # 如果是CUDA错误,尝试在CPU上处理 - if "CUDA" in str(e) or "cuda" in str(e): - print("尝试在CPU上处理此批次...") - try: - # 临时将模型移至CPU - self.caption_model = self.caption_model.to("cpu") - - # 在CPU上处理 - inputs = self.caption_processor( - images=batch, - text=[self.prompt] * len(batch), - return_tensors="pt" - ) - - with torch.no_grad(): - if 'florence' in self.caption_model.config.model_type: - generated_ids = self.caption_model.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=20, - num_beams=1, - do_sample=False - ) - else: - generated_ids = self.caption_model.generate( - **inputs, - max_length=50, - num_beams=3, - early_stopping=True - ) - - texts = self.caption_processor.batch_decode( - generated_ids, - skip_special_tokens=True - ) - texts = [text.strip() for text in texts] - generated_texts.extend(texts) - - # 处理完成后将模型移回原设备 - self.caption_model = self.caption_model.to(self.device) - except Exception as cpu_e: - print(f"CPU处理也失败: {cpu_e}") - generated_texts.extend(["[描述生成失败]"] * len(batch)) - else: - # 非CUDA错误,直接添加占位符 - generated_texts.extend(["[描述生成失败]"] * len(batch)) - + print(f"批次处理失败: {e}") return generated_texts def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]: @@ -378,14 +308,6 @@ class VisionAgent: raise ValueError(error_msg) from e - def __call__(self, image_path: str) -> List[UIElement]: - """Process an image from file path.""" - # image = self.load_image(image_source) - image = cv2.imread(image_path) - if image is None: - raise FileNotFoundError(f"Vision agent: 图片读取失败") - - return self.analyze_image(image) \ No newline at end of file diff --git a/gradio_ui/executor/anthropic_executor.py b/gradio_ui/executor/anthropic_executor.py index 7c92deb..05bf7e7 100644 --- a/gradio_ui/executor/anthropic_executor.py +++ b/gradio_ui/executor/anthropic_executor.py @@ -27,44 +27,42 @@ class AnthropicExecutor: self.output_callback = output_callback self.tool_output_callback = tool_output_callback - def __call__(self, response: BetaMessage, messages: list[BetaMessageParam]): + def __call__(self, response, messages: list[BetaMessageParam]): new_message = { "role": "assistant", - "content": cast(list[BetaContentBlockParam], response.content), + "content": cast(list[BetaContentBlockParam], response), } if new_message not in messages: messages.append(new_message) else: print("new_message already in messages, there are duplicates.") - tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in cast(list[BetaContentBlock], response.content): - self.output_callback(content_block, sender="bot") - # Execute the tool - if content_block.type == "tool_use": - # Run the asynchronous tool execution in a synchronous context - result = asyncio.run(self.tool_collection.run( - name=content_block.name, - tool_input=cast(dict[str, Any], content_block.input), - )) - - self.output_callback(result, sender="bot") - - tool_result_content.append( - _make_api_tool_result(result, content_block.id) - ) - self.tool_output_callback(result, content_block.id) + self.output_callback(response["action_type"], sender="bot") + # Execute the tool + if response["next_action"] != None: + # Run the asynchronous tool execution in a synchronous context + result = asyncio.run(self.tool_collection.run( + name=response["action_type"], + tool_input=cast(dict[str, Any], content_block.input), + )) + + self.output_callback(result, sender="bot") + + tool_result_content.append( + _make_api_tool_result(result, content_block.id) + ) + self.tool_output_callback(result, content_block.id) - # Craft messages based on the content_block - # Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message) - - display_messages = _message_display_callback(messages) - # display_messages = [] - - # Send the messages to the gradio - for user_msg, bot_msg in display_messages: - # yield [user_msg, bot_msg], tool_result_content - yield [None, None], tool_result_content + # Craft messages based on the content_block + # Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message) + + display_messages = _message_display_callback(messages) + # display_messages = [] + + # Send the messages to the gradio + for user_msg, bot_msg in display_messages: + # yield [user_msg, bot_msg], tool_result_content + yield [None, None], tool_result_content if not tool_result_content: return messages diff --git a/gradio_ui/loop.py b/gradio_ui/loop.py index 2a70422..3232334 100644 --- a/gradio_ui/loop.py +++ b/gradio_ui/loop.py @@ -61,14 +61,16 @@ def sampling_loop_sync( while True: parsed_screen = parse_screen(vision_agent) - # tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen) - tools_use_needed = task_run_agent(task_plan=plan, screen_info=parsed_screen) + tools_use_needed, vlm_response_json = task_run_agent(task_plan=plan, screen_info=parsed_screen) for message, tool_result_content in executor(tools_use_needed, messages): yield message if not tool_result_content: return messages def parse_screen(vision_agent: VisionAgent): - _, screenshot_path = get_screenshot() - parsed_screen = vision_agent(str(screenshot_path)) - return parsed_screen + screenshot, screenshot_path = get_screenshot() + response_json = {} + response_json['parsed_content_list'] = vision_agent(str(screenshot_path)) + response_json['width'] = screenshot.size[0] + response_json['height'] = screenshot.size[1] + return response_json