mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 13:07:17 +08:00
Merge branch 'dev-test'
This commit is contained in:
13
README.md
13
README.md
@@ -77,10 +77,15 @@ For supported vendors and models, please refer to this [link](./SUPPORT_MODEL.md
|
||||
### 🔧 CUDA Version Mismatch
|
||||
If you see the error: "GPU driver incompatible, please install appropriate torch version according to readme", it indicates a driver incompatibility. You can either:
|
||||
|
||||
1. Run with CPU only (slower but functional)
|
||||
2. Check your torch version with `pip list`
|
||||
3. Check supported CUDA versions on the [official website](https://pytorch.org/get-started/locally/)
|
||||
4. Reinstall Nvidia drivers
|
||||
1. Run pip list to check the torch version;
|
||||
2. Check supported CUDA versions on the official website;
|
||||
3. Copy the official torch installation command and reinstall torch for your CUDA version.
|
||||
|
||||
For example, if your CUDA version is 12.4, install torch using this command:
|
||||
|
||||
```bash
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
|
||||
14
README_CN.md
14
README_CN.md
@@ -78,8 +78,20 @@ python main.py
|
||||
|
||||
1. 运行`pip list`查看torch版本;
|
||||
2. 从[官网](https://pytorch.org/get-started/locally/)查看支持的cuda版本;
|
||||
3. 重新安装Nvidia驱动。
|
||||
3. 卸载已安装的 torch 和 torchvision;
|
||||
3. 复制官方的 torch 安装命令,重新安装适合自己 cuda 版本的 torch。
|
||||
|
||||
比如我的 cuda 版本为 12.4,需要按照如下命令来安装 torch;
|
||||
|
||||
```bash
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||
```
|
||||
|
||||
### 模型无法下载
|
||||
多半是被墙了,可以从百度网盘直接下载模型。
|
||||
|
||||
通过网盘分享的文件:weights.zip
|
||||
链接: https://pan.baidu.com/s/1Tj8sZZK9_QI7whZV93vb0w?pwd=dyeu 提取码: dyeu
|
||||
|
||||
## 🤝 参与共建
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
| Vendor-en | Vendor-ch | Model | base-url |
|
||||
| --- | --- | --- | --- |
|
||||
| Alibaba Cloud Bailian | 阿里云百炼 | deepseek-r1 | https://dashscope.aliyuncs.com/compatible-mode/v1 |
|
||||
| Alibaba Cloud Bailian | 阿里云百炼 | deepseek-v3 | https://dashscope.aliyuncs.com/compatible-mode/v1 |
|
||||
| deepseek | deepseek官方 | deepseek-chat | https://api.deepseek.com |
|
||||
| openainext | openainext | gpt-4o-2024-11-20 | https://api.openai-next.com/v1 |
|
||||
|
||||
8
gradio_ui/agent/base_agent.py
Normal file
8
gradio_ui/agent/base_agent.py
Normal file
@@ -0,0 +1,8 @@
|
||||
class BaseAgent:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.SYSTEM_PROMPT = ""
|
||||
|
||||
|
||||
def chat(self, messages):
|
||||
pass
|
||||
|
||||
@@ -19,26 +19,31 @@ class OmniParserClient:
|
||||
response_json = response.json()
|
||||
print('omniparser latency:', response_json['latency'])
|
||||
|
||||
som_image_data = base64.b64decode(response_json['som_image_base64'])
|
||||
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
|
||||
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
|
||||
with open(som_screenshot_path, "wb") as f:
|
||||
f.write(som_image_data)
|
||||
# som_image_data = base64.b64decode(response_json['som_image_base64'])
|
||||
# screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
|
||||
# som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
|
||||
# with open(som_screenshot_path, "wb") as f:
|
||||
# f.write(som_image_data)
|
||||
|
||||
response_json['width'] = screenshot.size[0]
|
||||
response_json['height'] = screenshot.size[1]
|
||||
response_json['original_screenshot_base64'] = image_base64
|
||||
response_json['screenshot_uuid'] = screenshot_path_uuid
|
||||
# response_json['screenshot_uuid'] = screenshot_path_uuid
|
||||
response_json = self.reformat_messages(response_json)
|
||||
return response_json
|
||||
|
||||
def reformat_messages(self, response_json: dict):
|
||||
screen_info = ""
|
||||
for idx, element in enumerate(response_json["parsed_content_list"]):
|
||||
element['idx'] = idx
|
||||
if element['type'] == 'text':
|
||||
screen_info += f'ID: {idx}, Text: {element["content"]}\n'
|
||||
elif element['type'] == 'icon':
|
||||
screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
|
||||
# element['idx'] = idx
|
||||
# if element['type'] == 'text':
|
||||
# screen_info += f'ID: {idx}, Text: {element["content"]}\n'
|
||||
# elif element['type'] == 'icon':
|
||||
# screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
|
||||
screen_info += f'ID: {element.element_id}, '
|
||||
screen_info += f'Coordinates: {element.coordinates}, '
|
||||
screen_info += f'Text: {element.text if len(element.text) else " "}, '
|
||||
screen_info += f'Caption: {element.caption}. '
|
||||
screen_info += "\n"
|
||||
response_json['screen_info'] = screen_info
|
||||
return response_json
|
||||
40
gradio_ui/agent/task_plan_agent.py
Normal file
40
gradio_ui/agent/task_plan_agent.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from gradio_ui.agent.base_agent import BaseAgent
|
||||
from xbrain.core.chat import run
|
||||
|
||||
class TaskPlanAgent(BaseAgent):
|
||||
def __init__(self, output_callback):
|
||||
self.output_callback = output_callback
|
||||
|
||||
def __call__(self, user_task: str):
|
||||
self.output_callback("正在规划任务中...", sender="bot")
|
||||
response = run([{"role": "user", "content": user_task}], user_prompt=system_prompt)
|
||||
self.output_callback(response, sender="bot")
|
||||
return response
|
||||
|
||||
system_prompt = """
|
||||
### 目标 ###
|
||||
你是电脑任务规划专家,根据用户的需求,规划出要执行的任务。
|
||||
##########
|
||||
### 输入 ###
|
||||
用户的需求,通常是一个文本描述。
|
||||
##########
|
||||
### 输出 ###
|
||||
一系列任务,包括任务名称
|
||||
##########
|
||||
### 例子 ###
|
||||
(案例1)
|
||||
输入:获取AI新闻
|
||||
输出:
|
||||
1. 打开浏览器
|
||||
2. 打开百度首页
|
||||
3. 搜索“AI”相关内容
|
||||
4. 浏览搜索结果,记录搜索结果
|
||||
5. 返回搜索内容
|
||||
(案例2)
|
||||
输入:删除桌面的txt文件
|
||||
输出:
|
||||
1. 进入桌面
|
||||
2. 寻找所有txt文件
|
||||
3. 右键txt文件,选择删除
|
||||
"""
|
||||
|
||||
198
gradio_ui/agent/task_run_agent.py
Normal file
198
gradio_ui/agent/task_run_agent.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import json
|
||||
import uuid
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
|
||||
from PIL import ImageDraw
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pydantic import BaseModel, Field
|
||||
from gradio_ui.agent.base_agent import BaseAgent
|
||||
from xbrain.core.chat import run
|
||||
import platform
|
||||
import re
|
||||
class TaskRunAgent(BaseAgent):
|
||||
def __init__(self, output_callback):
|
||||
self.output_callback = output_callback
|
||||
self.OUTPUT_DIR = "./tmp/outputs"
|
||||
|
||||
def __call__(self, task_plan, parsed_screen):
|
||||
screen_info = str(parsed_screen['parsed_content_list'])
|
||||
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
|
||||
device=self.get_device(),
|
||||
screen_info=screen_info)
|
||||
|
||||
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
|
||||
img_to_show = parsed_screen["image"]
|
||||
buffered = BytesIO()
|
||||
img_to_show.save(buffered, format="PNG")
|
||||
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
vlm_response = run([
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "图片是当前屏幕的截图,请根据图片以及解析出来的元素,确定下一步操作"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img_to_show_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
|
||||
vlm_response_json = json.loads(vlm_response)
|
||||
if "box_id" in vlm_response_json:
|
||||
try:
|
||||
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])].coordinates
|
||||
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 ), int((bbox[1] + bbox[3]) / 2 )]
|
||||
x, y = vlm_response_json["box_centroid_coordinate"]
|
||||
radius = 10
|
||||
draw = ImageDraw.Draw(img_to_show)
|
||||
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
|
||||
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
|
||||
buffered = BytesIO()
|
||||
img_to_show.save(buffered, format="PNG")
|
||||
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Error parsing: {vlm_response_json}")
|
||||
print(f"Error: {e}")
|
||||
self.output_callback(f'<img src="data:image/png;base64,{img_to_show_base64}">', sender="bot")
|
||||
self.output_callback(
|
||||
f'<details>'
|
||||
f' <summary>Parsed Screen elemetns by OmniParser</summary>'
|
||||
f' <pre>{screen_info}</pre>'
|
||||
f'</details>',
|
||||
sender="bot"
|
||||
)
|
||||
response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')]
|
||||
if 'box_centroid_coordinate' in vlm_response_json:
|
||||
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(move_cursor_block)
|
||||
|
||||
if vlm_response_json["next_action"] == "None":
|
||||
print("Task paused/completed.")
|
||||
elif vlm_response_json["next_action"] == "type":
|
||||
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': vlm_response_json["next_action"], 'text': vlm_response_json["value"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(sim_content_block)
|
||||
else:
|
||||
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': vlm_response_json["next_action"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(sim_content_block)
|
||||
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
|
||||
return response_message, vlm_response_json
|
||||
|
||||
|
||||
def get_device(self):
|
||||
# 获取当前操作系统信息
|
||||
system = platform.system()
|
||||
if system == "Windows":
|
||||
device = f"Windows {platform.release()}"
|
||||
elif system == "Darwin":
|
||||
device = f"Mac OS {platform.mac_ver()[0]}"
|
||||
elif system == "Linux":
|
||||
device = f"Linux {platform.release()}"
|
||||
else:
|
||||
device = system
|
||||
return device
|
||||
|
||||
|
||||
def extract_data(self, input_string, data_type):
|
||||
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
|
||||
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
||||
# Extract content
|
||||
# re.DOTALL allows '.' to match newlines as well
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
|
||||
return matches[0][0].strip() if matches else input_string
|
||||
|
||||
class TaskRunAgentResponse(BaseModel):
|
||||
reasoning: str = Field(description="描述当前屏幕上的内容,考虑历史记录,然后描述您如何实现任务的逐步思考,一次从可用操作中选择一个操作。")
|
||||
next_action: str = Field(
|
||||
description="选择一个操作类型,如果找不到合适的操作,请选择None",
|
||||
json_schema_extra={
|
||||
"enum": ["type", "left_click", "right_click", "double_click",
|
||||
"hover", "scroll_up", "scroll_down", "wait", "None"]
|
||||
}
|
||||
)
|
||||
box_id: int = Field(description="要操作的框ID,当next_action为left_click、right_click、double_click、hover时提供,否则为None", default=None)
|
||||
value: str = Field(description="仅当next_action为type时提供,否则为None", default=None)
|
||||
|
||||
system_prompt = """
|
||||
### 目标 ###
|
||||
你是一个自动化规划师,需要完成用户的任务。请你根据屏幕信息确定【下一步操作】,以完成任务:
|
||||
|
||||
你当前的任务是:
|
||||
{task_plan}
|
||||
|
||||
以下是用yolo检测的当前屏幕上的所有元素:
|
||||
|
||||
{screen_info}
|
||||
##########
|
||||
|
||||
### 注意 ###
|
||||
1. 每次应该只给出一个操作。
|
||||
2. 应该对当前屏幕进行分析,通过查看历史记录反思已完成的工作,然后描述您如何实现任务的逐步思考。
|
||||
3. 在"Next Action"中附上下一步操作预测。
|
||||
4. 不应包括其他操作,例如键盘快捷键。
|
||||
5. 当任务完成时,不要完成额外的操作。你应该在json字段中说"Next Action": "None"。
|
||||
6. 任务涉及购买多个产品或浏览多个页面。你应该将其分解为子目标,并按照说明的顺序一个一个地完成每个子目标。
|
||||
7. 避免连续多次选择相同的操作/元素,如果发生这种情况,反思自己,可能出了什么问题,并预测不同的操作。
|
||||
8. 如果您收到登录信息页面或验证码页面的提示,或者您认为下一步操作需要用户许可,您应该在json字段中说"Next Action": "None"。
|
||||
9. 你只能使用鼠标和键盘与计算机进行交互。
|
||||
10. 你只能与桌面图形用户界面交互(无法访问终端或应用程序菜单)。
|
||||
11. 如果当前屏幕没有显示任何可操作的元素,并且当前屏幕不能下滑,请返回None。
|
||||
|
||||
##########
|
||||
### 输出格式 ###
|
||||
```json
|
||||
{{
|
||||
"reasoning": str, # 描述当前屏幕上的内容,考虑历史记录,然后描述您如何实现任务的逐步思考,一次从可用操作中选择一个操作。
|
||||
"next_action": "action_type, action description" | "None" # 一次一个操作,简短精确地描述它。
|
||||
"box_id": n,
|
||||
"value": "xxx" # 仅当操作为type时提供value字段,否则不包括value键
|
||||
}}
|
||||
```
|
||||
|
||||
【next_action】仅包括下面之一:
|
||||
- type:输入一串文本。
|
||||
- left_click:将鼠标移动到框ID并左键单击。
|
||||
- right_click:将鼠标移动到框ID并右键单击。
|
||||
- double_click:将鼠标移动到框ID并双击。
|
||||
- hover:将鼠标移动到框ID。
|
||||
- scroll_up:向上滚动屏幕以查看之前的内容。
|
||||
- scroll_down:当所需按钮不可见或您需要查看更多内容时,向下滚动屏幕。
|
||||
- wait:等待1秒钟让设备加载或响应。
|
||||
|
||||
##########
|
||||
### 案例 ###
|
||||
一个例子:
|
||||
```json
|
||||
{{
|
||||
"reasoning": "当前屏幕显示亚马逊的谷歌搜索结果,在之前的操作中,我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。",
|
||||
"next_action": "left_click",
|
||||
"box_id": m
|
||||
}}
|
||||
```
|
||||
|
||||
另一个例子:
|
||||
```json
|
||||
{{
|
||||
"reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"。",
|
||||
"next_action": "type",
|
||||
"box_id": n,
|
||||
"value": "Apple watch"
|
||||
}}
|
||||
```
|
||||
|
||||
另一个例子:
|
||||
```json
|
||||
{{
|
||||
"reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。",
|
||||
"next_action": "scroll_down"
|
||||
}}
|
||||
"""
|
||||
|
||||
299
gradio_ui/agent/vision_agent.py
Normal file
299
gradio_ui/agent/vision_agent.py
Normal file
@@ -0,0 +1,299 @@
|
||||
from typing import List, Optional
|
||||
import cv2
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
import easyocr
|
||||
import supervision as sv
|
||||
import numpy as np
|
||||
import time
|
||||
from pydantic import BaseModel
|
||||
import base64
|
||||
from PIL import Image
|
||||
|
||||
class UIElement(BaseModel):
|
||||
element_id: int
|
||||
coordinates: list[float]
|
||||
caption: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
class VisionAgent:
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
|
||||
"""
|
||||
Initialize the vision agent
|
||||
|
||||
Parameters:
|
||||
yolo_model_path: Path to YOLO model
|
||||
caption_model_path: Path to image caption model, default is Florence-2
|
||||
"""
|
||||
# determine the available device and the best dtype
|
||||
self.device, self.dtype = self._get_optimal_device_and_dtype()
|
||||
# load the YOLO model
|
||||
self.yolo_model = YOLO(yolo_model_path)
|
||||
|
||||
# load the image caption model and processor
|
||||
self.caption_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# load the model according to the device type
|
||||
try:
|
||||
if self.device.type == 'cuda':
|
||||
# CUDA device uses float16
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
elif self.device.type == 'mps':
|
||||
# MPS device uses float32 (MPS has limited support for float16)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
else:
|
||||
# CPU uses float32
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.prompt = "<CAPTION>"
|
||||
|
||||
# set the batch size
|
||||
if self.device.type == 'cuda':
|
||||
self.batch_size = 128
|
||||
elif self.device.type == 'mps':
|
||||
self.batch_size = 32
|
||||
else:
|
||||
self.batch_size = 16
|
||||
|
||||
self.elements: List[UIElement] = []
|
||||
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
|
||||
|
||||
def __call__(self, image_path: str) -> List[UIElement]:
|
||||
"""Process an image from file path."""
|
||||
# image = self.load_image(image_source)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: Failed to read image")
|
||||
return self.analyze_image(image)
|
||||
|
||||
def _get_optimal_device_and_dtype(self):
|
||||
"""determine the optimal device and dtype"""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
# check if the GPU is suitable for using float16
|
||||
capability = torch.cuda.get_device_capability()
|
||||
# only use float16 on newer GPUs
|
||||
if capability[0] >= 7:
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
dtype = torch.float32
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
return device, dtype
|
||||
|
||||
def _reset_state(self):
|
||||
"""Clear previous analysis results"""
|
||||
self.elements = []
|
||||
|
||||
def analyze_image(self, image: np.ndarray) -> List[UIElement]:
|
||||
"""
|
||||
Process an image through all computer vision pipelines.
|
||||
|
||||
Args:
|
||||
image: Input image in BGR format (OpenCV default)
|
||||
|
||||
Returns:
|
||||
List of detected UI elements with annotations
|
||||
"""
|
||||
self._reset_state()
|
||||
|
||||
element_crops, boxes = self._detect_objects(image)
|
||||
start = time.time()
|
||||
element_texts = self._extract_text(element_crops)
|
||||
end = time.time()
|
||||
ocr_time = (end-start) * 10 ** 3
|
||||
print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
|
||||
start = time.time()
|
||||
element_captions = self._get_caption(element_crops, 5)
|
||||
end = time.time()
|
||||
caption_time = (end-start) * 10 ** 3
|
||||
print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
|
||||
for idx in range(len(element_crops)):
|
||||
new_element = UIElement(element_id=idx,
|
||||
coordinates=boxes[idx],
|
||||
text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
|
||||
caption=element_captions[idx]
|
||||
)
|
||||
self.elements.append(new_element)
|
||||
|
||||
return self.elements
|
||||
|
||||
def _extract_text(self, images: np.ndarray) -> list[str]:
|
||||
"""
|
||||
Run OCR in sequential mode
|
||||
TODO: It is possible to run in batch mode for a speed up, but the result quality needs test.
|
||||
https://github.com/JaidedAI/EasyOCR/pull/458
|
||||
"""
|
||||
texts = []
|
||||
for image in images:
|
||||
text = self.ocr_reader.readtext(image, detail=0, paragraph=True, text_threshold=0.85)
|
||||
texts.append(text)
|
||||
# print(texts)
|
||||
return texts
|
||||
|
||||
def _get_caption(self, element_crops, batch_size=None):
|
||||
"""get the caption of the element crops"""
|
||||
if not element_crops:
|
||||
return []
|
||||
|
||||
# if batch_size is not specified, use the instance's default value
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
# resize the image to 64x64
|
||||
resized_crops = []
|
||||
for img in element_crops:
|
||||
# convert to numpy array, resize, then convert back to PIL
|
||||
img_np = np.array(img)
|
||||
resized_np = cv2.resize(img_np, (64, 64))
|
||||
resized_crops.append(Image.fromarray(resized_np))
|
||||
|
||||
generated_texts = []
|
||||
device = self.device
|
||||
|
||||
# process in batches
|
||||
for i in range(0, len(resized_crops), batch_size):
|
||||
batch = resized_crops[i:i+batch_size]
|
||||
try:
|
||||
# select the dtype according to the device type
|
||||
if device.type == 'cuda':
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt",
|
||||
do_resize=False
|
||||
).to(device=device, dtype=torch.float16)
|
||||
else:
|
||||
# MPS and CPU use float32
|
||||
inputs = self.caption_processor(
|
||||
images=batch,
|
||||
text=[self.prompt] * len(batch),
|
||||
return_tensors="pt"
|
||||
).to(device=device)
|
||||
|
||||
# special treatment for Florence-2
|
||||
with torch.no_grad():
|
||||
if 'florence' in self.caption_model.config.model_type:
|
||||
generated_ids = self.caption_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=20,
|
||||
num_beams=5,
|
||||
do_sample=False
|
||||
)
|
||||
else:
|
||||
generated_ids = self.caption_model.generate(
|
||||
**inputs,
|
||||
max_length=50,
|
||||
num_beams=3,
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
# decode the generated IDs
|
||||
texts = self.caption_processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
texts = [text.strip() for text in texts]
|
||||
generated_texts.extend(texts)
|
||||
|
||||
# clean the cache
|
||||
if device.type == 'cuda' and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except RuntimeError as e:
|
||||
raise e
|
||||
return generated_texts
|
||||
|
||||
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
|
||||
"""Run object detection pipeline"""
|
||||
results = self.yolo_model(image)[0]
|
||||
detections = sv.Detections.from_ultralytics(results)
|
||||
boxes = detections.xyxy
|
||||
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# Filter out boxes contained by others
|
||||
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
sorted_indices = np.argsort(-areas) # Sort descending by area
|
||||
sorted_boxes = boxes[sorted_indices]
|
||||
|
||||
keep_sorted = []
|
||||
for i in range(len(sorted_boxes)):
|
||||
contained = False
|
||||
for j in keep_sorted:
|
||||
box_b = sorted_boxes[j]
|
||||
box_a = sorted_boxes[i]
|
||||
if (box_b[0] <= box_a[0] and box_b[1] <= box_a[1] and
|
||||
box_b[2] >= box_a[2] and box_b[3] >= box_a[3]):
|
||||
contained = True
|
||||
break
|
||||
if not contained:
|
||||
keep_sorted.append(i)
|
||||
|
||||
# Map back to original indices
|
||||
keep_indices = sorted_indices[keep_sorted]
|
||||
filtered_boxes = boxes[keep_indices]
|
||||
|
||||
# Extract element crops
|
||||
element_crops = []
|
||||
for box in filtered_boxes:
|
||||
x1, y1, x2, y2 = map(int, map(round, box))
|
||||
element = image[y1:y2, x1:x2]
|
||||
element_crops.append(np.array(element))
|
||||
|
||||
return element_crops, filtered_boxes
|
||||
|
||||
def load_image(self, image_source: str) -> np.ndarray:
|
||||
try:
|
||||
# Handle potential Data URL prefix (like "data:image/png;base64,")
|
||||
if ',' in image_source:
|
||||
_, payload = image_source.split(',', 1)
|
||||
else:
|
||||
payload = image_source
|
||||
|
||||
# Base64 decode -> bytes -> numpy array
|
||||
image_bytes = base64.b64decode(payload)
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# OpenCV decode image
|
||||
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("Failed to decode image: Invalid image data")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
except (base64.binascii.Error, ValueError) as e:
|
||||
# Generate clearer error message
|
||||
error_msg = f"Input is neither a valid file path nor valid Base64 image data"
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Callable
|
||||
import uuid
|
||||
from PIL import Image, ImageDraw
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from anthropic import APIResponse
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
|
||||
from gradio_ui.agent.llm_utils.oaiclient import run_oai_interleaved
|
||||
from gradio_ui.agent.llm_utils.utils import is_image_path
|
||||
import time
|
||||
import re
|
||||
|
||||
OUTPUT_DIR = "./tmp/outputs"
|
||||
|
||||
def extract_data(input_string, data_type):
|
||||
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
|
||||
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
||||
# Extract content
|
||||
# re.DOTALL allows '.' to match newlines as well
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
|
||||
return matches[0][0].strip() if matches else input_string
|
||||
|
||||
class VLMAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
output_callback: Callable,
|
||||
api_response_callback: Callable,
|
||||
max_tokens: int = 4096,
|
||||
base_url:str = "",
|
||||
only_n_most_recent_images: int | None = None,
|
||||
print_usage: bool = True,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.api_response_callback = api_response_callback
|
||||
self.max_tokens = max_tokens
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.output_callback = output_callback
|
||||
self.model = model
|
||||
self.print_usage = print_usage
|
||||
self.total_token_usage = 0
|
||||
self.total_cost = 0
|
||||
self.step_count = 0
|
||||
|
||||
self.system = ''
|
||||
|
||||
def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
|
||||
self.step_count += 1
|
||||
image_base64 = parsed_screen['original_screenshot_base64']
|
||||
latency_omniparser = parsed_screen['latency']
|
||||
self.output_callback(f'-- Step {self.step_count}: --', sender="bot")
|
||||
screen_info = str(parsed_screen['screen_info'])
|
||||
screenshot_uuid = parsed_screen['screenshot_uuid']
|
||||
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
|
||||
|
||||
boxids_and_labels = parsed_screen["screen_info"]
|
||||
system = self._get_system_prompt(boxids_and_labels)
|
||||
|
||||
# drop looping actions msg, byte image etc
|
||||
planner_messages = messages
|
||||
_remove_som_images(planner_messages)
|
||||
_maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
|
||||
|
||||
if isinstance(planner_messages[-1], dict):
|
||||
if not isinstance(planner_messages[-1]["content"], list):
|
||||
planner_messages[-1]["content"] = [planner_messages[-1]["content"]]
|
||||
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png")
|
||||
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
|
||||
|
||||
start = time.time()
|
||||
vlm_response, token_usage = run_oai_interleaved(
|
||||
messages=planner_messages,
|
||||
system=system,
|
||||
model_name=self.model,
|
||||
api_key=self.api_key,
|
||||
max_tokens=self.max_tokens,
|
||||
provider_base_url=self.base_url,
|
||||
temperature=0,
|
||||
)
|
||||
latency_vlm = time.time() - start
|
||||
self.output_callback(f"LLM: {latency_vlm:.2f}s, OmniParser: {latency_omniparser:.2f}s", sender="bot")
|
||||
|
||||
print(f"llm_response: {vlm_response}")
|
||||
if self.print_usage:
|
||||
print(f"Total token so far: {token_usage}. Total cost so far: $USD{self.total_cost:.5f}")
|
||||
|
||||
vlm_response_json = extract_data(vlm_response, "json")
|
||||
vlm_response_json = json.loads(vlm_response_json)
|
||||
|
||||
img_to_show_base64 = parsed_screen["som_image_base64"]
|
||||
if "Box ID" in vlm_response_json:
|
||||
try:
|
||||
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["Box ID"])]["bbox"]
|
||||
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
|
||||
img_to_show_data = base64.b64decode(img_to_show_base64)
|
||||
img_to_show = Image.open(BytesIO(img_to_show_data))
|
||||
|
||||
draw = ImageDraw.Draw(img_to_show)
|
||||
x, y = vlm_response_json["box_centroid_coordinate"]
|
||||
radius = 10
|
||||
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
|
||||
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
|
||||
|
||||
buffered = BytesIO()
|
||||
img_to_show.save(buffered, format="PNG")
|
||||
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except:
|
||||
print(f"Error parsing: {vlm_response_json}")
|
||||
pass
|
||||
self.output_callback(f'<img src="data:image/png;base64,{img_to_show_base64}">', sender="bot")
|
||||
self.output_callback(
|
||||
f'<details>'
|
||||
f' <summary>Parsed Screen elemetns by OmniParser</summary>'
|
||||
f' <pre>{screen_info}</pre>'
|
||||
f'</details>',
|
||||
sender="bot"
|
||||
)
|
||||
vlm_plan_str = ""
|
||||
for key, value in vlm_response_json.items():
|
||||
if key == "Reasoning":
|
||||
vlm_plan_str += f'{value}'
|
||||
else:
|
||||
vlm_plan_str += f'\n{key}: {value}'
|
||||
|
||||
# construct the response so that anthropicExcutor can execute the tool
|
||||
response_content = [BetaTextBlock(text=vlm_plan_str, type='text')]
|
||||
if 'box_centroid_coordinate' in vlm_response_json:
|
||||
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(move_cursor_block)
|
||||
|
||||
if vlm_response_json["Next Action"] == "None":
|
||||
print("Task paused/completed.")
|
||||
elif vlm_response_json["Next Action"] == "type":
|
||||
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': vlm_response_json["Next Action"], 'text': vlm_response_json["value"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(sim_content_block)
|
||||
else:
|
||||
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||
input={'action': vlm_response_json["Next Action"]},
|
||||
name='computer', type='tool_use')
|
||||
response_content.append(sim_content_block)
|
||||
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
|
||||
return response_message, vlm_response_json
|
||||
|
||||
def _api_response_callback(self, response: APIResponse):
|
||||
self.api_response_callback(response)
|
||||
|
||||
def _get_system_prompt(self, screen_info: str = ""):
|
||||
main_section = f"""
|
||||
You are using a Windows device.
|
||||
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
|
||||
You can only interact with the desktop GUI (no terminal or application menu access).
|
||||
You may be given some history plan and actions, this is the response from the previous loop.
|
||||
You should carefully consider your plan base on the task, screenshot, and history actions.
|
||||
|
||||
|
||||
Here is the list of all detected bounding boxes by IDs on the screen and their description:{screen_info}
|
||||
|
||||
Your available "Next Action" only include:
|
||||
- type: types a string of text.
|
||||
- left_click: move mouse to box id and left clicks.
|
||||
- right_click: move mouse to box id and right clicks.
|
||||
- double_click: move mouse to box id and double clicks.
|
||||
- hover: move mouse to box id.
|
||||
- scroll_up: scrolls the screen up to view previous content.
|
||||
- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content.
|
||||
- wait: waits for 1 second for the device to load or respond.
|
||||
|
||||
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
|
||||
|
||||
Output format:
|
||||
```json
|
||||
{{
|
||||
"Reasoning": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
|
||||
"Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
|
||||
"Box ID": n,
|
||||
"value": "xxx" # only provide value field if the action is type, else don't include value key
|
||||
}}
|
||||
```
|
||||
|
||||
One Example:
|
||||
```json
|
||||
{{
|
||||
"Reasoning": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
|
||||
"Next Action": "left_click",
|
||||
"Box ID": m
|
||||
}}
|
||||
```
|
||||
|
||||
Another Example:
|
||||
```json
|
||||
{{
|
||||
"Reasoning": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
|
||||
"Next Action": "type",
|
||||
"Box ID": n,
|
||||
"value": "Apple watch"
|
||||
}}
|
||||
```
|
||||
|
||||
Another Example:
|
||||
```json
|
||||
{{
|
||||
"Reasoning": "The current screen does not show 'submit' button, I need to scroll down to see if the button is available.",
|
||||
"Next Action": "scroll_down",
|
||||
}}
|
||||
```
|
||||
|
||||
IMPORTANT NOTES:
|
||||
1. You should only give a single action at a time.
|
||||
|
||||
"""
|
||||
thinking_model = "r1" in self.model
|
||||
if not thinking_model:
|
||||
main_section += """
|
||||
2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
|
||||
|
||||
"""
|
||||
else:
|
||||
main_section += """
|
||||
2. In <think> XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In <output> XML tags put the next action prediction JSON.
|
||||
|
||||
"""
|
||||
main_section += """
|
||||
3. Attach the next action prediction in the "Next Action".
|
||||
4. You should not include other actions, such as keyboard shortcuts.
|
||||
5. When the task is completed, don't complete additional actions. You should say "Next Action": "None" in the json field.
|
||||
6. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions.
|
||||
7. avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action.
|
||||
8. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Next Action": "None" in the json field.
|
||||
"""
|
||||
|
||||
return main_section
|
||||
|
||||
def _remove_som_images(messages):
|
||||
for msg in messages:
|
||||
msg_content = msg["content"]
|
||||
if isinstance(msg_content, list):
|
||||
msg["content"] = [
|
||||
cnt for cnt in msg_content
|
||||
if not (isinstance(cnt, str) and 'som' in cnt and is_image_path(cnt))
|
||||
]
|
||||
|
||||
|
||||
def _maybe_filter_to_n_most_recent_images(
|
||||
messages: list[BetaMessageParam],
|
||||
images_to_keep: int,
|
||||
min_removal_threshold: int = 10,
|
||||
):
|
||||
"""
|
||||
With the assumption that images are screenshots that are of diminishing value as
|
||||
the conversation progresses, remove all but the final `images_to_keep` tool_result
|
||||
images in place
|
||||
"""
|
||||
if images_to_keep is None:
|
||||
return messages
|
||||
|
||||
total_images = 0
|
||||
for msg in messages:
|
||||
for cnt in msg.get("content", []):
|
||||
if isinstance(cnt, str) and is_image_path(cnt):
|
||||
total_images += 1
|
||||
elif isinstance(cnt, dict) and cnt.get("type") == "tool_result":
|
||||
for content in cnt.get("content", []):
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
total_images += 1
|
||||
|
||||
images_to_remove = total_images - images_to_keep
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg["content"]
|
||||
if isinstance(msg_content, list):
|
||||
new_content = []
|
||||
for cnt in msg_content:
|
||||
# Remove images from SOM or screenshot as needed
|
||||
if isinstance(cnt, str) and is_image_path(cnt):
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
# VLM shouldn't use anthropic screenshot tool so shouldn't have these but in case it does, remove as needed
|
||||
elif isinstance(cnt, dict) and cnt.get("type") == "tool_result":
|
||||
new_tool_result_content = []
|
||||
for tool_result_entry in cnt.get("content", []):
|
||||
if isinstance(tool_result_entry, dict) and tool_result_entry.get("type") == "image":
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
new_tool_result_content.append(tool_result_entry)
|
||||
cnt["content"] = new_tool_result_content
|
||||
# Append fixed content to current message's content list
|
||||
new_content.append(cnt)
|
||||
msg["content"] = new_content
|
||||
@@ -14,12 +14,15 @@ from anthropic import APIResponse
|
||||
from anthropic.types import TextBlock
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
from gradio_ui.agent.vision_agent import VisionAgent
|
||||
from gradio_ui.loop import (
|
||||
sampling_loop_sync,
|
||||
)
|
||||
from gradio_ui.tools import ToolResult
|
||||
import base64
|
||||
from xbrain.utils.config import Config
|
||||
|
||||
from util.download_weights import MODEL_DIR
|
||||
CONFIG_DIR = Path("~/.anthropic").expanduser()
|
||||
API_KEY_FILE = CONFIG_DIR / "api_key"
|
||||
|
||||
@@ -41,12 +44,23 @@ class Sender(StrEnum):
|
||||
BOT = "assistant"
|
||||
TOOL = "tool"
|
||||
def setup_state(state):
|
||||
# 如果存在config,则从config中加载数据
|
||||
config = Config()
|
||||
if config.OPENAI_API_KEY:
|
||||
state["api_key"] = config.OPENAI_API_KEY
|
||||
else:
|
||||
state["api_key"] = ""
|
||||
if config.OPENAI_BASE_URL:
|
||||
state["base_url"] = config.OPENAI_BASE_URL
|
||||
else:
|
||||
state["base_url"] = "https://api.openai.com/v1"
|
||||
if config.OPENAI_MODEL:
|
||||
state["model"] = config.OPENAI_MODEL
|
||||
else:
|
||||
state["model"] = "gpt-4o"
|
||||
|
||||
if "messages" not in state:
|
||||
state["messages"] = []
|
||||
if "model" not in state:
|
||||
state["model"] = "gpt-4o"
|
||||
if "api_key" not in state:
|
||||
state["api_key"] = ""
|
||||
if "auth_validated" not in state:
|
||||
state["auth_validated"] = False
|
||||
if "responses" not in state:
|
||||
@@ -59,8 +73,6 @@ def setup_state(state):
|
||||
state['chatbot_messages'] = []
|
||||
if 'stop' not in state:
|
||||
state['stop'] = False
|
||||
if 'base_url' not in state:
|
||||
state['base_url'] = "https://api.openai-next.com/v1"
|
||||
|
||||
async def main(state):
|
||||
"""Render loop for Gradio"""
|
||||
@@ -156,11 +168,12 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b
|
||||
# print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
|
||||
|
||||
|
||||
def process_input(user_input, state):
|
||||
def process_input(user_input, state, vision_agent_state):
|
||||
# Reset the stop flag
|
||||
if state["stop"]:
|
||||
state["stop"] = False
|
||||
|
||||
config = Config()
|
||||
config.set_openai_config(base_url=state["base_url"], api_key=state["api_key"], model=state["model"])
|
||||
# Append the user message to state["messages"]
|
||||
state["messages"].append(
|
||||
{
|
||||
@@ -173,17 +186,15 @@ def process_input(user_input, state):
|
||||
state['chatbot_messages'].append((user_input, None)) # 确保格式正确
|
||||
yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
|
||||
# Run sampling_loop_sync with the chatbot_output_callback
|
||||
agent = vision_agent_state["agent"]
|
||||
for loop_msg in sampling_loop_sync(
|
||||
model=state["model"],
|
||||
messages=state["messages"],
|
||||
output_callback=partial(chatbot_output_callback, chatbot_state=state['chatbot_messages'], hide_images=False),
|
||||
tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
|
||||
api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
|
||||
api_key=state["api_key"],
|
||||
only_n_most_recent_images=state["only_n_most_recent_images"],
|
||||
max_tokens=8000,
|
||||
omniparser_url=args.omniparser_server_url,
|
||||
base_url = state["base_url"]
|
||||
vision_agent = agent
|
||||
):
|
||||
if loop_msg is None or state.get("stop"):
|
||||
yield state['chatbot_messages']
|
||||
@@ -244,14 +255,14 @@ def run():
|
||||
with gr.Column():
|
||||
model = gr.Textbox(
|
||||
label="Model",
|
||||
value="gpt-4o",
|
||||
value=state.value["model"],
|
||||
placeholder="输入模型名称",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
base_url = gr.Textbox(
|
||||
label="Base URL",
|
||||
value="https://api.openai-next.com/v1",
|
||||
value=state.value["base_url"],
|
||||
placeholder="输入基础 URL",
|
||||
interactive=True
|
||||
)
|
||||
@@ -268,7 +279,7 @@ def run():
|
||||
api_key = gr.Textbox(
|
||||
label="API Key",
|
||||
type="password",
|
||||
value=state.value.get("api_key", ""),
|
||||
value=state.value["api_key"],
|
||||
placeholder="Paste your API key here",
|
||||
interactive=True,
|
||||
)
|
||||
@@ -285,15 +296,11 @@ def run():
|
||||
chatbot = gr.Chatbot(
|
||||
label="Chatbot History",
|
||||
autoscroll=True,
|
||||
height=580 )
|
||||
height=580
|
||||
)
|
||||
|
||||
def update_model(model_selection, state):
|
||||
state["model"] = model_selection
|
||||
api_key_update = gr.update(
|
||||
placeholder="API Key",
|
||||
value=state["api_key"]
|
||||
)
|
||||
return api_key_update
|
||||
def update_model(model, state):
|
||||
state["model"] = model
|
||||
|
||||
def update_api_key(api_key_value, state):
|
||||
state["api_key"] = api_key_value
|
||||
@@ -309,10 +316,13 @@ def run():
|
||||
state['chatbot_messages'] = []
|
||||
return state['chatbot_messages']
|
||||
|
||||
model.change(fn=update_model, inputs=[model, state], outputs=[api_key])
|
||||
model.change(fn=update_model, inputs=[model, state], outputs=None)
|
||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
|
||||
submit_button.click(process_input, [chat_input, state], chatbot)
|
||||
vision_agent = VisionAgent(yolo_model_path=os.path.join(MODEL_DIR, "icon_detect", "model.pt"),
|
||||
caption_model_path=os.path.join(MODEL_DIR, "icon_caption"))
|
||||
vision_agent_state = gr.State({"agent": vision_agent})
|
||||
submit_button.click(process_input, [chat_input, state, vision_agent_state], chatbot)
|
||||
stop_button.click(stop_app, [state], None)
|
||||
base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7888)
|
||||
|
||||
@@ -2,22 +2,24 @@
|
||||
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
|
||||
"""
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum
|
||||
|
||||
from time import sleep
|
||||
import cv2
|
||||
from gradio_ui.agent.vision_agent import VisionAgent
|
||||
from gradio_ui.tools.screen_capture import get_screenshot
|
||||
from anthropic import APIResponse
|
||||
from anthropic.types import (
|
||||
TextBlock,
|
||||
)
|
||||
from anthropic.types.beta import (
|
||||
BetaContentBlock,
|
||||
BetaMessage,
|
||||
BetaMessageParam
|
||||
)
|
||||
from gradio_ui.agent.task_plan_agent import TaskPlanAgent
|
||||
from gradio_ui.agent.task_run_agent import TaskRunAgent
|
||||
from gradio_ui.tools import ToolResult
|
||||
|
||||
from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient
|
||||
from gradio_ui.agent.vlm_agent import VLMAgent
|
||||
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
OUTPUT_DIR = "./tmp/outputs"
|
||||
|
||||
def sampling_loop_sync(
|
||||
*,
|
||||
@@ -26,39 +28,69 @@ def sampling_loop_sync(
|
||||
output_callback: Callable[[BetaContentBlock], None],
|
||||
tool_output_callback: Callable[[ToolResult, str], None],
|
||||
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
|
||||
api_key: str,
|
||||
only_n_most_recent_images: int | None = 2,
|
||||
max_tokens: int = 4096,
|
||||
omniparser_url: str,
|
||||
base_url: str
|
||||
only_n_most_recent_images: int | None = 0,
|
||||
vision_agent: VisionAgent
|
||||
):
|
||||
"""
|
||||
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
|
||||
"""
|
||||
print('in sampling_loop_sync, model:', model)
|
||||
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
|
||||
actor = VLMAgent(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
api_response_callback=api_response_callback,
|
||||
output_callback=output_callback,
|
||||
max_tokens=max_tokens,
|
||||
only_n_most_recent_images=only_n_most_recent_images
|
||||
)
|
||||
task_plan_agent = TaskPlanAgent(output_callback=output_callback)
|
||||
executor = AnthropicExecutor(
|
||||
output_callback=output_callback,
|
||||
tool_output_callback=tool_output_callback,
|
||||
)
|
||||
|
||||
tool_result_content = None
|
||||
|
||||
print(f"Start the message loop. User messages: {messages}")
|
||||
plan = task_plan_agent(user_task = messages[-1]["content"][0].text)
|
||||
task_run_agent = TaskRunAgent(output_callback=output_callback)
|
||||
|
||||
|
||||
while True:
|
||||
parsed_screen = omniparser_client()
|
||||
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
|
||||
parsed_screen = parse_screen(vision_agent)
|
||||
tools_use_needed, __ = task_run_agent(task_plan=plan, parsed_screen=parsed_screen)
|
||||
sleep(2)
|
||||
for message, tool_result_content in executor(tools_use_needed, messages):
|
||||
yield message
|
||||
if not tool_result_content:
|
||||
return messages
|
||||
return messages
|
||||
|
||||
def parse_screen(vision_agent: VisionAgent):
|
||||
screenshot, screenshot_path = get_screenshot()
|
||||
response_json = {}
|
||||
response_json['parsed_content_list'] = vision_agent(str(screenshot_path))
|
||||
response_json['width'] = screenshot.size[0]
|
||||
response_json['height'] = screenshot.size[1]
|
||||
response_json['image'] = draw_elements(screenshot, response_json['parsed_content_list'])
|
||||
return response_json
|
||||
|
||||
def draw_elements(screenshot, parsed_content_list):
|
||||
"""
|
||||
将PIL图像转换为OpenCV兼容格式并绘制边界框
|
||||
|
||||
Args:
|
||||
screenshot: PIL Image对象
|
||||
parsed_content_list: 包含边界框信息的列表
|
||||
|
||||
Returns:
|
||||
带有绘制边界框的PIL图像
|
||||
"""
|
||||
# 将PIL图像转换为opencv格式
|
||||
opencv_image = np.array(screenshot)
|
||||
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
|
||||
# 绘制边界框
|
||||
for idx, element in enumerate(parsed_content_list):
|
||||
bbox = element.coordinates
|
||||
x1, y1, x2, y2 = bbox
|
||||
# 转换坐标为整数
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
# 绘制矩形
|
||||
cv2.rectangle(opencv_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
||||
# 在矩形边框左上角绘制序号
|
||||
cv2.putText(opencv_image, str(idx+1), (x1, y1-10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
|
||||
|
||||
# 将OpenCV图像格式转换回PIL格式
|
||||
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(opencv_image)
|
||||
|
||||
return pil_image
|
||||
@@ -1,6 +1,5 @@
|
||||
import base64
|
||||
import time
|
||||
from enum import StrEnum
|
||||
from typing import Literal, TypedDict
|
||||
from PIL import Image
|
||||
from util import tool
|
||||
@@ -175,6 +174,8 @@ class ComputerTool(BaseAnthropicTool):
|
||||
pyautogui.click()
|
||||
elif action == "right_click":
|
||||
pyautogui.rightClick()
|
||||
# 等待5秒,等待菜单弹出
|
||||
time.sleep(5)
|
||||
elif action == "middle_click":
|
||||
pyautogui.middleClick()
|
||||
elif action == "double_click":
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
from PIL import Image
|
||||
from .base import BaseAnthropicTool, ToolError
|
||||
from io import BytesIO
|
||||
from .base import ToolError
|
||||
from util import tool
|
||||
|
||||
OUTPUT_DIR = "./tmp/outputs"
|
||||
|
||||
74
main.py
74
main.py
@@ -1,78 +1,20 @@
|
||||
import subprocess
|
||||
from threading import Thread
|
||||
import time
|
||||
import requests
|
||||
from gradio_ui import app
|
||||
from util import download_weights
|
||||
import torch
|
||||
import socket
|
||||
|
||||
def run():
|
||||
try:
|
||||
print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True
|
||||
print("cuda is_available: ", torch.cuda.is_available())
|
||||
print("MPS is_available: ", torch.backends.mps.is_available())
|
||||
print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1
|
||||
print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
|
||||
print("cuda device_count", torch.cuda.device_count())
|
||||
print("cuda device_name", torch.cuda.get_device_name(0))
|
||||
except Exception:
|
||||
print("显卡驱动不适配,请根据readme安装合适版本的 torch!")
|
||||
print("GPU driver is not compatible, please install the appropriate version of torch according to the readme!")
|
||||
|
||||
# download the weight files
|
||||
download_weights.download()
|
||||
app.run()
|
||||
|
||||
|
||||
server_process = subprocess.Popen(
|
||||
["python", "./omniserver.py"],
|
||||
stdout=subprocess.PIPE, # 捕获标准输出
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
|
||||
stdout_thread = Thread(
|
||||
target=stream_reader,
|
||||
args=(server_process.stdout, "SERVER-OUT")
|
||||
)
|
||||
|
||||
stderr_thread = Thread(
|
||||
target=stream_reader,
|
||||
args=(server_process.stderr, "SERVER-ERR")
|
||||
)
|
||||
stdout_thread.daemon = True
|
||||
stderr_thread.daemon = True
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
|
||||
|
||||
try:
|
||||
# 下载权重文件
|
||||
download_weights.download()
|
||||
print("启动Omniserver服务中,因为加载模型真的超级慢,请耐心等待!")
|
||||
while True:
|
||||
try:
|
||||
res = requests.get("http://127.0.0.1:8000/probe/", timeout=5)
|
||||
if res.status_code == 200:
|
||||
print("Omniparser服务启动成功...")
|
||||
break
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
pass
|
||||
if server_process.poll() is not None:
|
||||
raise RuntimeError(f"服务器进程报错退出:{server_process.returncode}")
|
||||
print("等待服务启动...")
|
||||
time.sleep(10)
|
||||
|
||||
app.run()
|
||||
finally:
|
||||
if server_process.poll() is None: # 如果进程还在运行
|
||||
server_process.terminate() # 发送终止信号
|
||||
server_process.wait(timeout=8) # 等待进程结束
|
||||
|
||||
def stream_reader(pipe, prefix):
|
||||
for line in pipe:
|
||||
print(f"[{prefix}]", line, end="", flush=True)
|
||||
|
||||
def is_port_occupied(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 检测8000端口是否被占用
|
||||
if is_port_occupied(8000):
|
||||
print("8000端口被占用,请先关闭占用该端口的进程")
|
||||
exit()
|
||||
run()
|
||||
@@ -11,7 +11,8 @@ import argparse
|
||||
import uvicorn
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(root_dir)
|
||||
from util.omniparser import Omniparser
|
||||
# from util.omniparser import Omniparser
|
||||
from gradio_ui.agent.vision_agent import VisionAgent
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Omniparser API')
|
||||
@@ -29,8 +30,10 @@ args = parse_arguments()
|
||||
config = vars(args)
|
||||
|
||||
app = FastAPI()
|
||||
omniparser = Omniparser(config)
|
||||
|
||||
# omniparser = Omniparser(config)
|
||||
yolo_model_path = config['som_model_path']
|
||||
caption_model_path = config['caption_model_path']
|
||||
vision_agent = VisionAgent(yolo_model_path=yolo_model_path, caption_model_path=caption_model_path)
|
||||
class ParseRequest(BaseModel):
|
||||
base64_image: str
|
||||
|
||||
@@ -38,10 +41,12 @@ class ParseRequest(BaseModel):
|
||||
async def parse(parse_request: ParseRequest):
|
||||
print('start parsing...')
|
||||
start = time.time()
|
||||
dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
|
||||
# dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
|
||||
parsed_content_list = vision_agent(parse_request.base64_image)
|
||||
|
||||
latency = time.time() - start
|
||||
print('time:', latency)
|
||||
return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
|
||||
return {"parsed_content_list": parsed_content_list, 'latency': latency}
|
||||
|
||||
@app.get("/probe/")
|
||||
async def root():
|
||||
|
||||
@@ -2,31 +2,13 @@ torch
|
||||
easyocr
|
||||
torchvision
|
||||
supervision==0.18.0
|
||||
openai==1.3.5
|
||||
transformers
|
||||
ultralytics==8.3.70
|
||||
azure-identity
|
||||
numpy==1.26.4
|
||||
opencv-python
|
||||
opencv-python-headless
|
||||
gradio
|
||||
dill
|
||||
accelerate
|
||||
pyautogui==0.9.54
|
||||
anthropic[bedrock,vertex]>=0.37.1
|
||||
pyxbrain==1.1.31
|
||||
timm
|
||||
einops==0.8.0
|
||||
paddlepaddle
|
||||
paddleocr
|
||||
ruff==0.6.7
|
||||
pre-commit==3.8.0
|
||||
pytest==8.3.3
|
||||
pytest-asyncio==0.23.6
|
||||
pyautogui==0.9.54
|
||||
streamlit>=1.38.0
|
||||
anthropic[bedrock,vertex]>=0.37.1
|
||||
jsonschema==4.22.0
|
||||
boto3>=1.28.57
|
||||
google-auth<3,>=2
|
||||
screeninfo
|
||||
uiautomation
|
||||
dashscope
|
||||
groq
|
||||
modelscope
|
||||
@@ -1,262 +0,0 @@
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from supervision.detection.core import Detections
|
||||
from supervision.draw.color import Color, ColorPalette
|
||||
|
||||
|
||||
class BoxAnnotator:
|
||||
"""
|
||||
A class for drawing bounding boxes on an image using detections provided.
|
||||
|
||||
Attributes:
|
||||
color (Union[Color, ColorPalette]): The color to draw the bounding box,
|
||||
can be a single color or a color palette
|
||||
thickness (int): The thickness of the bounding box lines, default is 2
|
||||
text_color (Color): The color of the text on the bounding box, default is white
|
||||
text_scale (float): The scale of the text on the bounding box, default is 0.5
|
||||
text_thickness (int): The thickness of the text on the bounding box,
|
||||
default is 1
|
||||
text_padding (int): The padding around the text on the bounding box,
|
||||
default is 5
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
|
||||
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
|
||||
text_color: Color = Color.BLACK,
|
||||
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
text_thickness: int = 2, #1, # 2 for demo
|
||||
text_padding: int = 10,
|
||||
avoid_overlap: bool = True,
|
||||
):
|
||||
self.color: Union[Color, ColorPalette] = color
|
||||
self.thickness: int = thickness
|
||||
self.text_color: Color = text_color
|
||||
self.text_scale: float = text_scale
|
||||
self.text_thickness: int = text_thickness
|
||||
self.text_padding: int = text_padding
|
||||
self.avoid_overlap: bool = avoid_overlap
|
||||
|
||||
def annotate(
|
||||
self,
|
||||
scene: np.ndarray,
|
||||
detections: Detections,
|
||||
labels: Optional[List[str]] = None,
|
||||
skip_label: bool = False,
|
||||
image_size: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draws bounding boxes on the frame using the detections provided.
|
||||
|
||||
Args:
|
||||
scene (np.ndarray): The image on which the bounding boxes will be drawn
|
||||
detections (Detections): The detections for which the
|
||||
bounding boxes will be drawn
|
||||
labels (Optional[List[str]]): An optional list of labels
|
||||
corresponding to each detection. If `labels` are not provided,
|
||||
corresponding `class_id` will be used as label.
|
||||
skip_label (bool): Is set to `True`, skips bounding box label annotation.
|
||||
Returns:
|
||||
np.ndarray: The image with the bounding boxes drawn on it
|
||||
|
||||
Example:
|
||||
```python
|
||||
import supervision as sv
|
||||
|
||||
classes = ['person', ...]
|
||||
image = ...
|
||||
detections = sv.Detections(...)
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
labels = [
|
||||
f"{classes[class_id]} {confidence:0.2f}"
|
||||
for _, _, confidence, class_id, _ in detections
|
||||
]
|
||||
annotated_frame = box_annotator.annotate(
|
||||
scene=image.copy(),
|
||||
detections=detections,
|
||||
labels=labels
|
||||
)
|
||||
```
|
||||
"""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
for i in range(len(detections)):
|
||||
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
||||
class_id = (
|
||||
detections.class_id[i] if detections.class_id is not None else None
|
||||
)
|
||||
idx = class_id if class_id is not None else i
|
||||
color = (
|
||||
self.color.by_idx(idx)
|
||||
if isinstance(self.color, ColorPalette)
|
||||
else self.color
|
||||
)
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(x1, y1),
|
||||
pt2=(x2, y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=self.thickness,
|
||||
)
|
||||
if skip_label:
|
||||
continue
|
||||
|
||||
text = (
|
||||
f"{class_id}"
|
||||
if (labels is None or len(detections) != len(labels))
|
||||
else labels[i]
|
||||
)
|
||||
|
||||
text_width, text_height = cv2.getTextSize(
|
||||
text=text,
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
thickness=self.text_thickness,
|
||||
)[0]
|
||||
|
||||
if not self.avoid_overlap:
|
||||
text_x = x1 + self.text_padding
|
||||
text_y = y1 - self.text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * self.text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * self.text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
# text_x = x1 - self.text_padding - text_width
|
||||
# text_y = y1 + self.text_padding + text_height
|
||||
# text_background_x1 = x1 - 2 * self.text_padding - text_width
|
||||
# text_background_y1 = y1
|
||||
# text_background_x2 = x1
|
||||
# text_background_y2 = y1 + 2 * self.text_padding + text_height
|
||||
else:
|
||||
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
|
||||
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(text_background_x1, text_background_y1),
|
||||
pt2=(text_background_x2, text_background_y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=cv2.FILLED,
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
box_color = color.as_rgb()
|
||||
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
||||
text_color = (0,0,0) if luminance > 160 else (255,255,255)
|
||||
cv2.putText(
|
||||
img=scene,
|
||||
text=text,
|
||||
org=(text_x, text_y),
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
# color=self.text_color.as_rgb(),
|
||||
color=text_color,
|
||||
thickness=self.text_thickness,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
return scene
|
||||
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2, return_max=True):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
if return_max:
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
else:
|
||||
return intersection / union
|
||||
|
||||
|
||||
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
|
||||
""" check overlap of text and background detection box, and get_optimal_label_pos,
|
||||
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
|
||||
Threshold: default to 0.3
|
||||
"""
|
||||
|
||||
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
|
||||
is_overlap = False
|
||||
for i in range(len(detections)):
|
||||
detection = detections.xyxy[i].astype(int)
|
||||
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
|
||||
is_overlap = True
|
||||
break
|
||||
# check if the text is out of the image
|
||||
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
|
||||
is_overlap = True
|
||||
return is_overlap
|
||||
|
||||
# if pos == 'top left':
|
||||
text_x = x1 + text_padding
|
||||
text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'outer left':
|
||||
text_x = x1 - text_padding - text_width
|
||||
text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x1 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x1
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
|
||||
# elif pos == 'outer right':
|
||||
text_x = x2 + text_padding
|
||||
text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x2
|
||||
text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x2 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'top right':
|
||||
text_x = x2 - text_padding - text_width
|
||||
text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x2 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x2
|
||||
text_background_y2 = y1
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
@@ -1,12 +1,14 @@
|
||||
import subprocess
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from modelscope import snapshot_download
|
||||
__WEIGHTS_DIR = Path("weights")
|
||||
MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
|
||||
def download():
|
||||
# 创建权重目录
|
||||
weights_dir = Path("weights")
|
||||
weights_dir.mkdir(exist_ok=True)
|
||||
# Create weights directory
|
||||
|
||||
__WEIGHTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 需要下载的文件列表
|
||||
# List of files to download
|
||||
files = [
|
||||
"icon_detect/train_args.yaml",
|
||||
"icon_detect/model.pt",
|
||||
@@ -16,42 +18,24 @@ def download():
|
||||
"icon_caption/model.safetensors"
|
||||
]
|
||||
|
||||
# 检查并下载缺失的文件
|
||||
# Check and download missing files
|
||||
missing_files = []
|
||||
for file in files:
|
||||
file_path = weights_dir / file
|
||||
if not file_path.exists():
|
||||
file_path = os.path.join(MODEL_DIR, file)
|
||||
if not os.path.exists(file_path):
|
||||
missing_files.append(file)
|
||||
break
|
||||
|
||||
if not missing_files:
|
||||
print("已经检测到模型文件!")
|
||||
print("Model files already detected!")
|
||||
return
|
||||
|
||||
print(f"未检测到模型文件,需要下载 {len(missing_files)} 个文件")
|
||||
# 下载缺失的文件
|
||||
max_retries = 3 # 最大重试次数
|
||||
for file in missing_files:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f"正在下载: {file} (尝试 {attempt + 1}/{max_retries})")
|
||||
cmd = [
|
||||
"huggingface-cli",
|
||||
"download",
|
||||
"microsoft/OmniParser-v2.0",
|
||||
file,
|
||||
"--local-dir",
|
||||
"weights"
|
||||
]
|
||||
subprocess.run(cmd, check=True)
|
||||
break # 下载成功,跳出重试循环
|
||||
except subprocess.CalledProcessError as e:
|
||||
if attempt == max_retries - 1: # 最后一次尝试
|
||||
print(f"下载失败: {file},已达到最大重试次数")
|
||||
raise # 重新抛出异常
|
||||
print(f"下载失败: {file},正在重试...")
|
||||
continue
|
||||
|
||||
print("下载完成")
|
||||
snapshot_download(
|
||||
'AI-ModelScope/OmniParser-v2.0',
|
||||
cache_dir='weights'
|
||||
)
|
||||
|
||||
print("Download complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
download()
|
||||
@@ -1,31 +0,0 @@
|
||||
from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
|
||||
import torch
|
||||
from PIL import Image
|
||||
import io
|
||||
import base64
|
||||
from typing import Dict
|
||||
class Omniparser(object):
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
|
||||
print('Server initialized!')
|
||||
|
||||
def parse(self, image_base64: str):
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
print('image size:', image.size)
|
||||
|
||||
box_overlay_ratio = max(image.size) / 3200
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8 * box_overlay_ratio,
|
||||
'text_thickness': max(int(2 * box_overlay_ratio), 1),
|
||||
'text_padding': max(int(3 * box_overlay_ratio), 1),
|
||||
'thickness': max(int(3 * box_overlay_ratio), 1),
|
||||
}
|
||||
|
||||
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||
|
||||
return dino_labled_img, parsed_content_list
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import traceback
|
||||
import pyautogui
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
540
util/utils.py
540
util/utils.py
@@ -1,540 +0,0 @@
|
||||
# from ultralytics import YOLO
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import json
|
||||
import requests
|
||||
# utility function
|
||||
import os
|
||||
from openai import AzureOpenAI
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
# %matplotlib inline
|
||||
from matplotlib import pyplot as plt
|
||||
import easyocr
|
||||
from paddleocr import PaddleOCR
|
||||
reader = easyocr.Reader(['en', 'ch_sim'])
|
||||
paddle_ocr = PaddleOCR(
|
||||
lang='ch', # other lang also available
|
||||
use_angle_cls=False,
|
||||
use_gpu=False, # using cuda will conflict with pytorch in the same process
|
||||
show_log=False,
|
||||
max_batch_size=1024,
|
||||
use_dilation=True, # improves accuracy
|
||||
det_db_score_mode='slow', # improves accuracy
|
||||
rec_batch_num=1024)
|
||||
import time
|
||||
import base64
|
||||
|
||||
import os
|
||||
import ast
|
||||
import torch
|
||||
from typing import Tuple, List, Union
|
||||
from torchvision.ops import box_convert
|
||||
import re
|
||||
from torchvision.transforms import ToPILImage
|
||||
import supervision as sv
|
||||
import torchvision.transforms as T
|
||||
from util.box_annotator import BoxAnnotator
|
||||
|
||||
|
||||
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
|
||||
if not device:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if model_name == "blip2":
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
if device == 'cpu':
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, device_map=None, torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, device_map=None, torch_dtype=torch.float16
|
||||
).to(device)
|
||||
elif model_name == "florence2":
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
|
||||
return {'model': model.to(device), 'processor': processor}
|
||||
|
||||
|
||||
def get_yolo_model(model_path):
|
||||
from ultralytics import YOLO
|
||||
# Load the model.
|
||||
model = YOLO(model_path)
|
||||
return model
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128):
|
||||
# Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
|
||||
to_pil = ToPILImage()
|
||||
if starting_idx:
|
||||
non_ocr_boxes = filtered_boxes[starting_idx:]
|
||||
else:
|
||||
non_ocr_boxes = filtered_boxes
|
||||
croped_pil_image = []
|
||||
for i, coord in enumerate(non_ocr_boxes):
|
||||
try:
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
cropped_image = cv2.resize(cropped_image, (64, 64))
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
except:
|
||||
continue
|
||||
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
if not prompt:
|
||||
if 'florence' in model.config.model_type:
|
||||
prompt = "<CAPTION>"
|
||||
else:
|
||||
prompt = "The image shows"
|
||||
|
||||
generated_texts = []
|
||||
device = model.device
|
||||
for i in range(0, len(croped_pil_image), batch_size):
|
||||
start = time.time()
|
||||
batch = croped_pil_image[i:i+batch_size]
|
||||
t1 = time.time()
|
||||
if model.device.type == 'cuda':
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
|
||||
else:
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
|
||||
if 'florence' in model.config.model_type:
|
||||
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
|
||||
else:
|
||||
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
generated_text = [gen.strip() for gen in generated_text]
|
||||
generated_texts.extend(generated_text)
|
||||
|
||||
return generated_texts
|
||||
|
||||
|
||||
|
||||
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
|
||||
to_pil = ToPILImage()
|
||||
if ocr_bbox:
|
||||
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
|
||||
else:
|
||||
non_ocr_boxes = filtered_boxes
|
||||
croped_pil_image = []
|
||||
for i, coord in enumerate(non_ocr_boxes):
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
device = model.device
|
||||
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
|
||||
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
batch_size = 5 # Number of samples per batch
|
||||
generated_texts = []
|
||||
|
||||
for i in range(0, len(croped_pil_image), batch_size):
|
||||
images = croped_pil_image[i:i+batch_size]
|
||||
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
|
||||
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
|
||||
texts = [prompt] * len(images)
|
||||
for i, txt in enumerate(texts):
|
||||
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
|
||||
inputs['input_ids'].append(input['input_ids'])
|
||||
inputs['attention_mask'].append(input['attention_mask'])
|
||||
inputs['pixel_values'].append(input['pixel_values'])
|
||||
inputs['image_sizes'].append(input['image_sizes'])
|
||||
max_len = max([x.shape[1] for x in inputs['input_ids']])
|
||||
for i, v in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
|
||||
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
|
||||
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
|
||||
|
||||
generation_args = {
|
||||
"max_new_tokens": 25,
|
||||
"temperature": 0.01,
|
||||
"do_sample": False,
|
||||
}
|
||||
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
|
||||
# # remove input tokens
|
||||
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
|
||||
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
response = [res.strip('\n').strip() for res in response]
|
||||
generated_texts.extend(response)
|
||||
|
||||
return generated_texts
|
||||
|
||||
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
|
||||
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
|
||||
def is_inside(box1, box2):
|
||||
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
||||
intersection = intersection_area(box1, box2)
|
||||
ratio1 = intersection / box_area(box1)
|
||||
return ratio1 > 0.95
|
||||
|
||||
boxes = boxes.tolist()
|
||||
filtered_boxes = []
|
||||
if ocr_bbox:
|
||||
filtered_boxes.extend(ocr_bbox)
|
||||
# print('ocr_bbox!!!', ocr_bbox)
|
||||
for i, box1 in enumerate(boxes):
|
||||
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
|
||||
is_valid_box = True
|
||||
for j, box2 in enumerate(boxes):
|
||||
# keep the smaller box
|
||||
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
||||
is_valid_box = False
|
||||
break
|
||||
if is_valid_box:
|
||||
# add the following 2 lines to include ocr bbox
|
||||
if ocr_bbox:
|
||||
# only add the box if it does not overlap with any ocr bbox
|
||||
if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
|
||||
filtered_boxes.append(box1)
|
||||
else:
|
||||
filtered_boxes.append(box1)
|
||||
return torch.tensor(filtered_boxes)
|
||||
|
||||
|
||||
def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
|
||||
'''
|
||||
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
|
||||
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
|
||||
|
||||
'''
|
||||
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
|
||||
def is_inside(box1, box2):
|
||||
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
||||
intersection = intersection_area(box1, box2)
|
||||
ratio1 = intersection / box_area(box1)
|
||||
return ratio1 > 0.80
|
||||
|
||||
# boxes = boxes.tolist()
|
||||
filtered_boxes = []
|
||||
if ocr_bbox:
|
||||
filtered_boxes.extend(ocr_bbox)
|
||||
# print('ocr_bbox!!!', ocr_bbox)
|
||||
for i, box1_elem in enumerate(boxes):
|
||||
box1 = box1_elem['bbox']
|
||||
is_valid_box = True
|
||||
for j, box2_elem in enumerate(boxes):
|
||||
# keep the smaller box
|
||||
box2 = box2_elem['bbox']
|
||||
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
||||
is_valid_box = False
|
||||
break
|
||||
if is_valid_box:
|
||||
if ocr_bbox:
|
||||
# keep yolo boxes + prioritize ocr label
|
||||
box_added = False
|
||||
ocr_labels = ''
|
||||
for box3_elem in ocr_bbox:
|
||||
if not box_added:
|
||||
box3 = box3_elem['bbox']
|
||||
if is_inside(box3, box1): # ocr inside icon
|
||||
# box_added = True
|
||||
# delete the box3_elem from ocr_bbox
|
||||
try:
|
||||
# gather all ocr labels
|
||||
ocr_labels += box3_elem['content'] + ' '
|
||||
filtered_boxes.remove(box3_elem)
|
||||
except:
|
||||
continue
|
||||
# break
|
||||
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
|
||||
box_added = True
|
||||
break
|
||||
else:
|
||||
continue
|
||||
if not box_added:
|
||||
if ocr_labels:
|
||||
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
|
||||
else:
|
||||
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
|
||||
else:
|
||||
filtered_boxes.append(box1)
|
||||
return filtered_boxes # torch.tensor(filtered_boxes)
|
||||
|
||||
|
||||
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image_source = Image.open(image_path).convert("RGB")
|
||||
image = np.asarray(image_source)
|
||||
image_transformed, _ = transform(image_source, None)
|
||||
return image, image_transformed
|
||||
|
||||
|
||||
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
|
||||
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
|
||||
"""
|
||||
This function annotates an image with bounding boxes and labels.
|
||||
|
||||
Parameters:
|
||||
image_source (np.ndarray): The source image to be annotated.
|
||||
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
|
||||
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
|
||||
phrases (List[str]): A list of labels for each bounding box.
|
||||
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
|
||||
Returns:
|
||||
np.ndarray: The annotated image.
|
||||
"""
|
||||
h, w, _ = image_source.shape
|
||||
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
|
||||
detections = sv.Detections(xyxy=xyxy)
|
||||
|
||||
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
|
||||
|
||||
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
annotated_frame = image_source.copy()
|
||||
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
|
||||
|
||||
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
|
||||
return annotated_frame, label_coordinates
|
||||
|
||||
|
||||
def predict(model, image, caption, box_threshold, text_threshold):
|
||||
""" Use huggingface model to replace the original model
|
||||
"""
|
||||
model, processor = model['model'], model['processor']
|
||||
device = model.device
|
||||
|
||||
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
inputs.input_ids,
|
||||
box_threshold=box_threshold, # 0.4,
|
||||
text_threshold=text_threshold, # 0.3,
|
||||
target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
|
||||
return boxes, logits, phrases
|
||||
|
||||
|
||||
def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
|
||||
""" Use huggingface model to replace the original model
|
||||
"""
|
||||
# model = model['model']
|
||||
if scale_img:
|
||||
result = model.predict(
|
||||
source=image,
|
||||
conf=box_threshold,
|
||||
imgsz=imgsz,
|
||||
iou=iou_threshold, # default 0.7
|
||||
)
|
||||
else:
|
||||
result = model.predict(
|
||||
source=image,
|
||||
conf=box_threshold,
|
||||
iou=iou_threshold, # default 0.7
|
||||
)
|
||||
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
|
||||
conf = result[0].boxes.conf
|
||||
phrases = [str(i) for i in range(len(boxes))]
|
||||
|
||||
return boxes, conf, phrases
|
||||
|
||||
def int_box_area(box, w, h):
|
||||
x1, y1, x2, y2 = box
|
||||
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
|
||||
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
|
||||
return area
|
||||
|
||||
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128):
|
||||
"""Process either an image path or Image object
|
||||
|
||||
Args:
|
||||
image_source: Either a file path (str) or PIL Image object
|
||||
...
|
||||
"""
|
||||
if isinstance(image_source, str):
|
||||
image_source = Image.open(image_source)
|
||||
image_source = image_source.convert("RGB") # for CLIP
|
||||
w, h = image_source.size
|
||||
if not imgsz:
|
||||
imgsz = (h, w)
|
||||
# print('image size:', w, h)
|
||||
xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
|
||||
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
|
||||
image_source = np.asarray(image_source)
|
||||
phrases = [str(i) for i in range(len(phrases))]
|
||||
|
||||
# annotate the image with labels
|
||||
if ocr_bbox:
|
||||
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
|
||||
ocr_bbox=ocr_bbox.tolist()
|
||||
else:
|
||||
print('no ocr bbox!!!')
|
||||
ocr_bbox = None
|
||||
|
||||
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
|
||||
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
|
||||
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
|
||||
|
||||
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
|
||||
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
|
||||
# get the index of the first 'content': None
|
||||
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
|
||||
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
|
||||
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
|
||||
|
||||
# get parsed icon local semantics
|
||||
time1 = time.time()
|
||||
if use_local_semantics:
|
||||
caption_model = caption_model_processor['model']
|
||||
if 'phi3_v' in caption_model.config.model_type:
|
||||
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
|
||||
else:
|
||||
parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
|
||||
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
||||
icon_start = len(ocr_text)
|
||||
parsed_content_icon_ls = []
|
||||
# fill the filtered_boxes_elem None content with parsed_content_icon in order
|
||||
for i, box in enumerate(filtered_boxes_elem):
|
||||
if box['content'] is None:
|
||||
box['content'] = parsed_content_icon.pop(0)
|
||||
for i, txt in enumerate(parsed_content_icon):
|
||||
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
|
||||
parsed_content_merged = ocr_text + parsed_content_icon_ls
|
||||
else:
|
||||
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
||||
parsed_content_merged = ocr_text
|
||||
print('time to get parsed content:', time.time()-time1)
|
||||
|
||||
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
||||
|
||||
phrases = [i for i in range(len(filtered_boxes))]
|
||||
|
||||
# draw boxes
|
||||
if draw_bbox_config:
|
||||
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
|
||||
else:
|
||||
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
|
||||
|
||||
pil_img = Image.fromarray(annotated_frame)
|
||||
buffered = io.BytesIO()
|
||||
pil_img.save(buffered, format="PNG")
|
||||
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
if output_coord_in_ratio:
|
||||
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
|
||||
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
|
||||
|
||||
return encoded_image, label_coordinates, filtered_boxes_elem
|
||||
|
||||
|
||||
def get_xywh(input):
|
||||
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
return x, y, w, h
|
||||
|
||||
def get_xyxy(input):
|
||||
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
|
||||
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
|
||||
return x, y, xp, yp
|
||||
|
||||
def get_xywh_yolo(input):
|
||||
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
return x, y, w, h
|
||||
|
||||
def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
|
||||
if isinstance(image_source, str):
|
||||
image_source = Image.open(image_source)
|
||||
if image_source.mode == 'RGBA':
|
||||
# Convert RGBA to RGB to avoid alpha channel issues
|
||||
image_source = image_source.convert('RGB')
|
||||
image_np = np.array(image_source)
|
||||
w, h = image_source.size
|
||||
if use_paddleocr:
|
||||
if easyocr_args is None:
|
||||
text_threshold = 0.5
|
||||
else:
|
||||
text_threshold = easyocr_args['text_threshold']
|
||||
result = paddle_ocr.ocr(image_np, cls=False)[0]
|
||||
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
||||
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
||||
else: # EasyOCR
|
||||
if easyocr_args is None:
|
||||
easyocr_args = {}
|
||||
result = reader.readtext(image_np, **easyocr_args)
|
||||
coord = [item[0] for item in result]
|
||||
text = [item[1] for item in result]
|
||||
if display_img:
|
||||
opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
||||
bb = []
|
||||
for item in coord:
|
||||
x, y, a, b = get_xywh(item)
|
||||
bb.append((x, y, a, b))
|
||||
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
|
||||
# matplotlib expects RGB
|
||||
plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
|
||||
else:
|
||||
if output_bb_format == 'xywh':
|
||||
bb = [get_xywh(item) for item in coord]
|
||||
elif output_bb_format == 'xyxy':
|
||||
bb = [get_xyxy(item) for item in coord]
|
||||
return (text, bb), goal_filtering
|
||||
Reference in New Issue
Block a user