Merge pull request #616 from odaysec/dev

Fix uncontrolled data used in path expression
This commit is contained in:
warmshao
2025-06-01 22:56:06 +08:00
committed by GitHub
2 changed files with 13 additions and 2 deletions

View File

@@ -1111,7 +1111,12 @@ class DeepResearchAgent:
} }
self.current_task_id = task_id if task_id else str(uuid.uuid4()) self.current_task_id = task_id if task_id else str(uuid.uuid4())
output_dir = os.path.join(save_dir, self.current_task_id) safe_root_dir = "./tmp/deep_research"
normalized_save_dir = os.path.normpath(save_dir)
if not normalized_save_dir.startswith(os.path.abspath(safe_root_dir)):
logger.warning(f"Unsafe save_dir detected: {save_dir}. Using default directory.")
normalized_save_dir = os.path.abspath(safe_root_dir)
output_dir = os.path.join(normalized_save_dir, self.current_task_id)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
logger.info( logger.info(

View File

@@ -74,7 +74,13 @@ async def run_deep_research(webui_manager: WebuiManager, components: Dict[Compon
task_topic = components.get(research_task_comp, "").strip() task_topic = components.get(research_task_comp, "").strip()
task_id_to_resume = components.get(resume_task_id_comp, "").strip() or None task_id_to_resume = components.get(resume_task_id_comp, "").strip() or None
max_parallel_agents = int(components.get(parallel_num_comp, 1)) max_parallel_agents = int(components.get(parallel_num_comp, 1))
base_save_dir = components.get(save_dir_comp, "./tmp/deep_research") base_save_dir = components.get(save_dir_comp, "./tmp/deep_research").strip()
safe_root_dir = "./tmp/deep_research"
normalized_base_save_dir = os.path.abspath(os.path.normpath(base_save_dir))
if os.path.commonpath([normalized_base_save_dir, os.path.abspath(safe_root_dir)]) != os.path.abspath(safe_root_dir):
logger.warning(f"Unsafe base_save_dir detected: {base_save_dir}. Using default directory.")
normalized_base_save_dir = os.path.abspath(safe_root_dir)
base_save_dir = normalized_base_save_dir
mcp_server_config_str = components.get(mcp_server_config_comp) mcp_server_config_str = components.get(mcp_server_config_comp)
mcp_config = json.loads(mcp_server_config_str) if mcp_server_config_str else None mcp_config = json.loads(mcp_server_config_str) if mcp_server_config_str else None