feat: add OpenAI provider with structured output support (#28)

* feat: add OpenAI provider with structured output support

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* fix: add @ai-sdk/openai dependency and fix modelConfigs access

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* fix: correct indentation in agent.ts

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* refactor: centralize model initialization in config.ts

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* refactor: improve model config access patterns

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* fix: remove unused imports

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* refactor: clean up

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Han Xiao <han.xiao@jina.ai>
This commit is contained in:
devin-ai-integration[bot] 2025-02-06 15:05:38 +08:00 committed by GitHub
parent f1c7ada6ae
commit 50dff0863c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 271 additions and 100 deletions

View File

@ -15,6 +15,7 @@ COPY . .
# Set environment variables # Set environment variables
ENV GEMINI_API_KEY=${GEMINI_API_KEY} ENV GEMINI_API_KEY=${GEMINI_API_KEY}
ENV OPENAI_API_KEY=${OPENAI_API_KEY}
ENV JINA_API_KEY=${JINA_API_KEY} ENV JINA_API_KEY=${JINA_API_KEY}
ENV BRAVE_API_KEY=${BRAVE_API_KEY} ENV BRAVE_API_KEY=${BRAVE_API_KEY}

View File

@ -25,12 +25,7 @@ flowchart LR
## Install ## Install
We use gemini for llm, [jina reader](https://jina.ai/reader) for searching and reading webpages.
```bash ```bash
export GEMINI_API_KEY=... # for gemini api, ask han
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader
git clone https://github.com/jina-ai/node-DeepResearch.git git clone https://github.com/jina-ai/node-DeepResearch.git
cd node-DeepResearch cd node-DeepResearch
npm install npm install
@ -39,7 +34,14 @@ npm install
## Usage ## Usage
We use Gemini/OpenAI for reasoning, [Jina Reader](https://jina.ai/reader) for searching and reading webpages, you can get a free API key with 1M tokens from jina.ai.
```bash ```bash
export GEMINI_API_KEY=... # for gemini
# export OPENAI_API_KEY=... # for openai
# export LLM_PROVIDER=openai # for openai
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader
npm run dev $QUERY npm run dev $QUERY
``` ```

View File

@ -7,6 +7,7 @@ services:
dockerfile: Dockerfile dockerfile: Dockerfile
environment: environment:
- GEMINI_API_KEY=${GEMINI_API_KEY} - GEMINI_API_KEY=${GEMINI_API_KEY}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- JINA_API_KEY=${JINA_API_KEY} - JINA_API_KEY=${JINA_API_KEY}
- BRAVE_API_KEY=${BRAVE_API_KEY} - BRAVE_API_KEY=${BRAVE_API_KEY}
ports: ports:

23
package-lock.json generated
View File

@ -1,15 +1,16 @@
{ {
"name": "agentic-search", "name": "node-deepresearch",
"version": "1.0.0", "version": "1.0.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "agentic-search", "name": "node-deepresearch",
"version": "1.0.0", "version": "1.0.0",
"license": "ISC", "license": "Apache-2.0",
"dependencies": { "dependencies": {
"@ai-sdk/google": "^1.0.0", "@ai-sdk/google": "^1.0.0",
"@ai-sdk/openai": "^1.1.9",
"@types/cors": "^2.8.17", "@types/cors": "^2.8.17",
"@types/express": "^5.0.0", "@types/express": "^5.0.0",
"@types/node-fetch": "^2.6.12", "@types/node-fetch": "^2.6.12",
@ -51,6 +52,22 @@
"zod": "^3.0.0" "zod": "^3.0.0"
} }
}, },
"node_modules/@ai-sdk/openai": {
"version": "1.1.9",
"resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.1.9.tgz",
"integrity": "sha512-t/CpC4TLipdbgBJTMX/otzzqzCMBSPQwUOkYPGbT/jyuC86F+YO9o+LS0Ty2pGUE1kyT+B3WmJ318B16ZCg4hw==",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/provider": "1.0.7",
"@ai-sdk/provider-utils": "2.1.6"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"zod": "^3.0.0"
}
},
"node_modules/@ai-sdk/provider": { "node_modules/@ai-sdk/provider": {
"version": "1.0.7", "version": "1.0.7",
"resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.0.7.tgz", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.0.7.tgz",

View File

@ -19,6 +19,7 @@
"description": "", "description": "",
"dependencies": { "dependencies": {
"@ai-sdk/google": "^1.0.0", "@ai-sdk/google": "^1.0.0",
"@ai-sdk/openai": "^1.1.9",
"@types/cors": "^2.8.17", "@types/cors": "^2.8.17",
"@types/express": "^5.0.0", "@types/express": "^5.0.0",
"@types/node-fetch": "^2.6.12", "@types/node-fetch": "^2.6.12",

View File

@ -1,6 +1,6 @@
import {createGoogleGenerativeAI} from '@ai-sdk/google';
import {z} from 'zod'; import {z} from 'zod';
import {generateObject} from 'ai'; import {generateObject} from 'ai';
import {getModel, getMaxTokens, SEARCH_PROVIDER, STEP_SLEEP} from "./config";
import {readUrl} from "./tools/read"; import {readUrl} from "./tools/read";
import {handleGenerateObjectError} from './utils/error-handling'; import {handleGenerateObjectError} from './utils/error-handling';
import fs from 'fs/promises'; import fs from 'fs/promises';
@ -10,7 +10,6 @@ import {rewriteQuery} from "./tools/query-rewriter";
import {dedupQueries} from "./tools/dedup"; import {dedupQueries} from "./tools/dedup";
import {evaluateAnswer} from "./tools/evaluator"; import {evaluateAnswer} from "./tools/evaluator";
import {analyzeSteps} from "./tools/error-analyzer"; import {analyzeSteps} from "./tools/error-analyzer";
import {SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
import {TokenTracker} from "./utils/token-tracker"; import {TokenTracker} from "./utils/token-tracker";
import {ActionTracker} from "./utils/action-tracker"; import {ActionTracker} from "./utils/action-tracker";
import {StepAction, AnswerAction} from "./types"; import {StepAction, AnswerAction} from "./types";
@ -325,7 +324,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
false false
); );
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agent.model); const model = getModel('agent');
let object; let object;
let totalTokens = 0; let totalTokens = 0;
try { try {
@ -333,7 +332,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
model, model,
schema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch), schema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch),
prompt, prompt,
maxTokens: modelConfigs.agent.maxTokens maxTokens: getMaxTokens('agent')
}); });
object = result.object; object = result.object;
totalTokens = result.usage?.totalTokens || 0; totalTokens = result.usage?.totalTokens || 0;
@ -671,7 +670,7 @@ You decided to think out of the box or cut from a completely different angle.`);
true true
); );
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agentBeastMode.model); const model = getModel('agentBeastMode');
let object; let object;
let totalTokens = 0; let totalTokens = 0;
try { try {
@ -679,7 +678,7 @@ You decided to think out of the box or cut from a completely different angle.`);
model, model,
schema: getSchema(false, false, allowAnswer, false), schema: getSchema(false, false, allowAnswer, false),
prompt, prompt,
maxTokens: modelConfigs.agentBeastMode.maxTokens maxTokens: getMaxTokens('agentBeastMode')
}); });
object = result.object; object = result.object;
totalTokens = result.usage?.totalTokens || 0; totalTokens = result.usage?.totalTokens || 0;

