refactor: add safe obj generation (#60)

* fix: broken markdown footnote

* refactor: safe obj generation

* test: update token tracking assertions to match new implementation

Co-Authored-By: Han Xiao <han.xiao@jina.ai>

* refactor: safe obj generation

* chore: update readme

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
This commit is contained in:
Han Xiao 2025-02-13 00:33:58 +08:00 committed by GitHub
parent 08c1dd04ca
commit bd77535dd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 286 additions and 294 deletions

View File

@ -25,7 +25,7 @@ flowchart LR
```
Note that this project does *not* try to mimic what OpenAI or Gemini do with their deep research product. **We focus on finding the right answer with this loop cycle.** There is no plan to implement the structural article generation part. So if you want a service that can do deep searches and give you an answer, this is it. If you want a service that mimics long article writing like OpenAI/Gemini, **this isn't it.**
Unlike OpenAI and Gemini's Deep Research capabilities, we focus solely on **delivering accurate answers through our iterative process**. We don't optimize for long-form articles if you need quick, precise answers from deep search, you're in the right place. If you're looking for AI-generated reports like OpenAI/Gemini do, this isn't for you.
## Install
@ -195,12 +195,7 @@ Response format:
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21,
"completion_tokens_details": {
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
"total_tokens": 21
}
}
```

View File

@ -32,13 +32,14 @@
"maxTokens": 8000
},
"tools": {
"search-grounding": { "temperature": 0 },
"searchGrounding": { "temperature": 0 },
"dedup": { "temperature": 0.1 },
"evaluator": {},
"errorAnalyzer": {},
"queryRewriter": { "temperature": 0.1 },
"agent": { "temperature": 0.7 },
"agentBeastMode": { "temperature": 0.7 }
"agentBeastMode": { "temperature": 0.7 },
"fallback": { "temperature": 0 }
}
},
"openai": {
@ -48,13 +49,14 @@
"maxTokens": 8000
},
"tools": {
"search-grounding": { "temperature": 0 },
"searchGrounding": { "temperature": 0 },
"dedup": { "temperature": 0.1 },
"evaluator": {},
"errorAnalyzer": {},
"queryRewriter": { "temperature": 0.1 },
"agent": { "temperature": 0.7 },
"agentBeastMode": { "temperature": 0.7 }
"agentBeastMode": { "temperature": 0.7 },
"fallback": { "temperature": 0 }
}
}
}

View File

@ -38,13 +38,14 @@
"maxTokens": 8000
},
"tools": {
"search-grounding": { "temperature": 0 },
"searchGrounding": { "temperature": 0 },
"dedup": { "temperature": 0.1 },
"evaluator": {},
"errorAnalyzer": {},
"queryRewriter": { "temperature": 0.1 },
"agent": { "temperature": 0.7 },
"agentBeastMode": { "temperature": 0.7 }
"agentBeastMode": { "temperature": 0.7 },
"fallback": { "temperature": 0 }
}
},
"openai": {
@ -54,13 +55,14 @@
"maxTokens": 8000
},
"tools": {
"search-grounding": { "temperature": 0 },
"searchGrounding": { "temperature": 0 },
"dedup": { "temperature": 0.1 },
"evaluator": {},
"errorAnalyzer": {},
"queryRewriter": { "temperature": 0.1 },
"agent": { "temperature": 0.7 },
"agentBeastMode": { "temperature": 0.7 }
"agentBeastMode": { "temperature": 0.7 },
"fallback": { "temperature": 0 }
}
}
}

View File

