feat: add action tracker and reset token tracker (#8)

* feat: add action tracker and reset token tracker

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* refactor: make trackers request-scoped

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Han Xiao <han.xiao@jina.ai>
This commit is contained in:
devin-ai-integration[bot]
2025-02-02 23:25:54 +08:00
committed by GitHub
parent 3f032bbdcc
commit 5be008e8b9
9 changed files with 95 additions and 71 deletions

View File

@@ -8,8 +8,10 @@ import {dedupQueries} from "./tools/dedup";
import {evaluateAnswer} from "./tools/evaluator";
import {analyzeSteps} from "./tools/error-analyzer";
import {GEMINI_API_KEY, JINA_API_KEY, SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
import {tokenTracker} from "./utils/token-tracker";
import {TokenTracker} from "./utils/token-tracker";
import {ActionTracker} from "./utils/action-tracker";
import {StepAction, SchemaProperty, ResponseSchema, AnswerAction} from "./types";
import {TrackerContext} from "./types/tracker";
async function sleep(ms: number) {
const seconds = Math.ceil(ms / 1000);
@@ -241,7 +243,12 @@ function removeAllLineBreaks(text: string) {
return text.replace(/(\r\n|\n|\r)/gm, " ");
}
export async function getResponse(question: string, tokenBudget: number = 1_000_000, maxBadAttempts: number = 3): Promise<StepAction> {
export async function getResponse(question: string, tokenBudget: number = 1_000_000, maxBadAttempts: number = 3): Promise<{ result: StepAction; context: TrackerContext }> {
const context: TrackerContext = {
tokenTracker: new TokenTracker(),
actionTracker: new ActionTracker()
};
context.actionTracker.trackAction({ gaps: [question], totalStep: 0, badAttempts: 0 });
let step = 0;
let totalStep = 0;
let badAttempts = 0;
@@ -261,12 +268,13 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
const allURLs: Record<string, string> = {};
const visitedURLs: string[] = [];
while (tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) {
while (context.tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) {
// add 1s delay to avoid rate limiting
await sleep(STEP_SLEEP);
step++;
totalStep++;
const budgetPercentage = (tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2);
context.actionTracker.trackAction({ totalStep, thisStep, gaps, badAttempts });
const budgetPercentage = (context.tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2);
console.log(`Step ${totalStep} / Budget used ${budgetPercentage}%`);
console.log('Gaps:', gaps);
allowReflect = allowReflect && (gaps.length <= 1);
@@ -302,7 +310,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
const result = await model.generateContent(prompt);
const response = await result.response;
const usage = response.usageMetadata;
tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
thisStep = JSON.parse(response.text());
@@ -325,7 +333,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
...thisStep,
});
const {response: evaluation} = await evaluateAnswer(currentQuestion, thisStep.answer);
const {response: evaluation} = await evaluateAnswer(currentQuestion, thisStep.answer, context.tokenTracker);
if (currentQuestion === question) {
@@ -545,7 +553,7 @@ You decided to think out of the box or cut from a completely different angle.
const urlResults = await Promise.all(
uniqueURLs.map(async (url: string) => {
const {response, tokens} = await readUrl(url, JINA_API_KEY);
const {response, tokens} = await readUrl(url, JINA_API_KEY, context.tokenTracker);
allKnowledge.push({
question: `What is in ${response.data.url}?`,
answer: removeAllLineBreaks(response.data.content),
@@ -592,7 +600,7 @@ You decided to think out of the box or cut from a completely different angle.`);
totalStep++;
await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
if (isAnswered) {
return thisStep;
return { result: thisStep, context };
} else {
console.log('Enter Beast mode!!!')
const prompt = getPrompt(
@@ -621,12 +629,12 @@ You decided to think out of the box or cut from a completely different angle.`);
const result = await model.generateContent(prompt);
const response = await result.response;
const usage = response.usageMetadata;
tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
thisStep = JSON.parse(response.text());
console.log(thisStep)
return thisStep;
return { result: thisStep, context };
}
}
@@ -648,9 +656,10 @@ const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
export async function main() {
const question = process.argv[2] || "";
const finalStep = await getResponse(question) as AnswerAction;
const { result: finalStep } = await getResponse(question) as { result: AnswerAction; context: TrackerContext };
console.log('Final Answer:', finalStep.answer);
tokenTracker.printSummary();
const tracker = new TokenTracker();
tracker.printSummary();
}
if (require.main === module) {

View File

@@ -2,8 +2,8 @@ import express, { Request, Response, RequestHandler } from 'express';
import cors from 'cors';
import { EventEmitter } from 'events';
import { getResponse } from './agent';
import { tokenTracker } from './utils/token-tracker';
import { StepAction } from './types';
import { TrackerContext } from './types/tracker';
import fs from 'fs/promises';
import path from 'path';
@@ -51,29 +51,21 @@ app.get('/api/v1/stream/:requestId', ((req: Request, res: StreamResponse) => {
});
}) as RequestHandler);
function createProgressEmitter(requestId: string, budget: number | undefined, thisStep: StepAction | undefined) {
return (message: string, step: number, budgetPercentage?: string) => {
const budgetInfo = budgetPercentage ? {
used: tokenTracker.getTotalUsage(),
function createProgressEmitter(requestId: string, budget: number | undefined, context: TrackerContext) {
return () => {
const state = context.actionTracker.getState();
const budgetInfo = {
used: context.tokenTracker.getTotalUsage(),
total: budget || 1_000_000,
percentage: budgetPercentage
} : undefined;
percentage: ((context.tokenTracker.getTotalUsage() / (budget || 1_000_000)) * 100).toFixed(2)
};
if (thisStep?.action && thisStep?.thoughts) {
eventEmitter.emit(`progress-${requestId}`, {
type: 'progress',
data: { ...thisStep, totalStep: step },
step,
budget: budgetInfo
});
} else {
eventEmitter.emit(`progress-${requestId}`, {
type: 'progress',
data: message,
step,
budget: budgetInfo
});
}
eventEmitter.emit(`progress-${requestId}`, {
type: 'progress',
data: { ...state.thisStep, totalStep: state.totalStep },
step: state.totalStep,
budget: budgetInfo
});
};
}
@@ -87,32 +79,14 @@ app.post('/api/v1/query', (async (req: QueryRequest, res: Response) => {
const requestId = Date.now().toString();
res.json({ requestId });
// Store original console.log
const originalConsoleLog = console.log;
let thisStep: StepAction | undefined;
try {
const emitProgress = createProgressEmitter(requestId, budget, thisStep);
// Override console.log to track progress
console.log = (...args: any[]) => {
originalConsoleLog(...args);
const message = args.join(' ');
if (message.includes('Step') || message.includes('Budget used')) {
const step = parseInt(message.match(/Step (\d+)/)?.[1] || '0');
const budgetPercentage = message.match(/Budget used ([\d.]+)%/)?.[1];
emitProgress(message, step, budgetPercentage);
}
};
const result = await getResponse(q, budget, maxBadAttempt);
thisStep = result;
const { result, context } = await getResponse(q, budget, maxBadAttempt);
const emitProgress = createProgressEmitter(requestId, budget, context);
context.actionTracker.on('action', emitProgress);
await storeTaskResult(requestId, result);
eventEmitter.emit(`progress-${requestId}`, { type: 'answer', data: result });
} catch (error: any) {
eventEmitter.emit(`progress-${requestId}`, { type: 'error', data: error?.message || 'Unknown error' });
} finally {
console.log = originalConsoleLog;
}
}) as RequestHandler);
@@ -145,4 +119,4 @@ app.listen(port, () => {
console.log(`Server running at http://localhost:${port}`);
});
export default app;
export default app;

View File

@@ -1,6 +1,6 @@
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
import { GEMINI_API_KEY, modelConfigs } from "../config";
import { tokenTracker } from "../utils/token-tracker";
import { TokenTracker } from "../utils/token-tracker";
import { EvaluationResponse } from '../types';
@@ -63,7 +63,7 @@ Question: ${JSON.stringify(question)}
Answer: ${JSON.stringify(answer)}`;
}
export async function evaluateAnswer(question: string, answer: string): Promise<{ response: EvaluationResponse, tokens: number }> {
export async function evaluateAnswer(question: string, answer: string, tracker?: TokenTracker): Promise<{ response: EvaluationResponse, tokens: number }> {
try {
const prompt = getPrompt(question, answer);
const result = await model.generateContent(prompt);
@@ -75,7 +75,7 @@ export async function evaluateAnswer(question: string, answer: string): Promise<
reason: json.reasoning
});
const tokens = usage?.totalTokenCount || 0;
tokenTracker.trackUsage('evaluator', tokens);
(tracker || new TokenTracker()).trackUsage('evaluator', tokens);
return { response: json, tokens };
} catch (error) {
console.error('Error in answer evaluation:', error);

View File

@@ -1,6 +1,6 @@
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
import { GEMINI_API_KEY, modelConfigs } from "../config";
import { tokenTracker } from "../utils/token-tracker";
import { TokenTracker } from "../utils/token-tracker";
import { SearchAction } from "../types";
import { KeywordsResponse } from '../types';
@@ -105,7 +105,7 @@ Intention: ${action.thoughts}
`;
}
export async function rewriteQuery(action: SearchAction): Promise<{ queries: string[], tokens: number }> {
export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[], tokens: number }> {
try {
const prompt = getPrompt(action);
const result = await model.generateContent(prompt);
@@ -115,7 +115,7 @@ export async function rewriteQuery(action: SearchAction): Promise<{ queries: str
console.log('Query rewriter:', json.queries);
const tokens = usage?.totalTokenCount || 0;
tokenTracker.trackUsage('query-rewriter', tokens);
(tracker || new TokenTracker()).trackUsage('query-rewriter', tokens);
return { queries: json.queries, tokens };
} catch (error) {

View File

@@ -1,9 +1,9 @@
import https from 'https';
import { tokenTracker } from "../utils/token-tracker";
import { TokenTracker } from "../utils/token-tracker";
import { ReadResponse } from '../types';
export function readUrl(url: string, token: string): Promise<{ response: ReadResponse, tokens: number }> {
export function readUrl(url: string, token: string, tracker?: TokenTracker): Promise<{ response: ReadResponse, tokens: number }> {
return new Promise((resolve, reject) => {
const data = JSON.stringify({url});
@@ -33,7 +33,7 @@ export function readUrl(url: string, token: string): Promise<{ response: ReadRes
tokens: response.data.usage.tokens
});
const tokens = response.data?.usage?.tokens || 0;
tokenTracker.trackUsage('read', tokens);
(tracker || new TokenTracker()).trackUsage('read', tokens);
resolve({ response, tokens });
});
});

View File

@@ -1,9 +1,9 @@
import https from 'https';
import { tokenTracker } from "../utils/token-tracker";
import { TokenTracker } from "../utils/token-tracker";
import { SearchResponse } from '../types';
export function search(query: string, token: string): Promise<{ response: SearchResponse, tokens: number }> {
export function search(query: string, token: string, tracker?: TokenTracker): Promise<{ response: SearchResponse, tokens: number }> {
return new Promise((resolve, reject) => {
const options = {
hostname: 's.jina.ai',
@@ -28,7 +28,7 @@ export function search(query: string, token: string): Promise<{ response: Search
url: item.url,
tokens: item.usage.tokens
})));
tokenTracker.trackUsage('search', totalTokens);
(tracker || new TokenTracker()).trackUsage('search', totalTokens);
resolve({ response, tokens: totalTokens });
});
});

7
src/types/tracker.ts Normal file
View File

@@ -0,0 +1,7 @@
import { TokenTracker } from '../utils/token-tracker';
import { ActionTracker } from '../utils/action-tracker';
export interface TrackerContext {
tokenTracker: TokenTracker;
actionTracker: ActionTracker;
}

View File

@@ -0,0 +1,36 @@
import { EventEmitter } from 'events';
import { StepAction } from '../types';
interface ActionState {
thisStep: StepAction;
gaps: string[];
badAttempts: number;
totalStep: number;
}
export class ActionTracker extends EventEmitter {
private state: ActionState = {
thisStep: {action: 'answer', answer: '', references: [], thoughts: ''},
gaps: [],
badAttempts: 0,
totalStep: 0
};
trackAction(newState: Partial<ActionState>) {
this.state = { ...this.state, ...newState };
this.emit('action', this.state);
}
getState(): ActionState {
return { ...this.state };
}
reset() {
this.state = {
thisStep: {action: 'answer', answer: '', references: [], thoughts: ''},
gaps: [],
badAttempts: 0,
totalStep: 0
};
}
}

View File

@@ -2,7 +2,7 @@ import { EventEmitter } from 'events';
import { TokenUsage } from '../types';
class TokenTracker extends EventEmitter {
export class TokenTracker extends EventEmitter {
private usages: TokenUsage[] = [];
trackUsage(tool: string, tokens: number) {
@@ -33,5 +33,3 @@ class TokenTracker extends EventEmitter {
this.usages = [];
}
}
export const tokenTracker = new TokenTracker();