调task agent 和 run agent

This commit is contained in:
Dan Li
2025-03-11 13:35:09 +03:00
parent 3ca1bd6cca
commit 49cf1dfb6f
10 changed files with 972 additions and 745 deletions

View File

@@ -19,26 +19,31 @@ class OmniParserClient:
response_json = response.json()
print('omniparser latency:', response_json['latency'])
som_image_data = base64.b64decode(response_json['som_image_base64'])
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
with open(som_screenshot_path, "wb") as f:
f.write(som_image_data)
# som_image_data = base64.b64decode(response_json['som_image_base64'])
# screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
# som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
# with open(som_screenshot_path, "wb") as f:
# f.write(som_image_data)
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['original_screenshot_base64'] = image_base64
response_json['screenshot_uuid'] = screenshot_path_uuid
# response_json['screenshot_uuid'] = screenshot_path_uuid
response_json = self.reformat_messages(response_json)
return response_json
def reformat_messages(self, response_json: dict):
screen_info = ""
for idx, element in enumerate(response_json["parsed_content_list"]):
element['idx'] = idx
if element['type'] == 'text':
screen_info += f'ID: {idx}, Text: {element["content"]}\n'
elif element['type'] == 'icon':
screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
# element['idx'] = idx
# if element['type'] == 'text':
# screen_info += f'ID: {idx}, Text: {element["content"]}\n'
# elif element['type'] == 'icon':
# screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
screen_info += f'ID: {element.element_id}, '
screen_info += f'Coordinates: {element.coordinates}, '
screen_info += f'Text: {element.text if len(element.text) else " "}, '
screen_info += f'Caption: {element.caption}. '
screen_info += "\n"
response_json['screen_info'] = screen_info
return response_json

View File

@@ -1,8 +1,8 @@
from gradio_ui.agent.base_agent import BaseAgent
class TaskPlanAgent(BaseAgent):
def __init__(self, config):
super().__init__(config)
def __init__(self):
super().__init__()
self.SYSTEM_PROMPT = system_prompt

View File

@@ -6,12 +6,15 @@ from xbrain.core.chat import run
import platform
import re
class TaskRunAgent(BaseAgent):
def __init__(self,task_plan: str, screen_info):
def __init__(self):
print("TaskRunAgent initialized without a task")
def __call__(self,task_plan: str, screen_info):
self.OUTPUT_DIR = "./tmp/outputs"
device = self.get_device()
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
device=device,
screen_info=screen_info)
print(self.SYSTEM_PROMPT)
def get_device(self):
# 获取当前操作系统信息

View File