View File

@ -1,13 +1,35 @@
import dotenv from 'dotenv'; import dotenv from 'dotenv';
import { ProxyAgent, setGlobalDispatcher } from 'undici'; import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createOpenAI } from '@ai-sdk/openai';
interface ModelConfig { export type LLMProvider = 'openai' | 'gemini';
export type ToolName = keyof ToolConfigs;
function isValidProvider(provider: string): provider is LLMProvider {
return provider === 'openai' || provider === 'gemini';
}
function validateModelConfig(config: ModelConfig, toolName: string): ModelConfig {
if (typeof config.model !== 'string' || config.model.length === 0) {
throw new Error(`Invalid model name for ${toolName}`);
}
if (typeof config.temperature !== 'number' || config.temperature < 0 || config.temperature > 1) {
throw new Error(`Invalid temperature for ${toolName}`);
}
if (typeof config.maxTokens !== 'number' || config.maxTokens <= 0) {
throw new Error(`Invalid maxTokens for ${toolName}`);
}
return config;
}
export interface ModelConfig {
model: string; model: string;
temperature: number; temperature: number;
maxTokens: number; maxTokens: number;
} }
interface ToolConfigs { export interface ToolConfigs {
dedup: ModelConfig; dedup: ModelConfig;
evaluator: ModelConfig; evaluator: ModelConfig;
errorAnalyzer: ModelConfig; errorAnalyzer: ModelConfig;
@ -31,44 +53,87 @@ if (process.env.https_proxy) {
} }
export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string; export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string;
export const OPENAI_API_KEY = process.env.OPENAI_API_KEY as string;
export const JINA_API_KEY = process.env.JINA_API_KEY as string; export const JINA_API_KEY = process.env.JINA_API_KEY as string;
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string; export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string;
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina' export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina';
export const LLM_PROVIDER: LLMProvider = (() => {
const provider = process.env.LLM_PROVIDER || 'gemini';
if (!isValidProvider(provider)) {
throw new Error(`Invalid LLM provider: ${provider}`);
}
return provider;
})();
const DEFAULT_MODEL = 'gemini-1.5-flash'; const DEFAULT_GEMINI_MODEL = 'gemini-1.5-flash';
const DEFAULT_OPENAI_MODEL = 'gpt-4o-mini';
const defaultConfig: ModelConfig = { const defaultGeminiConfig: ModelConfig = {
model: DEFAULT_MODEL, model: DEFAULT_GEMINI_MODEL,
temperature: 0, temperature: 0,
maxTokens: 1000 maxTokens: 1000
}; };
export const modelConfigs: ToolConfigs = { const defaultOpenAIConfig: ModelConfig = {
dedup: { model: DEFAULT_OPENAI_MODEL,
...defaultConfig, temperature: 0,
temperature: 0.1 maxTokens: 1000
};
export const modelConfigs: Record<LLMProvider, ToolConfigs> = {
gemini: {
dedup: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'dedup'),
evaluator: validateModelConfig({ ...defaultGeminiConfig }, 'evaluator'),
errorAnalyzer: validateModelConfig({ ...defaultGeminiConfig }, 'errorAnalyzer'),
queryRewriter: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'queryRewriter'),
agent: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agent'),
agentBeastMode: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agentBeastMode')
}, },
evaluator: { openai: {
...defaultConfig dedup: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'dedup'),
}, evaluator: validateModelConfig({ ...defaultOpenAIConfig }, 'evaluator'),
errorAnalyzer: { errorAnalyzer: validateModelConfig({ ...defaultOpenAIConfig }, 'errorAnalyzer'),
...defaultConfig queryRewriter: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'queryRewriter'),
}, agent: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agent'),
queryRewriter: { agentBeastMode: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agentBeastMode')
...defaultConfig,
temperature: 0.1
},
agent: {
...defaultConfig,
temperature: 0.7
},
agentBeastMode: {
...defaultConfig,
temperature: 0.7
} }
}; };
export function getToolConfig(toolName: ToolName): ModelConfig {
if (!modelConfigs[LLM_PROVIDER][toolName]) {
throw new Error(`Invalid tool name: ${toolName}`);
}
return modelConfigs[LLM_PROVIDER][toolName];
}
export function getMaxTokens(toolName: ToolName): number {
return getToolConfig(toolName).maxTokens;
}
export function getModel(toolName: ToolName) {
const config = getToolConfig(toolName);
if (LLM_PROVIDER === 'openai') {
if (!OPENAI_API_KEY) {
throw new Error('OPENAI_API_KEY not found');
}
return createOpenAI({
apiKey: OPENAI_API_KEY,
compatibility: 'strict'
})(config.model);
}
if (!GEMINI_API_KEY) {
throw new Error('GEMINI_API_KEY not found');
}
return createGoogleGenerativeAI({ apiKey: GEMINI_API_KEY })(config.model);
}
export const STEP_SLEEP = 1000; export const STEP_SLEEP = 1000;
if (!GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found"); if (LLM_PROVIDER === 'gemini' && !GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
if (LLM_PROVIDER === 'openai' && !OPENAI_API_KEY) throw new Error("OPENAI_API_KEY not found");
if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found"); if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found");
console.log('LLM Provider:', LLM_PROVIDER)

