mirror of
https://github.com/yuruotong1/autoMate.git
synced 2025-12-26 05:16:21 +08:00
310 lines
11 KiB
Python
310 lines
11 KiB
Python
import os
|
|
from typing import List, Optional
|
|
import cv2
|
|
import torch
|
|
from ultralytics import YOLO
|
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
import easyocr
|
|
import supervision as sv
|
|
import numpy as np
|
|
import time
|
|
from pydantic import BaseModel
|
|
import base64
|
|
from PIL import Image
|
|
from transformers import AutoConfig
|
|
import os
|
|
|
|
class UIElement(BaseModel):
|
|
element_id: int
|
|
coordinates: list[float]
|
|
caption: Optional[str] = None
|
|
text: Optional[str] = None
|
|
|
|
class VisionAgent:
|
|
def __init__(self, yolo_model_path: str, caption_model_path: str):
|
|
"""
|
|
Initialize the vision agent
|
|
|
|
Parameters:
|
|
yolo_model_path: Path to YOLO model
|
|
caption_model_path: Path to image caption model
|
|
"""
|
|
# determine the available device and the best dtype
|
|
self.device, self.dtype = self._get_optimal_device_and_dtype()
|
|
# load the YOLO model
|
|
self.yolo_model = YOLO(yolo_model_path)
|
|
|
|
# load the image caption model and processor
|
|
self.caption_processor = AutoProcessor.from_pretrained(
|
|
"weights/AI-ModelScope/Florence-2-base-ft",
|
|
trust_remote_code=True,
|
|
local_files_only=True
|
|
)
|
|
config = AutoConfig.from_pretrained(
|
|
"weights/AI-ModelScope/Florence-2-base-ft", # 指向包含 configuration_florence2.py 的目录
|
|
trust_remote_code=True,
|
|
local_files_only=True
|
|
)
|
|
|
|
try:
|
|
# 修改:加载模型和权重都从 florence 目录
|
|
florence_base_path = "weights/AI-ModelScope/Florence-2-base-ft"
|
|
|
|
# 直接从 florence 目录完整加载模型(包括权重)
|
|
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
|
florence_base_path, # 这里使用包含代码和权重的完整目录
|
|
torch_dtype=self.dtype,
|
|
trust_remote_code=True,
|
|
local_files_only=True
|
|
).to(self.device)
|
|
"processor",
|
|
trust_remote_code=True
|
|
)
|
|
|
|
try:
|
|
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
|
caption_model_path,
|
|
torch_dtype=self.dtype,
|
|
trust_remote_code=True
|
|
).to(self.device)
|
|
|
|
# 不需要额外加载权重,因为权重已经包含在 florence_base_path 中
|
|
|
|
except Exception as e:
|
|
print(f"Model loading failed: {e}")
|
|
print(f"Model loading failed for path: {caption_model_path}")
|
|
raise e
|
|
self.prompt = "<CAPTION>"
|
|
|
|
# set the batch size
|
|
if self.device.type == 'cuda':
|
|
self.batch_size = 128
|
|
elif self.device.type == 'mps':
|
|
self.batch_size = 128
|
|
else:
|
|
self.batch_size = 16
|
|
|
|
self.elements: List[UIElement] = []
|
|
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
|
|
|
|
def __call__(self, image_path: str) -> List[UIElement]:
|
|
"""Process an image from file path."""
|
|
# image = self.load_image(image_source)
|
|
image = cv2.imread(image_path)
|
|
if image is None:
|
|
raise FileNotFoundError(f"Vision agent: Failed to read image")
|
|
return self.analyze_image(image)
|
|
|
|
def _get_optimal_device_and_dtype(self):
|
|
"""determine the optimal device and dtype"""
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
# check if the GPU is suitable for using float16
|
|
capability = torch.cuda.get_device_capability()
|
|
# only use float16 on newer GPUs
|
|
if capability[0] >= 7:
|
|
dtype = torch.float16
|
|
else:
|
|
dtype = torch.float32
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
device = torch.device("mps")
|
|
dtype = torch.float32
|
|
else:
|
|
device = torch.device("cpu")
|
|
dtype = torch.float32
|
|
|
|
return device, dtype
|
|
|
|
def _reset_state(self):
|
|
"""Clear previous analysis results"""
|
|
self.elements = []
|
|
|
|
def analyze_image(self, image: np.ndarray) -> List[UIElement]:
|
|
"""
|
|
Process an image through all computer vision pipelines.
|
|
|
|
Args:
|
|
image: Input image in BGR format (OpenCV default)
|
|
|
|
Returns:
|
|
List of detected UI elements with annotations
|
|
"""
|
|
self._reset_state()
|
|
|
|
element_crops, boxes = self._detect_objects(image)
|
|
start = time.time()
|
|
element_texts = self._extract_text(element_crops)
|
|
end = time.time()
|
|
ocr_time = (end-start) * 10 ** 3
|
|
print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
|
|
start = time.time()
|
|
element_captions = self._get_caption(element_crops, 5)
|
|
end = time.time()
|
|
caption_time = (end-start) * 10 ** 3
|
|
print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
|
|
for idx in range(len(element_crops)):
|
|
new_element = UIElement(element_id=idx,
|
|
coordinates=boxes[idx],
|
|
text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
|
|
caption=element_captions[idx]
|
|
)
|
|
self.elements.append(new_element)
|
|
|
|
return self.elements
|
|
|
|
def _extract_text(self, images: np.ndarray) -> list[str]:
|
|
"""
|
|
Run OCR in sequential mode
|
|
TODO: It is possible to run in batch mode for a speed up, but the result quality needs test.
|
|
https://github.com/JaidedAI/EasyOCR/pull/458
|
|
"""
|
|
texts = []
|
|
for image in images:
|
|
text = self.ocr_reader.readtext(image, detail=0, paragraph=True, text_threshold=0.85)
|
|
texts.append(text)
|
|
# print(texts)
|
|
return texts
|
|
|
|
def _get_caption(self, element_crops, batch_size=None):
|
|
"""get the caption of the element crops"""
|
|
if not element_crops:
|
|
return []
|
|
|
|
# if batch_size is not specified, use the instance's default value
|
|
if batch_size is None:
|
|
batch_size = self.batch_size
|
|
|
|
# resize the image to 64x64
|
|
resized_crops = []
|
|
for img in element_crops:
|
|
# convert to numpy array, resize, then convert back to PIL
|
|
img_np = np.array(img)
|
|
resized_np = cv2.resize(img_np, (64, 64))
|
|
resized_crops.append(Image.fromarray(resized_np))
|
|
|
|
generated_texts = []
|
|
device = self.device
|
|
|
|
# process in batches
|
|
for i in range(0, len(resized_crops), batch_size):
|
|
batch = resized_crops[i:i+batch_size]
|
|
try:
|
|
# select the dtype according to the device type
|
|
if device.type == 'cuda':
|
|
inputs = self.caption_processor(
|
|
images=batch,
|
|
text=[self.prompt] * len(batch),
|
|
return_tensors="pt",
|
|
do_resize=False
|
|
).to(device=device, dtype=torch.float16)
|
|
else:
|
|
# MPS and CPU use float32
|
|
inputs = self.caption_processor(
|
|
images=batch,
|
|
text=[self.prompt] * len(batch),
|
|
return_tensors="pt"
|
|
).to(device=device)
|
|
|
|
# special treatment for Florence-2
|
|
with torch.no_grad():
|
|
if 'florence' in self.caption_model.config.model_type:
|
|
generated_ids = self.caption_model.generate(
|
|
input_ids=inputs["input_ids"],
|
|
pixel_values=inputs["pixel_values"],
|
|
max_new_tokens=20,
|
|
num_beams=5,
|
|
do_sample=False
|
|
)
|
|
else:
|
|
generated_ids = self.caption_model.generate(
|
|
**inputs,
|
|
max_length=50,
|
|
num_beams=3,
|
|
early_stopping=True
|
|
)
|
|
|
|
# decode the generated IDs
|
|
texts = self.caption_processor.batch_decode(
|
|
generated_ids,
|
|
skip_special_tokens=True
|
|
)
|
|
texts = [text.strip() for text in texts]
|
|
generated_texts.extend(texts)
|
|
|
|
# clean the cache
|
|
if device.type == 'cuda' and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
except RuntimeError as e:
|
|
raise e
|
|
return generated_texts
|
|
|
|
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
|
|
"""Run object detection pipeline"""
|
|
results = self.yolo_model(image)[0]
|
|
detections = sv.Detections.from_ultralytics(results)
|
|
boxes = detections.xyxy
|
|
|
|
if len(boxes) == 0:
|
|
return []
|
|
|
|
# Filter out boxes contained by others
|
|
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
sorted_indices = np.argsort(-areas) # Sort descending by area
|
|
sorted_boxes = boxes[sorted_indices]
|
|
|
|
keep_sorted = []
|
|
for i in range(len(sorted_boxes)):
|
|
contained = False
|
|
for j in keep_sorted:
|
|
box_b = sorted_boxes[j]
|
|
box_a = sorted_boxes[i]
|
|
if (box_b[0] <= box_a[0] and box_b[1] <= box_a[1] and
|
|
box_b[2] >= box_a[2] and box_b[3] >= box_a[3]):
|
|
contained = True
|
|
break
|
|
if not contained:
|
|
keep_sorted.append(i)
|
|
|
|
# Map back to original indices
|
|
keep_indices = sorted_indices[keep_sorted]
|
|
filtered_boxes = boxes[keep_indices]
|
|
|
|
# Extract element crops
|
|
element_crops = []
|
|
for box in filtered_boxes:
|
|
x1, y1, x2, y2 = map(int, map(round, box))
|
|
element = image[y1:y2, x1:x2]
|
|
element_crops.append(np.array(element))
|
|
|
|
return element_crops, filtered_boxes
|
|
|
|
def load_image(self, image_source: str) -> np.ndarray:
|
|
try:
|
|
# Handle potential Data URL prefix (like "data:image/png;base64,")
|
|
if ',' in image_source:
|
|
_, payload = image_source.split(',', 1)
|
|
else:
|
|
payload = image_source
|
|
|
|
# Base64 decode -> bytes -> numpy array
|
|
image_bytes = base64.b64decode(payload)
|
|
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
|
|
|
# OpenCV decode image
|
|
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
|
|
|
if image is None:
|
|
raise ValueError("Failed to decode image: Invalid image data")
|
|
|
|
return self.analyze_image(image)
|
|
|
|
except (base64.binascii.Error, ValueError) as e:
|
|
# Generate clearer error message
|
|
error_msg = f"Input is neither a valid file path nor valid Base64 image data"
|
|
raise ValueError(error_msg) from e
|
|
|
|
|
|
|
|
|
|
|