fix: use character-based token estimation and simplify token tracking

Co-Authored-By: sha.zhou@jina.ai <sha.zhou@jina.ai>
This commit is contained in:
Devin AI
2025-02-11 11:29:14 +00:00
parent 8af35c6640
commit baf5263146
6 changed files with 147 additions and 69 deletions

View File

@@ -11,6 +11,9 @@ describe('/v1/chat/completions', () => {
beforeEach(async () => { beforeEach(async () => {
// Set NODE_ENV to test to prevent server from auto-starting // Set NODE_ENV to test to prevent server from auto-starting
process.env.NODE_ENV = 'test'; process.env.NODE_ENV = 'test';
process.env.LLM_PROVIDER = 'openai'; // Use OpenAI for testing
process.env.OPENAI_API_KEY = 'test-openai-key';
process.env.JINA_API_KEY = 'test-jina-key';
// Clean up any existing secret // Clean up any existing secret
const existingSecretIndex = process.argv.findIndex(arg => arg.startsWith('--secret=')); const existingSecretIndex = process.argv.findIndex(arg => arg.startsWith('--secret='));
@@ -32,11 +35,14 @@ describe('/v1/chat/completions', () => {
emitter.removeAllListeners(); emitter.removeAllListeners();
emitter.setMaxListeners(emitter.getMaxListeners() + 1); emitter.setMaxListeners(emitter.getMaxListeners() + 1);
// Clean up test secret // Clean up test secret and environment variables
const secretIndex = process.argv.findIndex(arg => arg.startsWith('--secret=')); const secretIndex = process.argv.findIndex(arg => arg.startsWith('--secret='));
if (secretIndex !== -1) { if (secretIndex !== -1) {
process.argv.splice(secretIndex, 1); process.argv.splice(secretIndex, 1);
} }
delete process.env.LLM_PROVIDER;
delete process.env.GEMINI_API_KEY;
delete process.env.JINA_API_KEY;
// Wait for any pending promises to settle // Wait for any pending promises to settle
await new Promise(resolve => setTimeout(resolve, 500)); await new Promise(resolve => setTimeout(resolve, 500));
@@ -258,12 +264,7 @@ describe('/v1/chat/completions', () => {
expect(validResponse.body.usage).toMatchObject({ expect(validResponse.body.usage).toMatchObject({
prompt_tokens: expect.any(Number), prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number), completion_tokens: expect.any(Number),
total_tokens: expect.any(Number), total_tokens: expect.any(Number)
completion_tokens_details: {
reasoning_tokens: expect.any(Number),
accepted_prediction_tokens: expect.any(Number),
rejected_prediction_tokens: expect.any(Number)
}
}); });
// Verify token counts are reasonable // Verify token counts are reasonable
@@ -274,34 +275,102 @@ describe('/v1/chat/completions', () => {
); );
}); });
it('should provide token usage in Vercel AI SDK format', async () => { it('should provide accurate token counts for various message lengths', async () => {
const shortMessage = 'test';
const mediumMessage = 'This is a medium length message that should have more tokens than the short message.';
const longMessage = 'This is a very long message that should have many more tokens. '.repeat(10);
// Test short message
const shortResponse = await request(app)
.post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`)
.send({
model: 'test-model',
messages: [{ role: 'user', content: shortMessage }]
});
// Test medium message
const mediumResponse = await request(app)
.post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`)
.send({
model: 'test-model',
messages: [{ role: 'user', content: mediumMessage }]
});
// Test long message
const longResponse = await request(app)
.post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`)
.send({
model: 'test-model',
messages: [{ role: 'user', content: longMessage }]
});
// Verify response format
[shortResponse, mediumResponse, longResponse].forEach(response => {
expect(response.status).toBe(200);
expect(response.body.usage).toMatchObject({
prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number),
total_tokens: expect.any(Number)
});
});
// Verify token counts increase with message length
const shortTokens = shortResponse.body.usage.prompt_tokens;
const mediumTokens = mediumResponse.body.usage.prompt_tokens;
const longTokens = longResponse.body.usage.prompt_tokens;
expect(mediumTokens).toBeGreaterThan(shortTokens);
expect(longTokens).toBeGreaterThan(mediumTokens);
// Verify token counts match our estimation (chars/4)
[
{ content: shortMessage, tokens: shortTokens },
{ content: mediumMessage, tokens: mediumTokens },
{ content: longMessage, tokens: longTokens }
].forEach(({ content, tokens }) => {
const expectedTokens = Math.ceil(Buffer.byteLength(content, 'utf-8') / 4);
expect(tokens).toBe(expectedTokens);
});
// Verify total tokens calculation
[shortResponse, mediumResponse, longResponse].forEach(response => {
expect(response.body.usage.total_tokens).toBe(
response.body.usage.prompt_tokens + response.body.usage.completion_tokens
);
});
});
it('should count tokens correctly for multiple messages', async () => {
const messages = [
{ role: 'system', content: 'You are a helpful assistant.' },
{ role: 'user', content: 'Hello!' },
{ role: 'assistant', content: 'Hi there! How can I help you?' },
{ role: 'user', content: 'What is the weather?' }
];
const response = await request(app) const response = await request(app)
.post('/v1/chat/completions') .post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`) .set('Authorization', `Bearer ${TEST_SECRET}`)
.send({ .send({
model: 'test-model', model: 'test-model',
messages: [{ role: 'user', content: 'test' }] messages
}); });
expect(response.status).toBe(200);
const usage = response.body.usage;
expect(usage).toMatchObject({ expect(response.status).toBe(200);
expect(response.body.usage).toMatchObject({
prompt_tokens: expect.any(Number), prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number), completion_tokens: expect.any(Number),
total_tokens: expect.any(Number), total_tokens: expect.any(Number)
completion_tokens_details: {
reasoning_tokens: expect.any(Number),
accepted_prediction_tokens: expect.any(Number),
rejected_prediction_tokens: expect.any(Number)
}
}); });
// Verify token counts are reasonable // Verify token count matches our estimation for all messages combined
expect(usage.prompt_tokens).toBeGreaterThan(0); const expectedPromptTokens = messages.reduce((total, msg) => {
expect(usage.completion_tokens).toBeGreaterThan(0); return total + Math.ceil(Buffer.byteLength(msg.content, 'utf-8') / 4);
expect(usage.total_tokens).toBe( }, 0);
usage.prompt_tokens + usage.completion_tokens
); expect(response.body.usage.prompt_tokens).toBe(expectedPromptTokens);
}); });
}); });

