From 39c8e556514b351bb9e896fe4b9611aa5901e71b Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 5 Feb 2025 09:04:15 +0000 Subject: [PATCH] refactor: extract LLM client to config.ts and add local LLM support Co-Authored-By: Han Xiao --- src/config.ts | 87 ++++++++++++++++++++++++++++++++++++- src/tools/dedup.ts | 7 ++- src/tools/error-analyzer.ts | 7 ++- src/tools/evaluator.ts | 7 ++- src/tools/query-rewriter.ts | 9 ++-- 5 files changed, 99 insertions(+), 18 deletions(-) diff --git a/src/config.ts b/src/config.ts index 3543b5f..92ae213 100644 --- a/src/config.ts +++ b/src/config.ts @@ -15,6 +15,80 @@ interface ToolConfigs { agentBeastMode: ModelConfig; } +import { GenerateContentResult, GoogleGenerativeAI } from '@google/generative-ai'; + +interface LLMClientConfig { + model: string; + temperature: number; + generationConfig?: { + temperature: number; + responseMimeType: string; + responseSchema: any; + }; +} + +interface LLMClient { + getGenerativeModel(config: LLMClientConfig): { + generateContent(prompt: string): Promise; + }; +} + +interface GenerateContentResult { + response: { + text(): string; + usageMetadata: { + totalTokenCount: number; + }; + }; +} + +class LocalLLMClient implements LLMClient { + constructor( + private hostname: string, + private port: string, + private model: string + ) {} + + getGenerativeModel(config: LLMClientConfig) { + return { + generateContent: async (prompt: string) => { + const response = await fetch(`http://${this.hostname}:${this.port}/v1/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: this.model, + messages: [ + { + role: 'user', + content: prompt, + }, + ], + temperature: config.generationConfig?.temperature ?? config.temperature, + response_format: { + type: 'json_schema', + json_schema: config.generationConfig?.responseSchema, + }, + max_tokens: 1000, + stream: false, + }), + }); + + const data = await response.json(); + return { + response: { + text: () => data.choices[0].message.content, + usageMetadata: { + totalTokenCount: data.usage?.total_tokens || 0, + }, + }, + }; + }, + }; + } +} + dotenv.config(); @@ -32,7 +106,18 @@ if (process.env.https_proxy) { export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string; export const JINA_API_KEY = process.env.JINA_API_KEY as string; export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string; -export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina' +export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina'; + +// LLM Configuration +export const LOCAL_LLM_HOSTNAME = process.env.LOCAL_LLM_HOSTNAME; +export const LOCAL_LLM_PORT = process.env.LOCAL_LLM_PORT; +export const LOCAL_LLM_MODEL = process.env.LOCAL_LLM_MODEL; +export const LLM_PROVIDER = process.env.LLM_PROVIDER || 'gemini'; + +// Initialize LLM client based on configuration +export const llmClient: LLMClient = LLM_PROVIDER === 'local' && LOCAL_LLM_HOSTNAME && LOCAL_LLM_PORT && LOCAL_LLM_MODEL + ? new LocalLLMClient(LOCAL_LLM_HOSTNAME, LOCAL_LLM_PORT, LOCAL_LLM_MODEL) + : new GoogleGenerativeAI(GEMINI_API_KEY); const DEFAULT_MODEL = 'gemini-1.5-flash'; diff --git a/src/tools/dedup.ts b/src/tools/dedup.ts index c48f5ce..eb182a6 100644 --- a/src/tools/dedup.ts +++ b/src/tools/dedup.ts @@ -1,5 +1,5 @@ -import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import { SchemaType } from "@google/generative-ai"; +import { modelConfigs, llmClient } from "../config"; import { TokenTracker } from "../utils/token-tracker"; import { DedupResponse } from '../types'; @@ -23,8 +23,7 @@ const responseSchema = { required: ["think", "unique_queries"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getGenerativeModel({ model: modelConfigs.dedup.model, generationConfig: { temperature: modelConfigs.dedup.temperature, diff --git a/src/tools/error-analyzer.ts b/src/tools/error-analyzer.ts index 7f55b94..773a40a 100644 --- a/src/tools/error-analyzer.ts +++ b/src/tools/error-analyzer.ts @@ -1,5 +1,5 @@ -import {GoogleGenerativeAI, SchemaType} from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import { SchemaType } from "@google/generative-ai"; +import { modelConfigs, llmClient } from "../config"; import { TokenTracker } from "../utils/token-tracker"; import { ErrorAnalysisResponse } from '../types'; @@ -23,8 +23,7 @@ const responseSchema = { required: ["recap", "blame", "improvement"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getGenerativeModel({ model: modelConfigs.errorAnalyzer.model, generationConfig: { temperature: modelConfigs.errorAnalyzer.temperature, diff --git a/src/tools/evaluator.ts b/src/tools/evaluator.ts index 222e0d2..fee2210 100644 --- a/src/tools/evaluator.ts +++ b/src/tools/evaluator.ts @@ -1,5 +1,5 @@ -import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import { SchemaType } from "@google/generative-ai"; +import { modelConfigs, llmClient } from "../config"; import { TokenTracker } from "../utils/token-tracker"; import { EvaluationResponse } from '../types'; @@ -19,8 +19,7 @@ const responseSchema = { required: ["is_definitive", "reasoning"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getGenerativeModel({ model: modelConfigs.evaluator.model, generationConfig: { temperature: modelConfigs.evaluator.temperature, diff --git a/src/tools/query-rewriter.ts b/src/tools/query-rewriter.ts index 8410413..31ab569 100644 --- a/src/tools/query-rewriter.ts +++ b/src/tools/query-rewriter.ts @@ -1,5 +1,5 @@ -import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import { SchemaType } from "@google/generative-ai"; +import { modelConfigs, llmClient } from "../config"; import { TokenTracker } from "../utils/token-tracker"; import { SearchAction } from "../types"; @@ -26,8 +26,7 @@ const responseSchema = { required: ["think", "queries"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getGenerativeModel({ model: modelConfigs.queryRewriter.model, generationConfig: { temperature: modelConfigs.queryRewriter.temperature, @@ -129,4 +128,4 @@ export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker) console.error('Error in query rewriting:', error); throw error; } -} \ No newline at end of file +}