opt deep research

This commit is contained in:
vincent
2025-04-30 20:38:41 +08:00
parent eba5788b15
commit f941819d29
7 changed files with 92 additions and 93 deletions

View File

@@ -40,6 +40,7 @@ PLAN_FILENAME = "research_plan.md"
SEARCH_INFO_FILENAME = "search_info.json"
_AGENT_STOP_FLAGS = {}
_BROWSER_AGENT_INSTANCES = {}
async def run_single_browser_task(
@@ -129,6 +130,7 @@ async def run_single_browser_task(
# Store instance for potential stop() call
task_key = f"{task_id}_{uuid.uuid4()}"
_BROWSER_AGENT_INSTANCES[task_key] = bu_agent_instance
# --- Run with Stop Check ---
# BrowserUseAgent needs to internally check a stop signal or have a stop method.
@@ -173,6 +175,9 @@ async def run_single_browser_task(
except Exception as e:
logger.error(f"Error closing browser: {e}")
if task_key in _BROWSER_AGENT_INSTANCES:
del _BROWSER_AGENT_INSTANCES[task_key]
class BrowserSearchInput(BaseModel):
queries: List[str] = Field(
@@ -257,7 +262,7 @@ def create_browser_search_tool(
name="parallel_browser_search",
description=f"""Use this tool to actively search the web for information related to a specific research task or question.
It runs up to {max_parallel_browsers} searches in parallel using a browser agent for better results than simple scraping.
Provide a list of distinct search queries that are likely to yield relevant information.""",
Provide a list of distinct search queries(up to {max_parallel_browsers}) that are likely to yield relevant information.""",
args_schema=BrowserSearchInput,
)
@@ -296,9 +301,8 @@ class DeepResearchState(TypedDict):
def _load_previous_state(task_id: str, output_dir: str) -> Dict[str, Any]:
"""Loads state from files if they exist."""
state_updates = {}
plan_file = os.path.join(output_dir, task_id, PLAN_FILENAME)
search_file = os.path.join(output_dir, task_id, SEARCH_INFO_FILENAME)
plan_file = os.path.join(output_dir, PLAN_FILENAME)
search_file = os.path.join(output_dir, SEARCH_INFO_FILENAME)
if os.path.exists(plan_file):
try:
with open(plan_file, 'r', encoding='utf-8') as f:
@@ -307,9 +311,9 @@ def _load_previous_state(task_id: str, output_dir: str) -> Dict[str, Any]:
step = 1
for line in f:
line = line.strip()
if line.startswith(("[x]", "[ ]")):
status = "completed" if line.startswith("[x]") else "pending"
task = line[4:].strip()
if line.startswith(("- [x]", "- [ ]")):
status = "completed" if line.startswith("- [x]") else "pending"
task = line[5:].strip()
plan.append(
ResearchPlanItem(step=step, task=task, status=status, queries=None, result_summary=None))
step += 1
@@ -321,7 +325,6 @@ def _load_previous_state(task_id: str, output_dir: str) -> Dict[str, Any]:
except Exception as e:
logger.error(f"Failed to load or parse research plan {plan_file}: {e}")
state_updates['error_message'] = f"Failed to load research plan: {e}"
if os.path.exists(search_file):
try:
with open(search_file, 'r', encoding='utf-8') as f:
@@ -342,7 +345,7 @@ def _save_plan_to_md(plan: List[ResearchPlanItem], output_dir: str):
with open(plan_file, 'w', encoding='utf-8') as f:
f.write("# Research Plan\n\n")
for item in plan:
marker = "[x]" if item['status'] == 'completed' else "[ ]"
marker = "- [x]" if item['status'] == 'completed' else "- [ ]"
f.write(f"{marker} {item['task']}\n")
logger.info(f"Research plan saved to {plan_file}")
except Exception as e:
@@ -545,8 +548,6 @@ async def research_execution_node(state: DeepResearchState) -> Dict[str, Any]:
stop_event = _AGENT_STOP_FLAGS.get(task_id)
if stop_event and stop_event.is_set():
logger.info(f"Stop requested before executing tool: {tool_name}")
# How to report this back? Maybe skip execution, return special state?
# Let's update state and return stop_requested = True
current_step['status'] = 'pending' # Not completed due to stop
_save_plan_to_md(plan, output_dir)
return {"stop_requested": True, "research_plan": plan}
@@ -668,7 +669,8 @@ async def synthesis_node(state: DeepResearchState) -> Dict[str, Any]:
# Prepare the research plan context
plan_summary = "\nResearch Plan Followed:\n"
for item in plan:
marker = "[x]" if item['status'] == 'completed' else "[?]" if item['status'] == 'failed' else "[ ]"
marker = "- [x]" if item['status'] == 'completed' else "- [ ] (Failed)" if item[
'status'] == 'failed' else "- [ ]"
plan_summary += f"{marker} {item['task']}\n"
synthesis_prompt = ChatPromptTemplate.from_messages([
@@ -745,7 +747,7 @@ def should_continue(state: DeepResearchState) -> str:
return "end_run" # Should not happen if planning node ran correctly
# Check if there are pending steps in the plan
if current_index < 2:
if current_index < len(plan):
logger.info(
f"Plan has pending steps (current index {current_index}/{len(plan)}). Routing to Research Execution.")
return "execute_research"
@@ -956,7 +958,25 @@ class DeepResearchAgent:
"final_state": final_state if final_state else {} # Return the final state dict
}
def stop(self):
async def _stop_lingering_browsers(self, task_id):
"""Attempts to stop any BrowserUseAgent instances associated with the task_id."""
keys_to_stop = [key for key in _BROWSER_AGENT_INSTANCES if key.startswith(f"{task_id}_")]
if not keys_to_stop:
return
logger.warning(
f"Found {len(keys_to_stop)} potentially lingering browser agents for task {task_id}. Attempting stop...")
for key in keys_to_stop:
agent_instance = _BROWSER_AGENT_INSTANCES.get(key)
try:
if agent_instance:
# Assuming BU agent has an async stop method
await agent_instance.stop()
logger.info(f"Called stop() on browser agent instance {key}")
except Exception as e:
logger.error(f"Error calling stop() on browser agent instance {key}: {e}")
async def stop(self):
"""Signals the currently running agent task to stop."""
if not self.current_task_id or not self.stop_event:
logger.info("No agent task is currently running.")
@@ -965,6 +985,7 @@ class DeepResearchAgent:
logger.info(f"Stop requested for task ID: {self.current_task_id}")
self.stop_event.set() # Signal the stop event
self.stopped = True
await self._stop_lingering_browsers(self.current_task_id)
def close(self):
self.stopped = False

View File

@@ -16,12 +16,13 @@ model_names = {
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05",
"gemini-2.5-pro-preview-03-25", "gemini-2.5-flash-preview-04-17"],
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
"deepseek-r1:14b", "deepseek-r1:32b"],
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
"alibaba": ["qwen-plus", "qwen-max", "qwen-vl-max", "qwen-vl-plus", "qwen-turbo", "qwen-long"],
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
"unbound": ["gemini-2.0-flash", "gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"],
"siliconflow": [

View File

@@ -265,23 +265,6 @@ def get_llm_model(provider: str, **kwargs):
azure_endpoint=base_url,
api_key=api_key,
)
elif provider == "bedrock":
if not kwargs.get("base_url", ""):
access_key_id = os.getenv('AWS_ACCESS_KEY_ID', '')
else:
access_key_id = kwargs.get("base_url")
if not kwargs.get("api_key", ""):
api_key = os.getenv('AWS_SECRET_ACCESS_KEY', '')
else:
api_key = kwargs.get("api_key")
return ChatBedrock(
model=kwargs.get("model_name", 'anthropic.claude-3-5-sonnet-20241022-v2:0'),
region=kwargs.get("bedrock_region", 'us-west-2'), # with higher quota
aws_access_key_id=SecretStr(access_key_id),
aws_secret_access_key=SecretStr(api_key),
temperature=kwargs.get("temperature", 0.0),
)
elif provider == "alibaba":
if not kwargs.get("base_url", ""):
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")

View File

@@ -84,7 +84,7 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
return
# Store base save dir for stop handler
webui_manager._dr_save_dir = base_save_dir
webui_manager.dr_save_dir = base_save_dir
os.makedirs(base_save_dir, exist_ok=True)
# --- 2. Initial UI Update ---
@@ -141,8 +141,8 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
}
# --- 4. Initialize or Get Agent ---
if not webui_manager._dr_agent:
webui_manager._dr_agent = DeepResearchAgent(
if not webui_manager.dr_agent:
webui_manager.dr_agent = DeepResearchAgent(
llm=llm,
browser_config=browser_config_dict,
mcp_server_config=mcp_config
@@ -150,20 +150,20 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
logger.info("DeepResearchAgent initialized.")
# --- 5. Start Agent Run ---
agent_run_coro = await webui_manager._dr_agent.run(
agent_run_coro = webui_manager.dr_agent.run(
topic=task_topic,
task_id=task_id_to_resume,
save_dir=base_save_dir,
max_parallel_browsers=max_parallel_agents
)
agent_task = asyncio.create_task(agent_run_coro)
webui_manager._dr_current_task = agent_task
webui_manager.dr_current_task = agent_task
# Wait briefly for the agent to start and potentially create the task ID/folder
await asyncio.sleep(1.0)
# Determine the actual task ID being used (agent sets this)
running_task_id = webui_manager._dr_agent.current_task_id
running_task_id = webui_manager.dr_agent.current_task_id
if not running_task_id:
# Agent might not have set it yet, try to get from result later? Risky.
# Or derive from resume_task_id if provided?
@@ -176,7 +176,7 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
else:
logger.info(f"Agent started with Task ID: {running_task_id}")
webui_manager._dr_task_id = running_task_id # Store for stop handler
webui_manager.dr_task_id = running_task_id # Store for stop handler
# --- 6. Monitor Progress via research_plan.md ---
if running_task_id:
@@ -187,12 +187,11 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
else:
logger.warning("Cannot monitor plan file: Task ID unknown.")
plan_file_path = None
last_plan_content = None
while not agent_task.done():
update_dict = {}
# Check for stop signal (agent sets self.stopped)
agent_stopped = getattr(webui_manager._dr_agent, 'stopped', False)
update_dict[resume_task_id_comp] = gr.update(value=running_task_id)
agent_stopped = getattr(webui_manager.dr_agent, 'stopped', False)
if agent_stopped:
logger.info("Stop signal detected from agent state.")
break # Exit monitoring loop
@@ -204,7 +203,8 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
if current_mtime > last_plan_mtime:
logger.info(f"Detected change in {plan_file_path}")
plan_content = _read_file_safe(plan_file_path)
if plan_content is not None and plan_content != last_plan_content:
if last_plan_content is None or (
plan_content is not None and plan_content != last_plan_content):
update_dict[markdown_display_comp] = gr.update(value=plan_content)
last_plan_content = plan_content
last_plan_mtime = current_mtime
@@ -230,7 +230,7 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
# Try to get task ID from result if not known before
if not running_task_id and final_result_dict and 'task_id' in final_result_dict:
running_task_id = final_result_dict['task_id']
webui_manager._dr_task_id = running_task_id
webui_manager.dr_task_id = running_task_id
task_specific_dir = os.path.join(base_save_dir, str(running_task_id))
report_file_path = os.path.join(task_specific_dir, "report.md")
logger.info(f"Task ID confirmed from result: {running_task_id}")
@@ -268,22 +268,14 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
finally:
# --- 8. Final UI Reset ---
webui_manager._dr_current_task = None # Clear task reference
webui_manager._dr_task_id = None # Clear running task ID
# Optionally close agent resources if needed, e.g., browser pool
if webui_manager._dr_agent and hasattr(webui_manager._dr_agent, 'close'):
try:
await webui_manager._dr_agent.close() # Assuming an async close method
logger.info("Closed DeepResearchAgent resources.")
webui_manager._dr_agent = None
except Exception as e_close:
logger.error(f"Error closing DeepResearchAgent: {e_close}")
webui_manager.dr_current_task = None # Clear task reference
webui_manager.dr_task_id = None # Clear running task ID
yield {
start_button_comp: gr.update(value="▶️ Run", interactive=True),
stop_button_comp: gr.update(interactive=False),
research_task_comp: gr.update(interactive=True),
resume_task_id_comp: gr.update(interactive=True),
resume_task_id_comp: gr.update(value="", interactive=True),
parallel_num_comp: gr.update(interactive=True),
save_dir_comp: gr.update(interactive=True),
# Keep download button enabled if file exists
@@ -295,10 +287,10 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
async def stop_deep_research(webui_manager: WebuiManager) -> Dict[Component, Any]:
"""Handles the Stop button click."""
logger.info("Stop button clicked for Deep Research.")
agent = webui_manager._dr_agent
task = webui_manager._dr_current_task
task_id = webui_manager._dr_task_id
base_save_dir = webui_manager._dr_save_dir
agent = webui_manager.dr_agent
task = webui_manager.dr_current_task
task_id = webui_manager.dr_task_id
base_save_dir = webui_manager.dr_save_dir
stop_button_comp = webui_manager.get_component_by_id("deep_research_agent.stop_button")
start_button_comp = webui_manager.get_component_by_id("deep_research_agent.start_button")
@@ -311,15 +303,11 @@ async def stop_deep_research(webui_manager: WebuiManager) -> Dict[Component, Any
if agent and task and not task.done():
logger.info("Signalling DeepResearchAgent to stop.")
if hasattr(agent, 'stop'):
try:
# Assuming stop is synchronous or sets a flag quickly
agent.stop()
except Exception as e:
logger.error(f"Error calling agent.stop(): {e}")
else:
logger.warning("Agent has no 'stop' method. Task cancellation might not be graceful.")
# Task cancellation is handled by the run_deep_research finally block if needed
try:
# Assuming stop is synchronous or sets a flag quickly
await agent.stop()
except Exception as e:
logger.error(f"Error calling agent.stop(): {e}")
# The run_deep_research loop should detect the stop and exit.
# We yield an intermediate "Stopping..." state. The final reset is done by run_deep_research.
@@ -393,7 +381,7 @@ def create_deep_research_agent_tab(webui_manager: WebuiManager):
with gr.Group():
research_task = gr.Textbox(label="Research Task", lines=5,
value="Give me a detailed plan for traveling to Switzerland on June 1st.",
value="Give me a detailed travel plan to Switzerland from June 1st to 10th.",
interactive=True)
with gr.Row():
resume_task_id = gr.Textbox(label="Resume Task ID", value="",
@@ -418,7 +406,9 @@ def create_deep_research_agent_tab(webui_manager: WebuiManager):
stop_button=stop_button,
markdown_display=markdown_display,
markdown_download=markdown_download,
resume_task_id=resume_task_id
resume_task_id=resume_task_id,
mcp_json_file=mcp_json_file,
mcp_server_config=mcp_server_config,
)
)
webui_manager.add_components("deep_research_agent", tab_components)
@@ -430,7 +420,7 @@ def create_deep_research_agent_tab(webui_manager: WebuiManager):
)
dr_tab_outputs = list(tab_components.values())
all_managed_inputs = webui_manager.get_components()
all_managed_inputs = set(webui_manager.get_components())
# --- Define Event Handler Wrappers ---
async def start_wrapper(comps: Dict[Component, Any]) -> AsyncGenerator[Dict[Component, Any], None]:
@@ -439,17 +429,17 @@ def create_deep_research_agent_tab(webui_manager: WebuiManager):
async def stop_wrapper() -> AsyncGenerator[Dict[Component, Any], None]:
update_dict = await stop_deep_research(webui_manager)
yield update_dict # Yield the single dict update
yield update_dict
# --- Connect Handlers ---
start_button.click(
fn=start_wrapper,
inputs=all_managed_inputs,
outputs=dr_tab_outputs # Update only components in this tab
outputs=dr_tab_outputs
)
stop_button.click(
fn=stop_wrapper,
inputs=None,
outputs=dr_tab_outputs # Update only components in this tab
outputs=dr_tab_outputs
)

View File

@@ -45,9 +45,9 @@ class WebuiManager:
init deep research agent
"""
self.dr_agent: Optional[DeepResearchAgent] = None
self._dr_current_task = None
self.dr_current_task = None
self.dr_agent_task_id: Optional[str] = None
self._dr_save_dir: Optional[str] = None
self.dr_save_dir: Optional[str] = None
def add_components(self, tab_name: str, components_dict: dict[str, "Component"]) -> None:
"""

View File

@@ -338,18 +338,16 @@ async def test_deep_research_agent():
from src.agent.deep_research.deep_research_agent import DeepResearchAgent, PLAN_FILENAME, REPORT_FILENAME
from src.utils import llm_provider
# llm = llm_provider.get_llm_model(
# provider="azure_openai",
# model_name="gpt-4o",
# temperature=0.5,
# base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""),
# api_key=os.getenv("AZURE_OPENAI_API_KEY", ""),
# )
llm = llm_provider.get_llm_model(
provider="bedrock",
provider="openai",
model_name="gpt-4o",
temperature=0.5
)
# llm = llm_provider.get_llm_model(
# provider="bedrock",
# )
mcp_server_config = {
"mcpServers": {
"desktop-commander": {
@@ -364,9 +362,8 @@ async def test_deep_research_agent():
browser_config = {"headless": False, "window_width": 1280, "window_height": 1100, "use_own_browser": False}
agent = DeepResearchAgent(llm=llm, browser_config=browser_config, mcp_server_config=mcp_server_config)
research_topic = "Impact of Microplastics on Marine Ecosystems"
task_id_to_resume = None # Set this to resume a previous task ID
task_id_to_resume = "815460fb-337a-4850-8fa4-a5f2db301a89" # Set this to resume a previous task ID
print(f"Starting research on: {research_topic}")
@@ -374,8 +371,9 @@ async def test_deep_research_agent():
# Call run and wait for the final result dictionary
result = await agent.run(research_topic,
task_id=task_id_to_resume,
save_dir="./tmp/downloads",
max_parallel_browsers=1)
save_dir="./tmp/deep_research",
max_parallel_browsers=1,
)
print("\n--- Research Process Ended ---")
print(f"Status: {result.get('status')}")

View File

@@ -141,13 +141,19 @@ def test_ibm_model():
test_llm(config, "Describe this image", "assets/examples/test.png")
def test_qwen_model():
config = LLMConfig(provider="alibaba", model_name="qwen3-30b-a3b")
test_llm(config, "How many 'r's are in the word 'strawberry'?")
if __name__ == "__main__":
# test_openai_model()
# test_google_model()
test_azure_openai_model()
# test_azure_openai_model()
# test_deepseek_model()
# test_ollama_model()
# test_deepseek_r1_model()
test_deepseek_r1_model()
# test_deepseek_r1_ollama_model()
# test_mistral_model()
# test_ibm_model()
# test_qwen_model()