mirror of
https://github.com/jina-ai/node-DeepResearch.git
synced 2025-12-25 22:16:49 +08:00
feat: add OpenAI provider with structured output support (#28)
* feat: add OpenAI provider with structured output support Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: add @ai-sdk/openai dependency and fix modelConfigs access Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: correct indentation in agent.ts Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: centralize model initialization in config.ts Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: improve model config access patterns Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove unused imports Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: clean up --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Han Xiao <han.xiao@jina.ai>
This commit is contained in:
parent
f1c7ada6ae
commit
50dff0863c
@ -15,6 +15,7 @@ COPY . .
|
|||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV GEMINI_API_KEY=${GEMINI_API_KEY}
|
ENV GEMINI_API_KEY=${GEMINI_API_KEY}
|
||||||
|
ENV OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
ENV JINA_API_KEY=${JINA_API_KEY}
|
ENV JINA_API_KEY=${JINA_API_KEY}
|
||||||
ENV BRAVE_API_KEY=${BRAVE_API_KEY}
|
ENV BRAVE_API_KEY=${BRAVE_API_KEY}
|
||||||
|
|
||||||
|
|||||||
12
README.md
12
README.md
@ -25,12 +25,7 @@ flowchart LR
|
|||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
We use gemini for llm, [jina reader](https://jina.ai/reader) for searching and reading webpages.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export GEMINI_API_KEY=... # for gemini api, ask han
|
|
||||||
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader
|
|
||||||
|
|
||||||
git clone https://github.com/jina-ai/node-DeepResearch.git
|
git clone https://github.com/jina-ai/node-DeepResearch.git
|
||||||
cd node-DeepResearch
|
cd node-DeepResearch
|
||||||
npm install
|
npm install
|
||||||
@ -39,7 +34,14 @@ npm install
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
We use Gemini/OpenAI for reasoning, [Jina Reader](https://jina.ai/reader) for searching and reading webpages, you can get a free API key with 1M tokens from jina.ai.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
export GEMINI_API_KEY=... # for gemini
|
||||||
|
# export OPENAI_API_KEY=... # for openai
|
||||||
|
# export LLM_PROVIDER=openai # for openai
|
||||||
|
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader
|
||||||
|
|
||||||
npm run dev $QUERY
|
npm run dev $QUERY
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ services:
|
|||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
environment:
|
environment:
|
||||||
- GEMINI_API_KEY=${GEMINI_API_KEY}
|
- GEMINI_API_KEY=${GEMINI_API_KEY}
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
- JINA_API_KEY=${JINA_API_KEY}
|
- JINA_API_KEY=${JINA_API_KEY}
|
||||||
- BRAVE_API_KEY=${BRAVE_API_KEY}
|
- BRAVE_API_KEY=${BRAVE_API_KEY}
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
23
package-lock.json
generated
23
package-lock.json
generated
@ -1,15 +1,16 @@
|
|||||||
{
|
{
|
||||||
"name": "agentic-search",
|
"name": "node-deepresearch",
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "agentic-search",
|
"name": "node-deepresearch",
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"license": "ISC",
|
"license": "Apache-2.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/google": "^1.0.0",
|
"@ai-sdk/google": "^1.0.0",
|
||||||
|
"@ai-sdk/openai": "^1.1.9",
|
||||||
"@types/cors": "^2.8.17",
|
"@types/cors": "^2.8.17",
|
||||||
"@types/express": "^5.0.0",
|
"@types/express": "^5.0.0",
|
||||||
"@types/node-fetch": "^2.6.12",
|
"@types/node-fetch": "^2.6.12",
|
||||||
@ -51,6 +52,22 @@
|
|||||||
"zod": "^3.0.0"
|
"zod": "^3.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@ai-sdk/openai": {
|
||||||
|
"version": "1.1.9",
|
||||||
|
"resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.1.9.tgz",
|
||||||
|
"integrity": "sha512-t/CpC4TLipdbgBJTMX/otzzqzCMBSPQwUOkYPGbT/jyuC86F+YO9o+LS0Ty2pGUE1kyT+B3WmJ318B16ZCg4hw==",
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"dependencies": {
|
||||||
|
"@ai-sdk/provider": "1.0.7",
|
||||||
|
"@ai-sdk/provider-utils": "2.1.6"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"zod": "^3.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@ai-sdk/provider": {
|
"node_modules/@ai-sdk/provider": {
|
||||||
"version": "1.0.7",
|
"version": "1.0.7",
|
||||||
"resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.0.7.tgz",
|
"resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.0.7.tgz",
|
||||||
|
|||||||
@ -19,6 +19,7 @@
|
|||||||
"description": "",
|
"description": "",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/google": "^1.0.0",
|
"@ai-sdk/google": "^1.0.0",
|
||||||
|
"@ai-sdk/openai": "^1.1.9",
|
||||||
"@types/cors": "^2.8.17",
|
"@types/cors": "^2.8.17",
|
||||||
"@types/express": "^5.0.0",
|
"@types/express": "^5.0.0",
|
||||||
"@types/node-fetch": "^2.6.12",
|
"@types/node-fetch": "^2.6.12",
|
||||||
|
|||||||
11
src/agent.ts
11
src/agent.ts
@ -1,6 +1,6 @@
|
|||||||
import {createGoogleGenerativeAI} from '@ai-sdk/google';
|
|
||||||
import {z} from 'zod';
|
import {z} from 'zod';
|
||||||
import {generateObject} from 'ai';
|
import {generateObject} from 'ai';
|
||||||
|
import {getModel, getMaxTokens, SEARCH_PROVIDER, STEP_SLEEP} from "./config";
|
||||||
import {readUrl} from "./tools/read";
|
import {readUrl} from "./tools/read";
|
||||||
import {handleGenerateObjectError} from './utils/error-handling';
|
import {handleGenerateObjectError} from './utils/error-handling';
|
||||||
import fs from 'fs/promises';
|
import fs from 'fs/promises';
|
||||||
@ -10,7 +10,6 @@ import {rewriteQuery} from "./tools/query-rewriter";
|
|||||||
import {dedupQueries} from "./tools/dedup";
|
import {dedupQueries} from "./tools/dedup";
|
||||||
import {evaluateAnswer} from "./tools/evaluator";
|
import {evaluateAnswer} from "./tools/evaluator";
|
||||||
import {analyzeSteps} from "./tools/error-analyzer";
|
import {analyzeSteps} from "./tools/error-analyzer";
|
||||||
import {SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
|
|
||||||
import {TokenTracker} from "./utils/token-tracker";
|
import {TokenTracker} from "./utils/token-tracker";
|
||||||
import {ActionTracker} from "./utils/action-tracker";
|
import {ActionTracker} from "./utils/action-tracker";
|
||||||
import {StepAction, AnswerAction} from "./types";
|
import {StepAction, AnswerAction} from "./types";
|
||||||
@ -325,7 +324,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
|||||||
false
|
false
|
||||||
);
|
);
|
||||||
|
|
||||||
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agent.model);
|
const model = getModel('agent');
|
||||||
let object;
|
let object;
|
||||||
let totalTokens = 0;
|
let totalTokens = 0;
|
||||||
try {
|
try {
|
||||||
@ -333,7 +332,7 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
|||||||
model,
|
model,
|
||||||
schema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch),
|
schema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch),
|
||||||
prompt,
|
prompt,
|
||||||
maxTokens: modelConfigs.agent.maxTokens
|
maxTokens: getMaxTokens('agent')
|
||||||
});
|
});
|
||||||
object = result.object;
|
object = result.object;
|
||||||
totalTokens = result.usage?.totalTokens || 0;
|
totalTokens = result.usage?.totalTokens || 0;
|
||||||
@ -671,7 +670,7 @@ You decided to think out of the box or cut from a completely different angle.`);
|
|||||||
true
|
true
|
||||||
);
|
);
|
||||||
|
|
||||||
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agentBeastMode.model);
|
const model = getModel('agentBeastMode');
|
||||||
let object;
|
let object;
|
||||||
let totalTokens = 0;
|
let totalTokens = 0;
|
||||||
try {
|
try {
|
||||||
@ -679,7 +678,7 @@ You decided to think out of the box or cut from a completely different angle.`);
|
|||||||
model,
|
model,
|
||||||
schema: getSchema(false, false, allowAnswer, false),
|
schema: getSchema(false, false, allowAnswer, false),
|
||||||
prompt,
|
prompt,
|
||||||
maxTokens: modelConfigs.agentBeastMode.maxTokens
|
maxTokens: getMaxTokens('agentBeastMode')
|
||||||
});
|
});
|
||||||
object = result.object;
|
object = result.object;
|
||||||
totalTokens = result.usage?.totalTokens || 0;
|
totalTokens = result.usage?.totalTokens || 0;
|
||||||
|
|||||||
121
src/config.ts
121
src/config.ts
@ -1,13 +1,35 @@
|
|||||||
import dotenv from 'dotenv';
|
import dotenv from 'dotenv';
|
||||||
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||||
|
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||||
|
import { createOpenAI } from '@ai-sdk/openai';
|
||||||
|
|
||||||
interface ModelConfig {
|
export type LLMProvider = 'openai' | 'gemini';
|
||||||
|
export type ToolName = keyof ToolConfigs;
|
||||||
|
|
||||||
|
function isValidProvider(provider: string): provider is LLMProvider {
|
||||||
|
return provider === 'openai' || provider === 'gemini';
|
||||||
|
}
|
||||||
|
|
||||||
|
function validateModelConfig(config: ModelConfig, toolName: string): ModelConfig {
|
||||||
|
if (typeof config.model !== 'string' || config.model.length === 0) {
|
||||||
|
throw new Error(`Invalid model name for ${toolName}`);
|
||||||
|
}
|
||||||
|
if (typeof config.temperature !== 'number' || config.temperature < 0 || config.temperature > 1) {
|
||||||
|
throw new Error(`Invalid temperature for ${toolName}`);
|
||||||
|
}
|
||||||
|
if (typeof config.maxTokens !== 'number' || config.maxTokens <= 0) {
|
||||||
|
throw new Error(`Invalid maxTokens for ${toolName}`);
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelConfig {
|
||||||
model: string;
|
model: string;
|
||||||
temperature: number;
|
temperature: number;
|
||||||
maxTokens: number;
|
maxTokens: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ToolConfigs {
|
export interface ToolConfigs {
|
||||||
dedup: ModelConfig;
|
dedup: ModelConfig;
|
||||||
evaluator: ModelConfig;
|
evaluator: ModelConfig;
|
||||||
errorAnalyzer: ModelConfig;
|
errorAnalyzer: ModelConfig;
|
||||||
@ -31,44 +53,87 @@ if (process.env.https_proxy) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string;
|
export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string;
|
||||||
|
export const OPENAI_API_KEY = process.env.OPENAI_API_KEY as string;
|
||||||
export const JINA_API_KEY = process.env.JINA_API_KEY as string;
|
export const JINA_API_KEY = process.env.JINA_API_KEY as string;
|
||||||
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string;
|
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string;
|
||||||
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina'
|
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina';
|
||||||
|
export const LLM_PROVIDER: LLMProvider = (() => {
|
||||||
|
const provider = process.env.LLM_PROVIDER || 'gemini';
|
||||||
|
if (!isValidProvider(provider)) {
|
||||||
|
throw new Error(`Invalid LLM provider: ${provider}`);
|
||||||
|
}
|
||||||
|
return provider;
|
||||||
|
})();
|
||||||
|
|
||||||
const DEFAULT_MODEL = 'gemini-1.5-flash';
|
const DEFAULT_GEMINI_MODEL = 'gemini-1.5-flash';
|
||||||
|
const DEFAULT_OPENAI_MODEL = 'gpt-4o-mini';
|
||||||
|
|
||||||
const defaultConfig: ModelConfig = {
|
const defaultGeminiConfig: ModelConfig = {
|
||||||
model: DEFAULT_MODEL,
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
maxTokens: 1000
|
maxTokens: 1000
|
||||||
};
|
};
|
||||||
|
|
||||||
export const modelConfigs: ToolConfigs = {
|
const defaultOpenAIConfig: ModelConfig = {
|
||||||
dedup: {
|
model: DEFAULT_OPENAI_MODEL,
|
||||||
...defaultConfig,
|
temperature: 0,
|
||||||
temperature: 0.1
|
maxTokens: 1000
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelConfigs: Record<LLMProvider, ToolConfigs> = {
|
||||||
|
gemini: {
|
||||||
|
dedup: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'dedup'),
|
||||||
|
evaluator: validateModelConfig({ ...defaultGeminiConfig }, 'evaluator'),
|
||||||
|
errorAnalyzer: validateModelConfig({ ...defaultGeminiConfig }, 'errorAnalyzer'),
|
||||||
|
queryRewriter: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'queryRewriter'),
|
||||||
|
agent: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agent'),
|
||||||
|
agentBeastMode: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agentBeastMode')
|
||||||
},
|
},
|
||||||
evaluator: {
|
openai: {
|
||||||
...defaultConfig
|
dedup: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'dedup'),
|
||||||
},
|
evaluator: validateModelConfig({ ...defaultOpenAIConfig }, 'evaluator'),
|
||||||
errorAnalyzer: {
|
errorAnalyzer: validateModelConfig({ ...defaultOpenAIConfig }, 'errorAnalyzer'),
|
||||||
...defaultConfig
|
queryRewriter: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'queryRewriter'),
|
||||||
},
|
agent: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agent'),
|
||||||
queryRewriter: {
|
agentBeastMode: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agentBeastMode')
|
||||||
...defaultConfig,
|
|
||||||
temperature: 0.1
|
|
||||||
},
|
|
||||||
agent: {
|
|
||||||
...defaultConfig,
|
|
||||||
temperature: 0.7
|
|
||||||
},
|
|
||||||
agentBeastMode: {
|
|
||||||
...defaultConfig,
|
|
||||||
temperature: 0.7
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export function getToolConfig(toolName: ToolName): ModelConfig {
|
||||||
|
if (!modelConfigs[LLM_PROVIDER][toolName]) {
|
||||||
|
throw new Error(`Invalid tool name: ${toolName}`);
|
||||||
|
}
|
||||||
|
return modelConfigs[LLM_PROVIDER][toolName];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getMaxTokens(toolName: ToolName): number {
|
||||||
|
return getToolConfig(toolName).maxTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
export function getModel(toolName: ToolName) {
|
||||||
|
const config = getToolConfig(toolName);
|
||||||
|
|
||||||
|
if (LLM_PROVIDER === 'openai') {
|
||||||
|
if (!OPENAI_API_KEY) {
|
||||||
|
throw new Error('OPENAI_API_KEY not found');
|
||||||
|
}
|
||||||
|
return createOpenAI({
|
||||||
|
apiKey: OPENAI_API_KEY,
|
||||||
|
compatibility: 'strict'
|
||||||
|
})(config.model);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!GEMINI_API_KEY) {
|
||||||
|
throw new Error('GEMINI_API_KEY not found');
|
||||||
|
}
|
||||||
|
return createGoogleGenerativeAI({ apiKey: GEMINI_API_KEY })(config.model);
|
||||||
|
}
|
||||||
|
|
||||||
export const STEP_SLEEP = 1000;
|
export const STEP_SLEEP = 1000;
|
||||||
|
|
||||||
if (!GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
|
if (LLM_PROVIDER === 'gemini' && !GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
|
||||||
|
if (LLM_PROVIDER === 'openai' && !OPENAI_API_KEY) throw new Error("OPENAI_API_KEY not found");
|
||||||
if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found");
|
if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found");
|
||||||
|
|
||||||
|
console.log('LLM Provider:', LLM_PROVIDER)
|
||||||
@ -1,16 +1,37 @@
|
|||||||
import { dedupQueries } from '../dedup';
|
import { dedupQueries } from '../dedup';
|
||||||
|
import { LLMProvider } from '../../config';
|
||||||
|
|
||||||
describe('dedupQueries', () => {
|
describe('dedupQueries', () => {
|
||||||
it('should remove duplicate queries', async () => {
|
const providers: Array<LLMProvider> = ['openai', 'gemini'];
|
||||||
jest.setTimeout(10000); // Increase timeout to 10s
|
const originalEnv = process.env;
|
||||||
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
|
|
||||||
const { unique_queries } = await dedupQueries(queries, []);
|
beforeEach(() => {
|
||||||
expect(unique_queries).toHaveLength(2);
|
jest.resetModules();
|
||||||
expect(unique_queries).toContain('javascript basics');
|
process.env = { ...originalEnv };
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle empty input', async () => {
|
afterEach(() => {
|
||||||
const { unique_queries } = await dedupQueries([], []);
|
process.env = originalEnv;
|
||||||
expect(unique_queries).toHaveLength(0);
|
});
|
||||||
|
|
||||||
|
providers.forEach(provider => {
|
||||||
|
describe(`with ${provider} provider`, () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
process.env.LLM_PROVIDER = provider;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should remove duplicate queries', async () => {
|
||||||
|
jest.setTimeout(10000);
|
||||||
|
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
|
||||||
|
const { unique_queries } = await dedupQueries(queries, []);
|
||||||
|
expect(unique_queries).toHaveLength(2);
|
||||||
|
expect(unique_queries).toContain('javascript basics');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty input', async () => {
|
||||||
|
const { unique_queries } = await dedupQueries([], []);
|
||||||
|
expect(unique_queries).toHaveLength(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,10 +1,31 @@
|
|||||||
import { analyzeSteps } from '../error-analyzer';
|
import { analyzeSteps } from '../error-analyzer';
|
||||||
|
import { LLMProvider } from '../../config';
|
||||||
|
|
||||||
describe('analyzeSteps', () => {
|
describe('analyzeSteps', () => {
|
||||||
it('should analyze error steps', async () => {
|
const providers: Array<LLMProvider> = ['openai', 'gemini'];
|
||||||
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']);
|
const originalEnv = process.env;
|
||||||
expect(response).toHaveProperty('recap');
|
|
||||||
expect(response).toHaveProperty('blame');
|
beforeEach(() => {
|
||||||
expect(response).toHaveProperty('improvement');
|
jest.resetModules();
|
||||||
|
process.env = { ...originalEnv };
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
process.env = originalEnv;
|
||||||
|
});
|
||||||
|
|
||||||
|
providers.forEach(provider => {
|
||||||
|
describe(`with ${provider} provider`, () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
process.env.LLM_PROVIDER = provider;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should analyze error steps', async () => {
|
||||||
|
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']);
|
||||||
|
expect(response).toHaveProperty('recap');
|
||||||
|
expect(response).toHaveProperty('blame');
|
||||||
|
expect(response).toHaveProperty('improvement');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,27 +1,48 @@
|
|||||||
import { evaluateAnswer } from '../evaluator';
|
import { evaluateAnswer } from '../evaluator';
|
||||||
import { TokenTracker } from '../../utils/token-tracker';
|
import { TokenTracker } from '../../utils/token-tracker';
|
||||||
|
import { LLMProvider } from '../../config';
|
||||||
|
|
||||||
describe('evaluateAnswer', () => {
|
describe('evaluateAnswer', () => {
|
||||||
it('should evaluate answer definitiveness', async () => {
|
const providers: Array<LLMProvider> = ['openai', 'gemini'];
|
||||||
const tokenTracker = new TokenTracker();
|
const originalEnv = process.env;
|
||||||
const { response } = await evaluateAnswer(
|
|
||||||
'What is TypeScript?',
|
beforeEach(() => {
|
||||||
'TypeScript is a strongly typed programming language that builds on JavaScript.',
|
jest.resetModules();
|
||||||
tokenTracker
|
process.env = { ...originalEnv };
|
||||||
);
|
|
||||||
expect(response).toHaveProperty('is_definitive');
|
|
||||||
expect(response).toHaveProperty('reasoning');
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should track token usage', async () => {
|
afterEach(() => {
|
||||||
const tokenTracker = new TokenTracker();
|
process.env = originalEnv;
|
||||||
const spy = jest.spyOn(tokenTracker, 'trackUsage');
|
});
|
||||||
const { tokens } = await evaluateAnswer(
|
|
||||||
'What is TypeScript?',
|
providers.forEach(provider => {
|
||||||
'TypeScript is a strongly typed programming language that builds on JavaScript.',
|
describe(`with ${provider} provider`, () => {
|
||||||
tokenTracker
|
beforeEach(() => {
|
||||||
);
|
process.env.LLM_PROVIDER = provider;
|
||||||
expect(spy).toHaveBeenCalledWith('evaluator', tokens);
|
});
|
||||||
expect(tokens).toBeGreaterThan(0);
|
|
||||||
|
it('should evaluate answer definitiveness', async () => {
|
||||||
|
const tokenTracker = new TokenTracker();
|
||||||
|
const { response } = await evaluateAnswer(
|
||||||
|
'What is TypeScript?',
|
||||||
|
'TypeScript is a strongly typed programming language that builds on JavaScript.',
|
||||||
|
tokenTracker
|
||||||
|
);
|
||||||
|
expect(response).toHaveProperty('is_definitive');
|
||||||
|
expect(response).toHaveProperty('reasoning');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should track token usage', async () => {
|
||||||
|
const tokenTracker = new TokenTracker();
|
||||||
|
const spy = jest.spyOn(tokenTracker, 'trackUsage');
|
||||||
|
const { tokens } = await evaluateAnswer(
|
||||||
|
'What is TypeScript?',
|
||||||
|
'TypeScript is a strongly typed programming language that builds on JavaScript.',
|
||||||
|
tokenTracker
|
||||||
|
);
|
||||||
|
expect(spy).toHaveBeenCalledWith('evaluator', tokens);
|
||||||
|
expect(tokens).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,13 +1,34 @@
|
|||||||
import { rewriteQuery } from '../query-rewriter';
|
import { rewriteQuery } from '../query-rewriter';
|
||||||
|
import { LLMProvider } from '../../config';
|
||||||
|
|
||||||
describe('rewriteQuery', () => {
|
describe('rewriteQuery', () => {
|
||||||
it('should rewrite search query', async () => {
|
const providers: Array<LLMProvider> = ['openai', 'gemini'];
|
||||||
const { queries } = await rewriteQuery({
|
const originalEnv = process.env;
|
||||||
action: 'search',
|
|
||||||
searchQuery: 'how does typescript work',
|
beforeEach(() => {
|
||||||
think: 'Understanding TypeScript basics'
|
jest.resetModules();
|
||||||
|
process.env = { ...originalEnv };
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
process.env = originalEnv;
|
||||||
|
});
|
||||||
|
|
||||||
|
providers.forEach(provider => {
|
||||||
|
describe(`with ${provider} provider`, () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
process.env.LLM_PROVIDER = provider;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should rewrite search query', async () => {
|
||||||
|
const { queries } = await rewriteQuery({
|
||||||
|
action: 'search',
|
||||||
|
searchQuery: 'how does typescript work',
|
||||||
|
think: 'Understanding TypeScript basics'
|
||||||
|
});
|
||||||
|
expect(Array.isArray(queries)).toBe(true);
|
||||||
|
expect(queries.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
expect(Array.isArray(queries)).toBe(true);
|
|
||||||
expect(queries.length).toBeGreaterThan(0);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { generateObject } from 'ai';
|
import { generateObject } from 'ai';
|
||||||
import { modelConfigs } from "../config";
|
import { getModel, getMaxTokens } from "../config";
|
||||||
import { TokenTracker } from "../utils/token-tracker";
|
import { TokenTracker } from "../utils/token-tracker";
|
||||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||||
import type { DedupResponse } from '../types';
|
import type { DedupResponse } from '../types';
|
||||||
|
|
||||||
|
const model = getModel('dedup');
|
||||||
|
|
||||||
const responseSchema = z.object({
|
const responseSchema = z.object({
|
||||||
think: z.string().describe('Strategic reasoning about the overall deduplication approach'),
|
think: z.string().describe('Strategic reasoning about the overall deduplication approach'),
|
||||||
@ -13,8 +13,6 @@ const responseSchema = z.object({
|
|||||||
.describe('Array of semantically unique queries').max(3)
|
.describe('Array of semantically unique queries').max(3)
|
||||||
});
|
});
|
||||||
|
|
||||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.dedup.model);
|
|
||||||
|
|
||||||
function getPrompt(newQueries: string[], existingQueries: string[]): string {
|
function getPrompt(newQueries: string[], existingQueries: string[]): string {
|
||||||
return `You are an expert in semantic similarity analysis. Given a set of queries (setA) and a set of queries (setB)
|
return `You are an expert in semantic similarity analysis. Given a set of queries (setA) and a set of queries (setB)
|
||||||
|
|
||||||
@ -77,7 +75,7 @@ export async function dedupQueries(newQueries: string[], existingQueries: string
|
|||||||
model,
|
model,
|
||||||
schema: responseSchema,
|
schema: responseSchema,
|
||||||
prompt,
|
prompt,
|
||||||
maxTokens: modelConfigs.dedup.maxTokens
|
maxTokens: getMaxTokens('dedup')
|
||||||
});
|
});
|
||||||
object = result.object;
|
object = result.object;
|
||||||
tokens = result.usage?.totalTokens || 0;
|
tokens = result.usage?.totalTokens || 0;
|
||||||
|
|||||||
@ -1,18 +1,19 @@
|
|||||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { generateObject } from 'ai';
|
import { generateObject } from 'ai';
|
||||||
import { modelConfigs } from "../config";
|
import { getModel, getMaxTokens } from "../config";
|
||||||
import { TokenTracker } from "../utils/token-tracker";
|
import { TokenTracker } from "../utils/token-tracker";
|
||||||
import { ErrorAnalysisResponse } from '../types';
|
import { ErrorAnalysisResponse } from '../types';
|
||||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||||
|
|
||||||
|
const model = getModel('errorAnalyzer');
|
||||||
|
|
||||||
const responseSchema = z.object({
|
const responseSchema = z.object({
|
||||||
recap: z.string().describe('Recap of the actions taken and the steps conducted'),
|
recap: z.string().describe('Recap of the actions taken and the steps conducted'),
|
||||||
blame: z.string().describe('Which action or the step was the root cause of the answer rejection'),
|
blame: z.string().describe('Which action or the step was the root cause of the answer rejection'),
|
||||||
improvement: z.string().describe('Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe.')
|
improvement: z.string().describe('Suggested key improvement for the next iteration, do not use bullet points, be concise and hot-take vibe.')
|
||||||
});
|
});
|
||||||
|
|
||||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.errorAnalyzer.model);
|
|
||||||
|
|
||||||
function getPrompt(diaryContext: string[]): string {
|
function getPrompt(diaryContext: string[]): string {
|
||||||
return `You are an expert at analyzing search and reasoning processes. Your task is to analyze the given sequence of steps and identify what went wrong in the search process.
|
return `You are an expert at analyzing search and reasoning processes. Your task is to analyze the given sequence of steps and identify what went wrong in the search process.
|
||||||
@ -112,7 +113,7 @@ export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracke
|
|||||||
model,
|
model,
|
||||||
schema: responseSchema,
|
schema: responseSchema,
|
||||||
prompt,
|
prompt,
|
||||||
maxTokens: modelConfigs.errorAnalyzer.maxTokens
|
maxTokens: getMaxTokens('errorAnalyzer')
|
||||||
});
|
});
|
||||||
object = result.object;
|
object = result.object;
|
||||||
tokens = result.usage?.totalTokens || 0;
|
tokens = result.usage?.totalTokens || 0;
|
||||||
|
|||||||
@ -1,17 +1,18 @@
|
|||||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { generateObject } from 'ai';
|
import { generateObject } from 'ai';
|
||||||
import { modelConfigs } from "../config";
|
import { getModel, getMaxTokens } from "../config";
|
||||||
import { TokenTracker } from "../utils/token-tracker";
|
import { TokenTracker } from "../utils/token-tracker";
|
||||||
import { EvaluationResponse } from '../types';
|
import { EvaluationResponse } from '../types';
|
||||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||||
|
|
||||||
|
const model = getModel('evaluator');
|
||||||
|
|
||||||
const responseSchema = z.object({
|
const responseSchema = z.object({
|
||||||
is_definitive: z.boolean().describe('Whether the answer provides a definitive response without uncertainty or \'I don\'t know\' type statements'),
|
is_definitive: z.boolean().describe('Whether the answer provides a definitive response without uncertainty or \'I don\'t know\' type statements'),
|
||||||
reasoning: z.string().describe('Explanation of why the answer is or isn\'t definitive')
|
reasoning: z.string().describe('Explanation of why the answer is or isn\'t definitive')
|
||||||
});
|
});
|
||||||
|
|
||||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.evaluator.model);
|
|
||||||
|
|
||||||
function getPrompt(question: string, answer: string): string {
|
function getPrompt(question: string, answer: string): string {
|
||||||
return `You are an evaluator of answer definitiveness. Analyze if the given answer provides a definitive response or not.
|
return `You are an evaluator of answer definitiveness. Analyze if the given answer provides a definitive response or not.
|
||||||
@ -57,7 +58,7 @@ export async function evaluateAnswer(question: string, answer: string, tracker?:
|
|||||||
model,
|
model,
|
||||||
schema: responseSchema,
|
schema: responseSchema,
|
||||||
prompt,
|
prompt,
|
||||||
maxTokens: modelConfigs.evaluator.maxTokens
|
maxTokens: getMaxTokens('evaluator')
|
||||||
});
|
});
|
||||||
object = result.object;
|
object = result.object;
|
||||||
totalTokens = result.usage?.totalTokens || 0;
|
totalTokens = result.usage?.totalTokens || 0;
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { modelConfigs } from "../config";
|
import { generateObject } from 'ai';
|
||||||
|
import { getModel, getMaxTokens } from "../config";
|
||||||
import { TokenTracker } from "../utils/token-tracker";
|
import { TokenTracker } from "../utils/token-tracker";
|
||||||
import { SearchAction, KeywordsResponse } from '../types';
|
import { SearchAction, KeywordsResponse } from '../types';
|
||||||
import { generateObject } from 'ai';
|
|
||||||
import { handleGenerateObjectError } from '../utils/error-handling';
|
import { handleGenerateObjectError } from '../utils/error-handling';
|
||||||
|
|
||||||
|
const model = getModel('queryRewriter');
|
||||||
|
|
||||||
const responseSchema = z.object({
|
const responseSchema = z.object({
|
||||||
think: z.string().describe('Strategic reasoning about query complexity and search approach'),
|
think: z.string().describe('Strategic reasoning about query complexity and search approach'),
|
||||||
queries: z.array(z.string().describe('Search query, must be less than 30 characters'))
|
queries: z.array(z.string().describe('Search query, must be less than 30 characters'))
|
||||||
@ -14,7 +15,7 @@ const responseSchema = z.object({
|
|||||||
.describe('Array of search queries, orthogonal to each other')
|
.describe('Array of search queries, orthogonal to each other')
|
||||||
});
|
});
|
||||||
|
|
||||||
const model = createGoogleGenerativeAI({ apiKey: process.env.GEMINI_API_KEY })(modelConfigs.queryRewriter.model);
|
|
||||||
|
|
||||||
function getPrompt(action: SearchAction): string {
|
function getPrompt(action: SearchAction): string {
|
||||||
return `You are an expert Information Retrieval Assistant. Transform user queries into precise keyword combinations with strategic reasoning and appropriate search operators.
|
return `You are an expert Information Retrieval Assistant. Transform user queries into precise keyword combinations with strategic reasoning and appropriate search operators.
|
||||||
@ -102,7 +103,7 @@ export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker)
|
|||||||
model,
|
model,
|
||||||
schema: responseSchema,
|
schema: responseSchema,
|
||||||
prompt,
|
prompt,
|
||||||
maxTokens: modelConfigs.queryRewriter.maxTokens
|
maxTokens: getMaxTokens('queryRewriter')
|
||||||
});
|
});
|
||||||
object = result.object;
|
object = result.object;
|
||||||
tokens = result.usage?.totalTokens || 0;
|
tokens = result.usage?.totalTokens || 0;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user