增加了处理后图片的返回

This commit is contained in:
Dan Li
2025-03-12 09:01:01 +03:00
parent 458a0c5cbe
commit a34abfa61e

View File

@@ -42,27 +42,32 @@ class VisionAgent:
# 根据设备类型加载模型
try:
print(f"正在加载图像描述模型: {caption_model_path}")
if self.device.type == 'cuda':
# CUDA设备使用float16
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float16,
trust_remote_code=True
).to(self.device)
elif self.device.type == 'mps':
# MPS设备使用float32MPS对float16支持有限
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
else:
# CPU使用float32
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=torch.float32,
trust_remote_code=True
).to(self.device)
# if self.device.type == 'cuda':
# # CUDA设备使用float16
# self.caption_model = AutoModelForCausalLM.from_pretrained(
# caption_model_path,
# torch_dtype=torch.float16,
# trust_remote_code=True
# ).to(self.device)
# elif self.device.type == 'mps':
# # MPS设备使用float32MPS对float16支持有限
# self.caption_model = AutoModelForCausalLM.from_pretrained(
# caption_model_path,
# torch_dtype=torch.float32,
# trust_remote_code=True
# ).to(self.device)
# else:
# # CPU使用float32
# self.caption_model = AutoModelForCausalLM.from_pretrained(
# caption_model_path,
# torch_dtype=torch.float32,
# trust_remote_code=True
# ).to(self.device)
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path,
torch_dtype=self.dtype,
trust_remote_code=True
).to(self.device)
print("图像描述模型加载成功")
except Exception as e:
@@ -75,9 +80,9 @@ class VisionAgent:
if self.device.type == 'cuda':
self.batch_size = 128 # CUDA设备使用较大批处理大小
elif self.device.type == 'mps':
self.batch_size = 32 # MPS设备使用中等批处理大小
self.batch_size = 64 # MPS设备使用中等批处理大小
else:
self.batch_size = 16 # CPU使用较小批处理大小
self.batch_size = 32 # CPU使用较小批处理大小
self.elements: List[UIElement] = []
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
@@ -88,6 +93,7 @@ class VisionAgent:
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"Vision agent: 图片读取失败")
return self.analyze_image(image)
def _get_optimal_device_and_dtype(self):
@@ -133,7 +139,8 @@ class VisionAgent:
"""
self._reset_state()
element_crops, boxes = self._detect_objects(image)
element_crops, boxes, annotated_image = self._detect_objects(image)
cv2.imwrite("annotated_image.jpg", annotated_image)
start = time.time()
element_texts = self._extract_text(element_crops)
end = time.time()
@@ -152,7 +159,7 @@ class VisionAgent:
)
self.elements.append(new_element)
return self.elements
return self.elements, annotated_image
def _extract_text(self, images: np.ndarray) -> list[str]:
"""
@@ -246,6 +253,16 @@ class VisionAgent:
results = self.yolo_model(image)[0]
detections = sv.Detections.from_ultralytics(results)
boxes = detections.xyxy
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
labels = [
f"{idx}" for idx, box in enumerate(detections.xyxy)]
annotated_image = box_annotator.annotate(
scene=image, detections=detections)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections, labels=labels)
if len(boxes) == 0:
return []
@@ -279,7 +296,7 @@ class VisionAgent:
element = image[y1:y2, x1:x2]
element_crops.append(np.array(element))
return element_crops, filtered_boxes
return element_crops, filtered_boxes, annotated_image
def load_image(self, image_source: str) -> np.ndarray:
try:
@@ -304,9 +321,4 @@ class VisionAgent:
except (base64.binascii.Error, ValueError) as e:
# 生成更清晰的错误信息
error_msg = f"输入既不是有效的文件路径也不是有效的Base64图片数据"
raise ValueError(error_msg) from e
raise ValueError(error_msg) from e