diff --git a/src/app.ts b/src/app.ts index 2519ea8..fce4beb 100644 --- a/src/app.ts +++ b/src/app.ts @@ -91,6 +91,28 @@ interface StreamingState { processingQueue: boolean; } +function getTokenBudgetAndMaxAttempts( + reasoningEffort: 'low' | 'medium' | 'high' | null = 'medium', + maxCompletionTokens: number | null = null +): { tokenBudget: number, maxBadAttempts: number } { + if (maxCompletionTokens !== null) { + return { + tokenBudget: maxCompletionTokens, + maxBadAttempts: 3 // Default to medium setting for max attempts + }; + } + + switch (reasoningEffort) { + case 'low': + return {tokenBudget: 500000, maxBadAttempts: 2}; + case 'high': + return {tokenBudget: 2000000, maxBadAttempts: 4}; + case 'medium': + default: + return {tokenBudget: 1000000, maxBadAttempts: 3}; + } +} + async function completeCurrentStreaming( streamingState: StreamingState, @@ -190,7 +212,7 @@ async function processQueue(streamingState: StreamingState, res: Response, reque system_fingerprint: 'fp_' + requestId, choices: [{ index: 0, - delta: { content: word }, + delta: {content: word}, logprobs: null, finish_reason: null }] @@ -210,7 +232,7 @@ async function processQueue(streamingState: StreamingState, res: Response, reque system_fingerprint: 'fp_' + requestId, choices: [{ index: 0, - delta: { content: '\n' }, + delta: {content: '\n'}, logprobs: null, finish_reason: null }] @@ -264,6 +286,11 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { return res.status(400).json({error: 'Last message must be from user'}); } + const {tokenBudget, maxBadAttempts} = getTokenBudgetAndMaxAttempts( + body.reasoning_effort, + body.max_completion_tokens + ); + const requestId = Date.now().toString(); const created = Math.floor(Date.now() / 1000); const context: TrackerContext = { @@ -273,13 +300,13 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { // Add this inside the chat completions endpoint, before setting up the action listener const streamingState: StreamingState = { - currentlyStreaming: false, - currentGenerator: null, - remainingContent: '', - isEmitting: false, - queue: [], - processingQueue: false -}; + currentlyStreaming: false, + currentGenerator: null, + remainingContent: '', + isEmitting: false, + queue: [], + processingQueue: false + }; if (body.stream) { res.setHeader('Content-Type', 'text/event-stream'); @@ -305,19 +332,19 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { // Set up progress listener with cleanup const actionListener = async (action: any) => { - if (action.thisStep.think) { - // Create a promise that resolves when this content is done streaming - await new Promise(resolve => { - streamingState.queue.push({ - content: action.thisStep.think, - resolve - }); + if (action.thisStep.think) { + // Create a promise that resolves when this content is done streaming + await new Promise(resolve => { + streamingState.queue.push({ + content: action.thisStep.think, + resolve + }); - // Start processing queue if not already processing - processQueue(streamingState, res, requestId, created, body.model); - }); - } -}; + // Start processing queue if not already processing + processQueue(streamingState, res, requestId, created, body.model); + }); + } + }; context.actionTracker.on('action', actionListener); // Make sure to update the cleanup code @@ -330,18 +357,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { } try { - let result; - try { - ({result} = await getResponse(lastMessage.content, undefined, undefined, context)); - } catch (error: any) { - // If deduplication fails, retry without it - if (error?.response?.status === 402) { - // If deduplication fails, retry with maxBadAttempt=3 to skip dedup - ({result} = await getResponse(lastMessage.content, undefined, 3, context)); - } else { - throw error; - } - } + const {result} = await getResponse(lastMessage.content, tokenBudget, maxBadAttempts, context) const usage = context.tokenTracker.getTotalUsageSnakeCase(); if (body.stream) { @@ -490,5 +506,4 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => { }) as RequestHandler); - export default app; diff --git a/src/tools/dedup.ts b/src/tools/dedup.ts index e3a7580..f006d42 100644 --- a/src/tools/dedup.ts +++ b/src/tools/dedup.ts @@ -1,9 +1,9 @@ -import { z } from 'zod'; -import { generateObject } from 'ai'; -import { getModel, getMaxTokens } from "../config"; -import { TokenTracker } from "../utils/token-tracker"; -import { handleGenerateObjectError } from '../utils/error-handling'; -import type { DedupResponse } from '../types'; +import {z} from 'zod'; +import {generateObject} from 'ai'; +import {getModel, getMaxTokens} from "../config"; +import {TokenTracker} from "../utils/token-tracker"; +import {handleGenerateObjectError} from '../utils/error-handling'; +import type {DedupResponse} from '../types'; const model = getModel('dedup'); @@ -65,11 +65,11 @@ SetA: ${JSON.stringify(newQueries)} SetB: ${JSON.stringify(existingQueries)}`; } -export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[], tokens: number }> { +export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[] }> { try { const prompt = getPrompt(newQueries, existingQueries); let object; - let tokens = 0; + let usage; try { const result = await generateObject({ model, @@ -78,32 +78,18 @@ export async function dedupQueries(newQueries: string[], existingQueries: string maxTokens: getMaxTokens('dedup') }); object = result.object; - tokens = result.usage?.totalTokens || 0; + usage = result.usage } catch (error) { const result = await handleGenerateObjectError(error); object = result.object; - tokens = result.totalTokens; + usage = result.usage } console.log('Dedup:', object.unique_queries); - (tracker || new TokenTracker()).trackUsage('dedup', tokens); - return { unique_queries: object.unique_queries, tokens }; + (tracker || new TokenTracker()).trackUsage('dedup', usage); + + return {unique_queries: object.unique_queries}; } catch (error) { console.error('Error in deduplication analysis:', error); throw error; } } - -export async function main() { - const newQueries = process.argv[2] ? JSON.parse(process.argv[2]) : []; - const existingQueries = process.argv[3] ? JSON.parse(process.argv[3]) : []; - - try { - await dedupQueries(newQueries, existingQueries); - } catch (error) { - console.error('Failed to deduplicate queries:', error); - } -} - -if (require.main === module) { - main().catch(console.error); -} diff --git a/src/tools/grounding.ts b/src/tools/grounding.ts index 52161ff..f3362d7 100644 --- a/src/tools/grounding.ts +++ b/src/tools/grounding.ts @@ -27,7 +27,7 @@ Must include the date and time of the latest answer.`, ?.map(support => support.segment.text) .join(' ') || ''; - (tracker || new TokenTracker()).trackUsage('grounding', usage.totalTokens); + (tracker || new TokenTracker()).trackUsage('grounding', usage); console.log('Grounding:', {text, groundedText}); return text + '|' + groundedText; diff --git a/src/types.ts b/src/types.ts index f1c8d1d..4f98cfb 100644 --- a/src/types.ts +++ b/src/types.ts @@ -161,6 +161,8 @@ export interface ChatCompletionRequest { content: string; }>; stream?: boolean; + reasoning_effort?: 'low' | 'medium' | 'high' | null; + max_completion_tokens?: number | null; } export interface ChatCompletionResponse {