mirror of
https://github.com/yuruotong1/autoMate.git
synced 2025-12-26 05:16:21 +08:00
更新windows配置
This commit is contained in:
parent
cdc9b403a6
commit
c67d35d187
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import cv2
|
||||
import torch
|
||||
@ -18,7 +19,7 @@ class UIElement(BaseModel):
|
||||
text: Optional[str] = None
|
||||
|
||||
class VisionAgent:
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str, florence_model_path: str):
|
||||
def __init__(self, yolo_model_path: str, caption_model_path: str):
|
||||
"""
|
||||
Initialize the vision agent
|
||||
|
||||
@ -33,36 +34,19 @@ class VisionAgent:
|
||||
|
||||
# load the image caption model and processor
|
||||
self.caption_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft",
|
||||
cache_dir=florence_model_path,
|
||||
trust_remote_code=True,
|
||||
"processor",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# load the model according to the device type
|
||||
try:
|
||||
if self.device.type == 'cuda':
|
||||
# CUDA device uses float16
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
elif self.device.type == 'mps':
|
||||
# MPS device uses float32 (MPS has limited support for float16)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
else:
|
||||
# CPU uses float32
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=torch.float32,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
self.caption_model = AutoModelForCausalLM.from_pretrained(
|
||||
caption_model_path,
|
||||
torch_dtype=self.dtype,
|
||||
trust_remote_code=True
|
||||
).to(self.device)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Model loading failed for path: {caption_model_path}")
|
||||
raise e
|
||||
self.prompt = "<CAPTION>"
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from gradio_ui.loop import (
|
||||
import base64
|
||||
from xbrain.utils.config import Config
|
||||
|
||||
from util.download_weights import FLORENCE_MODEL_DIR, OMNI_PARSER_MODEL_DIR
|
||||
from util.download_weights import OMNI_PARSER_MODEL_DIR
|
||||
CONFIG_DIR = Path("~/.anthropic").expanduser()
|
||||
API_KEY_FILE = CONFIG_DIR / "api_key"
|
||||
|
||||
@ -318,9 +318,7 @@ def run():
|
||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
|
||||
vision_agent = VisionAgent(yolo_model_path=os.path.join(OMNI_PARSER_MODEL_DIR, "icon_detect", "model.pt"),
|
||||
caption_model_path=os.path.join(OMNI_PARSER_MODEL_DIR, "icon_caption"),
|
||||
florence_model_path=os.path.join(FLORENCE_MODEL_DIR)
|
||||
)
|
||||
caption_model_path=os.path.join(OMNI_PARSER_MODEL_DIR, "icon_caption"))
|
||||
vision_agent_state = gr.State({"agent": vision_agent})
|
||||
submit_button.click(process_input, [chat_input, state, vision_agent_state], [chatbot, task_list])
|
||||
stop_button.click(stop_app, [state], None)
|
||||
|
||||
9
main.py
9
main.py
@ -1,12 +1,19 @@
|
||||
# import os
|
||||
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
|
||||
|
||||
from gradio_ui import app
|
||||
from util import download_weights
|
||||
|
||||
import torch
|
||||
|
||||
def run():
|
||||
if not torch.cuda.is_available():
|
||||
print("Warning: GPU is not available, we will use CPU, the application may run slower!\nyou computer will very likely heat up!")
|
||||
print("Downloading the weight files...")
|
||||
# download the weight files
|
||||
download_weights.download_models()
|
||||
# download_weights.download_models()
|
||||
# 配置 HuggingFace 镜像
|
||||
# print("HuggingFace mirror configured to use ModelScope registry")
|
||||
app.run()
|
||||
|
||||
|
||||
|
||||
1026
processor/added_tokens.json
Normal file
1026
processor/added_tokens.json
Normal file
File diff suppressed because it is too large
Load Diff
50001
processor/merges.txt
Normal file
50001
processor/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
33
processor/preprocessor_config.json
Normal file
33
processor/preprocessor_config.json
Normal file
@ -0,0 +1,33 @@
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoProcessor": "microsoft/Florence-2-base--processing_florence2.Florence2Processor"
|
||||
},
|
||||
"crop_size": {
|
||||
"height": 768,
|
||||
"width": 768
|
||||
},
|
||||
"do_center_crop": false,
|
||||
"do_convert_rgb": null,
|
||||
"do_normalize": true,
|
||||
"do_rescale": true,
|
||||
"do_resize": true,
|
||||
"image_mean": [
|
||||
0.485,
|
||||
0.456,
|
||||
0.406
|
||||
],
|
||||
"image_processor_type": "CLIPImageProcessor",
|
||||
"image_seq_length": 577,
|
||||
"image_std": [
|
||||
0.229,
|
||||
0.224,
|
||||
0.225
|
||||
],
|
||||
"processor_class": "Florence2Processor",
|
||||
"resample": 3,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"size": {
|
||||
"height": 768,
|
||||
"width": 768
|
||||
}
|
||||
}
|
||||
7185
processor/special_tokens_map.json
Normal file
7185
processor/special_tokens_map.json
Normal file
File diff suppressed because it is too large
Load Diff
259573
processor/tokenizer.json
Normal file
259573
processor/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
9277
processor/tokenizer_config.json
Normal file
9277
processor/tokenizer_config.json
Normal file
File diff suppressed because it is too large
Load Diff
1
processor/vocab.json
Normal file
1
processor/vocab.json
Normal file
File diff suppressed because one or more lines are too long
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
from modelscope import snapshot_download
|
||||
__WEIGHTS_DIR = Path("weights")
|
||||
OMNI_PARSER_MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "OmniParser-v2___0")
|
||||
FLORENCE_MODEL_DIR = os.path.join(__WEIGHTS_DIR, "AI-ModelScope", "Florence-2-base-ft")
|
||||
|
||||
def __download_omni_parser():
|
||||
# Create weights directory
|
||||
@ -14,15 +13,11 @@ def __download_omni_parser():
|
||||
'AI-ModelScope/OmniParser-v2.0',
|
||||
cache_dir='weights'
|
||||
)
|
||||
|
||||
def __download_florence_model():
|
||||
snapshot_download('AI-ModelScope/Florence-2-base-ft',
|
||||
cache_dir='weights'
|
||||
)
|
||||
|
||||
|
||||
|
||||
def download_models():
|
||||
__download_omni_parser()
|
||||
__download_florence_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_models()
|
||||
Loading…
x
Reference in New Issue
Block a user