mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 04:57:18 +08:00
增加了处理后图片的返回
This commit is contained in:
@@ -42,27 +42,32 @@ class VisionAgent:
|
||||
# 根据设备类型加载模型
|
||||
try:
|
||||
print(f"正在加载图像描述模型: {caption_model_path}")
|
||||
if self.device.type == 'cuda':
|
||||
# CUDA设备使用float16
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
elif self.device.type == 'mps':
|
||||
# MPS设备使用float32(MPS对float16支持有限)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
else:
|
||||
# CPU使用float32
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
# if self.device.type == 'cuda':
|
||||
# # CUDA设备使用float16
|
||||
# self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
# caption_model_path,
|
||||
# torch_dtype=torch.float16,
|
||||
# trust_remote_code=True
|
||||
# ).to(self.device)
|
||||
# elif self.device.type == 'mps':
|
||||
# # MPS设备使用float32(MPS对float16支持有限)
|
||||
# self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
# caption_model_path,
|
||||
# torch_dtype=torch.float32,
|
||||
# trust_remote_code=True
|
||||
# ).to(self.device)
|
||||
# else:
|
||||
# # CPU使用float32
|
||||
# self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
# caption_model_path,
|
||||
# torch_dtype=torch.float32,
|
||||
# trust_remote_code=True
|
||||
# ).to(self.device)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=self.dtype,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
print("图像描述模型加载成功")
|
||||
except Exception as e:
|
||||
@@ -75,9 +80,9 @@ class VisionAgent:
|
||||
if self.device.type == 'cuda':
|
||||
self.batch_size = 128 # CUDA设备使用较大批处理大小
|
||||
elif self.device.type == 'mps':
|
||||
self.batch_size = 32 # MPS设备使用中等批处理大小
|
||||
self.batch_size = 64 # MPS设备使用中等批处理大小
|
||||
else:
|
||||
self.batch_size = 16 # CPU使用较小批处理大小
|
||||
self.batch_size = 32 # CPU使用较小批处理大小
|
||||
|
||||
self.elements: List[UIElement] = []
|
||||
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
|
||||
@@ -88,6 +93,7 @@ class VisionAgent:
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: 图片读取失败")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
def _get_optimal_device_and_dtype(self):
|
||||
@@ -133,7 +139,8 @@ class VisionAgent:
|
||||
"""
|
||||
self._reset_state()
|
||||
|
||||
element_crops, boxes = self._detect_objects(image)
|
||||
element_crops, boxes, annotated_image = self._detect_objects(image)
|
||||
cv2.imwrite("annotated_image.jpg", annotated_image)
|
||||
start = time.time()
|
||||
element_texts = self._extract_text(element_crops)
|
||||
end = time.time()
|
||||
@@ -152,7 +159,7 @@ class VisionAgent:
|
||||
)
|
||||
self.elements.append(new_element)
|
||||
|
||||
return self.elements
|
||||
return self.elements, annotated_image
|
||||
|
||||
def _extract_text(self, images: np.ndarray) -> list[str]:
|
||||
"""
|
||||
@@ -246,6 +253,16 @@ class VisionAgent:
|
||||
results = self.yolo_model(image)[0]
|
||||
detections = sv.Detections.from_ultralytics(results)
|
||||
boxes = detections.xyxy
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
label_annotator = sv.LabelAnnotator()
|
||||
|
||||
labels = [
|
||||
f"{idx}" for idx, box in enumerate(detections.xyxy)]
|
||||
|
||||
annotated_image = box_annotator.annotate(
|
||||
scene=image, detections=detections)
|
||||
annotated_image = label_annotator.annotate(
|
||||
scene=annotated_image, detections=detections, labels=labels)
|
||||
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
@@ -279,7 +296,7 @@ class VisionAgent:
|
||||
element = image[y1:y2, x1:x2]
|
||||
element_crops.append(np.array(element))
|
||||
|
||||
return element_crops, filtered_boxes
|
||||
return element_crops, filtered_boxes, annotated_image
|
||||
|
||||
def load_image(self, image_source: str) -> np.ndarray:
|
||||
try:
|
||||
@@ -304,9 +321,4 @@ class VisionAgent:
|
||||
except (base64.binascii.Error, ValueError) as e:
|
||||
# 生成更清晰的错误信息
|
||||
error_msg = f"输入既不是有效的文件路径,也不是有效的Base64图片数据"
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
raise ValueError(error_msg) from e
|
||||
Reference in New Issue
Block a user