feat: experiment deep query decomposition

This commit is contained in:
Florian Hönicke 2025-03-05 17:33:22 +01:00
parent d0e730fcc0
commit d88f2ef002
2 changed files with 39 additions and 32 deletions

View File

@ -5,7 +5,7 @@ import {readUrl, removeAllLineBreaks} from "./tools/read";
import fs from 'fs/promises';
import {SafeSearchType, search as duckSearch} from "duck-duck-scrape";
import {braveSearch} from "./tools/brave-search";
import {rewriteQuery} from "./tools/query-rewriter";
import {IS_DEEP_QUERY_REWRITE, rewriteQuery, rewriteQueryDeep} from "./tools/query-rewriter";
import {dedupQueries} from "./tools/jina-dedup";
import {evaluateAnswer, evaluateQuestion} from "./tools/evaluator";
import {analyzeSteps} from "./tools/error-analyzer";
@ -310,7 +310,7 @@ export async function getResponse(question?: string,
evaluationMetrics[currentQuestion].push('strict')
}
if (step === 1 && evaluationMetrics[currentQuestion].includes('freshness')) {
if (step === 1) {
// if it detects freshness, avoid direct answer at step 1
allowAnswer = false;
allowReflect = false;
@ -526,7 +526,13 @@ But then you realized you have asked them before. You decided to to think out of
thisStep.searchRequests = chooseK((await dedupQueries(thisStep.searchRequests, [], context.tokenTracker)).unique_queries, MAX_QUERIES_PER_STEP);
// rewrite queries
let {queries: keywordsQueries} = await rewriteQuery(thisStep, context, SchemaGen);
let rewriteQueryFn
if (IS_DEEP_QUERY_REWRITE) {
rewriteQueryFn = rewriteQueryDeep;
} else {
rewriteQueryFn = rewriteQuery;
}
let {queries: keywordsQueries} = await rewriteQueryFn(thisStep, context, SchemaGen);
// avoid exisitng searched queries
keywordsQueries = chooseK((await dedupQueries(keywordsQueries, allKeywords, context.tokenTracker)).unique_queries, MAX_QUERIES_PER_STEP);

View File

@ -2,6 +2,7 @@ import {PromptPair, SearchAction, TrackerContext} from '../types';
import {ObjectGeneratorSafe} from "../utils/safe-generator";
import {Schemas} from "../utils/schemas";
export const IS_DEEP_QUERY_REWRITE = true
function getPrompt(query: string, think: string): PromptPair {
const currentTime = new Date();
@ -171,36 +172,36 @@ ${query}
const TOOL_NAME = 'queryRewriter';
// export async function rewriteQuery(action: SearchAction, trackers: TrackerContext, schemaGen: Schemas): Promise<{ queries: string[] }> {
// try {
// const generator = new ObjectGeneratorSafe(trackers.tokenTracker);
// const allQueries = [...action.searchRequests];
//
// throw new Error(`allAction: ${JSON.stringify(action)}, allQueries: ${JSON.stringify(allQueries)}`);
// const queryPromises = action.searchRequests.map(async (req) => {
// const prompt = getPrompt(req, action.think);
// const result = await generator.generateObject({
// model: TOOL_NAME,
// schema: schemaGen.getQueryRewriterSchema(),
// system: prompt.system,
// prompt: prompt.user,
// });
// trackers?.actionTracker.trackThink(result.object.think);
// return result.object.queries;
// });
//
// const queryResults = await Promise.all(queryPromises);
// queryResults.forEach(queries => allQueries.push(...queries));
// console.log(TOOL_NAME, allQueries);
//
// return {queries: allQueries};
// } catch (error) {
// console.error(`Error in ${TOOL_NAME}`, error);
// throw error;
// }
// }
export async function rewriteQuery(action: SearchAction, trackers: TrackerContext, schemaGen: Schemas): Promise<{ queries: string[] }> {
try {
const generator = new ObjectGeneratorSafe(trackers.tokenTracker);
const allQueries = [...action.searchRequests];
throw new Error(`allAction: ${JSON.stringify(action)}, allQueries: ${JSON.stringify(allQueries)}`);
const queryPromises = action.searchRequests.map(async (req) => {
const prompt = getPrompt(req, action.think);
const result = await generator.generateObject({
model: TOOL_NAME,
schema: schemaGen.getQueryRewriterSchema(),
system: prompt.system,
prompt: prompt.user,
});
trackers?.actionTracker.trackThink(result.object.think);
return result.object.queries;
});
const queryResults = await Promise.all(queryPromises);
queryResults.forEach(queries => allQueries.push(...queries));
console.log(TOOL_NAME, allQueries);
return {queries: allQueries};
} catch (error) {
console.error(`Error in ${TOOL_NAME}`, error);
throw error;
}
}
export async function rewriteQueryDeep(action: SearchAction, trackers: TrackerContext, schemaGen: Schemas): Promise<{ queries: string[] }> {
try {
const generator = new ObjectGeneratorSafe(trackers.tokenTracker);
const allQueries = [...action.searchRequests];