Merge branch 'dev-test'

This commit is contained in:
yuruo
2025-03-13 17:20:32 +08:00
21 changed files with 724 additions and 1337 deletions

View File

@@ -77,10 +77,15 @@ For supported vendors and models, please refer to this [link](./SUPPORT_MODEL.md
### 🔧 CUDA Version Mismatch
If you see the error: "GPU driver incompatible, please install appropriate torch version according to readme", it indicates a driver incompatibility. You can either:
1. Run with CPU only (slower but functional)
2. Check your torch version with `pip list`
3. Check supported CUDA versions on the [official website](https://pytorch.org/get-started/locally/)
4. Reinstall Nvidia drivers
1. Run pip list to check the torch version;
2. Check supported CUDA versions on the official website;
3. Copy the official torch installation command and reinstall torch for your CUDA version.
For example, if your CUDA version is 12.4, install torch using this command:
```bash
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
```
## 🤝 Contributing

View File

@@ -78,8 +78,20 @@ python main.py
1. 运行`pip list`查看torch版本
2. 从[官网](https://pytorch.org/get-started/locally/)查看支持的cuda版本
3. 重新安装Nvidia驱动。
3. 卸载已安装的 torch 和 torchvision
3. 复制官方的 torch 安装命令,重新安装适合自己 cuda 版本的 torch。
比如我的 cuda 版本为 12.4,需要按照如下命令来安装 torch
```bash
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
```
### 模型无法下载
多半是被墙了,可以从百度网盘直接下载模型。
通过网盘分享的文件weights.zip
链接: https://pan.baidu.com/s/1Tj8sZZK9_QI7whZV93vb0w?pwd=dyeu 提取码: dyeu
## 🤝 参与共建

View File

@@ -1,5 +1,3 @@
| Vendor-en | Vendor-ch | Model | base-url |
| --- | --- | --- | --- |
| Alibaba Cloud Bailian | 阿里云百炼 | deepseek-r1 | https://dashscope.aliyuncs.com/compatible-mode/v1 |
| Alibaba Cloud Bailian | 阿里云百炼 | deepseek-v3 | https://dashscope.aliyuncs.com/compatible-mode/v1 |
| deepseek | deepseek官方 | deepseek-chat | https://api.deepseek.com |
| openainext | openainext | gpt-4o-2024-11-20 | https://api.openai-next.com/v1 |

View File

@@ -0,0 +1,8 @@
class BaseAgent:
def __init__(self, *args, **kwargs):
self.SYSTEM_PROMPT = ""
def chat(self, messages):
pass

View File

@@ -19,26 +19,31 @@ class OmniParserClient:
response_json = response.json()
print('omniparser latency:', response_json['latency'])
som_image_data = base64.b64decode(response_json['som_image_base64'])
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
with open(som_screenshot_path, "wb") as f:
f.write(som_image_data)
# som_image_data = base64.b64decode(response_json['som_image_base64'])
# screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
# som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
# with open(som_screenshot_path, "wb") as f:
# f.write(som_image_data)
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['original_screenshot_base64'] = image_base64
response_json['screenshot_uuid'] = screenshot_path_uuid
# response_json['screenshot_uuid'] = screenshot_path_uuid
response_json = self.reformat_messages(response_json)
return response_json
def reformat_messages(self, response_json: dict):
screen_info = ""
for idx, element in enumerate(response_json["parsed_content_list"]):
element['idx'] = idx
if element['type'] == 'text':
screen_info += f'ID: {idx}, Text: {element["content"]}\n'
elif element['type'] == 'icon':
screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
# element['idx'] = idx
# if element['type'] == 'text':
# screen_info += f'ID: {idx}, Text: {element["content"]}\n'
# elif element['type'] == 'icon':
# screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
screen_info += f'ID: {element.element_id}, '
screen_info += f'Coordinates: {element.coordinates}, '
screen_info += f'Text: {element.text if len(element.text) else " "}, '
screen_info += f'Caption: {element.caption}. '
screen_info += "\n"
response_json['screen_info'] = screen_info
return response_json

View File

@@ -0,0 +1,40 @@
from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
class TaskPlanAgent(BaseAgent):
def __init__(self, output_callback):
self.output_callback = output_callback
def __call__(self, user_task: str):
self.output_callback("正在规划任务中...", sender="bot")
response = run([{"role": "user", "content": user_task}], user_prompt=system_prompt)
self.output_callback(response, sender="bot")
return response
system_prompt = """
### 目标 ###
你是电脑任务规划专家,根据用户的需求,规划出要执行的任务。
##########
### 输入 ###
用户的需求,通常是一个文本描述。
##########
### 输出 ###
一系列任务,包括任务名称
##########
### 例子 ###
案例1
输入获取AI新闻
输出:
1. 打开浏览器
2. 打开百度首页
3. 搜索“AI”相关内容
4. 浏览搜索结果,记录搜索结果
5. 返回搜索内容
案例2
输入删除桌面的txt文件
输出:
1. 进入桌面
2. 寻找所有txt文件
3. 右键txt文件选择删除
"""

View File

@@ -0,0 +1,198 @@
import json
import uuid
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from PIL import ImageDraw
import base64
from io import BytesIO
from pydantic import BaseModel, Field
from gradio_ui.agent.base_agent import BaseAgent
from xbrain.core.chat import run
import platform
import re
class TaskRunAgent(BaseAgent):
def __init__(self, output_callback):
self.output_callback = output_callback
self.OUTPUT_DIR = "./tmp/outputs"
def __call__(self, task_plan, parsed_screen):
screen_info = str(parsed_screen['parsed_content_list'])
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
device=self.get_device(),
screen_info=screen_info)
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
img_to_show = parsed_screen["image"]
buffered = BytesIO()
img_to_show.save(buffered, format="PNG")
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
vlm_response = run([
{
"role": "user",
"content": [
{"type": "text", "text": "图片是当前屏幕的截图,请根据图片以及解析出来的元素,确定下一步操作"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_to_show_base64}"
}
}
]
}
], user_prompt=self.SYSTEM_PROMPT, response_format=TaskRunAgentResponse)
vlm_response_json = json.loads(vlm_response)
if "box_id" in vlm_response_json:
try:
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["box_id"])].coordinates
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 ), int((bbox[1] + bbox[3]) / 2 )]
x, y = vlm_response_json["box_centroid_coordinate"]
radius = 10
draw = ImageDraw.Draw(img_to_show)
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
buffered = BytesIO()
img_to_show.save(buffered, format="PNG")
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e:
print(f"Error parsing: {vlm_response_json}")
print(f"Error: {e}")
self.output_callback(f'<img src="data:image/png;base64,{img_to_show_base64}">', sender="bot")
self.output_callback(
f'<details>'
f' <summary>Parsed Screen elemetns by OmniParser</summary>'
f' <pre>{screen_info}</pre>'
f'</details>',
sender="bot"
)
response_content = [BetaTextBlock(text=vlm_response_json["reasoning"], type='text')]
if 'box_centroid_coordinate' in vlm_response_json:
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
name='computer', type='tool_use')
response_content.append(move_cursor_block)
if vlm_response_json["next_action"] == "None":
print("Task paused/completed.")
elif vlm_response_json["next_action"] == "type":
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["next_action"], 'text': vlm_response_json["value"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
else:
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["next_action"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
return response_message, vlm_response_json
def get_device(self):
# 获取当前操作系统信息
system = platform.system()
if system == "Windows":
device = f"Windows {platform.release()}"
elif system == "Darwin":
device = f"Mac OS {platform.mac_ver()[0]}"
elif system == "Linux":
device = f"Linux {platform.release()}"
else:
device = system
return device
def extract_data(self, input_string, data_type):
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
pattern = f"```{data_type}" + r"(.*?)(```|$)"
# Extract content
# re.DOTALL allows '.' to match newlines as well
matches = re.findall(pattern, input_string, re.DOTALL)
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
return matches[0][0].strip() if matches else input_string
class TaskRunAgentResponse(BaseModel):
reasoning: str = Field(description="描述当前屏幕上的内容,考虑历史记录,然后描述您如何实现任务的逐步思考,一次从可用操作中选择一个操作。")
next_action: str = Field(
description="选择一个操作类型如果找不到合适的操作请选择None",
json_schema_extra={
"enum": ["type", "left_click", "right_click", "double_click",
"hover", "scroll_up", "scroll_down", "wait", "None"]
}
)
box_id: int = Field(description="要操作的框ID当next_action为left_click、right_click、double_click、hover时提供否则为None", default=None)
value: str = Field(description="仅当next_action为type时提供否则为None", default=None)
system_prompt = """
### 目标 ###
你是一个自动化规划师,需要完成用户的任务。请你根据屏幕信息确定【下一步操作】,以完成任务:
你当前的任务是:
{task_plan}
以下是用yolo检测的当前屏幕上的所有元素
{screen_info}
##########
### 注意 ###
1. 每次应该只给出一个操作。
2. 应该对当前屏幕进行分析,通过查看历史记录反思已完成的工作,然后描述您如何实现任务的逐步思考。
3. 在"Next Action"中附上下一步操作预测。
4. 不应包括其他操作,例如键盘快捷键。
5. 当任务完成时不要完成额外的操作。你应该在json字段中说"Next Action": "None"
6. 任务涉及购买多个产品或浏览多个页面。你应该将其分解为子目标,并按照说明的顺序一个一个地完成每个子目标。
7. 避免连续多次选择相同的操作/元素,如果发生这种情况,反思自己,可能出了什么问题,并预测不同的操作。
8. 如果您收到登录信息页面或验证码页面的提示或者您认为下一步操作需要用户许可您应该在json字段中说"Next Action": "None"
9. 你只能使用鼠标和键盘与计算机进行交互。
10. 你只能与桌面图形用户界面交互(无法访问终端或应用程序菜单)。
11. 如果当前屏幕没有显示任何可操作的元素并且当前屏幕不能下滑请返回None。
##########
### 输出格式 ###
```json
{{
"reasoning": str, # 描述当前屏幕上的内容,考虑历史记录,然后描述您如何实现任务的逐步思考,一次从可用操作中选择一个操作。
"next_action": "action_type, action description" | "None" # 一次一个操作,简短精确地描述它。
"box_id": n,
"value": "xxx" # 仅当操作为type时提供value字段否则不包括value键
}}
```
【next_action】仅包括下面之一
- type输入一串文本。
- left_click将鼠标移动到框ID并左键单击。
- right_click将鼠标移动到框ID并右键单击。
- double_click将鼠标移动到框ID并双击。
- hover将鼠标移动到框ID。
- scroll_up向上滚动屏幕以查看之前的内容。
- scroll_down当所需按钮不可见或您需要查看更多内容时向下滚动屏幕。
- wait等待1秒钟让设备加载或响应。
##########
### 案例 ###
一个例子:
```json
{{
"reasoning": "当前屏幕显示亚马逊的谷歌搜索结果在之前的操作中我已经在谷歌上搜索了亚马逊。然后我需要点击第一个搜索结果以转到amazon.com。",
"next_action": "left_click",
"box_id": m
}}
```
另一个例子:
```json
{{
"reasoning": "当前屏幕显示亚马逊的首页。没有之前的操作。因此,我需要在搜索栏中输入"Apple watch"",
"next_action": "type",
"box_id": n,
"value": "Apple watch"
}}
```
另一个例子:
```json
{{
"reasoning": "当前屏幕没有显示'提交'按钮,我需要向下滚动以查看按钮是否可用。",
"next_action": "scroll_down"
}}
"""

View File

@@ -0,0 +1,299 @@
from typing import List, Optional
import cv2
import torch
from ultralytics import YOLO
from transformers import AutoModelForCausalLM, AutoProcessor
import easyocr
import supervision as sv
import numpy as np
import time
from pydantic import BaseModel
import base64
from PIL import Image
class UIElement(BaseModel):
element_id: int
coordinates: list[float]
caption: Optional[str] = None
text: Optional[str] = None
class VisionAgent:
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
"""
Initialize the vision agent
Parameters:
yolo_model_path: Path to YOLO model
caption_model_path: Path to image caption model, default is Florence-2
"""
# determine the available device and the best dtype
self.device, self.dtype = self._get_optimal_device_and_dtype()
# load the YOLO model
self.yolo_model = YOLO(yolo_model_path)
# load the image caption model and processor
self.caption_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
)
# load the model according to the device type
try:
if self.device.type == 'cuda':
# CUDA device uses float16
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float16,
trust_remote_code=True
).to(self.device)
elif self.device.type == 'mps':
# MPS device uses float32 (MPS has limited support for float16)
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
else:
# CPU uses float32
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
except Exception as e:
raise e
self.prompt = "<CAPTION>"
# set the batch size
if self.device.type == 'cuda':
self.batch_size = 128
elif self.device.type == 'mps':
self.batch_size = 32
else:
self.batch_size = 16
self.elements: List[UIElement] = []
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
def __call__(self, image_path: str) -> List[UIElement]:
"""Process an image from file path."""
# image = self.load_image(image_source)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"Vision agent: Failed to read image")
return self.analyze_image(image)
def _get_optimal_device_and_dtype(self):
"""determine the optimal device and dtype"""
if torch.cuda.is_available():
device = torch.device("cuda")
# check if the GPU is suitable for using float16
capability = torch.cuda.get_device_capability()
# only use float16 on newer GPUs
if capability[0] >= 7:
dtype = torch.float16
else:
dtype = torch.float32
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device("mps")
dtype = torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32
return device, dtype
def _reset_state(self):
"""Clear previous analysis results"""
self.elements = []
def analyze_image(self, image: np.ndarray) -> List[UIElement]:
"""
Process an image through all computer vision pipelines.
Args:
image: Input image in BGR format (OpenCV default)
Returns:
List of detected UI elements with annotations
"""
self._reset_state()
element_crops, boxes = self._detect_objects(image)
start = time.time()
element_texts = self._extract_text(element_crops)
end = time.time()
ocr_time = (end-start) * 10 ** 3
print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
start = time.time()
element_captions = self._get_caption(element_crops, 5)
end = time.time()
caption_time = (end-start) * 10 ** 3
print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
for idx in range(len(element_crops)):
new_element = UIElement(element_id=idx,
coordinates=boxes[idx],
text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
caption=element_captions[idx]
)
self.elements.append(new_element)
return self.elements
def _extract_text(self, images: np.ndarray) -> list[str]:
"""
Run OCR in sequential mode
TODO: It is possible to run in batch mode for a speed up, but the result quality needs test.
https://github.com/JaidedAI/EasyOCR/pull/458
"""
texts = []
for image in images:
text = self.ocr_reader.readtext(image, detail=0, paragraph=True, text_threshold=0.85)
texts.append(text)
# print(texts)
return texts
def _get_caption(self, element_crops, batch_size=None):
"""get the caption of the element crops"""
if not element_crops:
return []
# if batch_size is not specified, use the instance's default value
if batch_size is None:
batch_size = self.batch_size
# resize the image to 64x64
resized_crops = []
for img in element_crops:
# convert to numpy array, resize, then convert back to PIL
img_np = np.array(img)
resized_np = cv2.resize(img_np, (64, 64))
resized_crops.append(Image.fromarray(resized_np))
generated_texts = []
device = self.device
# process in batches
for i in range(0, len(resized_crops), batch_size):
batch = resized_crops[i:i+batch_size]
try:
# select the dtype according to the device type
if device.type == 'cuda':
inputs = self.caption_processor(
images=batch,
text=[self.prompt] * len(batch),
return_tensors="pt",
do_resize=False
).to(device=device, dtype=torch.float16)
else:
# MPS and CPU use float32
inputs = self.caption_processor(
images=batch,
text=[self.prompt] * len(batch),
return_tensors="pt"
).to(device=device)
# special treatment for Florence-2
with torch.no_grad():
if 'florence' in self.caption_model.config.model_type:
generated_ids = self.caption_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=20,
num_beams=5,
do_sample=False
)
else:
generated_ids = self.caption_model.generate(
**inputs,
max_length=50,
num_beams=3,
early_stopping=True
)
# decode the generated IDs
texts = self.caption_processor.batch_decode(
generated_ids,
skip_special_tokens=True
)
texts = [text.strip() for text in texts]
generated_texts.extend(texts)
# clean the cache
if device.type == 'cuda' and torch.cuda.is_available():
torch.cuda.empty_cache()
except RuntimeError as e:
raise e
return generated_texts
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
"""Run object detection pipeline"""
results = self.yolo_model(image)[0]
detections = sv.Detections.from_ultralytics(results)
boxes = detections.xyxy
if len(boxes) == 0:
return []
# Filter out boxes contained by others
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sorted_indices = np.argsort(-areas) # Sort descending by area
sorted_boxes = boxes[sorted_indices]
keep_sorted = []
for i in range(len(sorted_boxes)):
contained = False
for j in keep_sorted:
box_b = sorted_boxes[j]
box_a = sorted_boxes[i]
if (box_b[0] <= box_a[0] and box_b[1] <= box_a[1] and
box_b[2] >= box_a[2] and box_b[3] >= box_a[3]):
contained = True
break
if not contained:
keep_sorted.append(i)
# Map back to original indices
keep_indices = sorted_indices[keep_sorted]
filtered_boxes = boxes[keep_indices]
# Extract element crops
element_crops = []
for box in filtered_boxes:
x1, y1, x2, y2 = map(int, map(round, box))
element = image[y1:y2, x1:x2]
element_crops.append(np.array(element))
return element_crops, filtered_boxes
def load_image(self, image_source: str) -> np.ndarray:
try:
# Handle potential Data URL prefix (like "data:image/png;base64,")
if ',' in image_source:
_, payload = image_source.split(',', 1)
else:
payload = image_source
# Base64 decode -> bytes -> numpy array
image_bytes = base64.b64decode(payload)
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
# OpenCV decode image
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
if image is None:
raise ValueError("Failed to decode image: Invalid image data")
return self.analyze_image(image)
except (base64.binascii.Error, ValueError) as e:
# Generate clearer error message
error_msg = f"Input is neither a valid file path nor valid Base64 image data"
raise ValueError(error_msg) from e

