mirror of
https://github.com/jina-ai/node-DeepResearch.git
synced 2026-03-22 07:29:35 +08:00
refactor: replace @google/generative-ai with @ai-sdk/google (#27)
* refactor: replace @google/generative-ai with @ai-sdk/google Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: use createGoogleGenerativeAI for API key configuration Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update Zod schemas to use discriminated unions Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: ensure at least one variant in Zod discriminated union Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove unused actions variable Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove duplicate sections declaration and update action sections Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update schema types and use process.env.GEMINI_API_KEY Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update schema to use z.union with type literal Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: restore original schema descriptions and remove unused imports Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update schema to use discriminatedUnion with proper descriptions Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update schema to use proper type casting for discriminated union Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update schema type casting for discriminated union Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update schema to use strict mode and proper type definitions Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: add type field to all schemas Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove unused schema variables Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove unused baseSchema variable Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove unused baseSchema variable and comments Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: implement token tracking using generateObject response Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update token tracking to use proper destructuring Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: implement token tracking in evaluator and add test Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: move maxTokens parameter to config.ts Co-Authored-By: Han Xiao <han.xiao@jina.ai> * feat: implement error handling for generateObject schema validation errors Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove lint errors in error handling utility Co-Authored-By: Han Xiao <han.xiao@jina.ai> * chore: clean up error handling utility Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove unused functionName parameter Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove functionName parameter from handleGenerateObjectError calls Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: update DedupResponse import to type import Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: clean up --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Han Xiao <han.xiao@jina.ai>
This commit is contained in:
committed by
GitHub
parent
2b84a577c8
commit
f1c7ada6ae
185
src/agent.ts
185
src/agent.ts
@@ -1,5 +1,8 @@
|
||||
import {GoogleGenerativeAI, SchemaType} from "@google/generative-ai";
|
||||
import {createGoogleGenerativeAI} from '@ai-sdk/google';
|
||||
import {z} from 'zod';
|
||||
import {generateObject} from 'ai';
|
||||
import {readUrl} from "./tools/read";
|
||||
import {handleGenerateObjectError} from './utils/error-handling';
|
||||
import fs from 'fs/promises';
|
||||
import {SafeSearchType, search as duckSearch} from "duck-duck-scrape";
|
||||
import {braveSearch} from "./tools/brave-search";
|
||||
@@ -7,10 +10,10 @@ import {rewriteQuery} from "./tools/query-rewriter";
|
||||
import {dedupQueries} from "./tools/dedup";
|
||||
import {evaluateAnswer} from "./tools/evaluator";
|
||||
import {analyzeSteps} from "./tools/error-analyzer";
|
||||
import {GEMINI_API_KEY, SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
|
||||
import {SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
|
||||
import {TokenTracker} from "./utils/token-tracker";
|
||||
import {ActionTracker} from "./utils/action-tracker";
|
||||
import {StepAction, SchemaProperty, ResponseSchema, AnswerAction} from "./types";
|
||||
import {StepAction, AnswerAction} from "./types";
|
||||
import {TrackerContext} from "./types";
|
||||
import {jinaSearch} from "./tools/jinaSearch";
|
||||
|
||||
@@ -20,89 +23,55 @@ async function sleep(ms: number) {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
function getSchema(allowReflect: boolean, allowRead: boolean, allowAnswer: boolean, allowSearch: boolean): ResponseSchema {
|
||||
function getSchema(allowReflect: boolean, allowRead: boolean, allowAnswer: boolean, allowSearch: boolean) {
|
||||
const actions: string[] = [];
|
||||
const properties: Record<string, SchemaProperty> = {
|
||||
action: {
|
||||
type: SchemaType.STRING,
|
||||
enum: actions,
|
||||
description: "Must match exactly one action type"
|
||||
},
|
||||
think: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Explain why choose this action, what's the thought process behind choosing this action"
|
||||
}
|
||||
const properties: Record<string, z.ZodTypeAny> = {
|
||||
action: z.enum(['placeholder']), // Will update later with actual actions
|
||||
think: z.string().describe("Explain why choose this action, what's the thought process behind choosing this action")
|
||||
};
|
||||
|
||||
if (allowSearch) {
|
||||
actions.push("search");
|
||||
properties.searchQuery = {
|
||||
type: SchemaType.STRING,
|
||||
description: "Only required when choosing 'search' action, must be a short, keyword-based query that BM25, tf-idf based search engines can understand."
|
||||
};
|
||||
properties.searchQuery = z.string()
|
||||
.describe("Only required when choosing 'search' action, must be a short, keyword-based query that BM25, tf-idf based search engines can understand.").optional();
|
||||
}
|
||||
|
||||
if (allowAnswer) {
|
||||
actions.push("answer");
|
||||
properties.answer = {
|
||||
type: SchemaType.STRING,
|
||||
description: "Only required when choosing 'answer' action, must be the final answer in natural language"
|
||||
};
|
||||
properties.references = {
|
||||
type: SchemaType.ARRAY,
|
||||
items: {
|
||||
type: SchemaType.OBJECT,
|
||||
properties: {
|
||||
exactQuote: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Exact relevant quote from the document"
|
||||
},
|
||||
url: {
|
||||
type: SchemaType.STRING,
|
||||
description: "URL of the document; must be directly from the context"
|
||||
}
|
||||
},
|
||||
required: ["exactQuote", "url"]
|
||||
},
|
||||
description: "Must be an array of references that support the answer, each reference must contain an exact quote and the URL of the document"
|
||||
};
|
||||
properties.answer = z.string()
|
||||
.describe("Only required when choosing 'answer' action, must be the final answer in natural language").optional();
|
||||
properties.references = z.array(
|
||||
z.object({
|
||||
exactQuote: z.string().describe("Exact relevant quote from the document"),
|
||||
url: z.string().describe("URL of the document; must be directly from the context")
|
||||
}).required()
|
||||
).describe("Must be an array of references that support the answer, each reference must contain an exact quote and the URL of the document").optional();
|
||||
}
|
||||
|
||||
if (allowReflect) {
|
||||
actions.push("reflect");
|
||||
properties.questionsToAnswer = {
|
||||
type: SchemaType.ARRAY,
|
||||
items: {
|
||||
type: SchemaType.STRING,
|
||||
description: "each question must be a single line, concise and clear. not composite or compound, less than 20 words."
|
||||
},
|
||||
description: "List of most important questions to fill the knowledge gaps of finding the answer to the original question",
|
||||
maxItems: 2
|
||||
};
|
||||
properties.questionsToAnswer = z.array(
|
||||
z.string().describe("each question must be a single line, concise and clear. not composite or compound, less than 20 words.")
|
||||
).max(2)
|
||||
.describe("List of most important questions to fill the knowledge gaps of finding the answer to the original question").optional();
|
||||
}
|
||||
|
||||
if (allowRead) {
|
||||
actions.push("visit");
|
||||
properties.URLTargets = {
|
||||
type: SchemaType.ARRAY,
|
||||
items: {
|
||||
type: SchemaType.STRING
|
||||
},
|
||||
maxItems: 2,
|
||||
description: "Must be an array of URLs, choose up the most relevant 2 URLs to visit"
|
||||
};
|
||||
properties.URLTargets = z.array(z.string())
|
||||
.max(2)
|
||||
.describe("Must be an array of URLs, choose up the most relevant 2 URLs to visit").optional();
|
||||
}
|
||||
|
||||
// Update the enum values after collecting all actions
|
||||
properties.action.enum = actions;
|
||||
properties.action = z.enum(actions as [string, ...string[]])
|
||||
.describe("Must match exactly one action type");
|
||||
|
||||
return z.object(properties);
|
||||
|
||||
return {
|
||||
type: SchemaType.OBJECT,
|
||||
properties,
|
||||
required: ["action", "think"]
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
function getPrompt(
|
||||
question: string,
|
||||
context?: string[],
|
||||
@@ -117,6 +86,7 @@ function getPrompt(
|
||||
beastMode?: boolean
|
||||
): string {
|
||||
const sections: string[] = [];
|
||||
const actionSections: string[] = [];
|
||||
|
||||
// Add header section
|
||||
sections.push(`Current date: ${new Date().toUTCString()}
|
||||
@@ -150,7 +120,7 @@ ${k.question}
|
||||
<answer>
|
||||
${k.answer}
|
||||
</answer>
|
||||
${k.references.length > 0 ? `
|
||||
${k.references?.length > 0 ? `
|
||||
<references>
|
||||
${JSON.stringify(k.references)}
|
||||
</references>
|
||||
@@ -201,14 +171,13 @@ ${learnedStrategy}
|
||||
}
|
||||
|
||||
// Build actions section
|
||||
const actions: string[] = [];
|
||||
|
||||
if (allURLs && Object.keys(allURLs).length > 0 && allowRead) {
|
||||
const urlList = Object.entries(allURLs)
|
||||
.map(([url, desc]) => ` + "${url}": "${desc}"`)
|
||||
.join('\n');
|
||||
|
||||
actions.push(`
|
||||
actionSections.push(`
|
||||
<action-visit>
|
||||
- Visit any URLs from below to gather external knowledge, choose the most relevant URLs that might contain the answer
|
||||
<url-list>
|
||||
@@ -222,7 +191,7 @@ ${urlList}
|
||||
}
|
||||
|
||||
if (allowSearch) {
|
||||
actions.push(`
|
||||
actionSections.push(`
|
||||
<action-search>
|
||||
- Query external sources using a public search engine
|
||||
- Focus on solving one specific aspect of the question
|
||||
@@ -232,7 +201,7 @@ ${urlList}
|
||||
}
|
||||
|
||||
if (allowAnswer) {
|
||||
actions.push(`
|
||||
actionSections.push(`
|
||||
<action-answer>
|
||||
- Provide final response only when 100% certain
|
||||
- Responses must be definitive (no ambiguity, uncertainty, or disclaimers)${allowReflect ? '\n- If doubts remain, use <action-reflect> instead' : ''}
|
||||
@@ -241,7 +210,7 @@ ${urlList}
|
||||
}
|
||||
|
||||
if (beastMode) {
|
||||
actions.push(`
|
||||
actionSections.push(`
|
||||
<action-answer>
|
||||
- Any answer is better than no answer
|
||||
- Partial answers are allowed, but make sure they are based on the context and knowledge you have gathered
|
||||
@@ -252,7 +221,7 @@ ${urlList}
|
||||
}
|
||||
|
||||
if (allowReflect) {
|
||||
actions.push(`
|
||||
actionSections.push(`
|
||||
<action-reflect>
|
||||
- Perform critical analysis through hypothetical scenarios or systematic breakdowns
|
||||
- Identify knowledge gaps and formulate essential clarifying questions
|
||||
@@ -268,7 +237,7 @@ ${urlList}
|
||||
sections.push(`
|
||||
Based on the current context, you must choose one of the following actions:
|
||||
<actions>
|
||||
${actions.join('\n\n')}
|
||||
${actionSections.join('\n\n')}
|
||||
</actions>
|
||||
`);
|
||||
|
||||
@@ -356,22 +325,25 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
false
|
||||
);
|
||||
|
||||
const model = genAI.getGenerativeModel({
|
||||
model: modelConfigs.agent.model,
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.agent.temperature,
|
||||
responseMimeType: "application/json",
|
||||
responseSchema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch)
|
||||
}
|
||||
});
|
||||
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
|
||||
|
||||
|
||||
thisStep = JSON.parse(response.text());
|
||||
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agent.model);
|
||||
let object;
|
||||
let totalTokens = 0;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
schema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch),
|
||||
prompt,
|
||||
maxTokens: modelConfigs.agent.maxTokens
|
||||
});
|
||||
object = result.object;
|
||||
totalTokens = result.usage?.totalTokens || 0;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<StepAction>(error);
|
||||
object = result.object;
|
||||
totalTokens = result.totalTokens;
|
||||
}
|
||||
context.tokenTracker.trackUsage('agent', totalTokens);
|
||||
thisStep = object as StepAction;
|
||||
// print allowed and chose action
|
||||
const actionsStr = [allowSearch, allowRead, allowAnswer, allowReflect].map((a, i) => a ? ['search', 'read', 'answer', 'reflect'][i] : null).filter(a => a).join(', ');
|
||||
console.log(`${thisStep.action} <- [${actionsStr}]`);
|
||||
@@ -683,8 +655,8 @@ You decided to think out of the box or cut from a completely different angle.`);
|
||||
} else {
|
||||
console.log('Enter Beast mode!!!')
|
||||
// any answer is better than no answer, humanity last resort
|
||||
step ++;
|
||||
totalStep ++;
|
||||
step++;
|
||||
totalStep++;
|
||||
const prompt = getPrompt(
|
||||
question,
|
||||
diaryContext,
|
||||
@@ -699,22 +671,27 @@ You decided to think out of the box or cut from a completely different angle.`);
|
||||
true
|
||||
);
|
||||
|
||||
const model = genAI.getGenerativeModel({
|
||||
model: modelConfigs.agentBeastMode.model,
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.agentBeastMode.temperature,
|
||||
responseMimeType: "application/json",
|
||||
responseSchema: getSchema(false, false, allowAnswer, false)
|
||||
}
|
||||
});
|
||||
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
context.tokenTracker.trackUsage('agent', usage?.totalTokenCount || 0);
|
||||
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agentBeastMode.model);
|
||||
let object;
|
||||
let totalTokens = 0;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
schema: getSchema(false, false, allowAnswer, false),
|
||||
prompt,
|
||||
maxTokens: modelConfigs.agentBeastMode.maxTokens
|
||||
});
|
||||
object = result.object;
|
||||
totalTokens = result.usage?.totalTokens || 0;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<StepAction>(error);
|
||||
object = result.object;
|
||||
totalTokens = result.totalTokens;
|
||||
}
|
||||
context.tokenTracker.trackUsage('agent', totalTokens);
|
||||
|
||||
await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
|
||||
thisStep = JSON.parse(response.text());
|
||||
thisStep = object as StepAction;
|
||||
console.log(thisStep)
|
||||
return {result: thisStep, context};
|
||||
}
|
||||
@@ -733,8 +710,6 @@ async function storeContext(prompt: string, memory: any[][], step: number) {
|
||||
}
|
||||
}
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
|
||||
|
||||
|
||||
export async function main() {
|
||||
const question = process.argv[2] || "";
|
||||
|
||||
@@ -4,6 +4,7 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
interface ModelConfig {
|
||||
model: string;
|
||||
temperature: number;
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
interface ToolConfigs {
|
||||
@@ -38,7 +39,8 @@ const DEFAULT_MODEL = 'gemini-1.5-flash';
|
||||
|
||||
const defaultConfig: ModelConfig = {
|
||||
model: DEFAULT_MODEL,
|
||||
temperature: 0
|
||||
temperature: 0,
|
||||
maxTokens: 1000
|
||||
};
|
||||
|
||||
export const modelConfigs: ToolConfigs = {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { dedupQueries } from '../dedup';
|
||||
|
||||
describe('dedupQueries', () => {
|
||||
it('should remove duplicate queries', async () => {
|
||||
jest.setTimeout(10000); // Increase timeout to 10s
|
||||
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
|
||||
const { unique_queries } = await dedupQueries(queries, []);
|
||||
expect(unique_queries).toHaveLength(2);
|
||||
|
||||
@@ -12,4 +12,16 @@ describe('evaluateAnswer', () => {
|
||||
expect(response).toHaveProperty('is_definitive');
|
||||
expect(response).toHaveProperty('reasoning');
|
||||
});
|
||||
|
||||
it('should track token usage', async () => {
|
||||
const tokenTracker = new TokenTracker();
|
||||
const spy = jest.spyOn(tokenTracker, 'trackUsage');
|
||||
const { tokens } = await evaluateAnswer(
|
||||
'What is TypeScript?',
|
||||
'TypeScript is a strongly typed programming language that builds on JavaScript.',
|
||||
tokenTracker
|
||||
);
|
||||
expect(spy).toHaveBeenCalledWith('evaluator', tokens);
|
||||
expect(tokens).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,38 +1,20 @@
|
||||
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
|
||||
import { GEMINI_API_KEY, modelConfigs } from "../config";
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { z } from 'zod';
|
||||
import { generateObject } from 'ai';
|
||||
import { modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||
import type { DedupResponse } from '../types';
|
||||
|
||||
import { DedupResponse } from '../types';
|
||||
|
||||
const responseSchema = {
|
||||
type: SchemaType.OBJECT,
|
||||
properties: {
|
||||
think: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Strategic reasoning about the overall deduplication approach"
|
||||
},
|
||||
unique_queries: {
|
||||
type: SchemaType.ARRAY,
|
||||
items: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Unique query that passed the deduplication process, must be less than 30 characters"
|
||||
},
|
||||
description: "Array of semantically unique queries"
|
||||
}
|
||||
},
|
||||
required: ["think", "unique_queries"]
|
||||
};
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
|
||||
const model = genAI.getGenerativeModel({
|
||||
model: modelConfigs.dedup.model,
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.dedup.temperature,
|
||||
responseMimeType: "application/json",
|
||||
responseSchema: responseSchema
|
||||
}
|
||||
const responseSchema = z.object({
|
||||
think: z.string().describe('Strategic reasoning about the overall deduplication approach'),
|
||||
unique_queries: z.array(z.string().describe('Unique query that passed the deduplication process, must be less than 30 characters'))
|
||||
.describe('Array of semantically unique queries').max(3)
|
||||
});
|
||||
|
||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.dedup.model);
|
||||
|
||||
function getPrompt(newQueries: string[], existingQueries: string[]): string {
|
||||
return `You are an expert in semantic similarity analysis. Given a set of queries (setA) and a set of queries (setB)
|
||||
|
||||
@@ -88,14 +70,25 @@ SetB: ${JSON.stringify(existingQueries)}`;
|
||||
export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[], tokens: number }> {
|
||||
try {
|
||||
const prompt = getPrompt(newQueries, existingQueries);
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
const json = JSON.parse(response.text()) as DedupResponse;
|
||||
console.log('Dedup:', json.unique_queries);
|
||||
const tokens = usage?.totalTokenCount || 0;
|
||||
let object;
|
||||
let tokens = 0;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
schema: responseSchema,
|
||||
prompt,
|
||||
maxTokens: modelConfigs.dedup.maxTokens
|
||||
});
|
||||
object = result.object;
|
||||
tokens = result.usage?.totalTokens || 0;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<DedupResponse>(error);
|
||||
object = result.object;
|
||||
tokens = result.totalTokens;
|
||||
}
|
||||
console.log('Dedup:', object.unique_queries);
|
||||
(tracker || new TokenTracker()).trackUsage('dedup', tokens);
|
||||
return { unique_queries: json.unique_queries, tokens };
|
||||
return { unique_queries: object.unique_queries, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in deduplication analysis:', error);
|
||||
throw error;
|
||||
|
||||
@@ -1,38 +1,19 @@
|
||||
import {GoogleGenerativeAI, SchemaType} from "@google/generative-ai";
|
||||
import { GEMINI_API_KEY, modelConfigs } from "../config";
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { z } from 'zod';
|
||||
import { generateObject } from 'ai';
|
||||
import { modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
|
||||
import { ErrorAnalysisResponse } from '../types';
|
||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||
|
||||
const responseSchema = {
|
||||
type: SchemaType.OBJECT,
|
||||
properties: {
|
||||
recap: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Recap of the actions taken and the steps conducted"
|
||||
},
|
||||
blame: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Which action or the step was the root cause of the answer rejection"
|
||||
},
|
||||
improvement: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe."
|
||||
}
|
||||
},
|
||||
required: ["recap", "blame", "improvement"]
|
||||
};
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
|
||||
const model = genAI.getGenerativeModel({
|
||||
model: modelConfigs.errorAnalyzer.model,
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.errorAnalyzer.temperature,
|
||||
responseMimeType: "application/json",
|
||||
responseSchema: responseSchema
|
||||
}
|
||||
const responseSchema = z.object({
|
||||
recap: z.string().describe('Recap of the actions taken and the steps conducted'),
|
||||
blame: z.string().describe('Which action or the step was the root cause of the answer rejection'),
|
||||
improvement: z.string().describe('Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe.')
|
||||
});
|
||||
|
||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.errorAnalyzer.model);
|
||||
|
||||
function getPrompt(diaryContext: string[]): string {
|
||||
return `You are an expert at analyzing search and reasoning processes. Your task is to analyze the given sequence of steps and identify what went wrong in the search process.
|
||||
|
||||
@@ -124,17 +105,28 @@ ${diaryContext.join('\n')}
|
||||
export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracker): Promise<{ response: ErrorAnalysisResponse, tokens: number }> {
|
||||
try {
|
||||
const prompt = getPrompt(diaryContext);
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
const json = JSON.parse(response.text()) as ErrorAnalysisResponse;
|
||||
let object;
|
||||
let tokens = 0;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
schema: responseSchema,
|
||||
prompt,
|
||||
maxTokens: modelConfigs.errorAnalyzer.maxTokens
|
||||
});
|
||||
object = result.object;
|
||||
tokens = result.usage?.totalTokens || 0;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<ErrorAnalysisResponse>(error);
|
||||
object = result.object;
|
||||
tokens = result.totalTokens;
|
||||
}
|
||||
console.log('Error analysis:', {
|
||||
is_valid: !json.blame,
|
||||
reason: json.blame || 'No issues found'
|
||||
is_valid: !object.blame,
|
||||
reason: object.blame || 'No issues found'
|
||||
});
|
||||
const tokens = usage?.totalTokenCount || 0;
|
||||
(tracker || new TokenTracker()).trackUsage('error-analyzer', tokens);
|
||||
return { response: json, tokens };
|
||||
return { response: object, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in answer evaluation:', error);
|
||||
throw error;
|
||||
|
||||
@@ -1,34 +1,18 @@
|
||||
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
|
||||
import { GEMINI_API_KEY, modelConfigs } from "../config";
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { z } from 'zod';
|
||||
import { generateObject } from 'ai';
|
||||
import { modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
|
||||
import { EvaluationResponse } from '../types';
|
||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||
|
||||
const responseSchema = {
|
||||
type: SchemaType.OBJECT,
|
||||
properties: {
|
||||
is_definitive: {
|
||||
type: SchemaType.BOOLEAN,
|
||||
description: "Whether the answer provides a definitive response without uncertainty or 'I don't know' type statements"
|
||||
},
|
||||
reasoning: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Explanation of why the answer is or isn't definitive"
|
||||
}
|
||||
},
|
||||
required: ["is_definitive", "reasoning"]
|
||||
};
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
|
||||
const model = genAI.getGenerativeModel({
|
||||
model: modelConfigs.evaluator.model,
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.evaluator.temperature,
|
||||
responseMimeType: "application/json",
|
||||
responseSchema: responseSchema
|
||||
}
|
||||
const responseSchema = z.object({
|
||||
is_definitive: z.boolean().describe('Whether the answer provides a definitive response without uncertainty or \'I don\'t know\' type statements'),
|
||||
reasoning: z.string().describe('Explanation of why the answer is or isn\'t definitive')
|
||||
});
|
||||
|
||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.evaluator.model);
|
||||
|
||||
function getPrompt(question: string, answer: string): string {
|
||||
return `You are an evaluator of answer definitiveness. Analyze if the given answer provides a definitive response or not.
|
||||
|
||||
@@ -66,17 +50,28 @@ Answer: ${JSON.stringify(answer)}`;
|
||||
export async function evaluateAnswer(question: string, answer: string, tracker?: TokenTracker): Promise<{ response: EvaluationResponse, tokens: number }> {
|
||||
try {
|
||||
const prompt = getPrompt(question, answer);
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
const json = JSON.parse(response.text()) as EvaluationResponse;
|
||||
let object;
|
||||
let totalTokens = 0;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
schema: responseSchema,
|
||||
prompt,
|
||||
maxTokens: modelConfigs.evaluator.maxTokens
|
||||
});
|
||||
object = result.object;
|
||||
totalTokens = result.usage?.totalTokens || 0;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<EvaluationResponse>(error);
|
||||
object = result.object;
|
||||
totalTokens = result.totalTokens;
|
||||
}
|
||||
console.log('Evaluation:', {
|
||||
definitive: json.is_definitive,
|
||||
reason: json.reasoning
|
||||
definitive: object.is_definitive,
|
||||
reason: object.reasoning
|
||||
});
|
||||
const tokens = usage?.totalTokenCount || 0;
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', tokens);
|
||||
return { response: json, tokens };
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', totalTokens);
|
||||
return { response: object, tokens: totalTokens };
|
||||
} catch (error) {
|
||||
console.error('Error in answer evaluation:', error);
|
||||
throw error;
|
||||
|
||||
@@ -1,41 +1,21 @@
|
||||
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
|
||||
import { GEMINI_API_KEY, modelConfigs } from "../config";
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { z } from 'zod';
|
||||
import { modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { SearchAction } from "../types";
|
||||
import { SearchAction, KeywordsResponse } from '../types';
|
||||
import { generateObject } from 'ai';
|
||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||
|
||||
import { KeywordsResponse } from '../types';
|
||||
|
||||
const responseSchema = {
|
||||
type: SchemaType.OBJECT,
|
||||
properties: {
|
||||
think: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Strategic reasoning about query complexity and search approach"
|
||||
},
|
||||
queries: {
|
||||
type: SchemaType.ARRAY,
|
||||
items: {
|
||||
type: SchemaType.STRING,
|
||||
description: "Search query, must be less than 30 characters"
|
||||
},
|
||||
description: "Array of search queries, orthogonal to each other",
|
||||
minItems: 1,
|
||||
maxItems: 3
|
||||
}
|
||||
},
|
||||
required: ["think", "queries"]
|
||||
};
|
||||
|
||||
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
|
||||
const model = genAI.getGenerativeModel({
|
||||
model: modelConfigs.queryRewriter.model,
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.queryRewriter.temperature,
|
||||
responseMimeType: "application/json",
|
||||
responseSchema: responseSchema
|
||||
}
|
||||
const responseSchema = z.object({
|
||||
think: z.string().describe('Strategic reasoning about query complexity and search approach'),
|
||||
queries: z.array(z.string().describe('Search query, must be less than 30 characters'))
|
||||
.min(1)
|
||||
.max(3)
|
||||
.describe('Array of search queries, orthogonal to each other')
|
||||
});
|
||||
|
||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.queryRewriter.model);
|
||||
|
||||
function getPrompt(action: SearchAction): string {
|
||||
return `You are an expert Information Retrieval Assistant. Transform user queries into precise keyword combinations with strategic reasoning and appropriate search operators.
|
||||
|
||||
@@ -115,18 +95,27 @@ Intention: ${action.think}
|
||||
export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[], tokens: number }> {
|
||||
try {
|
||||
const prompt = getPrompt(action);
|
||||
const result = await model.generateContent(prompt);
|
||||
const response = await result.response;
|
||||
const usage = response.usageMetadata;
|
||||
const json = JSON.parse(response.text()) as KeywordsResponse;
|
||||
|
||||
console.log('Query rewriter:', json.queries);
|
||||
const tokens = usage?.totalTokenCount || 0;
|
||||
let object;
|
||||
let tokens = 0;
|
||||
try {
|
||||
const result = await generateObject({
|
||||
model,
|
||||
schema: responseSchema,
|
||||
prompt,
|
||||
maxTokens: modelConfigs.queryRewriter.maxTokens
|
||||
});
|
||||
object = result.object;
|
||||
tokens = result.usage?.totalTokens || 0;
|
||||
} catch (error) {
|
||||
const result = await handleGenerateObjectError<KeywordsResponse>(error);
|
||||
object = result.object;
|
||||
tokens = result.totalTokens;
|
||||
}
|
||||
console.log('Query rewriter:', object.queries);
|
||||
(tracker || new TokenTracker()).trackUsage('query-rewriter', tokens);
|
||||
|
||||
return { queries: json.queries, tokens };
|
||||
return { queries: object.queries, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in query rewriting:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
37
src/types.ts
37
src/types.ts
@@ -1,4 +1,17 @@
|
||||
import { SchemaType } from "@google/generative-ai";
|
||||
import { z } from 'zod';
|
||||
|
||||
export const ThinkSchema = z.string().describe('Strategic reasoning about the process');
|
||||
|
||||
export const QuerySchema = z.string()
|
||||
.max(30)
|
||||
.describe('Search query, must be less than 30 characters');
|
||||
|
||||
export const URLSchema = z.string().url();
|
||||
|
||||
export const ReferenceSchema = z.object({
|
||||
exactQuote: z.string().describe('Exact relevant quote from the document'),
|
||||
url: URLSchema.describe('URL of the document')
|
||||
});
|
||||
|
||||
// Action Types
|
||||
type BaseAction = {
|
||||
@@ -119,28 +132,6 @@ export type KeywordsResponse = {
|
||||
queries: string[];
|
||||
};
|
||||
|
||||
// Schema Types
|
||||
export type SchemaProperty = {
|
||||
type: SchemaType;
|
||||
description: string;
|
||||
enum?: string[];
|
||||
items?: {
|
||||
type: SchemaType;
|
||||
description?: string;
|
||||
properties?: Record<string, SchemaProperty>;
|
||||
required?: string[];
|
||||
};
|
||||
properties?: Record<string, SchemaProperty>;
|
||||
required?: string[];
|
||||
maxItems?: number;
|
||||
};
|
||||
|
||||
export type ResponseSchema = {
|
||||
type: SchemaType;
|
||||
properties: Record<string, SchemaProperty>;
|
||||
required: string[];
|
||||
};
|
||||
|
||||
export interface StreamMessage {
|
||||
type: 'progress' | 'answer' | 'error';
|
||||
data: string | StepAction;
|
||||
|
||||
22
src/utils/error-handling.ts
Normal file
22
src/utils/error-handling.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import {NoObjectGeneratedError} from "ai";
|
||||
|
||||
export interface GenerateObjectResult<T> {
|
||||
object: T;
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export async function handleGenerateObjectError<T>(error: unknown): Promise<GenerateObjectResult<T>> {
|
||||
if (NoObjectGeneratedError.isInstance(error)) {
|
||||
console.error('Object not generated according to the schema, fallback to manual parsing');
|
||||
try {
|
||||
const partialResponse = JSON.parse((error as any).text);
|
||||
return {
|
||||
object: partialResponse as T,
|
||||
totalTokens: (error as any).usage?.totalTokens || 0
|
||||
};
|
||||
} catch (parseError) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
Reference in New Issue
Block a user