加入base_url

This commit is contained in:
yuruo 2025-03-04 09:31:00 +08:00
parent deeb1aa982
commit f526955d9f
3 changed files with 31 additions and 67 deletions

View File

@ -31,29 +31,15 @@ class VLMAgent:
def __init__(
self,
model: str,
provider: str,
api_key: str,
output_callback: Callable,
api_response_callback: Callable,
max_tokens: int = 4096,
base_url:str = "",
only_n_most_recent_images: int | None = None,
print_usage: bool = True,
):
if model == "omniparser + gpt-4o":
self.model = "gpt-4o-2024-11-20"
elif model == "omniparser + R1":
self.model = "deepseek-r1-distill-llama-70b"
elif model == "omniparser + qwen2.5vl":
self.model = "qwen2.5-vl-72b-instruct"
elif model == "omniparser + o1":
self.model = "o1"
elif model == "omniparser + o3-mini":
self.model = "o3-mini"
else:
raise ValueError(f"Model {model} not supported")
self.provider = provider
self.base_url = base_url
self.api_key = api_key
self.api_response_callback = api_response_callback
self.max_tokens = max_tokens
@ -98,7 +84,7 @@ class VLMAgent:
model_name=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
provider_base_url="https://api.openai.com/v1",
provider_base_url=self.base_url,
temperature=0,
)
print(f"oai token usage: {token_usage}")
@ -127,7 +113,7 @@ class VLMAgent:
model_name=self.model,
api_key=self.api_key,
max_tokens=min(2048, self.max_tokens),
provider_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
provider_base_url=self.base_url,
temperature=0,
)
print(f"qwen token usage: {token_usage}")

View File

@ -231,7 +231,8 @@ def process_input(user_input, state):
api_key=state["api_key"],
only_n_most_recent_images=state["only_n_most_recent_images"],
max_tokens=16384,
omniparser_url=args.omniparser_server_url
omniparser_url=args.omniparser_server_url,
base_url = state["base_url"]
):
if loop_msg is None or state.get("stop"):
yield state['chatbot_messages']
@ -348,12 +349,16 @@ def run():
def update_api_key(api_key_value, state):
state["api_key"] = api_key_value
def update_base_url(base_url, state):
state["base_url"] = base_url
def clear_chat(state):
# Reset message-related state
state["messages"] = []
state["responses"] = {}
state["tools"] = {}
state["base_url"] = ""
state['chatbot_messages'] = []
return state['chatbot_messages']
@ -362,6 +367,7 @@ def run():
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
submit_button.click(process_input, [chat_input, state], chatbot)
stop_button.click(stop_app, [state], None)
base_url.chage(fn=update_base_url, inputs=[base_url, state], outputs=None)
demo.launch(server_name="0.0.0.0", server_port=7888)
if __name__ == "__main__":

View File

@ -46,66 +46,38 @@ def sampling_loop_sync(
api_key: str,
only_n_most_recent_images: int | None = 2,
max_tokens: int = 4096,
omniparser_url: str
omniparser_url: str,
base_url: str
):
"""
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
"""
print('in sampling_loop_sync, model:', model)
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
if model == "claude-3-5-sonnet-20241022":
# Register Actor and Executor
actor = AnthropicActor(
model=model,
api_key=api_key,
api_response_callback=api_response_callback,
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
)
elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl"]):
actor = VLMAgent(
model=model,
api_key=api_key,
api_response_callback=api_response_callback,
output_callback=output_callback,
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
)
else:
raise ValueError(f"Model {model} not supported")
actor = VLMAgent(
model=model,
api_key=api_key,
base_url=base_url,
api_response_callback=api_response_callback,
output_callback=output_callback,
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
)
executor = AnthropicExecutor(
output_callback=output_callback,
tool_output_callback=tool_output_callback,
)
print(f"Model Inited: {model}, Provider: {provider}")
tool_result_content = None
print(f"Start the message loop. User messages: {messages}")
while True:
parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if model == "claude-3-5-sonnet-20241022": # Anthropic loop
while True:
parsed_screen = omniparser_client() # parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
screen_info_block = TextBlock(text='Below is the structured accessibility information of the current UI screen, which includes text and icons you can operate on, take these information into account when you are making the prediction for the next action. Note you will still need to take screenshot to get the image: \n' + parsed_screen['screen_info'], type='text')
screen_info_dict = {"role": "user", "content": [screen_info_block]}
messages.append(screen_info_dict)
tools_use_needed = actor(messages=messages)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content:
return messages
messages.append({"content": tool_result_content, "role": "user"})
elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl"]):
while True:
parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):
yield message
if not tool_result_content:
return messages
if not tool_result_content:
return messages