View File

@@ -1,299 +0,0 @@
import json
from collections.abc import Callable
from typing import Callable
import uuid
from PIL import Image, ImageDraw
import base64
from io import BytesIO
from anthropic import APIResponse
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from gradio_ui.agent.llm_utils.oaiclient import run_oai_interleaved
from gradio_ui.agent.llm_utils.utils import is_image_path
import time
import re
OUTPUT_DIR = "./tmp/outputs"
def extract_data(input_string, data_type):
# Regular expression to extract content starting from '```python' until the end if there are no closing backticks
pattern = f"```{data_type}" + r"(.*?)(```|$)"
# Extract content
# re.DOTALL allows '.' to match newlines as well
matches = re.findall(pattern, input_string, re.DOTALL)
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
return matches[0][0].strip() if matches else input_string
class VLMAgent:
def __init__(
self,
model: str,
api_key: str,
output_callback: Callable,
api_response_callback: Callable,
max_tokens: int = 4096,
base_url:str = "",
only_n_most_recent_images: int | None = None,
print_usage: bool = True,
):
self.base_url = base_url
self.api_key = api_key
self.api_response_callback = api_response_callback
self.max_tokens = max_tokens
self.only_n_most_recent_images = only_n_most_recent_images
self.output_callback = output_callback
self.model = model
self.print_usage = print_usage
self.total_token_usage = 0
self.total_cost = 0
self.step_count = 0
self.system = ''
def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
self.step_count += 1
image_base64 = parsed_screen['original_screenshot_base64']
latency_omniparser = parsed_screen['latency']
self.output_callback(f'-- Step {self.step_count}: --', sender="bot")
screen_info = str(parsed_screen['screen_info'])
screenshot_uuid = parsed_screen['screenshot_uuid']
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
boxids_and_labels = parsed_screen["screen_info"]
system = self._get_system_prompt(boxids_and_labels)
# drop looping actions msg, byte image etc
planner_messages = messages
_remove_som_images(planner_messages)
_maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
if isinstance(planner_messages[-1], dict):
if not isinstance(planner_messages[-1]["content"], list):
planner_messages[-1]["content"] = [planner_messages[-1]["content"]]
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png")
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
start = time.time()
vlm_response, token_usage = run_oai_interleaved(
messages=planner_messages,
system=system,
model_name=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
provider_base_url=self.base_url,
temperature=0,
)
latency_vlm = time.time() - start
self.output_callback(f"LLM: {latency_vlm:.2f}s, OmniParser: {latency_omniparser:.2f}s", sender="bot")
print(f"llm_response: {vlm_response}")
if self.print_usage:
print(f"Total token so far: {token_usage}. Total cost so far: $USD{self.total_cost:.5f}")
vlm_response_json = extract_data(vlm_response, "json")
vlm_response_json = json.loads(vlm_response_json)
img_to_show_base64 = parsed_screen["som_image_base64"]
if "Box ID" in vlm_response_json:
try:
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["Box ID"])]["bbox"]
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
img_to_show_data = base64.b64decode(img_to_show_base64)
img_to_show = Image.open(BytesIO(img_to_show_data))
draw = ImageDraw.Draw(img_to_show)
x, y = vlm_response_json["box_centroid_coordinate"]
radius = 10
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
buffered = BytesIO()
img_to_show.save(buffered, format="PNG")
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
except:
print(f"Error parsing: {vlm_response_json}")
pass
self.output_callback(f'<img src="data:image/png;base64,{img_to_show_base64}">', sender="bot")
self.output_callback(
f'<details>'
f' <summary>Parsed Screen elemetns by OmniParser</summary>'
f' <pre>{screen_info}</pre>'
f'</details>',
sender="bot"
)
vlm_plan_str = ""
for key, value in vlm_response_json.items():
if key == "Reasoning":
vlm_plan_str += f'{value}'
else:
vlm_plan_str += f'\n{key}: {value}'
# construct the response so that anthropicExcutor can execute the tool
response_content = [BetaTextBlock(text=vlm_plan_str, type='text')]
if 'box_centroid_coordinate' in vlm_response_json:
move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
name='computer', type='tool_use')
response_content.append(move_cursor_block)
if vlm_response_json["Next Action"] == "None":
print("Task paused/completed.")
elif vlm_response_json["Next Action"] == "type":
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["Next Action"], 'text': vlm_response_json["value"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
else:
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
input={'action': vlm_response_json["Next Action"]},
name='computer', type='tool_use')
response_content.append(sim_content_block)
response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0))
return response_message, vlm_response_json
def _api_response_callback(self, response: APIResponse):
self.api_response_callback(response)
def _get_system_prompt(self, screen_info: str = ""):
main_section = f"""
You are using a Windows device.
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
You can only interact with the desktop GUI (no terminal or application menu access).
You may be given some history plan and actions, this is the response from the previous loop.
You should carefully consider your plan base on the task, screenshot, and history actions.
Here is the list of all detected bounding boxes by IDs on the screen and their description:{screen_info}
Your available "Next Action" only include:
- type: types a string of text.
- left_click: move mouse to box id and left clicks.
- right_click: move mouse to box id and right clicks.
- double_click: move mouse to box id and double clicks.
- hover: move mouse to box id.
- scroll_up: scrolls the screen up to view previous content.
- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content.
- wait: waits for 1 second for the device to load or respond.
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
Output format:
```json
{{
"Reasoning": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
"Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
"Box ID": n,
"value": "xxx" # only provide value field if the action is type, else don't include value key
}}
```
One Example:
```json
{{
"Reasoning": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
"Next Action": "left_click",
"Box ID": m
}}
```
Another Example:
```json
{{
"Reasoning": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
"Next Action": "type",
"Box ID": n,
"value": "Apple watch"
}}
```
Another Example:
```json
{{
"Reasoning": "The current screen does not show 'submit' button, I need to scroll down to see if the button is available.",
"Next Action": "scroll_down",
}}
```
IMPORTANT NOTES:
1. You should only give a single action at a time.
"""
thinking_model = "r1" in self.model
if not thinking_model:
main_section += """
2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
"""
else:
main_section += """
2. In <think> XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In <output> XML tags put the next action prediction JSON.
"""
main_section += """
3. Attach the next action prediction in the "Next Action".
4. You should not include other actions, such as keyboard shortcuts.
5. When the task is completed, don't complete additional actions. You should say "Next Action": "None" in the json field.
6. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions.
7. avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action.
8. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Next Action": "None" in the json field.
"""
return main_section
def _remove_som_images(messages):
for msg in messages:
msg_content = msg["content"]
if isinstance(msg_content, list):
msg["content"] = [
cnt for cnt in msg_content
if not (isinstance(cnt, str) and 'som' in cnt and is_image_path(cnt))
]
def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int = 10,
):
"""
With the assumption that images are screenshots that are of diminishing value as
the conversation progresses, remove all but the final `images_to_keep` tool_result
images in place
"""
if images_to_keep is None:
return messages
total_images = 0
for msg in messages:
for cnt in msg.get("content", []):
if isinstance(cnt, str) and is_image_path(cnt):
total_images += 1
elif isinstance(cnt, dict) and cnt.get("type") == "tool_result":
for content in cnt.get("content", []):
if isinstance(content, dict) and content.get("type") == "image":
total_images += 1
images_to_remove = total_images - images_to_keep
for msg in messages:
msg_content = msg["content"]
if isinstance(msg_content, list):
new_content = []
for cnt in msg_content:
# Remove images from SOM or screenshot as needed
if isinstance(cnt, str) and is_image_path(cnt):
if images_to_remove > 0:
images_to_remove -= 1
continue
# VLM shouldn't use anthropic screenshot tool so shouldn't have these but in case it does, remove as needed
elif isinstance(cnt, dict) and cnt.get("type") == "tool_result":
new_tool_result_content = []
for tool_result_entry in cnt.get("content", []):
if isinstance(tool_result_entry, dict) and tool_result_entry.get("type") == "image":
if images_to_remove > 0:
images_to_remove -= 1
continue
new_tool_result_content.append(tool_result_entry)
cnt["content"] = new_tool_result_content
# Append fixed content to current message's content list
new_content.append(cnt)
msg["content"] = new_content

