mirror of
https://github.com/jina-ai/node-DeepResearch.git
synced 2026-03-22 15:39:06 +08:00
llm-provider: google cloud vertex
This commit is contained in:
@@ -3,34 +3,21 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google';
|
||||
import { createOpenAI, OpenAIProviderSettings } from '@ai-sdk/openai';
|
||||
import configJson from '../config.json';
|
||||
|
||||
// Load environment variables
|
||||
dotenv.config();
|
||||
|
||||
// Types
|
||||
export type LLMProvider = 'openai' | 'gemini';
|
||||
export type LLMProvider = 'openai' | 'gemini' | 'vertex';
|
||||
export type ToolName = keyof typeof configJson.models.gemini.tools;
|
||||
|
||||
// Type definitions for our config structure
|
||||
type EnvConfig = typeof configJson.env;
|
||||
|
||||
interface ProviderConfigBase {
|
||||
interface ProviderConfig {
|
||||
createClient: string;
|
||||
clientConfig?: Record<string, any>;
|
||||
}
|
||||
|
||||
interface OpenAIProviderConfig extends ProviderConfigBase {
|
||||
clientConfig: {
|
||||
compatibility: "strict" | "compatible";
|
||||
};
|
||||
}
|
||||
|
||||
interface GeminiProviderConfig extends ProviderConfigBase {}
|
||||
|
||||
type ProviderConfig = {
|
||||
openai: OpenAIProviderConfig;
|
||||
gemini: GeminiProviderConfig;
|
||||
};
|
||||
|
||||
// Environment setup
|
||||
const env: EnvConfig = { ...configJson.env };
|
||||
(Object.keys(env) as (keyof EnvConfig)[]).forEach(key => {
|
||||
@@ -69,7 +56,7 @@ export const LLM_PROVIDER: LLMProvider = (() => {
|
||||
})();
|
||||
|
||||
function isValidProvider(provider: string): provider is LLMProvider {
|
||||
return provider === 'openai' || provider === 'gemini';
|
||||
return provider === 'openai' || provider === 'gemini' || provider === 'vertex';
|
||||
}
|
||||
|
||||
interface ToolConfig {
|
||||
@@ -85,7 +72,7 @@ interface ToolOverrides {
|
||||
|
||||
// Get tool configuration
|
||||
export function getToolConfig(toolName: ToolName): ToolConfig {
|
||||
const providerConfig = configJson.models[LLM_PROVIDER];
|
||||
const providerConfig = configJson.models[LLM_PROVIDER === 'vertex' ? 'gemini' : LLM_PROVIDER];
|
||||
const defaultConfig = providerConfig.default;
|
||||
const toolOverrides = providerConfig.tools[toolName] as ToolOverrides;
|
||||
|
||||
@@ -103,7 +90,7 @@ export function getMaxTokens(toolName: ToolName): number {
|
||||
// Get model instance
|
||||
export function getModel(toolName: ToolName) {
|
||||
const config = getToolConfig(toolName);
|
||||
const providerConfig = configJson.providers[LLM_PROVIDER] as ProviderConfig[typeof LLM_PROVIDER];
|
||||
const providerConfig = (configJson.providers as Record<string, ProviderConfig | undefined>)[LLM_PROVIDER];
|
||||
|
||||
if (LLM_PROVIDER === 'openai') {
|
||||
if (!OPENAI_API_KEY) {
|
||||
@@ -112,7 +99,7 @@ export function getModel(toolName: ToolName) {
|
||||
|
||||
const opt: OpenAIProviderSettings = {
|
||||
apiKey: OPENAI_API_KEY,
|
||||
compatibility: (providerConfig as OpenAIProviderConfig).clientConfig.compatibility
|
||||
compatibility: providerConfig?.clientConfig?.compatibility
|
||||
};
|
||||
|
||||
if (OPENAI_BASE_URL) {
|
||||
@@ -122,6 +109,14 @@ export function getModel(toolName: ToolName) {
|
||||
return createOpenAI(opt)(config.model);
|
||||
}
|
||||
|
||||
if (LLM_PROVIDER === 'vertex') {
|
||||
const createVertex = require('@ai-sdk/google-vertex').createVertex;
|
||||
if (toolName === 'search-grounding') {
|
||||
return createVertex({ project: process.env.GCLOUD_PROJECT, ...providerConfig?.clientConfig })(config.model, { useSearchGrounding: true });
|
||||
}
|
||||
return createVertex({ project: process.env.GCLOUD_PROJECT, ...providerConfig?.clientConfig })(config.model);
|
||||
}
|
||||
|
||||
if (!GEMINI_API_KEY) {
|
||||
throw new Error('GEMINI_API_KEY not found');
|
||||
}
|
||||
@@ -150,7 +145,7 @@ const configSummary = {
|
||||
provider: SEARCH_PROVIDER
|
||||
},
|
||||
tools: Object.fromEntries(
|
||||
Object.keys(configJson.models[LLM_PROVIDER].tools).map(name => [
|
||||
Object.keys(configJson.models[LLM_PROVIDER === 'vertex' ? 'gemini' : LLM_PROVIDER].tools).map(name => [
|
||||
name,
|
||||
getToolConfig(name as ToolName)
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user