更新自动 florence 模型

This commit is contained in:
yuruo
2025-03-17 17:31:48 +08:00
parent af78cfc4ee
commit cdc9b403a6
4 changed files with 26 additions and 41 deletions

View File

@@ -18,7 +18,7 @@ class UIElement(BaseModel):
text: Optional[str] = None
class VisionAgent:
def __init__(self, yolo_model_path: str, caption_model_path: str = 'microsoft/Florence-2-base-ft'):
def __init__(self, yolo_model_path: str, caption_model_path: str, florence_model_path: str):
"""
Initialize the vision agent
@@ -33,8 +33,9 @@ class VisionAgent:
# load the image caption model and processor
self.caption_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
"microsoft/Florence-2-base-ft",
cache_dir=florence_model_path,
trust_remote_code=True,
)
# load the model according to the device type

View File

@@ -14,7 +14,7 @@ from gradio_ui.loop import (
import base64
from xbrain.utils.config import Config
from util.download_weights import MODEL_DIR
from util.download_weights import FLORENCE_MODEL_DIR, OMNI_PARSER_MODEL_DIR
CONFIG_DIR = Path("~/.anthropic").expanduser()
API_KEY_FILE = CONFIG_DIR / "api_key"
@@ -317,8 +317,10 @@ def run():
model.change(fn=update_model, inputs=[model, state], outputs=None)
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
vision_agent = VisionAgent(yolo_model_path=os.path.join(MODEL_DIR, "icon_detect", "model.pt"),
caption_model_path=os.path.join(MODEL_DIR, "icon_caption"))
vision_agent = VisionAgent(yolo_model_path=os.path.join(OMNI_PARSER_MODEL_DIR, "icon_detect", "model.pt"),
caption_model_path=os.path.join(OMNI_PARSER_MODEL_DIR, "icon_caption"),
florence_model_path=os.path.join(FLORENCE_MODEL_DIR)
)
vision_agent_state = gr.State({"agent": vision_agent})
submit_button.click(process_input, [chat_input, state, vision_agent_state], [chatbot, task_list])
stop_button.click(stop_app, [state], None)

13
main.py
View File

@@ -2,16 +2,11 @@ from gradio_ui import app
from util import download_weights
import torch
def run():
try:
print("cuda is_available: ", torch.cuda.is_available())
print("MPS is_available: ", torch.backends.mps.is_available())
print("cuda device_count", torch.cuda.device_count())
print("cuda device_name", torch.cuda.get_device_name(0))
except Exception:
print("GPU driver is not compatible, please install the appropriate version of torch according to the readme!")
if not torch.cuda.is_available():
print("Warning: GPU is not available, we will use CPU, the application may run slower!\nyou computer will very likely heat up!")
print("Downloading the weight files...")
# download the weight files
download_weights.download()
download_weights.download_models()
app.run()

View File

@@ -2,40 +2,27 @@ import os
from pathlib import Path
from modelscope import snapshot_download
__WEIGHTS_DIR = Path("weights")
MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
def download():
OMNI_PARSER_MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
FLORENCE_MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "Florence-2-base-ft")
def __download_omni_parser():
# Create weights directory
__WEIGHTS_DIR.mkdir(exist_ok=True)
# List of files to download
files = [
"icon_detect/train_args.yaml",
"icon_detect/model.pt",
"icon_detect/model.yaml",
"icon_caption/config.json",
"icon_caption/generation_config.json",
"icon_caption/model.safetensors"
]
# Check and download missing files
missing_files = []
for file in files:
file_path = os.path.join(MODEL_DIR, file)
if not os.path.exists(file_path):
missing_files.append(file)
break
if not missing_files:
print("Model files already detected!")
return
snapshot_download(
'AI-ModelScope/OmniParser-v2.0',
cache_dir='weights'
)
print("Download complete")
def __download_florence_model():
snapshot_download('AI-ModelScope/Florence-2-base-ft',
cache_dir='weights'
)
def download_models():
__download_omni_parser()
__download_florence_model()
if __name__ == "__main__":
download()
download_models()