aggregate images when team size > 1

This commit is contained in:
Sha Zhou 2025-06-13 12:25:26 +08:00
parent a6100b3713
commit 45c668106f
4 changed files with 17 additions and 13 deletions

View File

@ -47,6 +47,7 @@ import { logInfo, logError, logDebug, logWarning } from './logging';
import { researchPlan } from './tools/research-planner'; import { researchPlan } from './tools/research-planner';
import { reduceAnswers } from './tools/reducer'; import { reduceAnswers } from './tools/reducer';
import { AxiosError } from 'axios'; import { AxiosError } from 'axios';
import { dedupImagesWithEmbeddings } from './utils/image-tools';
async function wait(seconds: number) { async function wait(seconds: number) {
logDebug(`Waiting ${seconds}s...`); logDebug(`Waiting ${seconds}s...`);
@ -405,7 +406,7 @@ export async function getResponse(question?: string,
searchProvider?: string, searchProvider?: string,
withImages: boolean = false, withImages: boolean = false,
teamSize: number = 2 teamSize: number = 2
): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[], allURLs: string[], allImages?: string[], relatedImages?: string[] }> { ): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[], allURLs: string[], imageReferences?: ImageReference[] }> {
let step = 0; let step = 0;
let totalStep = 0; let totalStep = 0;
@ -465,7 +466,6 @@ export async function getResponse(question?: string,
const visitedURLs: string[] = []; const visitedURLs: string[] = [];
const badURLs: string[] = []; const badURLs: string[] = [];
const imageObjects: ImageObject[] = []; const imageObjects: ImageObject[] = [];
let imageReferences: ImageReference[] = [];
const evaluationMetrics: Record<string, RepeatEvaluationType[]> = {}; const evaluationMetrics: Record<string, RepeatEvaluationType[]> = {};
// reserve the 10% final budget for the beast mode // reserve the 10% final budget for the beast mode
const regularBudget = tokenBudget * 0.85; const regularBudget = tokenBudget * 0.85;
@ -815,6 +815,7 @@ But then you realized you have asked them before. You decided to to think out of
answer: subproblemResponses.map(r => (r.result as AnswerAction).answer).join('\n\n'), answer: subproblemResponses.map(r => (r.result as AnswerAction).answer).join('\n\n'),
mdAnswer: subproblemResponses.map(r => (r.result as AnswerAction).mdAnswer).join('\n\n'), mdAnswer: subproblemResponses.map(r => (r.result as AnswerAction).mdAnswer).join('\n\n'),
references: subproblemResponses.map(r => (r.result as AnswerAction).references).flat(), references: subproblemResponses.map(r => (r.result as AnswerAction).references).flat(),
imageReferences: subproblemResponses.map(r => (r.result as AnswerAction).imageReferences).flat(),
isFinal: true, isFinal: true,
isAggregated: true isAggregated: true
} as AnswerAction; } as AnswerAction;
@ -823,8 +824,6 @@ But then you realized you have asked them before. You decided to to think out of
visitedURLs.push(...subproblemResponses.map(r => r.readURLs).flat()); visitedURLs.push(...subproblemResponses.map(r => r.readURLs).flat());
weightedURLs = subproblemResponses.map(r => r.allURLs.map(url => ({ url, title: '' } as BoostedSearchSnippet))).flat(); weightedURLs = subproblemResponses.map(r => r.allURLs.map(url => ({ url, title: '' } as BoostedSearchSnippet))).flat();
// TODO aggregate images @shazhou2015
// break the loop, jump directly final boxing // break the loop, jump directly final boxing
break; break;
} }
@ -1076,16 +1075,20 @@ But unfortunately, you failed to solve the issue. You need to think out of the b
if (imageObjects.length && withImages) { if (imageObjects.length && withImages) {
try { try {
imageReferences = await buildImageReferences(answerStep.answer, imageObjects, context, SchemaGen); answerStep.imageReferences = await buildImageReferences(answerStep.answer, imageObjects, context, SchemaGen);
logDebug('Image references built:', { imageReferences }); logDebug('Image references built:', { imageReferences: answerStep.imageReferences.map(i => ({url: i.url, score: i.relevanceScore, answerChunk: i.answerChunk})) });
} catch (error) { } catch (error) {
logError('Error building image references:', { error }); logError('Error building image references:', { error });
imageReferences = []; answerStep.imageReferences = [];
} }
} }
} else if (answerStep.isAggregated) { } else if (answerStep.isAggregated) {
answerStep.answer = await reduceAnswers(answerStep.answer, context, SchemaGen); answerStep.answer = await reduceAnswers(answerStep.answer, context, SchemaGen);
answerStep.mdAnswer = repairMarkdownFootnotesOuter(buildMdFromAnswer(answerStep)); answerStep.mdAnswer = repairMarkdownFootnotesOuter(buildMdFromAnswer(answerStep));
logDebug('[agent] all image references:', { count: answerStep.imageReferences?.length });
const dedupImages = dedupImagesWithEmbeddings(answerStep.imageReferences as ImageObject[], []);
logDebug('[agent] deduped images:', { count: dedupImages.length });
answerStep.imageReferences = answerStep.imageReferences?.filter(i => dedupImages.some(d => d.url === i.url)) || [];
} }
// max return 300 urls // max return 300 urls
@ -1096,8 +1099,7 @@ But unfortunately, you failed to solve the issue. You need to think out of the b
visitedURLs: returnedURLs, // deprecated visitedURLs: returnedURLs, // deprecated
readURLs: visitedURLs.filter(url => !badURLs.includes(url)), readURLs: visitedURLs.filter(url => !badURLs.includes(url)),
allURLs: weightedURLs.map(r => r.url), allURLs: weightedURLs.map(r => r.url),
allImages: withImages ? imageObjects.map(i => i.url) : undefined, imageReferences: withImages ? (thisStep as AnswerAction).imageReferences : undefined,
relatedImages: withImages ? imageReferences.map(i => i.url) : undefined,
}; };
} }