View File

@@ -14,12 +14,15 @@ from anthropic import APIResponse
from anthropic.types import TextBlock
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
from anthropic.types.tool_use_block import ToolUseBlock
from gradio_ui.agent.vision_agent import VisionAgent
from gradio_ui.loop import (
sampling_loop_sync,
)
from gradio_ui.tools import ToolResult
import base64
from xbrain.utils.config import Config
from util.download_weights import MODEL_DIR
CONFIG_DIR = Path("~/.anthropic").expanduser()
API_KEY_FILE = CONFIG_DIR / "api_key"
@@ -41,12 +44,23 @@ class Sender(StrEnum):
BOT = "assistant"
TOOL = "tool"
def setup_state(state):
# 如果存在config则从config中加载数据
config = Config()
if config.OPENAI_API_KEY:
state["api_key"] = config.OPENAI_API_KEY
else:
state["api_key"] = ""
if config.OPENAI_BASE_URL:
state["base_url"] = config.OPENAI_BASE_URL
else:
state["base_url"] = "https://api.openai.com/v1"
if config.OPENAI_MODEL:
state["model"] = config.OPENAI_MODEL
else:
state["model"] = "gpt-4o"
if "messages" not in state:
state["messages"] = []
if "model" not in state:
state["model"] = "gpt-4o"
if "api_key" not in state:
state["api_key"] = ""
if "auth_validated" not in state:
state["auth_validated"] = False
if "responses" not in state:
@@ -59,8 +73,6 @@ def setup_state(state):
state['chatbot_messages'] = []
if 'stop' not in state:
state['stop'] = False
if 'base_url' not in state:
state['base_url'] = "https://api.openai-next.com/v1"
async def main(state):
"""Render loop for Gradio"""
@@ -156,11 +168,12 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b
# print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
def process_input(user_input, state):
def process_input(user_input, state, vision_agent_state):
# Reset the stop flag
if state["stop"]:
state["stop"] = False
config = Config()
config.set_openai_config(base_url=state["base_url"], api_key=state["api_key"], model=state["model"])
# Append the user message to state["messages"]
state["messages"].append(
{
@@ -173,17 +186,15 @@ def process_input(user_input, state):
state['chatbot_messages'].append((user_input, None)) # 确保格式正确
yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
# Run sampling_loop_sync with the chatbot_output_callback
agent = vision_agent_state["agent"]
for loop_msg in sampling_loop_sync(
model=state["model"],
messages=state["messages"],
output_callback=partial(chatbot_output_callback, chatbot_state=state['chatbot_messages'], hide_images=False),
tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
api_key=state["api_key"],
only_n_most_recent_images=state["only_n_most_recent_images"],
max_tokens=8000,
omniparser_url=args.omniparser_server_url,
base_url = state["base_url"]
vision_agent = agent
):
if loop_msg is None or state.get("stop"):
yield state['chatbot_messages']
@@ -244,14 +255,14 @@ def run():
with gr.Column():
model = gr.Textbox(
label="Model",
value="gpt-4o",
value=state.value["model"],
placeholder="输入模型名称",
interactive=True,
)
with gr.Column():
base_url = gr.Textbox(
label="Base URL",
value="https://api.openai-next.com/v1",
value=state.value["base_url"],
placeholder="输入基础 URL",
interactive=True
)
@@ -268,7 +279,7 @@ def run():
api_key = gr.Textbox(
label="API Key",
type="password",
value=state.value.get("api_key", ""),
value=state.value["api_key"],
placeholder="Paste your API key here",
interactive=True,
)
@@ -285,15 +296,11 @@ def run():
chatbot = gr.Chatbot(
label="Chatbot History",
autoscroll=True,
height=580 )
height=580
)
def update_model(model_selection, state):
state["model"] = model_selection
api_key_update = gr.update(
placeholder="API Key",
value=state["api_key"]
)
return api_key_update
def update_model(model, state):
state["model"] = model
def update_api_key(api_key_value, state):
state["api_key"] = api_key_value
@@ -309,10 +316,13 @@ def run():
state['chatbot_messages'] = []
return state['chatbot_messages']
model.change(fn=update_model, inputs=[model, state], outputs=[api_key])
model.change(fn=update_model, inputs=[model, state], outputs=None)
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
submit_button.click(process_input, [chat_input, state], chatbot)
vision_agent = VisionAgent(yolo_model_path=os.path.join(MODEL_DIR, "icon_detect", "model.pt"),
caption_model_path=os.path.join(MODEL_DIR, "icon_caption"))
vision_agent_state = gr.State({"agent": vision_agent})
submit_button.click(process_input, [chat_input, state, vision_agent_state], chatbot)
stop_button.click(stop_app, [state], None)
base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None)
demo.launch(server_name="0.0.0.0", server_port=7888)