@@ -29,7 +29,7 @@ def sampling_loop_sync(
tool_output_callback: Callable[[ToolResult, str], None],
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
api_key: str,
only_n_most_recent_images: int | None = 2,
only_n_most_recent_images: int | None = 0,
max_tokens: int = 4096,
omniparser_url: str,
base_url: str
@@ -40,7 +40,6 @@ def sampling_loop_sync(
print('in sampling_loop_sync, model:', model)
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
task_plan_agent = TaskPlanAgent()
task_plan_agent()
# actor = VLMAgent(
# model=model,
# api_key=api_key,
@@ -58,7 +57,7 @@ def sampling_loop_sync(
tool_result_content = None
print(f"Start the message loop. User messages: {messages}")
plan = task_plan_agent(messages[-1]["content"][0])
plan = task_plan_agent(user_task = messages[-1]["content"][0])
task_run_agent = TaskRunAgent()
while True:

View File

@@ -11,7 +11,8 @@ import argparse
import uvicorn
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(root_dir)
from util.omniparser import Omniparser
# from util.omniparser import Omniparser
from util.vision_agent import VisionAgent
def parse_arguments():
parser = argparse.ArgumentParser(description='Omniparser API')
@@ -29,8 +30,10 @@ args = parse_arguments()
config = vars(args)
app = FastAPI()
omniparser = Omniparser(config)
# omniparser = Omniparser(config)
yolo_model_path = config['som_model_path']
caption_model_path = config['caption_model_path']
vision_agent = VisionAgent(yolo_model_path=yolo_model_path, caption_model_path=caption_model_path)
class ParseRequest(BaseModel):
base64_image: str
@@ -38,10 +41,12 @@ class ParseRequest(BaseModel):
async def parse(parse_request: ParseRequest):
print('start parsing...')
start = time.time()
dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
# dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
parsed_content_list = vision_agent(parse_request.base64_image)
latency = time.time() - start
print('time:', latency)
return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
return {"parsed_content_list": parsed_content_list, 'latency': latency}
@app.get("/probe/")
async def root():

View File

@@ -2,7 +2,7 @@ torch
easyocr
torchvision
supervision==0.18.0
openai==1.3.5
openai>=1.3.5
transformers
ultralytics==8.3.70
azure-identity

View File

@@ -1,262 +1,262 @@
from typing import List, Optional, Union, Tuple
# from typing import List, Optional, Union, Tuple
import cv2
import numpy as np
# import cv2
# import numpy as np
from supervision.detection.core import Detections
from supervision.draw.color import Color, ColorPalette
# from supervision.detection.core import Detections
# from supervision.draw.color import Color, ColorPalette
class BoxAnnotator:
"""
A class for drawing bounding boxes on an image using detections provided.
# class BoxAnnotator:
# """
# A class for drawing bounding boxes on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to draw the bounding box,
can be a single color or a color palette
thickness (int): The thickness of the bounding box lines, default is 2
text_color (Color): The color of the text on the bounding box, default is white
text_scale (float): The scale of the text on the bounding box, default is 0.5
text_thickness (int): The thickness of the text on the bounding box,
default is 1
text_padding (int): The padding around the text on the bounding box,
default is 5
# Attributes:
# color (Union[Color, ColorPalette]): The color to draw the bounding box,
# can be a single color or a color palette
# thickness (int): The thickness of the bounding box lines, default is 2
# text_color (Color): The color of the text on the bounding box, default is white
# text_scale (float): The scale of the text on the bounding box, default is 0.5
# text_thickness (int): The thickness of the text on the bounding box,
# default is 1
# text_padding (int): The padding around the text on the bounding box,
# default is 5
"""
# """
def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
text_color: Color = Color.BLACK,
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
text_thickness: int = 2, #1, # 2 for demo
text_padding: int = 10,
avoid_overlap: bool = True,
):
self.color: Union[Color, ColorPalette] = color
self.thickness: int = thickness
self.text_color: Color = text_color
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
self.text_padding: int = text_padding
self.avoid_overlap: bool = avoid_overlap
# def __init__(
# self,
# color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
# thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
# text_color: Color = Color.BLACK,
# text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
# text_thickness: int = 2, #1, # 2 for demo
# text_padding: int = 10,
# avoid_overlap: bool = True,
# ):
# self.color: Union[Color, ColorPalette] = color
# self.thickness: int = thickness
# self.text_color: Color = text_color
# self.text_scale: float = text_scale
# self.text_thickness: int = text_thickness
# self.text_padding: int = text_padding
# self.avoid_overlap: bool = avoid_overlap
def annotate(
self,
scene: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
skip_label: bool = False,
image_size: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.
# def annotate(
# self,
# scene: np.ndarray,
# detections: Detections,
# labels: Optional[List[str]] = None,
# skip_label: bool = False,
# image_size: Optional[Tuple[int, int]] = None,
# ) -> np.ndarray:
# """
# Draws bounding boxes on the frame using the detections provided.
Args:
scene (np.ndarray): The image on which the bounding boxes will be drawn
detections (Detections): The detections for which the
bounding boxes will be drawn
labels (Optional[List[str]]): An optional list of labels
corresponding to each detection. If `labels` are not provided,
corresponding `class_id` will be used as label.
skip_label (bool): Is set to `True`, skips bounding box label annotation.
Returns:
np.ndarray: The image with the bounding boxes drawn on it
# Args:
# scene (np.ndarray): The image on which the bounding boxes will be drawn
# detections (Detections): The detections for which the
# bounding boxes will be drawn
# labels (Optional[List[str]]): An optional list of labels
# corresponding to each detection. If `labels` are not provided,
# corresponding `class_id` will be used as label.
# skip_label (bool): Is set to `True`, skips bounding box label annotation.
# Returns:
# np.ndarray: The image with the bounding boxes drawn on it
Example:
```python
import supervision as sv
# Example:
# ```python
# import supervision as sv
classes = ['person', ...]
image = ...
detections = sv.Detections(...)
# classes = ['person', ...]
# image = ...
# detections = sv.Detections(...)
box_annotator = sv.BoxAnnotator()
labels = [
f"{classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _ in detections
]
annotated_frame = box_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=labels
)
```
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)
cv2.rectangle(
img=scene,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
if skip_label:
continue
# box_annotator = sv.BoxAnnotator()
# labels = [
# f"{classes[class_id]} {confidence:0.2f}"
# for _, _, confidence, class_id, _ in detections
# ]
# annotated_frame = box_annotator.annotate(
# scene=image.copy(),
# detections=detections,
# labels=labels
# )
# ```
# """
# font = cv2.FONT_HERSHEY_SIMPLEX
# for i in range(len(detections)):
# x1, y1, x2, y2 = detections.xyxy[i].astype(int)
# class_id = (
# detections.class_id[i] if detections.class_id is not None else None
# )
# idx = class_id if class_id is not None else i
# color = (
# self.color.by_idx(idx)
# if isinstance(self.color, ColorPalette)
# else self.color
# )
# cv2.rectangle(
# img=scene,
# pt1=(x1, y1),
# pt2=(x2, y2),
# color=color.as_bgr(),
# thickness=self.thickness,
# )
# if skip_label:
# continue
text = (
f"{class_id}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)
# text = (
# f"{class_id}"
# if (labels is None or len(detections) != len(labels))
# else labels[i]
# )
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]
# text_width, text_height = cv2.getTextSize(
# text=text,
# fontFace=font,
# fontScale=self.text_scale,
# thickness=self.text_thickness,
# )[0]
if not self.avoid_overlap:
text_x = x1 + self.text_padding
text_y = y1 - self.text_padding
# if not self.avoid_overlap:
# text_x = x1 + self.text_padding
# text_y = y1 - self.text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * self.text_padding - text_height
# text_background_x1 = x1
# text_background_y1 = y1 - 2 * self.text_padding - text_height
text_background_x2 = x1 + 2 * self.text_padding + text_width
text_background_y2 = y1
# text_x = x1 - self.text_padding - text_width
# text_y = y1 + self.text_padding + text_height
# text_background_x1 = x1 - 2 * self.text_padding - text_width
# text_background_y1 = y1
# text_background_x2 = x1
# text_background_y2 = y1 + 2 * self.text_padding + text_height
else:
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
# text_background_x2 = x1 + 2 * self.text_padding + text_width
# text_background_y2 = y1
# # text_x = x1 - self.text_padding - text_width
# # text_y = y1 + self.text_padding + text_height
# # text_background_x1 = x1 - 2 * self.text_padding - text_width
# # text_background_y1 = y1
# # text_background_x2 = x1
# # text_background_y2 = y1 + 2 * self.text_padding + text_height
# else:
# text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
cv2.rectangle(
img=scene,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
# import pdb; pdb.set_trace()
box_color = color.as_rgb()
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
text_color = (0,0,0) if luminance > 160 else (255,255,255)
cv2.putText(
img=scene,
text=text,
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
# color=self.text_color.as_rgb(),
color=text_color,
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return scene
# cv2.rectangle(
# img=scene,
# pt1=(text_background_x1, text_background_y1),
# pt2=(text_background_x2, text_background_y2),
# color=color.as_bgr(),
# thickness=cv2.FILLED,
# )
# # import pdb; pdb.set_trace()
# box_color = color.as_rgb()
# luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
# text_color = (0,0,0) if luminance > 160 else (255,255,255)
# cv2.putText(
# img=scene,
# text=text,
# org=(text_x, text_y),
# fontFace=font,
# fontScale=self.text_scale,
# # color=self.text_color.as_rgb(),
# color=text_color,
# thickness=self.text_thickness,
# lineType=cv2.LINE_AA,
# )
# return scene
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
# def box_area(box):
# return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
# def intersection_area(box1, box2):
# x1 = max(box1[0], box2[0])
# y1 = max(box1[1], box2[1])
# x2 = min(box1[2], box2[2])
# y2 = min(box1[3], box2[3])
# return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2, return_max=True):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
if return_max:
return max(intersection / union, ratio1, ratio2)
else:
return intersection / union
# def IoU(box1, box2, return_max=True):
# intersection = intersection_area(box1, box2)
# union = box_area(box1) + box_area(box2) - intersection
# if box_area(box1) > 0 and box_area(box2) > 0:
# ratio1 = intersection / box_area(box1)
# ratio2 = intersection / box_area(box2)
# else:
# ratio1, ratio2 = 0, 0
# if return_max:
# return max(intersection / union, ratio1, ratio2)
# else:
# return intersection / union
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
""" check overlap of text and background detection box, and get_optimal_label_pos,
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
Threshold: default to 0.3
"""
# def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
# """ check overlap of text and background detection box, and get_optimal_label_pos,
# pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
# Threshold: default to 0.3
# """
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
is_overlap = False
for i in range(len(detections)):
detection = detections.xyxy[i].astype(int)
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
is_overlap = True
break
# check if the text is out of the image
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
is_overlap = True
return is_overlap
# def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
# is_overlap = False
# for i in range(len(detections)):
# detection = detections.xyxy[i].astype(int)
# if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
# is_overlap = True
# break
# # check if the text is out of the image
# if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
# is_overlap = True
# return is_overlap
# if pos == 'top left':
text_x = x1 + text_padding
text_y = y1 - text_padding
# # if pos == 'top left':
# text_x = x1 + text_padding
# text_y = y1 - text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * text_padding - text_height
# text_background_x1 = x1
# text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x1 + 2 * text_padding + text_width
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# text_background_x2 = x1 + 2 * text_padding + text_width
# text_background_y2 = y1
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
# if not is_overlap:
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer left':
text_x = x1 - text_padding - text_width
text_y = y1 + text_padding + text_height
# # elif pos == 'outer left':
# text_x = x1 - text_padding - text_width
# text_y = y1 + text_padding + text_height
text_background_x1 = x1 - 2 * text_padding - text_width
text_background_y1 = y1
# text_background_x1 = x1 - 2 * text_padding - text_width
# text_background_y1 = y1
text_background_x2 = x1
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# text_background_x2 = x1
# text_background_y2 = y1 + 2 * text_padding + text_height
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
# if not is_overlap:
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer right':
text_x = x2 + text_padding
text_y = y1 + text_padding + text_height
# # elif pos == 'outer right':
# text_x = x2 + text_padding
# text_y = y1 + text_padding + text_height
text_background_x1 = x2
text_background_y1 = y1
# text_background_x1 = x2
# text_background_y1 = y1
text_background_x2 = x2 + 2 * text_padding + text_width
text_background_y2 = y1 + 2 * text_padding + text_height
# text_background_x2 = x2 + 2 * text_padding + text_width
# text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
# if not is_overlap:
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'top right':
text_x = x2 - text_padding - text_width
text_y = y1 - text_padding
# # elif pos == 'top right':
# text_x = x2 - text_padding - text_width
# text_y = y1 - text_padding
text_background_x1 = x2 - 2 * text_padding - text_width
text_background_y1 = y1 - 2 * text_padding - text_height
# text_background_x1 = x2 - 2 * text_padding - text_width
# text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x2
text_background_y2 = y1
# text_background_x2 = x2
# text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
# if not is_overlap:
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2

View File

@@ -1,31 +1,32 @@
from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
import torch
from PIL import Image
import io
import base64
from typing import Dict
class Omniparser(object):
def __init__(self, config: Dict):
self.config = config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.som_model = get_yolo_model(model_path=config['som_model_path'])
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
print('Server initialized!')
# from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
# import torch
# from PIL import Image
# import io
# import base64
# from typing import Dict
# class Omniparser(object):
# def __init__(self, config: Dict):
# self.config = config
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# self.som_model = get_yolo_model(model_path=config['som_model_path'])
# self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
# print('Server initialized!')
def parse(self, image_base64: str):
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
print('image size:', image.size)
# def parse(self, image_base64: str):
# image_bytes = base64.b64decode(image_base64)
# image = Image.open(io.BytesIO(image_bytes))
# print('image size:', image.size)
box_overlay_ratio = max(image.size) / 3200
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
# box_overlay_ratio = max(image.size) / 3200
# draw_bbox_config = {
# 'text_scale': 0.8 * box_overlay_ratio,
# 'text_thickness': max(int(2 * box_overlay_ratio), 1),
# 'text_padding': max(int(3 * box_overlay_ratio), 1),
# 'thickness': max(int(3 * box_overlay_ratio), 1),
# }
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
# (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
# dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
return dino_labled_img, parsed_content_list
# return dino_labled_img, parsed_content_list

File diff suppressed because it is too large Load Diff

214
util/vision_agent.py Normal file
View File

@@ -0,0 +1,214 @@
from typing import List, Optional
import cv2
import torch
from ultralytics import YOLO
from transformers import AutoModelForCausalLM, AutoProcessor
import easyocr
import supervision as sv
import numpy as np
import time
from pydantic import BaseModel
import base64
class UIElement(BaseModel):
element_id: int
coordinates: list[float]
caption: Optional[str] = None
text: Optional[str] = None
class VisionAgent:
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
"""
Computer vision agent for UI analysis.
Args:
yolo_model_path: Path to YOLO model weights
caption_model_path: Name/path to captioning model (defaults to Florence-2)
"""
self.device = self._get_available_device()
self.dtype = self._get_dtype()
self.elements: List[UIElement] = []
self.yolo_model = YOLO(yolo_model_path)
self.caption_model = AutoModelForCausalLM.from_pretrained(
caption_model_path, trust_remote_code=True
).to(self.device)
self.caption_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
def _get_available_device(self) -> str:
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
return 'mps'
return 'cpu'
def _get_dtype(self)-> torch.dtype:
if torch.cuda.is_available():
return torch.float16
return torch.float32
def _reset_state(self):
"""Clear previous analysis results"""
self.elements = []
def analyze_image(self, image: np.ndarray) -> List[UIElement]:
"""
Process an image through all computer vision pipelines.
Args:
image: Input image in BGR format (OpenCV default)
Returns:
List of detected UI elements with annotations
"""
self._reset_state()
element_crops, boxes = self._detect_objects(image)
start = time.time()
element_texts = self._extract_text(element_crops)
end = time.time()
ocr_time = (end-start) * 10 ** 3
print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
start = time.time()
element_captions = self._get_caption(element_crops)
end = time.time()
caption_time = (end-start) * 10 ** 3
print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
for idx in range(len(element_crops)):
print(idx, boxes[idx], element_texts[idx], element_captions[idx])
new_element = UIElement(element_id=idx,
coordinates=boxes[idx],
text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
caption=element_captions[idx]
)
self.elements.append(new_element)
return self.elements
def _extract_text(self, images: np.ndarray) -> list[str]:
"""
Run OCR in sequential mode
TODO: It is possible to run in batch mode for a speed up, but the result quality needs test.
https://github.com/JaidedAI/EasyOCR/pull/458
"""
texts = []
for image in images:
text = self.ocr_reader.readtext(image, detail=0, paragraph=True, text_threshold=0.85)
texts.append(text)
# print(texts)
return texts
def _get_caption(self, images: np.ndarray, batch_size: int = 1) -> list[str]:
"""Run captioning in batched mode. TODO: adjust batch size"""
prompt = "<CAPTION>"
generated_texts = []
resized_images = []
for image in images:
resized_image = cv2.resize(image, (64, 64))
resized_images.append(resized_image)
for i in range(0, len(resized_images), batch_size):
batch_images = resized_images[i:i+batch_size]
inputs = self.caption_processor(
images=batch_images,
text=[prompt] * len(batch_images),
return_tensors="pt",
do_resize=True,
).to(device=self.device, dtype=self.dtype)
generated_ids = self.caption_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=10,
num_beams=1,
do_sample=False,
early_stopping=False,
)
generated_text = self.caption_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
generated_texts.extend([gen.strip() for gen in generated_text])
return generated_texts
def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
"""Run object detection pipeline"""
results = self.yolo_model(image)[0]
detections = sv.Detections.from_ultralytics(results)
boxes = detections.xyxy
if len(boxes) == 0:
return []
# Filter out boxes contained by others
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sorted_indices = np.argsort(-areas) # Sort descending by area
sorted_boxes = boxes[sorted_indices]
keep_sorted = []
for i in range(len(sorted_boxes)):
contained = False
for j in keep_sorted:
box_b = sorted_boxes[j]
box_a = sorted_boxes[i]
if (box_b[0] <= box_a[0] and box_b[1] <= box_a[1] and
box_b[2] >= box_a[2] and box_b[3] >= box_a[3]):
contained = True
break
if not contained:
keep_sorted.append(i)
# Map back to original indices
keep_indices = sorted_indices[keep_sorted]
filtered_boxes = boxes[keep_indices]
# Extract element crops
element_crops = []
for box in filtered_boxes:
x1, y1, x2, y2 = map(int, map(round, box))
element = image[y1:y2, x1:x2]
element_crops.append(np.array(element))
return element_crops, filtered_boxes
def load_image(self, image_source: str) -> np.ndarray:
try:
# 处理可能存在的Data URL前缀如 "data:image/png;base64,"
if ',' in image_source:
_, payload = image_source.split(',', 1)
else:
payload = image_source
# Base64解码 -> bytes -> numpy数组
image_bytes = base64.b64decode(payload)
np_array = np.frombuffer(image_bytes, dtype=np.uint8)
# OpenCV解码图像
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
if image is None:
raise ValueError("解码图片失败:无效的图片数据")
return self.analyze_image(image)
except (base64.binascii.Error, ValueError) as e:
# 生成更清晰的错误信息
error_msg = f"输入既不是有效的文件路径也不是有效的Base64图片数据"
raise ValueError(error_msg) from e
def __call__(self, image_source: str) -> List[UIElement]:
"""Process an image from file path."""
image = self.load_image(image_source)
if image is None:
raise FileNotFoundError(f"Vision agent: 图片读取失败")
return self.analyze_image(image)