View File

@@ -205,10 +205,13 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
actionTracker: new ActionTracker() actionTracker: new ActionTracker()
}; };
// Track prompt tokens for the initial message // Track prompt tokens for each message using actual content length
// Use Vercel's token counting convention - 1 token per message for (const message of body.messages) {
const messageTokens = body.messages.length; // Estimate tokens using character count / 4 as a rough approximation
context.tokenTracker.trackUsage('agent', messageTokens, TOKEN_CATEGORIES.PROMPT); // This will be replaced with actual Gemini tokenizer in a future update
const estimatedTokens = Math.ceil(Buffer.byteLength(message.content, 'utf-8') / 4);
context.tokenTracker.trackUsage('agent', estimatedTokens);
}
// Add this inside the chat completions endpoint, before setting up the action listener // Add this inside the chat completions endpoint, before setting up the action listener
const streamingState: StreamingState = { const streamingState: StreamingState = {
@@ -328,13 +331,14 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
// Track tokens based on action type // Track tokens based on action type
if (result.action === 'answer') { if (result.action === 'answer') {
// Track accepted prediction tokens for the final answer using Vercel's convention // Track tokens for the final answer using content length estimation
const answerTokens = 1; // Default to 1 token per answer const content = result.action === 'answer' ? buildMdFromAnswer(result) : result.think;
context.tokenTracker.trackUsage('evaluator', answerTokens, TOKEN_CATEGORIES.ACCEPTED); const estimatedTokens = Math.ceil(Buffer.byteLength(content, 'utf-8') / 4);
context.tokenTracker.trackUsage('evaluator', estimatedTokens);
} else { } else {
// Track rejected prediction tokens for non-answer responses // Track tokens for non-answer responses using content length estimation
const rejectedTokens = 1; // Default to 1 token per rejected response const estimatedTokens = Math.ceil(Buffer.byteLength(result.think, 'utf-8') / 4);
context.tokenTracker.trackUsage('evaluator', rejectedTokens, TOKEN_CATEGORIES.REJECTED); context.tokenTracker.trackUsage('evaluator', estimatedTokens);
} }
if (body.stream) { if (body.stream) {
@@ -412,11 +416,10 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
requestId requestId
}); });
// Track error as rejected tokens with Vercel token counting // Track error tokens using content length estimation
const errorMessage = error?.message || 'An error occurred'; const errorMessage = error?.message || 'An error occurred';
// Default to 1 token for errors as per Vercel AI SDK convention const estimatedTokens = Math.ceil(Buffer.byteLength(errorMessage, 'utf-8') / 4);
const errorTokens = 1; context.tokenTracker.trackUsage('evaluator', estimatedTokens);
context.tokenTracker.trackUsage('evaluator', errorTokens, TOKEN_CATEGORIES.REJECTED);
// Clean up event listeners // Clean up event listeners
context.actionTracker.removeAllListeners('action'); context.actionTracker.removeAllListeners('action');

View File

