mirror of
https://github.com/jina-ai/node-DeepResearch.git
synced 2025-12-26 06:28:56 +08:00
refactor: build ref with embeddings
This commit is contained in:
parent
858289a25d
commit
b432a27bff
@ -391,7 +391,9 @@ export async function getResponse(question?: string,
|
||||
noDirectAnswer: boolean = false,
|
||||
boostHostnames: string[] = [],
|
||||
badHostnames: string[] = [],
|
||||
onlyHostnames: string[] = []
|
||||
onlyHostnames: string[] = [],
|
||||
maxRef: number = 10,
|
||||
minRelScore: number = 0.7
|
||||
): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[], allURLs: string[] }> {
|
||||
|
||||
let step = 0;
|
||||
@ -990,7 +992,10 @@ But unfortunately, you failed to solve the issue. You need to think out of the b
|
||||
answerStep.answer,
|
||||
allWebContents,
|
||||
context,
|
||||
SchemaGen
|
||||
SchemaGen,
|
||||
80,
|
||||
maxRef,
|
||||
minRelScore
|
||||
);
|
||||
|
||||
answerStep.answer = answer;
|
||||
|
||||
@ -558,6 +558,8 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
|
||||
body.boost_hostnames,
|
||||
body.bad_hostnames,
|
||||
body.only_hostnames,
|
||||
body.max_annotations,
|
||||
body.min_annotation_relevance
|
||||
)
|
||||
let finalAnswer = (finalStep as AnswerAction).mdAnswer;
|
||||
|
||||
|
||||
@ -1,106 +1,16 @@
|
||||
import {segmentText} from './segment';
|
||||
import {JinaEmbeddingRequest, JinaEmbeddingResponse, Reference, TrackerContext, WebContent} from "../types";
|
||||
import {Reference, TrackerContext, WebContent} from "../types";
|
||||
import {Schemas} from "../utils/schemas";
|
||||
import axios, {AxiosError} from 'axios';
|
||||
import {JINA_API_KEY} from "../config";
|
||||
import {cosineSimilarity, jaccardRank} from "./cosine";
|
||||
|
||||
const BATCH_SIZE = 2000;
|
||||
const API_URL = "https://api.jina.ai/v1/embeddings";
|
||||
|
||||
// Simplified function to get embeddings in a single request
|
||||
async function getEmbeddings(
|
||||
texts: string[],
|
||||
tokenTracker?: any
|
||||
): Promise<{ embeddings: number[][], tokens: number }> {
|
||||
console.log(`[embeddings] Getting embeddings for ${texts.length} texts`);
|
||||
|
||||
if (!JINA_API_KEY) {
|
||||
throw new Error('JINA_API_KEY is not set');
|
||||
}
|
||||
|
||||
// Handle empty input case
|
||||
if (texts.length === 0) {
|
||||
return {embeddings: [], tokens: 0};
|
||||
}
|
||||
|
||||
// Process in batches of 2000
|
||||
const allEmbeddings: number[][] = [];
|
||||
let totalTokens = 0;
|
||||
const batchCount = Math.ceil(texts.length / BATCH_SIZE);
|
||||
|
||||
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
|
||||
const batchTexts = texts.slice(i, i + BATCH_SIZE);
|
||||
const currentBatch = Math.floor(i / BATCH_SIZE) + 1;
|
||||
console.log(`[embeddings] Processing batch ${currentBatch}/${batchCount} (${batchTexts.length} texts)`);
|
||||
|
||||
const request: JinaEmbeddingRequest = {
|
||||
model: "jina-embeddings-v3",
|
||||
task: "text-matching",
|
||||
late_chunking: false, // Late chunking turned off always
|
||||
dimensions: 1024,
|
||||
embedding_type: "float",
|
||||
input: batchTexts,
|
||||
truncate: true
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await axios.post<JinaEmbeddingResponse>(
|
||||
API_URL,
|
||||
request,
|
||||
{
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": `Bearer ${JINA_API_KEY}`
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Validate response format
|
||||
if (!response.data.data || response.data.data.length !== batchTexts.length) {
|
||||
console.error('Invalid response from Jina API:', response.data);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Sort embeddings by index to maintain original order
|
||||
const batchEmbeddings = response.data.data
|
||||
.sort((a, b) => a.index - b.index)
|
||||
.map(item => item.embedding);
|
||||
|
||||
allEmbeddings.push(...batchEmbeddings);
|
||||
totalTokens += response.data.usage.total_tokens;
|
||||
console.log(`[embeddings] Batch ${currentBatch} complete. Tokens used: ${response.data.usage.total_tokens}, total so far: ${totalTokens}`);
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error calling Jina Embeddings API:', error);
|
||||
if (error instanceof AxiosError && error.response?.status === 402) {
|
||||
return {embeddings: [], tokens: 0};
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// Track token usage if tracker is provided
|
||||
if (tokenTracker) {
|
||||
tokenTracker.trackUsage('embeddings', {
|
||||
promptTokens: totalTokens,
|
||||
completionTokens: 0,
|
||||
totalTokens: totalTokens
|
||||
});
|
||||
}
|
||||
|
||||
console.log(`[embeddings] Complete. Generated ${allEmbeddings.length} embeddings using ${totalTokens} tokens`);
|
||||
return {embeddings: allEmbeddings, tokens: totalTokens};
|
||||
}
|
||||
|
||||
import {getEmbeddings} from "./embeddings";
|
||||
|
||||
export async function buildReferences(
|
||||
answer: string,
|
||||
webContents: Record<string, WebContent>,
|
||||
context: TrackerContext,
|
||||
schema: Schemas,
|
||||
maxRef: number = 10,
|
||||
minChunkLength: number = 80,
|
||||
maxRef: number = 10,
|
||||
minRelScore: number = 0.7
|
||||
): Promise<{ answer: string, references: Array<Reference> }> {
|
||||
console.log(`[buildReferences] Starting with maxRef=${maxRef}, minChunkLength=${minChunkLength}, minRelScore=${minRelScore}`);
|
||||
|
||||
118
src/tools/embeddings.ts
Normal file
118
src/tools/embeddings.ts
Normal file
@ -0,0 +1,118 @@
|
||||
import {JINA_API_KEY} from "../config";
|
||||
import {JinaEmbeddingRequest, JinaEmbeddingResponse} from "../types";
|
||||
import axios, {AxiosError} from "axios";
|
||||
|
||||
const BATCH_SIZE = 128;
|
||||
const API_URL = "https://api.jina.ai/v1/embeddings";
|
||||
|
||||
// Modified to support different embedding tasks and dimensions
|
||||
export async function getEmbeddings(
|
||||
texts: string[],
|
||||
tokenTracker?: any,
|
||||
options: {
|
||||
task?: "text-matching" | "retrieval.passage" | "retrieval.query",
|
||||
dimensions?: number,
|
||||
late_chunking?: boolean,
|
||||
embedding_type?: string
|
||||
} = {}
|
||||
): Promise<{ embeddings: number[][], tokens: number }> {
|
||||
console.log(`[embeddings] Getting embeddings for ${texts.length} texts`);
|
||||
|
||||
if (!JINA_API_KEY) {
|
||||
throw new Error('JINA_API_KEY is not set');
|
||||
}
|
||||
|
||||
// Handle empty input case
|
||||
if (texts.length === 0) {
|
||||
return {embeddings: [], tokens: 0};
|
||||
}
|
||||
|
||||
// Process in batches
|
||||
const allEmbeddings: number[][] = [];
|
||||
let totalTokens = 0;
|
||||
const batchCount = Math.ceil(texts.length / BATCH_SIZE);
|
||||
|
||||
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
|
||||
const batchTexts = texts.slice(i, i + BATCH_SIZE);
|
||||
const currentBatch = Math.floor(i / BATCH_SIZE) + 1;
|
||||
console.log(`[embeddings] Processing batch ${currentBatch}/${batchCount} (${batchTexts.length} texts)`);
|
||||
|
||||
const request: JinaEmbeddingRequest = {
|
||||
model: "jina-embeddings-v3",
|
||||
task: options.task || "text-matching",
|
||||
input: batchTexts,
|
||||
truncate: true
|
||||
};
|
||||
|
||||
// Add optional parameters if provided
|
||||
if (options.dimensions) request.dimensions = options.dimensions;
|
||||
if (options.late_chunking) request.late_chunking = options.late_chunking;
|
||||
if (options.embedding_type) request.embedding_type = options.embedding_type;
|
||||
|
||||
try {
|
||||
const response = await axios.post<JinaEmbeddingResponse>(
|
||||
API_URL,
|
||||
request,
|
||||
{
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": `Bearer ${JINA_API_KEY}`
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Prepare embeddings, handling any missing indices
|
||||
let batchEmbeddings: number[][];
|
||||
|
||||
if (!response.data.data || response.data.data.length !== batchTexts.length) {
|
||||
console.error('Invalid response from Jina API:', response.data.data?.length, batchTexts.length);
|
||||
|
||||
// Find missing indices and complete with zero vectors
|
||||
const receivedIndices = new Set(response.data.data?.map(item => item.index) || []);
|
||||
const dimensionSize = response.data.data?.[0]?.embedding?.length || options.dimensions || 1024;
|
||||
|
||||
batchEmbeddings = [];
|
||||
|
||||
for (let idx = 0; idx < batchTexts.length; idx++) {
|
||||
if (receivedIndices.has(idx)) {
|
||||
// Find the item with this index
|
||||
const item = response.data.data.find(d => d.index === idx);
|
||||
batchEmbeddings.push(item!.embedding);
|
||||
} else {
|
||||
// Create a zero vector for missing index
|
||||
console.error(`Missing embedding for index ${idx}: [${batchTexts[idx]}]`);
|
||||
batchEmbeddings.push(new Array(dimensionSize).fill(0));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// All indices present, just sort by index
|
||||
batchEmbeddings = response.data.data
|
||||
.sort((a, b) => a.index - b.index)
|
||||
.map(item => item.embedding);
|
||||
}
|
||||
|
||||
allEmbeddings.push(...batchEmbeddings);
|
||||
totalTokens += response.data.usage?.total_tokens || 0;
|
||||
console.log(`[embeddings] Batch ${currentBatch} complete. Tokens used: ${response.data.usage?.total_tokens || 0}, total so far: ${totalTokens}`);
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error calling Jina Embeddings API:', error);
|
||||
if (error instanceof AxiosError && error.response?.status === 402) {
|
||||
return {embeddings: [], tokens: 0};
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// Track token usage if tracker is provided
|
||||
if (tokenTracker) {
|
||||
tokenTracker.trackUsage('embeddings', {
|
||||
promptTokens: totalTokens,
|
||||
completionTokens: 0,
|
||||
totalTokens: totalTokens
|
||||
});
|
||||
}
|
||||
|
||||
console.log(`[embeddings] Complete. Generated ${allEmbeddings.length} embeddings using ${totalTokens} tokens`);
|
||||
return {embeddings: allEmbeddings, tokens: totalTokens};
|
||||
}
|
||||
@ -1,77 +1,9 @@
|
||||
import axios, {AxiosError} from 'axios';
|
||||
import {TokenTracker} from "../utils/token-tracker";
|
||||
import {JINA_API_KEY} from "../config";
|
||||
import {cosineSimilarity} from "./cosine";
|
||||
import {JinaEmbeddingRequest, JinaEmbeddingResponse} from "../types";
|
||||
import {getEmbeddings} from "./embeddings";
|
||||
|
||||
const JINA_API_URL = 'https://api.jina.ai/v1/embeddings';
|
||||
const SIMILARITY_THRESHOLD = 0.86; // Adjustable threshold for cosine similarity
|
||||
|
||||
const JINA_API_CONFIG = {
|
||||
MODEL: 'jina-embeddings-v3',
|
||||
TASK: 'text-matching',
|
||||
DIMENSIONS: 1024,
|
||||
EMBEDDING_TYPE: 'float',
|
||||
LATE_CHUNKING: false
|
||||
} as const;
|
||||
|
||||
|
||||
// Get embeddings for all queries in one batch
|
||||
async function getEmbeddings(queries: string[]): Promise<{ embeddings: number[][], tokens: number }> {
|
||||
if (!JINA_API_KEY) {
|
||||
throw new Error('JINA_API_KEY is not set');
|
||||
}
|
||||
|
||||
const request: JinaEmbeddingRequest = {
|
||||
model: JINA_API_CONFIG.MODEL,
|
||||
task: JINA_API_CONFIG.TASK,
|
||||
late_chunking: JINA_API_CONFIG.LATE_CHUNKING,
|
||||
dimensions: JINA_API_CONFIG.DIMENSIONS,
|
||||
embedding_type: JINA_API_CONFIG.EMBEDDING_TYPE,
|
||||
input: queries
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await axios.post<JinaEmbeddingResponse>(
|
||||
JINA_API_URL,
|
||||
request,
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${JINA_API_KEY}`
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Validate response format
|
||||
if (!response.data.data || response.data.data.length !== queries.length) {
|
||||
console.error('Invalid response from Jina API:', response.data);
|
||||
return {
|
||||
embeddings: [],
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
|
||||
// Sort embeddings by index to maintain original order
|
||||
const embeddings = response.data.data
|
||||
.sort((a, b) => a.index - b.index)
|
||||
.map(item => item.embedding);
|
||||
|
||||
return {
|
||||
embeddings,
|
||||
tokens: response.data.usage.total_tokens
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error getting embeddings from Jina:', error);
|
||||
if (error instanceof AxiosError && error.response?.status === 402) {
|
||||
return {
|
||||
embeddings: [],
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function dedupQueries(
|
||||
newQueries: string[],
|
||||
@ -88,7 +20,7 @@ export async function dedupQueries(
|
||||
|
||||
// Get embeddings for all queries in one batch
|
||||
const allQueries = [...newQueries, ...existingQueries];
|
||||
const {embeddings: allEmbeddings, tokens} = await getEmbeddings(allQueries);
|
||||
const {embeddings: allEmbeddings} = await getEmbeddings(allQueries, tracker);
|
||||
|
||||
// If embeddings is empty (due to 402 error), return all new queries
|
||||
if (!allEmbeddings.length) {
|
||||
@ -134,13 +66,6 @@ export async function dedupQueries(
|
||||
usedIndices.add(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Track token usage from the API
|
||||
(tracker || new TokenTracker()).trackUsage('dedup', {
|
||||
promptTokens: 0,
|
||||
completionTokens: tokens,
|
||||
totalTokens: tokens
|
||||
});
|
||||
console.log('Dedup:', uniqueQueries);
|
||||
return {
|
||||
unique_queries: uniqueQueries,
|
||||
|
||||
@ -1,25 +1,19 @@
|
||||
import {TrackerContext} from "../types";
|
||||
import axios from 'axios';
|
||||
import {JINA_API_KEY} from "../config";
|
||||
import {Schemas} from "../utils/schemas";
|
||||
import {cosineSimilarity} from "./cosine";
|
||||
import {getEmbeddings} from "./embeddings";
|
||||
|
||||
// Refactored cherryPick function
|
||||
export async function cherryPick(question: string, longContext: string, options: any = {}, trackers: TrackerContext, schemaGen: Schemas, url: string) {
|
||||
|
||||
const {
|
||||
snippetLength = 6000, // char length of each snippet
|
||||
numSnippets = Math.max(2, Math.min(5, Math.floor(longContext.length / snippetLength))),
|
||||
chunkSize = 300, // char length of each chunk
|
||||
} = options;
|
||||
|
||||
const maxTokensPerRequest = 8192 // Maximum tokens per embedding request
|
||||
|
||||
// Rough estimate of tokens per character (can be adjusted based on your text)
|
||||
const tokensPerCharacter = 0.4
|
||||
|
||||
if (longContext.length < snippetLength * 2) {
|
||||
// If the context is shorter than the snippet length, return the whole context
|
||||
console.log('content is too short, dont bother')
|
||||
console.log('content is too short, dont bother');
|
||||
return longContext;
|
||||
}
|
||||
|
||||
@ -38,106 +32,32 @@ export async function cherryPick(question: string, longContext: string, options:
|
||||
throw new Error('Empty question, returning full context');
|
||||
}
|
||||
|
||||
// Estimate the number of tokens per chunk
|
||||
const estimatedTokensPerChunk = Math.ceil(chunkSize * tokensPerCharacter);
|
||||
|
||||
// Calculate chunks per batch to stay under token limit
|
||||
const chunksPerBatch = Math.floor(maxTokensPerRequest / estimatedTokensPerChunk);
|
||||
|
||||
// Create batches of chunks
|
||||
const chunkBatches = [];
|
||||
for (let i = 0; i < chunks.length; i += chunksPerBatch) {
|
||||
chunkBatches.push(chunks.slice(i, i + chunksPerBatch));
|
||||
}
|
||||
|
||||
console.log(`Total length ${longContext.length} split ${chunks.length} chunks into ${chunkBatches.length} batches of ~${chunksPerBatch} chunks each`);
|
||||
|
||||
// Process each batch and collect the embeddings
|
||||
const allChunkEmbeddings: number[][] = [];
|
||||
let totalTokensUsed = 0;
|
||||
|
||||
for (let batchIndex = 0; batchIndex < chunkBatches.length; batchIndex++) {
|
||||
const batch = chunkBatches[batchIndex];
|
||||
console.log(`Processing batch ${batchIndex + 1}/${chunkBatches.length} with ${batch.length} chunks`);
|
||||
|
||||
// Get embeddings for the current batch
|
||||
const batchEmbeddingResponse = await axios.post(
|
||||
'https://api.jina.ai/v1/embeddings',
|
||||
{
|
||||
model: "jina-embeddings-v3",
|
||||
task: "retrieval.passage",
|
||||
late_chunking: true,
|
||||
dimensions: 1024,
|
||||
embedding_type: "float",
|
||||
input: batch,
|
||||
truncate: true
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${JINA_API_KEY}`
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
if (batchEmbeddingResponse.status !== 200) {
|
||||
throw new Error(`Unexpected status code from API: ${batchEmbeddingResponse.status}`);
|
||||
}
|
||||
|
||||
// Validate response structure
|
||||
if (!batchEmbeddingResponse.data?.data) {
|
||||
throw new Error("Unexpected API response format");
|
||||
}
|
||||
|
||||
// Extract embeddings from this batch
|
||||
const batchEmbeddings = batchEmbeddingResponse.data.data.map((item: any) => item.embedding);
|
||||
allChunkEmbeddings.push(...batchEmbeddings);
|
||||
|
||||
// Track token usage
|
||||
const batchTokens = batchEmbeddingResponse.data.usage?.total_tokens || 0;
|
||||
totalTokensUsed += batchTokens;
|
||||
}
|
||||
|
||||
// Get embedding for the question
|
||||
const questionEmbeddingResponse = await axios.post(
|
||||
'https://api.jina.ai/v1/embeddings',
|
||||
// Get embeddings for all chunks using the new getEmbeddings function
|
||||
const chunkEmbeddingResult = await getEmbeddings(
|
||||
chunks,
|
||||
trackers.tokenTracker,
|
||||
{
|
||||
model: "jina-embeddings-v3",
|
||||
task: "retrieval.query",
|
||||
task: "retrieval.passage",
|
||||
dimensions: 1024,
|
||||
embedding_type: "float",
|
||||
input: [question],
|
||||
truncate: true
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${JINA_API_KEY}`
|
||||
}
|
||||
late_chunking: true,
|
||||
embedding_type: "float"
|
||||
}
|
||||
);
|
||||
|
||||
if (questionEmbeddingResponse.status !== 200) {
|
||||
throw new Error("Unexpected status code from API");
|
||||
}
|
||||
const allChunkEmbeddings = chunkEmbeddingResult.embeddings;
|
||||
|
||||
// Validate question embedding response
|
||||
if (!questionEmbeddingResponse.data?.data || !questionEmbeddingResponse.data.data[0]?.embedding) {
|
||||
throw new Error("Question embedding not found in API response");
|
||||
}
|
||||
// Get embedding for the question
|
||||
const questionEmbeddingResult = await getEmbeddings(
|
||||
[question],
|
||||
trackers.tokenTracker,
|
||||
{
|
||||
task: "retrieval.query",
|
||||
dimensions: 1024,
|
||||
embedding_type: "float"
|
||||
}
|
||||
);
|
||||
|
||||
// Track token usage for question embedding
|
||||
const questionTokens = questionEmbeddingResponse.data.usage?.total_tokens || 0;
|
||||
totalTokensUsed += questionTokens;
|
||||
|
||||
// Track total token usage
|
||||
trackers.tokenTracker.trackUsage('latechunk', {
|
||||
promptTokens: totalTokensUsed,
|
||||
completionTokens: 0,
|
||||
totalTokens: totalTokensUsed
|
||||
});
|
||||
|
||||
const questionEmbedding = questionEmbeddingResponse.data.data[0].embedding;
|
||||
const questionEmbedding = questionEmbeddingResult.embeddings[0];
|
||||
|
||||
// Verify that we got embeddings for all chunks
|
||||
if (allChunkEmbeddings.length !== chunks.length) {
|
||||
@ -199,4 +119,4 @@ ${snippet}
|
||||
// Fallback: just return the beginning of the context up to the desired length
|
||||
return longContext.substring(0, snippetLength * numSnippets);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -43,7 +43,7 @@ export async function rerankDocuments(
|
||||
batches.push(documents.slice(i, i + batchSize));
|
||||
}
|
||||
|
||||
console.log(`Processing ${documents.length} documents in ${batches.length} batches of up to ${batchSize} each`);
|
||||
console.log(`Rerank ${documents.length} documents in ${batches.length} batches of up to ${batchSize} each`);
|
||||
|
||||
// Process all batches in parallel
|
||||
const batchResults = await Promise.all(
|
||||
|
||||
@ -36,7 +36,7 @@ export async function segmentText(
|
||||
|
||||
// Process all batches in parallel
|
||||
const batchPromises = batches.map(async (batch, i) => {
|
||||
console.log(`Processing batch ${i + 1}/${batches.length} (size: ${batch.length})`);
|
||||
console.log(`[Segment] Processing batch ${i + 1}/${batches.length} (size: ${batch.length})`);
|
||||
|
||||
try {
|
||||
const {data} = await axios.post(
|
||||
|
||||
@ -244,6 +244,9 @@ export interface ChatCompletionRequest {
|
||||
boost_hostnames?: string[];
|
||||
bad_hostnames?: string[];
|
||||
only_hostnames?: string[];
|
||||
|
||||
max_annotations?: number;
|
||||
min_annotation_relevance?: number;
|
||||
}
|
||||
|
||||
export interface URLAnnotation {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user