diff --git a/gradio_ui/agent/llm_utils/omniparserclient.py b/gradio_ui/agent/llm_utils/omniparserclient.py
index 168effe..c71fdee 100644
--- a/gradio_ui/agent/llm_utils/omniparserclient.py
+++ b/gradio_ui/agent/llm_utils/omniparserclient.py
@@ -19,26 +19,31 @@ class OmniParserClient:
response_json = response.json()
print('omniparser latency:', response_json['latency'])
- som_image_data = base64.b64decode(response_json['som_image_base64'])
- screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
- som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
- with open(som_screenshot_path, "wb") as f:
- f.write(som_image_data)
+ # som_image_data = base64.b64decode(response_json['som_image_base64'])
+ # screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
+ # som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
+ # with open(som_screenshot_path, "wb") as f:
+ # f.write(som_image_data)
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['original_screenshot_base64'] = image_base64
- response_json['screenshot_uuid'] = screenshot_path_uuid
+ # response_json['screenshot_uuid'] = screenshot_path_uuid
response_json = self.reformat_messages(response_json)
return response_json
def reformat_messages(self, response_json: dict):
screen_info = ""
for idx, element in enumerate(response_json["parsed_content_list"]):
- element['idx'] = idx
- if element['type'] == 'text':
- screen_info += f'ID: {idx}, Text: {element["content"]}\n'
- elif element['type'] == 'icon':
- screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
+ # element['idx'] = idx
+ # if element['type'] == 'text':
+ # screen_info += f'ID: {idx}, Text: {element["content"]}\n'
+ # elif element['type'] == 'icon':
+ # screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
+ screen_info += f'ID: {element.element_id}, '
+ screen_info += f'Coordinates: {element.coordinates}, '
+ screen_info += f'Text: {element.text if len(element.text) else " "}, '
+ screen_info += f'Caption: {element.caption}. '
+ screen_info += "\n"
response_json['screen_info'] = screen_info
return response_json
\ No newline at end of file
diff --git a/gradio_ui/agent/task_plan_agent.py b/gradio_ui/agent/task_plan_agent.py
index 18bc6d5..0087c2c 100644
--- a/gradio_ui/agent/task_plan_agent.py
+++ b/gradio_ui/agent/task_plan_agent.py
@@ -1,8 +1,8 @@
from gradio_ui.agent.base_agent import BaseAgent
class TaskPlanAgent(BaseAgent):
- def __init__(self, config):
- super().__init__(config)
+ def __init__(self):
+ super().__init__()
self.SYSTEM_PROMPT = system_prompt
diff --git a/gradio_ui/agent/task_run_agent.py b/gradio_ui/agent/task_run_agent.py
index 1d4768b..f9d30ba 100644
--- a/gradio_ui/agent/task_run_agent.py
+++ b/gradio_ui/agent/task_run_agent.py
@@ -6,12 +6,15 @@ from xbrain.core.chat import run
import platform
import re
class TaskRunAgent(BaseAgent):
- def __init__(self,task_plan: str, screen_info):
+ def __init__(self):
+ print("TaskRunAgent initialized without a task")
+ def __call__(self,task_plan: str, screen_info):
self.OUTPUT_DIR = "./tmp/outputs"
device = self.get_device()
self.SYSTEM_PROMPT = system_prompt.format(task_plan=task_plan,
device=device,
screen_info=screen_info)
+ print(self.SYSTEM_PROMPT)
def get_device(self):
# 获取当前操作系统信息
diff --git a/gradio_ui/loop.py b/gradio_ui/loop.py
index ebf12e6..57d51f2 100644
--- a/gradio_ui/loop.py
+++ b/gradio_ui/loop.py
@@ -29,7 +29,7 @@ def sampling_loop_sync(
tool_output_callback: Callable[[ToolResult, str], None],
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
api_key: str,
- only_n_most_recent_images: int | None = 2,
+ only_n_most_recent_images: int | None = 0,
max_tokens: int = 4096,
omniparser_url: str,
base_url: str
@@ -40,7 +40,6 @@ def sampling_loop_sync(
print('in sampling_loop_sync, model:', model)
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
task_plan_agent = TaskPlanAgent()
- task_plan_agent()
# actor = VLMAgent(
# model=model,
# api_key=api_key,
@@ -58,7 +57,7 @@ def sampling_loop_sync(
tool_result_content = None
print(f"Start the message loop. User messages: {messages}")
- plan = task_plan_agent(messages[-1]["content"][0])
+ plan = task_plan_agent(user_task = messages[-1]["content"][0])
task_run_agent = TaskRunAgent()
while True:
diff --git a/omniserver.py b/omniserver.py
index 622840c..8a16e54 100644
--- a/omniserver.py
+++ b/omniserver.py
@@ -11,7 +11,8 @@ import argparse
import uvicorn
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(root_dir)
-from util.omniparser import Omniparser
+# from util.omniparser import Omniparser
+from util.vision_agent import VisionAgent
def parse_arguments():
parser = argparse.ArgumentParser(description='Omniparser API')
@@ -29,8 +30,10 @@ args = parse_arguments()
config = vars(args)
app = FastAPI()
-omniparser = Omniparser(config)
-
+# omniparser = Omniparser(config)
+yolo_model_path = config['som_model_path']
+caption_model_path = config['caption_model_path']
+vision_agent = VisionAgent(yolo_model_path=yolo_model_path, caption_model_path=caption_model_path)
class ParseRequest(BaseModel):
base64_image: str
@@ -38,10 +41,12 @@ class ParseRequest(BaseModel):
async def parse(parse_request: ParseRequest):
print('start parsing...')
start = time.time()
- dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
+ # dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
+ parsed_content_list = vision_agent(parse_request.base64_image)
+
latency = time.time() - start
print('time:', latency)
- return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
+ return {"parsed_content_list": parsed_content_list, 'latency': latency}
@app.get("/probe/")
async def root():
diff --git a/requirements.txt b/requirements.txt
index 75d8c00..107236c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,7 +2,7 @@ torch
easyocr
torchvision
supervision==0.18.0
-openai==1.3.5
+openai>=1.3.5
transformers
ultralytics==8.3.70
azure-identity
diff --git a/util/box_annotator.py b/util/box_annotator.py
index 82f7116..24d1f4c 100644
--- a/util/box_annotator.py
+++ b/util/box_annotator.py
@@ -1,262 +1,262 @@
-from typing import List, Optional, Union, Tuple
+# from typing import List, Optional, Union, Tuple
-import cv2
-import numpy as np
+# import cv2
+# import numpy as np
-from supervision.detection.core import Detections
-from supervision.draw.color import Color, ColorPalette
+# from supervision.detection.core import Detections
+# from supervision.draw.color import Color, ColorPalette
-class BoxAnnotator:
- """
- A class for drawing bounding boxes on an image using detections provided.
+# class BoxAnnotator:
+# """
+# A class for drawing bounding boxes on an image using detections provided.
- Attributes:
- color (Union[Color, ColorPalette]): The color to draw the bounding box,
- can be a single color or a color palette
- thickness (int): The thickness of the bounding box lines, default is 2
- text_color (Color): The color of the text on the bounding box, default is white
- text_scale (float): The scale of the text on the bounding box, default is 0.5
- text_thickness (int): The thickness of the text on the bounding box,
- default is 1
- text_padding (int): The padding around the text on the bounding box,
- default is 5
+# Attributes:
+# color (Union[Color, ColorPalette]): The color to draw the bounding box,
+# can be a single color or a color palette
+# thickness (int): The thickness of the bounding box lines, default is 2
+# text_color (Color): The color of the text on the bounding box, default is white
+# text_scale (float): The scale of the text on the bounding box, default is 0.5
+# text_thickness (int): The thickness of the text on the bounding box,
+# default is 1
+# text_padding (int): The padding around the text on the bounding box,
+# default is 5
- """
+# """
- def __init__(
- self,
- color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
- thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
- text_color: Color = Color.BLACK,
- text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
- text_thickness: int = 2, #1, # 2 for demo
- text_padding: int = 10,
- avoid_overlap: bool = True,
- ):
- self.color: Union[Color, ColorPalette] = color
- self.thickness: int = thickness
- self.text_color: Color = text_color
- self.text_scale: float = text_scale
- self.text_thickness: int = text_thickness
- self.text_padding: int = text_padding
- self.avoid_overlap: bool = avoid_overlap
+# def __init__(
+# self,
+# color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
+# thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
+# text_color: Color = Color.BLACK,
+# text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
+# text_thickness: int = 2, #1, # 2 for demo
+# text_padding: int = 10,
+# avoid_overlap: bool = True,
+# ):
+# self.color: Union[Color, ColorPalette] = color
+# self.thickness: int = thickness
+# self.text_color: Color = text_color
+# self.text_scale: float = text_scale
+# self.text_thickness: int = text_thickness
+# self.text_padding: int = text_padding
+# self.avoid_overlap: bool = avoid_overlap
- def annotate(
- self,
- scene: np.ndarray,
- detections: Detections,
- labels: Optional[List[str]] = None,
- skip_label: bool = False,
- image_size: Optional[Tuple[int, int]] = None,
- ) -> np.ndarray:
- """
- Draws bounding boxes on the frame using the detections provided.
+# def annotate(
+# self,
+# scene: np.ndarray,
+# detections: Detections,
+# labels: Optional[List[str]] = None,
+# skip_label: bool = False,
+# image_size: Optional[Tuple[int, int]] = None,
+# ) -> np.ndarray:
+# """
+# Draws bounding boxes on the frame using the detections provided.
- Args:
- scene (np.ndarray): The image on which the bounding boxes will be drawn
- detections (Detections): The detections for which the
- bounding boxes will be drawn
- labels (Optional[List[str]]): An optional list of labels
- corresponding to each detection. If `labels` are not provided,
- corresponding `class_id` will be used as label.
- skip_label (bool): Is set to `True`, skips bounding box label annotation.
- Returns:
- np.ndarray: The image with the bounding boxes drawn on it
+# Args:
+# scene (np.ndarray): The image on which the bounding boxes will be drawn
+# detections (Detections): The detections for which the
+# bounding boxes will be drawn
+# labels (Optional[List[str]]): An optional list of labels
+# corresponding to each detection. If `labels` are not provided,
+# corresponding `class_id` will be used as label.
+# skip_label (bool): Is set to `True`, skips bounding box label annotation.
+# Returns:
+# np.ndarray: The image with the bounding boxes drawn on it
- Example:
- ```python
- import supervision as sv
+# Example:
+# ```python
+# import supervision as sv
- classes = ['person', ...]
- image = ...
- detections = sv.Detections(...)
+# classes = ['person', ...]
+# image = ...
+# detections = sv.Detections(...)
- box_annotator = sv.BoxAnnotator()
- labels = [
- f"{classes[class_id]} {confidence:0.2f}"
- for _, _, confidence, class_id, _ in detections
- ]
- annotated_frame = box_annotator.annotate(
- scene=image.copy(),
- detections=detections,
- labels=labels
- )
- ```
- """
- font = cv2.FONT_HERSHEY_SIMPLEX
- for i in range(len(detections)):
- x1, y1, x2, y2 = detections.xyxy[i].astype(int)
- class_id = (
- detections.class_id[i] if detections.class_id is not None else None
- )
- idx = class_id if class_id is not None else i
- color = (
- self.color.by_idx(idx)
- if isinstance(self.color, ColorPalette)
- else self.color
- )
- cv2.rectangle(
- img=scene,
- pt1=(x1, y1),
- pt2=(x2, y2),
- color=color.as_bgr(),
- thickness=self.thickness,
- )
- if skip_label:
- continue
+# box_annotator = sv.BoxAnnotator()
+# labels = [
+# f"{classes[class_id]} {confidence:0.2f}"
+# for _, _, confidence, class_id, _ in detections
+# ]
+# annotated_frame = box_annotator.annotate(
+# scene=image.copy(),
+# detections=detections,
+# labels=labels
+# )
+# ```
+# """
+# font = cv2.FONT_HERSHEY_SIMPLEX
+# for i in range(len(detections)):
+# x1, y1, x2, y2 = detections.xyxy[i].astype(int)
+# class_id = (
+# detections.class_id[i] if detections.class_id is not None else None
+# )
+# idx = class_id if class_id is not None else i
+# color = (
+# self.color.by_idx(idx)
+# if isinstance(self.color, ColorPalette)
+# else self.color
+# )
+# cv2.rectangle(
+# img=scene,
+# pt1=(x1, y1),
+# pt2=(x2, y2),
+# color=color.as_bgr(),
+# thickness=self.thickness,
+# )
+# if skip_label:
+# continue
- text = (
- f"{class_id}"
- if (labels is None or len(detections) != len(labels))
- else labels[i]
- )
+# text = (
+# f"{class_id}"
+# if (labels is None or len(detections) != len(labels))
+# else labels[i]
+# )
- text_width, text_height = cv2.getTextSize(
- text=text,
- fontFace=font,
- fontScale=self.text_scale,
- thickness=self.text_thickness,
- )[0]
+# text_width, text_height = cv2.getTextSize(
+# text=text,
+# fontFace=font,
+# fontScale=self.text_scale,
+# thickness=self.text_thickness,
+# )[0]
- if not self.avoid_overlap:
- text_x = x1 + self.text_padding
- text_y = y1 - self.text_padding
+# if not self.avoid_overlap:
+# text_x = x1 + self.text_padding
+# text_y = y1 - self.text_padding
- text_background_x1 = x1
- text_background_y1 = y1 - 2 * self.text_padding - text_height
+# text_background_x1 = x1
+# text_background_y1 = y1 - 2 * self.text_padding - text_height
- text_background_x2 = x1 + 2 * self.text_padding + text_width
- text_background_y2 = y1
- # text_x = x1 - self.text_padding - text_width
- # text_y = y1 + self.text_padding + text_height
- # text_background_x1 = x1 - 2 * self.text_padding - text_width
- # text_background_y1 = y1
- # text_background_x2 = x1
- # text_background_y2 = y1 + 2 * self.text_padding + text_height
- else:
- text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
+# text_background_x2 = x1 + 2 * self.text_padding + text_width
+# text_background_y2 = y1
+# # text_x = x1 - self.text_padding - text_width
+# # text_y = y1 + self.text_padding + text_height
+# # text_background_x1 = x1 - 2 * self.text_padding - text_width
+# # text_background_y1 = y1
+# # text_background_x2 = x1
+# # text_background_y2 = y1 + 2 * self.text_padding + text_height
+# else:
+# text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
- cv2.rectangle(
- img=scene,
- pt1=(text_background_x1, text_background_y1),
- pt2=(text_background_x2, text_background_y2),
- color=color.as_bgr(),
- thickness=cv2.FILLED,
- )
- # import pdb; pdb.set_trace()
- box_color = color.as_rgb()
- luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
- text_color = (0,0,0) if luminance > 160 else (255,255,255)
- cv2.putText(
- img=scene,
- text=text,
- org=(text_x, text_y),
- fontFace=font,
- fontScale=self.text_scale,
- # color=self.text_color.as_rgb(),
- color=text_color,
- thickness=self.text_thickness,
- lineType=cv2.LINE_AA,
- )
- return scene
+# cv2.rectangle(
+# img=scene,
+# pt1=(text_background_x1, text_background_y1),
+# pt2=(text_background_x2, text_background_y2),
+# color=color.as_bgr(),
+# thickness=cv2.FILLED,
+# )
+# # import pdb; pdb.set_trace()
+# box_color = color.as_rgb()
+# luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
+# text_color = (0,0,0) if luminance > 160 else (255,255,255)
+# cv2.putText(
+# img=scene,
+# text=text,
+# org=(text_x, text_y),
+# fontFace=font,
+# fontScale=self.text_scale,
+# # color=self.text_color.as_rgb(),
+# color=text_color,
+# thickness=self.text_thickness,
+# lineType=cv2.LINE_AA,
+# )
+# return scene
-def box_area(box):
- return (box[2] - box[0]) * (box[3] - box[1])
+# def box_area(box):
+# return (box[2] - box[0]) * (box[3] - box[1])
-def intersection_area(box1, box2):
- x1 = max(box1[0], box2[0])
- y1 = max(box1[1], box2[1])
- x2 = min(box1[2], box2[2])
- y2 = min(box1[3], box2[3])
- return max(0, x2 - x1) * max(0, y2 - y1)
+# def intersection_area(box1, box2):
+# x1 = max(box1[0], box2[0])
+# y1 = max(box1[1], box2[1])
+# x2 = min(box1[2], box2[2])
+# y2 = min(box1[3], box2[3])
+# return max(0, x2 - x1) * max(0, y2 - y1)
-def IoU(box1, box2, return_max=True):
- intersection = intersection_area(box1, box2)
- union = box_area(box1) + box_area(box2) - intersection
- if box_area(box1) > 0 and box_area(box2) > 0:
- ratio1 = intersection / box_area(box1)
- ratio2 = intersection / box_area(box2)
- else:
- ratio1, ratio2 = 0, 0
- if return_max:
- return max(intersection / union, ratio1, ratio2)
- else:
- return intersection / union
+# def IoU(box1, box2, return_max=True):
+# intersection = intersection_area(box1, box2)
+# union = box_area(box1) + box_area(box2) - intersection
+# if box_area(box1) > 0 and box_area(box2) > 0:
+# ratio1 = intersection / box_area(box1)
+# ratio2 = intersection / box_area(box2)
+# else:
+# ratio1, ratio2 = 0, 0
+# if return_max:
+# return max(intersection / union, ratio1, ratio2)
+# else:
+# return intersection / union
-def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
- """ check overlap of text and background detection box, and get_optimal_label_pos,
- pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
- Threshold: default to 0.3
- """
+# def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
+# """ check overlap of text and background detection box, and get_optimal_label_pos,
+# pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
+# Threshold: default to 0.3
+# """
- def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
- is_overlap = False
- for i in range(len(detections)):
- detection = detections.xyxy[i].astype(int)
- if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
- is_overlap = True
- break
- # check if the text is out of the image
- if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
- is_overlap = True
- return is_overlap
+# def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
+# is_overlap = False
+# for i in range(len(detections)):
+# detection = detections.xyxy[i].astype(int)
+# if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
+# is_overlap = True
+# break
+# # check if the text is out of the image
+# if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
+# is_overlap = True
+# return is_overlap
- # if pos == 'top left':
- text_x = x1 + text_padding
- text_y = y1 - text_padding
+# # if pos == 'top left':
+# text_x = x1 + text_padding
+# text_y = y1 - text_padding
- text_background_x1 = x1
- text_background_y1 = y1 - 2 * text_padding - text_height
+# text_background_x1 = x1
+# text_background_y1 = y1 - 2 * text_padding - text_height
- text_background_x2 = x1 + 2 * text_padding + text_width
- text_background_y2 = y1
- is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
- if not is_overlap:
- return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+# text_background_x2 = x1 + 2 * text_padding + text_width
+# text_background_y2 = y1
+# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+# if not is_overlap:
+# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
- # elif pos == 'outer left':
- text_x = x1 - text_padding - text_width
- text_y = y1 + text_padding + text_height
+# # elif pos == 'outer left':
+# text_x = x1 - text_padding - text_width
+# text_y = y1 + text_padding + text_height
- text_background_x1 = x1 - 2 * text_padding - text_width
- text_background_y1 = y1
+# text_background_x1 = x1 - 2 * text_padding - text_width
+# text_background_y1 = y1
- text_background_x2 = x1
- text_background_y2 = y1 + 2 * text_padding + text_height
- is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
- if not is_overlap:
- return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+# text_background_x2 = x1
+# text_background_y2 = y1 + 2 * text_padding + text_height
+# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+# if not is_overlap:
+# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
- # elif pos == 'outer right':
- text_x = x2 + text_padding
- text_y = y1 + text_padding + text_height
+# # elif pos == 'outer right':
+# text_x = x2 + text_padding
+# text_y = y1 + text_padding + text_height
- text_background_x1 = x2
- text_background_y1 = y1
+# text_background_x1 = x2
+# text_background_y1 = y1
- text_background_x2 = x2 + 2 * text_padding + text_width
- text_background_y2 = y1 + 2 * text_padding + text_height
+# text_background_x2 = x2 + 2 * text_padding + text_width
+# text_background_y2 = y1 + 2 * text_padding + text_height
- is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
- if not is_overlap:
- return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+# if not is_overlap:
+# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
- # elif pos == 'top right':
- text_x = x2 - text_padding - text_width
- text_y = y1 - text_padding
+# # elif pos == 'top right':
+# text_x = x2 - text_padding - text_width
+# text_y = y1 - text_padding
- text_background_x1 = x2 - 2 * text_padding - text_width
- text_background_y1 = y1 - 2 * text_padding - text_height
+# text_background_x1 = x2 - 2 * text_padding - text_width
+# text_background_y1 = y1 - 2 * text_padding - text_height
- text_background_x2 = x2
- text_background_y2 = y1
+# text_background_x2 = x2
+# text_background_y2 = y1
- is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
- if not is_overlap:
- return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
+# if not is_overlap:
+# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
- return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
+# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
diff --git a/util/omniparser.py b/util/omniparser.py
index 3d9ad4f..e534597 100644
--- a/util/omniparser.py
+++ b/util/omniparser.py
@@ -1,31 +1,32 @@
-from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
-import torch
-from PIL import Image
-import io
-import base64
-from typing import Dict
-class Omniparser(object):
- def __init__(self, config: Dict):
- self.config = config
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- self.som_model = get_yolo_model(model_path=config['som_model_path'])
- self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
- print('Server initialized!')
+# from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
+# import torch
+# from PIL import Image
+# import io
+# import base64
+# from typing import Dict
+# class Omniparser(object):
+# def __init__(self, config: Dict):
+# self.config = config
+# device = 'cuda' if torch.cuda.is_available() else 'cpu'
+# self.som_model = get_yolo_model(model_path=config['som_model_path'])
+# self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
+# print('Server initialized!')
- def parse(self, image_base64: str):
- image_bytes = base64.b64decode(image_base64)
- image = Image.open(io.BytesIO(image_bytes))
- print('image size:', image.size)
+# def parse(self, image_base64: str):
+# image_bytes = base64.b64decode(image_base64)
+# image = Image.open(io.BytesIO(image_bytes))
+# print('image size:', image.size)
- box_overlay_ratio = max(image.size) / 3200
- draw_bbox_config = {
- 'text_scale': 0.8 * box_overlay_ratio,
- 'text_thickness': max(int(2 * box_overlay_ratio), 1),
- 'text_padding': max(int(3 * box_overlay_ratio), 1),
- 'thickness': max(int(3 * box_overlay_ratio), 1),
- }
+# box_overlay_ratio = max(image.size) / 3200
+# draw_bbox_config = {
+# 'text_scale': 0.8 * box_overlay_ratio,
+# 'text_thickness': max(int(2 * box_overlay_ratio), 1),
+# 'text_padding': max(int(3 * box_overlay_ratio), 1),
+# 'thickness': max(int(3 * box_overlay_ratio), 1),
+# }
- (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
+# (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
+# dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
- return dino_labled_img, parsed_content_list
\ No newline at end of file
+# return dino_labled_img, parsed_content_list
+
diff --git a/util/utils.py b/util/utils.py
index 6ed7252..5832aa5 100644
--- a/util/utils.py
+++ b/util/utils.py
@@ -1,540 +1,540 @@
-# from ultralytics import YOLO
-import os
-import io
-import base64
-import time
-from PIL import Image, ImageDraw, ImageFont
-import json
-import requests
-# utility function
-import os
-from openai import AzureOpenAI
+# # from ultralytics import YOLO
+# import os
+# import io
+# import base64
+# import time
+# from PIL import Image, ImageDraw, ImageFont
+# import json
+# import requests
+# # utility function
+# import os
+# from openai import AzureOpenAI
-import json
-import sys
-import os
-import cv2
-import numpy as np
-# %matplotlib inline
-from matplotlib import pyplot as plt
-import easyocr
-from paddleocr import PaddleOCR
-reader = easyocr.Reader(['en', 'ch_sim'])
-paddle_ocr = PaddleOCR(
- lang='ch', # other lang also available
- use_angle_cls=False,
- use_gpu=False, # using cuda will conflict with pytorch in the same process
- show_log=False,
- max_batch_size=1024,
- use_dilation=True, # improves accuracy
- det_db_score_mode='slow', # improves accuracy
- rec_batch_num=1024)
-import time
-import base64
+# import json
+# import sys
+# import os
+# import cv2
+# import numpy as np
+# # %matplotlib inline
+# from matplotlib import pyplot as plt
+# import easyocr
+# from paddleocr import PaddleOCR
+# reader = easyocr.Reader(['en', 'ch_sim'])
+# paddle_ocr = PaddleOCR(
+# lang='ch', # other lang also available
+# use_angle_cls=False,
+# use_gpu=False, # using cuda will conflict with pytorch in the same process
+# show_log=False,
+# max_batch_size=1024,
+# use_dilation=True, # improves accuracy
+# det_db_score_mode='slow', # improves accuracy
+# rec_batch_num=1024)
+# import time
+# import base64
-import os
-import ast
-import torch
-from typing import Tuple, List, Union
-from torchvision.ops import box_convert
-import re
-from torchvision.transforms import ToPILImage
-import supervision as sv
-import torchvision.transforms as T
-from util.box_annotator import BoxAnnotator
+# import os
+# import ast
+# import torch
+# from typing import Tuple, List, Union
+# from torchvision.ops import box_convert
+# import re
+# from torchvision.transforms import ToPILImage
+# import supervision as sv
+# import torchvision.transforms as T
+# from util.box_annotator import BoxAnnotator
-def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
- if not device:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- if model_name == "blip2":
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
- if device == 'cpu':
- model = Blip2ForConditionalGeneration.from_pretrained(
- model_name_or_path, device_map=None, torch_dtype=torch.float32
- )
- else:
- model = Blip2ForConditionalGeneration.from_pretrained(
- model_name_or_path, device_map=None, torch_dtype=torch.float16
- ).to(device)
- elif model_name == "florence2":
- from transformers import AutoProcessor, AutoModelForCausalLM
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
- if device == 'cpu':
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
- else:
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
- return {'model': model.to(device), 'processor': processor}
+# def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
+# if not device:
+# device = "cuda" if torch.cuda.is_available() else "cpu"
+# if model_name == "blip2":
+# from transformers import Blip2Processor, Blip2ForConditionalGeneration
+# processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
+# if device == 'cpu':
+# model = Blip2ForConditionalGeneration.from_pretrained(
+# model_name_or_path, device_map=None, torch_dtype=torch.float32
+# )
+# else:
+# model = Blip2ForConditionalGeneration.from_pretrained(
+# model_name_or_path, device_map=None, torch_dtype=torch.float16
+# ).to(device)
+# elif model_name == "florence2":
+# from transformers import AutoProcessor, AutoModelForCausalLM
+# processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
+# if device == 'cpu':
+# model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
+# else:
+# model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
+# return {'model': model.to(device), 'processor': processor}
-def get_yolo_model(model_path):
- from ultralytics import YOLO
- # Load the model.
- model = YOLO(model_path)
- return model
+# def get_yolo_model(model_path):
+# from ultralytics import YOLO
+# # Load the model.
+# model = YOLO(model_path)
+# return model
-@torch.inference_mode()
-def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128):
- # Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
- to_pil = ToPILImage()
- if starting_idx:
- non_ocr_boxes = filtered_boxes[starting_idx:]
- else:
- non_ocr_boxes = filtered_boxes
- croped_pil_image = []
- for i, coord in enumerate(non_ocr_boxes):
- try:
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
- cropped_image = image_source[ymin:ymax, xmin:xmax, :]
- cropped_image = cv2.resize(cropped_image, (64, 64))
- croped_pil_image.append(to_pil(cropped_image))
- except:
- continue
+# @torch.inference_mode()
+# def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128):
+# # Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
+# to_pil = ToPILImage()
+# if starting_idx:
+# non_ocr_boxes = filtered_boxes[starting_idx:]
+# else:
+# non_ocr_boxes = filtered_boxes
+# croped_pil_image = []
+# for i, coord in enumerate(non_ocr_boxes):
+# try:
+# xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
+# ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
+# cropped_image = image_source[ymin:ymax, xmin:xmax, :]
+# cropped_image = cv2.resize(cropped_image, (64, 64))
+# croped_pil_image.append(to_pil(cropped_image))
+# except:
+# continue
- model, processor = caption_model_processor['model'], caption_model_processor['processor']
- if not prompt:
- if 'florence' in model.config.model_type:
- prompt = "
"
- else:
- prompt = "The image shows"
+# model, processor = caption_model_processor['model'], caption_model_processor['processor']
+# if not prompt:
+# if 'florence' in model.config.model_type:
+# prompt = ""
+# else:
+# prompt = "The image shows"
- generated_texts = []
- device = model.device
- for i in range(0, len(croped_pil_image), batch_size):
- start = time.time()
- batch = croped_pil_image[i:i+batch_size]
- t1 = time.time()
- if model.device.type == 'cuda':
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
- else:
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
- if 'florence' in model.config.model_type:
- generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
- else:
- generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
- generated_text = [gen.strip() for gen in generated_text]
- generated_texts.extend(generated_text)
+# generated_texts = []
+# device = model.device
+# for i in range(0, len(croped_pil_image), batch_size):
+# start = time.time()
+# batch = croped_pil_image[i:i+batch_size]
+# t1 = time.time()
+# if model.device.type == 'cuda':
+# inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
+# else:
+# inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
+# if 'florence' in model.config.model_type:
+# generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
+# else:
+# generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
+# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
+# generated_text = [gen.strip() for gen in generated_text]
+# generated_texts.extend(generated_text)
- return generated_texts
+# return generated_texts
-def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
- to_pil = ToPILImage()
- if ocr_bbox:
- non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
- else:
- non_ocr_boxes = filtered_boxes
- croped_pil_image = []
- for i, coord in enumerate(non_ocr_boxes):
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
- cropped_image = image_source[ymin:ymax, xmin:xmax, :]
- croped_pil_image.append(to_pil(cropped_image))
+# def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
+# to_pil = ToPILImage()
+# if ocr_bbox:
+# non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
+# else:
+# non_ocr_boxes = filtered_boxes
+# croped_pil_image = []
+# for i, coord in enumerate(non_ocr_boxes):
+# xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
+# ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
+# cropped_image = image_source[ymin:ymax, xmin:xmax, :]
+# croped_pil_image.append(to_pil(cropped_image))
- model, processor = caption_model_processor['model'], caption_model_processor['processor']
- device = model.device
- messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
- prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+# model, processor = caption_model_processor['model'], caption_model_processor['processor']
+# device = model.device
+# messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
+# prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- batch_size = 5 # Number of samples per batch
- generated_texts = []
+# batch_size = 5 # Number of samples per batch
+# generated_texts = []
- for i in range(0, len(croped_pil_image), batch_size):
- images = croped_pil_image[i:i+batch_size]
- image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
- inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
- texts = [prompt] * len(images)
- for i, txt in enumerate(texts):
- input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
- inputs['input_ids'].append(input['input_ids'])
- inputs['attention_mask'].append(input['attention_mask'])
- inputs['pixel_values'].append(input['pixel_values'])
- inputs['image_sizes'].append(input['image_sizes'])
- max_len = max([x.shape[1] for x in inputs['input_ids']])
- for i, v in enumerate(inputs['input_ids']):
- inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
- inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
- inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
+# for i in range(0, len(croped_pil_image), batch_size):
+# images = croped_pil_image[i:i+batch_size]
+# image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
+# inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
+# texts = [prompt] * len(images)
+# for i, txt in enumerate(texts):
+# input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
+# inputs['input_ids'].append(input['input_ids'])
+# inputs['attention_mask'].append(input['attention_mask'])
+# inputs['pixel_values'].append(input['pixel_values'])
+# inputs['image_sizes'].append(input['image_sizes'])
+# max_len = max([x.shape[1] for x in inputs['input_ids']])
+# for i, v in enumerate(inputs['input_ids']):
+# inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
+# inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
+# inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
- generation_args = {
- "max_new_tokens": 25,
- "temperature": 0.01,
- "do_sample": False,
- }
- generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
- # # remove input tokens
- generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
- response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
- response = [res.strip('\n').strip() for res in response]
- generated_texts.extend(response)
+# generation_args = {
+# "max_new_tokens": 25,
+# "temperature": 0.01,
+# "do_sample": False,
+# }
+# generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
+# # # remove input tokens
+# generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
+# response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+# response = [res.strip('\n').strip() for res in response]
+# generated_texts.extend(response)
- return generated_texts
+# return generated_texts
-def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
- assert ocr_bbox is None or isinstance(ocr_bbox, List)
+# def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
+# assert ocr_bbox is None or isinstance(ocr_bbox, List)
- def box_area(box):
- return (box[2] - box[0]) * (box[3] - box[1])
+# def box_area(box):
+# return (box[2] - box[0]) * (box[3] - box[1])
- def intersection_area(box1, box2):
- x1 = max(box1[0], box2[0])
- y1 = max(box1[1], box2[1])
- x2 = min(box1[2], box2[2])
- y2 = min(box1[3], box2[3])
- return max(0, x2 - x1) * max(0, y2 - y1)
+# def intersection_area(box1, box2):
+# x1 = max(box1[0], box2[0])
+# y1 = max(box1[1], box2[1])
+# x2 = min(box1[2], box2[2])
+# y2 = min(box1[3], box2[3])
+# return max(0, x2 - x1) * max(0, y2 - y1)
- def IoU(box1, box2):
- intersection = intersection_area(box1, box2)
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
- if box_area(box1) > 0 and box_area(box2) > 0:
- ratio1 = intersection / box_area(box1)
- ratio2 = intersection / box_area(box2)
- else:
- ratio1, ratio2 = 0, 0
- return max(intersection / union, ratio1, ratio2)
+# def IoU(box1, box2):
+# intersection = intersection_area(box1, box2)
+# union = box_area(box1) + box_area(box2) - intersection + 1e-6
+# if box_area(box1) > 0 and box_area(box2) > 0:
+# ratio1 = intersection / box_area(box1)
+# ratio2 = intersection / box_area(box2)
+# else:
+# ratio1, ratio2 = 0, 0
+# return max(intersection / union, ratio1, ratio2)
- def is_inside(box1, box2):
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
- intersection = intersection_area(box1, box2)
- ratio1 = intersection / box_area(box1)
- return ratio1 > 0.95
+# def is_inside(box1, box2):
+# # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
+# intersection = intersection_area(box1, box2)
+# ratio1 = intersection / box_area(box1)
+# return ratio1 > 0.95
- boxes = boxes.tolist()
- filtered_boxes = []
- if ocr_bbox:
- filtered_boxes.extend(ocr_bbox)
- # print('ocr_bbox!!!', ocr_bbox)
- for i, box1 in enumerate(boxes):
- # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
- is_valid_box = True
- for j, box2 in enumerate(boxes):
- # keep the smaller box
- if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
- is_valid_box = False
- break
- if is_valid_box:
- # add the following 2 lines to include ocr bbox
- if ocr_bbox:
- # only add the box if it does not overlap with any ocr bbox
- if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
- filtered_boxes.append(box1)
- else:
- filtered_boxes.append(box1)
- return torch.tensor(filtered_boxes)
+# boxes = boxes.tolist()
+# filtered_boxes = []
+# if ocr_bbox:
+# filtered_boxes.extend(ocr_bbox)
+# # print('ocr_bbox!!!', ocr_bbox)
+# for i, box1 in enumerate(boxes):
+# # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
+# is_valid_box = True
+# for j, box2 in enumerate(boxes):
+# # keep the smaller box
+# if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
+# is_valid_box = False
+# break
+# if is_valid_box:
+# # add the following 2 lines to include ocr bbox
+# if ocr_bbox:
+# # only add the box if it does not overlap with any ocr bbox
+# if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
+# filtered_boxes.append(box1)
+# else:
+# filtered_boxes.append(box1)
+# return torch.tensor(filtered_boxes)
-def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
- '''
- ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
- boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
+# def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
+# '''
+# ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
+# boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
- '''
- assert ocr_bbox is None or isinstance(ocr_bbox, List)
+# '''
+# assert ocr_bbox is None or isinstance(ocr_bbox, List)
- def box_area(box):
- return (box[2] - box[0]) * (box[3] - box[1])
+# def box_area(box):
+# return (box[2] - box[0]) * (box[3] - box[1])
- def intersection_area(box1, box2):
- x1 = max(box1[0], box2[0])
- y1 = max(box1[1], box2[1])
- x2 = min(box1[2], box2[2])
- y2 = min(box1[3], box2[3])
- return max(0, x2 - x1) * max(0, y2 - y1)
+# def intersection_area(box1, box2):
+# x1 = max(box1[0], box2[0])
+# y1 = max(box1[1], box2[1])
+# x2 = min(box1[2], box2[2])
+# y2 = min(box1[3], box2[3])
+# return max(0, x2 - x1) * max(0, y2 - y1)
- def IoU(box1, box2):
- intersection = intersection_area(box1, box2)
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
- if box_area(box1) > 0 and box_area(box2) > 0:
- ratio1 = intersection / box_area(box1)
- ratio2 = intersection / box_area(box2)
- else:
- ratio1, ratio2 = 0, 0
- return max(intersection / union, ratio1, ratio2)
+# def IoU(box1, box2):
+# intersection = intersection_area(box1, box2)
+# union = box_area(box1) + box_area(box2) - intersection + 1e-6
+# if box_area(box1) > 0 and box_area(box2) > 0:
+# ratio1 = intersection / box_area(box1)
+# ratio2 = intersection / box_area(box2)
+# else:
+# ratio1, ratio2 = 0, 0
+# return max(intersection / union, ratio1, ratio2)
- def is_inside(box1, box2):
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
- intersection = intersection_area(box1, box2)
- ratio1 = intersection / box_area(box1)
- return ratio1 > 0.80
+# def is_inside(box1, box2):
+# # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
+# intersection = intersection_area(box1, box2)
+# ratio1 = intersection / box_area(box1)
+# return ratio1 > 0.80
- # boxes = boxes.tolist()
- filtered_boxes = []
- if ocr_bbox:
- filtered_boxes.extend(ocr_bbox)
- # print('ocr_bbox!!!', ocr_bbox)
- for i, box1_elem in enumerate(boxes):
- box1 = box1_elem['bbox']
- is_valid_box = True
- for j, box2_elem in enumerate(boxes):
- # keep the smaller box
- box2 = box2_elem['bbox']
- if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
- is_valid_box = False
- break
- if is_valid_box:
- if ocr_bbox:
- # keep yolo boxes + prioritize ocr label
- box_added = False
- ocr_labels = ''
- for box3_elem in ocr_bbox:
- if not box_added:
- box3 = box3_elem['bbox']
- if is_inside(box3, box1): # ocr inside icon
- # box_added = True
- # delete the box3_elem from ocr_bbox
- try:
- # gather all ocr labels
- ocr_labels += box3_elem['content'] + ' '
- filtered_boxes.remove(box3_elem)
- except:
- continue
- # break
- elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
- box_added = True
- break
- else:
- continue
- if not box_added:
- if ocr_labels:
- filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
- else:
- filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
- else:
- filtered_boxes.append(box1)
- return filtered_boxes # torch.tensor(filtered_boxes)
+# # boxes = boxes.tolist()
+# filtered_boxes = []
+# if ocr_bbox:
+# filtered_boxes.extend(ocr_bbox)
+# # print('ocr_bbox!!!', ocr_bbox)
+# for i, box1_elem in enumerate(boxes):
+# box1 = box1_elem['bbox']
+# is_valid_box = True
+# for j, box2_elem in enumerate(boxes):
+# # keep the smaller box
+# box2 = box2_elem['bbox']
+# if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
+# is_valid_box = False
+# break
+# if is_valid_box:
+# if ocr_bbox:
+# # keep yolo boxes + prioritize ocr label
+# box_added = False
+# ocr_labels = ''
+# for box3_elem in ocr_bbox:
+# if not box_added:
+# box3 = box3_elem['bbox']
+# if is_inside(box3, box1): # ocr inside icon
+# # box_added = True
+# # delete the box3_elem from ocr_bbox
+# try:
+# # gather all ocr labels
+# ocr_labels += box3_elem['content'] + ' '
+# filtered_boxes.remove(box3_elem)
+# except:
+# continue
+# # break
+# elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
+# box_added = True
+# break
+# else:
+# continue
+# if not box_added:
+# if ocr_labels:
+# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
+# else:
+# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
+# else:
+# filtered_boxes.append(box1)
+# return filtered_boxes # torch.tensor(filtered_boxes)
-def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
- transform = T.Compose(
- [
- T.RandomResize([800], max_size=1333),
- T.ToTensor(),
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ]
- )
- image_source = Image.open(image_path).convert("RGB")
- image = np.asarray(image_source)
- image_transformed, _ = transform(image_source, None)
- return image, image_transformed
+# def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
+# transform = T.Compose(
+# [
+# T.RandomResize([800], max_size=1333),
+# T.ToTensor(),
+# T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+# ]
+# )
+# image_source = Image.open(image_path).convert("RGB")
+# image = np.asarray(image_source)
+# image_transformed, _ = transform(image_source, None)
+# return image, image_transformed
-def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
- text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
- """
- This function annotates an image with bounding boxes and labels.
+# def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
+# text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
+# """
+# This function annotates an image with bounding boxes and labels.
- Parameters:
- image_source (np.ndarray): The source image to be annotated.
- boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
- logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
- phrases (List[str]): A list of labels for each bounding box.
- text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
+# Parameters:
+# image_source (np.ndarray): The source image to be annotated.
+# boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
+# logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
+# phrases (List[str]): A list of labels for each bounding box.
+# text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
- Returns:
- np.ndarray: The annotated image.
- """
- h, w, _ = image_source.shape
- boxes = boxes * torch.Tensor([w, h, w, h])
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
- xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
- detections = sv.Detections(xyxy=xyxy)
+# Returns:
+# np.ndarray: The annotated image.
+# """
+# h, w, _ = image_source.shape
+# boxes = boxes * torch.Tensor([w, h, w, h])
+# xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
+# xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
+# detections = sv.Detections(xyxy=xyxy)
- labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
+# labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
- box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
- annotated_frame = image_source.copy()
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
+# box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
+# annotated_frame = image_source.copy()
+# annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
- label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
- return annotated_frame, label_coordinates
+# label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
+# return annotated_frame, label_coordinates
-def predict(model, image, caption, box_threshold, text_threshold):
- """ Use huggingface model to replace the original model
- """
- model, processor = model['model'], model['processor']
- device = model.device
+# def predict(model, image, caption, box_threshold, text_threshold):
+# """ Use huggingface model to replace the original model
+# """
+# model, processor = model['model'], model['processor']
+# device = model.device
- inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
- with torch.no_grad():
- outputs = model(**inputs)
+# inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
+# with torch.no_grad():
+# outputs = model(**inputs)
- results = processor.post_process_grounded_object_detection(
- outputs,
- inputs.input_ids,
- box_threshold=box_threshold, # 0.4,
- text_threshold=text_threshold, # 0.3,
- target_sizes=[image.size[::-1]]
- )[0]
- boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
- return boxes, logits, phrases
+# results = processor.post_process_grounded_object_detection(
+# outputs,
+# inputs.input_ids,
+# box_threshold=box_threshold, # 0.4,
+# text_threshold=text_threshold, # 0.3,
+# target_sizes=[image.size[::-1]]
+# )[0]
+# boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
+# return boxes, logits, phrases
-def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
- """ Use huggingface model to replace the original model
- """
- # model = model['model']
- if scale_img:
- result = model.predict(
- source=image,
- conf=box_threshold,
- imgsz=imgsz,
- iou=iou_threshold, # default 0.7
- )
- else:
- result = model.predict(
- source=image,
- conf=box_threshold,
- iou=iou_threshold, # default 0.7
- )
- boxes = result[0].boxes.xyxy#.tolist() # in pixel space
- conf = result[0].boxes.conf
- phrases = [str(i) for i in range(len(boxes))]
+# def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
+# """ Use huggingface model to replace the original model
+# """
+# # model = model['model']
+# if scale_img:
+# result = model.predict(
+# source=image,
+# conf=box_threshold,
+# imgsz=imgsz,
+# iou=iou_threshold, # default 0.7
+# )
+# else:
+# result = model.predict(
+# source=image,
+# conf=box_threshold,
+# iou=iou_threshold, # default 0.7
+# )
+# boxes = result[0].boxes.xyxy#.tolist() # in pixel space
+# conf = result[0].boxes.conf
+# phrases = [str(i) for i in range(len(boxes))]
- return boxes, conf, phrases
+# return boxes, conf, phrases
-def int_box_area(box, w, h):
- x1, y1, x2, y2 = box
- int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
- area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
- return area
+# def int_box_area(box, w, h):
+# x1, y1, x2, y2 = box
+# int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
+# area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
+# return area
-def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128):
- """Process either an image path or Image object
+# def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128):
+# """Process either an image path or Image object
- Args:
- image_source: Either a file path (str) or PIL Image object
- ...
- """
- if isinstance(image_source, str):
- image_source = Image.open(image_source)
- image_source = image_source.convert("RGB") # for CLIP
- w, h = image_source.size
- if not imgsz:
- imgsz = (h, w)
- # print('image size:', w, h)
- xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
- xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
- image_source = np.asarray(image_source)
- phrases = [str(i) for i in range(len(phrases))]
+# Args:
+# image_source: Either a file path (str) or PIL Image object
+# ...
+# """
+# if isinstance(image_source, str):
+# image_source = Image.open(image_source)
+# image_source = image_source.convert("RGB") # for CLIP
+# w, h = image_source.size
+# if not imgsz:
+# imgsz = (h, w)
+# # print('image size:', w, h)
+# xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
+# xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
+# image_source = np.asarray(image_source)
+# phrases = [str(i) for i in range(len(phrases))]
- # annotate the image with labels
- if ocr_bbox:
- ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
- ocr_bbox=ocr_bbox.tolist()
- else:
- print('no ocr bbox!!!')
- ocr_bbox = None
+# # annotate the image with labels
+# if ocr_bbox:
+# ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
+# ocr_bbox=ocr_bbox.tolist()
+# else:
+# print('no ocr bbox!!!')
+# ocr_bbox = None
- ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
- xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
- filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
+# ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
+# xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
+# filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
- # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
- filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
- # get the index of the first 'content': None
- starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
- filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
- print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
+# # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
+# filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
+# # get the index of the first 'content': None
+# starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
+# filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
+# print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
- # get parsed icon local semantics
- time1 = time.time()
- if use_local_semantics:
- caption_model = caption_model_processor['model']
- if 'phi3_v' in caption_model.config.model_type:
- parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
- else:
- parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
- ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
- icon_start = len(ocr_text)
- parsed_content_icon_ls = []
- # fill the filtered_boxes_elem None content with parsed_content_icon in order
- for i, box in enumerate(filtered_boxes_elem):
- if box['content'] is None:
- box['content'] = parsed_content_icon.pop(0)
- for i, txt in enumerate(parsed_content_icon):
- parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
- parsed_content_merged = ocr_text + parsed_content_icon_ls
- else:
- ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
- parsed_content_merged = ocr_text
- print('time to get parsed content:', time.time()-time1)
+# # get parsed icon local semantics
+# time1 = time.time()
+# if use_local_semantics:
+# caption_model = caption_model_processor['model']
+# if 'phi3_v' in caption_model.config.model_type:
+# parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
+# else:
+# parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
+# ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
+# icon_start = len(ocr_text)
+# parsed_content_icon_ls = []
+# # fill the filtered_boxes_elem None content with parsed_content_icon in order
+# for i, box in enumerate(filtered_boxes_elem):
+# if box['content'] is None:
+# box['content'] = parsed_content_icon.pop(0)
+# for i, txt in enumerate(parsed_content_icon):
+# parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
+# parsed_content_merged = ocr_text + parsed_content_icon_ls
+# else:
+# ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
+# parsed_content_merged = ocr_text
+# print('time to get parsed content:', time.time()-time1)
- filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
+# filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
- phrases = [i for i in range(len(filtered_boxes))]
+# phrases = [i for i in range(len(filtered_boxes))]
- # draw boxes
- if draw_bbox_config:
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
- else:
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
+# # draw boxes
+# if draw_bbox_config:
+# annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
+# else:
+# annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
- pil_img = Image.fromarray(annotated_frame)
- buffered = io.BytesIO()
- pil_img.save(buffered, format="PNG")
- encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
- if output_coord_in_ratio:
- label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
- assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
+# pil_img = Image.fromarray(annotated_frame)
+# buffered = io.BytesIO()
+# pil_img.save(buffered, format="PNG")
+# encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
+# if output_coord_in_ratio:
+# label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
+# assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
- return encoded_image, label_coordinates, filtered_boxes_elem
+# return encoded_image, label_coordinates, filtered_boxes_elem
-def get_xywh(input):
- x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
- x, y, w, h = int(x), int(y), int(w), int(h)
- return x, y, w, h
+# def get_xywh(input):
+# x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
+# x, y, w, h = int(x), int(y), int(w), int(h)
+# return x, y, w, h
-def get_xyxy(input):
- x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
- x, y, xp, yp = int(x), int(y), int(xp), int(yp)
- return x, y, xp, yp
+# def get_xyxy(input):
+# x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
+# x, y, xp, yp = int(x), int(y), int(xp), int(yp)
+# return x, y, xp, yp
-def get_xywh_yolo(input):
- x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
- x, y, w, h = int(x), int(y), int(w), int(h)
- return x, y, w, h
+# def get_xywh_yolo(input):
+# x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
+# x, y, w, h = int(x), int(y), int(w), int(h)
+# return x, y, w, h
-def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
- if isinstance(image_source, str):
- image_source = Image.open(image_source)
- if image_source.mode == 'RGBA':
- # Convert RGBA to RGB to avoid alpha channel issues
- image_source = image_source.convert('RGB')
- image_np = np.array(image_source)
- w, h = image_source.size
- if use_paddleocr:
- if easyocr_args is None:
- text_threshold = 0.5
- else:
- text_threshold = easyocr_args['text_threshold']
- result = paddle_ocr.ocr(image_np, cls=False)[0]
- coord = [item[0] for item in result if item[1][1] > text_threshold]
- text = [item[1][0] for item in result if item[1][1] > text_threshold]
- else: # EasyOCR
- if easyocr_args is None:
- easyocr_args = {}
- result = reader.readtext(image_np, **easyocr_args)
- coord = [item[0] for item in result]
- text = [item[1] for item in result]
- if display_img:
- opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
- bb = []
- for item in coord:
- x, y, a, b = get_xywh(item)
- bb.append((x, y, a, b))
- cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
- # matplotlib expects RGB
- plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
- else:
- if output_bb_format == 'xywh':
- bb = [get_xywh(item) for item in coord]
- elif output_bb_format == 'xyxy':
- bb = [get_xyxy(item) for item in coord]
- return (text, bb), goal_filtering
\ No newline at end of file
+# def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
+# if isinstance(image_source, str):
+# image_source = Image.open(image_source)
+# if image_source.mode == 'RGBA':
+# # Convert RGBA to RGB to avoid alpha channel issues
+# image_source = image_source.convert('RGB')
+# image_np = np.array(image_source)
+# w, h = image_source.size
+# if use_paddleocr:
+# if easyocr_args is None:
+# text_threshold = 0.5
+# else:
+# text_threshold = easyocr_args['text_threshold']
+# result = paddle_ocr.ocr(image_np, cls=False)[0]
+# coord = [item[0] for item in result if item[1][1] > text_threshold]
+# text = [item[1][0] for item in result if item[1][1] > text_threshold]
+# else: # EasyOCR
+# if easyocr_args is None:
+# easyocr_args = {}
+# result = reader.readtext(image_np, **easyocr_args)
+# coord = [item[0] for item in result]
+# text = [item[1] for item in result]
+# if display_img:
+# opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
+# bb = []
+# for item in coord:
+# x, y, a, b = get_xywh(item)
+# bb.append((x, y, a, b))
+# cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
+# # matplotlib expects RGB
+# plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
+# else:
+# if output_bb_format == 'xywh':
+# bb = [get_xywh(item) for item in coord]
+# elif output_bb_format == 'xyxy':
+# bb = [get_xyxy(item) for item in coord]
+# return (text, bb), goal_filtering
\ No newline at end of file
diff --git a/util/vision_agent.py b/util/vision_agent.py
new file mode 100644
index 0000000..fe82b01
--- /dev/null
+++ b/util/vision_agent.py
@@ -0,0 +1,214 @@
+from typing import List, Optional
+import cv2
+import torch
+from ultralytics import YOLO
+from transformers import AutoModelForCausalLM, AutoProcessor
+import easyocr
+import supervision as sv
+import numpy as np
+import time
+from pydantic import BaseModel
+import base64
+
+class UIElement(BaseModel):
+ element_id: int
+ coordinates: list[float]
+ caption: Optional[str] = None
+ text: Optional[str] = None
+
+class VisionAgent:
+ def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
+ """
+ Computer vision agent for UI analysis.
+
+ Args:
+ yolo_model_path: Path to YOLO model weights
+ caption_model_path: Name/path to captioning model (defaults to Florence-2)
+ """
+ self.device = self._get_available_device()
+ self.dtype = self._get_dtype()
+ self.elements: List[UIElement] = []
+
+ self.yolo_model = YOLO(yolo_model_path)
+ self.caption_model = AutoModelForCausalLM.from_pretrained(
+ caption_model_path, trust_remote_code=True
+ ).to(self.device)
+ self.caption_processor = AutoProcessor.from_pretrained(
+ "microsoft/Florence-2-base", trust_remote_code=True
+ )
+ self.ocr_reader = easyocr.Reader(['en', 'ch_sim'])
+
+ def _get_available_device(self) -> str:
+ if torch.cuda.is_available():
+ return 'cuda'
+ if torch.backends.mps.is_available():
+ return 'mps'
+ return 'cpu'
+
+ def _get_dtype(self)-> torch.dtype:
+ if torch.cuda.is_available():
+ return torch.float16
+ return torch.float32
+
+ def _reset_state(self):
+ """Clear previous analysis results"""
+ self.elements = []
+
+ def analyze_image(self, image: np.ndarray) -> List[UIElement]:
+ """
+ Process an image through all computer vision pipelines.
+
+ Args:
+ image: Input image in BGR format (OpenCV default)
+
+ Returns:
+ List of detected UI elements with annotations
+ """
+ self._reset_state()
+
+ element_crops, boxes = self._detect_objects(image)
+ start = time.time()
+ element_texts = self._extract_text(element_crops)
+ end = time.time()
+ ocr_time = (end-start) * 10 ** 3
+ print(f"Speed: {ocr_time:.2f} ms OCR of {len(element_texts)} icons.")
+ start = time.time()
+ element_captions = self._get_caption(element_crops)
+ end = time.time()
+ caption_time = (end-start) * 10 ** 3
+ print(f"Speed: {caption_time:.2f} ms captioning of {len(element_captions)} icons.")
+ for idx in range(len(element_crops)):
+ print(idx, boxes[idx], element_texts[idx], element_captions[idx])
+ new_element = UIElement(element_id=idx,
+ coordinates=boxes[idx],
+ text=element_texts[idx][0] if len(element_texts[idx]) > 0 else '',
+ caption=element_captions[idx]
+ )
+ self.elements.append(new_element)
+
+ return self.elements
+
+ def _extract_text(self, images: np.ndarray) -> list[str]:
+ """
+ Run OCR in sequential mode
+ TODO: It is possible to run in batch mode for a speed up, but the result quality needs test.
+ https://github.com/JaidedAI/EasyOCR/pull/458
+ """
+ texts = []
+ for image in images:
+ text = self.ocr_reader.readtext(image, detail=0, paragraph=True, text_threshold=0.85)
+ texts.append(text)
+ # print(texts)
+ return texts
+
+
+ def _get_caption(self, images: np.ndarray, batch_size: int = 1) -> list[str]:
+ """Run captioning in batched mode. TODO: adjust batch size"""
+ prompt = ""
+ generated_texts = []
+ resized_images = []
+ for image in images:
+ resized_image = cv2.resize(image, (64, 64))
+ resized_images.append(resized_image)
+
+ for i in range(0, len(resized_images), batch_size):
+ batch_images = resized_images[i:i+batch_size]
+ inputs = self.caption_processor(
+ images=batch_images,
+ text=[prompt] * len(batch_images),
+ return_tensors="pt",
+ do_resize=True,
+ ).to(device=self.device, dtype=self.dtype)
+
+ generated_ids = self.caption_model.generate(
+ input_ids=inputs["input_ids"],
+ pixel_values=inputs["pixel_values"],
+ max_new_tokens=10,
+ num_beams=1,
+ do_sample=False,
+ early_stopping=False,
+ )
+
+ generated_text = self.caption_processor.batch_decode(
+ generated_ids, skip_special_tokens=True
+ )
+ generated_texts.extend([gen.strip() for gen in generated_text])
+
+ return generated_texts
+
+
+ def _detect_objects(self, image: np.ndarray) -> tuple[list[np.ndarray], list]:
+ """Run object detection pipeline"""
+ results = self.yolo_model(image)[0]
+ detections = sv.Detections.from_ultralytics(results)
+ boxes = detections.xyxy
+
+ if len(boxes) == 0:
+ return []
+
+ # Filter out boxes contained by others
+ areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+ sorted_indices = np.argsort(-areas) # Sort descending by area
+ sorted_boxes = boxes[sorted_indices]
+
+ keep_sorted = []
+ for i in range(len(sorted_boxes)):
+ contained = False
+ for j in keep_sorted:
+ box_b = sorted_boxes[j]
+ box_a = sorted_boxes[i]
+ if (box_b[0] <= box_a[0] and box_b[1] <= box_a[1] and
+ box_b[2] >= box_a[2] and box_b[3] >= box_a[3]):
+ contained = True
+ break
+ if not contained:
+ keep_sorted.append(i)
+
+ # Map back to original indices
+ keep_indices = sorted_indices[keep_sorted]
+ filtered_boxes = boxes[keep_indices]
+
+ # Extract element crops
+ element_crops = []
+ for box in filtered_boxes:
+ x1, y1, x2, y2 = map(int, map(round, box))
+ element = image[y1:y2, x1:x2]
+ element_crops.append(np.array(element))
+
+ return element_crops, filtered_boxes
+
+ def load_image(self, image_source: str) -> np.ndarray:
+ try:
+ # 处理可能存在的Data URL前缀(如 "data:image/png;base64,")
+ if ',' in image_source:
+ _, payload = image_source.split(',', 1)
+ else:
+ payload = image_source
+
+ # Base64解码 -> bytes -> numpy数组
+ image_bytes = base64.b64decode(payload)
+ np_array = np.frombuffer(image_bytes, dtype=np.uint8)
+
+ # OpenCV解码图像
+ image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
+
+ if image is None:
+ raise ValueError("解码图片失败:无效的图片数据")
+
+ return self.analyze_image(image)
+
+ except (base64.binascii.Error, ValueError) as e:
+ # 生成更清晰的错误信息
+ error_msg = f"输入既不是有效的文件路径,也不是有效的Base64图片数据"
+ raise ValueError(error_msg) from e
+
+
+ def __call__(self, image_source: str) -> List[UIElement]:
+ """Process an image from file path."""
+ image = self.load_image(image_source)
+ if image is None:
+ raise FileNotFoundError(f"Vision agent: 图片读取失败")
+
+ return self.analyze_image(image)
+
+