refactor: build ref with embeddings

This commit is contained in:
Han Xiao 2025-04-15 23:22:48 +08:00
parent 858289a25d
commit b432a27bff
9 changed files with 160 additions and 277 deletions

View File

@ -391,7 +391,9 @@ export async function getResponse(question?: string,
noDirectAnswer: boolean = false,
boostHostnames: string[] = [],
badHostnames: string[] = [],
onlyHostnames: string[] = []
onlyHostnames: string[] = [],
maxRef: number = 10,
minRelScore: number = 0.7
): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[], allURLs: string[] }> {
let step = 0;
@ -990,7 +992,10 @@ But unfortunately, you failed to solve the issue. You need to think out of the b
answerStep.answer,
allWebContents,
context,
SchemaGen
SchemaGen,
80,
maxRef,
minRelScore
);
answerStep.answer = answer;

View File

@ -558,6 +558,8 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
body.boost_hostnames,
body.bad_hostnames,
body.only_hostnames,
body.max_annotations,
body.min_annotation_relevance
)
let finalAnswer = (finalStep as AnswerAction).mdAnswer;

View File

@ -1,106 +1,16 @@
import {segmentText} from './segment';
import {JinaEmbeddingRequest, JinaEmbeddingResponse, Reference, TrackerContext, WebContent} from "../types";
import {Reference, TrackerContext, WebContent} from "../types";
import {Schemas} from "../utils/schemas";
import axios, {AxiosError} from 'axios';
import {JINA_API_KEY} from "../config";
import {cosineSimilarity, jaccardRank} from "./cosine";
const BATCH_SIZE = 2000;
const API_URL = "https://api.jina.ai/v1/embeddings";
// Simplified function to get embeddings in a single request
async function getEmbeddings(
texts: string[],
tokenTracker?: any
): Promise<{ embeddings: number[][], tokens: number }> {
console.log(`[embeddings] Getting embeddings for ${texts.length} texts`);
if (!JINA_API_KEY) {
throw new Error('JINA_API_KEY is not set');
}
// Handle empty input case
if (texts.length === 0) {
return {embeddings: [], tokens: 0};
}
// Process in batches of 2000
const allEmbeddings: number[][] = [];
let totalTokens = 0;
const batchCount = Math.ceil(texts.length / BATCH_SIZE);
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
const batchTexts = texts.slice(i, i + BATCH_SIZE);
const currentBatch = Math.floor(i / BATCH_SIZE) + 1;
console.log(`[embeddings] Processing batch ${currentBatch}/${batchCount} (${batchTexts.length} texts)`);
const request: JinaEmbeddingRequest = {
model: "jina-embeddings-v3",
task: "text-matching",
late_chunking: false, // Late chunking turned off always
dimensions: 1024,
embedding_type: "float",
input: batchTexts,
truncate: true
};
try {
const response = await axios.post<JinaEmbeddingResponse>(
API_URL,
request,
{
headers: {
"Content-Type": "application/json",
"Authorization": `Bearer ${JINA_API_KEY}`
}
}
);
// Validate response format
if (!response.data.data || response.data.data.length !== batchTexts.length) {
console.error('Invalid response from Jina API:', response.data);
continue;
}
// Sort embeddings by index to maintain original order
const batchEmbeddings = response.data.data
.sort((a, b) => a.index - b.index)
.map(item => item.embedding);
allEmbeddings.push(...batchEmbeddings);
totalTokens += response.data.usage.total_tokens;
console.log(`[embeddings] Batch ${currentBatch} complete. Tokens used: ${response.data.usage.total_tokens}, total so far: ${totalTokens}`);
} catch (error) {
console.error('Error calling Jina Embeddings API:', error);
if (error instanceof AxiosError && error.response?.status === 402) {
return {embeddings: [], tokens: 0};
}
throw error;
}
}
// Track token usage if tracker is provided
if (tokenTracker) {
tokenTracker.trackUsage('embeddings', {
promptTokens: totalTokens,
completionTokens: 0,
totalTokens: totalTokens
});
}
console.log(`[embeddings] Complete. Generated ${allEmbeddings.length} embeddings using ${totalTokens} tokens`);
return {embeddings: allEmbeddings, tokens: totalTokens};
}
import {getEmbeddings} from "./embeddings";
export async function buildReferences(
answer: string,
webContents: Record<string, WebContent>,
context: TrackerContext,
schema: Schemas,
maxRef: number = 10,
minChunkLength: number = 80,
maxRef: number = 10,
minRelScore: number = 0.7
): Promise<{ answer: string, references: Array<Reference> }> {
console.log(`[buildReferences] Starting with maxRef=${maxRef}, minChunkLength=${minChunkLength}, minRelScore=${minRelScore}`);

118
src/tools/embeddings.ts Normal file
View File

@ -0,0 +1,118 @@
import {JINA_API_KEY} from "../config";
import {JinaEmbeddingRequest, JinaEmbeddingResponse} from "../types";
import axios, {AxiosError} from "axios";
const BATCH_SIZE = 128;
const API_URL = "https://api.jina.ai/v1/embeddings";
// Modified to support different embedding tasks and dimensions
export async function getEmbeddings(
texts: string[],
tokenTracker?: any,
options: {
task?: "text-matching" | "retrieval.passage" | "retrieval.query",
dimensions?: number,
late_chunking?: boolean,
embedding_type?: string
} = {}
): Promise<{ embeddings: number[][], tokens: number }> {
console.log(`[embeddings] Getting embeddings for ${texts.length} texts`);
if (!JINA_API_KEY) {
throw new Error('JINA_API_KEY is not set');
}
// Handle empty input case
if (texts.length === 0) {
return {embeddings: [], tokens: 0};
}
// Process in batches
const allEmbeddings: number[][] = [];
let totalTokens = 0;
const batchCount = Math.ceil(texts.length / BATCH_SIZE);
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
const batchTexts = texts.slice(i, i + BATCH_SIZE);
const currentBatch = Math.floor(i / BATCH_SIZE) + 1;
console.log(`[embeddings] Processing batch ${currentBatch}/${batchCount} (${batchTexts.length} texts)`);
const request: JinaEmbeddingRequest = {
model: "jina-embeddings-v3",
task: options.task || "text-matching",
input: batchTexts,
truncate: true
};
// Add optional parameters if provided
if (options.dimensions) request.dimensions = options.dimensions;
if (options.late_chunking) request.late_chunking = options.late_chunking;
if (options.embedding_type) request.embedding_type = options.embedding_type;
try {
const response = await axios.post<JinaEmbeddingResponse>(
API_URL,
request,
{
headers: {
"Content-Type": "application/json",
"Authorization": `Bearer ${JINA_API_KEY}`
}
}
);
// Prepare embeddings, handling any missing indices
let batchEmbeddings: number[][];
if (!response.data.data || response.data.data.length !== batchTexts.length) {
console.error('Invalid response from Jina API:', response.data.data?.length, batchTexts.length);
// Find missing indices and complete with zero vectors
const receivedIndices = new Set(response.data.data?.map(item => item.index) || []);
const dimensionSize = response.data.data?.[0]?.embedding?.length || options.dimensions || 1024;
batchEmbeddings = [];
for (let idx = 0; idx < batchTexts.length; idx++) {
if (receivedIndices.has(idx)) {
// Find the item with this index
const item = response.data.data.find(d => d.index === idx);
batchEmbeddings.push(item!.embedding);
} else {
// Create a zero vector for missing index
console.error(`Missing embedding for index ${idx}: [${batchTexts[idx]}]`);
batchEmbeddings.push(new Array(dimensionSize).fill(0));
}
}
} else {
// All indices present, just sort by index
batchEmbeddings = response.data.data
.sort((a, b) => a.index - b.index)
.map(item => item.embedding);
}
allEmbeddings.push(...batchEmbeddings);
totalTokens += response.data.usage?.total_tokens || 0;
console.log(`[embeddings] Batch ${currentBatch} complete. Tokens used: ${response.data.usage?.total_tokens || 0}, total so far: ${totalTokens}`);
} catch (error) {
console.error('Error calling Jina Embeddings API:', error);
if (error instanceof AxiosError && error.response?.status === 402) {
return {embeddings: [], tokens: 0};
}
throw error;
}
}
// Track token usage if tracker is provided
if (tokenTracker) {
tokenTracker.trackUsage('embeddings', {
promptTokens: totalTokens,
completionTokens: 0,
totalTokens: totalTokens
});
}
console.log(`[embeddings] Complete. Generated ${allEmbeddings.length} embeddings using ${totalTokens} tokens`);
return {embeddings: allEmbeddings, tokens: totalTokens};
}

