fix: token tracking

This commit is contained in:
Han Xiao
2025-02-12 11:55:52 +08:00
parent 49217e5ecd
commit 57fb65d0c6
4 changed files with 65 additions and 62 deletions

View File

@@ -91,6 +91,28 @@ interface StreamingState {
processingQueue: boolean;
}
function getTokenBudgetAndMaxAttempts(
reasoningEffort: 'low' | 'medium' | 'high' | null = 'medium',
maxCompletionTokens: number | null = null
): { tokenBudget: number, maxBadAttempts: number } {
if (maxCompletionTokens !== null) {
return {
tokenBudget: maxCompletionTokens,
maxBadAttempts: 3 // Default to medium setting for max attempts
};
}
switch (reasoningEffort) {
case 'low':
return {tokenBudget: 500000, maxBadAttempts: 2};
case 'high':
return {tokenBudget: 2000000, maxBadAttempts: 4};
case 'medium':
default:
return {tokenBudget: 1000000, maxBadAttempts: 3};
}
}
async function completeCurrentStreaming(
streamingState: StreamingState,
@@ -190,7 +212,7 @@ async function processQueue(streamingState: StreamingState, res: Response, reque
system_fingerprint: 'fp_' + requestId,
choices: [{
index: 0,
delta: { content: word },
delta: {content: word},
logprobs: null,
finish_reason: null
}]
@@ -210,7 +232,7 @@ async function processQueue(streamingState: StreamingState, res: Response, reque
system_fingerprint: 'fp_' + requestId,
choices: [{
index: 0,
delta: { content: '\n' },
delta: {content: '\n'},
logprobs: null,
finish_reason: null
}]
@@ -264,6 +286,11 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
return res.status(400).json({error: 'Last message must be from user'});
}
const {tokenBudget, maxBadAttempts} = getTokenBudgetAndMaxAttempts(
body.reasoning_effort,
body.max_completion_tokens
);
const requestId = Date.now().toString();
const created = Math.floor(Date.now() / 1000);
const context: TrackerContext = {
@@ -273,13 +300,13 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
// Add this inside the chat completions endpoint, before setting up the action listener
const streamingState: StreamingState = {
currentlyStreaming: false,
currentGenerator: null,
remainingContent: '',
isEmitting: false,
queue: [],
processingQueue: false
};
currentlyStreaming: false,
currentGenerator: null,
remainingContent: '',
isEmitting: false,
queue: [],
processingQueue: false
};
if (body.stream) {
res.setHeader('Content-Type', 'text/event-stream');
@@ -305,19 +332,19 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
// Set up progress listener with cleanup
const actionListener = async (action: any) => {
if (action.thisStep.think) {
// Create a promise that resolves when this content is done streaming
await new Promise<void>(resolve => {
streamingState.queue.push({
content: action.thisStep.think,
resolve
});
if (action.thisStep.think) {
// Create a promise that resolves when this content is done streaming
await new Promise<void>(resolve => {
streamingState.queue.push({
content: action.thisStep.think,
resolve
});
// Start processing queue if not already processing
processQueue(streamingState, res, requestId, created, body.model);
});
}
};
// Start processing queue if not already processing
processQueue(streamingState, res, requestId, created, body.model);
});
}
};
context.actionTracker.on('action', actionListener);
// Make sure to update the cleanup code
@@ -330,18 +357,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
}
try {
let result;
try {
({result} = await getResponse(lastMessage.content, undefined, undefined, context));
} catch (error: any) {
// If deduplication fails, retry without it
if (error?.response?.status === 402) {
// If deduplication fails, retry with maxBadAttempt=3 to skip dedup
({result} = await getResponse(lastMessage.content, undefined, 3, context));
} else {
throw error;
}
}
const {result} = await getResponse(lastMessage.content, tokenBudget, maxBadAttempts, context)
const usage = context.tokenTracker.getTotalUsageSnakeCase();
if (body.stream) {
@@ -490,5 +506,4 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
}) as RequestHandler);
export default app;

View File

@@ -1,9 +1,9 @@
import { z } from 'zod';
import { generateObject } from 'ai';
import { getModel, getMaxTokens } from "../config";
import { TokenTracker } from "../utils/token-tracker";
import { handleGenerateObjectError } from '../utils/error-handling';
import type { DedupResponse } from '../types';
import {z} from 'zod';
import {generateObject} from 'ai';
import {getModel, getMaxTokens} from "../config";
import {TokenTracker} from "../utils/token-tracker";
import {handleGenerateObjectError} from '../utils/error-handling';
import type {DedupResponse} from '../types';
const model = getModel('dedup');
@@ -65,11 +65,11 @@ SetA: ${JSON.stringify(newQueries)}
SetB: ${JSON.stringify(existingQueries)}`;
}
export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[], tokens: number }> {
export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[] }> {
try {
const prompt = getPrompt(newQueries, existingQueries);
let object;
let tokens = 0;
let usage;
try {
const result = await generateObject({
model,
@@ -78,32 +78,18 @@ export async function dedupQueries(newQueries: string[], existingQueries: string
maxTokens: getMaxTokens('dedup')
});
object = result.object;
tokens = result.usage?.totalTokens || 0;
usage = result.usage
} catch (error) {
const result = await handleGenerateObjectError<DedupResponse>(error);
object = result.object;
tokens = result.totalTokens;
usage = result.usage
}
console.log('Dedup:', object.unique_queries);
(tracker || new TokenTracker()).trackUsage('dedup', tokens);
return { unique_queries: object.unique_queries, tokens };
(tracker || new TokenTracker()).trackUsage('dedup', usage);
return {unique_queries: object.unique_queries};
} catch (error) {
console.error('Error in deduplication analysis:', error);
throw error;
}
}
export async function main() {
const newQueries = process.argv[2] ? JSON.parse(process.argv[2]) : [];
const existingQueries = process.argv[3] ? JSON.parse(process.argv[3]) : [];
try {
await dedupQueries(newQueries, existingQueries);
} catch (error) {
console.error('Failed to deduplicate queries:', error);
}
}
if (require.main === module) {
main().catch(console.error);
}

View File

@@ -27,7 +27,7 @@ Must include the date and time of the latest answer.`,
?.map(support => support.segment.text)
.join(' ') || '';
(tracker || new TokenTracker()).trackUsage('grounding', usage.totalTokens);
(tracker || new TokenTracker()).trackUsage('grounding', usage);
console.log('Grounding:', {text, groundedText});
return text + '|' + groundedText;

View File

@@ -161,6 +161,8 @@ export interface ChatCompletionRequest {
content: string;
}>;
stream?: boolean;
reasoning_effort?: 'low' | 'medium' | 'high' | null;
max_completion_tokens?: number | null;
}
export interface ChatCompletionResponse {