fix: use character-based token estimation and simplify token tracking

Co-Authored-By: sha.zhou@jina.ai <sha.zhou@jina.ai>
This commit is contained in:
Devin AI
2025-02-11 11:29:14 +00:00
parent 8af35c6640
commit baf5263146
6 changed files with 147 additions and 69 deletions

View File

@@ -11,6 +11,9 @@ describe('/v1/chat/completions', () => {
beforeEach(async () => {
// Set NODE_ENV to test to prevent server from auto-starting
process.env.NODE_ENV = 'test';
process.env.LLM_PROVIDER = 'openai'; // Use OpenAI for testing
process.env.OPENAI_API_KEY = 'test-openai-key';
process.env.JINA_API_KEY = 'test-jina-key';
// Clean up any existing secret
const existingSecretIndex = process.argv.findIndex(arg => arg.startsWith('--secret='));
@@ -32,11 +35,14 @@ describe('/v1/chat/completions', () => {
emitter.removeAllListeners();
emitter.setMaxListeners(emitter.getMaxListeners() + 1);
// Clean up test secret
// Clean up test secret and environment variables
const secretIndex = process.argv.findIndex(arg => arg.startsWith('--secret='));
if (secretIndex !== -1) {
process.argv.splice(secretIndex, 1);
}
delete process.env.LLM_PROVIDER;
delete process.env.GEMINI_API_KEY;
delete process.env.JINA_API_KEY;
// Wait for any pending promises to settle
await new Promise(resolve => setTimeout(resolve, 500));
@@ -258,12 +264,7 @@ describe('/v1/chat/completions', () => {
expect(validResponse.body.usage).toMatchObject({
prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number),
total_tokens: expect.any(Number),
completion_tokens_details: {
reasoning_tokens: expect.any(Number),
accepted_prediction_tokens: expect.any(Number),
rejected_prediction_tokens: expect.any(Number)
}
total_tokens: expect.any(Number)
});
// Verify token counts are reasonable
@@ -274,34 +275,102 @@ describe('/v1/chat/completions', () => {
);
});
it('should provide token usage in Vercel AI SDK format', async () => {
it('should provide accurate token counts for various message lengths', async () => {
const shortMessage = 'test';
const mediumMessage = 'This is a medium length message that should have more tokens than the short message.';
const longMessage = 'This is a very long message that should have many more tokens. '.repeat(10);
// Test short message
const shortResponse = await request(app)
.post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`)
.send({
model: 'test-model',
messages: [{ role: 'user', content: shortMessage }]
});
// Test medium message
const mediumResponse = await request(app)
.post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`)
.send({
model: 'test-model',
messages: [{ role: 'user', content: mediumMessage }]
});
// Test long message
const longResponse = await request(app)
.post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`)
.send({
model: 'test-model',
messages: [{ role: 'user', content: longMessage }]
});
// Verify response format
[shortResponse, mediumResponse, longResponse].forEach(response => {
expect(response.status).toBe(200);
expect(response.body.usage).toMatchObject({
prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number),
total_tokens: expect.any(Number)
});
});
// Verify token counts increase with message length
const shortTokens = shortResponse.body.usage.prompt_tokens;
const mediumTokens = mediumResponse.body.usage.prompt_tokens;
const longTokens = longResponse.body.usage.prompt_tokens;
expect(mediumTokens).toBeGreaterThan(shortTokens);
expect(longTokens).toBeGreaterThan(mediumTokens);
// Verify token counts match our estimation (chars/4)
[
{ content: shortMessage, tokens: shortTokens },
{ content: mediumMessage, tokens: mediumTokens },
{ content: longMessage, tokens: longTokens }
].forEach(({ content, tokens }) => {
const expectedTokens = Math.ceil(Buffer.byteLength(content, 'utf-8') / 4);
expect(tokens).toBe(expectedTokens);
});
// Verify total tokens calculation
[shortResponse, mediumResponse, longResponse].forEach(response => {
expect(response.body.usage.total_tokens).toBe(
response.body.usage.prompt_tokens + response.body.usage.completion_tokens
);
});
});
it('should count tokens correctly for multiple messages', async () => {
const messages = [
{ role: 'system', content: 'You are a helpful assistant.' },
{ role: 'user', content: 'Hello!' },
{ role: 'assistant', content: 'Hi there! How can I help you?' },
{ role: 'user', content: 'What is the weather?' }
];
const response = await request(app)
.post('/v1/chat/completions')
.set('Authorization', `Bearer ${TEST_SECRET}`)
.send({
model: 'test-model',
messages: [{ role: 'user', content: 'test' }]
messages
});
expect(response.status).toBe(200);
const usage = response.body.usage;
expect(usage).toMatchObject({
expect(response.status).toBe(200);
expect(response.body.usage).toMatchObject({
prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number),
total_tokens: expect.any(Number),
completion_tokens_details: {
reasoning_tokens: expect.any(Number),
accepted_prediction_tokens: expect.any(Number),
rejected_prediction_tokens: expect.any(Number)
}
total_tokens: expect.any(Number)
});
// Verify token counts are reasonable
expect(usage.prompt_tokens).toBeGreaterThan(0);
expect(usage.completion_tokens).toBeGreaterThan(0);
expect(usage.total_tokens).toBe(
usage.prompt_tokens + usage.completion_tokens
);
// Verify token count matches our estimation for all messages combined
const expectedPromptTokens = messages.reduce((total, msg) => {
return total + Math.ceil(Buffer.byteLength(msg.content, 'utf-8') / 4);
}, 0);
expect(response.body.usage.prompt_tokens).toBe(expectedPromptTokens);
});
});

View File

@@ -205,10 +205,13 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
actionTracker: new ActionTracker()
};
// Track prompt tokens for the initial message
// Use Vercel's token counting convention - 1 token per message
const messageTokens = body.messages.length;
context.tokenTracker.trackUsage('agent', messageTokens, TOKEN_CATEGORIES.PROMPT);
// Track prompt tokens for each message using actual content length
for (const message of body.messages) {
// Estimate tokens using character count / 4 as a rough approximation
// This will be replaced with actual Gemini tokenizer in a future update
const estimatedTokens = Math.ceil(Buffer.byteLength(message.content, 'utf-8') / 4);
context.tokenTracker.trackUsage('agent', estimatedTokens);
}
// Add this inside the chat completions endpoint, before setting up the action listener
const streamingState: StreamingState = {
@@ -328,13 +331,14 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
// Track tokens based on action type
if (result.action === 'answer') {
// Track accepted prediction tokens for the final answer using Vercel's convention
const answerTokens = 1; // Default to 1 token per answer
context.tokenTracker.trackUsage('evaluator', answerTokens, TOKEN_CATEGORIES.ACCEPTED);
// Track tokens for the final answer using content length estimation
const content = result.action === 'answer' ? buildMdFromAnswer(result) : result.think;
const estimatedTokens = Math.ceil(Buffer.byteLength(content, 'utf-8') / 4);
context.tokenTracker.trackUsage('evaluator', estimatedTokens);
} else {
// Track rejected prediction tokens for non-answer responses
const rejectedTokens = 1; // Default to 1 token per rejected response
context.tokenTracker.trackUsage('evaluator', rejectedTokens, TOKEN_CATEGORIES.REJECTED);
// Track tokens for non-answer responses using content length estimation
const estimatedTokens = Math.ceil(Buffer.byteLength(result.think, 'utf-8') / 4);
context.tokenTracker.trackUsage('evaluator', estimatedTokens);
}
if (body.stream) {
@@ -412,11 +416,10 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
requestId
});
// Track error as rejected tokens with Vercel token counting
// Track error tokens using content length estimation
const errorMessage = error?.message || 'An error occurred';
// Default to 1 token for errors as per Vercel AI SDK convention
const errorTokens = 1;
context.tokenTracker.trackUsage('evaluator', errorTokens, TOKEN_CATEGORIES.REJECTED);
const estimatedTokens = Math.ceil(Buffer.byteLength(errorMessage, 'utf-8') / 4);
context.tokenTracker.trackUsage('evaluator', estimatedTokens);
// Clean up event listeners
context.actionTracker.removeAllListeners('action');

View File

@@ -381,7 +381,9 @@ export async function evaluateQuestion(
maxTokens: getMaxTokens('evaluator')
});
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0);
if (tracker) {
tracker.trackUsage('evaluator', result.usage?.totalTokens || 0);
}
console.log('Question Evaluation:', result.object);
// Always include definitive in types
@@ -419,7 +421,9 @@ async function performEvaluation(
maxTokens: params.maxTokens
});
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage?.totalTokens || 0);
if (tracker) {
tracker.trackUsage('evaluator', result.usage?.totalTokens || 0);
}
console.log(`${evaluationType} Evaluation:`, result.object);
return result;
@@ -523,7 +527,9 @@ export async function evaluateAnswer(
}
} catch (error) {
const errorResult = await handleGenerateObjectError<EvaluationResponse>(error);
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.totalTokens || 0);
if (tracker) {
tracker.trackUsage('evaluator', errorResult.totalTokens || 0);
}
return {response: errorResult.object};
}
}
@@ -557,4 +563,4 @@ async function fetchSourceContent(urls: string[], tracker?: TokenTracker): Promi
console.error('Error fetching source content:', error);
return '';
}
}
}

View File

@@ -62,8 +62,9 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons
const totalTokens = response.data.reduce((sum, item) => sum + (item.usage?.tokens || 0), 0);
console.log('Total URLs:', response.data.length);
const tokenTracker = tracker || new TokenTracker();
tokenTracker.trackUsage('search', totalTokens);
if (tracker) {
tracker.trackUsage('search', totalTokens);
}
resolve({ response, tokens: totalTokens });
});
@@ -81,4 +82,4 @@ export function search(query: string, tracker?: TokenTracker): Promise<{ respons
req.end();
});
}
}

View File

@@ -70,11 +70,20 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response
tokens: response.data.usage?.tokens || 0
});
const tokens = response.data.usage?.tokens || 0;
const tokenTracker = tracker || new TokenTracker();
tokenTracker.trackUsage('read', tokens);
const apiTokens = response.data.usage?.tokens || 0;
if (tracker) {
// Track API response tokens
tracker.trackUsage('read_api', apiTokens);
// Track content length tokens using the same estimation method
if (response.data.content) {
const contentTokens = Math.ceil(Buffer.byteLength(response.data.content, 'utf-8') / 4);
tracker.trackUsage('read_content', contentTokens);
}
}
resolve({ response, tokens });
resolve({ response, tokens: apiTokens });
});
});
@@ -95,4 +104,4 @@ export function readUrl(url: string, tracker?: TokenTracker): Promise<{ response
export function removeAllLineBreaks(text: string) {
return text.replace(/(\r\n|\n|\r)/gm, " ");
}
}

View File

@@ -30,6 +30,8 @@ export class TokenTracker extends EventEmitter {
if (!this.budget || currentTotal + tokens <= this.budget) {
const usage = { tool, tokens, category };
this.usages.push(usage);
console.log(`[TokenTracker] Adding ${tokens} tokens from ${tool}${category ? ` (${category})` : ''}`);
console.log(`[TokenTracker] New total: ${this.getTotalUsage()}`);
this.emit('usage', usage);
}
}
@@ -55,28 +57,16 @@ export class TokenTracker extends EventEmitter {
rejected_prediction_tokens: number;
};
} {
const categoryBreakdown = this.usages.reduce((acc, { tokens, category }) => {
if (category) {
acc[category] = (acc[category] || 0) + tokens;
}
return acc;
}, {} as Record<string, number>);
const prompt_tokens = categoryBreakdown.prompt || 0;
const completion_tokens =
(categoryBreakdown.reasoning || 0) +
(categoryBreakdown.accepted || 0) +
(categoryBreakdown.rejected || 0);
const toolBreakdown = this.getUsageBreakdown();
const prompt_tokens = toolBreakdown.agent || 0;
const completion_tokens = Object.entries(toolBreakdown)
.filter(([tool]) => tool !== 'agent')
.reduce((sum, [_, tokens]) => sum + tokens, 0);
return {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
completion_tokens_details: {
reasoning_tokens: categoryBreakdown.reasoning || 0,
accepted_prediction_tokens: categoryBreakdown.accepted || 0,
rejected_prediction_tokens: categoryBreakdown.rejected || 0
}
total_tokens: prompt_tokens + completion_tokens
};
}