mirror of
https://github.com/jina-ai/node-DeepResearch.git
synced 2026-03-22 07:29:35 +08:00
feat: add action tracker and reset token tracker (#8)
* feat: add action tracker and reset token tracker Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: make trackers request-scoped Co-Authored-By: Han Xiao <han.xiao@jina.ai> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Han Xiao <han.xiao@jina.ai>
This commit is contained in:
committed by
GitHub
parent
3f032bbdcc
commit
5be008e8b9
33
src/agent.ts
33
src/agent.ts
@@ -8,8 +8,10 @@ import {dedupQueries} from "./tools/dedup";
|
||||
import {evaluateAnswer} from "./tools/evaluator";
|
||||
import {analyzeSteps} from "./tools/error-analyzer";
|
||||
import {GEMINI_API_KEY, JINA_API_KEY, SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
|
||||
import {tokenTracker} from "./utils/token-tracker";
|
||||
import {TokenTracker} from "./utils/token-tracker";
|
||||
import {ActionTracker} from "./utils/action-tracker";
|
||||
import {StepAction, SchemaProperty, ResponseSchema, AnswerAction} from "./types";
|
||||
import {TrackerContext} from "./types/tracker";
|
||||
|
||||
async function sleep(ms: number) {
|
||||
const seconds = Math.ceil(ms / 1000);
|
||||
@@ -241,7 +243,12 @@ function removeAllLineBreaks(text: string) {
|
||||
return text.replace(/(\r\n|\n|\r)/gm, " ");
|
||||
}
|
||||
|
||||
export async function getResponse(question: string, tokenBudget: number = 1_000_000, maxBadAttempts: number = 3): Promise<StepAction> {
|
||||
export async function getResponse(question: string, tokenBudget: number = 1_000_000, maxBadAttempts: number = 3): Promise<{ result: StepAction; context: TrackerContext }> {
|
||||
const context: TrackerContext = {
|
||||
tokenTracker: new TokenTracker(),
|
||||
actionTracker: new ActionTracker()
|
||||
};
|
||||
context.actionTracker.trackAction({ gaps: [question], totalStep: 0, badAttempts: 0 });
|
||||
let step = 0;
|
||||
let totalStep = 0;
|
||||
let badAttempts = 0;
|
||||
@@ -261,12 +268,13 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
|
||||
const allURLs: Record<string, string> = {};
|
||||
const visitedURLs: string[] = [];
|
||||
while (tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) {
|
||||
while (context.tokenTracker.getTotalUsage() < tokenBudget && badAttempts <= maxBadAttempts) {
|
||||
// add 1s delay to avoid rate limiting
|
||||
await sleep(STEP_SLEEP);
|
||||
step++;
|
||||
totalStep++;
|
||||
const budgetPercentage = (tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2);
|
||||
context.actionTracker.trackAction({ totalStep, thisStep, gaps, badAttempts });
|
||||
const budgetPercentage = (context.tokenTracker.getTotalUsage() / tokenBudget * 100).toFixed(2);
|
||||
console.log(`Step ${totalStep} / Budget used ${budgetPercentage}%`);
|
||||
console.log('Gaps:', gaps);
|
||||
allowReflect = allowReflect && (gaps.length <= 1);
|
||||
@@ -302,7 +310,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
|
||||
context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
|
||||
|
||||
|
||||
thisStep = JSON.parse(response.text());
|
||||
@@ -325,7 +333,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
...thisStep,
|
||||
});
|
||||
|
||||
const {response: evaluation} = await evaluateAnswer(currentQuestion, thisStep.answer);
|
||||
const {response: evaluation} = await evaluateAnswer(currentQuestion, thisStep.answer, context.tokenTracker);
|
||||
|
||||
|
||||
if (currentQuestion === question) {
|
||||
@@ -545,7 +553,7 @@ You decided to think out of the box or cut from a completely different angle.
|
||||
|
||||
const urlResults = await Promise.all(
|
||||
uniqueURLs.map(async (url: string) => {
|
||||
const {response, tokens} = await readUrl(url, JINA_API_KEY);
|
||||
const {response, tokens} = await readUrl(url, JINA_API_KEY, context.tokenTracker);
|
||||
allKnowledge.push({
|
||||
question: `What is in ${response.data.url}?`,
|
||||
answer: removeAllLineBreaks(response.data.content),
|
||||
@@ -592,7 +600,7 @@ You decided to think out of the box or cut from a completely different angle.`);
|
||||
totalStep++;
|
||||
await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
|
||||
if (isAnswered) {
|
||||
return thisStep;
|
||||
return { result: thisStep, context };
|
||||
} else {
|
||||
console.log('Enter Beast mode!!!')
|
||||
const prompt = getPrompt(
|
||||
@@ -621,12 +629,12 @@ You decided to think out of the box or cut from a completely different angle.`);
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
|
||||
context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
|
||||
|
||||
await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
|
||||
thisStep = JSON.parse(response.text());
|
||||
console.log(thisStep)
|
||||
return thisStep;
|
||||
return { result: thisStep, context };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -648,9 +656,10 @@ const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
|
||||
|
||||
export async function main() {
|
||||
const question = process.argv[2] || "";
|
||||
const finalStep = await getResponse(question) as AnswerAction;
|
||||
const { result: finalStep } = await getResponse(question) as { result: AnswerAction; context: TrackerContext };
|
||||
console.log('Final Answer:', finalStep.answer);
|
||||
tokenTracker.printSummary();
|
||||
const tracker = new TokenTracker();
|
||||
tracker.printSummary();
|
||||
}
|
||||
|
||||
if (require.main === module) {
|
||||
|
||||
@@ -2,8 +2,8 @@ import express, { Request, Response, RequestHandler } from 'express';
|
||||
import cors from 'cors';
|
||||
import { EventEmitter } from 'events';
|
||||
import { getResponse } from './agent';
|
||||
import { tokenTracker } from './utils/token-tracker';
|
||||
import { StepAction } from './types';
|
||||
import { TrackerContext } from './types/tracker';
|
||||
import fs from 'fs/promises';
|
||||
import path from 'path';
|
||||
|
||||
@@ -51,29 +51,21 @@ app.get('/api/v1/stream/:requestId', ((req: Request, res: StreamResponse) => {
|
||||
});
|
||||
}) as RequestHandler);
|
||||
|
||||
function createProgressEmitter(requestId: string, budget: number | undefined, thisStep: StepAction | undefined) {
|
||||
return (message: string, step: number, budgetPercentage?: string) => {
|
||||
const budgetInfo = budgetPercentage ? {
|
||||
used: tokenTracker.getTotalUsage(),
|
||||
function createProgressEmitter(requestId: string, budget: number | undefined, context: TrackerContext) {
|
||||
return () => {
|
||||
const state = context.actionTracker.getState();
|
||||
const budgetInfo = {
|
||||
used: context.tokenTracker.getTotalUsage(),
|
||||
total: budget || 1_000_000,
|
||||
percentage: budgetPercentage
|
||||
} : undefined;
|
||||
percentage: ((context.tokenTracker.getTotalUsage() / (budget || 1_000_000)) * 100).toFixed(2)
|
||||
};
|
||||
|
||||
if (thisStep?.action && thisStep?.thoughts) {
|
||||
eventEmitter.emit(`progress-${requestId}`, {
|
||||
type: 'progress',
|
||||
data: { ...thisStep, totalStep: step },
|
||||
step,
|
||||
budget: budgetInfo
|
||||
});
|
||||
} else {
|
||||
eventEmitter.emit(`progress-${requestId}`, {
|
||||
type: 'progress',
|
||||
data: message,
|
||||
step,
|
||||
budget: budgetInfo
|
||||
});
|
||||
}
|
||||
eventEmitter.emit(`progress-${requestId}`, {
|
||||
type: 'progress',
|
||||
data: { ...state.thisStep, totalStep: state.totalStep },
|
||||
step: state.totalStep,
|
||||
budget: budgetInfo
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
@@ -87,32 +79,14 @@ app.post('/api/v1/query', (async (req: QueryRequest, res: Response) => {
|
||||
const requestId = Date.now().toString();
|
||||
res.json({ requestId });
|
||||
|
||||
// Store original console.log
|
||||
const originalConsoleLog = console.log;
|
||||
let thisStep: StepAction | undefined;
|
||||
|
||||
try {
|
||||
const emitProgress = createProgressEmitter(requestId, budget, thisStep);
|
||||
|
||||
// Override console.log to track progress
|
||||
console.log = (...args: any[]) => {
|
||||
originalConsoleLog(...args);
|
||||
const message = args.join(' ');
|
||||
if (message.includes('Step') || message.includes('Budget used')) {
|
||||
const step = parseInt(message.match(/Step (\d+)/)?.[1] || '0');
|
||||
const budgetPercentage = message.match(/Budget used ([\d.]+)%/)?.[1];
|
||||
emitProgress(message, step, budgetPercentage);
|
||||
}
|
||||
};
|
||||
|
||||
const result = await getResponse(q, budget, maxBadAttempt);
|
||||
thisStep = result;
|
||||
const { result, context } = await getResponse(q, budget, maxBadAttempt);
|
||||
const emitProgress = createProgressEmitter(requestId, budget, context);
|
||||
context.actionTracker.on('action', emitProgress);
|
||||
await storeTaskResult(requestId, result);
|
||||
eventEmitter.emit(`progress-${requestId}`, { type: 'answer', data: result });
|
||||
} catch (error: any) {
|
||||
eventEmitter.emit(`progress-${requestId}`, { type: 'error', data: error?.message || 'Unknown error' });
|
||||
} finally {
|
||||
console.log = originalConsoleLog;
|
||||
}
|
||||
}) as RequestHandler);
|
||||
|
||||
@@ -145,4 +119,4 @@ app.listen(port, () => {
|
||||
console.log(`Server running at http://localhost:${port}`);
|
||||
});
|
||||
|
||||
export default app;
|
||||
export default app;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
|
||||
import { GEMINI_API_KEY, modelConfigs } from "../config";
|
||||
import { tokenTracker } from "../utils/token-tracker";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
|
||||
import { EvaluationResponse } from '../types';
|
||||
|
||||
@@ -63,7 +63,7 @@ Question: ${JSON.stringify(question)}
|
||||
Answer: ${JSON.stringify(answer)}`;
|
||||
}
|
||||
|
||||
export async function evaluateAnswer(question: string, answer: string): Promise<{ response: EvaluationResponse, tokens: number }> {
|
||||
export async function evaluateAnswer(question: string, answer: string, tracker?: TokenTracker): Promise<{ response: EvaluationResponse, tokens: number }> {
|
||||
try {
|
||||
const prompt = getPrompt(question, answer);
|
||||
const result = await model.generateContent(prompt);
|
||||
@@ -75,7 +75,7 @@ export async function evaluateAnswer(question: string, answer: string): Promise<
|
||||
reason: json.reasoning
|
||||
});
|
||||
const tokens = usage?.totalTokenCount || 0;
|
||||
tokenTracker.trackUsage('evaluator', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', tokens);
|
||||
return { response: json, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in answer evaluation:', error);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
|
||||
import { GEMINI_API_KEY, modelConfigs } from "../config";
|
||||
import { tokenTracker } from "../utils/token-tracker";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { SearchAction } from "../types";
|
||||
|
||||
import { KeywordsResponse } from '../types';
|
||||
@@ -105,7 +105,7 @@ Intention: ${action.thoughts}
|
||||
`;
|
||||
}
|
||||
|
||||
export async function rewriteQuery(action: SearchAction): Promise<{ queries: string[], tokens: number }> {
|
||||
export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[], tokens: number }> {
|
||||
try {
|
||||
const prompt = getPrompt(action);
|
||||
const result = await model.generateContent(prompt);
|
||||
@@ -115,7 +115,7 @@ export async function rewriteQuery(action: SearchAction): Promise<{ queries: str
|
||||
|
||||
console.log('Query rewriter:', json.queries);
|
||||
const tokens = usage?.totalTokenCount || 0;
|
||||
tokenTracker.trackUsage('query-rewriter', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('query-rewriter', tokens);
|
||||
|
||||
return { queries: json.queries, tokens };
|
||||
} catch (error) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import https from 'https';
|
||||
import { tokenTracker } from "../utils/token-tracker";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
|
||||
import { ReadResponse } from '../types';
|
||||
|
||||
export function readUrl(url: string, token: string): Promise<{ response: ReadResponse, tokens: number }> {
|
||||
export function readUrl(url: string, token: string, tracker?: TokenTracker): Promise<{ response: ReadResponse, tokens: number }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const data = JSON.stringify({url});
|
||||
|
||||
@@ -33,7 +33,7 @@ export function readUrl(url: string, token: string): Promise<{ response: ReadRes
|
||||
tokens: response.data.usage.tokens
|
||||
});
|
||||
const tokens = response.data?.usage?.tokens || 0;
|
||||
tokenTracker.trackUsage('read', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('read', tokens);
|
||||
resolve({ response, tokens });
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import https from 'https';
|
||||
import { tokenTracker } from "../utils/token-tracker";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
|
||||
import { SearchResponse } from '../types';
|
||||
|
||||
export function search(query: string, token: string): Promise<{ response: SearchResponse, tokens: number }> {
|
||||
export function search(query: string, token: string, tracker?: TokenTracker): Promise<{ response: SearchResponse, tokens: number }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const options = {
|
||||
hostname: 's.jina.ai',
|
||||
@@ -28,7 +28,7 @@ export function search(query: string, token: string): Promise<{ response: Search
|
||||
url: item.url,
|
||||
tokens: item.usage.tokens
|
||||
})));
|
||||
tokenTracker.trackUsage('search', totalTokens);
|
||||
(tracker || new TokenTracker()).trackUsage('search', totalTokens);
|
||||
resolve({ response, tokens: totalTokens });
|
||||
});
|
||||
});
|
||||
|
||||
7
src/types/tracker.ts
Normal file
7
src/types/tracker.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { TokenTracker } from '../utils/token-tracker';
|
||||
import { ActionTracker } from '../utils/action-tracker';
|
||||
|
||||
export interface TrackerContext {
|
||||
tokenTracker: TokenTracker;
|
||||
actionTracker: ActionTracker;
|
||||
}
|
||||
36
src/utils/action-tracker.ts
Normal file
36
src/utils/action-tracker.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { EventEmitter } from 'events';
|
||||
import { StepAction } from '../types';
|
||||
|
||||
interface ActionState {
|
||||
thisStep: StepAction;
|
||||
gaps: string[];
|
||||
badAttempts: number;
|
||||
totalStep: number;
|
||||
}
|
||||
|
||||
export class ActionTracker extends EventEmitter {
|
||||
private state: ActionState = {
|
||||
thisStep: {action: 'answer', answer: '', references: [], thoughts: ''},
|
||||
gaps: [],
|
||||
badAttempts: 0,
|
||||
totalStep: 0
|
||||
};
|
||||
|
||||
trackAction(newState: Partial<ActionState>) {
|
||||
this.state = { ...this.state, ...newState };
|
||||
this.emit('action', this.state);
|
||||
}
|
||||
|
||||
getState(): ActionState {
|
||||
return { ...this.state };
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.state = {
|
||||
thisStep: {action: 'answer', answer: '', references: [], thoughts: ''},
|
||||
gaps: [],
|
||||
badAttempts: 0,
|
||||
totalStep: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ import { EventEmitter } from 'events';
|
||||
|
||||
import { TokenUsage } from '../types';
|
||||
|
||||
class TokenTracker extends EventEmitter {
|
||||
export class TokenTracker extends EventEmitter {
|
||||
private usages: TokenUsage[] = [];
|
||||
|
||||
trackUsage(tool: string, tokens: number) {
|
||||
@@ -33,5 +33,3 @@ class TokenTracker extends EventEmitter {
|
||||
this.usages = [];
|
||||
}
|
||||
}
|
||||
|
||||
export const tokenTracker = new TokenTracker();
|
||||
|
||||
Reference in New Issue
Block a user