diff --git a/src/agent/deep_research/deep_research_agent.py b/src/agent/deep_research/deep_research_agent.py index 6981890..86be301 100644 --- a/src/agent/deep_research/deep_research_agent.py +++ b/src/agent/deep_research/deep_research_agent.py @@ -1111,7 +1111,12 @@ class DeepResearchAgent: } self.current_task_id = task_id if task_id else str(uuid.uuid4()) - output_dir = os.path.join(save_dir, self.current_task_id) + safe_root_dir = "./tmp/deep_research" + normalized_save_dir = os.path.normpath(save_dir) + if not normalized_save_dir.startswith(os.path.abspath(safe_root_dir)): + logger.warning(f"Unsafe save_dir detected: {save_dir}. Using default directory.") + normalized_save_dir = os.path.abspath(safe_root_dir) + output_dir = os.path.join(normalized_save_dir, self.current_task_id) os.makedirs(output_dir, exist_ok=True) logger.info( diff --git a/src/webui/components/deep_research_agent_tab.py b/src/webui/components/deep_research_agent_tab.py index ff455b5..88faea0 100644 --- a/src/webui/components/deep_research_agent_tab.py +++ b/src/webui/components/deep_research_agent_tab.py @@ -74,7 +74,13 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon task_topic = components.get(research_task_comp, "").strip() task_id_to_resume = components.get(resume_task_id_comp, "").strip() or None max_parallel_agents = int(components.get(parallel_num_comp, 1)) - base_save_dir = components.get(save_dir_comp, "./tmp/deep_research") + base_save_dir = components.get(save_dir_comp, "./tmp/deep_research").strip() + safe_root_dir = "./tmp/deep_research" + normalized_base_save_dir = os.path.abspath(os.path.normpath(base_save_dir)) + if os.path.commonpath([normalized_base_save_dir, os.path.abspath(safe_root_dir)]) != os.path.abspath(safe_root_dir): + logger.warning(f"Unsafe base_save_dir detected: {base_save_dir}. Using default directory.") + normalized_base_save_dir = os.path.abspath(safe_root_dir) + base_save_dir = normalized_base_save_dir mcp_server_config_str = components.get(mcp_server_config_comp) mcp_config = json.loads(mcp_server_config_str) if mcp_server_config_str else None