From b432a27bff154e02c0f941c6e9f81c7041316069 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 15 Apr 2025 23:22:48 +0800 Subject: [PATCH] refactor: build ref with embeddings --- src/agent.ts | 9 ++- src/app.ts | 2 + src/tools/build-ref.ts | 96 +-------------------------- src/tools/embeddings.ts | 118 +++++++++++++++++++++++++++++++++ src/tools/jina-dedup.ts | 79 +--------------------- src/tools/jina-latechunk.ts | 126 +++++++----------------------------- src/tools/jina-rerank.ts | 2 +- src/tools/segment.ts | 2 +- src/types.ts | 3 + 9 files changed, 160 insertions(+), 277 deletions(-) create mode 100644 src/tools/embeddings.ts diff --git a/src/agent.ts b/src/agent.ts index 4ba1bec..eae0741 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -391,7 +391,9 @@ export async function getResponse(question?: string, noDirectAnswer: boolean = false, boostHostnames: string[] = [], badHostnames: string[] = [], - onlyHostnames: string[] = [] + onlyHostnames: string[] = [], + maxRef: number = 10, + minRelScore: number = 0.7 ): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[], allURLs: string[] }> { let step = 0; @@ -990,7 +992,10 @@ But unfortunately, you failed to solve the issue. You need to think out of the b answerStep.answer, allWebContents, context, - SchemaGen + SchemaGen, + 80, + maxRef, + minRelScore ); answerStep.answer = answer; diff --git a/src/app.ts b/src/app.ts index 8f815b5..b8d3514 100644 --- a/src/app.ts +++ b/src/app.ts @@ -558,6 +558,8 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { body.boost_hostnames, body.bad_hostnames, body.only_hostnames, + body.max_annotations, + body.min_annotation_relevance ) let finalAnswer = (finalStep as AnswerAction).mdAnswer; diff --git a/src/tools/build-ref.ts b/src/tools/build-ref.ts index 2e119dc..c04a33e 100644 --- a/src/tools/build-ref.ts +++ b/src/tools/build-ref.ts @@ -1,106 +1,16 @@ import {segmentText} from './segment'; -import {JinaEmbeddingRequest, JinaEmbeddingResponse, Reference, TrackerContext, WebContent} from "../types"; +import {Reference, TrackerContext, WebContent} from "../types"; import {Schemas} from "../utils/schemas"; -import axios, {AxiosError} from 'axios'; -import {JINA_API_KEY} from "../config"; import {cosineSimilarity, jaccardRank} from "./cosine"; - -const BATCH_SIZE = 2000; -const API_URL = "https://api.jina.ai/v1/embeddings"; - -// Simplified function to get embeddings in a single request -async function getEmbeddings( - texts: string[], - tokenTracker?: any -): Promise<{ embeddings: number[][], tokens: number }> { - console.log(`[embeddings] Getting embeddings for ${texts.length} texts`); - - if (!JINA_API_KEY) { - throw new Error('JINA_API_KEY is not set'); - } - - // Handle empty input case - if (texts.length === 0) { - return {embeddings: [], tokens: 0}; - } - - // Process in batches of 2000 - const allEmbeddings: number[][] = []; - let totalTokens = 0; - const batchCount = Math.ceil(texts.length / BATCH_SIZE); - - for (let i = 0; i < texts.length; i += BATCH_SIZE) { - const batchTexts = texts.slice(i, i + BATCH_SIZE); - const currentBatch = Math.floor(i / BATCH_SIZE) + 1; - console.log(`[embeddings] Processing batch ${currentBatch}/${batchCount} (${batchTexts.length} texts)`); - - const request: JinaEmbeddingRequest = { - model: "jina-embeddings-v3", - task: "text-matching", - late_chunking: false, // Late chunking turned off always - dimensions: 1024, - embedding_type: "float", - input: batchTexts, - truncate: true - }; - - try { - const response = await axios.post( - API_URL, - request, - { - headers: { - "Content-Type": "application/json", - "Authorization": `Bearer ${JINA_API_KEY}` - } - } - ); - - // Validate response format - if (!response.data.data || response.data.data.length !== batchTexts.length) { - console.error('Invalid response from Jina API:', response.data); - continue; - } - - // Sort embeddings by index to maintain original order - const batchEmbeddings = response.data.data - .sort((a, b) => a.index - b.index) - .map(item => item.embedding); - - allEmbeddings.push(...batchEmbeddings); - totalTokens += response.data.usage.total_tokens; - console.log(`[embeddings] Batch ${currentBatch} complete. Tokens used: ${response.data.usage.total_tokens}, total so far: ${totalTokens}`); - - } catch (error) { - console.error('Error calling Jina Embeddings API:', error); - if (error instanceof AxiosError && error.response?.status === 402) { - return {embeddings: [], tokens: 0}; - } - throw error; - } - } - - // Track token usage if tracker is provided - if (tokenTracker) { - tokenTracker.trackUsage('embeddings', { - promptTokens: totalTokens, - completionTokens: 0, - totalTokens: totalTokens - }); - } - - console.log(`[embeddings] Complete. Generated ${allEmbeddings.length} embeddings using ${totalTokens} tokens`); - return {embeddings: allEmbeddings, tokens: totalTokens}; -} - +import {getEmbeddings} from "./embeddings"; export async function buildReferences( answer: string, webContents: Record, context: TrackerContext, schema: Schemas, - maxRef: number = 10, minChunkLength: number = 80, + maxRef: number = 10, minRelScore: number = 0.7 ): Promise<{ answer: string, references: Array }> { console.log(`[buildReferences] Starting with maxRef=${maxRef}, minChunkLength=${minChunkLength}, minRelScore=${minRelScore}`); diff --git a/src/tools/embeddings.ts b/src/tools/embeddings.ts new file mode 100644 index 0000000..b9780af --- /dev/null +++ b/src/tools/embeddings.ts @@ -0,0 +1,118 @@ +import {JINA_API_KEY} from "../config"; +import {JinaEmbeddingRequest, JinaEmbeddingResponse} from "../types"; +import axios, {AxiosError} from "axios"; + +const BATCH_SIZE = 128; +const API_URL = "https://api.jina.ai/v1/embeddings"; + +// Modified to support different embedding tasks and dimensions +export async function getEmbeddings( + texts: string[], + tokenTracker?: any, + options: { + task?: "text-matching" | "retrieval.passage" | "retrieval.query", + dimensions?: number, + late_chunking?: boolean, + embedding_type?: string + } = {} +): Promise<{ embeddings: number[][], tokens: number }> { + console.log(`[embeddings] Getting embeddings for ${texts.length} texts`); + + if (!JINA_API_KEY) { + throw new Error('JINA_API_KEY is not set'); + } + + // Handle empty input case + if (texts.length === 0) { + return {embeddings: [], tokens: 0}; + } + + // Process in batches + const allEmbeddings: number[][] = []; + let totalTokens = 0; + const batchCount = Math.ceil(texts.length / BATCH_SIZE); + + for (let i = 0; i < texts.length; i += BATCH_SIZE) { + const batchTexts = texts.slice(i, i + BATCH_SIZE); + const currentBatch = Math.floor(i / BATCH_SIZE) + 1; + console.log(`[embeddings] Processing batch ${currentBatch}/${batchCount} (${batchTexts.length} texts)`); + + const request: JinaEmbeddingRequest = { + model: "jina-embeddings-v3", + task: options.task || "text-matching", + input: batchTexts, + truncate: true + }; + + // Add optional parameters if provided + if (options.dimensions) request.dimensions = options.dimensions; + if (options.late_chunking) request.late_chunking = options.late_chunking; + if (options.embedding_type) request.embedding_type = options.embedding_type; + + try { + const response = await axios.post( + API_URL, + request, + { + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${JINA_API_KEY}` + } + } + ); + + // Prepare embeddings, handling any missing indices + let batchEmbeddings: number[][]; + + if (!response.data.data || response.data.data.length !== batchTexts.length) { + console.error('Invalid response from Jina API:', response.data.data?.length, batchTexts.length); + + // Find missing indices and complete with zero vectors + const receivedIndices = new Set(response.data.data?.map(item => item.index) || []); + const dimensionSize = response.data.data?.[0]?.embedding?.length || options.dimensions || 1024; + + batchEmbeddings = []; + + for (let idx = 0; idx < batchTexts.length; idx++) { + if (receivedIndices.has(idx)) { + // Find the item with this index + const item = response.data.data.find(d => d.index === idx); + batchEmbeddings.push(item!.embedding); + } else { + // Create a zero vector for missing index + console.error(`Missing embedding for index ${idx}: [${batchTexts[idx]}]`); + batchEmbeddings.push(new Array(dimensionSize).fill(0)); + } + } + } else { + // All indices present, just sort by index + batchEmbeddings = response.data.data + .sort((a, b) => a.index - b.index) + .map(item => item.embedding); + } + + allEmbeddings.push(...batchEmbeddings); + totalTokens += response.data.usage?.total_tokens || 0; + console.log(`[embeddings] Batch ${currentBatch} complete. Tokens used: ${response.data.usage?.total_tokens || 0}, total so far: ${totalTokens}`); + + } catch (error) { + console.error('Error calling Jina Embeddings API:', error); + if (error instanceof AxiosError && error.response?.status === 402) { + return {embeddings: [], tokens: 0}; + } + throw error; + } + } + + // Track token usage if tracker is provided + if (tokenTracker) { + tokenTracker.trackUsage('embeddings', { + promptTokens: totalTokens, + completionTokens: 0, + totalTokens: totalTokens + }); + } + + console.log(`[embeddings] Complete. Generated ${allEmbeddings.length} embeddings using ${totalTokens} tokens`); + return {embeddings: allEmbeddings, tokens: totalTokens}; +} diff --git a/src/tools/jina-dedup.ts b/src/tools/jina-dedup.ts index 11efbac..27bb825 100644 --- a/src/tools/jina-dedup.ts +++ b/src/tools/jina-dedup.ts @@ -1,77 +1,9 @@ -import axios, {AxiosError} from 'axios'; import {TokenTracker} from "../utils/token-tracker"; -import {JINA_API_KEY} from "../config"; import {cosineSimilarity} from "./cosine"; -import {JinaEmbeddingRequest, JinaEmbeddingResponse} from "../types"; +import {getEmbeddings} from "./embeddings"; -const JINA_API_URL = 'https://api.jina.ai/v1/embeddings'; const SIMILARITY_THRESHOLD = 0.86; // Adjustable threshold for cosine similarity -const JINA_API_CONFIG = { - MODEL: 'jina-embeddings-v3', - TASK: 'text-matching', - DIMENSIONS: 1024, - EMBEDDING_TYPE: 'float', - LATE_CHUNKING: false -} as const; - - -// Get embeddings for all queries in one batch -async function getEmbeddings(queries: string[]): Promise<{ embeddings: number[][], tokens: number }> { - if (!JINA_API_KEY) { - throw new Error('JINA_API_KEY is not set'); - } - - const request: JinaEmbeddingRequest = { - model: JINA_API_CONFIG.MODEL, - task: JINA_API_CONFIG.TASK, - late_chunking: JINA_API_CONFIG.LATE_CHUNKING, - dimensions: JINA_API_CONFIG.DIMENSIONS, - embedding_type: JINA_API_CONFIG.EMBEDDING_TYPE, - input: queries - }; - - try { - const response = await axios.post( - JINA_API_URL, - request, - { - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${JINA_API_KEY}` - } - } - ); - - // Validate response format - if (!response.data.data || response.data.data.length !== queries.length) { - console.error('Invalid response from Jina API:', response.data); - return { - embeddings: [], - tokens: 0 - }; - } - - // Sort embeddings by index to maintain original order - const embeddings = response.data.data - .sort((a, b) => a.index - b.index) - .map(item => item.embedding); - - return { - embeddings, - tokens: response.data.usage.total_tokens - }; - } catch (error) { - console.error('Error getting embeddings from Jina:', error); - if (error instanceof AxiosError && error.response?.status === 402) { - return { - embeddings: [], - tokens: 0 - }; - } - throw error; - } -} export async function dedupQueries( newQueries: string[], @@ -88,7 +20,7 @@ export async function dedupQueries( // Get embeddings for all queries in one batch const allQueries = [...newQueries, ...existingQueries]; - const {embeddings: allEmbeddings, tokens} = await getEmbeddings(allQueries); + const {embeddings: allEmbeddings} = await getEmbeddings(allQueries, tracker); // If embeddings is empty (due to 402 error), return all new queries if (!allEmbeddings.length) { @@ -134,13 +66,6 @@ export async function dedupQueries( usedIndices.add(i); } } - - // Track token usage from the API - (tracker || new TokenTracker()).trackUsage('dedup', { - promptTokens: 0, - completionTokens: tokens, - totalTokens: tokens - }); console.log('Dedup:', uniqueQueries); return { unique_queries: uniqueQueries, diff --git a/src/tools/jina-latechunk.ts b/src/tools/jina-latechunk.ts index abae677..20613b0 100644 --- a/src/tools/jina-latechunk.ts +++ b/src/tools/jina-latechunk.ts @@ -1,25 +1,19 @@ import {TrackerContext} from "../types"; -import axios from 'axios'; -import {JINA_API_KEY} from "../config"; import {Schemas} from "../utils/schemas"; import {cosineSimilarity} from "./cosine"; +import {getEmbeddings} from "./embeddings"; +// Refactored cherryPick function export async function cherryPick(question: string, longContext: string, options: any = {}, trackers: TrackerContext, schemaGen: Schemas, url: string) { - const { snippetLength = 6000, // char length of each snippet numSnippets = Math.max(2, Math.min(5, Math.floor(longContext.length / snippetLength))), chunkSize = 300, // char length of each chunk } = options; - const maxTokensPerRequest = 8192 // Maximum tokens per embedding request - - // Rough estimate of tokens per character (can be adjusted based on your text) - const tokensPerCharacter = 0.4 - if (longContext.length < snippetLength * 2) { // If the context is shorter than the snippet length, return the whole context - console.log('content is too short, dont bother') + console.log('content is too short, dont bother'); return longContext; } @@ -38,106 +32,32 @@ export async function cherryPick(question: string, longContext: string, options: throw new Error('Empty question, returning full context'); } - // Estimate the number of tokens per chunk - const estimatedTokensPerChunk = Math.ceil(chunkSize * tokensPerCharacter); - - // Calculate chunks per batch to stay under token limit - const chunksPerBatch = Math.floor(maxTokensPerRequest / estimatedTokensPerChunk); - - // Create batches of chunks - const chunkBatches = []; - for (let i = 0; i < chunks.length; i += chunksPerBatch) { - chunkBatches.push(chunks.slice(i, i + chunksPerBatch)); - } - - console.log(`Total length ${longContext.length} split ${chunks.length} chunks into ${chunkBatches.length} batches of ~${chunksPerBatch} chunks each`); - - // Process each batch and collect the embeddings - const allChunkEmbeddings: number[][] = []; - let totalTokensUsed = 0; - - for (let batchIndex = 0; batchIndex < chunkBatches.length; batchIndex++) { - const batch = chunkBatches[batchIndex]; - console.log(`Processing batch ${batchIndex + 1}/${chunkBatches.length} with ${batch.length} chunks`); - - // Get embeddings for the current batch - const batchEmbeddingResponse = await axios.post( - 'https://api.jina.ai/v1/embeddings', - { - model: "jina-embeddings-v3", - task: "retrieval.passage", - late_chunking: true, - dimensions: 1024, - embedding_type: "float", - input: batch, - truncate: true - }, - { - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${JINA_API_KEY}` - } - } - ); - - if (batchEmbeddingResponse.status !== 200) { - throw new Error(`Unexpected status code from API: ${batchEmbeddingResponse.status}`); - } - - // Validate response structure - if (!batchEmbeddingResponse.data?.data) { - throw new Error("Unexpected API response format"); - } - - // Extract embeddings from this batch - const batchEmbeddings = batchEmbeddingResponse.data.data.map((item: any) => item.embedding); - allChunkEmbeddings.push(...batchEmbeddings); - - // Track token usage - const batchTokens = batchEmbeddingResponse.data.usage?.total_tokens || 0; - totalTokensUsed += batchTokens; - } - - // Get embedding for the question - const questionEmbeddingResponse = await axios.post( - 'https://api.jina.ai/v1/embeddings', + // Get embeddings for all chunks using the new getEmbeddings function + const chunkEmbeddingResult = await getEmbeddings( + chunks, + trackers.tokenTracker, { - model: "jina-embeddings-v3", - task: "retrieval.query", + task: "retrieval.passage", dimensions: 1024, - embedding_type: "float", - input: [question], - truncate: true - }, - { - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${JINA_API_KEY}` - } + late_chunking: true, + embedding_type: "float" } ); - if (questionEmbeddingResponse.status !== 200) { - throw new Error("Unexpected status code from API"); - } + const allChunkEmbeddings = chunkEmbeddingResult.embeddings; - // Validate question embedding response - if (!questionEmbeddingResponse.data?.data || !questionEmbeddingResponse.data.data[0]?.embedding) { - throw new Error("Question embedding not found in API response"); - } + // Get embedding for the question + const questionEmbeddingResult = await getEmbeddings( + [question], + trackers.tokenTracker, + { + task: "retrieval.query", + dimensions: 1024, + embedding_type: "float" + } + ); - // Track token usage for question embedding - const questionTokens = questionEmbeddingResponse.data.usage?.total_tokens || 0; - totalTokensUsed += questionTokens; - - // Track total token usage - trackers.tokenTracker.trackUsage('latechunk', { - promptTokens: totalTokensUsed, - completionTokens: 0, - totalTokens: totalTokensUsed - }); - - const questionEmbedding = questionEmbeddingResponse.data.data[0].embedding; + const questionEmbedding = questionEmbeddingResult.embeddings[0]; // Verify that we got embeddings for all chunks if (allChunkEmbeddings.length !== chunks.length) { @@ -199,4 +119,4 @@ ${snippet} // Fallback: just return the beginning of the context up to the desired length return longContext.substring(0, snippetLength * numSnippets); } -} +} \ No newline at end of file diff --git a/src/tools/jina-rerank.ts b/src/tools/jina-rerank.ts index ac679d0..4398c66 100644 --- a/src/tools/jina-rerank.ts +++ b/src/tools/jina-rerank.ts @@ -43,7 +43,7 @@ export async function rerankDocuments( batches.push(documents.slice(i, i + batchSize)); } - console.log(`Processing ${documents.length} documents in ${batches.length} batches of up to ${batchSize} each`); + console.log(`Rerank ${documents.length} documents in ${batches.length} batches of up to ${batchSize} each`); // Process all batches in parallel const batchResults = await Promise.all( diff --git a/src/tools/segment.ts b/src/tools/segment.ts index cc078c4..e1b433d 100644 --- a/src/tools/segment.ts +++ b/src/tools/segment.ts @@ -36,7 +36,7 @@ export async function segmentText( // Process all batches in parallel const batchPromises = batches.map(async (batch, i) => { - console.log(`Processing batch ${i + 1}/${batches.length} (size: ${batch.length})`); + console.log(`[Segment] Processing batch ${i + 1}/${batches.length} (size: ${batch.length})`); try { const {data} = await axios.post( diff --git a/src/types.ts b/src/types.ts index 7e543fa..794bb02 100644 --- a/src/types.ts +++ b/src/types.ts @@ -244,6 +244,9 @@ export interface ChatCompletionRequest { boost_hostnames?: string[]; bad_hostnames?: string[]; only_hostnames?: string[]; + + max_annotations?: number; + min_annotation_relevance?: number; } export interface URLAnnotation {