From 21df4095ee9e785d35191ad93a17ffcd602e81ff Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:27:21 +0000 Subject: [PATCH] refactor: update files to use shared llm client Co-Authored-By: Han Xiao --- src/agent.ts | 16 +++++++--------- src/tools/dedup.ts | 10 +++++----- src/tools/error-analyzer.ts | 10 +++++----- src/tools/evaluator.ts | 10 +++++----- src/tools/query-rewriter.ts | 12 ++++++------ 5 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/agent.ts b/src/agent.ts index 2b0616c..96a15bc 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -1,4 +1,4 @@ -import {GoogleGenerativeAI, SchemaType} from "@google/generative-ai"; +import {SchemaType} from "@google/generative-ai"; import {readUrl} from "./tools/read"; import fs from 'fs/promises'; import {SafeSearchType, search as duckSearch} from "duck-duck-scrape"; @@ -7,7 +7,8 @@ import {rewriteQuery} from "./tools/query-rewriter"; import {dedupQueries} from "./tools/dedup"; import {evaluateAnswer} from "./tools/evaluator"; import {analyzeSteps} from "./tools/error-analyzer"; -import {GEMINI_API_KEY, SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config"; +import {SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config"; +import {llmClient} from "./utils/llm-client"; import {TokenTracker} from "./utils/token-tracker"; import {ActionTracker} from "./utils/action-tracker"; import {StepAction, SchemaProperty, ResponseSchema, AnswerAction} from "./types"; @@ -356,10 +357,10 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_ false ); - const model = genAI.getGenerativeModel({ + const model = llmClient.getModel({ model: modelConfigs.agent.model, + temperature: modelConfigs.agent.temperature, generationConfig: { - temperature: modelConfigs.agent.temperature, responseMimeType: "application/json", responseSchema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch) } @@ -699,10 +700,10 @@ You decided to think out of the box or cut from a completely different angle.`); true ); - const model = genAI.getGenerativeModel({ + const model = llmClient.getModel({ model: modelConfigs.agentBeastMode.model, + temperature: modelConfigs.agentBeastMode.temperature, generationConfig: { - temperature: modelConfigs.agentBeastMode.temperature, responseMimeType: "application/json", responseSchema: getSchema(false, false, allowAnswer, false) } @@ -733,9 +734,6 @@ async function storeContext(prompt: string, memory: any[][], step: number) { } } -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); - - export async function main() { const question = process.argv[2] || ""; const { diff --git a/src/tools/dedup.ts b/src/tools/dedup.ts index c48f5ce..b53d2cb 100644 --- a/src/tools/dedup.ts +++ b/src/tools/dedup.ts @@ -1,5 +1,6 @@ -import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import { SchemaType } from "@google/generative-ai"; +import { modelConfigs } from "../config"; +import { llmClient } from "../utils/llm-client"; import { TokenTracker } from "../utils/token-tracker"; import { DedupResponse } from '../types'; @@ -23,11 +24,10 @@ const responseSchema = { required: ["think", "unique_queries"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getModel({ model: modelConfigs.dedup.model, + temperature: modelConfigs.dedup.temperature, generationConfig: { - temperature: modelConfigs.dedup.temperature, responseMimeType: "application/json", responseSchema: responseSchema } diff --git a/src/tools/error-analyzer.ts b/src/tools/error-analyzer.ts index 7f55b94..de911f5 100644 --- a/src/tools/error-analyzer.ts +++ b/src/tools/error-analyzer.ts @@ -1,5 +1,6 @@ -import {GoogleGenerativeAI, SchemaType} from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import {SchemaType} from "@google/generative-ai"; +import { modelConfigs } from "../config"; +import { llmClient } from "../utils/llm-client"; import { TokenTracker } from "../utils/token-tracker"; import { ErrorAnalysisResponse } from '../types'; @@ -23,11 +24,10 @@ const responseSchema = { required: ["recap", "blame", "improvement"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getModel({ model: modelConfigs.errorAnalyzer.model, + temperature: modelConfigs.errorAnalyzer.temperature, generationConfig: { - temperature: modelConfigs.errorAnalyzer.temperature, responseMimeType: "application/json", responseSchema: responseSchema } diff --git a/src/tools/evaluator.ts b/src/tools/evaluator.ts index 222e0d2..a6acd6a 100644 --- a/src/tools/evaluator.ts +++ b/src/tools/evaluator.ts @@ -1,5 +1,6 @@ -import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import { SchemaType } from "@google/generative-ai"; +import { modelConfigs } from "../config"; +import { llmClient } from "../utils/llm-client"; import { TokenTracker } from "../utils/token-tracker"; import { EvaluationResponse } from '../types'; @@ -19,11 +20,10 @@ const responseSchema = { required: ["is_definitive", "reasoning"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getModel({ model: modelConfigs.evaluator.model, + temperature: modelConfigs.evaluator.temperature, generationConfig: { - temperature: modelConfigs.evaluator.temperature, responseMimeType: "application/json", responseSchema: responseSchema } diff --git a/src/tools/query-rewriter.ts b/src/tools/query-rewriter.ts index 8410413..b25ed57 100644 --- a/src/tools/query-rewriter.ts +++ b/src/tools/query-rewriter.ts @@ -1,5 +1,6 @@ -import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; -import { GEMINI_API_KEY, modelConfigs } from "../config"; +import { SchemaType } from "@google/generative-ai"; +import { modelConfigs } from "../config"; +import { llmClient } from "../utils/llm-client"; import { TokenTracker } from "../utils/token-tracker"; import { SearchAction } from "../types"; @@ -26,11 +27,10 @@ const responseSchema = { required: ["think", "queries"] }; -const genAI = new GoogleGenerativeAI(GEMINI_API_KEY); -const model = genAI.getGenerativeModel({ +const model = llmClient.getModel({ model: modelConfigs.queryRewriter.model, + temperature: modelConfigs.queryRewriter.temperature, generationConfig: { - temperature: modelConfigs.queryRewriter.temperature, responseMimeType: "application/json", responseSchema: responseSchema } @@ -129,4 +129,4 @@ export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker) console.error('Error in query rewriting:', error); throw error; } -} \ No newline at end of file +}