diff --git a/src/__tests__/server.test.ts b/src/__tests__/server.test.ts index 312a1b9..75b3cc4 100644 --- a/src/__tests__/server.test.ts +++ b/src/__tests__/server.test.ts @@ -11,6 +11,9 @@ describe('/v1/chat/completions', () => { beforeEach(async () => { // Set NODE_ENV to test to prevent server from auto-starting process.env.NODE_ENV = 'test'; + process.env.LLM_PROVIDER = 'openai'; // Use OpenAI for testing + process.env.OPENAI_API_KEY = 'test-openai-key'; + process.env.JINA_API_KEY = 'test-jina-key'; // Clean up any existing secret const existingSecretIndex = process.argv.findIndex(arg => arg.startsWith('--secret=')); @@ -32,11 +35,14 @@ describe('/v1/chat/completions', () => { emitter.removeAllListeners(); emitter.setMaxListeners(emitter.getMaxListeners() + 1); - // Clean up test secret + // Clean up test secret and environment variables const secretIndex = process.argv.findIndex(arg => arg.startsWith('--secret=')); if (secretIndex !== -1) { process.argv.splice(secretIndex, 1); } + delete process.env.LLM_PROVIDER; + delete process.env.GEMINI_API_KEY; + delete process.env.JINA_API_KEY; // Wait for any pending promises to settle await new Promise(resolve => setTimeout(resolve, 500)); @@ -258,12 +264,7 @@ describe('/v1/chat/completions', () => { expect(validResponse.body.usage).toMatchObject({ prompt_tokens: expect.any(Number), completion_tokens: expect.any(Number), - total_tokens: expect.any(Number), - completion_tokens_details: { - reasoning_tokens: expect.any(Number), - accepted_prediction_tokens: expect.any(Number), - rejected_prediction_tokens: expect.any(Number) - } + total_tokens: expect.any(Number) }); // Verify token counts are reasonable @@ -274,34 +275,102 @@ describe('/v1/chat/completions', () => { ); }); - it('should provide token usage in Vercel AI SDK format', async () => { + it('should provide accurate token counts for various message lengths', async () => { + const shortMessage = 'test'; + const mediumMessage = 'This is a medium length message that should have more tokens than the short message.'; + const longMessage = 'This is a very long message that should have many more tokens. '.repeat(10); + + // Test short message + const shortResponse = await request(app) + .post('/v1/chat/completions') + .set('Authorization', `Bearer ${TEST_SECRET}`) + .send({ + model: 'test-model', + messages: [{ role: 'user', content: shortMessage }] + }); + + // Test medium message + const mediumResponse = await request(app) + .post('/v1/chat/completions') + .set('Authorization', `Bearer ${TEST_SECRET}`) + .send({ + model: 'test-model', + messages: [{ role: 'user', content: mediumMessage }] + }); + + // Test long message + const longResponse = await request(app) + .post('/v1/chat/completions') + .set('Authorization', `Bearer ${TEST_SECRET}`) + .send({ + model: 'test-model', + messages: [{ role: 'user', content: longMessage }] + }); + + // Verify response format + [shortResponse, mediumResponse, longResponse].forEach(response => { + expect(response.status).toBe(200); + expect(response.body.usage).toMatchObject({ + prompt_tokens: expect.any(Number), + completion_tokens: expect.any(Number), + total_tokens: expect.any(Number) + }); + }); + + // Verify token counts increase with message length + const shortTokens = shortResponse.body.usage.prompt_tokens; + const mediumTokens = mediumResponse.body.usage.prompt_tokens; + const longTokens = longResponse.body.usage.prompt_tokens; + + expect(mediumTokens).toBeGreaterThan(shortTokens); + expect(longTokens).toBeGreaterThan(mediumTokens); + + // Verify token counts match our estimation (chars/4) + [ + { content: shortMessage, tokens: shortTokens }, + { content: mediumMessage, tokens: mediumTokens }, + { content: longMessage, tokens: longTokens } + ].forEach(({ content, tokens }) => { + const expectedTokens = Math.ceil(Buffer.byteLength(content, 'utf-8') / 4); + expect(tokens).toBe(expectedTokens); + }); + + // Verify total tokens calculation + [shortResponse, mediumResponse, longResponse].forEach(response => { + expect(response.body.usage.total_tokens).toBe( + response.body.usage.prompt_tokens + response.body.usage.completion_tokens + ); + }); + }); + + it('should count tokens correctly for multiple messages', async () => { + const messages = [ + { role: 'system', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'Hello!' }, + { role: 'assistant', content: 'Hi there! How can I help you?' }, + { role: 'user', content: 'What is the weather?' } + ]; + const response = await request(app) .post('/v1/chat/completions') .set('Authorization', `Bearer ${TEST_SECRET}`) .send({ model: 'test-model', - messages: [{ role: 'user', content: 'test' }] + messages }); - - expect(response.status).toBe(200); - const usage = response.body.usage; - expect(usage).toMatchObject({ + expect(response.status).toBe(200); + expect(response.body.usage).toMatchObject({ prompt_tokens: expect.any(Number), completion_tokens: expect.any(Number), - total_tokens: expect.any(Number), - completion_tokens_details: { - reasoning_tokens: expect.any(Number), - accepted_prediction_tokens: expect.any(Number), - rejected_prediction_tokens: expect.any(Number) - } + total_tokens: expect.any(Number) }); - // Verify token counts are reasonable - expect(usage.prompt_tokens).toBeGreaterThan(0); - expect(usage.completion_tokens).toBeGreaterThan(0); - expect(usage.total_tokens).toBe( - usage.prompt_tokens + usage.completion_tokens - ); + // Verify token count matches our estimation for all messages combined + const expectedPromptTokens = messages.reduce((total, msg) => { + return total + Math.ceil(Buffer.byteLength(msg.content, 'utf-8') / 4); + }, 0); + + expect(response.body.usage.prompt_tokens).toBe(expectedPromptTokens); }); }); diff --git a/src/app.ts b/src/app.ts index bd74137..fb1159f 100644 --- a/src/app.ts +++ b/src/app.ts @@ -205,10 +205,13 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { actionTracker: new ActionTracker() }; - // Track prompt tokens for the initial message - // Use Vercel's token counting convention - 1 token per message - const messageTokens = body.messages.length; - context.tokenTracker.trackUsage('agent', messageTokens, TOKEN_CATEGORIES.PROMPT); + // Track prompt tokens for each message using actual content length + for (const message of body.messages) { + // Estimate tokens using character count / 4 as a rough approximation + // This will be replaced with actual Gemini tokenizer in a future update + const estimatedTokens = Math.ceil(Buffer.byteLength(message.content, 'utf-8') / 4); + context.tokenTracker.trackUsage('agent', estimatedTokens); + } // Add this inside the chat completions endpoint, before setting up the action listener const streamingState: StreamingState = { @@ -328,13 +331,14 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { // Track tokens based on action type if (result.action === 'answer') { - // Track accepted prediction tokens for the final answer using Vercel's convention - const answerTokens = 1; // Default to 1 token per answer - context.tokenTracker.trackUsage('evaluator', answerTokens, TOKEN_CATEGORIES.ACCEPTED); + // Track tokens for the final answer using content length estimation + const content = result.action === 'answer' ? buildMdFromAnswer(result) : result.think; + const estimatedTokens = Math.ceil(Buffer.byteLength(content, 'utf-8') / 4); + context.tokenTracker.trackUsage('evaluator', estimatedTokens); } else { - // Track rejected prediction tokens for non-answer responses - const rejectedTokens = 1; // Default to 1 token per rejected response - context.tokenTracker.trackUsage('evaluator', rejectedTokens, TOKEN_CATEGORIES.REJECTED); + // Track tokens for non-answer responses using content length estimation + const estimatedTokens = Math.ceil(Buffer.byteLength(result.think, 'utf-8') / 4); + context.tokenTracker.trackUsage('evaluator', estimatedTokens); } if (body.stream) { @@ -412,11 +416,10 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { requestId }); - // Track error as rejected tokens with Vercel token counting + // Track error tokens using content length estimation const errorMessage = error?.message || 'An error occurred'; - // Default to 1 token for errors as per Vercel AI SDK convention - const errorTokens = 1; - context.tokenTracker.trackUsage('evaluator', errorTokens, TOKEN_CATEGORIES.REJECTED); + const estimatedTokens = Math.ceil(Buffer.byteLength(errorMessage, 'utf-8') / 4); + context.tokenTracker.trackUsage('evaluator', estimatedTokens); // Clean up event listeners context.actionTracker.removeAllListeners('action'); diff --git a/src/tools/evaluator.ts b/src/tools/evaluator.ts index c652f62..553b9bc 100644 --- a/src/tools/evaluator.ts +++ b/src/tools/evaluator.ts @@ -381,7 +381,9 @@ export async function evaluateQuestion( maxTokens: getMaxTokens('evaluator') }); - (tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0); + if (tracker) { + tracker.trackUsage('evaluator', result.usage?.totalTokens || 0); + } console.log('Question Evaluation:', result.object); // Always include definitive in types @@ -419,7 +421,9 @@ async function performEvaluation( maxTokens: params.maxTokens }); - (tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0); + if (tracker) { + tracker.trackUsage('evaluator', result.usage?.totalTokens || 0); + } console.log(`${evaluationType} Evaluation:`, result.object); return result; @@ -523,7 +527,9 @@ export async function evaluateAnswer( } } catch (error) { const errorResult = await handleGenerateObjectError(error); - (tracker || new TokenTracker()).trackUsage('evaluator', errorResult.totalTokens || 0); + if (tracker) { + tracker.trackUsage('evaluator', errorResult.totalTokens || 0); + } return {response: errorResult.object}; } } @@ -557,4 +563,4 @@ async function fetchSourceContent(urls: string[], tracker?: TokenTracker): Promi console.error('Error fetching source content:', error); return ''; } -} \ No newline at end of file +} diff --git a/src/tools/jina-search.ts b/src/tools/jina-search.ts index 3e259cc..6e96a4f 100644 --- a/src/tools/jina-search.ts +++ b/src/tools/jina-search.ts @@ -62,8 +62,9 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons const totalTokens = response.data.reduce((sum, item) => sum + (item.usage?.tokens || 0), 0); console.log('Total URLs:', response.data.length); - const tokenTracker = tracker || new TokenTracker(); - tokenTracker.trackUsage('search', totalTokens); + if (tracker) { + tracker.trackUsage('search', totalTokens); + } resolve({ response, tokens: totalTokens }); }); @@ -81,4 +82,4 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons req.end(); }); -} \ No newline at end of file +} diff --git a/src/tools/read.ts b/src/tools/read.ts index 5ccaaf4..374dbad 100644 --- a/src/tools/read.ts +++ b/src/tools/read.ts @@ -70,11 +70,20 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response tokens: response.data.usage?.tokens || 0 }); - const tokens = response.data.usage?.tokens || 0; - const tokenTracker = tracker || new TokenTracker(); - tokenTracker.trackUsage('read', tokens); + const apiTokens = response.data.usage?.tokens || 0; + + if (tracker) { + // Track API response tokens + tracker.trackUsage('read_api', apiTokens); + + // Track content length tokens using the same estimation method + if (response.data.content) { + const contentTokens = Math.ceil(Buffer.byteLength(response.data.content, 'utf-8') / 4); + tracker.trackUsage('read_content', contentTokens); + } + } - resolve({ response, tokens }); + resolve({ response, tokens: apiTokens }); }); }); @@ -95,4 +104,4 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response export function removeAllLineBreaks(text: string) { return text.replace(/(\r\n|\n|\r)/gm, " "); -} \ No newline at end of file +} diff --git a/src/utils/token-tracker.ts b/src/utils/token-tracker.ts index 5168a85..e90a20f 100644 --- a/src/utils/token-tracker.ts +++ b/src/utils/token-tracker.ts @@ -30,6 +30,8 @@ export class TokenTracker extends EventEmitter { if (!this.budget || currentTotal + tokens <= this.budget) { const usage = { tool, tokens, category }; this.usages.push(usage); + console.log(`[TokenTracker] Adding ${tokens} tokens from ${tool}${category ? ` (${category})` : ''}`); + console.log(`[TokenTracker] New total: ${this.getTotalUsage()}`); this.emit('usage', usage); } } @@ -55,28 +57,16 @@ export class TokenTracker extends EventEmitter { rejected_prediction_tokens: number; }; } { - const categoryBreakdown = this.usages.reduce((acc, { tokens, category }) => { - if (category) { - acc[category] = (acc[category] || 0) + tokens; - } - return acc; - }, {} as Record); - - const prompt_tokens = categoryBreakdown.prompt || 0; - const completion_tokens = - (categoryBreakdown.reasoning || 0) + - (categoryBreakdown.accepted || 0) + - (categoryBreakdown.rejected || 0); + const toolBreakdown = this.getUsageBreakdown(); + const prompt_tokens = toolBreakdown.agent || 0; + const completion_tokens = Object.entries(toolBreakdown) + .filter(([tool]) => tool !== 'agent') + .reduce((sum, [_, tokens]) => sum + tokens, 0); return { prompt_tokens, completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - completion_tokens_details: { - reasoning_tokens: categoryBreakdown.reasoning || 0, - accepted_prediction_tokens: categoryBreakdown.accepted || 0, - rejected_prediction_tokens: categoryBreakdown.rejected || 0 - } + total_tokens: prompt_tokens + completion_tokens }; }