From 41ee1812db1d6cd8a3df1a45b26073958737fe31 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 12 Feb 2025 11:20:49 +0800 Subject: [PATCH] fix: token tracking --- src/agent.ts | 21 ++--- src/app.ts | 178 +----------------------------------- src/evals/batch-evals.ts | 2 +- src/tools/error-analyzer.ts | 31 +++---- src/tools/evaluator.ts | 20 ++-- src/tools/jina-dedup.ts | 19 ++-- src/tools/jina-search.ts | 10 +- src/tools/query-rewriter.ts | 12 +-- src/tools/read.ts | 10 +- src/types.ts | 18 +--- src/utils/error-handling.ts | 6 +- src/utils/token-tracker.ts | 79 ++++++---------- 12 files changed, 96 insertions(+), 310 deletions(-) diff --git a/src/agent.ts b/src/agent.ts index 406e5b7..99d211f 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -314,12 +314,12 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ const allURLs: Record = {}; const visitedURLs: string[] = []; const evaluationMetrics: Record = {}; - while (context.tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) { + while (context.tokenTracker.getTotalUsage().totalTokens < tokenBudget && badAttempts <= maxBadAttempts) { // add 1s delay to avoid rate limiting await sleep(STEP_SLEEP); step++; totalStep++; - const budgetPercentage = (context.tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2); + const budgetPercentage = (context.tokenTracker.getTotalUsage().totalTokens / tokenBudget * 100).toFixed(2); console.log(`Step ${totalStep} / Budget used ${budgetPercentage}%`); console.log('Gaps:', gaps); allowReflect = allowReflect && (gaps.length <= 1); @@ -350,7 +350,6 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ schema = getSchema(allowReflect, allowRead, allowAnswer, allowSearch) const model = getModel('agent'); let object; - let totalTokens = 0; try { const result = await generateObject({ model, @@ -359,11 +358,11 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ maxTokens: getMaxTokens('agent') }); object = result.object; - totalTokens = result.usage?.totalTokens || 0; + context.tokenTracker.trackUsage('agent', result.usage); } catch (error) { const result = await handleGenerateObjectError(error); object = result.object; - totalTokens = result.totalTokens; + context.tokenTracker.trackUsage('agent', result.usage); } thisStep = object as StepAction; // print allowed and chose action @@ -372,7 +371,6 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ console.log(thisStep) context.actionTracker.trackAction({totalStep, thisStep, gaps, badAttempts}); - context.tokenTracker.trackUsage('agent', totalTokens); // reset allowAnswer to true allowAnswer = true; @@ -622,7 +620,7 @@ You decided to think out of the box or cut from a completely different angle. const urlResults = await Promise.all( uniqueURLs.map(async (url: string) => { try { - const {response, tokens} = await readUrl(url, context.tokenTracker); + const {response} = await readUrl(url, context.tokenTracker); allKnowledge.push({ question: `What is in ${response.data?.url || 'the URL'}?`, answer: removeAllLineBreaks(response.data?.content || 'No content available'), @@ -631,7 +629,7 @@ You decided to think out of the box or cut from a completely different angle. }); visitedURLs.push(url); delete allURLs[url]; - return {url, result: response, tokens}; + return {url, result: response}; } catch (error) { console.error('Error reading URL:', error); } @@ -696,7 +694,6 @@ You decided to think out of the box or cut from a completely different angle.`); schema = getSchema(false, false, true, false); const model = getModel('agentBeastMode'); let object; - let totalTokens; try { const result = await generateObject({ model, @@ -705,16 +702,16 @@ You decided to think out of the box or cut from a completely different angle.`); maxTokens: getMaxTokens('agentBeastMode') }); object = result.object; - totalTokens = result.usage?.totalTokens || 0; + context.tokenTracker.trackUsage('agent', result.usage); } catch (error) { const result = await handleGenerateObjectError(error); object = result.object; - totalTokens = result.totalTokens; + context.tokenTracker.trackUsage('agent', result.usage); } await storeContext(prompt, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep); thisStep = object as StepAction; context.actionTracker.trackAction({totalStep, thisStep, gaps, badAttempts}); - context.tokenTracker.trackUsage('agent', totalTokens); + console.log(thisStep) return {result: thisStep, context}; } diff --git a/src/app.ts b/src/app.ts index 949b619..2519ea8 100644 --- a/src/app.ts +++ b/src/app.ts @@ -1,10 +1,7 @@ import express, {Request, Response, RequestHandler} from 'express'; import cors from 'cors'; -import {EventEmitter} from 'events'; import {getResponse} from './agent'; import { - StepAction, - StreamMessage, TrackerContext, ChatCompletionRequest, ChatCompletionResponse, @@ -12,8 +9,6 @@ import { AnswerAction, Model } from './types'; -import fs from 'fs/promises'; -import path from 'path'; import {TokenTracker} from "./utils/token-tracker"; import {ActionTracker} from "./utils/action-tracker"; @@ -30,16 +25,6 @@ app.get('/health', (req, res) => { res.json({status: 'ok'}); }); -const eventEmitter = new EventEmitter(); - -interface QueryRequest extends Request { - body: { - q: string; - budget?: number; - maxBadAttempt?: number; - }; -} - function buildMdFromAnswer(answer: AnswerAction) { let refStr = ''; if (answer.references?.length > 0) { @@ -358,7 +343,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { } } - const usage = context.tokenTracker.getUsageDetails(); + const usage = context.tokenTracker.getTotalUsageSnakeCase(); if (body.stream) { // Complete any ongoing streaming before sending final answer await completeCurrentStreaming(streamingState, res, requestId, created, body.model); @@ -442,7 +427,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { context.actionTracker.removeAllListeners('action'); // Get token usage in OpenAI API format - const usage = context.tokenTracker.getUsageDetails(); + const usage = context.tokenTracker.getTotalUsageSnakeCase(); if (body.stream && res.headersSent) { // For streaming responses that have already started, send error as a chunk @@ -504,165 +489,6 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { } }) as RequestHandler); -interface StreamResponse extends Response { - write: (chunk: string) => boolean; -} -function createProgressEmitter(requestId: string, budget: number | undefined, context: TrackerContext) { - return () => { - const state = context.actionTracker.getState(); - const budgetInfo = { - used: context.tokenTracker.getTotalUsage(), - total: budget || 1_000_000, - percentage: ((context.tokenTracker.getTotalUsage() / (budget || 1_000_000)) * 100).toFixed(2) - }; - - eventEmitter.emit(`progress-${requestId}`, { - type: 'progress', - data: {...state.thisStep, totalStep: state.totalStep}, - step: state.totalStep, - budget: budgetInfo, - trackers: { - tokenUsage: context.tokenTracker.getTotalUsage(), - actionState: context.actionTracker.getState() - } - }); - }; -} - -function cleanup(requestId: string) { - const context = trackers.get(requestId); - if (context) { - context.actionTracker.removeAllListeners(); - context.tokenTracker.removeAllListeners(); - trackers.delete(requestId); - } -} - -function emitTrackerUpdate(requestId: string, context: TrackerContext) { - const trackerData = { - tokenUsage: context.tokenTracker.getTotalUsage(), - tokenBreakdown: context.tokenTracker.getUsageBreakdown(), - actionState: context.actionTracker.getState().thisStep, - step: context.actionTracker.getState().totalStep, - badAttempts: context.actionTracker.getState().badAttempts, - gaps: context.actionTracker.getState().gaps - }; - - eventEmitter.emit(`progress-${requestId}`, { - type: 'progress', - trackers: trackerData - }); -} - -// Store the trackers for each request -const trackers = new Map(); - -app.post('/api/v1/query', (async (req: QueryRequest, res: Response) => { - const {q, budget, maxBadAttempt} = req.body; - if (!q) { - return res.status(400).json({error: 'Query (q) is required'}); - } - - const requestId = Date.now().toString(); - - // Create new trackers for this request - const context: TrackerContext = { - tokenTracker: new TokenTracker(), - actionTracker: new ActionTracker() - }; - trackers.set(requestId, context); - - // Set up listeners immediately for both trackers - context.actionTracker.on('action', () => emitTrackerUpdate(requestId, context)); - // context.tokenTracker.on('usage', () => emitTrackerUpdate(requestId, context)); - - res.json({requestId}); - - try { - const {result} = await getResponse(q, budget, maxBadAttempt, context); - const emitProgress = createProgressEmitter(requestId, budget, context); - context.actionTracker.on('action', emitProgress); - await storeTaskResult(requestId, result); - eventEmitter.emit(`progress-${requestId}`, { - type: 'answer', - data: result, - trackers: { - tokenUsage: context.tokenTracker.getTotalUsage(), - actionState: context.actionTracker.getState() - } - }); - cleanup(requestId); - } catch (error: any) { - eventEmitter.emit(`progress-${requestId}`, { - type: 'error', - data: error?.message || 'Unknown error', - status: 500, - trackers: { - tokenUsage: context.tokenTracker.getTotalUsage(), - actionState: context.actionTracker.getState() - } - }); - cleanup(requestId); - } -}) as RequestHandler); - -app.get('/api/v1/stream/:requestId', (async (req: Request, res: StreamResponse) => { - const requestId = req.params.requestId; - const context = trackers.get(requestId); - - res.setHeader('Content-Type', 'text/event-stream'); - res.setHeader('Cache-Control', 'no-cache'); - res.setHeader('Connection', 'keep-alive'); - - const listener = (data: StreamMessage) => { - // The trackers are now included in all event types - // We don't need to add them here as they're already part of the data - res.write(`data: ${JSON.stringify(data)}\n\n`); - }; - - eventEmitter.on(`progress-${requestId}`, listener); - - // Handle client disconnection - req.on('close', () => { - eventEmitter.removeListener(`progress-${requestId}`, listener); - }); - - // Send initial connection confirmation with tracker state - const initialData = { - type: 'connected', - requestId, - trackers: context ? { - tokenUsage: context.tokenTracker.getTotalUsage(), - actionState: context.actionTracker.getState() - } : null - }; - res.write(`data: ${JSON.stringify(initialData)}\n\n`); -}) as RequestHandler); - -async function storeTaskResult(requestId: string, result: StepAction) { - try { - const taskDir = path.join(process.cwd(), 'tasks'); - await fs.mkdir(taskDir, {recursive: true}); - await fs.writeFile( - path.join(taskDir, `${requestId}.json`), - JSON.stringify(result, null, 2) - ); - } catch (error) { - console.error('Task storage failed:', error); - throw new Error('Failed to store task result'); - } -} - -app.get('/api/v1/task/:requestId', (async (req: Request, res: Response) => { - const requestId = req.params.requestId; - try { - const taskPath = path.join(process.cwd(), 'tasks', `${requestId}.json`); - const taskData = await fs.readFile(taskPath, 'utf-8'); - res.json(JSON.parse(taskData)); - } catch (error) { - res.status(404).json({error: 'Task not found'}); - } -}) as RequestHandler); export default app; diff --git a/src/evals/batch-evals.ts b/src/evals/batch-evals.ts index 72ca790..cf9326b 100644 --- a/src/evals/batch-evals.ts +++ b/src/evals/batch-evals.ts @@ -160,7 +160,7 @@ async function batchEvaluate(inputFile: string): Promise { pass: evaluation.pass, reason: evaluation.reason, total_steps: context.actionTracker.getState().totalStep, - total_tokens: context.tokenTracker.getTotalUsage(), + total_tokens: context.tokenTracker.getTotalUsage().totalTokens, question, expected_answer: expectedAnswer, actual_answer: actualAnswer diff --git a/src/tools/error-analyzer.ts b/src/tools/error-analyzer.ts index 61c2715..08f48de 100644 --- a/src/tools/error-analyzer.ts +++ b/src/tools/error-analyzer.ts @@ -1,9 +1,9 @@ -import { z } from 'zod'; -import { generateObject } from 'ai'; -import { getModel, getMaxTokens } from "../config"; -import { TokenTracker } from "../utils/token-tracker"; -import { ErrorAnalysisResponse } from '../types'; -import { handleGenerateObjectError } from '../utils/error-handling'; +import {z} from 'zod'; +import {generateObject} from 'ai'; +import {getModel, getMaxTokens} from "../config"; +import {TokenTracker} from "../utils/token-tracker"; +import {ErrorAnalysisResponse} from '../types'; +import {handleGenerateObjectError} from '../utils/error-handling'; const model = getModel('errorAnalyzer'); @@ -12,13 +12,12 @@ const responseSchema = z.object({ blame: z.string().describe('Which action or the step was the root cause of the answer rejection'), improvement: z.string().describe('Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe.'), questionsToAnswer: z.array( - z.string().describe("each question must be a single line, concise and clear. not composite or compound, less than 20 words.") - ).max(2) - .describe("List of most important reflect questions to fill the knowledge gaps"), + z.string().describe("each question must be a single line, concise and clear. not composite or compound, less than 20 words.") + ).max(2) + .describe("List of most important reflect questions to fill the knowledge gaps"), }); - function getPrompt(diaryContext: string[]): string { return `You are an expert at analyzing search and reasoning processes. Your task is to analyze the given sequence of steps and identify what went wrong in the search process. @@ -112,11 +111,11 @@ ${diaryContext.join('\n')} `; } -export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracker): Promise<{ response: ErrorAnalysisResponse, tokens: number }> { +export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracker): Promise<{ response: ErrorAnalysisResponse }> { try { const prompt = getPrompt(diaryContext); let object; - let tokens = 0; + let usage; try { const result = await generateObject({ model, @@ -125,18 +124,18 @@ export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracke maxTokens: getMaxTokens('errorAnalyzer') }); object = result.object; - tokens = result.usage?.totalTokens || 0; + usage = result.usage; } catch (error) { const result = await handleGenerateObjectError(error); object = result.object; - tokens = result.totalTokens; + usage = result.usage; } console.log('Error analysis:', { is_valid: !object.blame, reason: object.blame || 'No issues found' }); - (tracker || new TokenTracker()).trackUsage('error-analyzer', tokens); - return { response: object, tokens }; + (tracker || new TokenTracker()).trackUsage('error-analyzer', usage); + return {response: object}; } catch (error) { console.error('Error in answer evaluation:', error); throw error; diff --git a/src/tools/evaluator.ts b/src/tools/evaluator.ts index a20ea52..7f2f655 100644 --- a/src/tools/evaluator.ts +++ b/src/tools/evaluator.ts @@ -1,5 +1,5 @@ import {z} from 'zod'; -import {generateObject} from 'ai'; +import {generateObject, GenerateObjectResult} from 'ai'; import {getModel, getMaxTokens} from "../config"; import {TokenTracker} from "../utils/token-tracker"; import {AnswerAction, EvaluationResponse} from '../types'; @@ -383,7 +383,7 @@ export async function evaluateQuestion( maxTokens: getMaxTokens('evaluator') }); - (tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0); + (tracker || new TokenTracker()).trackUsage('evaluator', result.usage); console.log('Question Evaluation:', result.object); // Always include definitive in types @@ -397,7 +397,7 @@ export async function evaluateQuestion( return types; } catch (error) { const errorResult = await handleGenerateObjectError(error); - (tracker || new TokenTracker()).trackUsage('evaluator', errorResult.totalTokens || 0); + (tracker || new TokenTracker()).trackUsage('evaluator', errorResult.usage); return ['definitive', 'freshness', 'plurality']; } } @@ -413,7 +413,7 @@ async function performEvaluation( maxTokens: number; }, tracker?: TokenTracker -): Promise { +): Promise> { const result = await generateObject({ model: params.model, schema: params.schema, @@ -421,18 +421,12 @@ async function performEvaluation( maxTokens: params.maxTokens }); - (tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0); + (tracker || new TokenTracker()).trackUsage('evaluator', result.usage); console.log(`${evaluationType} Evaluation:`, result.object); return result; } -interface GenerateObjectResult { - object: EvaluationResponse; - usage?: { - totalTokens: number; - }; -} // Main evaluation function export async function evaluateAnswer( @@ -441,7 +435,7 @@ export async function evaluateAnswer( evaluationOrder: EvaluationType[] = ['definitive', 'freshness', 'plurality'], tracker?: TokenTracker ): Promise<{ response: EvaluationResponse }> { - let result: GenerateObjectResult; + let result; // Only add attribution if we have valid references if (action.references && action.references.length > 0) { @@ -525,7 +519,7 @@ export async function evaluateAnswer( } } catch (error) { const errorResult = await handleGenerateObjectError(error); - (tracker || new TokenTracker()).trackUsage('evaluator', errorResult.totalTokens || 0); + (tracker || new TokenTracker()).trackUsage('evaluator', errorResult.usage); return {response: errorResult.object}; } } diff --git a/src/tools/jina-dedup.ts b/src/tools/jina-dedup.ts index 5b051a0..f78fb43 100644 --- a/src/tools/jina-dedup.ts +++ b/src/tools/jina-dedup.ts @@ -1,5 +1,5 @@ -import axios, { AxiosError } from 'axios'; -import { TokenTracker } from "../utils/token-tracker"; +import axios, {AxiosError} from 'axios'; +import {TokenTracker} from "../utils/token-tracker"; import {JINA_API_KEY} from "../config"; const JINA_API_URL = 'https://api.jina.ai/v1/embeddings'; @@ -107,27 +107,23 @@ export async function dedupQueries( newQueries: string[], existingQueries: string[], tracker?: TokenTracker -): Promise<{ unique_queries: string[], tokens: number }> { +): Promise<{ unique_queries: string[] }> { try { // Quick return for single new query with no existing queries if (newQueries.length === 1 && existingQueries.length === 0) { - console.log('Dedup (quick return):', newQueries); return { unique_queries: newQueries, - tokens: 0 // No tokens used since we didn't call the API }; } // Get embeddings for all queries in one batch const allQueries = [...newQueries, ...existingQueries]; - const { embeddings: allEmbeddings, tokens } = await getEmbeddings(allQueries); + const {embeddings: allEmbeddings, tokens} = await getEmbeddings(allQueries); // If embeddings is empty (due to 402 error), return all new queries if (!allEmbeddings.length) { - console.log('Dedup (no embeddings):', newQueries); return { unique_queries: newQueries, - tokens: 0 }; } @@ -170,11 +166,14 @@ export async function dedupQueries( } // Track token usage from the API - (tracker || new TokenTracker()).trackUsage('dedup', tokens); + (tracker || new TokenTracker()).trackUsage('dedup', { + promptTokens: tokens, + completionTokens: 0, + totalTokens: tokens + }); console.log('Dedup:', uniqueQueries); return { unique_queries: uniqueQueries, - tokens }; } catch (error) { console.error('Error in deduplication analysis:', error); diff --git a/src/tools/jina-search.ts b/src/tools/jina-search.ts index 3e259cc..555c504 100644 --- a/src/tools/jina-search.ts +++ b/src/tools/jina-search.ts @@ -3,7 +3,7 @@ import { TokenTracker } from "../utils/token-tracker"; import { SearchResponse } from '../types'; import { JINA_API_KEY } from "../config"; -export function search(query: string, tracker?: TokenTracker): Promise<{ response: SearchResponse, tokens: number }> { +export function search(query: string, tracker?: TokenTracker): Promise<{ response: SearchResponse}> { return new Promise((resolve, reject) => { if (!query.trim()) { reject(new Error('Query cannot be empty')); @@ -63,9 +63,13 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons console.log('Total URLs:', response.data.length); const tokenTracker = tracker || new TokenTracker(); - tokenTracker.trackUsage('search', totalTokens); + tokenTracker.trackUsage('search', { + totalTokens, + promptTokens: query.length, + completionTokens: totalTokens + }); - resolve({ response, tokens: totalTokens }); + resolve({ response }); }); }); diff --git a/src/tools/query-rewriter.ts b/src/tools/query-rewriter.ts index 70a7ed0..b9c9c2b 100644 --- a/src/tools/query-rewriter.ts +++ b/src/tools/query-rewriter.ts @@ -93,11 +93,11 @@ Intention: ${action.think} `; } -export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[], tokens: number }> { +export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[] }> { try { const prompt = getPrompt(action); let object; - let tokens = 0; + let usage; try { const result = await generateObject({ model, @@ -106,15 +106,15 @@ export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker) maxTokens: getMaxTokens('queryRewriter') }); object = result.object; - tokens = result.usage?.totalTokens || 0; + usage = result.usage; } catch (error) { const result = await handleGenerateObjectError(error); object = result.object; - tokens = result.totalTokens; + usage = result.usage; } console.log('Query rewriter:', object.queries); - (tracker || new TokenTracker()).trackUsage('query-rewriter', tokens); - return { queries: object.queries, tokens }; + (tracker || new TokenTracker()).trackUsage('query-rewriter', usage); + return { queries: object.queries }; } catch (error) { console.error('Error in query rewriting:', error); throw error; diff --git a/src/tools/read.ts b/src/tools/read.ts index 5ccaaf4..fe076be 100644 --- a/src/tools/read.ts +++ b/src/tools/read.ts @@ -3,7 +3,7 @@ import { TokenTracker } from "../utils/token-tracker"; import { ReadResponse } from '../types'; import { JINA_API_KEY } from "../config"; -export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response: ReadResponse, tokens: number }> { +export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response: ReadResponse }> { return new Promise((resolve, reject) => { if (!url.trim()) { reject(new Error('URL cannot be empty')); @@ -72,9 +72,13 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response const tokens = response.data.usage?.tokens || 0; const tokenTracker = tracker || new TokenTracker(); - tokenTracker.trackUsage('read', tokens); + tokenTracker.trackUsage('read', { + totalTokens: tokens, + promptTokens: url.length, + completionTokens: tokens + }); - resolve({ response, tokens }); + resolve({ response }); }); }); diff --git a/src/types.ts b/src/types.ts index df1ad43..f1c8d1d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,4 +1,6 @@ // Action Types +import {LanguageModelUsage} from "ai"; + type BaseAction = { action: "search" | "answer" | "reflect" | "visit"; think: string; @@ -30,25 +32,11 @@ export type VisitAction = BaseAction & { export type StepAction = SearchAction | AnswerAction | ReflectAction | VisitAction; -// Response Types -export const TOKEN_CATEGORIES = { - PROMPT: 'prompt', - REASONING: 'reasoning', - ACCEPTED: 'accepted', - REJECTED: 'rejected' -} as const; - -export type TokenCategory = typeof TOKEN_CATEGORIES[keyof typeof TOKEN_CATEGORIES]; // Following Vercel AI SDK's token counting interface export interface TokenUsage { tool: string; - tokens: number; - category?: TokenCategory; - // Following Vercel AI SDK's token counting interface - prompt_tokens?: number; - completion_tokens?: number; - total_tokens?: number; + usage: LanguageModelUsage; } export interface SearchResponse { diff --git a/src/utils/error-handling.ts b/src/utils/error-handling.ts index aed77e0..cde2a25 100644 --- a/src/utils/error-handling.ts +++ b/src/utils/error-handling.ts @@ -1,8 +1,8 @@ -import {NoObjectGeneratedError} from "ai"; +import {LanguageModelUsage, NoObjectGeneratedError} from "ai"; export interface GenerateObjectResult { object: T; - totalTokens: number; + usage: LanguageModelUsage; } export async function handleGenerateObjectError(error: unknown): Promise> { @@ -12,7 +12,7 @@ export async function handleGenerateObjectError(error: unknown): Promise this.budget) { - console.error(`Token budget exceeded: ${currentTotal + tokens} > ${this.budget}`); - } - // Only track usage if we're within budget - if (!this.budget || currentTotal + tokens <= this.budget) { - const usage = { tool, tokens, category }; - this.usages.push(usage); - this.emit('usage', usage); - } + trackUsage(tool: string, usage: LanguageModelUsage) { + const u = {tool, usage}; + this.usages.push(u); + this.emit('usage', usage); } - getTotalUsage(): number { - return this.usages.reduce((sum, usage) => sum + usage.tokens, 0); + getTotalUsage(): LanguageModelUsage { + return this.usages.reduce((acc, {usage}) => { + acc.promptTokens += usage.promptTokens; + acc.completionTokens += usage.completionTokens; + acc.totalTokens += usage.totalTokens; + return acc; + }, {promptTokens: 0, completionTokens: 0, totalTokens: 0}); + } + + getTotalUsageSnakeCase(): {prompt_tokens: number, completion_tokens: number, total_tokens: number} { + return this.usages.reduce((acc, {usage}) => { + acc.prompt_tokens += usage.promptTokens; + acc.completion_tokens += usage.completionTokens; + acc.total_tokens += usage.totalTokens; + return acc; + }, {prompt_tokens: 0, completion_tokens: 0, total_tokens: 0}); } getUsageBreakdown(): Record { - return this.usages.reduce((acc, { tool, tokens }) => { - acc[tool] = (acc[tool] || 0) + tokens; + return this.usages.reduce((acc, {tool, usage}) => { + acc[tool] = (acc[tool] || 0) + usage.totalTokens; return acc; }, {} as Record); } - getUsageDetails(): { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - completion_tokens_details?: { - reasoning_tokens: number; - accepted_prediction_tokens: number; - rejected_prediction_tokens: number; - }; - } { - const categoryBreakdown = this.usages.reduce((acc, { tokens, category }) => { - if (category) { - acc[category] = (acc[category] || 0) + tokens; - } - return acc; - }, {} as Record); - - const prompt_tokens = categoryBreakdown.prompt || 0; - const completion_tokens = - (categoryBreakdown.reasoning || 0) + - (categoryBreakdown.accepted || 0) + - (categoryBreakdown.rejected || 0); - - return { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - completion_tokens_details: { - reasoning_tokens: categoryBreakdown.reasoning || 0, - accepted_prediction_tokens: categoryBreakdown.accepted || 0, - rejected_prediction_tokens: categoryBreakdown.rejected || 0 - } - }; - } printSummary() { const breakdown = this.getUsageBreakdown(); console.log('Token Usage Summary:', { + budget: this.budget, total: this.getTotalUsage(), breakdown });