View File

@@ -2,22 +2,24 @@
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
"""
from collections.abc import Callable
from enum import StrEnum
from time import sleep
import cv2
from gradio_ui.agent.vision_agent import VisionAgent
from gradio_ui.tools.screen_capture import get_screenshot
from anthropic import APIResponse
from anthropic.types import (
TextBlock,
)
from anthropic.types.beta import (
BetaContentBlock,
BetaMessage,
BetaMessageParam
)
from gradio_ui.agent.task_plan_agent import TaskPlanAgent
from gradio_ui.agent.task_run_agent import TaskRunAgent
from gradio_ui.tools import ToolResult
from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient
from gradio_ui.agent.vlm_agent import VLMAgent
from gradio_ui.executor.anthropic_executor import AnthropicExecutor
import numpy as np
from PIL import Image
OUTPUT_DIR = "./tmp/outputs"
def sampling_loop_sync(
*,
@@ -26,39 +28,69 @@ def sampling_loop_sync(
output_callback: Callable[[BetaContentBlock], None],
tool_output_callback: Callable[[ToolResult, str], None],
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
api_key: str,
only_n_most_recent_images: int | None = 2,
max_tokens: int = 4096,
omniparser_url: str,
base_url: str
only_n_most_recent_images: int | None = 0,
vision_agent: VisionAgent
):
"""
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
"""
print('in sampling_loop_sync, model:', model)
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
actor = VLMAgent(
model=model,
api_key=api_key,
base_url=base_url,
api_response_callback=api_response_callback,
output_callback=output_callback,
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
)
task_plan_agent = TaskPlanAgent(output_callback=output_callback)
executor = AnthropicExecutor(
output_callback=output_callback,
tool_output_callback=tool_output_callback,
)
tool_result_content = None
print(f"Start the message loop. User messages: {messages}")
plan = task_plan_agent(user_task = messages[-1]["content"][0].text)
task_run_agent = TaskRunAgent(output_callback=output_callback)
while True:
parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
parsed_screen = parse_screen(vision_agent)
tools_use_needed, __ = task_run_agent(task_plan=plan, parsed_screen=parsed_screen)
sleep(2)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content:
return messages
return messages
def parse_screen(vision_agent: VisionAgent):
screenshot, screenshot_path = get_screenshot()
response_json = {}
response_json['parsed_content_list'] = vision_agent(str(screenshot_path))
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['image'] = draw_elements(screenshot, response_json['parsed_content_list'])
return response_json
def draw_elements(screenshot, parsed_content_list):
"""
将PIL图像转换为OpenCV兼容格式并绘制边界框
Args:
screenshot: PIL Image对象
parsed_content_list: 包含边界框信息的列表
Returns:
带有绘制边界框的PIL图像
"""
# 将PIL图像转换为opencv格式
opencv_image = np.array(screenshot)
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
# 绘制边界框
for idx, element in enumerate(parsed_content_list):
bbox = element.coordinates
x1, y1, x2, y2 = bbox
# 转换坐标为整数
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# 绘制矩形
cv2.rectangle(opencv_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
# 在矩形边框左上角绘制序号
cv2.putText(opencv_image, str(idx+1), (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
# 将OpenCV图像格式转换回PIL格式
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(opencv_image)
return pil_image

View File

@@ -1,6 +1,5 @@
import base64
import time
from enum import StrEnum
from typing import Literal, TypedDict
from PIL import Image
from util import tool
@@ -175,6 +174,8 @@ class ComputerTool(BaseAnthropicTool):
pyautogui.click()
elif action == "right_click":
pyautogui.rightClick()
# 等待5秒等待菜单弹出
time.sleep(5)
elif action == "middle_click":
pyautogui.middleClick()
elif action == "double_click":

View File

@@ -1,8 +1,7 @@
from pathlib import Path
from uuid import uuid4
from PIL import Image
from .base import BaseAnthropicTool, ToolError
from io import BytesIO
from .base import ToolError
from util import tool
OUTPUT_DIR = "./tmp/outputs"

74
main.py
View File

@@ -1,78 +1,20 @@
import subprocess
from threading import Thread
import time
import requests
from gradio_ui import app
from util import download_weights
import torch
import socket
def run():
try:
print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True
print("cuda is_available: ", torch.cuda.is_available())
print("MPS is_available: ", torch.backends.mps.is_available())
print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1
print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
print("cuda device_count", torch.cuda.device_count())
print("cuda device_name", torch.cuda.get_device_name(0))
except Exception:
print("显卡驱动不适配请根据readme安装合适版本的 torch")
print("GPU driver is not compatible, please install the appropriate version of torch according to the readme!")
# download the weight files
download_weights.download()
app.run()
server_process = subprocess.Popen(
["python", "./omniserver.py"],
stdout=subprocess.PIPE, # 捕获标准输出
stderr=subprocess.PIPE,
text=True
)
stdout_thread = Thread(
target=stream_reader,
args=(server_process.stdout, "SERVER-OUT")
)
stderr_thread = Thread(
target=stream_reader,
args=(server_process.stderr, "SERVER-ERR")
)
stdout_thread.daemon = True
stderr_thread.daemon = True
stdout_thread.start()
stderr_thread.start()
try:
# 下载权重文件
download_weights.download()
print("启动Omniserver服务中因为加载模型真的超级慢请耐心等待")
while True:
try:
res = requests.get("http://127.0.0.1:8000/probe/", timeout=5)
if res.status_code == 200:
print("Omniparser服务启动成功...")
break
except (requests.ConnectionError, requests.Timeout):
pass
if server_process.poll() is not None:
raise RuntimeError(f"服务器进程报错退出:{server_process.returncode}")
print("等待服务启动...")
time.sleep(10)
app.run()
finally:
if server_process.poll() is None: # 如果进程还在运行
server_process.terminate() # 发送终止信号
server_process.wait(timeout=8) # 等待进程结束
def stream_reader(pipe, prefix):
for line in pipe:
print(f"[{prefix}]", line, end="", flush=True)
def is_port_occupied(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if __name__ == '__main__':
# 检测8000端口是否被占用
if is_port_occupied(8000):
print("8000端口被占用请先关闭占用该端口的进程")
exit()
run()

View File

@@ -11,7 +11,8 @@ import argparse
import uvicorn
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(root_dir)
from util.omniparser import Omniparser
# from util.omniparser import Omniparser
from gradio_ui.agent.vision_agent import VisionAgent
def parse_arguments():
parser = argparse.ArgumentParser(description='Omniparser API')
@@ -29,8 +30,10 @@ args = parse_arguments()
config = vars(args)
app = FastAPI()
omniparser = Omniparser(config)
# omniparser = Omniparser(config)
yolo_model_path = config['som_model_path']
caption_model_path = config['caption_model_path']
vision_agent = VisionAgent(yolo_model_path=yolo_model_path, caption_model_path=caption_model_path)
class ParseRequest(BaseModel):
base64_image: str
@@ -38,10 +41,12 @@ class ParseRequest(BaseModel):
async def parse(parse_request: ParseRequest):
print('start parsing...')
start = time.time()
dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
# dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
parsed_content_list = vision_agent(parse_request.base64_image)
latency = time.time() - start
print('time:', latency)
return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
return {"parsed_content_list": parsed_content_list, 'latency': latency}
@app.get("/probe/")
async def root():

View File

@@ -2,31 +2,13 @@ torch
easyocr
torchvision
supervision==0.18.0
openai==1.3.5
transformers
ultralytics==8.3.70
azure-identity
numpy==1.26.4
opencv-python
opencv-python-headless
gradio
dill
accelerate
pyautogui==0.9.54
anthropic[bedrock,vertex]>=0.37.1
pyxbrain==1.1.31
timm
einops==0.8.0
paddlepaddle
paddleocr
ruff==0.6.7
pre-commit==3.8.0
pytest==8.3.3
pytest-asyncio==0.23.6
pyautogui==0.9.54
streamlit>=1.38.0
anthropic[bedrock,vertex]>=0.37.1
jsonschema==4.22.0
boto3>=1.28.57
google-auth<3,>=2
screeninfo
uiautomation
dashscope
groq
modelscope

View File

@@ -1,262 +0,0 @@
from typing import List, Optional, Union, Tuple
import cv2
import numpy as np
from supervision.detection.core import Detections
from supervision.draw.color import Color, ColorPalette
class BoxAnnotator:
"""
A class for drawing bounding boxes on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to draw the bounding box,
can be a single color or a color palette
thickness (int): The thickness of the bounding box lines, default is 2
text_color (Color): The color of the text on the bounding box, default is white
text_scale (float): The scale of the text on the bounding box, default is 0.5
text_thickness (int): The thickness of the text on the bounding box,
default is 1
text_padding (int): The padding around the text on the bounding box,
default is 5
"""
def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
text_color: Color = Color.BLACK,
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
text_thickness: int = 2, #1, # 2 for demo
text_padding: int = 10,
avoid_overlap: bool = True,
):
self.color: Union[Color, ColorPalette] = color
self.thickness: int = thickness
self.text_color: Color = text_color
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
self.text_padding: int = text_padding
self.avoid_overlap: bool = avoid_overlap
def annotate(
self,
scene: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
skip_label: bool = False,
image_size: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.
Args:
scene (np.ndarray): The image on which the bounding boxes will be drawn
detections (Detections): The detections for which the
bounding boxes will be drawn
labels (Optional[List[str]]): An optional list of labels
corresponding to each detection. If `labels` are not provided,
corresponding `class_id` will be used as label.
skip_label (bool): Is set to `True`, skips bounding box label annotation.
Returns:
np.ndarray: The image with the bounding boxes drawn on it
Example:
```python
import supervision as sv
classes = ['person', ...]
image = ...
detections = sv.Detections(...)
box_annotator = sv.BoxAnnotator()
labels = [
f"{classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _ in detections
]
annotated_frame = box_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=labels
)
```
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)
cv2.rectangle(
img=scene,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
if skip_label:
continue
text = (
f"{class_id}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]
if not self.avoid_overlap:
text_x = x1 + self.text_padding
text_y = y1 - self.text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * self.text_padding - text_height
text_background_x2 = x1 + 2 * self.text_padding + text_width
text_background_y2 = y1
# text_x = x1 - self.text_padding - text_width
# text_y = y1 + self.text_padding + text_height
# text_background_x1 = x1 - 2 * self.text_padding - text_width
# text_background_y1 = y1
# text_background_x2 = x1
# text_background_y2 = y1 + 2 * self.text_padding + text_height
else:
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
cv2.rectangle(
img=scene,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
# import pdb; pdb.set_trace()
box_color = color.as_rgb()
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
text_color = (0,0,0) if luminance > 160 else (255,255,255)
cv2.putText(
img=scene,
text=text,
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
# color=self.text_color.as_rgb(),
color=text_color,
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return scene
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2, return_max=True):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
if return_max:
return max(intersection / union, ratio1, ratio2)
else:
return intersection / union
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
""" check overlap of text and background detection box, and get_optimal_label_pos,
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
Threshold: default to 0.3
"""
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
is_overlap = False
for i in range(len(detections)):
detection = detections.xyxy[i].astype(int)
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
is_overlap = True
break
# check if the text is out of the image
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
is_overlap = True
return is_overlap
# if pos == 'top left':
text_x = x1 + text_padding
text_y = y1 - text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x1 + 2 * text_padding + text_width
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer left':
text_x = x1 - text_padding - text_width
text_y = y1 + text_padding + text_height
text_background_x1 = x1 - 2 * text_padding - text_width
text_background_y1 = y1
text_background_x2 = x1
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer right':
text_x = x2 + text_padding
text_y = y1 + text_padding + text_height
text_background_x1 = x2
text_background_y1 = y1
text_background_x2 = x2 + 2 * text_padding + text_width
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'top right':
text_x = x2 - text_padding - text_width
text_y = y1 - text_padding
text_background_x1 = x2 - 2 * text_padding - text_width
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x2
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2

View File

@@ -1,12 +1,14 @@
import subprocess
import os
from pathlib import Path
from modelscope import snapshot_download
__WEIGHTS_DIR = Path("weights")
MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
def download():
# 创建权重目录
weights_dir = Path("weights")
weights_dir.mkdir(exist_ok=True)
# Create weights directory
__WEIGHTS_DIR.mkdir(exist_ok=True)
# 需要下载的文件列表
# List of files to download
files = [
"icon_detect/train_args.yaml",
"icon_detect/model.pt",
@@ -16,42 +18,24 @@ def download():
"icon_caption/model.safetensors"
]
# 检查并下载缺失的文件
# Check and download missing files
missing_files = []
for file in files:
file_path = weights_dir / file
if not file_path.exists():
file_path = os.path.join(MODEL_DIR, file)
if not os.path.exists(file_path):
missing_files.append(file)
break
if not missing_files:
print("已经检测到模型文件!")
print("Model files already detected!")
return
print(f"未检测到模型文件,需要下载 {len(missing_files)} 个文件")
# 下载缺失的文件
max_retries = 3 # 最大重试次数
for file in missing_files:
for attempt in range(max_retries):
try:
print(f"正在下载: {file} (尝试 {attempt + 1}/{max_retries})")
cmd = [
"huggingface-cli",
"download",
"microsoft/OmniParser-v2.0",
file,
"--local-dir",
"weights"
]
subprocess.run(cmd, check=True)
break # 下载成功,跳出重试循环
except subprocess.CalledProcessError as e:
if attempt == max_retries - 1: # 最后一次尝试
print(f"下载失败: {file},已达到最大重试次数")
raise # 重新抛出异常
print(f"下载失败: {file},正在重试...")
continue
print("下载完成")
snapshot_download(
'AI-ModelScope/OmniParser-v2.0',
cache_dir='weights'
)
print("Download complete")
if __name__ == "__main__":
download()

View File

@@ -1,31 +0,0 @@
from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
import torch
from PIL import Image
import io
import base64
from typing import Dict
class Omniparser(object):
def __init__(self, config: Dict):
self.config = config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.som_model = get_yolo_model(model_path=config['som_model_path'])
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
print('Server initialized!')
def parse(self, image_base64: str):
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
print('image size:', image.size)
box_overlay_ratio = max(image.size) / 3200
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
return dino_labled_img, parsed_content_list

View File

@@ -2,7 +2,6 @@ import os
import shlex
import subprocess
import threading
import traceback
import pyautogui
from PIL import Image
from io import BytesIO

View File

@@ -1,540 +0,0 @@
# from ultralytics import YOLO
import os
import io
import base64
import time
from PIL import Image, ImageDraw, ImageFont
import json
import requests
# utility function
import os
from openai import AzureOpenAI
import json
import sys
import os
import cv2
import numpy as np
# %matplotlib inline
from matplotlib import pyplot as plt
import easyocr
from paddleocr import PaddleOCR
reader = easyocr.Reader(['en', 'ch_sim'])
paddle_ocr = PaddleOCR(
lang='ch', # other lang also available
use_angle_cls=False,
use_gpu=False, # using cuda will conflict with pytorch in the same process
show_log=False,
max_batch_size=1024,
use_dilation=True, # improves accuracy
det_db_score_mode='slow', # improves accuracy
rec_batch_num=1024)
import time
import base64
import os
import ast
import torch
from typing import Tuple, List, Union
from torchvision.ops import box_convert
import re
from torchvision.transforms import ToPILImage
import supervision as sv
import torchvision.transforms as T
from util.box_annotator import BoxAnnotator
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "blip2":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float16
).to(device)
elif model_name == "florence2":
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
return {'model': model.to(device), 'processor': processor}
def get_yolo_model(model_path):
from ultralytics import YOLO
# Load the model.
model = YOLO(model_path)
return model
@torch.inference_mode()
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128):
# Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
to_pil = ToPILImage()
if starting_idx:
non_ocr_boxes = filtered_boxes[starting_idx:]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
for i, coord in enumerate(non_ocr_boxes):
try:
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
cropped_image = cv2.resize(cropped_image, (64, 64))
croped_pil_image.append(to_pil(cropped_image))
except:
continue
model, processor = caption_model_processor['model'], caption_model_processor['processor']
if not prompt:
if 'florence' in model.config.model_type:
prompt = "<CAPTION>"
else:
prompt = "The image shows"
generated_texts = []
device = model.device
for i in range(0, len(croped_pil_image), batch_size):
start = time.time()
batch = croped_pil_image[i:i+batch_size]
t1 = time.time()
if model.device.type == 'cuda':
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
else:
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
if 'florence' in model.config.model_type:
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
else:
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_text = [gen.strip() for gen in generated_text]
generated_texts.extend(generated_text)
return generated_texts
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
to_pil = ToPILImage()
if ocr_bbox:
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
for i, coord in enumerate(non_ocr_boxes):
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
croped_pil_image.append(to_pil(cropped_image))
model, processor = caption_model_processor['model'], caption_model_processor['processor']
device = model.device
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch_size = 5 # Number of samples per batch
generated_texts = []
for i in range(0, len(croped_pil_image), batch_size):
images = croped_pil_image[i:i+batch_size]
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
texts = [prompt] * len(images)
for i, txt in enumerate(texts):
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
inputs['input_ids'].append(input['input_ids'])
inputs['attention_mask'].append(input['attention_mask'])
inputs['pixel_values'].append(input['pixel_values'])
inputs['image_sizes'].append(input['image_sizes'])
max_len = max([x.shape[1] for x in inputs['input_ids']])
for i, v in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
generation_args = {
"max_new_tokens": 25,
"temperature": 0.01,
"do_sample": False,
}
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
# # remove input tokens
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
response = [res.strip('\n').strip() for res in response]
generated_texts.extend(response)
return generated_texts
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.95
boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1 in enumerate(boxes):
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
is_valid_box = True
for j, box2 in enumerate(boxes):
# keep the smaller box
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
# add the following 2 lines to include ocr bbox
if ocr_bbox:
# only add the box if it does not overlap with any ocr bbox
if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
filtered_boxes.append(box1)
else:
filtered_boxes.append(box1)
return torch.tensor(filtered_boxes)
def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
'''
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
'''
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.80
# boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1_elem in enumerate(boxes):
box1 = box1_elem['bbox']
is_valid_box = True
for j, box2_elem in enumerate(boxes):
# keep the smaller box
box2 = box2_elem['bbox']
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
if ocr_bbox:
# keep yolo boxes + prioritize ocr label
box_added = False
ocr_labels = ''
for box3_elem in ocr_bbox:
if not box_added:
box3 = box3_elem['bbox']
if is_inside(box3, box1): # ocr inside icon
# box_added = True
# delete the box3_elem from ocr_bbox
try:
# gather all ocr labels
ocr_labels += box3_elem['content'] + ' '
filtered_boxes.remove(box3_elem)
except:
continue
# break
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
box_added = True
break
else:
continue
if not box_added:
if ocr_labels:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
else:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
else:
filtered_boxes.append(box1)
return filtered_boxes # torch.tensor(filtered_boxes)
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
"""
This function annotates an image with bounding boxes and labels.
Parameters:
image_source (np.ndarray): The source image to be annotated.
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
phrases (List[str]): A list of labels for each bounding box.
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
Returns:
np.ndarray: The annotated image.
"""
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
annotated_frame = image_source.copy()
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
return annotated_frame, label_coordinates
def predict(model, image, caption, box_threshold, text_threshold):
""" Use huggingface model to replace the original model
"""
model, processor = model['model'], model['processor']
device = model.device
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold, # 0.4,
text_threshold=text_threshold, # 0.3,
target_sizes=[image.size[::-1]]
)[0]
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
return boxes, logits, phrases
def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
""" Use huggingface model to replace the original model
"""
# model = model['model']
if scale_img:
result = model.predict(
source=image,
conf=box_threshold,
imgsz=imgsz,
iou=iou_threshold, # default 0.7
)
else:
result = model.predict(
source=image,
conf=box_threshold,
iou=iou_threshold, # default 0.7
)
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
conf = result[0].boxes.conf
phrases = [str(i) for i in range(len(boxes))]
return boxes, conf, phrases
def int_box_area(box, w, h):
x1, y1, x2, y2 = box
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
return area
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128):
"""Process either an image path or Image object
Args:
image_source: Either a file path (str) or PIL Image object
...
"""
if isinstance(image_source, str):
image_source = Image.open(image_source)
image_source = image_source.convert("RGB") # for CLIP
w, h = image_source.size
if not imgsz:
imgsz = (h, w)
# print('image size:', w, h)
xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
image_source = np.asarray(image_source)
phrases = [str(i) for i in range(len(phrases))]
# annotate the image with labels
if ocr_bbox:
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
ocr_bbox=ocr_bbox.tolist()
else:
print('no ocr bbox!!!')
ocr_bbox = None
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
# get the index of the first 'content': None
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
# get parsed icon local semantics
time1 = time.time()
if use_local_semantics:
caption_model = caption_model_processor['model']
if 'phi3_v' in caption_model.config.model_type:
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
else:
parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
icon_start = len(ocr_text)
parsed_content_icon_ls = []
# fill the filtered_boxes_elem None content with parsed_content_icon in order
for i, box in enumerate(filtered_boxes_elem):
if box['content'] is None:
box['content'] = parsed_content_icon.pop(0)
for i, txt in enumerate(parsed_content_icon):
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
parsed_content_merged = ocr_text + parsed_content_icon_ls
else:
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
parsed_content_merged = ocr_text
print('time to get parsed content:', time.time()-time1)
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
phrases = [i for i in range(len(filtered_boxes))]
# draw boxes
if draw_bbox_config:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
else:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
pil_img = Image.fromarray(annotated_frame)
buffered = io.BytesIO()
pil_img.save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
if output_coord_in_ratio:
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
return encoded_image, label_coordinates, filtered_boxes_elem
def get_xywh(input):
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def get_xyxy(input):
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
return x, y, xp, yp
def get_xywh_yolo(input):
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
if isinstance(image_source, str):
image_source = Image.open(image_source)
if image_source.mode == 'RGBA':
# Convert RGBA to RGB to avoid alpha channel issues
image_source = image_source.convert('RGB')
image_np = np.array(image_source)
w, h = image_source.size
if use_paddleocr:
if easyocr_args is None:
text_threshold = 0.5
else:
text_threshold = easyocr_args['text_threshold']
result = paddle_ocr.ocr(image_np, cls=False)[0]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
else: # EasyOCR
if easyocr_args is None:
easyocr_args = {}
result = reader.readtext(image_np, **easyocr_args)
coord = [item[0] for item in result]
text = [item[1] for item in result]
if display_img:
opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
bb = []
for item in coord:
x, y, a, b = get_xywh(item)
bb.append((x, y, a, b))
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
# matplotlib expects RGB
plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
else:
if output_bb_format == 'xywh':
bb = [get_xywh(item) for item in coord]
elif output_bb_format == 'xyxy':
bb = [get_xyxy(item) for item in coord]
return (text, bb), goal_filtering