mirror of
https://github.com/yuruotong1/autoMate.git
synced 2026-03-22 13:07:17 +08:00
更新自动 florence 模型
This commit is contained in:
@@ -18,7 +18,7 @@ class UIElement(BaseModel):
|
||||
text: Optional[str] = None
|
||||
|
||||
class VisionAgent:
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str, florence_model_path: str):
|
||||
"""
|
||||
Initialize the vision agent
|
||||
|
||||
@@ -33,8 +33,9 @@ class VisionAgent:
|
||||
|
||||
# load the image caption model and processor
|
||||
self.caption_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base",
|
||||
trust_remote_code=True
|
||||
"microsoft/Florence-2-base-ft",
|
||||
cache_dir=florence_model_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# load the model according to the device type
|
||||
|
||||
@@ -14,7 +14,7 @@ from gradio_ui.loop import (
|
||||
import base64
|
||||
from xbrain.utils.config import Config
|
||||
|
||||
from util.download_weights import MODEL_DIR
|
||||
from util.download_weights import FLORENCE_MODEL_DIR, OMNI_PARSER_MODEL_DIR
|
||||
CONFIG_DIR = Path("~/.anthropic").expanduser()
|
||||
API_KEY_FILE = CONFIG_DIR / "api_key"
|
||||
|
||||
@@ -317,8 +317,10 @@ def run():
|
||||
model.change(fn=update_model, inputs=[model, state], outputs=None)
|
||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
|
||||
vision_agent = VisionAgent(yolo_model_path=os.path.join(MODEL_DIR, "icon_detect", "model.pt"),
|
||||
caption_model_path=os.path.join(MODEL_DIR, "icon_caption"))
|
||||
vision_agent = VisionAgent(yolo_model_path=os.path.join(OMNI_PARSER_MODEL_DIR, "icon_detect", "model.pt"),
|
||||
caption_model_path=os.path.join(OMNI_PARSER_MODEL_DIR, "icon_caption"),
|
||||
florence_model_path=os.path.join(FLORENCE_MODEL_DIR)
|
||||
)
|
||||
vision_agent_state = gr.State({"agent": vision_agent})
|
||||
submit_button.click(process_input, [chat_input, state, vision_agent_state], [chatbot, task_list])
|
||||
stop_button.click(stop_app, [state], None)
|
||||
|
||||
13
main.py
13
main.py
@@ -2,16 +2,11 @@ from gradio_ui import app
|
||||
from util import download_weights
|
||||
import torch
|
||||
def run():
|
||||
try:
|
||||
print("cuda is_available: ", torch.cuda.is_available())
|
||||
print("MPS is_available: ", torch.backends.mps.is_available())
|
||||
print("cuda device_count", torch.cuda.device_count())
|
||||
print("cuda device_name", torch.cuda.get_device_name(0))
|
||||
except Exception:
|
||||
print("GPU driver is not compatible, please install the appropriate version of torch according to the readme!")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("Warning: GPU is not available, we will use CPU, the application may run slower!\nyou computer will very likely heat up!")
|
||||
print("Downloading the weight files...")
|
||||
# download the weight files
|
||||
download_weights.download()
|
||||
download_weights.download_models()
|
||||
app.run()
|
||||
|
||||
|
||||
|
||||
@@ -2,40 +2,27 @@ import os
|
||||
from pathlib import Path
|
||||
from modelscope import snapshot_download
|
||||
__WEIGHTS_DIR = Path("weights")
|
||||
MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
|
||||
def download():
|
||||
OMNI_PARSER_MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
|
||||
FLORENCE_MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "Florence-2-base-ft")
|
||||
|
||||
def __download_omni_parser():
|
||||
# Create weights directory
|
||||
|
||||
__WEIGHTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# List of files to download
|
||||
files = [
|
||||
"icon_detect/train_args.yaml",
|
||||
"icon_detect/model.pt",
|
||||
"icon_detect/model.yaml",
|
||||
"icon_caption/config.json",
|
||||
"icon_caption/generation_config.json",
|
||||
"icon_caption/model.safetensors"
|
||||
]
|
||||
|
||||
# Check and download missing files
|
||||
missing_files = []
|
||||
for file in files:
|
||||
file_path = os.path.join(MODEL_DIR, file)
|
||||
if not os.path.exists(file_path):
|
||||
missing_files.append(file)
|
||||
break
|
||||
|
||||
if not missing_files:
|
||||
print("Model files already detected!")
|
||||
return
|
||||
|
||||
snapshot_download(
|
||||
'AI-ModelScope/OmniParser-v2.0',
|
||||
cache_dir='weights'
|
||||
)
|
||||
|
||||
print("Download complete")
|
||||
def __download_florence_model():
|
||||
snapshot_download('AI-ModelScope/Florence-2-base-ft',
|
||||
cache_dir='weights'
|
||||
)
|
||||
|
||||
def download_models():
|
||||
__download_omni_parser()
|
||||
__download_florence_model()
|
||||
|
||||
if __name__ == "__main__":
|
||||
download()
|
||||
download_models()
|
||||
Reference in New Issue
Block a user