调通从app到agent

This commit is contained in:
yuruo 2025-03-04 10:35:45 +08:00
parent 4840c3bea6
commit 20bd0dd870
3 changed files with 25 additions and 35 deletions

View File

@ -62,6 +62,8 @@ def setup_state(state):
state['chatbot_messages'] = []
if 'stop' not in state:
state['stop'] = False
if 'base_url' not in state:
state['base_url'] = ""
async def main(state):
"""Render loop for Gradio"""
@ -176,49 +178,23 @@ def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="b
for user_msg, bot_msg in chatbot_state]
# print(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
def valid_params(user_input, state):
"""Validate all requirements and return a list of error messages."""
errors = []
for server_name, url in [('Windows Host', 'localhost:5000'), ('OmniParser Server', args.omniparser_server_url)]:
try:
url = f'http://{url}/probe'
response = requests.get(url, timeout=3)
if response.status_code != 200:
errors.append(f"{server_name} is not responding")
except RequestException as e:
errors.append(f"{server_name} is not responding")
if not state["api_key"].strip():
errors.append("LLM API Key is not set")
if not user_input:
errors.append("no computer use request provided")
return errors
def process_input(user_input, state):
# Reset the stop flag
if state["stop"]:
state["stop"] = False
errors = valid_params(user_input, state)
if errors:
raise gr.Error("Validation errors: " + ", ".join(errors))
# Append the user message to state["messages"]
state["messages"].append(
{
"role": Sender.USER,
"content": [TextBlock(type="text", text=user_input)],
"content": user_input,
}
)
# Append the user's message to chatbot_messages with None for the assistant's reply
state['chatbot_messages'].append((user_input, None))
state['chatbot_messages'].append({"role": "user", "content": user_input}) # 确保格式正确
yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
print("state")
print(state)
# Run sampling_loop_sync with the chatbot_output_callback
@ -367,7 +343,7 @@ def run():
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
submit_button.click(process_input, [chat_input, state], chatbot)
stop_button.click(stop_app, [state], None)
base_url.chage(fn=update_base_url, inputs=[base_url, state], outputs=None)
base_url.change(fn=update_base_url, inputs=[base_url, state], outputs=None)
demo.launch(server_name="0.0.0.0", server_port=7888)
if __name__ == "__main__":

24
main.py
View File

@ -1,12 +1,26 @@
import argparse
import subprocess
import signal
import sys
from gradio_ui import app
from util import download_weights
import time
def run():
download_weights.download()
app.run()
# 启动 server.py 子进程
server_process = subprocess.Popen(
["python", "./server.py"],
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP
)
try:
# 下载权重文件
download_weights.download()
# 启动 Gradio UI
app.run()
finally:
# 确保在主进程退出时终止子进程
if server_process.poll() is None: # 如果进程还在运行
server_process.terminate() # 发送终止信号
server_process.wait(timeout=5) # 等待进程结束
if __name__ == '__main__':
download_weights.download()
run()

View File

@ -21,7 +21,7 @@ def parse_arguments():
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
parser.add_argument('--port', type=int, default=5000, help='Port for the API')
parser.add_argument('--port', type=int, default=8000, help='Port for the API')
args = parser.parse_args()
return args