from typing import List, Optional import cv2 import torch from ultralytics import YOLO from transformers import AutoModelForCausalLM, AutoProcessor import easyocr import supervision as sv import numpy as np import time from pydantic import BaseModel import base64 from PIL import Image class UIElement(BaseModel): element_id: int coordinates: list[float] caption: Optional[str] = None text: Optional[str] = None class VisionAgent: def __init__(self, yolo_model_path: str, caption_model_path: str): """ Initialize the vision agent Parameters: yolo_model_path: Path to YOLO model caption_model_path: Path to image caption model """ # determine the available device and the best dtype self.device, self.dtype = self._get_optimal_device_and_dtype() # load the YOLO model self.yolo_model = YOLO(yolo_model_path) # load the image caption model and processor self.caption_processor = AutoProcessor.from_pretrained( "weights/AI-ModelScope/Florence-2-base", trust_remote_code=True, local_files_only=True ) try: self.caption_model = AutoModelForCausalLM.from_pretrained( caption_model_path, torch_dtype=self.dtype, trust_remote_code=True ).to(self.device) except Exception as e: print(f"Model loading failed for path: {caption_model_path}") raise e self.prompt = "" # set the batch size if self.device.type == 'cuda': self.batch_size = 128 elif self.device.type == 'mps': self.batch_size = 128 else: self.batch_size = 16 self.elements: List[UIElement] = [] self.ocr_reader = easyocr.Reader(['en', 'ch_sim']) def __call__(self, image_path: str) -> List[UIElement]: """Process an image from file path.""" # image = self.load_image(image_source) image = cv2.imread(image_path) if image is None: raise FileNotFoundError(f"Vision agent: Failed to read image") return self.analyze_image(image) def _get_optimal_device_and_dtype(self): """determine the optimal device and dtype""" if torch.cuda.is_available(): device = torch.device("cuda") # check if the GPU is suitable for using float16 capability = torch.cuda.get_device_capability() # only use float16 on newer GPUs if capability[0] >= 7: dtype = torch.float16 else: dtype = torch.float32 elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device("mps") dtype = torch.float32 else: device = torch.device("cpu") dtype = torch.float32 return device, dtype def _reset_state(self): """Clear previous analysis results""" self.elements = [] def analyze_image(self, image: np.ndarray) -> List[UIElement]: """ Process an image through all computer vision pipelines. Args: image: Input image in BGR format (OpenCV default) Returns: List of detected UI elements with annotations """ self._reset_state() element_crops, boxes = self._detect_objects(image) start = time.time() element_texts = self._extract_text(element_crops) end = time.time() ocr_time = (end-start) * 10 ** 3 print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.") start = time.time() element_captions = self._get_caption(element_crops, 5) end = time.time() caption_time = (end-start) * 10 ** 3 print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.") for idx in range(len(element_crops)): new_element = UIElement(element_id=idx, coordinates=boxes[idx], text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '', caption=element_captions[idx] ) self.elements.append(new_element) return self.elements def _extract_text(self, images: np.ndarray) -> list[str]: """ Run OCR in sequential mode TODO: It is possible to run in batch mode for a speed up, but the result quality needs test. https://github.com/JaidedAI/EasyOCR/pull/458 """ texts = [] for image in images: text = self.ocr_reader.readtext(image, detail=0, paragraph=True, text_threshold=0.85) texts.append(text) # print(texts) return texts def _get_caption(self, element_crops, batch_size=None): """get the caption of the element crops""" if not element_crops: return [] # if batch_size is not specified, use the instance's default value if batch_size is None: batch_size = self.batch_size # resize the image to 64x64 resized_crops = [] for img in element_crops: # convert to numpy array, resize, then convert back to PIL img_np = np.array(img) resized_np = cv2.resize(img_np, (64, 64)) resized_crops.append(Image.fromarray(resized_np)) generated_texts = [] device = self.device # process in batches for i in range(0, len(resized_crops), batch_size): batch = resized_crops[i:i+batch_size] try: # select the dtype according to the device type if device.type == 'cuda': inputs = self.caption_processor( images=batch, text=[self.prompt] * len(batch), return_tensors="pt", do_resize=False ).to(device=device, dtype=torch.float16) else: # MPS and CPU use float32 inputs = self.caption_processor( images=batch, text=[self.prompt] * len(batch), return_tensors="pt" ).to(device=device) # special treatment for Florence-2 with torch.no_grad(): if 'florence' in self.caption_model.config.model_type: generated_ids = self.caption_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=20, num_beams=5, do_sample=False ) else: generated_ids = self.caption_model.generate( **inputs, max_length=50, num_beams=3, early_stopping=True ) # decode the generated IDs texts = self.caption_processor.batch_decode( generated_ids, skip_special_tokens=True ) texts = [text.strip() for text in texts] generated_texts.extend(texts) # clean the cache if device.type == 'cuda' and torch.cuda.is_available(): torch.cuda.empty_cache() except RuntimeError as e: raise e return generated_texts def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]: """Run object detection pipeline""" results = self.yolo_model(image)[0] detections = sv.Detections.from_ultralytics(results) boxes = detections.xyxy if len(boxes) == 0: return [] # Filter out boxes contained by others areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) sorted_indices = np.argsort(-areas) # Sort descending by area sorted_boxes = boxes[sorted_indices] keep_sorted = [] for i in range(len(sorted_boxes)): contained = False for j in keep_sorted: box_b = sorted_boxes[j] box_a = sorted_boxes[i] if (box_b[0] <= box_a[0] and box_b[1] <= box_a[1] and box_b[2] >= box_a[2] and box_b[3] >= box_a[3]): contained = True break if not contained: keep_sorted.append(i) # Map back to original indices keep_indices = sorted_indices[keep_sorted] filtered_boxes = boxes[keep_indices] # Extract element crops element_crops = [] for box in filtered_boxes: x1, y1, x2, y2 = map(int, map(round, box)) element = image[y1:y2, x1:x2] element_crops.append(np.array(element)) return element_crops, filtered_boxes def load_image(self, image_source: str) -> np.ndarray: try: # Handle potential Data URL prefix (like "data:image/png;base64,") if ',' in image_source: _, payload = image_source.split(',', 1) else: payload = image_source # Base64 decode -> bytes -> numpy array image_bytes = base64.b64decode(payload) np_array = np.frombuffer(image_bytes, dtype=np.uint8) # OpenCV decode image image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) if image is None: raise ValueError("Failed to decode image: Invalid image data") return self.analyze_image(image) except (base64.binascii.Error, ValueError) as e: # Generate clearer error message error_msg = f"Input is neither a valid file path nor valid Base64 image data" raise ValueError(error_msg) from e