View File

@ -1,16 +1,37 @@
import { dedupQueries } from '../dedup'; import { dedupQueries } from '../dedup';
import { LLMProvider } from '../../config';
describe('dedupQueries', () => { describe('dedupQueries', () => {
it('should remove duplicate queries', async () => { const providers: Array<LLMProvider> = ['openai', 'gemini'];
jest.setTimeout(10000); // Increase timeout to 10s const originalEnv = process.env;
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
const { unique_queries } = await dedupQueries(queries, []); beforeEach(() => {
expect(unique_queries).toHaveLength(2); jest.resetModules();
expect(unique_queries).toContain('javascript basics'); process.env = { ...originalEnv };
}); });
it('should handle empty input', async () => { afterEach(() => {
const { unique_queries } = await dedupQueries([], []); process.env = originalEnv;
expect(unique_queries).toHaveLength(0); });
providers.forEach(provider => {
describe(`with ${provider} provider`, () => {
beforeEach(() => {
process.env.LLM_PROVIDER = provider;
});
it('should remove duplicate queries', async () => {
jest.setTimeout(10000);
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
const { unique_queries } = await dedupQueries(queries, []);
expect(unique_queries).toHaveLength(2);
expect(unique_queries).toContain('javascript basics');
});
it('should handle empty input', async () => {
const { unique_queries } = await dedupQueries([], []);
expect(unique_queries).toHaveLength(0);
});
});
}); });
}); });