View File

@ -575,8 +575,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
visitedURLs, visitedURLs,
readURLs, readURLs,
allURLs, allURLs,
allImages, imageReferences,
relatedImages,
} = await getResponse(undefined, } = await getResponse(undefined,
tokenBudget, tokenBudget,
maxBadAttempts, maxBadAttempts,
@ -670,7 +669,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
visitedURLs, visitedURLs,
readURLs, readURLs,
numURLs: allURLs.length, numURLs: allURLs.length,
relatedImages relatedImages: body.with_images ? (imageReferences?.map(ref => ref.url) || []) : undefined,
}; };
res.write(`data: ${JSON.stringify(finalChunk)}\n\n`); res.write(`data: ${JSON.stringify(finalChunk)}\n\n`);
res.end(); res.end();
@ -697,7 +696,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
visitedURLs, visitedURLs,
readURLs, readURLs,
numURLs: allURLs.length, numURLs: allURLs.length,
relatedImages, relatedImages: body.with_images ? (imageReferences?.map(ref => ref.url) || []) : undefined,
}; };
logInfo(`[chat/completions] Completed!`, { logInfo(`[chat/completions] Completed!`, {

View File

@ -540,6 +540,7 @@ export async function buildImageReferences(
return { return {
url: source.url, url: source.url,
relevanceScore: match.relevanceScore, relevanceScore: match.relevanceScore,
embedding: [allImageEmbeddings[match.imageIndex]],
answerChunk: match.answerChunk, answerChunk: match.answerChunk,
answerChunkPosition: match.answerChunkPosition answerChunkPosition: match.answerChunkPosition
}; };

View File

@ -33,6 +33,7 @@ export type ImageReference = {
relevanceScore?: number; relevanceScore?: number;
answerChunk?: string; answerChunk?: string;
answerChunkPosition?: number[]; answerChunkPosition?: number[];
embedding?: number[][];
} }
export type AnswerAction = BaseAction & { export type AnswerAction = BaseAction & {
@ -42,6 +43,7 @@ export type AnswerAction = BaseAction & {
isFinal?: boolean; isFinal?: boolean;
mdAnswer?: string; mdAnswer?: string;
isAggregated?: boolean; isAggregated?: boolean;
imageReferences?: Array<ImageReference>;
}; };