@@ -381,7 +381,9 @@ export async function evaluateQuestion(
maxTokens: getMaxTokens('evaluator') maxTokens: getMaxTokens('evaluator')
}); });
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0); if (tracker) {
tracker.trackUsage('evaluator', result.usage?.totalTokens || 0);
}
console.log('Question Evaluation:', result.object); console.log('Question Evaluation:', result.object);
// Always include definitive in types // Always include definitive in types
@@ -419,7 +421,9 @@ async function performEvaluation(
maxTokens: params.maxTokens maxTokens: params.maxTokens
}); });
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0); if (tracker) {
tracker.trackUsage('evaluator', result.usage?.totalTokens || 0);
}
console.log(`${evaluationType} Evaluation:`, result.object); console.log(`${evaluationType} Evaluation:`, result.object);
return result; return result;
@@ -523,7 +527,9 @@ export async function evaluateAnswer(
} }
} catch (error) { } catch (error) {
const errorResult = await handleGenerateObjectError<EvaluationResponse>(error); const errorResult = await handleGenerateObjectError<EvaluationResponse>(error);
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.totalTokens || 0); if (tracker) {
tracker.trackUsage('evaluator', errorResult.totalTokens || 0);
}
return {response: errorResult.object}; return {response: errorResult.object};
} }
} }
@@ -557,4 +563,4 @@ async function fetchSourceContent(urls: string[], tracker?: TokenTracker): Promi
console.error('Error fetching source content:', error); console.error('Error fetching source content:', error);
return ''; return '';
} }
} }

View File

@@ -62,8 +62,9 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons
const totalTokens = response.data.reduce((sum, item) => sum + (item.usage?.tokens || 0), 0); const totalTokens = response.data.reduce((sum, item) => sum + (item.usage?.tokens || 0), 0);
console.log('Total URLs:', response.data.length); console.log('Total URLs:', response.data.length);
const tokenTracker = tracker || new TokenTracker(); if (tracker) {
tokenTracker.trackUsage('search', totalTokens); tracker.trackUsage('search', totalTokens);
}
resolve({ response, tokens: totalTokens }); resolve({ response, tokens: totalTokens });
}); });
@@ -81,4 +82,4 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons
req.end(); req.end();
}); });
} }

View File

@@ -70,11 +70,20 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response
tokens: response.data.usage?.tokens || 0 tokens: response.data.usage?.tokens || 0
}); });
const tokens = response.data.usage?.tokens || 0; const apiTokens = response.data.usage?.tokens || 0;
const tokenTracker = tracker || new TokenTracker();
tokenTracker.trackUsage('read', tokens); if (tracker) {
// Track API response tokens
tracker.trackUsage('read_api', apiTokens);
// Track content length tokens using the same estimation method
if (response.data.content) {
const contentTokens = Math.ceil(Buffer.byteLength(response.data.content, 'utf-8') / 4);
tracker.trackUsage('read_content', contentTokens);
}
}
resolve({ response, tokens }); resolve({ response, tokens: apiTokens });
}); });
}); });
@@ -95,4 +104,4 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response
export function removeAllLineBreaks(text: string) { export function removeAllLineBreaks(text: string) {
return text.replace(/(\r\n|\n|\r)/gm, " "); return text.replace(/(\r\n|\n|\r)/gm, " ");
} }

View File

@@ -30,6 +30,8 @@ export class TokenTracker extends EventEmitter {
if (!this.budget || currentTotal + tokens <= this.budget) { if (!this.budget || currentTotal + tokens <= this.budget) {
const usage = { tool, tokens, category }; const usage = { tool, tokens, category };
this.usages.push(usage); this.usages.push(usage);
console.log(`[TokenTracker] Adding ${tokens} tokens from ${tool}${category ? ` (${category})` : ''}`);
console.log(`[TokenTracker] New total: ${this.getTotalUsage()}`);
this.emit('usage', usage); this.emit('usage', usage);
} }
} }
@@ -55,28 +57,16 @@ export class TokenTracker extends EventEmitter {
rejected_prediction_tokens: number; rejected_prediction_tokens: number;
}; };
} { } {
const categoryBreakdown = this.usages.reduce((acc, { tokens, category }) => { const toolBreakdown = this.getUsageBreakdown();
if (category) { const prompt_tokens = toolBreakdown.agent || 0;
acc[category] = (acc[category] || 0) + tokens; const completion_tokens = Object.entries(toolBreakdown)
} .filter(([tool]) => tool !== 'agent')
return acc; .reduce((sum, [_, tokens]) => sum + tokens, 0);
}, {} as Record<string, number>);
const prompt_tokens = categoryBreakdown.prompt || 0;
const completion_tokens =
(categoryBreakdown.reasoning || 0) +
(categoryBreakdown.accepted || 0) +
(categoryBreakdown.rejected || 0);
return { return {
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens: prompt_tokens + completion_tokens, total_tokens: prompt_tokens + completion_tokens
completion_tokens_details: {
reasoning_tokens: categoryBreakdown.reasoning || 0,
accepted_prediction_tokens: categoryBreakdown.accepted || 0,
rejected_prediction_tokens: categoryBreakdown.rejected || 0
}
}; };
} }