From c9a51bb4034cfd55b71ab827def598b2a312ecf9 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Fri, 14 Mar 2025 13:22:17 +0800 Subject: [PATCH] fix: fallback genobj --- src/agent.ts | 7 ++-- src/app.ts | 74 +++++++++++++++++++++++-------------- src/types.ts | 28 +++++++++----- src/utils/safe-generator.ts | 7 ++-- src/utils/url-tools.ts | 18 +++++---- 5 files changed, 83 insertions(+), 51 deletions(-) diff --git a/src/agent.ts b/src/agent.ts index 38922c9..fbb1979 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -30,7 +30,7 @@ import { addToAllURLs, rankURLs, countUrlParts, - removeBFromA, + filterURLs, normalizeUrl, sampleMultinomial, weightedURLToString, getLastModified, keepKPerHostname, processURLs } from "./utils/url-tools"; @@ -239,7 +239,7 @@ export async function getResponse(question?: string, maxBadAttempts: number = 3, existingContext?: Partial, messages?: Array -): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[] }> { +): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[], allURLs: string[] }> { let step = 0; let totalStep = 0; @@ -329,7 +329,7 @@ export async function getResponse(question?: string, if (allURLs && Object.keys(allURLs).length > 0) { // rerank urls weightedURLs = rankURLs( - removeBFromA(allURLs, visitedURLs), + filterURLs(allURLs, visitedURLs), { question: currentQuestion }, context); @@ -851,6 +851,7 @@ But unfortunately, you failed to solve the issue. You need to think out of the b context, visitedURLs: returnedURLs, readURLs: visitedURLs, + allURLs: weightedURLs.map(r => r.url) }; } diff --git a/src/app.ts b/src/app.ts index 7e9a0ad..9a4c08b 100644 --- a/src/app.ts +++ b/src/app.ts @@ -393,7 +393,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { // clean from all assistant messages body.messages = body.messages?.filter(message => { if (message.role === 'assistant') { - // 2 cases message.content can be a string or an array + // 2 cases message.content can be a string or an array if (typeof message.content === 'string') { message.content = (message.content as string).replace(/[\s\S]*?<\/think>/g, '').trim(); // Filter out the message if the content is empty after removal @@ -406,7 +406,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { } }); //Filter out any content objects in the array that now have null/undefined/empty text. - message.content = message.content.filter((content:any) => + message.content = message.content.filter((content: any) => !(content.type === 'text' && content.text === '') ); @@ -417,17 +417,17 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { } else if (message.role === 'user' && Array.isArray(message.content)) { message.content = message.content.map((content: any) => { if (content.type === 'image_url') { - return { - type: 'image', - image: content.image_url?.url || '', - } + return { + type: 'image', + image: content.image_url?.url || '', + } } return content; }); return true; } else if (message.role === 'system') { if (Array.isArray(message.content)) { - message.content = message.content.map((content: any) => `${content.text || content}`).join(' '); + message.content = message.content.map((content: any) => `${content.text || content}`).join(' '); } return true; } @@ -503,19 +503,19 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { // emit every url in the visit action in url field (step as VisitAction).URLTargets.forEach((url) => { const chunk: ChatCompletionChunk = { - id: requestId, - object: 'chat.completion.chunk', - created, - model: body.model, - system_fingerprint: 'fp_' + requestId, - choices: [{ - index: 0, - delta: {type: 'think', url}, - logprobs: null, - finish_reason: null, - }] - }; - res.write(`data: ${JSON.stringify(chunk)}\n\n`); + id: requestId, + object: 'chat.completion.chunk', + created, + model: body.model, + system_fingerprint: 'fp_' + requestId, + choices: [{ + index: 0, + delta: {type: 'think', url}, + logprobs: null, + finish_reason: null, + }] + }; + res.write(`data: ${JSON.stringify(chunk)}\n\n`); }); } if (step.think) { @@ -545,11 +545,23 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { try { const { result: finalStep, - visitedURLs: visitedURLs, - readURLs: readURLs + visitedURLs, + readURLs, + allURLs } = await getResponse(undefined, tokenBudget, maxBadAttempts, context, body.messages) let finalAnswer = (finalStep as AnswerAction).mdAnswer; + const annotations = (finalStep as AnswerAction).references?.map(ref => ({ + type: 'url_citation' as const, + url_citation: { + title: ref.title, + exactQuote: ref.exactQuote, + url: ref.url, + dateTime: ref.dateTime, + } + })) + + if (responseSchema) { try { const generator = new ObjectGeneratorSafe(context?.tokenTracker); @@ -597,13 +609,18 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { system_fingerprint: 'fp_' + requestId, choices: [{ index: 0, - delta: {content: finalAnswer, type: responseSchema? 'json': 'text'}, + delta: { + content: finalAnswer, + type: responseSchema ? 'json' : 'text', + annotations, + }, logprobs: null, finish_reason: 'stop' }], usage, visitedURLs, - readURLs + readURLs, + numURLs: allURLs.length }; res.write(`data: ${JSON.stringify(finalChunk)}\n\n`); res.end(); @@ -620,14 +637,16 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { message: { role: 'assistant', content: finalStep.action === 'answer' ? (finalAnswer || '') : finalStep.think, - type: responseSchema? 'json': 'text' + type: responseSchema ? 'json' : 'text', + annotations, }, logprobs: null, finish_reason: 'stop' }], usage, visitedURLs, - readURLs + readURLs, + numURLs: allURLs.length }; // Log final response (excluding full content for brevity) @@ -637,7 +656,8 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { contentLength: response.choices[0].message.content.length, usage: response.usage, visitedURLs: response.visitedURLs, - readURLs: response.readURLs + readURLs: response.readURLs, + numURLs: allURLs.length }); res.json(response); diff --git a/src/types.ts b/src/types.ts index dd7e217..222ed52 100644 --- a/src/types.ts +++ b/src/types.ts @@ -19,14 +19,17 @@ export type SearchAction = BaseAction & { searchRequests: string[]; }; +export type Reference = { + exactQuote: string; + url: string; + title: string; + dateTime?: string; + } + export type AnswerAction = BaseAction & { action: "answer"; answer: string; - references: Array<{ - exactQuote: string; - url: string; - dateTime?: string; - }>; + references: Array; isFinal?: boolean; mdAnswer?: string; }; @@ -35,11 +38,7 @@ export type AnswerAction = BaseAction & { export type KnowledgeItem = { question: string, answer: string, - references?: Array<{ - exactQuote: string; - url: string; - dateTime?: string; - }> | Array; + references?: Array | Array; type: 'qa' | 'side-info' | 'chat-history' | 'url' | 'coding', updated?: string, sourceCode?: string, @@ -218,6 +217,11 @@ export interface ChatCompletionRequest { response_format?: ResponseFormat; } +export interface URLAnnotation { + type: 'url_citation', + url_citation: Reference +} + export interface ChatCompletionResponse { id: string; object: 'chat.completion'; @@ -230,6 +234,7 @@ export interface ChatCompletionResponse { role: 'assistant'; content: string; type: 'text' | 'think' | 'json' | 'error'; + annotations?: Array; }; logprobs: null; finish_reason: 'stop' | 'error'; @@ -241,6 +246,7 @@ export interface ChatCompletionResponse { }; visitedURLs?: string[]; readURLs?: string[]; + numURLs?: number; } export interface ChatCompletionChunk { @@ -256,6 +262,7 @@ export interface ChatCompletionChunk { content?: string; type?: 'text' | 'think' | 'json' | 'error'; url?: string; + annotations?: Array; }; logprobs: null; finish_reason: null | 'stop' | 'thinking_end' | 'error'; @@ -263,6 +270,7 @@ export interface ChatCompletionChunk { usage?: any; visitedURLs?: string[]; readURLs?: string[]; + numURLs?: number; } // Tracker Types diff --git a/src/utils/safe-generator.ts b/src/utils/safe-generator.ts index dc7841c..6a8fd23 100644 --- a/src/utils/safe-generator.ts +++ b/src/utils/safe-generator.ts @@ -164,16 +164,17 @@ export class ObjectGeneratorSafe { const fallbackModel = getModel('fallback'); if (NoObjectGeneratedError.isInstance(parseError)) { const failedOutput = (parseError as any).text; - console.error(`${model} failed on object generation ${failedOutput} -> manual parsing failed again -> trying fallback model`, fallbackModel); + console.error(`${model} failed on object generation ${failedOutput} -> manual parsing failed again -> trying fallback model`); try { // Create a distilled version of the schema without descriptions const distilledSchema = this.createDistilledSchema(schema); - console.log('Distilled schema', distilledSchema) + // find last `"url":` appear in the string, which is the source of the problem + const tailoredOutput = failedOutput.slice(0, Math.max(failedOutput.lastIndexOf('"url":'), 1500)); const fallbackResult = await generateObject({ model: fallbackModel, schema: distilledSchema, - prompt: `Following the given JSON schema, extract the field from below: \n\n ${failedOutput}`, + prompt: `Following the given JSON schema, extract the field from below: \n\n ${tailoredOutput}`, maxTokens: getToolConfig('fallback').maxTokens, temperature: getToolConfig('fallback').temperature, }); diff --git a/src/utils/url-tools.ts b/src/utils/url-tools.ts index 33d4eba..e74511a 100644 --- a/src/utils/url-tools.ts +++ b/src/utils/url-tools.ts @@ -136,9 +136,9 @@ export function normalizeUrl(urlString: string, debug = false, options = { } } -export function removeBFromA(allURLs: Record, visitedURLs: string[]): SearchSnippet[] { +export function filterURLs(allURLs: Record, visitedURLs: string[]): SearchSnippet[] { return Object.entries(allURLs) - .filter(([url]) => !visitedURLs.includes(url)) + .filter(([url, ]) => !visitedURLs.includes(url)) .map(([, result]) => result); } @@ -269,13 +269,14 @@ export const rankURLs = (urlItems: SearchSnippet[], options: any = {}, trackers: }; export const addToAllURLs = (r: SearchSnippet, allURLs: Record, weightDelta = 1) => { - if (!allURLs[r.url]) { - allURLs[r.url] = r; - allURLs[r.url].weight = weightDelta; + const nURL = normalizeUrl(r.url); + if (!allURLs[nURL]) { + allURLs[nURL] = r; + allURLs[nURL].weight = weightDelta; } else { - (allURLs[r.url].weight as number)+= weightDelta; - const curDesc = allURLs[r.url].description; - allURLs[r.url].description = smartMergeStrings(curDesc, r.description); + (allURLs[nURL].weight as number)+= weightDelta; + const curDesc = allURLs[nURL].description; + allURLs[nURL].description = smartMergeStrings(curDesc, r.description); } } @@ -413,6 +414,7 @@ export async function processURLs( const urlResults = await Promise.all( urls.map(async url => { try { + url = normalizeUrl(url); const {response} = await readUrl(url, true, context.tokenTracker); const {data} = response; const guessedTime = await getLastModified(url);