From deeb1aa982fc5553dab5fbeba53023be50b30794 Mon Sep 17 00:00:00 2001 From: ruotongyu Date: Mon, 3 Mar 2025 20:54:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- gradio_ui/agent/anthropic_agent.py | 2 - gradio_ui/app.py | 2 - gradio_ui/loop.py | 3 - server.py | 14 +- util/box_annotator.py | 262 +++++++++++++++++++++++++++++ 6 files changed, 271 insertions(+), 15 deletions(-) create mode 100644 util/box_annotator.py diff --git a/.gitignore b/.gitignore index c1ed458..2594c8c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ **/__pycache__** -weights** \ No newline at end of file +weights** +.conda** \ No newline at end of file diff --git a/gradio_ui/agent/anthropic_agent.py b/gradio_ui/agent/anthropic_agent.py index ddca387..654836e 100644 --- a/gradio_ui/agent/anthropic_agent.py +++ b/gradio_ui/agent/anthropic_agent.py @@ -48,7 +48,6 @@ class AnthropicActor: def __init__( self, model: str, - provider: APIProvider, api_key: str, api_response_callback: Callable[[APIResponse[BetaMessage]], None], max_tokens: int = 4096, @@ -56,7 +55,6 @@ class AnthropicActor: print_usage: bool = True, ): self.model = model - self.provider = provider self.api_key = api_key self.api_response_callback = api_response_callback self.max_tokens = max_tokens diff --git a/gradio_ui/app.py b/gradio_ui/app.py index 11bfa49..d5b37c6 100644 --- a/gradio_ui/app.py +++ b/gradio_ui/app.py @@ -224,7 +224,6 @@ def process_input(user_input, state): # Run sampling_loop_sync with the chatbot_output_callback for loop_msg in sampling_loop_sync( model=state["model"], - provider=state["provider"], messages=state["messages"], output_callback=partial(chatbot_output_callback, chatbot_state=state['chatbot_messages'], hide_images=False), tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]), @@ -349,7 +348,6 @@ def run(): def update_api_key(api_key_value, state): state["api_key"] = api_key_value - state[f'{state["provider"]}_api_key'] = api_key_value def clear_chat(state): # Reset message-related state diff --git a/gradio_ui/loop.py b/gradio_ui/loop.py index cb93b57..fccf999 100644 --- a/gradio_ui/loop.py +++ b/gradio_ui/loop.py @@ -39,7 +39,6 @@ PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { def sampling_loop_sync( *, model: str, - provider: APIProvider | None, messages: list[BetaMessageParam], output_callback: Callable[[BetaContentBlock], None], tool_output_callback: Callable[[ToolResult, str], None], @@ -58,7 +57,6 @@ def sampling_loop_sync( # Register Actor and Executor actor = AnthropicActor( model=model, - provider=provider, api_key=api_key, api_response_callback=api_response_callback, max_tokens=max_tokens, @@ -67,7 +65,6 @@ def sampling_loop_sync( elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl"]): actor = VLMAgent( model=model, - provider=provider, api_key=api_key, api_response_callback=api_response_callback, output_callback=output_callback, diff --git a/server.py b/server.py index 49fb306..7f156e0 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,5 @@ ''' -python -m omniparserserver --som_model_path ../../weights/icon_detect/model.pt --caption_model_name florence2 --caption_model_path ../../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05 +python -m server --som_model_path ../../weights/icon_detect/model.pt --caption_model_name florence2 --caption_model_path ../../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05 ''' import sys @@ -14,14 +14,14 @@ sys.path.append(root_dir) from util.omniparser import Omniparser def parse_arguments(): - parser = argparse.ArgumentParser(description='Omniparser API') - parser.add_argument('--som_model_path', type=str, default='../../weights/icon_detect/model.pt', help='Path to the som model') + parser = argparse.ArgumentParser(description='autoMate API') + parser.add_argument('--som_model_path', type=str, default='./weights/icon_detect/model.pt', help='Path to the som model') parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model') - parser.add_argument('--caption_model_path', type=str, default='../../weights/icon_caption_florence', help='Path to the caption model') + parser.add_argument('--caption_model_path', type=str, default='./weights/icon_caption_florence', help='Path to the caption model') parser.add_argument('--device', type=str, default='cpu', help='Device to run the model') parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API') - parser.add_argument('--port', type=int, default=8000, help='Port for the API') + parser.add_argument('--port', type=int, default=5000, help='Port for the API') args = parser.parse_args() return args @@ -45,7 +45,7 @@ async def parse(parse_request: ParseRequest): @app.get("/probe/") async def root(): - return {"message": "Omniparser API ready"} + return {"message": "API ready"} if __name__ == "__main__": - uvicorn.run("omniparserserver:app", host=args.host, port=args.port, reload=True) \ No newline at end of file + uvicorn.run("server:app", host=args.host, port=args.port, reload=True) \ No newline at end of file diff --git a/util/box_annotator.py b/util/box_annotator.py new file mode 100644 index 0000000..82f7116 --- /dev/null +++ b/util/box_annotator.py @@ -0,0 +1,262 @@ +from typing import List, Optional, Union, Tuple + +import cv2 +import numpy as np + +from supervision.detection.core import Detections +from supervision.draw.color import Color, ColorPalette + + +class BoxAnnotator: + """ + A class for drawing bounding boxes on an image using detections provided. + + Attributes: + color (Union[Color, ColorPalette]): The color to draw the bounding box, + can be a single color or a color palette + thickness (int): The thickness of the bounding box lines, default is 2 + text_color (Color): The color of the text on the bounding box, default is white + text_scale (float): The scale of the text on the bounding box, default is 0.5 + text_thickness (int): The thickness of the text on the bounding box, + default is 1 + text_padding (int): The padding around the text on the bounding box, + default is 5 + + """ + + def __init__( + self, + color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, + thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo + text_color: Color = Color.BLACK, + text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web + text_thickness: int = 2, #1, # 2 for demo + text_padding: int = 10, + avoid_overlap: bool = True, + ): + self.color: Union[Color, ColorPalette] = color + self.thickness: int = thickness + self.text_color: Color = text_color + self.text_scale: float = text_scale + self.text_thickness: int = text_thickness + self.text_padding: int = text_padding + self.avoid_overlap: bool = avoid_overlap + + def annotate( + self, + scene: np.ndarray, + detections: Detections, + labels: Optional[List[str]] = None, + skip_label: bool = False, + image_size: Optional[Tuple[int, int]] = None, + ) -> np.ndarray: + """ + Draws bounding boxes on the frame using the detections provided. + + Args: + scene (np.ndarray): The image on which the bounding boxes will be drawn + detections (Detections): The detections for which the + bounding boxes will be drawn + labels (Optional[List[str]]): An optional list of labels + corresponding to each detection. If `labels` are not provided, + corresponding `class_id` will be used as label. + skip_label (bool): Is set to `True`, skips bounding box label annotation. + Returns: + np.ndarray: The image with the bounding boxes drawn on it + + Example: + ```python + import supervision as sv + + classes = ['person', ...] + image = ... + detections = sv.Detections(...) + + box_annotator = sv.BoxAnnotator() + labels = [ + f"{classes[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, _ in detections + ] + annotated_frame = box_annotator.annotate( + scene=image.copy(), + detections=detections, + labels=labels + ) + ``` + """ + font = cv2.FONT_HERSHEY_SIMPLEX + for i in range(len(detections)): + x1, y1, x2, y2 = detections.xyxy[i].astype(int) + class_id = ( + detections.class_id[i] if detections.class_id is not None else None + ) + idx = class_id if class_id is not None else i + color = ( + self.color.by_idx(idx) + if isinstance(self.color, ColorPalette) + else self.color + ) + cv2.rectangle( + img=scene, + pt1=(x1, y1), + pt2=(x2, y2), + color=color.as_bgr(), + thickness=self.thickness, + ) + if skip_label: + continue + + text = ( + f"{class_id}" + if (labels is None or len(detections) != len(labels)) + else labels[i] + ) + + text_width, text_height = cv2.getTextSize( + text=text, + fontFace=font, + fontScale=self.text_scale, + thickness=self.text_thickness, + )[0] + + if not self.avoid_overlap: + text_x = x1 + self.text_padding + text_y = y1 - self.text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * self.text_padding - text_height + + text_background_x2 = x1 + 2 * self.text_padding + text_width + text_background_y2 = y1 + # text_x = x1 - self.text_padding - text_width + # text_y = y1 + self.text_padding + text_height + # text_background_x1 = x1 - 2 * self.text_padding - text_width + # text_background_y1 = y1 + # text_background_x2 = x1 + # text_background_y2 = y1 + 2 * self.text_padding + text_height + else: + text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size) + + cv2.rectangle( + img=scene, + pt1=(text_background_x1, text_background_y1), + pt2=(text_background_x2, text_background_y2), + color=color.as_bgr(), + thickness=cv2.FILLED, + ) + # import pdb; pdb.set_trace() + box_color = color.as_rgb() + luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] + text_color = (0,0,0) if luminance > 160 else (255,255,255) + cv2.putText( + img=scene, + text=text, + org=(text_x, text_y), + fontFace=font, + fontScale=self.text_scale, + # color=self.text_color.as_rgb(), + color=text_color, + thickness=self.text_thickness, + lineType=cv2.LINE_AA, + ) + return scene + + +def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + +def intersection_area(box1, box2): + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + return max(0, x2 - x1) * max(0, y2 - y1) + +def IoU(box1, box2, return_max=True): + intersection = intersection_area(box1, box2) + union = box_area(box1) + box_area(box2) - intersection + if box_area(box1) > 0 and box_area(box2) > 0: + ratio1 = intersection / box_area(box1) + ratio2 = intersection / box_area(box2) + else: + ratio1, ratio2 = 0, 0 + if return_max: + return max(intersection / union, ratio1, ratio2) + else: + return intersection / union + + +def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size): + """ check overlap of text and background detection box, and get_optimal_label_pos, + pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right + Threshold: default to 0.3 + """ + + def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size): + is_overlap = False + for i in range(len(detections)): + detection = detections.xyxy[i].astype(int) + if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3: + is_overlap = True + break + # check if the text is out of the image + if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]: + is_overlap = True + return is_overlap + + # if pos == 'top left': + text_x = x1 + text_padding + text_y = y1 - text_padding + + text_background_x1 = x1 + text_background_y1 = y1 - 2 * text_padding - text_height + + text_background_x2 = x1 + 2 * text_padding + text_width + text_background_y2 = y1 + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + # elif pos == 'outer left': + text_x = x1 - text_padding - text_width + text_y = y1 + text_padding + text_height + + text_background_x1 = x1 - 2 * text_padding - text_width + text_background_y1 = y1 + + text_background_x2 = x1 + text_background_y2 = y1 + 2 * text_padding + text_height + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + + # elif pos == 'outer right': + text_x = x2 + text_padding + text_y = y1 + text_padding + text_height + + text_background_x1 = x2 + text_background_y1 = y1 + + text_background_x2 = x2 + 2 * text_padding + text_width + text_background_y2 = y1 + 2 * text_padding + text_height + + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + # elif pos == 'top right': + text_x = x2 - text_padding - text_width + text_y = y1 - text_padding + + text_background_x1 = x2 - 2 * text_padding - text_width + text_background_y1 = y1 - 2 * text_padding - text_height + + text_background_x2 = x2 + text_background_y2 = y1 + + is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size) + if not is_overlap: + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 + + return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2