@ -9,8 +9,11 @@ describe('/v1/chat/completions', () => {
jest.setTimeout(120000); // Increase timeout for all tests in this suite
beforeEach(async () => {
// Set NODE_ENV to test to prevent server from auto-starting
// Set up test environment
process.env.NODE_ENV = 'test';
process.env.LLM_PROVIDER = 'openai'; // Use OpenAI provider for tests
process.env.OPENAI_API_KEY = 'test-key';
process.env.JINA_API_KEY = 'test-key';
// Clean up any existing secret
const existingSecretIndex = process.argv.findIndex(arg => arg.startsWith('--secret='));
@ -27,6 +30,10 @@ describe('/v1/chat/completions', () => {
});
afterEach(async () => {
// Clean up environment variables
delete process.env.OPENAI_API_KEY;
delete process.env.JINA_API_KEY;
// Clean up any remaining event listeners
const emitter = EventEmitter.prototype;
emitter.removeAllListeners();
@ -258,17 +265,10 @@ describe('/v1/chat/completions', () => {
expect(validResponse.body.usage).toMatchObject({
prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number),
total_tokens: expect.any(Number),
completion_tokens_details: {
reasoning_tokens: expect.any(Number),
accepted_prediction_tokens: expect.any(Number),
rejected_prediction_tokens: expect.any(Number)
}
total_tokens: expect.any(Number)
});
// Verify token counts are reasonable
expect(validResponse.body.usage.prompt_tokens).toBeGreaterThan(0);
expect(validResponse.body.usage.completion_tokens).toBeGreaterThan(0);
// Basic token tracking structure should be present
expect(validResponse.body.usage.total_tokens).toBe(
validResponse.body.usage.prompt_tokens + validResponse.body.usage.completion_tokens
);
@ -289,17 +289,10 @@ describe('/v1/chat/completions', () => {
expect(usage).toMatchObject({
prompt_tokens: expect.any(Number),
completion_tokens: expect.any(Number),
total_tokens: expect.any(Number),
completion_tokens_details: {
reasoning_tokens: expect.any(Number),
accepted_prediction_tokens: expect.any(Number),
rejected_prediction_tokens: expect.any(Number)
}
total_tokens: expect.any(Number)
});
// Verify token counts are reasonable
expect(usage.prompt_tokens).toBeGreaterThan(0);
expect(usage.completion_tokens).toBeGreaterThan(0);
// Basic token tracking structure should be present
expect(usage.total_tokens).toBe(
usage.prompt_tokens + usage.completion_tokens
);

View File

@ -1,8 +1,7 @@
import {z, ZodObject} from 'zod';
import {CoreAssistantMessage, CoreUserMessage, generateObject} from 'ai';
import {getModel, getMaxTokens, SEARCH_PROVIDER, STEP_SLEEP} from "./config";
import {CoreAssistantMessage, CoreUserMessage} from 'ai';
import {SEARCH_PROVIDER, STEP_SLEEP} from "./config";
import {readUrl, removeAllLineBreaks} from "./tools/read";
import {handleGenerateObjectError} from './utils/error-handling';
import fs from 'fs/promises';
import {SafeSearchType, search as duckSearch} from "duck-duck-scrape";
import {braveSearch} from "./tools/brave-search";
@ -17,6 +16,7 @@ import {TrackerContext} from "./types";
import {search} from "./tools/jina-search";
// import {grounding} from "./tools/grounding";
import {zodToJsonSchema} from "zod-to-json-schema";
import {ObjectGeneratorSafe} from "./utils/safe-generator";
async function sleep(ms: number) {
const seconds = Math.ceil(ms / 1000);
@ -364,23 +364,13 @@ export async function getResponse(question: string,
false
);
schema = getSchema(allowReflect, allowRead, allowAnswer, allowSearch)
const model = getModel('agent');
let object;
try {
const result = await generateObject({
model,
schema,
prompt,
maxTokens: getMaxTokens('agent')
});
object = result.object;
context.tokenTracker.trackUsage('agent', result.usage);
} catch (error) {
const result = await handleGenerateObjectError<StepAction>(error);
object = result.object;
context.tokenTracker.trackUsage('agent', result.usage);
}
thisStep = object as StepAction;
const generator = new ObjectGeneratorSafe(context.tokenTracker);
const result = await generator.generateObject({
model: 'agent',
schema,
prompt,
});
thisStep = result.object as StepAction;
// print allowed and chose action
const actionsStr = [allowSearch, allowRead, allowAnswer, allowReflect].map((a, i) => a ? ['search', 'read', 'answer', 'reflect'][i] : null).filter(a => a).join(', ');
console.log(`${thisStep.action} <- [${actionsStr}]`);
@ -464,6 +454,7 @@ ${evaluation.think}
});
if (errorAnalysis.questionsToAnswer) {
// reranker? maybe
gaps.push(...errorAnalysis.questionsToAnswer.slice(0, 2));
allQuestions.push(...errorAnalysis.questionsToAnswer.slice(0, 2));
gaps.push(question.trim()); // always keep the original question in the gaps
@ -510,8 +501,8 @@ ${newGapQuestions.map((q: string) => `- ${q}`).join('\n')}
You will now figure out the answers to these sub-questions and see if they can help you find the answer to the original question.
`);
gaps.push(...newGapQuestions);
allQuestions.push(...newGapQuestions);
gaps.push(...newGapQuestions.slice(0, 2));
allQuestions.push(...newGapQuestions.slice(0, 2));
gaps.push(question.trim()); // always keep the original question in the gaps
} else {
diaryContext.push(`
@ -708,24 +699,15 @@ You decided to think out of the box or cut from a completely different angle.`);
);
schema = getSchema(false, false, true, false);
const model = getModel('agentBeastMode');
let object;
try {
const result = await generateObject({
model,
schema: schema,
prompt,
maxTokens: getMaxTokens('agentBeastMode')
});
object = result.object;
context.tokenTracker.trackUsage('agent', result.usage);
} catch (error) {
const result = await handleGenerateObjectError<StepAction>(error);
object = result.object;
context.tokenTracker.trackUsage('agent', result.usage);
}
const generator = new ObjectGeneratorSafe(context.tokenTracker);
const result = await generator.generateObject({
model: 'agentBeastMode',
schema,
prompt,
});
await storeContext(prompt, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
thisStep = object as StepAction;
thisStep = result.object as StepAction;
context.actionTracker.trackAction({totalStep, thisStep, gaps, badAttempts});
console.log(thisStep)

View File

@ -111,7 +111,7 @@ export function getModel(toolName: ToolName) {
if (LLM_PROVIDER === 'vertex') {
const createVertex = require('@ai-sdk/google-vertex').createVertex;
if (toolName === 'search-grounding') {
if (toolName === 'searchGrounding') {
return createVertex({ project: process.env.GCLOUD_PROJECT, ...providerConfig?.clientConfig })(config.model, { useSearchGrounding: true });
}
return createVertex({ project: process.env.GCLOUD_PROJECT, ...providerConfig?.clientConfig })(config.model);
@ -121,7 +121,7 @@ export function getModel(toolName: ToolName) {
throw new Error('GEMINI_API_KEY not found');
}
if (toolName === 'search-grounding') {
if (toolName === 'searchGrounding') {
return createGoogleGenerativeAI({ apiKey: GEMINI_API_KEY })(config.model, { useSearchGrounding: true });
}
return createGoogleGenerativeAI({ apiKey: GEMINI_API_KEY })(config.model);

View File

@ -1,11 +1,7 @@
import {z} from 'zod';
import {generateObject} from 'ai';
import {getModel, getMaxTokens} from "../config";
import {TokenTracker} from "../utils/token-tracker";
import {handleGenerateObjectError} from '../utils/error-handling';
import type {DedupResponse} from '../types';
import {ObjectGeneratorSafe} from "../utils/safe-generator";
const model = getModel('dedup');
const responseSchema = z.object({
think: z.string().describe('Strategic reasoning about the overall deduplication approach'),
@ -65,31 +61,29 @@ SetA: ${JSON.stringify(newQueries)}
SetB: ${JSON.stringify(existingQueries)}`;
}
export async function dedupQueries(newQueries: string[], existingQueries: string[], tracker?: TokenTracker): Promise<{ unique_queries: string[] }> {
try {
const prompt = getPrompt(newQueries, existingQueries);
let object;
let usage;
try {
const result = await generateObject({
model,
schema: responseSchema,
prompt,
maxTokens: getMaxTokens('dedup')
});
object = result.object;
usage = result.usage
} catch (error) {
const result = await handleGenerateObjectError<DedupResponse>(error);
object = result.object;
usage = result.usage
}
console.log('Dedup:', object.unique_queries);
(tracker || new TokenTracker()).trackUsage('dedup', usage);
return {unique_queries: object.unique_queries};
const TOOL_NAME = 'dedup';
export async function dedupQueries(
newQueries: string[],
existingQueries: string[],
tracker?: TokenTracker
): Promise<{ unique_queries: string[] }> {
try {
const generator = new ObjectGeneratorSafe(tracker);
const prompt = getPrompt(newQueries, existingQueries);
const result = await generator.generateObject({
model: TOOL_NAME,
schema: responseSchema,
prompt,
});
console.log(TOOL_NAME, result.object.unique_queries);
return {unique_queries: result.object.unique_queries};
} catch (error) {
console.error('Error in deduplication analysis:', error);
console.error(`Error in ${TOOL_NAME}`, error);
throw error;
}
}
}

View File

@ -1,11 +1,8 @@
import {z} from 'zod';
import {generateObject} from 'ai';
import {getModel, getMaxTokens} from "../config";
import {TokenTracker} from "../utils/token-tracker";
import {ErrorAnalysisResponse} from '../types';
import {handleGenerateObjectError} from '../utils/error-handling';
import {ObjectGeneratorSafe} from "../utils/safe-generator";
const model = getModel('errorAnalyzer');
const responseSchema = z.object({
recap: z.string().describe('Recap of the actions taken and the steps conducted'),
@ -111,33 +108,27 @@ ${diaryContext.join('\n')}
`;
}
export async function analyzeSteps(diaryContext: string[], tracker?: TokenTracker): Promise<{ response: ErrorAnalysisResponse }> {
const TOOL_NAME = 'errorAnalyzer';
export async function analyzeSteps(
diaryContext: string[],
tracker?: TokenTracker
): Promise<{ response: ErrorAnalysisResponse }> {
try {
const generator = new ObjectGeneratorSafe(tracker);
const prompt = getPrompt(diaryContext);
let object;
let usage;
try {
const result = await generateObject({
model,
schema: responseSchema,
prompt,
maxTokens: getMaxTokens('errorAnalyzer')
});
object = result.object;
usage = result.usage;
} catch (error) {
const result = await handleGenerateObjectError<ErrorAnalysisResponse>(error);
object = result.object;
usage = result.usage;
}
console.log('Error analysis:', {
is_valid: !object.blame,
reason: object.blame || 'No issues found'
const result = await generator.generateObject({
model: TOOL_NAME,
schema: responseSchema,
prompt,
});
(tracker || new TokenTracker()).trackUsage('error-analyzer', usage);
return {response: object};
console.log(TOOL_NAME, result.object);
return { response: result.object };
} catch (error) {
console.error('Error in answer evaluation:', error);
console.error(`Error in ${TOOL_NAME}`, error);
throw error;
}
}
}

View File

@ -1,12 +1,10 @@
import {z} from 'zod';
import {generateObject, GenerateObjectResult} from 'ai';
import {getModel, getMaxTokens} from "../config";
import {GenerateObjectResult} from 'ai';
import {TokenTracker} from "../utils/token-tracker";
import {AnswerAction, EvaluationResponse} from '../types';
import {handleGenerateObjectError} from '../utils/error-handling';
import {readUrl, removeAllLineBreaks} from "./read";
import {ObjectGeneratorSafe} from "../utils/safe-generator";
const model = getModel('evaluator');
type EvaluationType = 'definitive' | 'freshness' | 'plurality' | 'attribution';
@ -371,19 +369,21 @@ Now evaluate this question:
Question: ${JSON.stringify(question)}`;
}
const TOOL_NAME = 'evaluator';
export async function evaluateQuestion(
question: string,
tracker?: TokenTracker
): Promise<EvaluationType[]> {
try {
const result = await generateObject({
model: getModel('evaluator'),
const generator = new ObjectGeneratorSafe(tracker);
const result = await generator.generateObject({
model: TOOL_NAME,
schema: questionEvaluationSchema,
prompt: getQuestionEvaluationPrompt(question),
maxTokens: getMaxTokens('evaluator')
});
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage);
console.log('Question Evaluation:', result.object);
// Always include definitive in types
@ -391,49 +391,38 @@ export async function evaluateQuestion(
if (result.object.needsFreshness) types.push('freshness');
if (result.object.needsPlurality) types.push('plurality');
console.log('Question Metrics:', types)
console.log('Question Metrics:', types);
// Always evaluate definitive first, then freshness (if needed), then plurality (if needed)
return types;
} catch (error) {
const errorResult = await handleGenerateObjectError<EvaluationResponse>(error);
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.usage);
console.error('Error in question evaluation:', error);
// Default to all evaluation types in case of error
return ['definitive', 'freshness', 'plurality'];
}
}
// Helper function to handle common evaluation logic
async function performEvaluation(
async function performEvaluation<T>(
evaluationType: EvaluationType,
params: {
model: any;
schema: z.ZodType<any>;
schema: z.ZodType<T>;
prompt: string;
maxTokens: number;
},
tracker?: TokenTracker
): Promise<GenerateObjectResult<any>> {
try {
const result = await generateObject({
model: params.model,
schema: params.schema,
prompt: params.prompt,
maxTokens: params.maxTokens
});
): Promise<GenerateObjectResult<T>> {
const generator = new ObjectGeneratorSafe(tracker);
(tracker || new TokenTracker()).trackUsage('evaluator', result.usage);
console.log(`${evaluationType} Evaluation:`, result.object);
const result = await generator.generateObject({
model: TOOL_NAME,
schema: params.schema,
prompt: params.prompt,
});
return result;
} catch (error) {
const errorResult = await handleGenerateObjectError<any>(error);
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.usage);
return {
object: errorResult.object,
usage: errorResult.usage
} as GenerateObjectResult<any>;
}
console.log(`${evaluationType} ${TOOL_NAME}`, result.object);
return result as GenerateObjectResult<any>;
}
@ -452,84 +441,70 @@ export async function evaluateAnswer(
}
for (const evaluationType of evaluationOrder) {
try {
switch (evaluationType) {
case 'attribution': {
// Safely handle references and ensure we have content
const urls = action.references?.map(ref => ref.url) ?? [];
const uniqueURLs = [...new Set(urls)];
const allKnowledge = await fetchSourceContent(uniqueURLs, tracker);
switch (evaluationType) {
case 'attribution': {
// Safely handle references and ensure we have content
const urls = action.references?.map(ref => ref.url) ?? [];
const uniqueURLs = [...new Set(urls)];
const allKnowledge = await fetchSourceContent(uniqueURLs, tracker);
if (!allKnowledge.trim()) {
return {
response: {
pass: false,
think: "The answer does not provide any valid attribution references that could be verified. No accessible source content was found to validate the claims made in the answer.",
type: 'attribution',
}
};
}
result = await performEvaluation(
'attribution',
{
model,
schema: attributionSchema,
prompt: getAttributionPrompt(question, action.answer, allKnowledge),
maxTokens: getMaxTokens('evaluator')
},
tracker
);
break;
if (!allKnowledge.trim()) {
return {
response: {
pass: false,
think: "The answer does not provide any valid attribution references that could be verified. No accessible source content was found to validate the claims made in the answer.",
type: 'attribution',
}
};
}
case 'definitive':
result = await performEvaluation(
'definitive',
{
model,
schema: definitiveSchema,
prompt: getDefinitivePrompt(question, action.answer),
maxTokens: getMaxTokens('evaluator')
},
tracker
);
break;
case 'freshness':
result = await performEvaluation(
'freshness',
{
model,
schema: freshnessSchema,
prompt: getFreshnessPrompt(question, action.answer, new Date().toISOString()),
maxTokens: getMaxTokens('evaluator')
},
tracker
);
break;
case 'plurality':
result = await performEvaluation(
'plurality',
{
model,
schema: pluralitySchema,
prompt: getPluralityPrompt(question, action.answer),
maxTokens: getMaxTokens('evaluator')
},
tracker
);
break;
result = await performEvaluation(
'attribution',
{
schema: attributionSchema,
prompt: getAttributionPrompt(question, action.answer, allKnowledge),
},
tracker
);
break;
}
if (!result?.object.pass) {
return {response: result.object};
}
} catch (error) {
const errorResult = await handleGenerateObjectError<EvaluationResponse>(error);
(tracker || new TokenTracker()).trackUsage('evaluator', errorResult.usage);
return {response: errorResult.object};
case 'definitive':
result = await performEvaluation(
'definitive',
{
schema: definitiveSchema,
prompt: getDefinitivePrompt(question, action.answer),
},
tracker
);
break;
case 'freshness':
result = await performEvaluation(
'freshness',
{
schema: freshnessSchema,
prompt: getFreshnessPrompt(question, action.answer, new Date().toISOString()),
},
tracker
);
break;
case 'plurality':
result = await performEvaluation(
'plurality',
{
schema: pluralitySchema,
prompt: getPluralityPrompt(question, action.answer),
},
tracker
);
break;
}
if (!result?.object.pass) {
return {response: result.object};
}
}

View File

@ -3,7 +3,7 @@ import {getModel} from "../config";
import { GoogleGenerativeAIProviderMetadata } from '@ai-sdk/google';
import {TokenTracker} from "../utils/token-tracker";
const model = getModel('search-grounding')
const model = getModel('searchGrounding')
export async function grounding(query: string, tracker?: TokenTracker): Promise<string> {
try {

View File

@ -1,11 +1,8 @@
import { z } from 'zod';
import { generateObject } from 'ai';
import { getModel, getMaxTokens } from "../config";
import { TokenTracker } from "../utils/token-tracker";
import { SearchAction, KeywordsResponse } from '../types';
import { handleGenerateObjectError } from '../utils/error-handling';
import { SearchAction } from '../types';
import {ObjectGeneratorSafe} from "../utils/safe-generator";
const model = getModel('queryRewriter');
const responseSchema = z.object({
think: z.string().describe('Strategic reasoning about query complexity and search approach'),
@ -93,30 +90,23 @@ Intention: ${action.think}
`;
}
const TOOL_NAME = 'queryRewriter';
export async function rewriteQuery(action: SearchAction, tracker?: TokenTracker): Promise<{ queries: string[] }> {
try {
const generator = new ObjectGeneratorSafe(tracker);
const prompt = getPrompt(action);
let object;
let usage;
try {
const result = await generateObject({
model,
schema: responseSchema,
prompt,
maxTokens: getMaxTokens('queryRewriter')
});
object = result.object;
usage = result.usage;
} catch (error) {
const result = await handleGenerateObjectError<KeywordsResponse>(error);
object = result.object;
usage = result.usage;
}
console.log('Query rewriter:', object.queries);
(tracker || new TokenTracker()).trackUsage('query-rewriter', usage);
return { queries: object.queries };
const result = await generator.generateObject({
model: TOOL_NAME,
schema: responseSchema,
prompt,
});
console.log(TOOL_NAME, result.object.queries);
return { queries: result.object.queries };
} catch (error) {
console.error('Error in query rewriting:', error);
console.error(`Error in ${TOOL_NAME}`, error);
throw error;
}
}
}

View File

@ -181,11 +181,6 @@ export interface ChatCompletionResponse {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
completion_tokens_details?: {
reasoning_tokens: number;
accepted_prediction_tokens: number;
rejected_prediction_tokens: number;
};
};
}

View File

@ -1,22 +0,0 @@
import {LanguageModelUsage, NoObjectGeneratedError} from "ai";
export interface GenerateObjectResult<T> {
object: T;
usage: LanguageModelUsage;
}
export async function handleGenerateObjectError<T>(error: unknown): Promise<GenerateObjectResult<T>> {
if (NoObjectGeneratedError.isInstance(error)) {
console.error('Object not generated according to the schema, fallback to manual parsing');
try {
const partialResponse = JSON.parse((error as any).text);
return {
object: partialResponse as T,
usage: (error as any).usage
};
} catch (parseError) {
throw error;
}
}
throw error;
}

View File

@ -0,0 +1,95 @@
import { z } from 'zod';
import {generateObject, LanguageModelUsage, NoObjectGeneratedError} from "ai";
import {TokenTracker} from "./token-tracker";
import {getModel, ToolName, getToolConfig} from "../config";
interface GenerateObjectResult<T> {
object: T;
usage: LanguageModelUsage;
}
interface GenerateOptions<T> {
model: ToolName;
schema: z.ZodType<T>;
prompt: string;
}
export class ObjectGeneratorSafe {
private tokenTracker: TokenTracker;
constructor(tokenTracker?: TokenTracker) {
this.tokenTracker = tokenTracker || new TokenTracker();
}
async generateObject<T>(options: GenerateOptions<T>): Promise<GenerateObjectResult<T>> {
const {
model,
schema,
prompt,
} = options;
try {
// Primary attempt with main model
const result = await generateObject({
model: getModel(model),
schema,
prompt,
maxTokens: getToolConfig(model).maxTokens,
temperature: getToolConfig(model).temperature,
});
this.tokenTracker.trackUsage(model, result.usage);
return result;
} catch (error) {
// First fallback: Try manual JSON parsing of the error response
try {
const errorResult = await this.handleGenerateObjectError<T>(error);
this.tokenTracker.trackUsage(model, errorResult.usage);
return errorResult;
} catch (parseError) {
// Second fallback: Try with fallback model if provided
const fallbackModel = getModel('fallback');
if (NoObjectGeneratedError.isInstance(parseError)) {
const failedOutput = (parseError as any).text;
console.error(`${model} failed on object generation ${failedOutput} -> manual parsing failed again -> trying fallback model`, fallbackModel);
try {
const fallbackResult = await generateObject({
model: fallbackModel,
schema,
prompt: `Extract the desired information from this text: \n ${failedOutput}`,
maxTokens: getToolConfig('fallback').maxTokens,
temperature: getToolConfig('fallback').temperature,
});
this.tokenTracker.trackUsage(model, fallbackResult.usage);
return fallbackResult;
} catch (fallbackError) {
// If fallback model also fails, try parsing its error response
return await this.handleGenerateObjectError<T>(fallbackError);
}
}
// If no fallback model or all attempts failed, throw the original error
throw error;
}
}
}
private async handleGenerateObjectError<T>(error: unknown): Promise<GenerateObjectResult<T>> {
if (NoObjectGeneratedError.isInstance(error)) {
console.error('Object not generated according to schema, fallback to manual JSON parsing');
try {
const partialResponse = JSON.parse((error as any).text);
return {
object: partialResponse as T,
usage: (error as any).usage
};
} catch (parseError) {
throw error;
}
}
throw error;
}
}