View File

@ -1,10 +1,31 @@
import { analyzeSteps } from '../error-analyzer'; import { analyzeSteps } from '../error-analyzer';
import { LLMProvider } from '../../config';
describe('analyzeSteps', () => { describe('analyzeSteps', () => {
it('should analyze error steps', async () => { const providers: Array<LLMProvider> = ['openai', 'gemini'];
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']); const originalEnv = process.env;
expect(response).toHaveProperty('recap');
expect(response).toHaveProperty('blame'); beforeEach(() => {
expect(response).toHaveProperty('improvement'); jest.resetModules();
process.env = { ...originalEnv };
});
afterEach(() => {
process.env = originalEnv;
});
providers.forEach(provider => {
describe(`with ${provider} provider`, () => {
beforeEach(() => {
process.env.LLM_PROVIDER = provider;
});
it('should analyze error steps', async () => {
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']);
expect(response).toHaveProperty('recap');
expect(response).toHaveProperty('blame');
expect(response).toHaveProperty('improvement');
});
});
}); });
}); });

View File

@ -1,27 +1,48 @@
import { evaluateAnswer } from '../evaluator'; import { evaluateAnswer } from '../evaluator';
import { TokenTracker } from '../../utils/token-tracker'; import { TokenTracker } from '../../utils/token-tracker';
import { LLMProvider } from '../../config';
describe('evaluateAnswer', () => { describe('evaluateAnswer', () => {
it('should evaluate answer definitiveness', async () => { const providers: Array<LLMProvider> = ['openai', 'gemini'];
const tokenTracker = new TokenTracker(); const originalEnv = process.env;
const { response } = await evaluateAnswer(
'What is TypeScript?', beforeEach(() => {
'TypeScript is a strongly typed programming language that builds on JavaScript.', jest.resetModules();
tokenTracker process.env = { ...originalEnv };
);
expect(response).toHaveProperty('is_definitive');
expect(response).toHaveProperty('reasoning');
}); });
it('should track token usage', async () => { afterEach(() => {
const tokenTracker = new TokenTracker(); process.env = originalEnv;
const spy = jest.spyOn(tokenTracker, 'trackUsage'); });
const { tokens } = await evaluateAnswer(
'What is TypeScript?', providers.forEach(provider => {
'TypeScript is a strongly typed programming language that builds on JavaScript.', describe(`with ${provider} provider`, () => {
tokenTracker beforeEach(() => {
); process.env.LLM_PROVIDER = provider;
expect(spy).toHaveBeenCalledWith('evaluator', tokens); });
expect(tokens).toBeGreaterThan(0);
it('should evaluate answer definitiveness', async () => {
const tokenTracker = new TokenTracker();
const { response } = await evaluateAnswer(
'What is TypeScript?',
'TypeScript is a strongly typed programming language that builds on JavaScript.',
tokenTracker
);
expect(response).toHaveProperty('is_definitive');
expect(response).toHaveProperty('reasoning');
});
it('should track token usage', async () => {
const tokenTracker = new TokenTracker();
const spy = jest.spyOn(tokenTracker, 'trackUsage');
const { tokens } = await evaluateAnswer(
'What is TypeScript?',
'TypeScript is a strongly typed programming language that builds on JavaScript.',
tokenTracker
);
expect(spy).toHaveBeenCalledWith('evaluator', tokens);
expect(tokens).toBeGreaterThan(0);
});
});
}); });
}); });

