From 5be008e8b94444950730902017613834736e5f27 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 2 Feb 2025 23:25:54 +0800 Subject: [PATCH] feat: add action tracker and reset token tracker (#8) * feat: add action tracker and reset token tracker Co-Authored-By: Han Xiao * refactor: make trackers request-scoped Co-Authored-By: Han Xiao --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Han Xiao --- src/agent.ts | 33 +++++++++++++------- src/server.ts | 62 +++++++++++-------------------------- src/tools/evaluator.ts | 6 ++-- src/tools/query-rewriter.ts | 6 ++-- src/tools/read.ts | 6 ++-- src/tools/search.ts | 6 ++-- src/types/tracker.ts | 7 +++++ src/utils/action-tracker.ts | 36 +++++++++++++++++++++ src/utils/token-tracker.ts | 4 +-- 9 files changed, 95 insertions(+), 71 deletions(-) create mode 100644 src/types/tracker.ts create mode 100644 src/utils/action-tracker.ts diff --git a/src/agent.ts b/src/agent.ts index d36f3b1..376b5e3 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -8,8 +8,10 @@ import {dedupQueries} from "./tools/dedup"; import {evaluateAnswer} from "./tools/evaluator"; import {analyzeSteps} from "./tools/error-analyzer"; import {GEMINI_API_KEY, JINA_API_KEY, SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config"; -import {tokenTracker} from "./utils/token-tracker"; +import {TokenTracker} from "./utils/token-tracker"; +import {ActionTracker} from "./utils/action-tracker"; import {StepAction, SchemaProperty, ResponseSchema, AnswerAction} from "./types"; +import {TrackerContext} from "./types/tracker"; async function sleep(ms: number) { const seconds = Math.ceil(ms / 1000); @@ -241,7 +243,12 @@ function removeAllLineBreaks(text: string) { return text.replace(/(\r\n|\n|\r)/gm, " "); } -export async function getResponse(question: string, tokenBudget: number = 1_000_000, maxBadAttempts: number = 3): Promise { +export async function getResponse(question: string, tokenBudget: number = 1_000_000, maxBadAttempts: number = 3): Promise<{ result: StepAction; context: TrackerContext }> { + const context: TrackerContext = { + tokenTracker: new TokenTracker(), + actionTracker: new ActionTracker() + }; + context.actionTracker.trackAction({ gaps: [question], totalStep: 0, badAttempts: 0 }); let step = 0; let totalStep = 0; let badAttempts = 0; @@ -261,12 +268,13 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ const allURLs: Record = {}; const visitedURLs: string[] = []; - while (tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) { + while (context.tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) { // add 1s delay to avoid rate limiting await sleep(STEP_SLEEP); step++; totalStep++; - const budgetPercentage = (tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2); + context.actionTracker.trackAction({ totalStep, thisStep, gaps, badAttempts }); + const budgetPercentage = (context.tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2); console.log(`Step ${totalStep} / Budget used ${budgetPercentage}%`); console.log('Gaps:', gaps); allowReflect = allowReflect && (gaps.length <= 1); @@ -302,7 +310,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ const result = await model.generateContent(prompt); const response = await result.response; const usage = response.usageMetadata; - tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0); + context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0); thisStep = JSON.parse(response.text()); @@ -325,7 +333,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ ...thisStep, }); - const {response: evaluation} = await evaluateAnswer(currentQuestion, thisStep.answer); + const {response: evaluation} = await evaluateAnswer(currentQuestion, thisStep.answer, context.tokenTracker); if (currentQuestion === question) { @@ -545,7 +553,7 @@ You decided to think out of the box or cut from a completely different angle. const urlResults = await Promise.all( uniqueURLs.map(async (url: string) => { - const {response, tokens} = await readUrl(url, JINA_API_KEY); + const {response, tokens} = await readUrl(url, JINA_API_KEY, context.tokenTracker); allKnowledge.push({ question: `What is in ${response.data.url}?`, answer: removeAllLineBreaks(response.data.content), @@ -592,7 +600,7 @@ You decided to think out of the box or cut from a completely different angle.`); totalStep++; await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep); if (isAnswered) { - return thisStep; + return { result: thisStep, context }; } else { console.log('Enter Beast mode!!!') const prompt = getPrompt( @@ -621,12 +629,12 @@ You decided to think out of the box or cut from a completely different angle.`); const result = await model.generateContent(prompt); const response = await result.response; const usage = response.usageMetadata; - tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0); + context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0); await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep); thisStep = JSON.parse(response.text()); console.log(thisStep) - return thisStep; + return { result: thisStep, context }; } } @@ -648,9 +656,10 @@ const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); export async function main() { const question = process.argv[2] || ""; - const finalStep = await getResponse(question) as AnswerAction; + const { result: finalStep } = await getResponse(question) as { result: AnswerAction; context: TrackerContext }; console.log('Final Answer:', finalStep.answer); - tokenTracker.printSummary(); + const tracker = new TokenTracker(); + tracker.printSummary(); } if (require.main === module) { diff --git a/src/server.ts b/src/server.ts index 20db2a5..08aa9be 100644 --- a/src/server.ts +++ b/src/server.ts @@ -2,8 +2,8 @@ import express, { Request, Response, RequestHandler } from 'express'; import cors from 'cors'; import { EventEmitter } from 'events'; import { getResponse } from './agent'; -import { tokenTracker } from './utils/token-tracker'; import { StepAction } from './types'; +import { TrackerContext } from './types/tracker'; import fs from 'fs/promises'; import path from 'path'; @@ -51,29 +51,21 @@ app.get('/api/v1/stream/:requestId', ((req: Request, res: StreamResponse) => { }); }) as RequestHandler); -function createProgressEmitter(requestId: string, budget: number | undefined, thisStep: StepAction | undefined) { - return (message: string, step: number, budgetPercentage?: string) => { - const budgetInfo = budgetPercentage ? { - used: tokenTracker.getTotalUsage(), +function createProgressEmitter(requestId: string, budget: number | undefined, context: TrackerContext) { + return () => { + const state = context.actionTracker.getState(); + const budgetInfo = { + used: context.tokenTracker.getTotalUsage(), total: budget || 1_000_000, - percentage: budgetPercentage - } : undefined; + percentage: ((context.tokenTracker.getTotalUsage() / (budget || 1_000_000)) * 100).toFixed(2) + }; - if (thisStep?.action && thisStep?.thoughts) { - eventEmitter.emit(`progress-${requestId}`, { - type: 'progress', - data: { ...thisStep, totalStep: step }, - step, - budget: budgetInfo - }); - } else { - eventEmitter.emit(`progress-${requestId}`, { - type: 'progress', - data: message, - step, - budget: budgetInfo - }); - } + eventEmitter.emit(`progress-${requestId}`, { + type: 'progress', + data: { ...state.thisStep, totalStep: state.totalStep }, + step: state.totalStep, + budget: budgetInfo + }); }; } @@ -87,32 +79,14 @@ app.post('/api/v1/query', (async (req: QueryRequest, res: Response) => { const requestId = Date.now().toString(); res.json({ requestId }); - // Store original console.log - const originalConsoleLog = console.log; - let thisStep: StepAction | undefined; - try { - const emitProgress = createProgressEmitter(requestId, budget, thisStep); - - // Override console.log to track progress - console.log = (...args: any[]) => { - originalConsoleLog(...args); - const message = args.join(' '); - if (message.includes('Step') || message.includes('Budget used')) { - const step = parseInt(message.match(/Step (\d+)/)?.[1] || '0'); - const budgetPercentage = message.match(/Budget used ([\d.]+)%/)?.[1]; - emitProgress(message, step, budgetPercentage); - } - }; - - const result = await getResponse(q, budget, maxBadAttempt); - thisStep = result; + const { result, context } = await getResponse(q, budget, maxBadAttempt); + const emitProgress = createProgressEmitter(requestId, budget, context); + context.actionTracker.on('action', emitProgress); await storeTaskResult(requestId, result); eventEmitter.emit(`progress-${requestId}`, { type: 'answer', data: result }); } catch (error: any) { eventEmitter.emit(`progress-${requestId}`, { type: 'error', data: error?.message || 'Unknown error' }); - } finally { - console.log = originalConsoleLog; } }) as RequestHandler); @@ -145,4 +119,4 @@ app.listen(port, () => { console.log(`Server running at http://localhost:${port}`); }); -export default app; \ No newline at end of file +export default app; diff --git a/src/tools/evaluator.ts b/src/tools/evaluator.ts index 1d2ffb7..222e0d2 100644 --- a/src/tools/evaluator.ts +++ b/src/tools/evaluator.ts @@ -1,6 +1,6 @@ import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; import { GEMINI_API_KEY, modelConfigs } from "../config"; -import { tokenTracker } from "../utils/token-tracker"; +import { TokenTracker } from "../utils/token-tracker"; import { EvaluationResponse } from '../types'; @@ -63,7 +63,7 @@ Question: ${JSON.stringify(question)} Answer: ${JSON.stringify(answer)}`; } -export async function evaluateAnswer(question: string, answer: string): Promise<{ response: EvaluationResponse, tokens: number }> { +export async function evaluateAnswer(question: string, answer: string, tracker?: TokenTracker): Promise<{ response: EvaluationResponse, tokens: number }> { try { const prompt = getPrompt(question, answer); const result = await model.generateContent(prompt); @@ -75,7 +75,7 @@ export async function evaluateAnswer(question: string, answer: string): Promise< reason: json.reasoning }); const tokens = usage?.totalTokenCount || 0; - tokenTracker.trackUsage('evaluator', tokens); + (tracker || new TokenTracker()).trackUsage('evaluator', tokens); return { response: json, tokens }; } catch (error) { console.error('Error in answer evaluation:', error); diff --git a/src/tools/query-rewriter.ts b/src/tools/query-rewriter.ts index 60e14e6..ad991a9 100644 --- a/src/tools/query-rewriter.ts +++ b/src/tools/query-rewriter.ts @@ -1,6 +1,6 @@ import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; import { GEMINI_API_KEY, modelConfigs } from "../config"; -import { tokenTracker } from "../utils/token-tracker"; +import { TokenTracker } from "../utils/token-tracker"; import { SearchAction } from "../types"; import { KeywordsResponse } from '../types'; @@ -105,7 +105,7 @@ Intention: ${action.thoughts} `; } -export async function rewriteQuery(action: SearchAction): Promise<{ queries: string[], tokens: number }> { +export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[], tokens: number }> { try { const prompt = getPrompt(action); const result = await model.generateContent(prompt); @@ -115,7 +115,7 @@ export async function rewriteQuery(action: SearchAction): Promise<{ queries: str console.log('Query rewriter:', json.queries); const tokens = usage?.totalTokenCount || 0; - tokenTracker.trackUsage('query-rewriter', tokens); + (tracker || new TokenTracker()).trackUsage('query-rewriter', tokens); return { queries: json.queries, tokens }; } catch (error) { diff --git a/src/tools/read.ts b/src/tools/read.ts index 28b4dbf..26ae62f 100644 --- a/src/tools/read.ts +++ b/src/tools/read.ts @@ -1,9 +1,9 @@ import https from 'https'; -import { tokenTracker } from "../utils/token-tracker"; +import { TokenTracker } from "../utils/token-tracker"; import { ReadResponse } from '../types'; -export function readUrl(url: string, token: string): Promise<{ response: ReadResponse, tokens: number }> { +export function readUrl(url: string, token: string, tracker?: TokenTracker): Promise<{ response: ReadResponse, tokens: number }> { return new Promise((resolve, reject) => { const data = JSON.stringify({url}); @@ -33,7 +33,7 @@ export function readUrl(url: string, token: string): Promise<{ response: ReadRes tokens: response.data.usage.tokens }); const tokens = response.data?.usage?.tokens || 0; - tokenTracker.trackUsage('read', tokens); + (tracker || new TokenTracker()).trackUsage('read', tokens); resolve({ response, tokens }); }); }); diff --git a/src/tools/search.ts b/src/tools/search.ts index ab3b486..0e6c125 100644 --- a/src/tools/search.ts +++ b/src/tools/search.ts @@ -1,9 +1,9 @@ import https from 'https'; -import { tokenTracker } from "../utils/token-tracker"; +import { TokenTracker } from "../utils/token-tracker"; import { SearchResponse } from '../types'; -export function search(query: string, token: string): Promise<{ response: SearchResponse, tokens: number }> { +export function search(query: string, token: string, tracker?: TokenTracker): Promise<{ response: SearchResponse, tokens: number }> { return new Promise((resolve, reject) => { const options = { hostname: 's.jina.ai', @@ -28,7 +28,7 @@ export function search(query: string, token: string): Promise<{ response: Search url: item.url, tokens: item.usage.tokens }))); - tokenTracker.trackUsage('search', totalTokens); + (tracker || new TokenTracker()).trackUsage('search', totalTokens); resolve({ response, tokens: totalTokens }); }); }); diff --git a/src/types/tracker.ts b/src/types/tracker.ts new file mode 100644 index 0000000..f21e068 --- /dev/null +++ b/src/types/tracker.ts @@ -0,0 +1,7 @@ +import { TokenTracker } from '../utils/token-tracker'; +import { ActionTracker } from '../utils/action-tracker'; + +export interface TrackerContext { + tokenTracker: TokenTracker; + actionTracker: ActionTracker; +} diff --git a/src/utils/action-tracker.ts b/src/utils/action-tracker.ts new file mode 100644 index 0000000..f98cce6 --- /dev/null +++ b/src/utils/action-tracker.ts @@ -0,0 +1,36 @@ +import { EventEmitter } from 'events'; +import { StepAction } from '../types'; + +interface ActionState { + thisStep: StepAction; + gaps: string[]; + badAttempts: number; + totalStep: number; +} + +export class ActionTracker extends EventEmitter { + private state: ActionState = { + thisStep: {action: 'answer', answer: '', references: [], thoughts: ''}, + gaps: [], + badAttempts: 0, + totalStep: 0 + }; + + trackAction(newState: Partial) { + this.state = { ...this.state, ...newState }; + this.emit('action', this.state); + } + + getState(): ActionState { + return { ...this.state }; + } + + reset() { + this.state = { + thisStep: {action: 'answer', answer: '', references: [], thoughts: ''}, + gaps: [], + badAttempts: 0, + totalStep: 0 + }; + } +} diff --git a/src/utils/token-tracker.ts b/src/utils/token-tracker.ts index 944ba09..2df9345 100644 --- a/src/utils/token-tracker.ts +++ b/src/utils/token-tracker.ts @@ -2,7 +2,7 @@ import { EventEmitter } from 'events'; import { TokenUsage } from '../types'; -class TokenTracker extends EventEmitter { +export class TokenTracker extends EventEmitter { private usages: TokenUsage[] = []; trackUsage(tool: string, tokens: number) { @@ -33,5 +33,3 @@ class TokenTracker extends EventEmitter { this.usages = []; } } - -export const tokenTracker = new TokenTracker();