mirror of
https://github.com/jina-ai/node-DeepResearch.git
synced 2026-03-22 07:29:35 +08:00
fix: token tracking
This commit is contained in:
21
src/agent.ts
21
src/agent.ts
@@ -314,12 +314,12 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
const allURLs: Record<string, string> = {};
|
||||
const visitedURLs: string[] = [];
|
||||
const evaluationMetrics: Record<string, any[]> = {};
|
||||
while (context.tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) {
|
||||
while (context.tokenTracker.getTotalUsage().totalTokens < tokenBudget && badAttempts <= maxBadAttempts) {
|
||||
// add 1s delay to avoid rate limiting
|
||||
await sleep(STEP_SLEEP);
|
||||
step++;
|
||||
totalStep++;
|
||||
const budgetPercentage = (context.tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2);
|
||||
const budgetPercentage = (context.tokenTracker.getTotalUsage().totalTokens / tokenBudget * 100).toFixed(2);
|
||||
console.log(`Step ${totalStep} / Budget used ${budgetPercentage}%`);
|
||||
console.log('Gaps:', gaps);
|
||||
allowReflect = allowReflect && (gaps.length <= 1);
|
||||
@@ -350,7 +350,6 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
schema = getSchema(allowReflect, allowRead, allowAnswer, allowSearch)
|
||||
const model = getModel('agent');
|
||||
let object;
|
||||
let totalTokens = 0;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
@@ -359,11 +358,11 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
maxTokens: getMaxTokens('agent')
|
||||
});
|
||||
object = result.object;
|
||||
totalTokens = result.usage?.totalTokens || 0;
|
||||
context.tokenTracker.trackUsage('agent', result.usage);
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<StepAction>(error);
|
||||
object = result.object;
|
||||
totalTokens = result.totalTokens;
|
||||
context.tokenTracker.trackUsage('agent', result.usage);
|
||||
}
|
||||
thisStep = object as StepAction;
|
||||
// print allowed and chose action
|
||||
@@ -372,7 +371,6 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
console.log(thisStep)
|
||||
|
||||
context.actionTracker.trackAction({totalStep, thisStep, gaps, badAttempts});
|
||||
context.tokenTracker.trackUsage('agent', totalTokens);
|
||||
|
||||
// reset allowAnswer to true
|
||||
allowAnswer = true;
|
||||
@@ -622,7 +620,7 @@ You decided to think out of the box or cut from a completely different angle.
|
||||
const urlResults = await Promise.all(
|
||||
uniqueURLs.map(async (url: string) => {
|
||||
try {
|
||||
const {response, tokens} = await readUrl(url, context.tokenTracker);
|
||||
const {response} = await readUrl(url, context.tokenTracker);
|
||||
allKnowledge.push({
|
||||
question: `What is in ${response.data?.url || 'the URL'}?`,
|
||||
answer: removeAllLineBreaks(response.data?.content || 'No content available'),
|
||||
@@ -631,7 +629,7 @@ You decided to think out of the box or cut from a completely different angle.
|
||||
});
|
||||
visitedURLs.push(url);
|
||||
delete allURLs[url];
|
||||
return {url, result: response, tokens};
|
||||
return {url, result: response};
|
||||
} catch (error) {
|
||||
console.error('Error reading URL:', error);
|
||||
}
|
||||
@@ -696,7 +694,6 @@ You decided to think out of the box or cut from a completely different angle.`);
|
||||
schema = getSchema(false, false, true, false);
|
||||
const model = getModel('agentBeastMode');
|
||||
let object;
|
||||
let totalTokens;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
@@ -705,16 +702,16 @@ You decided to think out of the box or cut from a completely different angle.`);
|
||||
maxTokens: getMaxTokens('agentBeastMode')
|
||||
});
|
||||
object = result.object;
|
||||
totalTokens = result.usage?.totalTokens || 0;
|
||||
context.tokenTracker.trackUsage('agent', result.usage);
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<StepAction>(error);
|
||||
object = result.object;
|
||||
totalTokens = result.totalTokens;
|
||||
context.tokenTracker.trackUsage('agent', result.usage);
|
||||
}
|
||||
await storeContext(prompt, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
|
||||
thisStep = object as StepAction;
|
||||
context.actionTracker.trackAction({totalStep, thisStep, gaps, badAttempts});
|
||||
context.tokenTracker.trackUsage('agent', totalTokens);
|
||||
|
||||
console.log(thisStep)
|
||||
return {result: thisStep, context};
|
||||
}
|
||||
|
||||
178
src/app.ts
178
src/app.ts
@@ -1,10 +1,7 @@
|
||||
import express, {Request, Response, RequestHandler} from 'express';
|
||||
import cors from 'cors';
|
||||
import {EventEmitter} from 'events';
|
||||
import {getResponse} from './agent';
|
||||
import {
|
||||
StepAction,
|
||||
StreamMessage,
|
||||
TrackerContext,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@@ -12,8 +9,6 @@ import {
|
||||
AnswerAction,
|
||||
Model
|
||||
} from './types';
|
||||
import fs from 'fs/promises';
|
||||
import path from 'path';
|
||||
import {TokenTracker} from "./utils/token-tracker";
|
||||
import {ActionTracker} from "./utils/action-tracker";
|
||||
|
||||
@@ -30,16 +25,6 @@ app.get('/health', (req, res) => {
|
||||
res.json({status: 'ok'});
|
||||
});
|
||||
|
||||
const eventEmitter = new EventEmitter();
|
||||
|
||||
interface QueryRequest extends Request {
|
||||
body: {
|
||||
q: string;
|
||||
budget?: number;
|
||||
maxBadAttempt?: number;
|
||||
};
|
||||
}
|
||||
|
||||
function buildMdFromAnswer(answer: AnswerAction) {
|
||||
let refStr = '';
|
||||
if (answer.references?.length > 0) {
|
||||
@@ -358,7 +343,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
|
||||
}
|
||||
}
|
||||
|
||||
const usage = context.tokenTracker.getUsageDetails();
|
||||
const usage = context.tokenTracker.getTotalUsageSnakeCase();
|
||||
if (body.stream) {
|
||||
// Complete any ongoing streaming before sending final answer
|
||||
await completeCurrentStreaming(streamingState, res, requestId, created, body.model);
|
||||
@@ -442,7 +427,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
|
||||
context.actionTracker.removeAllListeners('action');
|
||||
|
||||
// Get token usage in OpenAI API format
|
||||
const usage = context.tokenTracker.getUsageDetails();
|
||||
const usage = context.tokenTracker.getTotalUsageSnakeCase();
|
||||
|
||||
if (body.stream && res.headersSent) {
|
||||
// For streaming responses that have already started, send error as a chunk
|
||||
@@ -504,165 +489,6 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
|
||||
}
|
||||
}) as RequestHandler);
|
||||
|
||||
interface StreamResponse extends Response {
|
||||
write: (chunk: string) => boolean;
|
||||
}
|
||||
|
||||
function createProgressEmitter(requestId: string, budget: number | undefined, context: TrackerContext) {
|
||||
return () => {
|
||||
const state = context.actionTracker.getState();
|
||||
const budgetInfo = {
|
||||
used: context.tokenTracker.getTotalUsage(),
|
||||
total: budget || 1_000_000,
|
||||
percentage: ((context.tokenTracker.getTotalUsage() / (budget || 1_000_000)) * 100).toFixed(2)
|
||||
};
|
||||
|
||||
eventEmitter.emit(`progress-${requestId}`, {
|
||||
type: 'progress',
|
||||
data: {...state.thisStep, totalStep: state.totalStep},
|
||||
step: state.totalStep,
|
||||
budget: budgetInfo,
|
||||
trackers: {
|
||||
tokenUsage: context.tokenTracker.getTotalUsage(),
|
||||
actionState: context.actionTracker.getState()
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
function cleanup(requestId: string) {
|
||||
const context = trackers.get(requestId);
|
||||
if (context) {
|
||||
context.actionTracker.removeAllListeners();
|
||||
context.tokenTracker.removeAllListeners();
|
||||
trackers.delete(requestId);
|
||||
}
|
||||
}
|
||||
|
||||
function emitTrackerUpdate(requestId: string, context: TrackerContext) {
|
||||
const trackerData = {
|
||||
tokenUsage: context.tokenTracker.getTotalUsage(),
|
||||
tokenBreakdown: context.tokenTracker.getUsageBreakdown(),
|
||||
actionState: context.actionTracker.getState().thisStep,
|
||||
step: context.actionTracker.getState().totalStep,
|
||||
badAttempts: context.actionTracker.getState().badAttempts,
|
||||
gaps: context.actionTracker.getState().gaps
|
||||
};
|
||||
|
||||
eventEmitter.emit(`progress-${requestId}`, {
|
||||
type: 'progress',
|
||||
trackers: trackerData
|
||||
});
|
||||
}
|
||||
|
||||
// Store the trackers for each request
|
||||
const trackers = new Map<string, TrackerContext>();
|
||||
|
||||
app.post('/api/v1/query', (async (req: QueryRequest, res: Response) => {
|
||||
const {q, budget, maxBadAttempt} = req.body;
|
||||
if (!q) {
|
||||
return res.status(400).json({error: 'Query (q) is required'});
|
||||
}
|
||||
|
||||
const requestId = Date.now().toString();
|
||||
|
||||
// Create new trackers for this request
|
||||
const context: TrackerContext = {
|
||||
tokenTracker: new TokenTracker(),
|
||||
actionTracker: new ActionTracker()
|
||||
};
|
||||
trackers.set(requestId, context);
|
||||
|
||||
// Set up listeners immediately for both trackers
|
||||
context.actionTracker.on('action', () => emitTrackerUpdate(requestId, context));
|
||||
// context.tokenTracker.on('usage', () => emitTrackerUpdate(requestId, context));
|
||||
|
||||
res.json({requestId});
|
||||
|
||||
try {
|
||||
const {result} = await getResponse(q, budget, maxBadAttempt, context);
|
||||
const emitProgress = createProgressEmitter(requestId, budget, context);
|
||||
context.actionTracker.on('action', emitProgress);
|
||||
await storeTaskResult(requestId, result);
|
||||
eventEmitter.emit(`progress-${requestId}`, {
|
||||
type: 'answer',
|
||||
data: result,
|
||||
trackers: {
|
||||
tokenUsage: context.tokenTracker.getTotalUsage(),
|
||||
actionState: context.actionTracker.getState()
|
||||
}
|
||||
});
|
||||
cleanup(requestId);
|
||||
} catch (error: any) {
|
||||
eventEmitter.emit(`progress-${requestId}`, {
|
||||
type: 'error',
|
||||
data: error?.message || 'Unknown error',
|
||||
status: 500,
|
||||
trackers: {
|
||||
tokenUsage: context.tokenTracker.getTotalUsage(),
|
||||
actionState: context.actionTracker.getState()
|
||||
}
|
||||
});
|
||||
cleanup(requestId);
|
||||
}
|
||||
}) as RequestHandler);
|
||||
|
||||
app.get('/api/v1/stream/:requestId', (async (req: Request, res: StreamResponse) => {
|
||||
const requestId = req.params.requestId;
|
||||
const context = trackers.get(requestId);
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
|
||||
const listener = (data: StreamMessage) => {
|
||||
// The trackers are now included in all event types
|
||||
// We don't need to add them here as they're already part of the data
|
||||
res.write(`data: ${JSON.stringify(data)}\n\n`);
|
||||
};
|
||||
|
||||
eventEmitter.on(`progress-${requestId}`, listener);
|
||||
|
||||
// Handle client disconnection
|
||||
req.on('close', () => {
|
||||
eventEmitter.removeListener(`progress-${requestId}`, listener);
|
||||
});
|
||||
|
||||
// Send initial connection confirmation with tracker state
|
||||
const initialData = {
|
||||
type: 'connected',
|
||||
requestId,
|
||||
trackers: context ? {
|
||||
tokenUsage: context.tokenTracker.getTotalUsage(),
|
||||
actionState: context.actionTracker.getState()
|
||||
} : null
|
||||
};
|
||||
res.write(`data: ${JSON.stringify(initialData)}\n\n`);
|
||||
}) as RequestHandler);
|
||||
|
||||
async function storeTaskResult(requestId: string, result: StepAction) {
|
||||
try {
|
||||
const taskDir = path.join(process.cwd(), 'tasks');
|
||||
await fs.mkdir(taskDir, {recursive: true});
|
||||
await fs.writeFile(
|
||||
path.join(taskDir, `${requestId}.json`),
|
||||
JSON.stringify(result, null, 2)
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Task storage failed:', error);
|
||||
throw new Error('Failed to store task result');
|
||||
}
|
||||
}
|
||||
|
||||
app.get('/api/v1/task/:requestId', (async (req: Request, res: Response) => {
|
||||
const requestId = req.params.requestId;
|
||||
try {
|
||||
const taskPath = path.join(process.cwd(), 'tasks', `${requestId}.json`);
|
||||
const taskData = await fs.readFile(taskPath, 'utf-8');
|
||||
res.json(JSON.parse(taskData));
|
||||
} catch (error) {
|
||||
res.status(404).json({error: 'Task not found'});
|
||||
}
|
||||
}) as RequestHandler);
|
||||
|
||||
export default app;
|
||||
|
||||
@@ -160,7 +160,7 @@ async function batchEvaluate(inputFile: string): Promise<void> {
|
||||
pass: evaluation.pass,
|
||||
reason: evaluation.reason,
|
||||
total_steps: context.actionTracker.getState().totalStep,
|
||||
total_tokens: context.tokenTracker.getTotalUsage(),
|
||||
total_tokens: context.tokenTracker.getTotalUsage().totalTokens,
|
||||
question,
|
||||
expected_answer: expectedAnswer,
|
||||
actual_answer: actualAnswer
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { z } from 'zod';
|
||||
import { generateObject } from 'ai';
|
||||
import { getModel, getMaxTokens } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { ErrorAnalysisResponse } from '../types';
|
||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||
import {z} from 'zod';
|
||||
import {generateObject} from 'ai';
|
||||
import {getModel, getMaxTokens} from "../config";
|
||||
import {TokenTracker} from "../utils/token-tracker";
|
||||
import {ErrorAnalysisResponse} from '../types';
|
||||
import {handleGenerateObjectError} from '../utils/error-handling';
|
||||
|
||||
const model = getModel('errorAnalyzer');
|
||||
|
||||
@@ -12,13 +12,12 @@ const responseSchema = z.object({
|
||||
blame: z.string().describe('Which action or the step was the root cause of the answer rejection'),
|
||||
improvement: z.string().describe('Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe.'),
|
||||
questionsToAnswer: z.array(
|
||||
z.string().describe("each question must be a single line, concise and clear. not composite or compound, less than 20 words.")
|
||||
).max(2)
|
||||
.describe("List of most important reflect questions to fill the knowledge gaps"),
|
||||
z.string().describe("each question must be a single line, concise and clear. not composite or compound, less than 20 words.")
|
||||
).max(2)
|
||||
.describe("List of most important reflect questions to fill the knowledge gaps"),
|
||||
});
|
||||
|
||||
|
||||
|
||||
function getPrompt(diaryContext: string[]): string {
|
||||
return `You are an expert at analyzing search and reasoning processes. Your task is to analyze the given sequence of steps and identify what went wrong in the search process.
|
||||
|
||||
@@ -112,11 +111,11 @@ ${diaryContext.join('\n')}
|
||||
`;
|
||||
}
|
||||
|
||||
export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracker): Promise<{ response: ErrorAnalysisResponse, tokens: number }> {
|
||||
export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracker): Promise<{ response: ErrorAnalysisResponse }> {
|
||||
try {
|
||||
const prompt = getPrompt(diaryContext);
|
||||
let object;
|
||||
let tokens = 0;
|
||||
let usage;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
@@ -125,18 +124,18 @@ export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracke
|
||||
maxTokens: getMaxTokens('errorAnalyzer')
|
||||
});
|
||||
object = result.object;
|
||||
tokens = result.usage?.totalTokens || 0;
|
||||
usage = result.usage;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<ErrorAnalysisResponse>(error);
|
||||
object = result.object;
|
||||
tokens = result.totalTokens;
|
||||
usage = result.usage;
|
||||
}
|
||||
console.log('Error analysis:', {
|
||||
is_valid: !object.blame,
|
||||
reason: object.blame || 'No issues found'
|
||||
});
|
||||
(tracker || new TokenTracker()).trackUsage('error-analyzer', tokens);
|
||||
return { response: object, tokens };
|
||||
(tracker || new TokenTracker()).trackUsage('error-analyzer', usage);
|
||||
return {response: object};
|
||||
} catch (error) {
|
||||
console.error('Error in answer evaluation:', error);
|
||||
throw error;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import {z} from 'zod';
|
||||
import {generateObject} from 'ai';
|
||||
import {generateObject, GenerateObjectResult} from 'ai';
|
||||
import {getModel, getMaxTokens} from "../config";
|
||||
import {TokenTracker} from "../utils/token-tracker";
|
||||
import {AnswerAction, EvaluationResponse} from '../types';
|
||||
@@ -383,7 +383,7 @@ export async function evaluateQuestion(
|
||||
maxTokens: getMaxTokens('evaluator')
|
||||
});
|
||||
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage);
|
||||
console.log('Question Evaluation:', result.object);
|
||||
|
||||
// Always include definitive in types
|
||||
@@ -397,7 +397,7 @@ export async function evaluateQuestion(
|
||||
return types;
|
||||
} catch (error) {
|
||||
const errorResult = await handleGenerateObjectError<EvaluationResponse>(error);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.totalTokens || 0);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.usage);
|
||||
return ['definitive', 'freshness', 'plurality'];
|
||||
}
|
||||
}
|
||||
@@ -413,7 +413,7 @@ async function performEvaluation(
|
||||
maxTokens: number;
|
||||
},
|
||||
tracker?: TokenTracker
|
||||
): Promise<GenerateObjectResult> {
|
||||
): Promise<GenerateObjectResult<any>> {
|
||||
const result = await generateObject({
|
||||
model: params.model,
|
||||
schema: params.schema,
|
||||
@@ -421,18 +421,12 @@ async function performEvaluation(
|
||||
maxTokens: params.maxTokens
|
||||
});
|
||||
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage);
|
||||
console.log(`${evaluationType} Evaluation:`, result.object);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
interface GenerateObjectResult {
|
||||
object: EvaluationResponse;
|
||||
usage?: {
|
||||
totalTokens: number;
|
||||
};
|
||||
}
|
||||
|
||||
// Main evaluation function
|
||||
export async function evaluateAnswer(
|
||||
@@ -441,7 +435,7 @@ export async function evaluateAnswer(
|
||||
evaluationOrder: EvaluationType[] = ['definitive', 'freshness', 'plurality'],
|
||||
tracker?: TokenTracker
|
||||
): Promise<{ response: EvaluationResponse }> {
|
||||
let result: GenerateObjectResult;
|
||||
let result;
|
||||
|
||||
// Only add attribution if we have valid references
|
||||
if (action.references && action.references.length > 0) {
|
||||
@@ -525,7 +519,7 @@ export async function evaluateAnswer(
|
||||
}
|
||||
} catch (error) {
|
||||
const errorResult = await handleGenerateObjectError<EvaluationResponse>(error);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.totalTokens || 0);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.usage);
|
||||
return {response: errorResult.object};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import axios, { AxiosError } from 'axios';
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import axios, {AxiosError} from 'axios';
|
||||
import {TokenTracker} from "../utils/token-tracker";
|
||||
import {JINA_API_KEY} from "../config";
|
||||
|
||||
const JINA_API_URL = 'https://api.jina.ai/v1/embeddings';
|
||||
@@ -107,27 +107,23 @@ export async function dedupQueries(
|
||||
newQueries: string[],
|
||||
existingQueries: string[],
|
||||
tracker?: TokenTracker
|
||||
): Promise<{ unique_queries: string[], tokens: number }> {
|
||||
): Promise<{ unique_queries: string[] }> {
|
||||
try {
|
||||
// Quick return for single new query with no existing queries
|
||||
if (newQueries.length === 1 && existingQueries.length === 0) {
|
||||
console.log('Dedup (quick return):', newQueries);
|
||||
return {
|
||||
unique_queries: newQueries,
|
||||
tokens: 0 // No tokens used since we didn't call the API
|
||||
};
|
||||
}
|
||||
|
||||
// Get embeddings for all queries in one batch
|
||||
const allQueries = [...newQueries, ...existingQueries];
|
||||
const { embeddings: allEmbeddings, tokens } = await getEmbeddings(allQueries);
|
||||
const {embeddings: allEmbeddings, tokens} = await getEmbeddings(allQueries);
|
||||
|
||||
// If embeddings is empty (due to 402 error), return all new queries
|
||||
if (!allEmbeddings.length) {
|
||||
console.log('Dedup (no embeddings):', newQueries);
|
||||
return {
|
||||
unique_queries: newQueries,
|
||||
tokens: 0
|
||||
};
|
||||
}
|
||||
|
||||
@@ -170,11 +166,14 @@ export async function dedupQueries(
|
||||
}
|
||||
|
||||
// Track token usage from the API
|
||||
(tracker || new TokenTracker()).trackUsage('dedup', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('dedup', {
|
||||
promptTokens: tokens,
|
||||
completionTokens: 0,
|
||||
totalTokens: tokens
|
||||
});
|
||||
console.log('Dedup:', uniqueQueries);
|
||||
return {
|
||||
unique_queries: uniqueQueries,
|
||||
tokens
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error in deduplication analysis:', error);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { TokenTracker } from "../utils/token-tracker";
|
||||
import { SearchResponse } from '../types';
|
||||
import { JINA_API_KEY } from "../config";
|
||||
|
||||
export function search(query: string, tracker?: TokenTracker): Promise<{ response: SearchResponse, tokens: number }> {
|
||||
export function search(query: string, tracker?: TokenTracker): Promise<{ response: SearchResponse}> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!query.trim()) {
|
||||
reject(new Error('Query cannot be empty'));
|
||||
@@ -63,9 +63,13 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons
|
||||
console.log('Total URLs:', response.data.length);
|
||||
|
||||
const tokenTracker = tracker || new TokenTracker();
|
||||
tokenTracker.trackUsage('search', totalTokens);
|
||||
tokenTracker.trackUsage('search', {
|
||||
totalTokens,
|
||||
promptTokens: query.length,
|
||||
completionTokens: totalTokens
|
||||
});
|
||||
|
||||
resolve({ response, tokens: totalTokens });
|
||||
resolve({ response });
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -93,11 +93,11 @@ Intention: ${action.think}
|
||||
`;
|
||||
}
|
||||
|
||||
export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[], tokens: number }> {
|
||||
export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[] }> {
|
||||
try {
|
||||
const prompt = getPrompt(action);
|
||||
let object;
|
||||
let tokens = 0;
|
||||
let usage;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
@@ -106,15 +106,15 @@ export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker)
|
||||
maxTokens: getMaxTokens('queryRewriter')
|
||||
});
|
||||
object = result.object;
|
||||
tokens = result.usage?.totalTokens || 0;
|
||||
usage = result.usage;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<KeywordsResponse>(error);
|
||||
object = result.object;
|
||||
tokens = result.totalTokens;
|
||||
usage = result.usage;
|
||||
}
|
||||
console.log('Query rewriter:', object.queries);
|
||||
(tracker || new TokenTracker()).trackUsage('query-rewriter', tokens);
|
||||
return { queries: object.queries, tokens };
|
||||
(tracker || new TokenTracker()).trackUsage('query-rewriter', usage);
|
||||
return { queries: object.queries };
|
||||
} catch (error) {
|
||||
console.error('Error in query rewriting:', error);
|
||||
throw error;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { TokenTracker } from "../utils/token-tracker";
|
||||
import { ReadResponse } from '../types';
|
||||
import { JINA_API_KEY } from "../config";
|
||||
|
||||
export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response: ReadResponse, tokens: number }> {
|
||||
export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response: ReadResponse }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!url.trim()) {
|
||||
reject(new Error('URL cannot be empty'));
|
||||
@@ -72,9 +72,13 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response
|
||||
|
||||
const tokens = response.data.usage?.tokens || 0;
|
||||
const tokenTracker = tracker || new TokenTracker();
|
||||
tokenTracker.trackUsage('read', tokens);
|
||||
tokenTracker.trackUsage('read', {
|
||||
totalTokens: tokens,
|
||||
promptTokens: url.length,
|
||||
completionTokens: tokens
|
||||
});
|
||||
|
||||
resolve({ response, tokens });
|
||||
resolve({ response });
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
18
src/types.ts
18
src/types.ts
@@ -1,4 +1,6 @@
|
||||
// Action Types
|
||||
import {LanguageModelUsage} from "ai";
|
||||
|
||||
type BaseAction = {
|
||||
action: "search" | "answer" | "reflect" | "visit";
|
||||
think: string;
|
||||
@@ -30,25 +32,11 @@ export type VisitAction = BaseAction & {
|
||||
|
||||
export type StepAction = SearchAction | AnswerAction | ReflectAction | VisitAction;
|
||||
|
||||
// Response Types
|
||||
export const TOKEN_CATEGORIES = {
|
||||
PROMPT: 'prompt',
|
||||
REASONING: 'reasoning',
|
||||
ACCEPTED: 'accepted',
|
||||
REJECTED: 'rejected'
|
||||
} as const;
|
||||
|
||||
export type TokenCategory = typeof TOKEN_CATEGORIES[keyof typeof TOKEN_CATEGORIES];
|
||||
|
||||
// Following Vercel AI SDK's token counting interface
|
||||
export interface TokenUsage {
|
||||
tool: string;
|
||||
tokens: number;
|
||||
category?: TokenCategory;
|
||||
// Following Vercel AI SDK's token counting interface
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
total_tokens?: number;
|
||||
usage: LanguageModelUsage;
|
||||
}
|
||||
|
||||
export interface SearchResponse {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import {NoObjectGeneratedError} from "ai";
|
||||
import {LanguageModelUsage, NoObjectGeneratedError} from "ai";
|
||||
|
||||
export interface GenerateObjectResult<T> {
|
||||
object: T;
|
||||
totalTokens: number;
|
||||
usage: LanguageModelUsage;
|
||||
}
|
||||
|
||||
export async function handleGenerateObjectError<T>(error: unknown): Promise<GenerateObjectResult<T>> {
|
||||
@@ -12,7 +12,7 @@ export async function handleGenerateObjectError<T>(error: unknown): Promise<Gene
|
||||
const partialResponse = JSON.parse((error as any).text);
|
||||
return {
|
||||
object: partialResponse as T,
|
||||
totalTokens: (error as any).usage?.totalTokens || 0
|
||||
usage: (error as any).usage
|
||||
};
|
||||
} catch (parseError) {
|
||||
throw error;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { EventEmitter } from 'events';
|
||||
import {EventEmitter} from 'events';
|
||||
|
||||
import { TokenUsage, TokenCategory } from '../types';
|
||||
import {TokenUsage} from '../types';
|
||||
import {LanguageModelUsage} from "ai";
|
||||
|
||||
export class TokenTracker extends EventEmitter {
|
||||
private usages: TokenUsage[] = [];
|
||||
@@ -17,72 +18,46 @@ export class TokenTracker extends EventEmitter {
|
||||
asyncLocalContext.ctx.chargeAmount = this.getTotalUsage();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
trackUsage(tool: string, tokens: number, category?: TokenCategory) {
|
||||
const currentTotal = this.getTotalUsage();
|
||||
if (this.budget && currentTotal + tokens > this.budget) {
|
||||
console.error(`Token budget exceeded: ${currentTotal + tokens} > ${this.budget}`);
|
||||
}
|
||||
// Only track usage if we're within budget
|
||||
if (!this.budget || currentTotal + tokens <= this.budget) {
|
||||
const usage = { tool, tokens, category };
|
||||
this.usages.push(usage);
|
||||
this.emit('usage', usage);
|
||||
}
|
||||
trackUsage(tool: string, usage: LanguageModelUsage) {
|
||||
const u = {tool, usage};
|
||||
this.usages.push(u);
|
||||
this.emit('usage', usage);
|
||||
}
|
||||
|
||||
getTotalUsage(): number {
|
||||
return this.usages.reduce((sum, usage) => sum + usage.tokens, 0);
|
||||
getTotalUsage(): LanguageModelUsage {
|
||||
return this.usages.reduce((acc, {usage}) => {
|
||||
acc.promptTokens += usage.promptTokens;
|
||||
acc.completionTokens += usage.completionTokens;
|
||||
acc.totalTokens += usage.totalTokens;
|
||||
return acc;
|
||||
}, {promptTokens: 0, completionTokens: 0, totalTokens: 0});
|
||||
}
|
||||
|
||||
getTotalUsageSnakeCase(): {prompt_tokens: number, completion_tokens: number, total_tokens: number} {
|
||||
return this.usages.reduce((acc, {usage}) => {
|
||||
acc.prompt_tokens += usage.promptTokens;
|
||||
acc.completion_tokens += usage.completionTokens;
|
||||
acc.total_tokens += usage.totalTokens;
|
||||
return acc;
|
||||
}, {prompt_tokens: 0, completion_tokens: 0, total_tokens: 0});
|
||||
}
|
||||
|
||||
getUsageBreakdown(): Record<string, number> {
|
||||
return this.usages.reduce((acc, { tool, tokens }) => {
|
||||
acc[tool] = (acc[tool] || 0) + tokens;
|
||||
return this.usages.reduce((acc, {tool, usage}) => {
|
||||
acc[tool] = (acc[tool] || 0) + usage.totalTokens;
|
||||
return acc;
|
||||
}, {} as Record<string, number>);
|
||||
}
|
||||
|
||||
getUsageDetails(): {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
completion_tokens_details?: {
|
||||
reasoning_tokens: number;
|
||||
accepted_prediction_tokens: number;
|
||||
rejected_prediction_tokens: number;
|
||||
};
|
||||
} {
|
||||
const categoryBreakdown = this.usages.reduce((acc, { tokens, category }) => {
|
||||
if (category) {
|
||||
acc[category] = (acc[category] || 0) + tokens;
|
||||
}
|
||||
return acc;
|
||||
}, {} as Record<string, number>);
|
||||
|
||||
const prompt_tokens = categoryBreakdown.prompt || 0;
|
||||
const completion_tokens =
|
||||
(categoryBreakdown.reasoning || 0) +
|
||||
(categoryBreakdown.accepted || 0) +
|
||||
(categoryBreakdown.rejected || 0);
|
||||
|
||||
return {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens,
|
||||
completion_tokens_details: {
|
||||
reasoning_tokens: categoryBreakdown.reasoning || 0,
|
||||
accepted_prediction_tokens: categoryBreakdown.accepted || 0,
|
||||
rejected_prediction_tokens: categoryBreakdown.rejected || 0
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
printSummary() {
|
||||
const breakdown = this.getUsageBreakdown();
|
||||
console.log('Token Usage Summary:', {
|
||||
budget: this.budget,
|
||||
total: this.getTotalUsage(),
|
||||
breakdown
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user