mirror of
https://github.com/jina-ai/node-DeepResearch.git
synced 2025-12-25 22:16:49 +08:00
refactor: add multi-provider support with Vercel AI SDK
- Add support for Gemini, OpenAI, and Ollama providers - Set default models (gemini-flash-1.5 for Gemini, gpt4o-mini for OpenAI) - Implement provider factory pattern - Update schema handling for each provider - Add environment variable configuration - Maintain token tracking across providers Co-Authored-By: Han Xiao <han.xiao@jina.ai>
This commit is contained in:
parent
22c2244225
commit
4c0093deb0
14
.env.example
Normal file
14
.env.example
Normal file
@ -0,0 +1,14 @@
|
||||
# Google Gemini API Key (required)
|
||||
GEMINI_API_KEY=your_gemini_key_here
|
||||
|
||||
# OpenAI API Key (required for OpenAI provider)
|
||||
OPENAI_API_KEY=your_openai_key_here
|
||||
|
||||
# Ollama API Key (required for Ollama provider)
|
||||
OLLAMA_API_KEY=your_ollama_key_here
|
||||
|
||||
# Jina API Key (required for search)
|
||||
JINA_API_KEY=your_jina_key_here
|
||||
|
||||
# Brave API Key (optional for Brave search)
|
||||
BRAVE_API_KEY=your_brave_key_here
|
||||
36
package-lock.json
generated
36
package-lock.json
generated
@ -9,10 +9,11 @@
|
||||
"version": "1.0.0",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@google/generative-ai": "^0.21.0",
|
||||
"@types/cors": "^2.8.17",
|
||||
"@types/express": "^5.0.0",
|
||||
"@types/node-fetch": "^2.6.12",
|
||||
"ai": "^4.1.19",
|
||||
"ai": "^4.1.20",
|
||||
"axios": "^1.7.9",
|
||||
"cors": "^2.8.5",
|
||||
"duck-duck-scrape": "^2.2.7",
|
||||
@ -71,13 +72,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ai-sdk/react": {
|
||||
"version": "1.1.9",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-1.1.9.tgz",
|
||||
"integrity": "sha512-2si293+NYs3WbPfHXSZ4/71NtYV0zxYhhHSL4H1EPyHU9Gf/H81rhjsslvt45mguPecPkMG19/VIXDjJ4uTwsw==",
|
||||
"version": "1.1.10",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-1.1.10.tgz",
|
||||
"integrity": "sha512-RTkEVYKq7qO6Ct3XdVTgbaCTyjX+q1HLqb+t2YvZigimzMCQbHkpZCtt2H2Fgpt1UOTqnAAlXjEAgTW3X60Y9g==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider-utils": "2.1.6",
|
||||
"@ai-sdk/ui-utils": "1.1.9",
|
||||
"@ai-sdk/ui-utils": "1.1.10",
|
||||
"swr": "^2.2.5",
|
||||
"throttleit": "2.1.0"
|
||||
},
|
||||
@ -98,9 +99,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ai-sdk/ui-utils": {
|
||||
"version": "1.1.9",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-1.1.9.tgz",
|
||||
"integrity": "sha512-o0tDopdtHqgr9FAx0qSkdwPUDSdX+4l42YOn70zvs6+O+PILeTpf2YYV5Xr32TbNfSUq1DWLLhU1O7/3Dsxm1Q==",
|
||||
"version": "1.1.10",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-1.1.10.tgz",
|
||||
"integrity": "sha512-x+A1Nfy8RTSatdCe+7nRpHAZVzPFB6H+r+2JKoapSvrwsu9mw2pAbmFgV8Zaj94TsmUdTlO0/j97e63f+yYuWg==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider": "1.0.7",
|
||||
@ -771,6 +772,15 @@
|
||||
"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@google/generative-ai": {
|
||||
"version": "0.21.0",
|
||||
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.21.0.tgz",
|
||||
"integrity": "sha512-7XhUbtnlkSEZK15kN3t+tzIMxsbKm/dSkKBFalj+20NvPKe1kBY7mR2P7vuijEn+f06z5+A8bVGKO0v39cr6Wg==",
|
||||
"license": "Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@humanwhocodes/config-array": {
|
||||
"version": "0.13.0",
|
||||
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.13.0.tgz",
|
||||
@ -1954,15 +1964,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/ai": {
|
||||
"version": "4.1.19",
|
||||
"resolved": "https://registry.npmjs.org/ai/-/ai-4.1.19.tgz",
|
||||
"integrity": "sha512-Xx498vbFVN4Y3F4kWF59ojLyn/d++NbSZwENq1zuSFW4OjwzTf79jtMxD+BYeMiDH+mgIrmROY/ONtqMOchZGw==",
|
||||
"version": "4.1.20",
|
||||
"resolved": "https://registry.npmjs.org/ai/-/ai-4.1.20.tgz",
|
||||
"integrity": "sha512-wgx2AMdgTKxHb0EWZR9ovHtxzSVWhoHQqkOVW3KGsUdfSaPmS/GXot358413V6fQncuMWfiFcKEyTXRB0AuZvA==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider": "1.0.7",
|
||||
"@ai-sdk/provider-utils": "2.1.6",
|
||||
"@ai-sdk/react": "1.1.9",
|
||||
"@ai-sdk/ui-utils": "1.1.9",
|
||||
"@ai-sdk/react": "1.1.10",
|
||||
"@ai-sdk/ui-utils": "1.1.10",
|
||||
"@opentelemetry/api": "1.9.0",
|
||||
"jsondiffpatch": "0.6.0"
|
||||
},
|
||||
|
||||
@ -21,13 +21,14 @@
|
||||
"@types/cors": "^2.8.17",
|
||||
"@types/express": "^5.0.0",
|
||||
"@types/node-fetch": "^2.6.12",
|
||||
"ai": "^4.1.19",
|
||||
"@google/generative-ai": "^0.21.0",
|
||||
"ai": "^4.1.20",
|
||||
"openai": "^4.82.0",
|
||||
"axios": "^1.7.9",
|
||||
"cors": "^2.8.5",
|
||||
"duck-duck-scrape": "^2.2.7",
|
||||
"express": "^4.21.2",
|
||||
"node-fetch": "^3.3.2",
|
||||
"openai": "^4.82.0",
|
||||
"undici": "^7.3.0",
|
||||
"zod": "^3.24.1"
|
||||
},
|
||||
|
||||
99
src/agent.ts
99
src/agent.ts
@ -1,4 +1,4 @@
|
||||
import OpenAI from 'openai';
|
||||
import { ProviderFactory, AIProvider, isGeminiProvider, isOpenAIProvider } from './utils/provider-factory';
|
||||
import {readUrl} from "./tools/read";
|
||||
import fs from 'fs/promises';
|
||||
import {SafeSearchType, search as duckSearch} from "duck-duck-scrape";
|
||||
@ -7,15 +7,60 @@ import {rewriteQuery} from "./tools/query-rewriter";
|
||||
import {dedupQueries} from "./tools/dedup";
|
||||
import {evaluateAnswer} from "./tools/evaluator";
|
||||
import {analyzeSteps} from "./tools/error-analyzer";
|
||||
import {OPENAI_API_KEY, SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
|
||||
import {aiConfig, SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
|
||||
import { z } from 'zod';
|
||||
import {TokenTracker} from "./utils/token-tracker";
|
||||
import {ActionTracker} from "./utils/action-tracker";
|
||||
import {StepAction, SchemaProperty, ResponseSchema, AnswerAction} from "./types";
|
||||
import { StepAction, AnswerAction, ProviderType } from "./types";
|
||||
import {TrackerContext} from "./types";
|
||||
import {jinaSearch} from "./tools/jinaSearch";
|
||||
import { getProviderSchema } from './utils/schema';
|
||||
|
||||
const openai = new OpenAI({ apiKey: OPENAI_API_KEY });
|
||||
async function generateResponse(provider: AIProvider, prompt: string, providerType: ProviderType, schema: any, modelConfig: any) {
|
||||
if (!isGeminiProvider(provider) && !isOpenAIProvider(provider)) {
|
||||
throw new Error('Invalid provider type');
|
||||
}
|
||||
switch (providerType) {
|
||||
case 'gemini': {
|
||||
if (!isGeminiProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.generateContent({
|
||||
contents: [{ role: 'user', parts: [{ text: prompt }]}],
|
||||
generationConfig: {
|
||||
temperature: modelConfig.temperature,
|
||||
maxOutputTokens: 1000
|
||||
}
|
||||
});
|
||||
const response = await result.response;
|
||||
return {
|
||||
text: response.text(),
|
||||
tokens: response.usageMetadata?.totalTokenCount || 0
|
||||
};
|
||||
}
|
||||
case 'openai': {
|
||||
if (!isOpenAIProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfig.model,
|
||||
temperature: modelConfig.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: getProviderSchema('openai', schema)
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
return {
|
||||
text: functionCall?.arguments || '',
|
||||
tokens: result.usage?.total_tokens || 0
|
||||
};
|
||||
}
|
||||
case 'ollama':
|
||||
throw new Error('Ollama support coming soon');
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${providerType}`);
|
||||
}
|
||||
}
|
||||
|
||||
async function sleep(ms: number) {
|
||||
const seconds = Math.ceil(ms / 1000);
|
||||
@ -26,7 +71,7 @@ async function sleep(ms: number) {
|
||||
function getSchema(allowReflect: boolean, allowRead: boolean, allowAnswer: boolean, allowSearch: boolean) {
|
||||
const actions: string[] = [];
|
||||
let schema = z.object({
|
||||
action: z.enum([]).describe("Must match exactly one action type"),
|
||||
action: z.enum(['dummy'] as [string, ...string[]]).describe("Must match exactly one action type"),
|
||||
think: z.string().describe("Explain why choose this action, what's the thought process behind choosing this action")
|
||||
});
|
||||
|
||||
@ -327,23 +372,15 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
|
||||
false
|
||||
);
|
||||
|
||||
const result = await openai.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.agent.model,
|
||||
temperature: modelConfigs.agent.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: getSchema(allowReflect, allowRead, allowAnswer, allowSearch)
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
const responseData = functionCall ? JSON.parse(functionCall.arguments) as StepAction : null;
|
||||
const provider = ProviderFactory.createProvider();
|
||||
const providerType = aiConfig.defaultProvider;
|
||||
const schema = getSchema(allowReflect, allowRead, allowAnswer, allowSearch);
|
||||
|
||||
const { text, tokens } = await generateResponse(provider, prompt, providerType, schema, modelConfigs.agent);
|
||||
const responseData = JSON.parse(text) as StepAction;
|
||||
if (!responseData) throw new Error('No valid response generated');
|
||||
|
||||
context.tokenTracker.trackUsage('agent', result.usage.total_tokens);
|
||||
context.tokenTracker.trackUsage('agent', tokens, providerType);
|
||||
thisStep = responseData;
|
||||
// print allowed and chose action
|
||||
const actionsStr = [allowSearch, allowRead, allowAnswer, allowReflect].map((a, i) => a ? ['search', 'read', 'answer', 'reflect'][i] : null).filter(a => a).join(', ');
|
||||
@ -672,23 +709,15 @@ You decided to think out of the box or cut from a completely different angle.`);
|
||||
true
|
||||
);
|
||||
|
||||
const result = await openai.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.agentBeastMode.model,
|
||||
temperature: modelConfigs.agentBeastMode.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: getSchema(false, false, allowAnswer, false)
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
const responseData = functionCall ? JSON.parse(functionCall.arguments) as StepAction : null;
|
||||
const provider = ProviderFactory.createProvider();
|
||||
const providerType = aiConfig.defaultProvider;
|
||||
const schema = getSchema(false, false, allowAnswer, false);
|
||||
|
||||
const { text, tokens } = await generateResponse(provider, prompt, providerType, schema, modelConfigs.agentBeastMode);
|
||||
const responseData = JSON.parse(text) as StepAction;
|
||||
if (!responseData) throw new Error('No valid response generated');
|
||||
|
||||
context.tokenTracker.trackUsage('agent', result.usage.total_tokens);
|
||||
context.tokenTracker.trackUsage('agent', tokens, providerType);
|
||||
await storeContext(prompt, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
|
||||
thisStep = responseData;
|
||||
console.log(thisStep)
|
||||
|
||||
@ -29,44 +29,73 @@ if (process.env.https_proxy) {
|
||||
}
|
||||
}
|
||||
|
||||
import { AIConfig, ProviderType } from './types';
|
||||
|
||||
export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string;
|
||||
export const OPENAI_API_KEY = process.env.OPENAI_API_KEY as string;
|
||||
export const OLLAMA_API_KEY = process.env.OLLAMA_API_KEY as string;
|
||||
export const JINA_API_KEY = process.env.JINA_API_KEY as string;
|
||||
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string;
|
||||
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina'
|
||||
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina';
|
||||
|
||||
const DEFAULT_MODEL = 'gpt-4';
|
||||
export const aiConfig: AIConfig = {
|
||||
defaultProvider: 'gemini' as ProviderType,
|
||||
providers: {
|
||||
gemini: {
|
||||
type: 'gemini',
|
||||
model: 'gemini-flash-1.5', // Updated to correct model name
|
||||
temperature: 0
|
||||
},
|
||||
openai: {
|
||||
type: 'openai',
|
||||
model: 'gpt4o-mini', // Updated to correct model name
|
||||
temperature: 0
|
||||
},
|
||||
ollama: {
|
||||
type: 'ollama',
|
||||
model: 'llama2',
|
||||
temperature: 0
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const defaultConfig: ModelConfig = {
|
||||
model: DEFAULT_MODEL,
|
||||
temperature: 0
|
||||
model: aiConfig.providers[aiConfig.defaultProvider].model,
|
||||
temperature: aiConfig.providers[aiConfig.defaultProvider].temperature
|
||||
};
|
||||
|
||||
export const modelConfigs: ToolConfigs = {
|
||||
dedup: {
|
||||
...defaultConfig,
|
||||
model: aiConfig.providers[aiConfig.defaultProvider].model,
|
||||
temperature: 0.1
|
||||
},
|
||||
evaluator: {
|
||||
...defaultConfig
|
||||
...defaultConfig,
|
||||
model: aiConfig.providers[aiConfig.defaultProvider].model
|
||||
},
|
||||
errorAnalyzer: {
|
||||
...defaultConfig
|
||||
...defaultConfig,
|
||||
model: aiConfig.providers[aiConfig.defaultProvider].model
|
||||
},
|
||||
queryRewriter: {
|
||||
...defaultConfig,
|
||||
model: aiConfig.providers[aiConfig.defaultProvider].model,
|
||||
temperature: 0.1
|
||||
},
|
||||
agent: {
|
||||
...defaultConfig,
|
||||
model: aiConfig.providers[aiConfig.defaultProvider].model,
|
||||
temperature: 0.7
|
||||
},
|
||||
agentBeastMode: {
|
||||
...defaultConfig,
|
||||
model: aiConfig.providers[aiConfig.defaultProvider].model,
|
||||
temperature: 0.7
|
||||
}
|
||||
};
|
||||
|
||||
export const STEP_SLEEP = 1000;
|
||||
|
||||
if (!OPENAI_API_KEY) throw new Error("OPENAI_API_KEY not found");
|
||||
if (!GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
|
||||
if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found");
|
||||
|
||||
5
src/env.ts
Normal file
5
src/env.ts
Normal file
@ -0,0 +1,5 @@
|
||||
export const GEMINI_API_KEY = process.env.GEMINI_API_KEY || '';
|
||||
export const OPENAI_API_KEY = process.env.OPENAI_API_KEY || '';
|
||||
export const OLLAMA_API_KEY = process.env.OLLAMA_API_KEY || '';
|
||||
export const JINA_API_KEY = process.env.JINA_API_KEY || '';
|
||||
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY || '';
|
||||
@ -1,10 +1,9 @@
|
||||
import OpenAI from 'openai';
|
||||
import { OPENAI_API_KEY, modelConfigs } from "../config";
|
||||
import { ProviderFactory, AIProvider, isGeminiProvider, isOpenAIProvider } from '../utils/provider-factory';
|
||||
import { aiConfig, modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { DedupResponse } from '../types';
|
||||
import { DedupResponse, ProviderType, OpenAIFunctionParameter } from '../types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const openai = new OpenAI({ apiKey: OPENAI_API_KEY });
|
||||
import { getProviderSchema } from '../utils/schema';
|
||||
|
||||
const responseSchema = z.object({
|
||||
think: z.string().describe("Strategic reasoning about the overall deduplication approach"),
|
||||
@ -13,6 +12,52 @@ const responseSchema = z.object({
|
||||
).describe("Array of semantically unique queries")
|
||||
});
|
||||
|
||||
async function generateResponse(provider: AIProvider, prompt: string, providerType: ProviderType) {
|
||||
if (!isGeminiProvider(provider) && !isOpenAIProvider(provider)) {
|
||||
throw new Error('Invalid provider type');
|
||||
}
|
||||
switch (providerType) {
|
||||
case 'gemini': {
|
||||
if (!isGeminiProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.generateContent({
|
||||
contents: [{ role: 'user', parts: [{ text: prompt }]}],
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.dedup.temperature,
|
||||
maxOutputTokens: 1000
|
||||
}
|
||||
});
|
||||
const response = await result.response;
|
||||
return {
|
||||
text: response.text(),
|
||||
tokens: response.usageMetadata?.totalTokenCount || 0
|
||||
};
|
||||
}
|
||||
case 'openai': {
|
||||
if (!isOpenAIProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.dedup.model,
|
||||
temperature: modelConfigs.dedup.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: getProviderSchema('openai', responseSchema) as OpenAIFunctionParameter
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
return {
|
||||
text: functionCall?.arguments || '',
|
||||
tokens: result.usage?.total_tokens || 0
|
||||
};
|
||||
}
|
||||
case 'ollama':
|
||||
throw new Error('Ollama support coming soon');
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${providerType}`);
|
||||
}
|
||||
}
|
||||
|
||||
function getPrompt(newQueries: string[], existingQueries: string[]): string {
|
||||
return `You are an expert in semantic similarity analysis. Given a set of queries (setA) and a set of queries (setB)
|
||||
|
||||
@ -67,29 +112,22 @@ SetB: ${JSON.stringify(existingQueries)}`;
|
||||
|
||||
export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[], tokens: number }> {
|
||||
try {
|
||||
const provider = ProviderFactory.createProvider();
|
||||
const providerType = aiConfig.defaultProvider;
|
||||
const prompt = getPrompt(newQueries, existingQueries);
|
||||
const result = await openai.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.dedup.model,
|
||||
temperature: modelConfigs.dedup.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: responseSchema.shape
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
const responseData = functionCall ? JSON.parse(functionCall.arguments) as DedupResponse : null;
|
||||
|
||||
const { text, tokens } = await generateResponse(provider, prompt, providerType);
|
||||
const responseData = JSON.parse(text) as DedupResponse;
|
||||
if (!responseData) throw new Error('No valid response generated');
|
||||
|
||||
console.log('Dedup:', responseData.unique_queries);
|
||||
const tokens = result.usage.total_tokens;
|
||||
(tracker || new TokenTracker()).trackUsage('dedup', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('dedup', tokens, providerType);
|
||||
return { unique_queries: responseData.unique_queries, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in deduplication analysis:', error);
|
||||
if (error instanceof Error && error.message.includes('Ollama support')) {
|
||||
throw new Error('Ollama provider is not yet supported for deduplication');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
@ -99,9 +137,11 @@ export async function main() {
|
||||
const existingQueries = process.argv[3] ? JSON.parse(process.argv[3]) : [];
|
||||
|
||||
try {
|
||||
await dedupQueries(newQueries, existingQueries);
|
||||
const result = await dedupQueries(newQueries, existingQueries);
|
||||
console.log(JSON.stringify(result, null, 2));
|
||||
} catch (error) {
|
||||
console.error('Failed to deduplicate queries:', error);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import OpenAI from 'openai';
|
||||
import { OPENAI_API_KEY, modelConfigs } from "../config";
|
||||
import { ProviderFactory, AIProvider, isGeminiProvider, isOpenAIProvider } from '../utils/provider-factory';
|
||||
import { aiConfig, modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { ErrorAnalysisResponse } from '../types';
|
||||
import { ErrorAnalysisResponse, ProviderType, OpenAIFunctionParameter } from '../types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const openai = new OpenAI({ apiKey: OPENAI_API_KEY });
|
||||
import { getProviderSchema } from '../utils/schema';
|
||||
|
||||
const responseSchema = z.object({
|
||||
recap: z.string().describe("Recap of the actions taken and the steps conducted"),
|
||||
@ -100,23 +99,60 @@ ${diaryContext.join('\n')}
|
||||
`;
|
||||
}
|
||||
|
||||
async function generateResponse(provider: AIProvider, prompt: string, providerType: ProviderType) {
|
||||
if (!isGeminiProvider(provider) && !isOpenAIProvider(provider)) {
|
||||
throw new Error('Invalid provider type');
|
||||
}
|
||||
switch (providerType) {
|
||||
case 'gemini': {
|
||||
if (!isGeminiProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.generateContent({
|
||||
contents: [{ role: 'user', parts: [{ text: prompt }]}],
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.errorAnalyzer.temperature,
|
||||
maxOutputTokens: 1000
|
||||
}
|
||||
});
|
||||
const response = await result.response;
|
||||
return {
|
||||
text: response.text(),
|
||||
tokens: response.usageMetadata?.totalTokenCount || 0
|
||||
};
|
||||
}
|
||||
case 'openai': {
|
||||
if (!isOpenAIProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.errorAnalyzer.model,
|
||||
temperature: modelConfigs.errorAnalyzer.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: getProviderSchema('openai', responseSchema) as OpenAIFunctionParameter
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
return {
|
||||
text: functionCall?.arguments || '',
|
||||
tokens: result.usage?.total_tokens || 0
|
||||
};
|
||||
}
|
||||
case 'ollama':
|
||||
throw new Error('Ollama support coming soon');
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${providerType}`);
|
||||
}
|
||||
}
|
||||
|
||||
export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracker): Promise<{ response: ErrorAnalysisResponse, tokens: number }> {
|
||||
try {
|
||||
const provider = ProviderFactory.createProvider();
|
||||
const providerType = aiConfig.defaultProvider;
|
||||
const prompt = getPrompt(diaryContext);
|
||||
const result = await openai.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.errorAnalyzer.model,
|
||||
temperature: modelConfigs.errorAnalyzer.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: responseSchema.shape
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
const responseData = functionCall ? JSON.parse(functionCall.arguments) as ErrorAnalysisResponse : null;
|
||||
|
||||
const { text, tokens } = await generateResponse(provider, prompt, providerType);
|
||||
const responseData = JSON.parse(text) as ErrorAnalysisResponse;
|
||||
if (!responseData) throw new Error('No valid response generated');
|
||||
|
||||
console.log('Error analysis:', {
|
||||
@ -124,8 +160,7 @@ export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracke
|
||||
reason: responseData.blame || 'No issues found'
|
||||
});
|
||||
|
||||
const tokens = result.usage.total_tokens;
|
||||
(tracker || new TokenTracker()).trackUsage('error-analyzer', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('error-analyzer', tokens, providerType);
|
||||
return { response: responseData, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in answer evaluation:', error);
|
||||
|
||||
@ -1,16 +1,61 @@
|
||||
import OpenAI from 'openai';
|
||||
import { OPENAI_API_KEY, modelConfigs } from "../config";
|
||||
import { ProviderFactory, AIProvider, isGeminiProvider, isOpenAIProvider } from '../utils/provider-factory';
|
||||
import { aiConfig, modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { EvaluationResponse } from '../types';
|
||||
import { EvaluationResponse, ProviderType, OpenAIFunctionParameter } from '../types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const openai = new OpenAI({ apiKey: OPENAI_API_KEY });
|
||||
import { getProviderSchema } from '../utils/schema';
|
||||
|
||||
const responseSchema = z.object({
|
||||
is_definitive: z.boolean().describe("Whether the answer provides a definitive response without uncertainty or 'I don't know' type statements"),
|
||||
reasoning: z.string().describe("Explanation of why the answer is or isn't definitive")
|
||||
});
|
||||
|
||||
async function generateResponse(provider: AIProvider, prompt: string, providerType: ProviderType) {
|
||||
if (!isGeminiProvider(provider) && !isOpenAIProvider(provider)) {
|
||||
throw new Error('Invalid provider type');
|
||||
}
|
||||
switch (providerType) {
|
||||
case 'gemini': {
|
||||
if (!isGeminiProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.generateContent({
|
||||
contents: [{ role: 'user', parts: [{ text: prompt }]}],
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.evaluator.temperature,
|
||||
maxOutputTokens: 1000
|
||||
}
|
||||
});
|
||||
const response = await result.response;
|
||||
return {
|
||||
text: response.text(),
|
||||
tokens: response.usageMetadata?.totalTokenCount || 0
|
||||
};
|
||||
}
|
||||
case 'openai': {
|
||||
if (!isOpenAIProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.evaluator.model,
|
||||
temperature: modelConfigs.evaluator.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: getProviderSchema('openai', responseSchema) as OpenAIFunctionParameter
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
return {
|
||||
text: functionCall?.arguments || '',
|
||||
tokens: result.usage?.total_tokens || 0
|
||||
};
|
||||
}
|
||||
case 'ollama':
|
||||
throw new Error('Ollama support coming soon');
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${providerType}`);
|
||||
}
|
||||
}
|
||||
|
||||
function getPrompt(question: string, answer: string): string {
|
||||
return `You are an evaluator of answer definitiveness. Analyze if the given answer provides a definitive response or not.
|
||||
|
||||
@ -47,21 +92,12 @@ Answer: ${JSON.stringify(answer)}`;
|
||||
|
||||
export async function evaluateAnswer(question: string, answer: string, tracker?: TokenTracker): Promise<{ response: EvaluationResponse, tokens: number }> {
|
||||
try {
|
||||
const provider = ProviderFactory.createProvider();
|
||||
const providerType = aiConfig.defaultProvider;
|
||||
const prompt = getPrompt(question, answer);
|
||||
const result = await openai.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.evaluator.model,
|
||||
temperature: modelConfigs.evaluator.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: responseSchema.shape
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
const responseData = functionCall ? JSON.parse(functionCall.arguments) as EvaluationResponse : null;
|
||||
|
||||
const { text, tokens } = await generateResponse(provider, prompt, providerType);
|
||||
const responseData = JSON.parse(text) as EvaluationResponse;
|
||||
if (!responseData) throw new Error('No valid response generated');
|
||||
|
||||
console.log('Evaluation:', {
|
||||
@ -69,11 +105,13 @@ export async function evaluateAnswer(question: string, answer: string, tracker?:
|
||||
reason: responseData.reasoning
|
||||
});
|
||||
|
||||
const tokens = result.usage.total_tokens;
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('evaluator', tokens, providerType);
|
||||
return { response: responseData, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in answer evaluation:', error);
|
||||
if (error instanceof Error && error.message.includes('Ollama support')) {
|
||||
throw new Error('Ollama provider is not yet supported for answer evaluation');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
@ -89,9 +127,11 @@ async function main() {
|
||||
}
|
||||
|
||||
try {
|
||||
await evaluateAnswer(question, answer);
|
||||
const result = await evaluateAnswer(question, answer);
|
||||
console.log(JSON.stringify(result, null, 2));
|
||||
} catch (error) {
|
||||
console.error('Failed to evaluate answer:', error);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import OpenAI from 'openai';
|
||||
import { OPENAI_API_KEY, modelConfigs } from "../config";
|
||||
import { ProviderFactory, AIProvider, isGeminiProvider, isOpenAIProvider } from '../utils/provider-factory';
|
||||
import { aiConfig, modelConfigs } from "../config";
|
||||
import { TokenTracker } from "../utils/token-tracker";
|
||||
import { SearchAction, KeywordsResponse } from "../types";
|
||||
import { SearchAction, KeywordsResponse, ProviderType, OpenAIFunctionParameter } from "../types";
|
||||
import { z } from 'zod';
|
||||
|
||||
const openai = new OpenAI({ apiKey: OPENAI_API_KEY });
|
||||
import { getProviderSchema } from '../utils/schema';
|
||||
|
||||
const responseSchema = z.object({
|
||||
think: z.string().describe("Strategic reasoning about query complexity and search approach"),
|
||||
@ -91,32 +90,71 @@ Intention: ${action.think}
|
||||
`;
|
||||
}
|
||||
|
||||
async function generateResponse(provider: AIProvider, prompt: string, providerType: ProviderType) {
|
||||
if (!isGeminiProvider(provider) && !isOpenAIProvider(provider)) {
|
||||
throw new Error('Invalid provider type');
|
||||
}
|
||||
switch (providerType) {
|
||||
case 'gemini': {
|
||||
if (!isGeminiProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.generateContent({
|
||||
contents: [{ role: 'user', parts: [{ text: prompt }]}],
|
||||
generationConfig: {
|
||||
temperature: modelConfigs.queryRewriter.temperature,
|
||||
maxOutputTokens: 1000
|
||||
}
|
||||
});
|
||||
const response = await result.response;
|
||||
return {
|
||||
text: response.text(),
|
||||
tokens: response.usageMetadata?.totalTokenCount || 0
|
||||
};
|
||||
}
|
||||
case 'openai': {
|
||||
if (!isOpenAIProvider(provider)) throw new Error('Invalid provider type');
|
||||
const result = await provider.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.queryRewriter.model,
|
||||
temperature: modelConfigs.queryRewriter.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: getProviderSchema('openai', responseSchema) as OpenAIFunctionParameter
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
return {
|
||||
text: functionCall?.arguments || '',
|
||||
tokens: result.usage?.total_tokens || 0
|
||||
};
|
||||
}
|
||||
case 'ollama':
|
||||
throw new Error('Ollama support coming soon');
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${providerType}`);
|
||||
}
|
||||
}
|
||||
|
||||
export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[], tokens: number }> {
|
||||
try {
|
||||
const provider = ProviderFactory.createProvider();
|
||||
const providerType = aiConfig.defaultProvider;
|
||||
const prompt = getPrompt(action);
|
||||
const result = await openai.chat.completions.create({
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
model: modelConfigs.queryRewriter.model,
|
||||
temperature: modelConfigs.queryRewriter.temperature,
|
||||
max_tokens: 1000,
|
||||
functions: [{
|
||||
name: 'generate',
|
||||
parameters: responseSchema.shape
|
||||
}],
|
||||
function_call: { name: 'generate' }
|
||||
});
|
||||
|
||||
const functionCall = result.choices[0].message.function_call;
|
||||
const responseData = functionCall ? JSON.parse(functionCall.arguments) as KeywordsResponse : null;
|
||||
|
||||
const { text, tokens } = await generateResponse(provider, prompt, providerType);
|
||||
const responseData = JSON.parse(text) as KeywordsResponse;
|
||||
if (!responseData) throw new Error('No valid response generated');
|
||||
|
||||
console.log('Query rewriter:', responseData.queries);
|
||||
const tokens = result.usage.total_tokens;
|
||||
(tracker || new TokenTracker()).trackUsage('query-rewriter', tokens);
|
||||
(tracker || new TokenTracker()).trackUsage('query-rewriter', tokens, providerType);
|
||||
|
||||
return { queries: responseData.queries, tokens };
|
||||
} catch (error) {
|
||||
console.error('Error in query rewriting:', error);
|
||||
if (error instanceof Error && error.message.includes('Ollama support')) {
|
||||
throw new Error('Ollama provider is not yet supported for query rewriting');
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
27
src/types.ts
27
src/types.ts
@ -5,6 +5,32 @@ export enum SchemaType {
|
||||
OBJECT = 'OBJECT'
|
||||
}
|
||||
|
||||
export type ProviderType = 'gemini' | 'openai' | 'ollama';
|
||||
|
||||
export interface OpenAIFunctionParameter {
|
||||
type: string;
|
||||
description?: string;
|
||||
properties?: Record<string, OpenAIFunctionParameter>;
|
||||
required?: string[];
|
||||
items?: OpenAIFunctionParameter;
|
||||
}
|
||||
|
||||
export interface OpenAIFunction {
|
||||
name: string;
|
||||
parameters: OpenAIFunctionParameter;
|
||||
}
|
||||
|
||||
export interface ProviderConfig {
|
||||
type: ProviderType;
|
||||
model: string;
|
||||
temperature: number;
|
||||
}
|
||||
|
||||
export interface AIConfig {
|
||||
defaultProvider: ProviderType;
|
||||
providers: Record<ProviderType, ProviderConfig>;
|
||||
}
|
||||
|
||||
// Action Types
|
||||
type BaseAction = {
|
||||
action: "search" | "answer" | "reflect" | "visit";
|
||||
@ -41,6 +67,7 @@ export type StepAction = SearchAction | AnswerAction | ReflectAction | VisitActi
|
||||
export interface TokenUsage {
|
||||
tool: string;
|
||||
tokens: number;
|
||||
provider?: ProviderType;
|
||||
}
|
||||
|
||||
export interface SearchResponse {
|
||||
|
||||
64
src/utils/provider-factory.ts
Normal file
64
src/utils/provider-factory.ts
Normal file
@ -0,0 +1,64 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai';
|
||||
import OpenAI from 'openai';
|
||||
import type { ProviderConfig } from '../types';
|
||||
import { GEMINI_API_KEY, OPENAI_API_KEY, aiConfig } from '../config';
|
||||
|
||||
const defaultConfig = aiConfig.providers[aiConfig.defaultProvider];
|
||||
|
||||
export interface GeminiProvider {
|
||||
generateContent(params: {
|
||||
contents: Array<{ role: string; parts: Array<{ text: string }> }>;
|
||||
generationConfig?: {
|
||||
temperature?: number;
|
||||
maxOutputTokens?: number;
|
||||
};
|
||||
}): Promise<{
|
||||
response: {
|
||||
text(): string;
|
||||
usageMetadata?: { totalTokenCount?: number };
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
export type OpenAIProvider = OpenAI;
|
||||
|
||||
export type AIProvider = GeminiProvider | OpenAIProvider;
|
||||
|
||||
export function isGeminiProvider(provider: AIProvider): provider is GeminiProvider {
|
||||
return 'generateContent' in provider;
|
||||
}
|
||||
|
||||
export function isOpenAIProvider(provider: AIProvider): provider is OpenAIProvider {
|
||||
return 'chat' in provider;
|
||||
}
|
||||
|
||||
export class ProviderFactory {
|
||||
private static geminiClient: GoogleGenerativeAI | null = null;
|
||||
private static openaiClient: OpenAI | null = null;
|
||||
|
||||
static createProvider(config: ProviderConfig = defaultConfig): AIProvider {
|
||||
switch (config.type) {
|
||||
case 'gemini': {
|
||||
if (!this.geminiClient) {
|
||||
this.geminiClient = new GoogleGenerativeAI(GEMINI_API_KEY);
|
||||
}
|
||||
return this.geminiClient.getGenerativeModel({
|
||||
model: config.model,
|
||||
generationConfig: {
|
||||
temperature: config.temperature
|
||||
}
|
||||
});
|
||||
}
|
||||
case 'openai': {
|
||||
if (!this.openaiClient) {
|
||||
this.openaiClient = new OpenAI({ apiKey: OPENAI_API_KEY });
|
||||
}
|
||||
return this.openaiClient;
|
||||
}
|
||||
case 'ollama':
|
||||
throw new Error('Ollama support coming soon');
|
||||
default:
|
||||
throw new Error(`Unsupported provider type: ${config.type}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,34 +1,126 @@
|
||||
import { z } from 'zod';
|
||||
import { SchemaType } from '../types';
|
||||
import type { SchemaProperty, ResponseSchema } from '../types';
|
||||
import { SchemaType, ProviderType, OpenAIFunctionParameter } from '../types';
|
||||
import type { SchemaProperty } from '../types';
|
||||
|
||||
export function convertToZodType(prop: SchemaProperty): z.ZodTypeAny {
|
||||
let zodType: z.ZodTypeAny;
|
||||
switch (prop.type) {
|
||||
case SchemaType.STRING:
|
||||
return z.string().describe(prop.description || '');
|
||||
zodType = z.string().describe(prop.description || '');
|
||||
break;
|
||||
case SchemaType.BOOLEAN:
|
||||
return z.boolean().describe(prop.description || '');
|
||||
zodType = z.boolean().describe(prop.description || '');
|
||||
break;
|
||||
case SchemaType.ARRAY:
|
||||
if (!prop.items) throw new Error('Array schema must have items defined');
|
||||
return z.array(convertToZodType(prop.items)).describe(prop.description || '');
|
||||
case SchemaType.OBJECT:
|
||||
zodType = z.array(convertToZodType(prop.items))
|
||||
.describe(prop.description || '')
|
||||
.max(prop.maxItems || Infinity);
|
||||
break;
|
||||
case SchemaType.OBJECT: {
|
||||
if (!prop.properties) throw new Error('Object schema must have properties defined');
|
||||
const shape: Record<string, z.ZodTypeAny> = {};
|
||||
for (const [key, value] of Object.entries(prop.properties)) {
|
||||
shape[key] = convertToZodType(value);
|
||||
}
|
||||
return z.object(shape).describe(prop.description || '');
|
||||
zodType = z.object(shape).describe(prop.description || '');
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unsupported schema type: ${prop.type}`);
|
||||
}
|
||||
return zodType;
|
||||
}
|
||||
|
||||
export function createZodSchema(schema: ResponseSchema): z.ZodObject<any> {
|
||||
const shape: Record<string, z.ZodTypeAny> = {};
|
||||
for (const [key, prop] of Object.entries(schema.properties)) {
|
||||
shape[key] = convertToZodType(prop);
|
||||
export function convertToGeminiSchema(schema: z.ZodSchema): SchemaProperty {
|
||||
// Initialize schema properties
|
||||
let type: SchemaType;
|
||||
let properties: Record<string, SchemaProperty> | undefined;
|
||||
let items: SchemaProperty | undefined;
|
||||
let description = '';
|
||||
|
||||
if (schema instanceof z.ZodString) {
|
||||
type = SchemaType.STRING;
|
||||
description = schema.description || '';
|
||||
} else if (schema instanceof z.ZodBoolean) {
|
||||
type = SchemaType.BOOLEAN;
|
||||
description = schema.description || '';
|
||||
} else if (schema instanceof z.ZodArray) {
|
||||
type = SchemaType.ARRAY;
|
||||
description = schema.description || '';
|
||||
items = convertToGeminiSchema(schema.element);
|
||||
} else if (schema instanceof z.ZodObject) {
|
||||
type = SchemaType.OBJECT;
|
||||
description = schema.description || '';
|
||||
properties = {};
|
||||
const shape = schema.shape as Record<string, z.ZodTypeAny>;
|
||||
for (const [key, value] of Object.entries(shape)) {
|
||||
properties[key] = convertToGeminiSchema(value as z.ZodSchema);
|
||||
}
|
||||
} else {
|
||||
throw new Error('Unsupported Zod type');
|
||||
}
|
||||
|
||||
return {
|
||||
type,
|
||||
description,
|
||||
...(properties && { properties }),
|
||||
...(items && { items })
|
||||
};
|
||||
}
|
||||
|
||||
export function convertToOpenAIFunctionSchema(schema: z.ZodSchema): OpenAIFunctionParameter {
|
||||
if (schema instanceof z.ZodString) {
|
||||
return { type: 'string', description: schema.description || '' };
|
||||
} else if (schema instanceof z.ZodBoolean) {
|
||||
return { type: 'boolean', description: schema.description || '' };
|
||||
} else if (schema instanceof z.ZodArray) {
|
||||
return {
|
||||
type: 'array',
|
||||
description: schema.description || '',
|
||||
items: convertToOpenAIFunctionSchema(schema.element),
|
||||
...(schema._def.maxLength && { maxItems: schema._def.maxLength.value })
|
||||
};
|
||||
} else if (schema instanceof z.ZodObject) {
|
||||
const properties: Record<string, any> = {};
|
||||
const required: string[] = [];
|
||||
const shape = schema.shape as Record<string, z.ZodTypeAny>;
|
||||
|
||||
for (const [key, value] of Object.entries(shape)) {
|
||||
const zodValue = value as z.ZodTypeAny;
|
||||
properties[key] = convertToOpenAIFunctionSchema(zodValue);
|
||||
if (!zodValue.isOptional?.()) {
|
||||
required.push(key);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'object',
|
||||
description: schema.description || '',
|
||||
properties,
|
||||
required: required.length > 0 ? required : undefined
|
||||
};
|
||||
}
|
||||
|
||||
throw new Error('Unsupported Zod type');
|
||||
}
|
||||
|
||||
export function getProviderSchema(provider: ProviderType, schema: z.ZodSchema): SchemaProperty | OpenAIFunctionParameter {
|
||||
switch (provider) {
|
||||
case 'gemini':
|
||||
return convertToGeminiSchema(schema);
|
||||
case 'openai':
|
||||
case 'ollama': {
|
||||
const functionSchema = convertToOpenAIFunctionSchema(schema);
|
||||
return {
|
||||
type: 'object',
|
||||
properties: functionSchema.properties,
|
||||
required: functionSchema.required
|
||||
};
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unsupported provider: ${provider}`);
|
||||
}
|
||||
return z.object(shape);
|
||||
}
|
||||
|
||||
export function createPromptConfig(temperature: number = 0) {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { EventEmitter } from 'events';
|
||||
|
||||
import { TokenUsage } from '../types';
|
||||
import { TokenUsage, ProviderType } from '../types';
|
||||
|
||||
export class TokenTracker extends EventEmitter {
|
||||
private usages: TokenUsage[] = [];
|
||||
@ -11,15 +11,15 @@ export class TokenTracker extends EventEmitter {
|
||||
this.budget = budget;
|
||||
}
|
||||
|
||||
trackUsage(tool: string, tokens: number) {
|
||||
trackUsage(tool: string, tokens: number, provider?: ProviderType) {
|
||||
const currentTotal = this.getTotalUsage();
|
||||
if (this.budget && currentTotal + tokens > this.budget) {
|
||||
console.error(`Token budget exceeded: ${currentTotal + tokens} > ${this.budget}`);
|
||||
}
|
||||
// Only track usage if we're within budget
|
||||
if (!this.budget || currentTotal + tokens <= this.budget) {
|
||||
this.usages.push({ tool, tokens });
|
||||
this.emit('usage', { tool, tokens });
|
||||
this.usages.push({ tool, tokens, provider });
|
||||
this.emit('usage', { tool, tokens, provider });
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,17 +27,31 @@ export class TokenTracker extends EventEmitter {
|
||||
return this.usages.reduce((sum, usage) => sum + usage.tokens, 0);
|
||||
}
|
||||
|
||||
getUsageBreakdown(): Record<string, number> {
|
||||
return this.usages.reduce((acc, { tool, tokens }) => {
|
||||
acc[tool] = (acc[tool] || 0) + tokens;
|
||||
getUsageBreakdown(): Record<string, { total: number; byProvider: Record<string, number> }> {
|
||||
return this.usages.reduce((acc, { tool, tokens, provider }) => {
|
||||
if (!acc[tool]) {
|
||||
acc[tool] = { total: 0, byProvider: {} };
|
||||
}
|
||||
acc[tool].total += tokens;
|
||||
if (provider) {
|
||||
acc[tool].byProvider[provider] = (acc[tool].byProvider[provider] || 0) + tokens;
|
||||
}
|
||||
return acc;
|
||||
}, {} as Record<string, number>);
|
||||
}, {} as Record<string, { total: number; byProvider: Record<string, number> }>);
|
||||
}
|
||||
|
||||
printSummary() {
|
||||
const breakdown = this.getUsageBreakdown();
|
||||
const totalByProvider = this.usages.reduce((acc, { tokens, provider }) => {
|
||||
if (provider) {
|
||||
acc[provider] = (acc[provider] || 0) + tokens;
|
||||
}
|
||||
return acc;
|
||||
}, {} as Record<string, number>);
|
||||
|
||||
console.log('Token Usage Summary:', {
|
||||
total: this.getTotalUsage(),
|
||||
byProvider: totalByProvider,
|
||||
breakdown
|
||||
});
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user