fix: token tracking

This commit is contained in:
Han Xiao
2025-02-12 11:55:52 +08:00
parent 49217e5ecd
commit 57fb65d0c6
4 changed files with 65 additions and 62 deletions

View File

@@ -91,6 +91,28 @@ interface StreamingState {
processingQueue: boolean; processingQueue: boolean;
} }
function getTokenBudgetAndMaxAttempts(
reasoningEffort: 'low' | 'medium' | 'high' | null = 'medium',
maxCompletionTokens: number | null = null
): { tokenBudget: number, maxBadAttempts: number } {
if (maxCompletionTokens !== null) {
return {
tokenBudget: maxCompletionTokens,
maxBadAttempts: 3 // Default to medium setting for max attempts
};
}
switch (reasoningEffort) {
case 'low':
return {tokenBudget: 500000, maxBadAttempts: 2};
case 'high':
return {tokenBudget: 2000000, maxBadAttempts: 4};
case 'medium':
default:
return {tokenBudget: 1000000, maxBadAttempts: 3};
}
}
async function completeCurrentStreaming( async function completeCurrentStreaming(
streamingState: StreamingState, streamingState: StreamingState,
@@ -190,7 +212,7 @@ async function processQueue(streamingState: StreamingState, res: Response, reque
system_fingerprint: 'fp_' + requestId, system_fingerprint: 'fp_' + requestId,
choices: [{ choices: [{
index: 0, index: 0,
delta: { content: word }, delta: {content: word},
logprobs: null, logprobs: null,
finish_reason: null finish_reason: null
}] }]
@@ -210,7 +232,7 @@ async function processQueue(streamingState: StreamingState, res: Response, reque
system_fingerprint: 'fp_' + requestId, system_fingerprint: 'fp_' + requestId,
choices: [{ choices: [{
index: 0, index: 0,
delta: { content: '\n' }, delta: {content: '\n'},
logprobs: null, logprobs: null,
finish_reason: null finish_reason: null
}] }]
@@ -264,6 +286,11 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
return res.status(400).json({error: 'Last message must be from user'}); return res.status(400).json({error: 'Last message must be from user'});
} }
const {tokenBudget, maxBadAttempts} = getTokenBudgetAndMaxAttempts(
body.reasoning_effort,
body.max_completion_tokens
);
const requestId = Date.now().toString(); const requestId = Date.now().toString();
const created = Math.floor(Date.now() / 1000); const created = Math.floor(Date.now() / 1000);
const context: TrackerContext = { const context: TrackerContext = {
@@ -273,13 +300,13 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
// Add this inside the chat completions endpoint, before setting up the action listener // Add this inside the chat completions endpoint, before setting up the action listener
const streamingState: StreamingState = { const streamingState: StreamingState = {
currentlyStreaming: false, currentlyStreaming: false,
currentGenerator: null, currentGenerator: null,
remainingContent: '', remainingContent: '',
isEmitting: false, isEmitting: false,
queue: [], queue: [],
processingQueue: false processingQueue: false
}; };
if (body.stream) { if (body.stream) {
res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Content-Type', 'text/event-stream');
@@ -305,19 +332,19 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
// Set up progress listener with cleanup // Set up progress listener with cleanup
const actionListener = async (action: any) => { const actionListener = async (action: any) => {
if (action.thisStep.think) { if (action.thisStep.think) {
// Create a promise that resolves when this content is done streaming // Create a promise that resolves when this content is done streaming
await new Promise<void>(resolve => { await new Promise<void>(resolve => {
streamingState.queue.push({ streamingState.queue.push({
content: action.thisStep.think, content: action.thisStep.think,
resolve resolve
}); });
// Start processing queue if not already processing // Start processing queue if not already processing
processQueue(streamingState, res, requestId, created, body.model); processQueue(streamingState, res, requestId, created, body.model);
}); });
} }
}; };
context.actionTracker.on('action', actionListener); context.actionTracker.on('action', actionListener);
// Make sure to update the cleanup code // Make sure to update the cleanup code
@@ -330,18 +357,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
} }
try { try {
let result; const {result} = await getResponse(lastMessage.content, tokenBudget, maxBadAttempts, context)
try {
({result} = await getResponse(lastMessage.content, undefined, undefined, context));
} catch (error: any) {
// If deduplication fails, retry without it
if (error?.response?.status === 402) {
// If deduplication fails, retry with maxBadAttempt=3 to skip dedup
({result} = await getResponse(lastMessage.content, undefined, 3, context));
} else {
throw error;
}
}
const usage = context.tokenTracker.getTotalUsageSnakeCase(); const usage = context.tokenTracker.getTotalUsageSnakeCase();
if (body.stream) { if (body.stream) {
@@ -490,5 +506,4 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
}) as RequestHandler); }) as RequestHandler);
export default app; export default app;

View File

@@ -1,9 +1,9 @@
import { z } from 'zod'; import {z} from 'zod';
import { generateObject } from 'ai'; import {generateObject} from 'ai';
import { getModel, getMaxTokens } from "../config"; import {getModel, getMaxTokens} from "../config";
import { TokenTracker } from "../utils/token-tracker"; import {TokenTracker} from "../utils/token-tracker";
import { handleGenerateObjectError } from '../utils/error-handling'; import {handleGenerateObjectError} from '../utils/error-handling';
import type { DedupResponse } from '../types'; import type {DedupResponse} from '../types';
const model = getModel('dedup'); const model = getModel('dedup');
@@ -65,11 +65,11 @@ SetA: ${JSON.stringify(newQueries)}
SetB: ${JSON.stringify(existingQueries)}`; SetB: ${JSON.stringify(existingQueries)}`;
} }
export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[], tokens: number }> { export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[] }> {
try { try {
const prompt = getPrompt(newQueries, existingQueries); const prompt = getPrompt(newQueries, existingQueries);
let object; let object;
let tokens = 0; let usage;
try { try {
const result = await generateObject({ const result = await generateObject({
model, model,
@@ -78,32 +78,18 @@ export async function dedupQueries(newQueries: string[], existingQueries: string
maxTokens: getMaxTokens('dedup') maxTokens: getMaxTokens('dedup')
}); });
object = result.object; object = result.object;
tokens = result.usage?.totalTokens || 0; usage = result.usage
} catch (error) { } catch (error) {
const result = await handleGenerateObjectError<DedupResponse>(error); const result = await handleGenerateObjectError<DedupResponse>(error);
object = result.object; object = result.object;
tokens = result.totalTokens; usage = result.usage
} }
console.log('Dedup:', object.unique_queries); console.log('Dedup:', object.unique_queries);
(tracker || new TokenTracker()).trackUsage('dedup', tokens); (tracker || new TokenTracker()).trackUsage('dedup', usage);
return { unique_queries: object.unique_queries, tokens };
return {unique_queries: object.unique_queries};
} catch (error) { } catch (error) {
console.error('Error in deduplication analysis:', error); console.error('Error in deduplication analysis:', error);
throw error; throw error;
} }
} }
export async function main() {
const newQueries = process.argv[2] ? JSON.parse(process.argv[2]) : [];
const existingQueries = process.argv[3] ? JSON.parse(process.argv[3]) : [];
try {
await dedupQueries(newQueries, existingQueries);
} catch (error) {
console.error('Failed to deduplicate queries:', error);
}
}
if (require.main === module) {
main().catch(console.error);
}

View File

@@ -27,7 +27,7 @@ Must include the date and time of the latest answer.`,
?.map(support => support.segment.text) ?.map(support => support.segment.text)
.join(' ') || ''; .join(' ') || '';
(tracker || new TokenTracker()).trackUsage('grounding', usage.totalTokens); (tracker || new TokenTracker()).trackUsage('grounding', usage);
console.log('Grounding:', {text, groundedText}); console.log('Grounding:', {text, groundedText});
return text + '|' + groundedText; return text + '|' + groundedText;

View File

@@ -161,6 +161,8 @@ export interface ChatCompletionRequest {
content: string; content: string;
}>; }>;
stream?: boolean; stream?: boolean;
reasoning_effort?: 'low' | 'medium' | 'high' | null;
max_completion_tokens?: number | null;
} }
export interface ChatCompletionResponse { export interface ChatCompletionResponse {