diff --git a/requirements.txt b/requirements.txt index 9777ebc..7f2d12c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ browser-use==0.1.40 pyperclip==1.9.0 -gradio==5.10.0 +gradio==5.23.1 json-repair langchain-mistralai==0.2.4 langchain-google-genai==2.0.8 diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py index a41245b..4b0eff3 100644 --- a/src/agent/custom_agent.py +++ b/src/agent/custom_agent.py @@ -208,8 +208,8 @@ class CustomAgent(Agent): @time_execution_async("--get_next_action") async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput: """Get next action from LLM based on current state""" - - ai_message = self.llm.invoke(input_messages) + fixed_input_messages = self._convert_input_messages(input_messages) + ai_message = self.llm.invoke(fixed_input_messages) self.message_manager._add_message_with_tokens(ai_message) if hasattr(ai_message, "reasoning_content"): @@ -222,10 +222,16 @@ class CustomAgent(Agent): else: ai_content = ai_message.content - ai_content = ai_content.replace("```json", "").replace("```", "") - ai_content = repair_json(ai_content) - parsed_json = json.loads(ai_content) - parsed: AgentOutput = self.AgentOutput(**parsed_json) + try: + ai_content = ai_content.replace("```json", "").replace("```", "") + ai_content = repair_json(ai_content) + parsed_json = json.loads(ai_content) + parsed: AgentOutput = self.AgentOutput(**parsed_json) + except Exception as e: + import traceback + traceback.print_exc() + logger.debug(ai_message.content) + raise ValueError('Could not parse response.') if parsed is None: logger.debug(ai_message.content) diff --git a/src/agent/custom_message_manager.py b/src/agent/custom_message_manager.py index 8f2276b..212c3fb 100644 --- a/src/agent/custom_message_manager.py +++ b/src/agent/custom_message_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import pdb from typing import List, Optional, Type, Dict from browser_use.agent.message_manager.service import MessageManager @@ -96,7 +97,7 @@ class CustomMessageManager(MessageManager): self._add_message_with_tokens(state_message) def _remove_state_message_by_index(self, remove_ind=-1) -> None: - """Remove last state message from history""" + """Remove state message by index from history""" i = len(self.state.history.messages) - 1 remove_cnt = 0 while i >= 0: diff --git a/src/agent/custom_system_prompt.md b/src/agent/custom_system_prompt.md index 13efbdb..9cefaa2 100644 --- a/src/agent/custom_system_prompt.md +++ b/src/agent/custom_system_prompt.md @@ -18,11 +18,17 @@ Example: # Response Rules 1. RESPONSE FORMAT: You must ALWAYS respond with valid JSON in this exact format: -{{"current_state": {{"evaluation_previous_goal": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Mention if something unexpected happened. Shortly state why/why not.", -"important_contents": "Output important contents closely related to user's instruction on the current page. If there is, please output the contents. If not, please output ''.", -"thought": "Think about the requirements that have been completed in previous operations and the requirements that need to be completed in the next one operation. If your output of evaluation_previous_goal is 'Failed', please reflect and output your reflection here.", -"next_goal": "Please generate a brief natural language description for the goal of your next actions based on your thought."}}, -"action":[{{"one_action_name": {{// action-specific parameter}}}}, // ... more actions in sequence]}} +{{ + "current_state": {{ + "evaluation_previous_goal": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Mention if something unexpected happened. Shortly state why/why not.", + "important_contents": "Output important contents closely related to user\'s instruction on the current page. If there is, please output the contents. If not, please output empty string ''.", + "thought": "Think about the requirements that have been completed in previous operations and the requirements that need to be completed in the next one operation. If your output of evaluation_previous_goal is 'Failed', please reflect and output your reflection here.", + "next_goal": "Please generate a brief natural language description for the goal of your next actions based on your thought." + }}, + "action": [ + {{"one_action_name": {{// action-specific parameter}}}}, // ... more actions in sequence + ] +}} 2. ACTIONS: You can specify multiple actions in the list to be executed in sequence. But always specify only one action name per item. Use maximum {{max_actions}} actions per sequence. Common action sequences: diff --git a/src/utils/agent_state.py b/src/utils/agent_state.py index 487a810..2456a55 100644 --- a/src/utils/agent_state.py +++ b/src/utils/agent_state.py @@ -1,5 +1,6 @@ import asyncio + class AgentState: _instance = None @@ -27,4 +28,4 @@ class AgentState: self.last_valid_state = state def get_last_valid_state(self): - return self.last_valid_state \ No newline at end of file + return self.last_valid_state diff --git a/src/utils/deep_research.py b/src/utils/deep_research.py index ab538e0..0409385 100644 --- a/src/utils/deep_research.py +++ b/src/utils/deep_research.py @@ -19,7 +19,13 @@ from browser_use.agent.views import ActionResult from browser_use.browser.context import BrowserContext from browser_use.controller.service import Controller, DoneAction from main_content_extractor import MainContentExtractor -from langchain.schema import SystemMessage, HumanMessage +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + ToolMessage, + SystemMessage +) from json_repair import repair_json from src.agent.custom_prompts import CustomSystemPrompt, CustomAgentMessagePrompt from src.controller.custom_controller import CustomController diff --git a/src/utils/default_config_settings.py b/src/utils/default_config_settings.py deleted file mode 100644 index 22c6185..0000000 --- a/src/utils/default_config_settings.py +++ /dev/null @@ -1,125 +0,0 @@ -import os -import pickle -import uuid -import gradio as gr - - -def default_config(): - """Prepare the default configuration""" - return { - "agent_type": "custom", - "max_steps": 100, - "max_actions_per_step": 10, - "use_vision": True, - "tool_calling_method": "auto", - "llm_provider": "openai", - "llm_model_name": "gpt-4o", - "llm_num_ctx": 32000, - "llm_temperature": 0.6, - "llm_base_url": "", - "llm_api_key": "", - "use_own_browser": os.getenv("CHROME_PERSISTENT_SESSION", "false").lower() == "true", - "keep_browser_open": False, - "headless": False, - "disable_security": True, - "enable_recording": True, - "window_w": 1280, - "window_h": 1100, - "save_recording_path": "./tmp/record_videos", - "save_trace_path": "./tmp/traces", - "save_agent_history_path": "./tmp/agent_history", - "task": "go to google.com and type 'OpenAI' click search and give me the first url", - } - - -def load_config_from_file(config_file): - """Load settings from a UUID.pkl file.""" - try: - with open(config_file, 'rb') as f: - settings = pickle.load(f) - return settings - except Exception as e: - return f"Error loading configuration: {str(e)}" - - -def save_config_to_file(settings, save_dir="./tmp/webui_settings"): - """Save the current settings to a UUID.pkl file with a UUID name.""" - os.makedirs(save_dir, exist_ok=True) - config_file = os.path.join(save_dir, f"{uuid.uuid4()}.pkl") - with open(config_file, 'wb') as f: - pickle.dump(settings, f) - return f"Configuration saved to {config_file}" - - -def save_current_config(*args): - current_config = { - "agent_type": args[0], - "max_steps": args[1], - "max_actions_per_step": args[2], - "use_vision": args[3], - "tool_calling_method": args[4], - "llm_provider": args[5], - "llm_model_name": args[6], - "llm_num_ctx": args[7], - "llm_temperature": args[8], - "llm_base_url": args[9], - "llm_api_key": args[10], - "use_own_browser": args[11], - "keep_browser_open": args[12], - "headless": args[13], - "disable_security": args[14], - "enable_recording": args[15], - "window_w": args[16], - "window_h": args[17], - "save_recording_path": args[18], - "save_trace_path": args[19], - "save_agent_history_path": args[20], - "task": args[21], - } - return save_config_to_file(current_config) - - -def update_ui_from_config(config_file): - if config_file is not None: - loaded_config = load_config_from_file(config_file.name) - if isinstance(loaded_config, dict): - return ( - gr.update(value=loaded_config.get("agent_type", "custom")), - gr.update(value=loaded_config.get("max_steps", 100)), - gr.update(value=loaded_config.get("max_actions_per_step", 10)), - gr.update(value=loaded_config.get("use_vision", True)), - gr.update(value=loaded_config.get("tool_calling_method", True)), - gr.update(value=loaded_config.get("llm_provider", "openai")), - gr.update(value=loaded_config.get("llm_model_name", "gpt-4o")), - gr.update(value=loaded_config.get("llm_num_ctx", 32000)), - gr.update(value=loaded_config.get("llm_temperature", 1.0)), - gr.update(value=loaded_config.get("llm_base_url", "")), - gr.update(value=loaded_config.get("llm_api_key", "")), - gr.update(value=loaded_config.get("use_own_browser", False)), - gr.update(value=loaded_config.get("keep_browser_open", False)), - gr.update(value=loaded_config.get("headless", False)), - gr.update(value=loaded_config.get("disable_security", True)), - gr.update(value=loaded_config.get("enable_recording", True)), - gr.update(value=loaded_config.get("window_w", 1280)), - gr.update(value=loaded_config.get("window_h", 1100)), - gr.update(value=loaded_config.get("save_recording_path", "./tmp/record_videos")), - gr.update(value=loaded_config.get("save_trace_path", "./tmp/traces")), - gr.update(value=loaded_config.get("save_agent_history_path", "./tmp/agent_history")), - gr.update(value=loaded_config.get("task", "")), - "Configuration loaded successfully." - ) - else: - return ( - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), "Error: Invalid configuration file." - ) - return ( - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), - gr.update(), "No file selected." - ) diff --git a/src/utils/utils.py b/src/utils/utils.py index c113843..7289002 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -4,13 +4,15 @@ import time from pathlib import Path from typing import Dict, Optional import requests +import json +import gradio as gr +import uuid from langchain_anthropic import ChatAnthropic from langchain_mistralai import ChatMistralAI from langchain_google_genai import ChatGoogleGenerativeAI from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI -import gradio as gr from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama, UnboundChatOpenAI @@ -37,7 +39,7 @@ def get_llm_model(provider: str, **kwargs): env_var = f"{provider.upper()}_API_KEY" api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") if not api_key: - handle_api_key_error(provider, env_var) + raise MissingAPIKeyError(provider, env_var) kwargs["api_key"] = api_key if provider == "anthropic": @@ -185,7 +187,7 @@ model_names = { "ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"], "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"], - "mistral": ["mixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"], + "mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"], "alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"], "moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"], "unbound": ["gemini-2.0-flash","gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"] @@ -197,6 +199,7 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None): """ Update the model name dropdown with predefined models for the selected provider. """ + import gradio as gr # Use API keys from .env if not provided if not api_key: api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "") @@ -210,15 +213,13 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None): return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True) -def handle_api_key_error(provider: str, env_var: str): - """ - Handles the missing API key error by raising a gr.Error with a clear message. - """ - provider_display = PROVIDER_DISPLAY_NAMES.get(provider, provider.upper()) - raise gr.Error( - f"💥 {provider_display} API key not found! 🔑 Please set the " - f"`{env_var}` environment variable or provide it in the UI." - ) +class MissingAPIKeyError(Exception): + """Custom exception for missing API key.""" + + def __init__(self, provider: str, env_var: str): + provider_display = PROVIDER_DISPLAY_NAMES.get(provider, provider.upper()) + super().__init__(f"💥 {provider_display} API key not found! 🔑 Please set the " + f"`{env_var}` environment variable or provide it in the UI.") def encode_image(img_path): @@ -287,3 +288,70 @@ async def capture_screenshot(browser_context): return encoded except Exception as e: return None + + +class ConfigManager: + def __init__(self): + self.components = {} + self.component_order = [] + + def register_component(self, name: str, component): + """Register a gradio component for config management.""" + self.components[name] = component + if name not in self.component_order: + self.component_order.append(name) + return component + + def save_current_config(self): + """Save the current configuration of all registered components.""" + current_config = {} + for name in self.component_order: + component = self.components[name] + # Get the current value from the component + current_config[name] = getattr(component, "value", None) + + return save_config_to_file(current_config) + + def update_ui_from_config(self, config_file): + """Update UI components from a loaded configuration file.""" + if config_file is None: + return [gr.update() for _ in self.component_order] + ["No file selected."] + + loaded_config = load_config_from_file(config_file.name) + + if not isinstance(loaded_config, dict): + return [gr.update() for _ in self.component_order] + ["Error: Invalid configuration file."] + + # Prepare updates for all components + updates = [] + for name in self.component_order: + if name in loaded_config: + updates.append(gr.update(value=loaded_config[name])) + else: + updates.append(gr.update()) + + updates.append("Configuration loaded successfully.") + return updates + + def get_all_components(self): + """Return all registered components in the order they were registered.""" + return [self.components[name] for name in self.component_order] + + +def load_config_from_file(config_file): + """Load settings from a config file (JSON format).""" + try: + with open(config_file, 'r') as f: + settings = json.load(f) + return settings + except Exception as e: + return f"Error loading configuration: {str(e)}" + + +def save_config_to_file(settings, save_dir="./tmp/webui_settings"): + """Save the current settings to a UUID.json file with a UUID name.""" + os.makedirs(save_dir, exist_ok=True) + config_file = os.path.join(save_dir, f"{uuid.uuid4()}.json") + with open(config_file, 'w') as f: + json.dump(settings, f, indent=2) + return f"Configuration saved to {config_file}" diff --git a/tests/test_browser_use.py b/tests/test_browser_use.py index db35c5f..6ef4210 100644 --- a/tests/test_browser_use.py +++ b/tests/test_browser_use.py @@ -133,11 +133,11 @@ async def test_browser_use_custom(): api_key=os.getenv("GOOGLE_API_KEY", "") ) - # llm = utils.get_llm_model( - # provider="deepseek", - # model_name="deepseek-reasoner", - # temperature=0.8 - # ) + llm = utils.get_llm_model( + provider="deepseek", + model_name="deepseek-reasoner", + temperature=0.8 + ) # llm = utils.get_llm_model( # provider="deepseek", diff --git a/webui.py b/webui.py index ec51779..bc68605 100644 --- a/webui.py +++ b/webui.py @@ -13,6 +13,8 @@ import os logger = logging.getLogger(__name__) import gradio as gr +import inspect +from functools import wraps from browser_use.agent.service import Agent from playwright.async_api import async_playwright @@ -32,9 +34,8 @@ from src.agent.custom_prompts import CustomSystemPrompt, CustomAgentMessagePromp from src.browser.custom_context import BrowserContextConfig, CustomBrowserContext from src.controller.custom_controller import CustomController from gradio.themes import Citrus, Default, Glass, Monochrome, Ocean, Origin, Soft, Base -from src.utils.default_config_settings import default_config, load_config_from_file, save_config_to_file, \ - save_current_config, update_ui_from_config -from src.utils.utils import update_model_dropdown, get_latest_files, capture_screenshot +from src.utils.utils import update_model_dropdown, get_latest_files, capture_screenshot, MissingAPIKeyError +from src.utils import utils # Global variables for persistence _global_browser = None @@ -44,6 +45,49 @@ _global_agent = None # Create the global agent state instance _global_agent_state = AgentState() +# webui config +webui_config_manager = utils.ConfigManager() + + +def scan_and_register_components(blocks): + """扫描一个 Blocks 对象并注册其中的所有交互式组件,但不包括按钮""" + global webui_config_manager + + def traverse_blocks(block, prefix=""): + registered = 0 + + # 处理 Blocks 自身的组件 + if hasattr(block, "children"): + for i, child in enumerate(block.children): + if isinstance(child, gr.components.Component): + # 排除按钮 (Button) 组件 + if getattr(child, "interactive", False) and not isinstance(child, gr.Button): + name = f"{prefix}component_{i}" + if hasattr(child, "label") and child.label: + # 使用标签作为名称的一部分 + label = child.label + name = f"{prefix}{label}" + logger.debug(f"Registering component: {name}") + webui_config_manager.register_component(name, child) + registered += 1 + elif hasattr(child, "children"): + # 递归处理嵌套的 Blocks + new_prefix = f"{prefix}block_{i}_" + registered += traverse_blocks(child, new_prefix) + + return registered + + total = traverse_blocks(blocks) + logger.info(f"Total registered components: {total}") + + +def save_current_config(): + return webui_config_manager.save_current_config() + + +def update_ui_from_config(config_file): + return webui_config_manager.update_ui_from_config(config_file) + def resolve_sensitive_env_variables(text): """ @@ -245,8 +289,9 @@ async def run_browser_agent( gr.update(interactive=True) # Re-enable run button ) - except gr.Error: - raise + except MissingAPIKeyError as e: + logger.error(str(e)) + raise gr.Error(str(e), print_exception=False) except Exception as e: import traceback @@ -539,8 +584,7 @@ async def run_with_stream( max_input_tokens=max_input_tokens ) # Add HTML content at the start of the result array - html_content = f"