View File

@ -1,77 +1,9 @@
import axios, {AxiosError} from 'axios';
import {TokenTracker} from "../utils/token-tracker";
import {JINA_API_KEY} from "../config";
import {cosineSimilarity} from "./cosine";
import {JinaEmbeddingRequest, JinaEmbeddingResponse} from "../types";
import {getEmbeddings} from "./embeddings";
const JINA_API_URL = 'https://api.jina.ai/v1/embeddings';
const SIMILARITY_THRESHOLD = 0.86; // Adjustable threshold for cosine similarity
const JINA_API_CONFIG = {
MODEL: 'jina-embeddings-v3',
TASK: 'text-matching',
DIMENSIONS: 1024,
EMBEDDING_TYPE: 'float',
LATE_CHUNKING: false
} as const;
// Get embeddings for all queries in one batch
async function getEmbeddings(queries: string[]): Promise<{ embeddings: number[][], tokens: number }> {
if (!JINA_API_KEY) {
throw new Error('JINA_API_KEY is not set');
}
const request: JinaEmbeddingRequest = {
model: JINA_API_CONFIG.MODEL,
task: JINA_API_CONFIG.TASK,
late_chunking: JINA_API_CONFIG.LATE_CHUNKING,
dimensions: JINA_API_CONFIG.DIMENSIONS,
embedding_type: JINA_API_CONFIG.EMBEDDING_TYPE,
input: queries
};
try {
const response = await axios.post<JinaEmbeddingResponse>(
JINA_API_URL,
request,
{
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${JINA_API_KEY}`
}
}
);
// Validate response format
if (!response.data.data || response.data.data.length !== queries.length) {
console.error('Invalid response from Jina API:', response.data);
return {
embeddings: [],
tokens: 0
};
}
// Sort embeddings by index to maintain original order
const embeddings = response.data.data
.sort((a, b) => a.index - b.index)
.map(item => item.embedding);
return {
embeddings,
tokens: response.data.usage.total_tokens
};
} catch (error) {
console.error('Error getting embeddings from Jina:', error);
if (error instanceof AxiosError && error.response?.status === 402) {
return {
embeddings: [],
tokens: 0
};
}
throw error;
}
}
export async function dedupQueries(
newQueries: string[],
@ -88,7 +20,7 @@ export async function dedupQueries(
// Get embeddings for all queries in one batch
const allQueries = [...newQueries, ...existingQueries];
const {embeddings: allEmbeddings, tokens} = await getEmbeddings(allQueries);
const {embeddings: allEmbeddings} = await getEmbeddings(allQueries, tracker);
// If embeddings is empty (due to 402 error), return all new queries
if (!allEmbeddings.length) {
@ -134,13 +66,6 @@ export async function dedupQueries(
usedIndices.add(i);
}
}
// Track token usage from the API
(tracker || new TokenTracker()).trackUsage('dedup', {
promptTokens: 0,
completionTokens: tokens,
totalTokens: tokens
});
console.log('Dedup:', uniqueQueries);
return {
unique_queries: uniqueQueries,

View File

@ -1,25 +1,19 @@
import {TrackerContext} from "../types";
import axios from 'axios';
import {JINA_API_KEY} from "../config";
import {Schemas} from "../utils/schemas";
import {cosineSimilarity} from "./cosine";
import {getEmbeddings} from "./embeddings";
// Refactored cherryPick function
export async function cherryPick(question: string, longContext: string, options: any = {}, trackers: TrackerContext, schemaGen: Schemas, url: string) {
const {
snippetLength = 6000, // char length of each snippet
numSnippets = Math.max(2, Math.min(5, Math.floor(longContext.length / snippetLength))),
chunkSize = 300, // char length of each chunk
} = options;
const maxTokensPerRequest = 8192 // Maximum tokens per embedding request
// Rough estimate of tokens per character (can be adjusted based on your text)
const tokensPerCharacter = 0.4
if (longContext.length < snippetLength * 2) {
// If the context is shorter than the snippet length, return the whole context
console.log('content is too short, dont bother')
console.log('content is too short, dont bother');
return longContext;
}
@ -38,106 +32,32 @@ export async function cherryPick(question: string, longContext: string, options:
throw new Error('Empty question, returning full context');
}
// Estimate the number of tokens per chunk
const estimatedTokensPerChunk = Math.ceil(chunkSize * tokensPerCharacter);
// Calculate chunks per batch to stay under token limit
const chunksPerBatch = Math.floor(maxTokensPerRequest / estimatedTokensPerChunk);
// Create batches of chunks
const chunkBatches = [];
for (let i = 0; i < chunks.length; i += chunksPerBatch) {
chunkBatches.push(chunks.slice(i, i + chunksPerBatch));
}
console.log(`Total length ${longContext.length} split ${chunks.length} chunks into ${chunkBatches.length} batches of ~${chunksPerBatch} chunks each`);
// Process each batch and collect the embeddings
const allChunkEmbeddings: number[][] = [];
let totalTokensUsed = 0;
for (let batchIndex = 0; batchIndex < chunkBatches.length; batchIndex++) {
const batch = chunkBatches[batchIndex];
console.log(`Processing batch ${batchIndex + 1}/${chunkBatches.length} with ${batch.length} chunks`);
// Get embeddings for the current batch
const batchEmbeddingResponse = await axios.post(
'https://api.jina.ai/v1/embeddings',
{
model: "jina-embeddings-v3",
task: "retrieval.passage",
late_chunking: true,
dimensions: 1024,
embedding_type: "float",
input: batch,
truncate: true
},
{
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${JINA_API_KEY}`
}
}
);
if (batchEmbeddingResponse.status !== 200) {
throw new Error(`Unexpected status code from API: ${batchEmbeddingResponse.status}`);
}
// Validate response structure
if (!batchEmbeddingResponse.data?.data) {
throw new Error("Unexpected API response format");
}
// Extract embeddings from this batch
const batchEmbeddings = batchEmbeddingResponse.data.data.map((item: any) => item.embedding);
allChunkEmbeddings.push(...batchEmbeddings);
// Track token usage
const batchTokens = batchEmbeddingResponse.data.usage?.total_tokens || 0;
totalTokensUsed += batchTokens;
}
// Get embedding for the question
const questionEmbeddingResponse = await axios.post(
'https://api.jina.ai/v1/embeddings',
// Get embeddings for all chunks using the new getEmbeddings function
const chunkEmbeddingResult = await getEmbeddings(
chunks,
trackers.tokenTracker,
{
model: "jina-embeddings-v3",
task: "retrieval.query",
task: "retrieval.passage",
dimensions: 1024,
embedding_type: "float",
input: [question],
truncate: true
},
{
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${JINA_API_KEY}`
}
late_chunking: true,
embedding_type: "float"
}
);
if (questionEmbeddingResponse.status !== 200) {
throw new Error("Unexpected status code from API");
}
const allChunkEmbeddings = chunkEmbeddingResult.embeddings;
// Validate question embedding response
if (!questionEmbeddingResponse.data?.data || !questionEmbeddingResponse.data.data[0]?.embedding) {
throw new Error("Question embedding not found in API response");
}
// Get embedding for the question
const questionEmbeddingResult = await getEmbeddings(
[question],
trackers.tokenTracker,
{
task: "retrieval.query",
dimensions: 1024,
embedding_type: "float"
}
);
// Track token usage for question embedding
const questionTokens = questionEmbeddingResponse.data.usage?.total_tokens || 0;
totalTokensUsed += questionTokens;
// Track total token usage
trackers.tokenTracker.trackUsage('latechunk', {
promptTokens: totalTokensUsed,
completionTokens: 0,
totalTokens: totalTokensUsed
});
const questionEmbedding = questionEmbeddingResponse.data.data[0].embedding;
const questionEmbedding = questionEmbeddingResult.embeddings[0];
// Verify that we got embeddings for all chunks
if (allChunkEmbeddings.length !== chunks.length) {
@ -199,4 +119,4 @@ ${snippet}
// Fallback: just return the beginning of the context up to the desired length
return longContext.substring(0, snippetLength * numSnippets);
}
}
}

View File

@ -43,7 +43,7 @@ export async function rerankDocuments(
batches.push(documents.slice(i, i + batchSize));
}
console.log(`Processing ${documents.length} documents in ${batches.length} batches of up to ${batchSize} each`);
console.log(`Rerank ${documents.length} documents in ${batches.length} batches of up to ${batchSize} each`);
// Process all batches in parallel
const batchResults = await Promise.all(

View File

@ -36,7 +36,7 @@ export async function segmentText(
// Process all batches in parallel
const batchPromises = batches.map(async (batch, i) => {
console.log(`Processing batch ${i + 1}/${batches.length} (size: ${batch.length})`);
console.log(`[Segment] Processing batch ${i + 1}/${batches.length} (size: ${batch.length})`);
try {
const {data} = await axios.post(

View File

@ -244,6 +244,9 @@ export interface ChatCompletionRequest {
boost_hostnames?: string[];
bad_hostnames?: string[];
only_hostnames?: string[];
max_annotations?: number;
min_annotation_relevance?: number;
}
export interface URLAnnotation {