mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 04:57:18 +08:00
调task agent 和 run agent
This commit is contained in:
@@ -1,262 +1,262 @@
|
||||
from typing import List, Optional, Union, Tuple
|
||||
# from typing import List, Optional, Union, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
# import cv2
|
||||
# import numpy as np
|
||||
|
||||
from supervision.detection.core import Detections
|
||||
from supervision.draw.color import Color, ColorPalette
|
||||
# from supervision.detection.core import Detections
|
||||
# from supervision.draw.color import Color, ColorPalette
|
||||
|
||||
|
||||
class BoxAnnotator:
|
||||
"""
|
||||
A class for drawing bounding boxes on an image using detections provided.
|
||||
# class BoxAnnotator:
|
||||
# """
|
||||
# A class for drawing bounding boxes on an image using detections provided.
|
||||
|
||||
Attributes:
|
||||
color (Union[Color, ColorPalette]): The color to draw the bounding box,
|
||||
can be a single color or a color palette
|
||||
thickness (int): The thickness of the bounding box lines, default is 2
|
||||
text_color (Color): The color of the text on the bounding box, default is white
|
||||
text_scale (float): The scale of the text on the bounding box, default is 0.5
|
||||
text_thickness (int): The thickness of the text on the bounding box,
|
||||
default is 1
|
||||
text_padding (int): The padding around the text on the bounding box,
|
||||
default is 5
|
||||
# Attributes:
|
||||
# color (Union[Color, ColorPalette]): The color to draw the bounding box,
|
||||
# can be a single color or a color palette
|
||||
# thickness (int): The thickness of the bounding box lines, default is 2
|
||||
# text_color (Color): The color of the text on the bounding box, default is white
|
||||
# text_scale (float): The scale of the text on the bounding box, default is 0.5
|
||||
# text_thickness (int): The thickness of the text on the bounding box,
|
||||
# default is 1
|
||||
# text_padding (int): The padding around the text on the bounding box,
|
||||
# default is 5
|
||||
|
||||
"""
|
||||
# """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
|
||||
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
|
||||
text_color: Color = Color.BLACK,
|
||||
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
text_thickness: int = 2, #1, # 2 for demo
|
||||
text_padding: int = 10,
|
||||
avoid_overlap: bool = True,
|
||||
):
|
||||
self.color: Union[Color, ColorPalette] = color
|
||||
self.thickness: int = thickness
|
||||
self.text_color: Color = text_color
|
||||
self.text_scale: float = text_scale
|
||||
self.text_thickness: int = text_thickness
|
||||
self.text_padding: int = text_padding
|
||||
self.avoid_overlap: bool = avoid_overlap
|
||||
# def __init__(
|
||||
# self,
|
||||
# color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
|
||||
# thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
|
||||
# text_color: Color = Color.BLACK,
|
||||
# text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
# text_thickness: int = 2, #1, # 2 for demo
|
||||
# text_padding: int = 10,
|
||||
# avoid_overlap: bool = True,
|
||||
# ):
|
||||
# self.color: Union[Color, ColorPalette] = color
|
||||
# self.thickness: int = thickness
|
||||
# self.text_color: Color = text_color
|
||||
# self.text_scale: float = text_scale
|
||||
# self.text_thickness: int = text_thickness
|
||||
# self.text_padding: int = text_padding
|
||||
# self.avoid_overlap: bool = avoid_overlap
|
||||
|
||||
def annotate(
|
||||
self,
|
||||
scene: np.ndarray,
|
||||
detections: Detections,
|
||||
labels: Optional[List[str]] = None,
|
||||
skip_label: bool = False,
|
||||
image_size: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draws bounding boxes on the frame using the detections provided.
|
||||
# def annotate(
|
||||
# self,
|
||||
# scene: np.ndarray,
|
||||
# detections: Detections,
|
||||
# labels: Optional[List[str]] = None,
|
||||
# skip_label: bool = False,
|
||||
# image_size: Optional[Tuple[int, int]] = None,
|
||||
# ) -> np.ndarray:
|
||||
# """
|
||||
# Draws bounding boxes on the frame using the detections provided.
|
||||
|
||||
Args:
|
||||
scene (np.ndarray): The image on which the bounding boxes will be drawn
|
||||
detections (Detections): The detections for which the
|
||||
bounding boxes will be drawn
|
||||
labels (Optional[List[str]]): An optional list of labels
|
||||
corresponding to each detection. If `labels` are not provided,
|
||||
corresponding `class_id` will be used as label.
|
||||
skip_label (bool): Is set to `True`, skips bounding box label annotation.
|
||||
Returns:
|
||||
np.ndarray: The image with the bounding boxes drawn on it
|
||||
# Args:
|
||||
# scene (np.ndarray): The image on which the bounding boxes will be drawn
|
||||
# detections (Detections): The detections for which the
|
||||
# bounding boxes will be drawn
|
||||
# labels (Optional[List[str]]): An optional list of labels
|
||||
# corresponding to each detection. If `labels` are not provided,
|
||||
# corresponding `class_id` will be used as label.
|
||||
# skip_label (bool): Is set to `True`, skips bounding box label annotation.
|
||||
# Returns:
|
||||
# np.ndarray: The image with the bounding boxes drawn on it
|
||||
|
||||
Example:
|
||||
```python
|
||||
import supervision as sv
|
||||
# Example:
|
||||
# ```python
|
||||
# import supervision as sv
|
||||
|
||||
classes = ['person', ...]
|
||||
image = ...
|
||||
detections = sv.Detections(...)
|
||||
# classes = ['person', ...]
|
||||
# image = ...
|
||||
# detections = sv.Detections(...)
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
labels = [
|
||||
f"{classes[class_id]} {confidence:0.2f}"
|
||||
for _, _, confidence, class_id, _ in detections
|
||||
]
|
||||
annotated_frame = box_annotator.annotate(
|
||||
scene=image.copy(),
|
||||
detections=detections,
|
||||
labels=labels
|
||||
)
|
||||
```
|
||||
"""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
for i in range(len(detections)):
|
||||
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
||||
class_id = (
|
||||
detections.class_id[i] if detections.class_id is not None else None
|
||||
)
|
||||
idx = class_id if class_id is not None else i
|
||||
color = (
|
||||
self.color.by_idx(idx)
|
||||
if isinstance(self.color, ColorPalette)
|
||||
else self.color
|
||||
)
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(x1, y1),
|
||||
pt2=(x2, y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=self.thickness,
|
||||
)
|
||||
if skip_label:
|
||||
continue
|
||||
# box_annotator = sv.BoxAnnotator()
|
||||
# labels = [
|
||||
# f"{classes[class_id]} {confidence:0.2f}"
|
||||
# for _, _, confidence, class_id, _ in detections
|
||||
# ]
|
||||
# annotated_frame = box_annotator.annotate(
|
||||
# scene=image.copy(),
|
||||
# detections=detections,
|
||||
# labels=labels
|
||||
# )
|
||||
# ```
|
||||
# """
|
||||
# font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# for i in range(len(detections)):
|
||||
# x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
||||
# class_id = (
|
||||
# detections.class_id[i] if detections.class_id is not None else None
|
||||
# )
|
||||
# idx = class_id if class_id is not None else i
|
||||
# color = (
|
||||
# self.color.by_idx(idx)
|
||||
# if isinstance(self.color, ColorPalette)
|
||||
# else self.color
|
||||
# )
|
||||
# cv2.rectangle(
|
||||
# img=scene,
|
||||
# pt1=(x1, y1),
|
||||
# pt2=(x2, y2),
|
||||
# color=color.as_bgr(),
|
||||
# thickness=self.thickness,
|
||||
# )
|
||||
# if skip_label:
|
||||
# continue
|
||||
|
||||
text = (
|
||||
f"{class_id}"
|
||||
if (labels is None or len(detections) != len(labels))
|
||||
else labels[i]
|
||||
)
|
||||
# text = (
|
||||
# f"{class_id}"
|
||||
# if (labels is None or len(detections) != len(labels))
|
||||
# else labels[i]
|
||||
# )
|
||||
|
||||
text_width, text_height = cv2.getTextSize(
|
||||
text=text,
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
thickness=self.text_thickness,
|
||||
)[0]
|
||||
# text_width, text_height = cv2.getTextSize(
|
||||
# text=text,
|
||||
# fontFace=font,
|
||||
# fontScale=self.text_scale,
|
||||
# thickness=self.text_thickness,
|
||||
# )[0]
|
||||
|
||||
if not self.avoid_overlap:
|
||||
text_x = x1 + self.text_padding
|
||||
text_y = y1 - self.text_padding
|
||||
# if not self.avoid_overlap:
|
||||
# text_x = x1 + self.text_padding
|
||||
# text_y = y1 - self.text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * self.text_padding - text_height
|
||||
# text_background_x1 = x1
|
||||
# text_background_y1 = y1 - 2 * self.text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * self.text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
# text_x = x1 - self.text_padding - text_width
|
||||
# text_y = y1 + self.text_padding + text_height
|
||||
# text_background_x1 = x1 - 2 * self.text_padding - text_width
|
||||
# text_background_y1 = y1
|
||||
# text_background_x2 = x1
|
||||
# text_background_y2 = y1 + 2 * self.text_padding + text_height
|
||||
else:
|
||||
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
|
||||
# text_background_x2 = x1 + 2 * self.text_padding + text_width
|
||||
# text_background_y2 = y1
|
||||
# # text_x = x1 - self.text_padding - text_width
|
||||
# # text_y = y1 + self.text_padding + text_height
|
||||
# # text_background_x1 = x1 - 2 * self.text_padding - text_width
|
||||
# # text_background_y1 = y1
|
||||
# # text_background_x2 = x1
|
||||
# # text_background_y2 = y1 + 2 * self.text_padding + text_height
|
||||
# else:
|
||||
# text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
|
||||
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(text_background_x1, text_background_y1),
|
||||
pt2=(text_background_x2, text_background_y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=cv2.FILLED,
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
box_color = color.as_rgb()
|
||||
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
||||
text_color = (0,0,0) if luminance > 160 else (255,255,255)
|
||||
cv2.putText(
|
||||
img=scene,
|
||||
text=text,
|
||||
org=(text_x, text_y),
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
# color=self.text_color.as_rgb(),
|
||||
color=text_color,
|
||||
thickness=self.text_thickness,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
return scene
|
||||
# cv2.rectangle(
|
||||
# img=scene,
|
||||
# pt1=(text_background_x1, text_background_y1),
|
||||
# pt2=(text_background_x2, text_background_y2),
|
||||
# color=color.as_bgr(),
|
||||
# thickness=cv2.FILLED,
|
||||
# )
|
||||
# # import pdb; pdb.set_trace()
|
||||
# box_color = color.as_rgb()
|
||||
# luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
||||
# text_color = (0,0,0) if luminance > 160 else (255,255,255)
|
||||
# cv2.putText(
|
||||
# img=scene,
|
||||
# text=text,
|
||||
# org=(text_x, text_y),
|
||||
# fontFace=font,
|
||||
# fontScale=self.text_scale,
|
||||
# # color=self.text_color.as_rgb(),
|
||||
# color=text_color,
|
||||
# thickness=self.text_thickness,
|
||||
# lineType=cv2.LINE_AA,
|
||||
# )
|
||||
# return scene
|
||||
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
# def box_area(box):
|
||||
# return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
# def intersection_area(box1, box2):
|
||||
# x1 = max(box1[0], box2[0])
|
||||
# y1 = max(box1[1], box2[1])
|
||||
# x2 = min(box1[2], box2[2])
|
||||
# y2 = min(box1[3], box2[3])
|
||||
# return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2, return_max=True):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
if return_max:
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
else:
|
||||
return intersection / union
|
||||
# def IoU(box1, box2, return_max=True):
|
||||
# intersection = intersection_area(box1, box2)
|
||||
# union = box_area(box1) + box_area(box2) - intersection
|
||||
# if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
# ratio1 = intersection / box_area(box1)
|
||||
# ratio2 = intersection / box_area(box2)
|
||||
# else:
|
||||
# ratio1, ratio2 = 0, 0
|
||||
# if return_max:
|
||||
# return max(intersection / union, ratio1, ratio2)
|
||||
# else:
|
||||
# return intersection / union
|
||||
|
||||
|
||||
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
|
||||
""" check overlap of text and background detection box, and get_optimal_label_pos,
|
||||
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
|
||||
Threshold: default to 0.3
|
||||
"""
|
||||
# def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
|
||||
# """ check overlap of text and background detection box, and get_optimal_label_pos,
|
||||
# pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
|
||||
# Threshold: default to 0.3
|
||||
# """
|
||||
|
||||
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
|
||||
is_overlap = False
|
||||
for i in range(len(detections)):
|
||||
detection = detections.xyxy[i].astype(int)
|
||||
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
|
||||
is_overlap = True
|
||||
break
|
||||
# check if the text is out of the image
|
||||
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
|
||||
is_overlap = True
|
||||
return is_overlap
|
||||
# def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
|
||||
# is_overlap = False
|
||||
# for i in range(len(detections)):
|
||||
# detection = detections.xyxy[i].astype(int)
|
||||
# if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
|
||||
# is_overlap = True
|
||||
# break
|
||||
# # check if the text is out of the image
|
||||
# if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
|
||||
# is_overlap = True
|
||||
# return is_overlap
|
||||
|
||||
# if pos == 'top left':
|
||||
text_x = x1 + text_padding
|
||||
text_y = y1 - text_padding
|
||||
# # if pos == 'top left':
|
||||
# text_x = x1 + text_padding
|
||||
# text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
# text_background_x1 = x1
|
||||
# text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
# text_background_x2 = x1 + 2 * text_padding + text_width
|
||||
# text_background_y2 = y1
|
||||
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
# if not is_overlap:
|
||||
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'outer left':
|
||||
text_x = x1 - text_padding - text_width
|
||||
text_y = y1 + text_padding + text_height
|
||||
# # elif pos == 'outer left':
|
||||
# text_x = x1 - text_padding - text_width
|
||||
# text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x1 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1
|
||||
# text_background_x1 = x1 - 2 * text_padding - text_width
|
||||
# text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x1
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
# text_background_x2 = x1
|
||||
# text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
# if not is_overlap:
|
||||
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
|
||||
# elif pos == 'outer right':
|
||||
text_x = x2 + text_padding
|
||||
text_y = y1 + text_padding + text_height
|
||||
# # elif pos == 'outer right':
|
||||
# text_x = x2 + text_padding
|
||||
# text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x2
|
||||
text_background_y1 = y1
|
||||
# text_background_x1 = x2
|
||||
# text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x2 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
# text_background_x2 = x2 + 2 * text_padding + text_width
|
||||
# text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
# if not is_overlap:
|
||||
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'top right':
|
||||
text_x = x2 - text_padding - text_width
|
||||
text_y = y1 - text_padding
|
||||
# # elif pos == 'top right':
|
||||
# text_x = x2 - text_padding - text_width
|
||||
# text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x2 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
# text_background_x1 = x2 - 2 * text_padding - text_width
|
||||
# text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x2
|
||||
text_background_y2 = y1
|
||||
# text_background_x2 = x2
|
||||
# text_background_y2 = y1
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
# if not is_overlap:
|
||||
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
@@ -1,31 +1,32 @@
|
||||
from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
|
||||
import torch
|
||||
from PIL import Image
|
||||
import io
|
||||
import base64
|
||||
from typing import Dict
|
||||
class Omniparser(object):
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
|
||||
print('Server initialized!')
|
||||
# from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
|
||||
# import torch
|
||||
# from PIL import Image
|
||||
# import io
|
||||
# import base64
|
||||
# from typing import Dict
|
||||
# class Omniparser(object):
|
||||
# def __init__(self, config: Dict):
|
||||
# self.config = config
|
||||
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
# self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
# self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
|
||||
# print('Server initialized!')
|
||||
|
||||
def parse(self, image_base64: str):
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
print('image size:', image.size)
|
||||
# def parse(self, image_base64: str):
|
||||
# image_bytes = base64.b64decode(image_base64)
|
||||
# image = Image.open(io.BytesIO(image_bytes))
|
||||
# print('image size:', image.size)
|
||||
|
||||
box_overlay_ratio = max(image.size) / 3200
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8 * box_overlay_ratio,
|
||||
'text_thickness': max(int(2 * box_overlay_ratio), 1),
|
||||
'text_padding': max(int(3 * box_overlay_ratio), 1),
|
||||
'thickness': max(int(3 * box_overlay_ratio), 1),
|
||||
}
|
||||
# box_overlay_ratio = max(image.size) / 3200
|
||||
# draw_bbox_config = {
|
||||
# 'text_scale': 0.8 * box_overlay_ratio,
|
||||
# 'text_thickness': max(int(2 * box_overlay_ratio), 1),
|
||||
# 'text_padding': max(int(3 * box_overlay_ratio), 1),
|
||||
# 'thickness': max(int(3 * box_overlay_ratio), 1),
|
||||
# }
|
||||
|
||||
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||
# (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
|
||||
# dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||
|
||||
return dino_labled_img, parsed_content_list
|
||||
# return dino_labled_img, parsed_content_list
|
||||
|
||||
|
||||
948
util/utils.py
948
util/utils.py
File diff suppressed because it is too large
Load Diff
214
util/vision_agent.py
Normal file
214
util/vision_agent.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from typing import List, Optional
|
||||
import cv2
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
import easyocr
|
||||
import supervision as sv
|
||||
import numpy as np
|
||||
import time
|
||||
from pydantic import BaseModel
|
||||
import base64
|
||||
|
||||
class UIElement(BaseModel):
|
||||
element_id: int
|
||||
coordinates: list[float]
|
||||
caption: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
|
||||
class VisionAgent:
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
|
||||
"""
|
||||
Computer vision agent for UI analysis.
|
||||
|
||||
Args:
|
||||
yolo_model_path: Path to YOLO model weights
|
||||
caption_model_path: Name/path to captioning model (defaults to Florence-2)
|
||||
"""
|
||||
self.device = self._get_available_device()
|
||||
self.dtype = self._get_dtype()
|
||||
self.elements: List[UIElement] = []
|
||||
|
||||
self.yolo_model = YOLO(yolo_model_path)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path, trust_remote_code=True
|
||||
).to(self.device)
|
||||
self.caption_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base", trust_remote_code=True
|
||||
)
|
||||
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
|
||||
|
||||
def _get_available_device(self) -> str:
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
if torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
def _get_dtype(self)-> torch.dtype:
|
||||
if torch.cuda.is_available():
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
||||
def _reset_state(self):
|
||||
"""Clear previous analysis results"""
|
||||
self.elements = []
|
||||
|
||||
def analyze_image(self, image: np.ndarray) -> List[UIElement]:
|
||||
"""
|
||||
Process an image through all computer vision pipelines.
|
||||
|
||||
Args:
|
||||
image: Input image in BGR format (OpenCV default)
|
||||
|
||||
Returns:
|
||||
List of detected UI elements with annotations
|
||||
"""
|
||||
self._reset_state()
|
||||
|
||||
element_crops, boxes = self._detect_objects(image)
|
||||
start = time.time()
|
||||
element_texts = self._extract_text(element_crops)
|
||||
end = time.time()
|
||||
ocr_time = (end-start) * 10 ** 3
|
||||
print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
|
||||
start = time.time()
|
||||
element_captions = self._get_caption(element_crops)
|
||||
end = time.time()
|
||||
caption_time = (end-start) * 10 ** 3
|
||||
print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
|
||||
for idx in range(len(element_crops)):
|
||||
print(idx, boxes[idx], element_texts[idx], element_captions[idx])
|
||||
new_element = UIElement(element_id=idx,
|
||||
coordinates=boxes[idx],
|
||||
text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
|
||||
caption=element_captions[idx]
|
||||
)
|
||||
self.elements.append(new_element)
|
||||
|
||||
return self.elements
|
||||
|
||||
def _extract_text(self, images: np.ndarray) -> list[str]:
|
||||
"""
|
||||
Run OCR in sequential mode
|
||||
TODO: It is possible to run in batch mode for a speed up, but the result quality needs test.
|
||||
https://github.com/JaidedAI/EasyOCR/pull/458
|
||||
"""
|
||||
texts = []
|
||||
for image in images:
|
||||
text = self.ocr_reader.readtext(image, detail=0, paragraph=True, text_threshold=0.85)
|
||||
texts.append(text)
|
||||
# print(texts)
|
||||
return texts
|
||||
|
||||
|
||||
def _get_caption(self, images: np.ndarray, batch_size: int = 1) -> list[str]:
|
||||
"""Run captioning in batched mode. TODO: adjust batch size"""
|
||||
prompt = "<CAPTION>"
|
||||
generated_texts = []
|
||||
resized_images = []
|
||||
for image in images:
|
||||
resized_image = cv2.resize(image, (64, 64))
|
||||
resized_images.append(resized_image)
|
||||
|
||||
for i in range(0, len(resized_images), batch_size):
|
||||
batch_images = resized_images[i:i+batch_size]
|
||||
inputs = self.caption_processor(
|
||||
images=batch_images,
|
||||
text=[prompt] * len(batch_images),
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
generated_ids = self.caption_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=10,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
early_stopping=False,
|
||||
)
|
||||
|
||||
generated_text = self.caption_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=True
|
||||
)
|
||||
generated_texts.extend([gen.strip() for gen in generated_text])
|
||||
|
||||
return generated_texts
|
||||
|
||||
|
||||
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
|
||||
"""Run object detection pipeline"""
|
||||
results = self.yolo_model(image)[0]
|
||||
detections = sv.Detections.from_ultralytics(results)
|
||||
boxes = detections.xyxy
|
||||
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# Filter out boxes contained by others
|
||||
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
sorted_indices = np.argsort(-areas) # Sort descending by area
|
||||
sorted_boxes = boxes[sorted_indices]
|
||||
|
||||
keep_sorted = []
|
||||
for i in range(len(sorted_boxes)):
|
||||
contained = False
|
||||
for j in keep_sorted:
|
||||
box_b = sorted_boxes[j]
|
||||
box_a = sorted_boxes[i]
|
||||
if (box_b[0] <= box_a[0] and box_b[1] <= box_a[1] and
|
||||
box_b[2] >= box_a[2] and box_b[3] >= box_a[3]):
|
||||
contained = True
|
||||
break
|
||||
if not contained:
|
||||
keep_sorted.append(i)
|
||||
|
||||
# Map back to original indices
|
||||
keep_indices = sorted_indices[keep_sorted]
|
||||
filtered_boxes = boxes[keep_indices]
|
||||
|
||||
# Extract element crops
|
||||
element_crops = []
|
||||
for box in filtered_boxes:
|
||||
x1, y1, x2, y2 = map(int, map(round, box))
|
||||
element = image[y1:y2, x1:x2]
|
||||
element_crops.append(np.array(element))
|
||||
|
||||
return element_crops, filtered_boxes
|
||||
|
||||
def load_image(self, image_source: str) -> np.ndarray:
|
||||
try:
|
||||
# 处理可能存在的Data URL前缀(如 "data:image/png;base64,")
|
||||
if ',' in image_source:
|
||||
_, payload = image_source.split(',', 1)
|
||||
else:
|
||||
payload = image_source
|
||||
|
||||
# Base64解码 -> bytes -> numpy数组
|
||||
image_bytes = base64.b64decode(payload)
|
||||
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
|
||||
# OpenCV解码图像
|
||||
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("解码图片失败:无效的图片数据")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
except (base64.binascii.Error, ValueError) as e:
|
||||
# 生成更清晰的错误信息
|
||||
error_msg = f"输入既不是有效的文件路径,也不是有效的Base64图片数据"
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def __call__(self, image_source: str) -> List[UIElement]:
|
||||
"""Process an image from file path."""
|
||||
image = self.load_image(image_source)
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Vision agent: 图片读取失败")
|
||||
|
||||
return self.analyze_image(image)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user