From 19a938b8884f5f57631a4a2e4ef87b64ccb66966 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 5 Feb 2025 09:14:56 +0000 Subject: [PATCH] fix: update LLMClientConfig interface and fix type errors Co-Authored-By: Han Xiao --- src/config.ts | 57 ++++++++++++++++++++++++++++--------- src/tools/dedup.ts | 2 +- src/tools/error-analyzer.ts | 2 +- src/tools/evaluator.ts | 2 +- src/tools/query-rewriter.ts | 2 +- 5 files changed, 48 insertions(+), 17 deletions(-) diff --git a/src/config.ts b/src/config.ts index 92ae213..2529f31 100644 --- a/src/config.ts +++ b/src/config.ts @@ -15,31 +15,62 @@ interface ToolConfigs { agentBeastMode: ModelConfig; } -import { GenerateContentResult, GoogleGenerativeAI } from '@google/generative-ai'; +import { GoogleGenerativeAI } from '@google/generative-ai'; interface LLMClientConfig { model: string; temperature: number; generationConfig?: { - temperature: number; - responseMimeType: string; - responseSchema: any; + responseMimeType?: string; + responseSchema?: any; + }; +} + +interface LLMResponse { + text(): string; + usageMetadata: { + totalTokenCount: number; }; } interface LLMClient { getGenerativeModel(config: LLMClientConfig): { - generateContent(prompt: string): Promise; + generateContent(prompt: string): Promise<{ + response: LLMResponse; + }>; }; } -interface GenerateContentResult { - response: { - text(): string; - usageMetadata: { - totalTokenCount: number; +class GoogleAIWrapper implements LLMClient { + private client: GoogleGenerativeAI; + + constructor(apiKey: string) { + this.client = new GoogleGenerativeAI(apiKey); + } + + getGenerativeModel(config: LLMClientConfig) { + const model = this.client.getGenerativeModel({ + model: config.model, + generationConfig: { + temperature: config.temperature, + ...(config.generationConfig || {}) + } + }); + + return { + generateContent: async (prompt: string) => { + const result = await model.generateContent(prompt); + return { + response: { + text: () => result.response.text(), + usageMetadata: { + totalTokenCount: result.response.usageMetadata?.totalTokenCount ?? 0 + } + } + }; + } }; - }; + } } class LocalLLMClient implements LLMClient { @@ -65,7 +96,7 @@ class LocalLLMClient implements LLMClient { content: prompt, }, ], - temperature: config.generationConfig?.temperature ?? config.temperature, + temperature: config.temperature, response_format: { type: 'json_schema', json_schema: config.generationConfig?.responseSchema, @@ -117,7 +148,7 @@ export const LLM_PROVIDER = process.env.LLM_PROVIDER || 'gemini'; // Initialize LLM client based on configuration export const llmClient: LLMClient = LLM_PROVIDER === 'local' && LOCAL_LLM_HOSTNAME && LOCAL_LLM_PORT && LOCAL_LLM_MODEL ? new LocalLLMClient(LOCAL_LLM_HOSTNAME, LOCAL_LLM_PORT, LOCAL_LLM_MODEL) - : new GoogleGenerativeAI(GEMINI_API_KEY); + : new GoogleAIWrapper(GEMINI_API_KEY); const DEFAULT_MODEL = 'gemini-1.5-flash'; diff --git a/src/tools/dedup.ts b/src/tools/dedup.ts index eb182a6..7c064c8 100644 --- a/src/tools/dedup.ts +++ b/src/tools/dedup.ts @@ -25,8 +25,8 @@ const responseSchema = { const model = llmClient.getGenerativeModel({ model: modelConfigs.dedup.model, + temperature: modelConfigs.dedup.temperature, generationConfig: { - temperature: modelConfigs.dedup.temperature, responseMimeType: "application/json", responseSchema: responseSchema } diff --git a/src/tools/error-analyzer.ts b/src/tools/error-analyzer.ts index 773a40a..0ef30f1 100644 --- a/src/tools/error-analyzer.ts +++ b/src/tools/error-analyzer.ts @@ -25,8 +25,8 @@ const responseSchema = { const model = llmClient.getGenerativeModel({ model: modelConfigs.errorAnalyzer.model, + temperature: modelConfigs.errorAnalyzer.temperature, generationConfig: { - temperature: modelConfigs.errorAnalyzer.temperature, responseMimeType: "application/json", responseSchema: responseSchema } diff --git a/src/tools/evaluator.ts b/src/tools/evaluator.ts index fee2210..049f193 100644 --- a/src/tools/evaluator.ts +++ b/src/tools/evaluator.ts @@ -21,8 +21,8 @@ const responseSchema = { const model = llmClient.getGenerativeModel({ model: modelConfigs.evaluator.model, + temperature: modelConfigs.evaluator.temperature, generationConfig: { - temperature: modelConfigs.evaluator.temperature, responseMimeType: "application/json", responseSchema: responseSchema } diff --git a/src/tools/query-rewriter.ts b/src/tools/query-rewriter.ts index 31ab569..eb89679 100644 --- a/src/tools/query-rewriter.ts +++ b/src/tools/query-rewriter.ts @@ -28,8 +28,8 @@ const responseSchema = { const model = llmClient.getGenerativeModel({ model: modelConfigs.queryRewriter.model, + temperature: modelConfigs.queryRewriter.temperature, generationConfig: { - temperature: modelConfigs.queryRewriter.temperature, responseMimeType: "application/json", responseSchema: responseSchema }