Merge pull request #28 from richard-devbot/main

Update LLM configuration Tab to display list of models
This commit is contained in:
warmshao
2025-01-08 23:38:58 +08:00
committed by GitHub
3 changed files with 103 additions and 66 deletions

View File

@@ -2,4 +2,5 @@ browser-use>=0.1.18
langchain-google-genai>=2.0.8 langchain-google-genai>=2.0.8
pyperclip pyperclip
gradio gradio
langchain-ollama langchain-ollama

View File

@@ -12,7 +12,7 @@ from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI from langchain_openai import AzureChatOpenAI, ChatOpenAI
import gradio as gr
def get_llm_model(provider: str, **kwargs): def get_llm_model(provider: str, **kwargs):
""" """
@@ -106,8 +106,34 @@ def get_llm_model(provider: str, **kwargs):
) )
else: else:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")
# Predefined model names for common providers
model_names = {
"anthropic": ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
"deepseek": ["deepseek-chat"],
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
"ollama": ["qwen2.5:7b", "llama2:7b"],
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
}
# Callback to update the model name dropdown based on the selected provider
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
"""
Update the model name dropdown with predefined models for the selected provider.
"""
# Use API keys from .env if not provided
if not api_key:
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
if not base_url:
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
# Use predefined models for the selected provider
if llm_provider in model_names:
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
else:
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
def encode_image(img_path): def encode_image(img_path):
if not img_path: if not img_path:
return None return None

138
webui.py
View File

