refactor: extract LLM client to config.ts and add local LLM support

Co-Authored-By: Han Xiao <han.xiao@jina.ai>
This commit is contained in:
Devin AI 2025-02-05 09:04:15 +00:00
parent 2b84a577c8
commit 39c8e55651
5 changed files with 99 additions and 18 deletions

View File

@ -15,6 +15,80 @@ interface ToolConfigs {
agentBeastMode: ModelConfig;
}
import { GenerateContentResult, GoogleGenerativeAI } from '@google/generative-ai';
interface LLMClientConfig {
model: string;
temperature: number;
generationConfig?: {
temperature: number;
responseMimeType: string;
responseSchema: any;
};
}
interface LLMClient {
getGenerativeModel(config: LLMClientConfig): {
generateContent(prompt: string): Promise<GenerateContentResult>;
};
}
interface GenerateContentResult {
response: {
text(): string;
usageMetadata: {
totalTokenCount: number;
};
};
}
class LocalLLMClient implements LLMClient {
constructor(
private hostname: string,
private port: string,
private model: string
) {}
getGenerativeModel(config: LLMClientConfig) {
return {
generateContent: async (prompt: string) => {
const response = await fetch(`http://${this.hostname}:${this.port}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: this.model,
messages: [
{
role: 'user',
content: prompt,
},
],
temperature: config.generationConfig?.temperature ?? config.temperature,
response_format: {
type: 'json_schema',
json_schema: config.generationConfig?.responseSchema,
},
max_tokens: 1000,
stream: false,
}),
});
const data = await response.json();
return {
response: {
text: () => data.choices[0].message.content,
usageMetadata: {
totalTokenCount: data.usage?.total_tokens || 0,
},
},
};
},
};
}
}
dotenv.config();
@ -32,7 +106,18 @@ if (process.env.https_proxy) {
export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string;
export const JINA_API_KEY = process.env.JINA_API_KEY as string;
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string;
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina'
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina';
// LLM Configuration
export const LOCAL_LLM_HOSTNAME = process.env.LOCAL_LLM_HOSTNAME;
export const LOCAL_LLM_PORT = process.env.LOCAL_LLM_PORT;
export const LOCAL_LLM_MODEL = process.env.LOCAL_LLM_MODEL;
export const LLM_PROVIDER = process.env.LLM_PROVIDER || 'gemini';
// Initialize LLM client based on configuration
export const llmClient: LLMClient = LLM_PROVIDER === 'local' && LOCAL_LLM_HOSTNAME && LOCAL_LLM_PORT && LOCAL_LLM_MODEL
? new LocalLLMClient(LOCAL_LLM_HOSTNAME, LOCAL_LLM_PORT, LOCAL_LLM_MODEL)
: new GoogleGenerativeAI(GEMINI_API_KEY);
const DEFAULT_MODEL = 'gemini-1.5-flash';

View File

@ -1,5 +1,5 @@
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
import { GEMINI_API_KEY, modelConfigs } from "../config";
import { SchemaType } from "@google/generative-ai";
import { modelConfigs, llmClient } from "../config";
import { TokenTracker } from "../utils/token-tracker";
import { DedupResponse } from '../types';
@ -23,8 +23,7 @@ const responseSchema = {
required: ["think", "unique_queries"]
};
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
const model = genAI.getGenerativeModel({
const model = llmClient.getGenerativeModel({
model: modelConfigs.dedup.model,
generationConfig: {
temperature: modelConfigs.dedup.temperature,

View File

@ -1,5 +1,5 @@
import {GoogleGenerativeAI, SchemaType} from "@google/generative-ai";
import { GEMINI_API_KEY, modelConfigs } from "../config";
import { SchemaType } from "@google/generative-ai";
import { modelConfigs, llmClient } from "../config";
import { TokenTracker } from "../utils/token-tracker";
import { ErrorAnalysisResponse } from '../types';
@ -23,8 +23,7 @@ const responseSchema = {
required: ["recap", "blame", "improvement"]
};
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
const model = genAI.getGenerativeModel({
const model = llmClient.getGenerativeModel({
model: modelConfigs.errorAnalyzer.model,
generationConfig: {
temperature: modelConfigs.errorAnalyzer.temperature,

View File

@ -1,5 +1,5 @@
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
import { GEMINI_API_KEY, modelConfigs } from "../config";
import { SchemaType } from "@google/generative-ai";
import { modelConfigs, llmClient } from "../config";
import { TokenTracker } from "../utils/token-tracker";
import { EvaluationResponse } from '../types';
@ -19,8 +19,7 @@ const responseSchema = {
required: ["is_definitive", "reasoning"]
};
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
const model = genAI.getGenerativeModel({
const model = llmClient.getGenerativeModel({
model: modelConfigs.evaluator.model,
generationConfig: {
temperature: modelConfigs.evaluator.temperature,

View File

@ -1,5 +1,5 @@
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";
import { GEMINI_API_KEY, modelConfigs } from "../config";
import { SchemaType } from "@google/generative-ai";
import { modelConfigs, llmClient } from "../config";
import { TokenTracker } from "../utils/token-tracker";
import { SearchAction } from "../types";
@ -26,8 +26,7 @@ const responseSchema = {
required: ["think", "queries"]
};
const genAI = new GoogleGenerativeAI(GEMINI_API_KEY);
const model = genAI.getGenerativeModel({
const model = llmClient.getGenerativeModel({
model: modelConfigs.queryRewriter.model,
generationConfig: {
temperature: modelConfigs.queryRewriter.temperature,
@ -129,4 +128,4 @@ export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker)
console.error('Error in query rewriting:', error);
throw error;
}
}
}