From 23331aaa67c75e66ddccbc096eb0bacc38815bc7 Mon Sep 17 00:00:00 2001 From: yuruo Date: Wed, 12 Mar 2025 22:40:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .huggingface/config.json | 7 + README.md | 13 +- README_CN.md | 8 +- SUPPORT_MODEL.md | 4 +- gradio_ui/agent/task_run_agent.py | 3 +- gradio_ui/agent/vision_agent.py | 29 +- gradio_ui/agent/vlm_agent.py | 299 ----------------- gradio_ui/app.py | 3 +- gradio_ui/loop.py | 4 - gradio_ui/tools/screen_capture.py | 3 +- requirements.txt | 26 +- util/box_annotator.py | 262 --------------- util/omniparser.py | 32 -- util/tool.py | 1 - util/utils.py | 540 ------------------------------ 15 files changed, 35 insertions(+), 1199 deletions(-) create mode 100644 .huggingface/config.json delete mode 100644 gradio_ui/agent/vlm_agent.py delete mode 100644 util/box_annotator.py delete mode 100644 util/omniparser.py delete mode 100644 util/utils.py diff --git a/.huggingface/config.json b/.huggingface/config.json new file mode 100644 index 0000000..6ff8de8 --- /dev/null +++ b/.huggingface/config.json @@ -0,0 +1,7 @@ +{ + "hf_endpoint": "https://huggingface.co", + "hf_hub_disable_symlinks_warning": true, + "hf_hub_disable_experimental_warning": true, + "hf_hub_disable_automatic_mirror": true, + "use_auth_token": false +} \ No newline at end of file diff --git a/README.md b/README.md index 3c4f35f..24f7038 100644 --- a/README.md +++ b/README.md @@ -77,10 +77,15 @@ For supported vendors and models, please refer to this [link](./SUPPORT_MODEL.md ### 🔧 CUDA Version Mismatch If you see the error: "GPU driver incompatible, please install appropriate torch version according to readme", it indicates a driver incompatibility. You can either: -1. Run with CPU only (slower but functional) -2. Check your torch version with `pip list` -3. Check supported CUDA versions on the [official website](https://pytorch.org/get-started/locally/) -4. Reinstall Nvidia drivers +1. Run pip list to check the torch version; +2. Check supported CUDA versions on the official website; +3. Copy the official torch installation command and reinstall torch for your CUDA version. + +For example, if your CUDA version is 12.4, install torch using this command: + +```bash +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 +``` ## 🤝 Contributing diff --git a/README_CN.md b/README_CN.md index 71d9648..4b2f647 100644 --- a/README_CN.md +++ b/README_CN.md @@ -78,7 +78,13 @@ python main.py 1. 运行`pip list`查看torch版本; 2. 从[官网](https://pytorch.org/get-started/locally/)查看支持的cuda版本; -3. 重新安装Nvidia驱动。 +3. 复制官方的 torch 安装命令,重新安装适合自己 cuda 版本的 torch。 + +比如我的 cuda 版本为 12.4,需要按照如下命令来安装 torch; + +```bash +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 +``` ### 模型无法下载 多半是被墙了,可以从百度网盘直接下载模型。 diff --git a/SUPPORT_MODEL.md b/SUPPORT_MODEL.md index ceaf180..d8a213a 100644 --- a/SUPPORT_MODEL.md +++ b/SUPPORT_MODEL.md @@ -1,5 +1,3 @@ | Vendor-en | Vendor-ch | Model | base-url | | --- | --- | --- | --- | -| Alibaba Cloud Bailian | 阿里云百炼 | deepseek-r1 | https://dashscope.aliyuncs.com/compatible-mode/v1 | -| Alibaba Cloud Bailian | 阿里云百炼 | deepseek-v3 | https://dashscope.aliyuncs.com/compatible-mode/v1 | -| deepseek | deepseek官方 | deepseek-chat | https://api.deepseek.com | \ No newline at end of file +| openainext | openainext | gpt-4o-2024-11-20 | https://api.openai-next.com/v1 | diff --git a/gradio_ui/agent/task_run_agent.py b/gradio_ui/agent/task_run_agent.py index 0034acb..fc3d953 100644 --- a/gradio_ui/agent/task_run_agent.py +++ b/gradio_ui/agent/task_run_agent.py @@ -1,9 +1,8 @@ import json import uuid from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage -from PIL import Image, ImageDraw +from PIL import ImageDraw import base64 -from gradio import Image from io import BytesIO from pydantic import BaseModel, Field from gradio_ui.agent.base_agent import BaseAgent diff --git a/gradio_ui/agent/vision_agent.py b/gradio_ui/agent/vision_agent.py index 8c87b76..ad2bd1c 100644 --- a/gradio_ui/agent/vision_agent.py +++ b/gradio_ui/agent/vision_agent.py @@ -10,7 +10,6 @@ import time from pydantic import BaseModel import base64 from PIL import Image - class UIElement(BaseModel): element_id: int coordinates: list[float] @@ -29,7 +28,8 @@ class VisionAgent: # 确定可用的设备和最佳数据类型 self.device, self.dtype = self._get_optimal_device_and_dtype() print(f"使用设备: {self.device}, 数据类型: {self.dtype}") - + + # os.environ['HF_ENDPOINT'] = 'https://huggingface.co' # 加载YOLO模型 self.yolo_model = YOLO(yolo_model_path) @@ -42,27 +42,6 @@ class VisionAgent: # 根据设备类型加载模型 try: print(f"正在加载图像描述模型: {caption_model_path}") - # if self.device.type == 'cuda': - # # CUDA设备使用float16 - # self.caption_model = AutoModelForCausalLM.from_pretrained( - # caption_model_path, - # torch_dtype=torch.float16, - # trust_remote_code=True - # ).to(self.device) - # elif self.device.type == 'mps': - # # MPS设备使用float32(MPS对float16支持有限) - # self.caption_model = AutoModelForCausalLM.from_pretrained( - # caption_model_path, - # torch_dtype=torch.float32, - # trust_remote_code=True - # ).to(self.device) - # else: - # # CPU使用float32 - # self.caption_model = AutoModelForCausalLM.from_pretrained( - # caption_model_path, - # torch_dtype=torch.float32, - # trust_remote_code=True - # ).to(self.device) self.caption_model = AutoModelForCausalLM.from_pretrained( caption_model_path, torch_dtype=self.dtype, @@ -220,8 +199,8 @@ class VisionAgent: generated_ids = self.caption_model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], - max_new_tokens=20, - num_beams=1, + max_new_tokens=128, + num_beams=4, do_sample=False ) else: diff --git a/gradio_ui/agent/vlm_agent.py b/gradio_ui/agent/vlm_agent.py deleted file mode 100644 index b6d19b9..0000000 --- a/gradio_ui/agent/vlm_agent.py +++ /dev/null @@ -1,299 +0,0 @@ -import json -from collections.abc import Callable -from typing import Callable -import uuid -from PIL import Image, ImageDraw -import base64 -from io import BytesIO -from anthropic import APIResponse -from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage -from gradio_ui.agent.llm_utils.oaiclient import run_oai_interleaved -from gradio_ui.agent.llm_utils.utils import is_image_path -import time -import re - -OUTPUT_DIR = "./tmp/outputs" - -def extract_data(input_string, data_type): - # Regular expression to extract content starting from '```python' until the end if there are no closing backticks - pattern = f"```{data_type}" + r"(.*?)(```|$)" - # Extract content - # re.DOTALL allows '.' to match newlines as well - matches = re.findall(pattern, input_string, re.DOTALL) - # Return the first match if exists, trimming whitespace and ignoring potential closing backticks - return matches[0][0].strip() if matches else input_string - -class VLMAgent: - def __init__( - self, - model: str, - api_key: str, - output_callback: Callable, - api_response_callback: Callable, - max_tokens: int = 4096, - base_url:str = "", - only_n_most_recent_images: int | None = None, - print_usage: bool = True, - ): - self.base_url = base_url - self.api_key = api_key - self.api_response_callback = api_response_callback - self.max_tokens = max_tokens - self.only_n_most_recent_images = only_n_most_recent_images - self.output_callback = output_callback - self.model = model - self.print_usage = print_usage - self.total_token_usage = 0 - self.total_cost = 0 - self.step_count = 0 - - self.system = '' - - def __call__(self, messages: list, parsed_screen: list[str, list, dict]): - self.step_count += 1 - image_base64 = parsed_screen['original_screenshot_base64'] - latency_omniparser = parsed_screen['latency'] - self.output_callback(f'-- Step {self.step_count}: --', sender="bot") - screen_info = str(parsed_screen['screen_info']) - screenshot_uuid = parsed_screen['screenshot_uuid'] - screen_width, screen_height = parsed_screen['width'], parsed_screen['height'] - - boxids_and_labels = parsed_screen["screen_info"] - system = self._get_system_prompt(boxids_and_labels) - - # drop looping actions msg, byte image etc - planner_messages = messages - _remove_som_images(planner_messages) - _maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images) - - if isinstance(planner_messages[-1], dict): - if not isinstance(planner_messages[-1]["content"], list): - planner_messages[-1]["content"] = [planner_messages[-1]["content"]] - planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png") - planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png") - - start = time.time() - vlm_response, token_usage = run_oai_interleaved( - messages=planner_messages, - system=system, - model_name=self.model, - api_key=self.api_key, - max_tokens=self.max_tokens, - provider_base_url=self.base_url, - temperature=0, - ) - latency_vlm = time.time() - start - self.output_callback(f"LLM: {latency_vlm:.2f}s, OmniParser: {latency_omniparser:.2f}s", sender="bot") - - print(f"llm_response: {vlm_response}") - if self.print_usage: - print(f"Total token so far: {token_usage}. Total cost so far: $USD{self.total_cost:.5f}") - - vlm_response_json = extract_data(vlm_response, "json") - vlm_response_json = json.loads(vlm_response_json) - - img_to_show_base64 = parsed_screen["som_image_base64"] - if "Box ID" in vlm_response_json: - try: - bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["Box ID"])]["bbox"] - vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)] - img_to_show_data = base64.b64decode(img_to_show_base64) - img_to_show = Image.open(BytesIO(img_to_show_data)) - - draw = ImageDraw.Draw(img_to_show) - x, y = vlm_response_json["box_centroid_coordinate"] - radius = 10 - draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red') - draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2) - - buffered = BytesIO() - img_to_show.save(buffered, format="PNG") - img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") - except: - print(f"Error parsing: {vlm_response_json}") - pass - self.output_callback(f'', sender="bot") - self.output_callback( - f'
' - f' Parsed Screen elemetns by OmniParser' - f'
{screen_info}
' - f'
', - sender="bot" - ) - vlm_plan_str = "" - for key, value in vlm_response_json.items(): - if key == "Reasoning": - vlm_plan_str += f'{value}' - else: - vlm_plan_str += f'\n{key}: {value}' - - # construct the response so that anthropicExcutor can execute the tool - response_content = [BetaTextBlock(text=vlm_plan_str, type='text')] - if 'box_centroid_coordinate' in vlm_response_json: - move_cursor_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', - input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]}, - name='computer', type='tool_use') - response_content.append(move_cursor_block) - - if vlm_response_json["Next Action"] == "None": - print("Task paused/completed.") - elif vlm_response_json["Next Action"] == "type": - sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', - input={'action': vlm_response_json["Next Action"], 'text': vlm_response_json["value"]}, - name='computer', type='tool_use') - response_content.append(sim_content_block) - else: - sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', - input={'action': vlm_response_json["Next Action"]}, - name='computer', type='tool_use') - response_content.append(sim_content_block) - response_message = BetaMessage(id=f'toolu_{uuid.uuid4()}', content=response_content, model='', role='assistant', type='message', stop_reason='tool_use', usage=BetaUsage(input_tokens=0, output_tokens=0)) - return response_message, vlm_response_json - - def _api_response_callback(self, response: APIResponse): - self.api_response_callback(response) - - def _get_system_prompt(self, screen_info: str = ""): - main_section = f""" -You are using a Windows device. -You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot. -You can only interact with the desktop GUI (no terminal or application menu access). -You may be given some history plan and actions, this is the response from the previous loop. -You should carefully consider your plan base on the task, screenshot, and history actions. - - -Here is the list of all detected bounding boxes by IDs on the screen and their description:{screen_info} - -Your available "Next Action" only include: -- type: types a string of text. -- left_click: move mouse to box id and left clicks. -- right_click: move mouse to box id and right clicks. -- double_click: move mouse to box id and double clicks. -- hover: move mouse to box id. -- scroll_up: scrolls the screen up to view previous content. -- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content. -- wait: waits for 1 second for the device to load or respond. - -Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task. - -Output format: -```json -{{ - "Reasoning": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time. - "Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely. - "Box ID": n, - "value": "xxx" # only provide value field if the action is type, else don't include value key -}} -``` - -One Example: -```json -{{ - "Reasoning": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.", - "Next Action": "left_click", - "Box ID": m -}} -``` - -Another Example: -```json -{{ - "Reasoning": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.", - "Next Action": "type", - "Box ID": n, - "value": "Apple watch" -}} -``` - -Another Example: -```json -{{ - "Reasoning": "The current screen does not show 'submit' button, I need to scroll down to see if the button is available.", - "Next Action": "scroll_down", -}} -``` - -IMPORTANT NOTES: -1. You should only give a single action at a time. - -""" - thinking_model = ("r1" in self.model) or ("reasoner" in self.model) - if not thinking_model: - main_section += """ -2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. - -""" - else: - main_section += """ -2. In XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In XML tags put the next action prediction JSON. - -""" - main_section += """ -3. Attach the next action prediction in the "Next Action". -4. You should not include other actions, such as keyboard shortcuts. -5. When the task is completed, don't complete additional actions. You should say "Next Action": "None" in the json field. -6. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions. -7. avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action. -8. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Next Action": "None" in the json field. -""" - - return main_section - -def _remove_som_images(messages): - for msg in messages: - msg_content = msg["content"] - if isinstance(msg_content, list): - msg["content"] = [ - cnt for cnt in msg_content - if not (isinstance(cnt, str) and 'som' in cnt and is_image_path(cnt)) - ] - - -def _maybe_filter_to_n_most_recent_images( - messages: list[BetaMessageParam], - images_to_keep: int, - min_removal_threshold: int = 10, -): - """ - With the assumption that images are screenshots that are of diminishing value as - the conversation progresses, remove all but the final `images_to_keep` tool_result - images in place - """ - if images_to_keep is None: - return messages - - total_images = 0 - for msg in messages: - for cnt in msg.get("content", []): - if isinstance(cnt, str) and is_image_path(cnt): - total_images += 1 - elif isinstance(cnt, dict) and cnt.get("type") == "tool_result": - for content in cnt.get("content", []): - if isinstance(content, dict) and content.get("type") == "image": - total_images += 1 - - images_to_remove = total_images - images_to_keep - - for msg in messages: - msg_content = msg["content"] - if isinstance(msg_content, list): - new_content = [] - for cnt in msg_content: - # Remove images from SOM or screenshot as needed - if isinstance(cnt, str) and is_image_path(cnt): - if images_to_remove > 0: - images_to_remove -= 1 - continue - # VLM shouldn't use anthropic screenshot tool so shouldn't have these but in case it does, remove as needed - elif isinstance(cnt, dict) and cnt.get("type") == "tool_result": - new_tool_result_content = [] - for tool_result_entry in cnt.get("content", []): - if isinstance(tool_result_entry, dict) and tool_result_entry.get("type") == "image": - if images_to_remove > 0: - images_to_remove -= 1 - continue - new_tool_result_content.append(tool_result_entry) - cnt["content"] = new_tool_result_content - # Append fixed content to current message's content list - new_content.append(cnt) - msg["content"] = new_content \ No newline at end of file diff --git a/gradio_ui/app.py b/gradio_ui/app.py index ee27e4a..bc6f519 100644 --- a/gradio_ui/app.py +++ b/gradio_ui/app.py @@ -297,7 +297,8 @@ def run(): chatbot = gr.Chatbot( label="Chatbot History", autoscroll=True, - height=580 ) + height=580 + ) def update_model(model, state): state["model"] = model diff --git a/gradio_ui/loop.py b/gradio_ui/loop.py index 8e40c1e..97665ad 100644 --- a/gradio_ui/loop.py +++ b/gradio_ui/loop.py @@ -15,11 +15,7 @@ from anthropic.types.beta import ( from gradio_ui.agent.task_plan_agent import TaskPlanAgent from gradio_ui.agent.task_run_agent import TaskRunAgent from gradio_ui.tools import ToolResult -from gradio_ui.agent.llm_utils.utils import encode_image -from gradio_ui.agent.llm_utils.omniparserclient import OmniParserClient -from gradio_ui.agent.vlm_agent import VLMAgent from gradio_ui.executor.anthropic_executor import AnthropicExecutor -from pathlib import Path import numpy as np from PIL import Image diff --git a/gradio_ui/tools/screen_capture.py b/gradio_ui/tools/screen_capture.py index 5554b3d..430aee2 100644 --- a/gradio_ui/tools/screen_capture.py +++ b/gradio_ui/tools/screen_capture.py @@ -1,8 +1,7 @@ from pathlib import Path from uuid import uuid4 from PIL import Image -from .base import BaseAnthropicTool, ToolError -from io import BytesIO +from .base import ToolError from util import tool OUTPUT_DIR = "./tmp/outputs" diff --git a/requirements.txt b/requirements.txt index 107236c..1518408 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,32 +2,12 @@ torch easyocr torchvision supervision==0.18.0 -openai>=1.3.5 transformers ultralytics==8.3.70 -azure-identity numpy==1.26.4 -opencv-python -opencv-python-headless gradio -dill -accelerate -timm -einops==0.8.0 -paddlepaddle -paddleocr -ruff==0.6.7 -pre-commit==3.8.0 -pytest==8.3.3 -pytest-asyncio==0.23.6 pyautogui==0.9.54 -streamlit>=1.38.0 anthropic[bedrock,vertex]>=0.37.1 -jsonschema==4.22.0 -boto3>=1.28.57 -google-auth<3,>=2 -screeninfo -uiautomation -dashscope -groq -pyxbrain==1.1.31 \ No newline at end of file +pyxbrain==1.1.31 +timm +einops==0.8.0 \ No newline at end of file diff --git a/util/box_annotator.py b/util/box_annotator.py deleted file mode 100644 index 24d1f4c..0000000 --- a/util/box_annotator.py +++ /dev/null @@ -1,262 +0,0 @@ -# from typing import List, Optional, Union, Tuple - -# import cv2 -# import numpy as np - -# from supervision.detection.core import Detections -# from supervision.draw.color import Color, ColorPalette - - -# class BoxAnnotator: -# """ -# A class for drawing bounding boxes on an image using detections provided. - -# Attributes: -# color (Union[Color, ColorPalette]): The color to draw the bounding box, -# can be a single color or a color palette -# thickness (int): The thickness of the bounding box lines, default is 2 -# text_color (Color): The color of the text on the bounding box, default is white -# text_scale (float): The scale of the text on the bounding box, default is 0.5 -# text_thickness (int): The thickness of the text on the bounding box, -# default is 1 -# text_padding (int): The padding around the text on the bounding box, -# default is 5 - -# """ - -# def __init__( -# self, -# color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, -# thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo -# text_color: Color = Color.BLACK, -# text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web -# text_thickness: int = 2, #1, # 2 for demo -# text_padding: int = 10, -# avoid_overlap: bool = True, -# ): -# self.color: Union[Color, ColorPalette] = color -# self.thickness: int = thickness -# self.text_color: Color = text_color -# self.text_scale: float = text_scale -# self.text_thickness: int = text_thickness -# self.text_padding: int = text_padding -# self.avoid_overlap: bool = avoid_overlap - -# def annotate( -# self, -# scene: np.ndarray, -# detections: Detections, -# labels: Optional[List[str]] = None, -# skip_label: bool = False, -# image_size: Optional[Tuple[int, int]] = None, -# ) -> np.ndarray: -# """ -# Draws bounding boxes on the frame using the detections provided. - -# Args: -# scene (np.ndarray): The image on which the bounding boxes will be drawn -# detections (Detections): The detections for which the -# bounding boxes will be drawn -# labels (Optional[List[str]]): An optional list of labels -# corresponding to each detection. If `labels` are not provided, -# corresponding `class_id` will be used as label. -# skip_label (bool): Is set to `True`, skips bounding box label annotation. -# Returns: -# np.ndarray: The image with the bounding boxes drawn on it - -# Example: -# ```python -# import supervision as sv - -# classes = ['person', ...] -# image = ... -# detections = sv.Detections(...) - -# box_annotator = sv.BoxAnnotator() -# labels = [ -# f"{classes[class_id]} {confidence:0.2f}" -# for _, _, confidence, class_id, _ in detections -# ] -# annotated_frame = box_annotator.annotate( -# scene=image.copy(), -# detections=detections, -# labels=labels -# ) -# ``` -# """ -# font = cv2.FONT_HERSHEY_SIMPLEX -# for i in range(len(detections)): -# x1, y1, x2, y2 = detections.xyxy[i].astype(int) -# class_id = ( -# detections.class_id[i] if detections.class_id is not None else None -# ) -# idx = class_id if class_id is not None else i -# color = ( -# self.color.by_idx(idx) -# if isinstance(self.color, ColorPalette) -# else self.color -# ) -# cv2.rectangle( -# img=scene, -# pt1=(x1, y1), -# pt2=(x2, y2), -# color=color.as_bgr(), -# thickness=self.thickness, -# ) -# if skip_label: -# continue - -# text = ( -# f"{class_id}" -# if (labels is None or len(detections) != len(labels)) -# else labels[i] -# ) - -# text_width, text_height = cv2.getTextSize( -# text=text, -# fontFace=font, -# fontScale=self.text_scale, -# thickness=self.text_thickness, -# )[0] - -# if not self.avoid_overlap: -# text_x = x1 + self.text_padding -# text_y = y1 - self.text_padding - -# text_background_x1 = x1 -# text_background_y1 = y1 - 2 * self.text_padding - text_height - -# text_background_x2 = x1 + 2 * self.text_padding + text_width -# text_background_y2 = y1 -# # text_x = x1 - self.text_padding - text_width -# # text_y = y1 + self.text_padding + text_height -# # text_background_x1 = x1 - 2 * self.text_padding - text_width -# # text_background_y1 = y1 -# # text_background_x2 = x1 -# # text_background_y2 = y1 + 2 * self.text_padding + text_height -# else: -# text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size) - -# cv2.rectangle( -# img=scene, -# pt1=(text_background_x1, text_background_y1), -# pt2=(text_background_x2, text_background_y2), -# color=color.as_bgr(), -# thickness=cv2.FILLED, -# ) -# # import pdb; pdb.set_trace() -# box_color = color.as_rgb() -# luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] -# text_color = (0,0,0) if luminance > 160 else (255,255,255) -# cv2.putText( -# img=scene, -# text=text, -# org=(text_x, text_y), -# fontFace=font, -# fontScale=self.text_scale, -# # color=self.text_color.as_rgb(), -# color=text_color, -# thickness=self.text_thickness, -# lineType=cv2.LINE_AA, -# ) -# return scene - - -# def box_area(box): -# return (box[2] - box[0]) * (box[3] - box[1]) - -# def intersection_area(box1, box2): -# x1 = max(box1[0], box2[0]) -# y1 = max(box1[1], box2[1]) -# x2 = min(box1[2], box2[2]) -# y2 = min(box1[3], box2[3]) -# return max(0, x2 - x1) * max(0, y2 - y1) - -# def IoU(box1, box2, return_max=True): -# intersection = intersection_area(box1, box2) -# union = box_area(box1) + box_area(box2) - intersection -# if box_area(box1) > 0 and box_area(box2) > 0: -# ratio1 = intersection / box_area(box1) -# ratio2 = intersection / box_area(box2) -# else: -# ratio1, ratio2 = 0, 0 -# if return_max: -# return max(intersection / union, ratio1, ratio2) -# else: -# return intersection / union - - -# def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size): -# """ check overlap of text and background detection box, and get_optimal_label_pos, -# pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right -# Threshold: default to 0.3 -# """ - -# def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size): -# is_overlap = False -# for i in range(len(detections)): -# detection = detections.xyxy[i].astype(int) -# if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3: -# is_overlap = True -# break -# # check if the text is out of the image -# if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]: -# is_overlap = True -# return is_overlap - -# # if pos == 'top left': -# text_x = x1 + text_padding -# text_y = y1 - text_padding - -# text_background_x1 = x1 -# text_background_y1 = y1 - 2 * text_padding - text_height - -# text_background_x2 = x1 + 2 * text_padding + text_width -# text_background_y2 = y1 -# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) -# if not is_overlap: -# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 - -# # elif pos == 'outer left': -# text_x = x1 - text_padding - text_width -# text_y = y1 + text_padding + text_height - -# text_background_x1 = x1 - 2 * text_padding - text_width -# text_background_y1 = y1 - -# text_background_x2 = x1 -# text_background_y2 = y1 + 2 * text_padding + text_height -# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) -# if not is_overlap: -# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 - - -# # elif pos == 'outer right': -# text_x = x2 + text_padding -# text_y = y1 + text_padding + text_height - -# text_background_x1 = x2 -# text_background_y1 = y1 - -# text_background_x2 = x2 + 2 * text_padding + text_width -# text_background_y2 = y1 + 2 * text_padding + text_height - -# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) -# if not is_overlap: -# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 - -# # elif pos == 'top right': -# text_x = x2 - text_padding - text_width -# text_y = y1 - text_padding - -# text_background_x1 = x2 - 2 * text_padding - text_width -# text_background_y1 = y1 - 2 * text_padding - text_height - -# text_background_x2 = x2 -# text_background_y2 = y1 - -# is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) -# if not is_overlap: -# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 - -# return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 diff --git a/util/omniparser.py b/util/omniparser.py deleted file mode 100644 index e534597..0000000 --- a/util/omniparser.py +++ /dev/null @@ -1,32 +0,0 @@ -# from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box -# import torch -# from PIL import Image -# import io -# import base64 -# from typing import Dict -# class Omniparser(object): -# def __init__(self, config: Dict): -# self.config = config -# device = 'cuda' if torch.cuda.is_available() else 'cpu' -# self.som_model = get_yolo_model(model_path=config['som_model_path']) -# self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device) -# print('Server initialized!') - -# def parse(self, image_base64: str): -# image_bytes = base64.b64decode(image_base64) -# image = Image.open(io.BytesIO(image_bytes)) -# print('image size:', image.size) - -# box_overlay_ratio = max(image.size) / 3200 -# draw_bbox_config = { -# 'text_scale': 0.8 * box_overlay_ratio, -# 'text_thickness': max(int(2 * box_overlay_ratio), 1), -# 'text_padding': max(int(3 * box_overlay_ratio), 1), -# 'thickness': max(int(3 * box_overlay_ratio), 1), -# } - -# (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False) -# dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128) - -# return dino_labled_img, parsed_content_list - diff --git a/util/tool.py b/util/tool.py index 1dd973a..309f32f 100644 --- a/util/tool.py +++ b/util/tool.py @@ -2,7 +2,6 @@ import os import shlex import subprocess import threading -import traceback import pyautogui from PIL import Image from io import BytesIO diff --git a/util/utils.py b/util/utils.py deleted file mode 100644 index 5832aa5..0000000 --- a/util/utils.py +++ /dev/null @@ -1,540 +0,0 @@ -# # from ultralytics import YOLO -# import os -# import io -# import base64 -# import time -# from PIL import Image, ImageDraw, ImageFont -# import json -# import requests -# # utility function -# import os -# from openai import AzureOpenAI - -# import json -# import sys -# import os -# import cv2 -# import numpy as np -# # %matplotlib inline -# from matplotlib import pyplot as plt -# import easyocr -# from paddleocr import PaddleOCR -# reader = easyocr.Reader(['en', 'ch_sim']) -# paddle_ocr = PaddleOCR( -# lang='ch', # other lang also available -# use_angle_cls=False, -# use_gpu=False, # using cuda will conflict with pytorch in the same process -# show_log=False, -# max_batch_size=1024, -# use_dilation=True, # improves accuracy -# det_db_score_mode='slow', # improves accuracy -# rec_batch_num=1024) -# import time -# import base64 - -# import os -# import ast -# import torch -# from typing import Tuple, List, Union -# from torchvision.ops import box_convert -# import re -# from torchvision.transforms import ToPILImage -# import supervision as sv -# import torchvision.transforms as T -# from util.box_annotator import BoxAnnotator - - -# def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): -# if not device: -# device = "cuda" if torch.cuda.is_available() else "cpu" -# if model_name == "blip2": -# from transformers import Blip2Processor, Blip2ForConditionalGeneration -# processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") -# if device == 'cpu': -# model = Blip2ForConditionalGeneration.from_pretrained( -# model_name_or_path, device_map=None, torch_dtype=torch.float32 -# ) -# else: -# model = Blip2ForConditionalGeneration.from_pretrained( -# model_name_or_path, device_map=None, torch_dtype=torch.float16 -# ).to(device) -# elif model_name == "florence2": -# from transformers import AutoProcessor, AutoModelForCausalLM -# processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) -# if device == 'cpu': -# model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True) -# else: -# model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device) -# return {'model': model.to(device), 'processor': processor} - - -# def get_yolo_model(model_path): -# from ultralytics import YOLO -# # Load the model. -# model = YOLO(model_path) -# return model - - -# @torch.inference_mode() -# def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128): -# # Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model -# to_pil = ToPILImage() -# if starting_idx: -# non_ocr_boxes = filtered_boxes[starting_idx:] -# else: -# non_ocr_boxes = filtered_boxes -# croped_pil_image = [] -# for i, coord in enumerate(non_ocr_boxes): -# try: -# xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) -# ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) -# cropped_image = image_source[ymin:ymax, xmin:xmax, :] -# cropped_image = cv2.resize(cropped_image, (64, 64)) -# croped_pil_image.append(to_pil(cropped_image)) -# except: -# continue - -# model, processor = caption_model_processor['model'], caption_model_processor['processor'] -# if not prompt: -# if 'florence' in model.config.model_type: -# prompt = "" -# else: -# prompt = "The image shows" - -# generated_texts = [] -# device = model.device -# for i in range(0, len(croped_pil_image), batch_size): -# start = time.time() -# batch = croped_pil_image[i:i+batch_size] -# t1 = time.time() -# if model.device.type == 'cuda': -# inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16) -# else: -# inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device) -# if 'florence' in model.config.model_type: -# generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False) -# else: -# generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True, -# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) -# generated_text = [gen.strip() for gen in generated_text] -# generated_texts.extend(generated_text) - -# return generated_texts - - - -# def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor): -# to_pil = ToPILImage() -# if ocr_bbox: -# non_ocr_boxes = filtered_boxes[len(ocr_bbox):] -# else: -# non_ocr_boxes = filtered_boxes -# croped_pil_image = [] -# for i, coord in enumerate(non_ocr_boxes): -# xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) -# ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) -# cropped_image = image_source[ymin:ymax, xmin:xmax, :] -# croped_pil_image.append(to_pil(cropped_image)) - -# model, processor = caption_model_processor['model'], caption_model_processor['processor'] -# device = model.device -# messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}] -# prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - -# batch_size = 5 # Number of samples per batch -# generated_texts = [] - -# for i in range(0, len(croped_pil_image), batch_size): -# images = croped_pil_image[i:i+batch_size] -# image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images] -# inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []} -# texts = [prompt] * len(images) -# for i, txt in enumerate(texts): -# input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt") -# inputs['input_ids'].append(input['input_ids']) -# inputs['attention_mask'].append(input['attention_mask']) -# inputs['pixel_values'].append(input['pixel_values']) -# inputs['image_sizes'].append(input['image_sizes']) -# max_len = max([x.shape[1] for x in inputs['input_ids']]) -# for i, v in enumerate(inputs['input_ids']): -# inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1) -# inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1) -# inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()} - -# generation_args = { -# "max_new_tokens": 25, -# "temperature": 0.01, -# "do_sample": False, -# } -# generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) -# # # remove input tokens -# generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:] -# response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) -# response = [res.strip('\n').strip() for res in response] -# generated_texts.extend(response) - -# return generated_texts - -# def remove_overlap(boxes, iou_threshold, ocr_bbox=None): -# assert ocr_bbox is None or isinstance(ocr_bbox, List) - -# def box_area(box): -# return (box[2] - box[0]) * (box[3] - box[1]) - -# def intersection_area(box1, box2): -# x1 = max(box1[0], box2[0]) -# y1 = max(box1[1], box2[1]) -# x2 = min(box1[2], box2[2]) -# y2 = min(box1[3], box2[3]) -# return max(0, x2 - x1) * max(0, y2 - y1) - -# def IoU(box1, box2): -# intersection = intersection_area(box1, box2) -# union = box_area(box1) + box_area(box2) - intersection + 1e-6 -# if box_area(box1) > 0 and box_area(box2) > 0: -# ratio1 = intersection / box_area(box1) -# ratio2 = intersection / box_area(box2) -# else: -# ratio1, ratio2 = 0, 0 -# return max(intersection / union, ratio1, ratio2) - -# def is_inside(box1, box2): -# # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] -# intersection = intersection_area(box1, box2) -# ratio1 = intersection / box_area(box1) -# return ratio1 > 0.95 - -# boxes = boxes.tolist() -# filtered_boxes = [] -# if ocr_bbox: -# filtered_boxes.extend(ocr_bbox) -# # print('ocr_bbox!!!', ocr_bbox) -# for i, box1 in enumerate(boxes): -# # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j): -# is_valid_box = True -# for j, box2 in enumerate(boxes): -# # keep the smaller box -# if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2): -# is_valid_box = False -# break -# if is_valid_box: -# # add the following 2 lines to include ocr bbox -# if ocr_bbox: -# # only add the box if it does not overlap with any ocr bbox -# if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)): -# filtered_boxes.append(box1) -# else: -# filtered_boxes.append(box1) -# return torch.tensor(filtered_boxes) - - -# def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None): -# ''' -# ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...] -# boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...] - -# ''' -# assert ocr_bbox is None or isinstance(ocr_bbox, List) - -# def box_area(box): -# return (box[2] - box[0]) * (box[3] - box[1]) - -# def intersection_area(box1, box2): -# x1 = max(box1[0], box2[0]) -# y1 = max(box1[1], box2[1]) -# x2 = min(box1[2], box2[2]) -# y2 = min(box1[3], box2[3]) -# return max(0, x2 - x1) * max(0, y2 - y1) - -# def IoU(box1, box2): -# intersection = intersection_area(box1, box2) -# union = box_area(box1) + box_area(box2) - intersection + 1e-6 -# if box_area(box1) > 0 and box_area(box2) > 0: -# ratio1 = intersection / box_area(box1) -# ratio2 = intersection / box_area(box2) -# else: -# ratio1, ratio2 = 0, 0 -# return max(intersection / union, ratio1, ratio2) - -# def is_inside(box1, box2): -# # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] -# intersection = intersection_area(box1, box2) -# ratio1 = intersection / box_area(box1) -# return ratio1 > 0.80 - -# # boxes = boxes.tolist() -# filtered_boxes = [] -# if ocr_bbox: -# filtered_boxes.extend(ocr_bbox) -# # print('ocr_bbox!!!', ocr_bbox) -# for i, box1_elem in enumerate(boxes): -# box1 = box1_elem['bbox'] -# is_valid_box = True -# for j, box2_elem in enumerate(boxes): -# # keep the smaller box -# box2 = box2_elem['bbox'] -# if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2): -# is_valid_box = False -# break -# if is_valid_box: -# if ocr_bbox: -# # keep yolo boxes + prioritize ocr label -# box_added = False -# ocr_labels = '' -# for box3_elem in ocr_bbox: -# if not box_added: -# box3 = box3_elem['bbox'] -# if is_inside(box3, box1): # ocr inside icon -# # box_added = True -# # delete the box3_elem from ocr_bbox -# try: -# # gather all ocr labels -# ocr_labels += box3_elem['content'] + ' ' -# filtered_boxes.remove(box3_elem) -# except: -# continue -# # break -# elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box -# box_added = True -# break -# else: -# continue -# if not box_added: -# if ocr_labels: -# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'}) -# else: -# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'}) -# else: -# filtered_boxes.append(box1) -# return filtered_boxes # torch.tensor(filtered_boxes) - - -# def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: -# transform = T.Compose( -# [ -# T.RandomResize([800], max_size=1333), -# T.ToTensor(), -# T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), -# ] -# ) -# image_source = Image.open(image_path).convert("RGB") -# image = np.asarray(image_source) -# image_transformed, _ = transform(image_source, None) -# return image, image_transformed - - -# def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float, -# text_padding=5, text_thickness=2, thickness=3) -> np.ndarray: -# """ -# This function annotates an image with bounding boxes and labels. - -# Parameters: -# image_source (np.ndarray): The source image to be annotated. -# boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale -# logits (torch.Tensor): A tensor containing confidence scores for each bounding box. -# phrases (List[str]): A list of labels for each bounding box. -# text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web - -# Returns: -# np.ndarray: The annotated image. -# """ -# h, w, _ = image_source.shape -# boxes = boxes * torch.Tensor([w, h, w, h]) -# xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() -# xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy() -# detections = sv.Detections(xyxy=xyxy) - -# labels = [f"{phrase}" for phrase in range(boxes.shape[0])] - -# box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web -# annotated_frame = image_source.copy() -# annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h)) - -# label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)} -# return annotated_frame, label_coordinates - - -# def predict(model, image, caption, box_threshold, text_threshold): -# """ Use huggingface model to replace the original model -# """ -# model, processor = model['model'], model['processor'] -# device = model.device - -# inputs = processor(images=image, text=caption, return_tensors="pt").to(device) -# with torch.no_grad(): -# outputs = model(**inputs) - -# results = processor.post_process_grounded_object_detection( -# outputs, -# inputs.input_ids, -# box_threshold=box_threshold, # 0.4, -# text_threshold=text_threshold, # 0.3, -# target_sizes=[image.size[::-1]] -# )[0] -# boxes, logits, phrases = results["boxes"], results["scores"], results["labels"] -# return boxes, logits, phrases - - -# def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7): -# """ Use huggingface model to replace the original model -# """ -# # model = model['model'] -# if scale_img: -# result = model.predict( -# source=image, -# conf=box_threshold, -# imgsz=imgsz, -# iou=iou_threshold, # default 0.7 -# ) -# else: -# result = model.predict( -# source=image, -# conf=box_threshold, -# iou=iou_threshold, # default 0.7 -# ) -# boxes = result[0].boxes.xyxy#.tolist() # in pixel space -# conf = result[0].boxes.conf -# phrases = [str(i) for i in range(len(boxes))] - -# return boxes, conf, phrases - -# def int_box_area(box, w, h): -# x1, y1, x2, y2 = box -# int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)] -# area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1]) -# return area - -# def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128): -# """Process either an image path or Image object - -# Args: -# image_source: Either a file path (str) or PIL Image object -# ... -# """ -# if isinstance(image_source, str): -# image_source = Image.open(image_source) -# image_source = image_source.convert("RGB") # for CLIP -# w, h = image_source.size -# if not imgsz: -# imgsz = (h, w) -# # print('image size:', w, h) -# xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1) -# xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device) -# image_source = np.asarray(image_source) -# phrases = [str(i) for i in range(len(phrases))] - -# # annotate the image with labels -# if ocr_bbox: -# ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h]) -# ocr_bbox=ocr_bbox.tolist() -# else: -# print('no ocr bbox!!!') -# ocr_bbox = None - -# ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0] -# xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0] -# filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem) - -# # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None -# filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None) -# # get the index of the first 'content': None -# starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1) -# filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem]) -# print('len(filtered_boxes):', len(filtered_boxes), starting_idx) - -# # get parsed icon local semantics -# time1 = time.time() -# if use_local_semantics: -# caption_model = caption_model_processor['model'] -# if 'phi3_v' in caption_model.config.model_type: -# parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor) -# else: -# parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size) -# ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)] -# icon_start = len(ocr_text) -# parsed_content_icon_ls = [] -# # fill the filtered_boxes_elem None content with parsed_content_icon in order -# for i, box in enumerate(filtered_boxes_elem): -# if box['content'] is None: -# box['content'] = parsed_content_icon.pop(0) -# for i, txt in enumerate(parsed_content_icon): -# parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}") -# parsed_content_merged = ocr_text + parsed_content_icon_ls -# else: -# ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)] -# parsed_content_merged = ocr_text -# print('time to get parsed content:', time.time()-time1) - -# filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh") - -# phrases = [i for i in range(len(filtered_boxes))] - -# # draw boxes -# if draw_bbox_config: -# annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config) -# else: -# annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding) - -# pil_img = Image.fromarray(annotated_frame) -# buffered = io.BytesIO() -# pil_img.save(buffered, format="PNG") -# encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii') -# if output_coord_in_ratio: -# label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()} -# assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0] - -# return encoded_image, label_coordinates, filtered_boxes_elem - - -# def get_xywh(input): -# x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1] -# x, y, w, h = int(x), int(y), int(w), int(h) -# return x, y, w, h - -# def get_xyxy(input): -# x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1] -# x, y, xp, yp = int(x), int(y), int(xp), int(yp) -# return x, y, xp, yp - -# def get_xywh_yolo(input): -# x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1] -# x, y, w, h = int(x), int(y), int(w), int(h) -# return x, y, w, h - -# def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False): -# if isinstance(image_source, str): -# image_source = Image.open(image_source) -# if image_source.mode == 'RGBA': -# # Convert RGBA to RGB to avoid alpha channel issues -# image_source = image_source.convert('RGB') -# image_np = np.array(image_source) -# w, h = image_source.size -# if use_paddleocr: -# if easyocr_args is None: -# text_threshold = 0.5 -# else: -# text_threshold = easyocr_args['text_threshold'] -# result = paddle_ocr.ocr(image_np, cls=False)[0] -# coord = [item[0] for item in result if item[1][1] > text_threshold] -# text = [item[1][0] for item in result if item[1][1] > text_threshold] -# else: # EasyOCR -# if easyocr_args is None: -# easyocr_args = {} -# result = reader.readtext(image_np, **easyocr_args) -# coord = [item[0] for item in result] -# text = [item[1] for item in result] -# if display_img: -# opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) -# bb = [] -# for item in coord: -# x, y, a, b = get_xywh(item) -# bb.append((x, y, a, b)) -# cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2) -# # matplotlib expects RGB -# plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)) -# else: -# if output_bb_format == 'xywh': -# bb = [get_xywh(item) for item in coord] -# elif output_bb_format == 'xyxy': -# bb = [get_xyxy(item) for item in coord] -# return (text, bb), goal_filtering \ No newline at end of file