View File

@ -1,13 +1,34 @@
import { rewriteQuery } from '../query-rewriter'; import { rewriteQuery } from '../query-rewriter';
import { LLMProvider } from '../../config';
describe('rewriteQuery', () => { describe('rewriteQuery', () => {
it('should rewrite search query', async () => { const providers: Array<LLMProvider> = ['openai', 'gemini'];
const { queries } = await rewriteQuery({ const originalEnv = process.env;
action: 'search',
searchQuery: 'how does typescript work', beforeEach(() => {
think: 'Understanding TypeScript basics' jest.resetModules();
process.env = { ...originalEnv };
});
afterEach(() => {
process.env = originalEnv;
});
providers.forEach(provider => {
describe(`with ${provider} provider`, () => {
beforeEach(() => {
process.env.LLM_PROVIDER = provider;
});
it('should rewrite search query', async () => {
const { queries } = await rewriteQuery({
action: 'search',
searchQuery: 'how does typescript work',
think: 'Understanding TypeScript basics'
});
expect(Array.isArray(queries)).toBe(true);
expect(queries.length).toBeGreaterThan(0);
});
}); });
expect(Array.isArray(queries)).toBe(true);
expect(queries.length).toBeGreaterThan(0);
}); });
}); });

View File

@ -1,11 +1,11 @@
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { z } from 'zod'; import { z } from 'zod';
import { generateObject } from 'ai'; import { generateObject } from 'ai';
import { modelConfigs } from "../config"; import { getModel, getMaxTokens } from "../config";
import { TokenTracker } from "../utils/token-tracker"; import { TokenTracker } from "../utils/token-tracker";
import { handleGenerateObjectError } from '../utils/error-handling'; import { handleGenerateObjectError } from '../utils/error-handling';
import type { DedupResponse } from '../types'; import type { DedupResponse } from '../types';
const model = getModel('dedup');
const responseSchema = z.object({ const responseSchema = z.object({
think: z.string().describe('Strategic reasoning about the overall deduplication approach'), think: z.string().describe('Strategic reasoning about the overall deduplication approach'),
@ -13,8 +13,6 @@ const responseSchema = z.object({
.describe('Array of semantically unique queries').max(3) .describe('Array of semantically unique queries').max(3)
}); });
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.dedup.model);
function getPrompt(newQueries: string[], existingQueries: string[]): string { function getPrompt(newQueries: string[], existingQueries: string[]): string {
return `You are an expert in semantic similarity analysis. Given a set of queries (setA) and a set of queries (setB) return `You are an expert in semantic similarity analysis. Given a set of queries (setA) and a set of queries (setB)
@ -77,7 +75,7 @@ export async function dedupQueries(newQueries: string[], existingQueries: string
model, model,
schema: responseSchema, schema: responseSchema,
prompt, prompt,
maxTokens: modelConfigs.dedup.maxTokens maxTokens: getMaxTokens('dedup')
}); });
object = result.object; object = result.object;
tokens = result.usage?.totalTokens || 0; tokens = result.usage?.totalTokens || 0;

View File

@ -1,18 +1,19 @@
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { z } from 'zod'; import { z } from 'zod';
import { generateObject } from 'ai'; import { generateObject } from 'ai';
import { modelConfigs } from "../config"; import { getModel, getMaxTokens } from "../config";
import { TokenTracker } from "../utils/token-tracker"; import { TokenTracker } from "../utils/token-tracker";
import { ErrorAnalysisResponse } from '../types'; import { ErrorAnalysisResponse } from '../types';
import { handleGenerateObjectError } from '../utils/error-handling'; import { handleGenerateObjectError } from '../utils/error-handling';
const model = getModel('errorAnalyzer');
const responseSchema = z.object({ const responseSchema = z.object({
recap: z.string().describe('Recap of the actions taken and the steps conducted'), recap: z.string().describe('Recap of the actions taken and the steps conducted'),
blame: z.string().describe('Which action or the step was the root cause of the answer rejection'), blame: z.string().describe('Which action or the step was the root cause of the answer rejection'),
improvement: z.string().describe('Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe.') improvement: z.string().describe('Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe.')
}); });
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.errorAnalyzer.model);
function getPrompt(diaryContext: string[]): string { function getPrompt(diaryContext: string[]): string {
return `You are an expert at analyzing search and reasoning processes. Your task is to analyze the given sequence of steps and identify what went wrong in the search process. return `You are an expert at analyzing search and reasoning processes. Your task is to analyze the given sequence of steps and identify what went wrong in the search process.
@ -112,7 +113,7 @@ export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracke
model, model,
schema: responseSchema, schema: responseSchema,
prompt, prompt,
maxTokens: modelConfigs.errorAnalyzer.maxTokens maxTokens: getMaxTokens('errorAnalyzer')
}); });
object = result.object; object = result.object;
tokens = result.usage?.totalTokens || 0; tokens = result.usage?.totalTokens || 0;

