llm-provider: google cloud vertex

This commit is contained in:
Yanlong Wang
2025-02-12 18:53:07 +08:00
parent d4167e81d6
commit 44530a4760
7 changed files with 440 additions and 29 deletions

View File

@@ -3,34 +3,21 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createOpenAI, OpenAIProviderSettings } from '@ai-sdk/openai';
import configJson from '../config.json';
// Load environment variables
dotenv.config();
// Types
export type LLMProvider = 'openai' | 'gemini';
export type LLMProvider = 'openai' | 'gemini' | 'vertex';
export type ToolName = keyof typeof configJson.models.gemini.tools;
// Type definitions for our config structure
type EnvConfig = typeof configJson.env;
interface ProviderConfigBase {
interface ProviderConfig {
createClient: string;
clientConfig?: Record<string, any>;
}
interface OpenAIProviderConfig extends ProviderConfigBase {
clientConfig: {
compatibility: "strict" | "compatible";
};
}
interface GeminiProviderConfig extends ProviderConfigBase {}
type ProviderConfig = {
openai: OpenAIProviderConfig;
gemini: GeminiProviderConfig;
};
// Environment setup
const env: EnvConfig = { ...configJson.env };
(Object.keys(env) as (keyof EnvConfig)[]).forEach(key => {
@@ -69,7 +56,7 @@ export const LLM_PROVIDER: LLMProvider = (() => {
})();
function isValidProvider(provider: string): provider is LLMProvider {
return provider === 'openai' || provider === 'gemini';
return provider === 'openai' || provider === 'gemini' || provider === 'vertex';
}
interface ToolConfig {
@@ -85,7 +72,7 @@ interface ToolOverrides {
// Get tool configuration
export function getToolConfig(toolName: ToolName): ToolConfig {
const providerConfig = configJson.models[LLM_PROVIDER];
const providerConfig = configJson.models[LLM_PROVIDER === 'vertex' ? 'gemini' : LLM_PROVIDER];
const defaultConfig = providerConfig.default;
const toolOverrides = providerConfig.tools[toolName] as ToolOverrides;
@@ -103,7 +90,7 @@ export function getMaxTokens(toolName: ToolName): number {
// Get model instance
export function getModel(toolName: ToolName) {
const config = getToolConfig(toolName);
const providerConfig = configJson.providers[LLM_PROVIDER] as ProviderConfig[typeof LLM_PROVIDER];
const providerConfig = (configJson.providers as Record<string, ProviderConfig | undefined>)[LLM_PROVIDER];
if (LLM_PROVIDER === 'openai') {
if (!OPENAI_API_KEY) {
@@ -112,7 +99,7 @@ export function getModel(toolName: ToolName) {
const opt: OpenAIProviderSettings = {
apiKey: OPENAI_API_KEY,
compatibility: (providerConfig as OpenAIProviderConfig).clientConfig.compatibility
compatibility: providerConfig?.clientConfig?.compatibility
};
if (OPENAI_BASE_URL) {
@@ -122,6 +109,14 @@ export function getModel(toolName: ToolName) {
return createOpenAI(opt)(config.model);
}
if (LLM_PROVIDER === 'vertex') {
const createVertex = require('@ai-sdk/google-vertex').createVertex;
if (toolName === 'search-grounding') {
return createVertex({ project: process.env.GCLOUD_PROJECT, ...providerConfig?.clientConfig })(config.model, { useSearchGrounding: true });
}
return createVertex({ project: process.env.GCLOUD_PROJECT, ...providerConfig?.clientConfig })(config.model);
}
if (!GEMINI_API_KEY) {
throw new Error('GEMINI_API_KEY not found');
}
@@ -150,7 +145,7 @@ const configSummary = {
provider: SEARCH_PROVIDER
},
tools: Object.fromEntries(
Object.keys(configJson.models[LLM_PROVIDER].tools).map(name => [
Object.keys(configJson.models[LLM_PROVIDER === 'vertex' ? 'gemini' : LLM_PROVIDER].tools).map(name => [
name,
getToolConfig(name as ToolName)
])