From a34abfa61e52280efb2ce3ba0e57c6805dd44d0e Mon Sep 17 00:00:00 2001 From: Dan Li <972648237@qq.com> Date: Wed, 12 Mar 2025 09:01:01 +0300 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=90=8E=E5=9B=BE=E7=89=87=E7=9A=84=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gradio_ui/agent/vision_agent.py | 76 +++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/gradio_ui/agent/vision_agent.py b/gradio_ui/agent/vision_agent.py index 254a1ff..8c87b76 100644 --- a/gradio_ui/agent/vision_agent.py +++ b/gradio_ui/agent/vision_agent.py @@ -42,27 +42,32 @@ class VisionAgent: # 根据设备类型加载模型 try: print(f"正在加载图像描述模型: {caption_model_path}") - if self.device.type == 'cuda': - # CUDA设备使用float16 - self.caption_model = AutoModelForCausalLM.from_pretrained( - caption_model_path, - torch_dtype=torch.float16, - trust_remote_code=True - ).to(self.device) - elif self.device.type == 'mps': - # MPS设备使用float32(MPS对float16支持有限) - self.caption_model = AutoModelForCausalLM.from_pretrained( - caption_model_path, - torch_dtype=torch.float32, - trust_remote_code=True - ).to(self.device) - else: - # CPU使用float32 - self.caption_model = AutoModelForCausalLM.from_pretrained( - caption_model_path, - torch_dtype=torch.float32, - trust_remote_code=True - ).to(self.device) + # if self.device.type == 'cuda': + # # CUDA设备使用float16 + # self.caption_model = AutoModelForCausalLM.from_pretrained( + # caption_model_path, + # torch_dtype=torch.float16, + # trust_remote_code=True + # ).to(self.device) + # elif self.device.type == 'mps': + # # MPS设备使用float32(MPS对float16支持有限) + # self.caption_model = AutoModelForCausalLM.from_pretrained( + # caption_model_path, + # torch_dtype=torch.float32, + # trust_remote_code=True + # ).to(self.device) + # else: + # # CPU使用float32 + # self.caption_model = AutoModelForCausalLM.from_pretrained( + # caption_model_path, + # torch_dtype=torch.float32, + # trust_remote_code=True + # ).to(self.device) + self.caption_model = AutoModelForCausalLM.from_pretrained( + caption_model_path, + torch_dtype=self.dtype, + trust_remote_code=True + ).to(self.device) print("图像描述模型加载成功") except Exception as e: @@ -75,9 +80,9 @@ class VisionAgent: if self.device.type == 'cuda': self.batch_size = 128 # CUDA设备使用较大批处理大小 elif self.device.type == 'mps': - self.batch_size = 32 # MPS设备使用中等批处理大小 + self.batch_size = 64 # MPS设备使用中等批处理大小 else: - self.batch_size = 16 # CPU使用较小批处理大小 + self.batch_size = 32 # CPU使用较小批处理大小 self.elements: List[UIElement] = [] self.ocr_reader = easyocr.Reader(['en', 'ch_sim']) @@ -88,6 +93,7 @@ class VisionAgent: image = cv2.imread(image_path) if image is None: raise FileNotFoundError(f"Vision agent: 图片读取失败") + return self.analyze_image(image) def _get_optimal_device_and_dtype(self): @@ -133,7 +139,8 @@ class VisionAgent: """ self._reset_state() - element_crops, boxes = self._detect_objects(image) + element_crops, boxes, annotated_image = self._detect_objects(image) + cv2.imwrite("annotated_image.jpg", annotated_image) start = time.time() element_texts = self._extract_text(element_crops) end = time.time() @@ -152,7 +159,7 @@ class VisionAgent: ) self.elements.append(new_element) - return self.elements + return self.elements, annotated_image def _extract_text(self, images: np.ndarray) -> list[str]: """ @@ -246,6 +253,16 @@ class VisionAgent: results = self.yolo_model(image)[0] detections = sv.Detections.from_ultralytics(results) boxes = detections.xyxy + box_annotator = sv.BoxAnnotator() + label_annotator = sv.LabelAnnotator() + + labels = [ + f"{idx}" for idx, box in enumerate(detections.xyxy)] + + annotated_image = box_annotator.annotate( + scene=image, detections=detections) + annotated_image = label_annotator.annotate( + scene=annotated_image, detections=detections, labels=labels) if len(boxes) == 0: return [] @@ -279,7 +296,7 @@ class VisionAgent: element = image[y1:y2, x1:x2] element_crops.append(np.array(element)) - return element_crops, filtered_boxes + return element_crops, filtered_boxes, annotated_image def load_image(self, image_source: str) -> np.ndarray: try: @@ -304,9 +321,4 @@ class VisionAgent: except (base64.binascii.Error, ValueError) as e: # 生成更清晰的错误信息 error_msg = f"输入既不是有效的文件路径,也不是有效的Base64图片数据" - raise ValueError(error_msg) from e - - - - - \ No newline at end of file + raise ValueError(error_msg) from e \ No newline at end of file