View File

@ -1,17 +1,18 @@
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { z } from 'zod'; import { z } from 'zod';
import { generateObject } from 'ai'; import { generateObject } from 'ai';
import { modelConfigs } from "../config"; import { getModel, getMaxTokens } from "../config";
import { TokenTracker } from "../utils/token-tracker"; import { TokenTracker } from "../utils/token-tracker";
import { EvaluationResponse } from '../types'; import { EvaluationResponse } from '../types';
import { handleGenerateObjectError } from '../utils/error-handling'; import { handleGenerateObjectError } from '../utils/error-handling';
const model = getModel('evaluator');
const responseSchema = z.object({ const responseSchema = z.object({
is_definitive: z.boolean().describe('Whether the answer provides a definitive response without uncertainty or \'I don\'t know\' type statements'), is_definitive: z.boolean().describe('Whether the answer provides a definitive response without uncertainty or \'I don\'t know\' type statements'),
reasoning: z.string().describe('Explanation of why the answer is or isn\'t definitive') reasoning: z.string().describe('Explanation of why the answer is or isn\'t definitive')
}); });
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.evaluator.model);
function getPrompt(question: string, answer: string): string { function getPrompt(question: string, answer: string): string {
return `You are an evaluator of answer definitiveness. Analyze if the given answer provides a definitive response or not. return `You are an evaluator of answer definitiveness. Analyze if the given answer provides a definitive response or not.
@ -57,7 +58,7 @@ export async function evaluateAnswer(question: string, answer: string, tracker?:
model, model,
schema: responseSchema, schema: responseSchema,
prompt, prompt,
maxTokens: modelConfigs.evaluator.maxTokens maxTokens: getMaxTokens('evaluator')
}); });
object = result.object; object = result.object;
totalTokens = result.usage?.totalTokens || 0; totalTokens = result.usage?.totalTokens || 0;

View File

@ -1,11 +1,12 @@
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { z } from 'zod'; import { z } from 'zod';
import { modelConfigs } from "../config"; import { generateObject } from 'ai';
import { getModel, getMaxTokens } from "../config";
import { TokenTracker } from "../utils/token-tracker"; import { TokenTracker } from "../utils/token-tracker";
import { SearchAction, KeywordsResponse } from '../types'; import { SearchAction, KeywordsResponse } from '../types';
import { generateObject } from 'ai';
import { handleGenerateObjectError } from '../utils/error-handling'; import { handleGenerateObjectError } from '../utils/error-handling';
const model = getModel('queryRewriter');
const responseSchema = z.object({ const responseSchema = z.object({
think: z.string().describe('Strategic reasoning about query complexity and search approach'), think: z.string().describe('Strategic reasoning about query complexity and search approach'),
queries: z.array(z.string().describe('Search query, must be less than 30 characters')) queries: z.array(z.string().describe('Search query, must be less than 30 characters'))
@ -14,7 +15,7 @@ const responseSchema = z.object({
.describe('Array of search queries, orthogonal to each other') .describe('Array of search queries, orthogonal to each other')
}); });
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.queryRewriter.model);
function getPrompt(action: SearchAction): string { function getPrompt(action: SearchAction): string {
return `You are an expert Information Retrieval Assistant. Transform user queries into precise keyword combinations with strategic reasoning and appropriate search operators. return `You are an expert Information Retrieval Assistant. Transform user queries into precise keyword combinations with strategic reasoning and appropriate search operators.
@ -102,7 +103,7 @@ export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker)
model, model,
schema: responseSchema, schema: responseSchema,
prompt, prompt,
maxTokens: modelConfigs.queryRewriter.maxTokens maxTokens: getMaxTokens('queryRewriter')
}); });
object = result.object; object = result.object;
tokens = result.usage?.totalTokens || 0; tokens = result.usage?.totalTokens || 0;