@@ -5,6 +5,8 @@
# @Project : browser-use-webui # @Project : browser-use-webui
# @FileName: webui.py # @FileName: webui.py
import pdb
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@@ -12,6 +14,12 @@ import argparse
import os import os
import gradio as gr import gradio as gr
import argparse
from gradio.themes import Base, Default, Soft, Monochrome, Glass, Origin, Citrus, Ocean
import asyncio
import os, glob
from browser_use.agent.service import Agent from browser_use.agent.service import Agent
from browser_use.browser.browser import Browser, BrowserConfig from browser_use.browser.browser import Browser, BrowserConfig
from browser_use.browser.context import ( from browser_use.browser.context import (
@@ -26,7 +34,7 @@ from src.browser.custom_browser import CustomBrowser
from src.browser.custom_context import BrowserContextConfig from src.browser.custom_context import BrowserContextConfig
from src.controller.custom_controller import CustomController from src.controller.custom_controller import CustomController
from src.utils import utils from src.utils import utils
from src.utils.utils import update_model_dropdown
async def run_browser_agent( async def run_browser_agent(
agent_type, agent_type,
@@ -268,11 +276,6 @@ async def run_custom_agent(
await browser.close() await browser.close()
return final_result, errors, model_actions, model_thoughts return final_result, errors, model_actions, model_thoughts
import glob
from gradio.themes import Citrus, Default, Glass, Monochrome, Ocean, Origin, Soft
# Define the theme map globally # Define the theme map globally
theme_map = { theme_map = {
"Default": Default(), "Default": Default(),
@@ -282,6 +285,7 @@ theme_map = {
"Origin": Origin(), "Origin": Origin(),
"Citrus": Citrus(), "Citrus": Citrus(),
"Ocean": Ocean(), "Ocean": Ocean(),
"Base": Base()
} }
@@ -364,22 +368,17 @@ def create_ui(theme_name="Ocean"):
with gr.TabItem("🔧 LLM Configuration", id=2): with gr.TabItem("🔧 LLM Configuration", id=2):
with gr.Group(): with gr.Group():
llm_provider = gr.Dropdown( llm_provider = gr.Dropdown(
[ ["anthropic", "openai", "deepseek", "gemini", "ollama", "azure_openai"],
"anthropic",
"openai",
"gemini",
"azure_openai",
"deepseek",
"ollama",
],
label="LLM Provider", label="LLM Provider",
value="openai", value="",
info="Select your preferred language model provider", info="Select your preferred language model provider"
) )
llm_model_name = gr.Textbox( llm_model_name = gr.Dropdown(
label="Model Name", label="Model Name",
value="gpt-4o", value="",
info="Specify the model to use", interactive=True,
allow_custom_value=True, # Allow users to input custom model names
info="Select a model from the dropdown or type a custom model name"
) )
llm_temperature = gr.Slider( llm_temperature = gr.Slider(
minimum=0.0, minimum=0.0,
@@ -387,16 +386,21 @@ def create_ui(theme_name="Ocean"):
value=1.0, value=1.0,
step=0.1, step=0.1,
label="Temperature", label="Temperature",
info="Controls randomness in model outputs", info="Controls randomness in model outputs"
) )
with gr.Row(): with gr.Row():
llm_base_url = gr.Textbox( llm_base_url = gr.Textbox(
label="Base URL", info="API endpoint URL (if required)" label="Base URL",
value=os.getenv(f"{llm_provider.value.upper()}_BASE_URL ", ""), # Default to .env value
info="API endpoint URL (if required)"
) )
llm_api_key = gr.Textbox( llm_api_key = gr.Textbox(
label="API Key", type="password", info="Your API key" label="API Key",
type="password",
value=os.getenv(f"{llm_provider.value.upper()}_API_KEY", ""), # Default to .env value
info="Your API key (leave blank to use .env)"
) )
with gr.TabItem("🌐 Browser Settings", id=3): with gr.TabItem("🌐 Browser Settings", id=3):
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
@@ -454,7 +458,7 @@ def create_ui(theme_name="Ocean"):
run_button = gr.Button("▶️ Run Agent", variant="primary", scale=2) run_button = gr.Button("▶️ Run Agent", variant="primary", scale=2)
stop_button = gr.Button("⏹️ Stop", variant="stop", scale=1) stop_button = gr.Button("⏹️ Stop", variant="stop", scale=1)
with gr.TabItem("🎬 Recordings", id=5): with gr.TabItem("📊 Results", id=5):
recording_display = gr.Video(label="Latest Recording") recording_display = gr.Video(label="Latest Recording")
with gr.Group(): with gr.Group():
@@ -477,61 +481,67 @@ def create_ui(theme_name="Ocean"):
model_thoughts_output = gr.Textbox( model_thoughts_output = gr.Textbox(
label="Model Thoughts", lines=3, show_label=True label="Model Thoughts", lines=3, show_label=True
) )
with gr.TabItem("🎥 Recordings", id=6):
def list_recordings(save_recording_path):
if not os.path.exists(save_recording_path):
return []
# Get all video files
recordings = glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4")) + glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
# Sort recordings by creation time (oldest first)
recordings.sort(key=os.path.getctime)
# Add numbering to the recordings
numbered_recordings = []
for idx, recording in enumerate(recordings, start=1):
filename = os.path.basename(recording)
numbered_recordings.append((recording, f"{idx}. {filename}"))
return numbered_recordings
recordings_gallery = gr.Gallery(
label="Recordings",
value=list_recordings("./tmp/record_videos"),
columns=3,
height="auto",
object_fit="contain"
)
refresh_button = gr.Button("🔄 Refresh Recordings", variant="secondary")
refresh_button.click(
fn=list_recordings,
inputs=save_recording_path,
outputs=recordings_gallery
)
# Attach the callback to the LLM provider dropdown
llm_provider.change(
lambda provider, api_key, base_url: update_model_dropdown(provider, api_key, base_url),
inputs=[llm_provider, llm_api_key, llm_base_url],
outputs=llm_model_name
)
# Run button click handler # Run button click handler
run_button.click( run_button.click(
fn=run_browser_agent, fn=run_browser_agent,
inputs=[ inputs=[agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_own_browser, headless, disable_security, window_w, window_h, save_recording_path, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content],
agent_type, outputs=[final_result_output, errors_output, model_actions_output, model_thoughts_output, recording_display,],
llm_provider,
llm_model_name,
llm_temperature,
llm_base_url,
llm_api_key,
use_own_browser,
headless,
disable_security,
window_w,
window_h,
save_recording_path,
task,
add_infos,
max_steps,
use_vision,
max_actions_per_step,
tool_call_in_content
],
outputs=[
final_result_output,
errors_output,
model_actions_output,
model_thoughts_output,
recording_display,
],
) )
return demo return demo
def main(): def main():
parser = argparse.ArgumentParser(description="Gradio UI for Browser Agent") parser = argparse.ArgumentParser(description="Gradio UI for Browser Agent")
parser.add_argument( parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
"--ip", type=str, default="127.0.0.1", help="IP address to bind to"
)
parser.add_argument("--port", type=int, default=7788, help="Port to listen on") parser.add_argument("--port", type=int, default=7788, help="Port to listen on")
parser.add_argument( parser.add_argument("--theme", type=str, default="Ocean", choices=theme_map.keys(), help="Theme to use for the UI")
"--theme",
type=str,
default="Ocean",
choices=theme_map.keys(),
help="Theme to use for the UI",
)
parser.add_argument("--dark-mode", action="store_true", help="Enable dark mode") parser.add_argument("--dark-mode", action="store_true", help="Enable dark mode")
args = parser.parse_args() args = parser.parse_args()
demo = create_ui(theme_name=args.theme) demo = create_ui(theme_name=args.theme)
demo.launch(server_name=args.ip, server_port=args.port) demo.launch(server_name=args.ip, server_port=args.port)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()