jina-ai: billing for saas service (#55)

* wip: jina billing

* wip

* fix: build issues

* ci: cd gh action

* fix: make ci happy
This commit is contained in:
Yanlong Wang
2025-02-11 18:27:15 +08:00
committed by GitHub
parent c4639a2e92
commit 8af35c6640
26 changed files with 6150 additions and 647 deletions

View File

@@ -0,0 +1,347 @@
import {
Also, AuthenticationFailedError, AuthenticationRequiredError,
DownstreamServiceFailureError, RPC_CALL_ENVIRONMENT,
ArrayOf, AutoCastable, Prop
} from 'civkit/civ-rpc';
import { parseJSONText } from 'civkit/vectorize';
import { htmlEscape } from 'civkit/escape';
import { marshalErrorLike } from 'civkit/lang';
import type express from 'express';
import logger from '../lib/logger';
import { AsyncLocalContext } from '../lib/async-context';
import { InjectProperty } from '../lib/registry';
import { JinaEmbeddingsDashboardHTTP } from '../lib/billing';
import envConfig from '../lib/env-config';
import { FirestoreRecord } from '../lib/firestore';
import _ from 'lodash';
import { RateLimitDesc } from '../rate-limit';
export class JinaWallet extends AutoCastable {
@Prop({
default: ''
})
user_id!: string;
@Prop({
default: 0
})
trial_balance!: number;
@Prop()
trial_start?: Date;
@Prop()
trial_end?: Date;
@Prop({
default: 0
})
regular_balance!: number;
@Prop({
default: 0
})
total_balance!: number;
}
export class JinaEmbeddingsTokenAccount extends FirestoreRecord {
static override collectionName = 'embeddingsTokenAccounts';
override _id!: string;
@Prop({
required: true
})
user_id!: string;
@Prop({
nullable: true,
type: String,
})
email?: string;
@Prop({
nullable: true,
type: String,
})
full_name?: string;
@Prop({
nullable: true,
type: String,
})
customer_id?: string;
@Prop({
nullable: true,
type: String,
})
avatar_url?: string;
// Not keeping sensitive info for now
// @Prop()
// billing_address?: object;
// @Prop()
// payment_method?: object;
@Prop({
required: true
})
wallet!: JinaWallet;
@Prop({
type: Object
})
metadata?: { [k: string]: any; };
@Prop({
defaultFactory: () => new Date()
})
lastSyncedAt!: Date;
@Prop({
dictOf: [ArrayOf(RateLimitDesc)]
})
customRateLimits?: { [k: string]: RateLimitDesc[]; };
static patchedFields = [
];
static override from(input: any) {
for (const field of this.patchedFields) {
if (typeof input[field] === 'string') {
input[field] = parseJSONText(input[field]);
}
}
return super.from(input) as JinaEmbeddingsTokenAccount;
}
override degradeForFireStore() {
const copy: any = {
...this,
wallet: { ...this.wallet },
// Firebase disability
customRateLimits: _.mapValues(this.customRateLimits, (v) => v.map((x) => ({ ...x }))),
};
for (const field of (this.constructor as typeof JinaEmbeddingsTokenAccount).patchedFields) {
if (typeof copy[field] === 'object') {
copy[field] = JSON.stringify(copy[field]) as any;
}
}
return copy;
}
[k: string]: any;
}
const authDtoLogger = logger.child({ service: 'JinaAuthDTO' });
export interface FireBaseHTTPCtx {
req: express.Request,
res: express.Response,
}
const THE_VERY_SAME_JINA_EMBEDDINGS_CLIENT = new JinaEmbeddingsDashboardHTTP(envConfig.JINA_EMBEDDINGS_DASHBOARD_API_KEY);
@Also({
openapi: {
operation: {
parameters: {
'Authorization': {
description: htmlEscape`Jina Token for authentication.\n\n` +
htmlEscape`- Member of <JinaEmbeddingsAuthDTO>\n\n` +
`- Authorization: Bearer {YOUR_JINA_TOKEN}`
,
in: 'header',
schema: {
anyOf: [
{ type: 'string', format: 'token' }
]
}
}
}
}
}
})
export class JinaEmbeddingsAuthDTO extends AutoCastable {
uid?: string;
bearerToken?: string;
user?: JinaEmbeddingsTokenAccount;
@InjectProperty(AsyncLocalContext)
ctxMgr!: AsyncLocalContext;
jinaEmbeddingsDashboard = THE_VERY_SAME_JINA_EMBEDDINGS_CLIENT;
static override from(input: any) {
const instance = super.from(input) as JinaEmbeddingsAuthDTO;
const ctx = input[RPC_CALL_ENVIRONMENT];
const req = (ctx.rawRequest || ctx.req) as express.Request | undefined;
if (req) {
const authorization = req.get('authorization');
if (authorization) {
const authToken = authorization.split(' ')[1] || authorization;
instance.bearerToken = authToken;
}
}
if (!instance.bearerToken && input._token) {
instance.bearerToken = input._token;
}
return instance;
}
async getBrief(ignoreCache?: boolean | string) {
if (!this.bearerToken) {
throw new AuthenticationRequiredError({
message: 'Absence of bearer token'
});
}
let account;
try {
account = await JinaEmbeddingsTokenAccount.fromFirestore(this.bearerToken);
} catch (err) {
// FireStore would not accept any string as input and may throw if not happy with it
void 0;
}
const age = account?.lastSyncedAt ? Date.now() - account.lastSyncedAt.getTime() : Infinity;
if (account && !ignoreCache) {
if (account && age < 180_000) {
this.user = account;
this.uid = this.user?.user_id;
return account;
}
}
try {
const r = await this.jinaEmbeddingsDashboard.validateToken(this.bearerToken);
const brief = r.data;
const draftAccount = JinaEmbeddingsTokenAccount.from({
...account, ...brief, _id: this.bearerToken,
lastSyncedAt: new Date()
});
await JinaEmbeddingsTokenAccount.save(draftAccount.degradeForFireStore(), undefined, { merge: true });
this.user = draftAccount;
this.uid = this.user?.user_id;
return draftAccount;
} catch (err: any) {
authDtoLogger.warn(`Failed to get user brief: ${err}`, { err: marshalErrorLike(err) });
if (err?.status === 401) {
throw new AuthenticationFailedError({
message: 'Invalid bearer token'
});
}
if (account) {
this.user = account;
this.uid = this.user?.user_id;
return account;
}
throw new DownstreamServiceFailureError(`Failed to authenticate: ${err}`);
}
}
async reportUsage(tokenCount: number, mdl: string, endpoint: string = '/encode') {
const user = await this.assertUser();
const uid = user.user_id;
user.wallet.total_balance -= tokenCount;
return this.jinaEmbeddingsDashboard.reportUsage(this.bearerToken!, {
model_name: mdl,
api_endpoint: endpoint,
consumer: {
id: uid,
user_id: uid,
},
usage: {
total_tokens: tokenCount
},
labels: {
model_name: mdl
}
}).then((r) => {
JinaEmbeddingsTokenAccount.COLLECTION.doc(this.bearerToken!)
.update({ 'wallet.total_balance': JinaEmbeddingsTokenAccount.OPS.increment(-tokenCount) })
.catch((err) => {
authDtoLogger.warn(`Failed to update cache for ${uid}: ${err}`, { err: marshalErrorLike(err) });
});
return r;
}).catch((err) => {
user.wallet.total_balance += tokenCount;
authDtoLogger.warn(`Failed to report usage for ${uid}: ${err}`, { err: marshalErrorLike(err) });
});
}
async solveUID() {
if (this.uid) {
this.ctxMgr.set('uid', this.uid);
return this.uid;
}
if (this.bearerToken) {
await this.getBrief();
this.ctxMgr.set('uid', this.uid);
return this.uid;
}
return undefined;
}
async assertUID() {
const uid = await this.solveUID();
if (!uid) {
throw new AuthenticationRequiredError('Authentication failed');
}
return uid;
}
async assertUser() {
if (this.user) {
return this.user;
}
await this.getBrief();
return this.user!;
}
getRateLimits(...tags: string[]) {
const descs = tags.map((x) => this.user?.customRateLimits?.[x] || []).flat().filter((x) => x.isEffective());
if (descs.length) {
return descs;
}
return undefined;
}
}

View File

@@ -0,0 +1,9 @@
import { GlobalAsyncContext } from 'civkit/async-context';
import { container, singleton } from 'tsyringe';
@singleton()
export class AsyncLocalContext extends GlobalAsyncContext {}
const instance = container.resolve(AsyncLocalContext);
Reflect.set(process, 'asyncLocalContext', instance);
export default instance;

102
jina-ai/src/lib/billing.ts Normal file
View File

@@ -0,0 +1,102 @@
import { HTTPService } from 'civkit';
import _ from 'lodash';
export interface JinaWallet {
trial_balance: number;
trial_start: Date;
trial_end: Date;
regular_balance: number;
total_balance: number;
}
export interface JinaUserBrief {
user_id: string;
email: string | null;
full_name: string | null;
customer_id: string | null;
avatar_url?: string;
billing_address: Partial<{
address: string;
city: string;
state: string;
country: string;
postal_code: string;
}>;
payment_method: Partial<{
brand: string;
last4: string;
exp_month: number;
exp_year: number;
}>;
wallet: JinaWallet;
metadata: {
[k: string]: any;
};
}
export interface JinaUsageReport {
model_name: string;
api_endpoint: string;
consumer: {
user_id: string;
customer_plan?: string;
[k: string]: any;
};
usage: {
total_tokens: number;
};
labels: {
user_type?: string;
model_name?: string;
[k: string]: any;
};
}
export class JinaEmbeddingsDashboardHTTP extends HTTPService {
name = 'JinaEmbeddingsDashboardHTTP';
constructor(
public apiKey: string,
public baseUri: string = 'https://embeddings-dashboard-api.jina.ai/api'
) {
super(baseUri);
this.baseOptions.timeout = 30_000; // 30 sec
}
async authorization(token: string) {
const r = await this.get<JinaUserBrief>('/v1/authorization', {
headers: {
Authorization: `Bearer ${token}`
},
responseType: 'json',
});
return r;
}
async validateToken(token: string) {
const r = await this.getWithSearchParams<JinaUserBrief>('/v1/api_key/user', {
api_key: token,
}, {
responseType: 'json',
});
return r;
}
async reportUsage(token: string, query: JinaUsageReport) {
const r = await this.postJson('/v1/usage', query, {
headers: {
Authorization: `Bearer ${token}`,
'x-api-key': this.apiKey,
},
responseType: 'text',
});
return r;
}
}

View File

@@ -0,0 +1,59 @@
import { container, singleton } from 'tsyringe';
export const SPECIAL_COMBINED_ENV_KEY = 'ENV_COMBINED';
const CONF_ENV = [
'OPENAI_API_KEY',
'ANTHROPIC_API_KEY',
'REPLICATE_API_KEY',
'GOOGLE_AI_STUDIO_API_KEY',
'JINA_EMBEDDINGS_API_KEY',
'JINA_EMBEDDINGS_DASHBOARD_API_KEY',
'BRAVE_SEARCH_API_KEY',
] as const;
@singleton()
export class EnvConfig {
dynamic!: Record<string, string>;
combined: Record<string, string> = {};
originalEnv: Record<string, string | undefined> = { ...process.env };
constructor() {
if (process.env[SPECIAL_COMBINED_ENV_KEY]) {
Object.assign(this.combined, JSON.parse(
Buffer.from(process.env[SPECIAL_COMBINED_ENV_KEY]!, 'base64').toString('utf-8')
));
delete process.env[SPECIAL_COMBINED_ENV_KEY];
}
// Static config
for (const x of CONF_ENV) {
const s = this.combined[x] || process.env[x] || '';
Reflect.set(this, x, s);
if (x in process.env) {
delete process.env[x];
}
}
// Dynamic config
this.dynamic = new Proxy({
get: (_target: any, prop: string) => {
return this.combined[prop] || process.env[prop] || '';
}
}, {}) as any;
}
}
// eslint-disable-next-line @typescript-eslint/no-empty-interface
export interface EnvConfig extends Record<typeof CONF_ENV[number], string> { }
const instance = container.resolve(EnvConfig);
export default instance;

70
jina-ai/src/lib/errors.ts Normal file
View File

@@ -0,0 +1,70 @@
import { ApplicationError, Prop, RPC_TRANSFER_PROTOCOL_META_SYMBOL, StatusCode } from 'civkit';
import _ from 'lodash';
import dayjs from 'dayjs';
import utc from 'dayjs/plugin/utc';
dayjs.extend(utc);
@StatusCode(50301)
export class ServiceDisabledError extends ApplicationError { }
@StatusCode(50302)
export class ServiceCrashedError extends ApplicationError { }
@StatusCode(50303)
export class ServiceNodeResourceDrainError extends ApplicationError { }
@StatusCode(40104)
export class EmailUnverifiedError extends ApplicationError { }
@StatusCode(40201)
export class InsufficientCreditsError extends ApplicationError { }
@StatusCode(40202)
export class FreeFeatureLimitError extends ApplicationError { }
@StatusCode(40203)
export class InsufficientBalanceError extends ApplicationError { }
@StatusCode(40903)
export class LockConflictError extends ApplicationError { }
@StatusCode(40904)
export class BudgetExceededError extends ApplicationError { }
@StatusCode(45101)
export class HarmfulContentError extends ApplicationError { }
@StatusCode(45102)
export class SecurityCompromiseError extends ApplicationError { }
@StatusCode(41201)
export class BatchSizeTooLargeError extends ApplicationError { }
@StatusCode(42903)
export class RateLimitTriggeredError extends ApplicationError {
@Prop({
desc: 'Retry after seconds',
})
retryAfter?: number;
@Prop({
desc: 'Retry after date',
})
retryAfterDate?: Date;
protected override get [RPC_TRANSFER_PROTOCOL_META_SYMBOL]() {
const retryAfter = this.retryAfter || this.retryAfterDate;
if (!retryAfter) {
return super[RPC_TRANSFER_PROTOCOL_META_SYMBOL];
}
return _.merge(_.cloneDeep(super[RPC_TRANSFER_PROTOCOL_META_SYMBOL]), {
headers: {
'Retry-After': `${retryAfter instanceof Date ? dayjs(retryAfter).utc().format('ddd, DD MMM YYYY HH:mm:ss [GMT]') : retryAfter}`,
}
});
}
}

View File

@@ -0,0 +1,223 @@
import _ from 'lodash';
import { AutoCastable, Prop, RPC_MARSHAL } from 'civkit/civ-rpc';
import {
Firestore, FieldValue, DocumentReference,
Query, Timestamp, SetOptions, DocumentSnapshot,
} from '@google-cloud/firestore';
// Firestore doesn't support JavaScript objects with custom prototypes (i.e. objects that were created via the \"new\" operator)
function patchFireStoreArrogance(func: Function) {
return function (this: unknown) {
const origObjectGetPrototype = Object.getPrototypeOf;
Object.getPrototypeOf = function (x) {
const r = origObjectGetPrototype.call(this, x);
if (!r) {
return r;
}
return Object.prototype;
};
try {
return func.call(this, ...arguments);
} finally {
Object.getPrototypeOf = origObjectGetPrototype;
}
};
}
Reflect.set(DocumentReference.prototype, 'set', patchFireStoreArrogance(Reflect.get(DocumentReference.prototype, 'set')));
Reflect.set(DocumentSnapshot, 'fromObject', patchFireStoreArrogance(Reflect.get(DocumentSnapshot, 'fromObject')));
function mapValuesDeep(v: any, fn: (i: any) => any): any {
if (_.isPlainObject(v)) {
return _.mapValues(v, (i) => mapValuesDeep(i, fn));
} else if (_.isArray(v)) {
return v.map((i) => mapValuesDeep(i, fn));
} else {
return fn(v);
}
}
export type Constructor<T> = { new(...args: any[]): T; };
export type Constructed<T> = T extends Partial<infer U> ? U : T extends object ? T : object;
export function fromFirestore<T extends FirestoreRecord>(
this: Constructor<T>, id: string, overrideCollection?: string
): Promise<T | undefined>;
export async function fromFirestore(
this: any, id: string, overrideCollection?: string
) {
const collection = overrideCollection || this.collectionName;
if (!collection) {
throw new Error(`Missing collection name to construct ${this.name}`);
}
const ref = this.DB.collection(overrideCollection || this.collectionName).doc(id);
const ptr = await ref.get();
if (!ptr.exists) {
return undefined;
}
const doc = this.from(
// Fixes non-native firebase types
mapValuesDeep(ptr.data(), (i: any) => {
if (i instanceof Timestamp) {
return i.toDate();
}
return i;
})
);
Object.defineProperty(doc, '_ref', { value: ref, enumerable: false });
Object.defineProperty(doc, '_id', { value: ptr.id, enumerable: true });
return doc;
}
export function fromFirestoreQuery<T extends FirestoreRecord>(
this: Constructor<T>, query: Query
): Promise<T[]>;
export async function fromFirestoreQuery(this: any, query: Query) {
const ptr = await query.get();
if (ptr.docs.length) {
return ptr.docs.map(doc => {
const r = this.from(
mapValuesDeep(doc.data(), (i: any) => {
if (i instanceof Timestamp) {
return i.toDate();
}
return i;
})
);
Object.defineProperty(r, '_ref', { value: doc.ref, enumerable: false });
Object.defineProperty(r, '_id', { value: doc.id, enumerable: true });
return r;
});
}
return [];
}
export function setToFirestore<T extends FirestoreRecord>(
this: Constructor<T>, doc: T, overrideCollection?: string, setOptions?: SetOptions
): Promise<T>;
export async function setToFirestore(
this: any, doc: any, overrideCollection?: string, setOptions?: SetOptions
) {
let ref: DocumentReference<any> = doc._ref;
if (!ref) {
const collection = overrideCollection || this.collectionName;
if (!collection) {
throw new Error(`Missing collection name to construct ${this.name}`);
}
const predefinedId = doc._id || undefined;
const hdl = this.DB.collection(overrideCollection || this.collectionName);
ref = predefinedId ? hdl.doc(predefinedId) : hdl.doc();
Object.defineProperty(doc, '_ref', { value: ref, enumerable: false });
Object.defineProperty(doc, '_id', { value: ref.id, enumerable: true });
}
await ref.set(doc, { merge: true, ...setOptions });
return doc;
}
export function deleteQueryBatch<T extends FirestoreRecord>(
this: Constructor<T>, query: Query
): Promise<T>;
export async function deleteQueryBatch(this: any, query: Query) {
const snapshot = await query.get();
const batchSize = snapshot.size;
if (batchSize === 0) {
return;
}
// Delete documents in a batch
const batch = this.DB.batch();
snapshot.docs.forEach((doc) => {
batch.delete(doc.ref);
});
await batch.commit();
process.nextTick(() => {
this.deleteQueryBatch(query);
});
};
export function fromFirestoreDoc<T extends FirestoreRecord>(
this: Constructor<T>, snapshot: DocumentSnapshot,
): T | undefined;
export function fromFirestoreDoc(
this: any, snapshot: DocumentSnapshot,
) {
const doc = this.from(
// Fixes non-native firebase types
mapValuesDeep(snapshot.data(), (i: any) => {
if (i instanceof Timestamp) {
return i.toDate();
}
return i;
})
);
Object.defineProperty(doc, '_ref', { value: snapshot.ref, enumerable: false });
Object.defineProperty(doc, '_id', { value: snapshot.id, enumerable: true });
return doc;
}
const defaultFireStore = new Firestore({
projectId: process.env.GCLOUD_PROJECT,
});
export class FirestoreRecord extends AutoCastable {
static collectionName?: string;
static OPS = FieldValue;
static DB = defaultFireStore;
static get COLLECTION() {
if (!this.collectionName) {
throw new Error('Not implemented');
}
return this.DB.collection(this.collectionName);
}
@Prop()
_id?: string;
_ref?: DocumentReference<Partial<Omit<this, '_ref' | '_id'>>>;
static fromFirestore = fromFirestore;
static fromFirestoreDoc = fromFirestoreDoc;
static fromFirestoreQuery = fromFirestoreQuery;
static save = setToFirestore;
static deleteQueryBatch = deleteQueryBatch;
[RPC_MARSHAL]() {
return {
...this,
_id: this._id,
_ref: this._ref?.path
};
}
degradeForFireStore(): this {
return JSON.parse(JSON.stringify(this, function (k, v) {
if (k === '') {
return v;
}
if (typeof v === 'object' && v && (typeof v.degradeForFireStore === 'function')) {
return v.degradeForFireStore();
}
return v;
}));
}
}

56
jina-ai/src/lib/logger.ts Normal file
View File

@@ -0,0 +1,56 @@
import { AbstractPinoLogger } from 'civkit/pino-logger';
import { singleton, container } from 'tsyringe';
import { threadId } from 'node:worker_threads';
import { getTraceCtx } from 'civkit/async-context';
const levelToSeverityMap: { [k: string]: string | undefined; } = {
trace: 'DEFAULT',
debug: 'DEBUG',
info: 'INFO',
warn: 'WARNING',
error: 'ERROR',
fatal: 'CRITICAL',
};
@singleton()
export class GlobalLogger extends AbstractPinoLogger {
loggerOptions = {
level: 'debug',
base: {
tid: threadId,
}
};
override init(): void {
if (process.env['NODE_ENV']?.startsWith('prod')) {
super.init(process.stdout);
} else {
const PinoPretty = require('pino-pretty').PinoPretty;
super.init(PinoPretty({
singleLine: true,
colorize: true,
messageFormat(log: any, messageKey: any) {
return `${log['tid'] ? `[${log['tid']}]` : ''}[${log['service'] || 'ROOT'}] ${log[messageKey]}`;
},
}));
}
this.emit('ready');
}
override log(...args: any[]) {
const [levelObj, ...rest] = args;
const severity = levelToSeverityMap[levelObj?.level];
const traceCtx = getTraceCtx();
const patched: any= { ...levelObj, severity };
if (traceCtx?.traceId && process.env['GCLOUD_PROJECT']) {
patched['logging.googleapis.com/trace'] = `projects/${process.env['GCLOUD_PROJECT']}/traces/${traceCtx.traceId}`;
}
return super.log(patched, ...rest);
}
}
const instance = container.resolve(GlobalLogger);
export default instance;

View File

@@ -0,0 +1,4 @@
import { container } from 'tsyringe';
import { propertyInjectorFactory } from 'civkit/property-injector';
export const InjectProperty = propertyInjectorFactory(container);

View File

@@ -0,0 +1,93 @@
import { ApplicationError, RPC_CALL_ENVIRONMENT } from "civkit/civ-rpc";
import { marshalErrorLike } from "civkit/lang";
import { randomUUID } from "crypto";
import { once } from "events";
import type { NextFunction, Request, Response } from "express";
import { JinaEmbeddingsAuthDTO } from "./dto/jina-embeddings-auth";
import rateLimitControl, { API_CALL_STATUS, RateLimitDesc } from "./rate-limit";
import asyncLocalContext from "./lib/async-context";
import globalLogger from "./lib/logger";
globalLogger.serviceReady();
const logger = globalLogger.child({ service: 'BillingMiddleware' });
const appName = 'DEEPRESEARCH';
export const jinaAiBillingMiddleware = (req: Request, res: Response, next: NextFunction) => {
if (req.path === '/ping') {
res.status(200).end('pone');
return;
}
if (req.method !== 'POST' && req.method !== 'GET') {
next();
return;
}
asyncLocalContext.run(async () => {
const googleTraceId = req.get('x-cloud-trace-context')?.split('/')?.[0];
const ctx = asyncLocalContext.ctx;
ctx.traceId = req.get('x-request-id') || req.get('request-id') || googleTraceId || randomUUID();
ctx.traceT0 = new Date();
ctx.ip = req?.ip;
try {
const authDto = JinaEmbeddingsAuthDTO.from({
[RPC_CALL_ENVIRONMENT]: { req, res }
});
const user = await authDto.assertUser();
await rateLimitControl.serviceReady();
const rateLimitPolicy = authDto.getRateLimits(appName) || [
parseInt(user.metadata?.speed_level) >= 2 ?
RateLimitDesc.from({
occurrence: 30,
periodSeconds: 60
}) :
RateLimitDesc.from({
occurrence: 10,
periodSeconds: 60
})
];
const criterions = rateLimitPolicy.map((c) => rateLimitControl.rateLimitDescToCriterion(c));
await Promise.all(criterions.map(([pointInTime, n]) => rateLimitControl.assertUidPeriodicLimit(user._id, pointInTime, n, appName)));
const apiRoll = rateLimitControl.record({ uid: user._id, tags: [appName] })
apiRoll.save().catch((err) => logger.warn(`Failed to save rate limit record`, { err: marshalErrorLike(err) }));
const pResClose = once(res, 'close');
next();
await pResClose;
const chargeAmount = ctx.chargeAmount;
if (chargeAmount) {
authDto.reportUsage(chargeAmount, `reader-${appName}`).catch((err) => {
logger.warn(`Unable to report usage for ${user._id}`, { err: marshalErrorLike(err) });
});
apiRoll.chargeAmount = chargeAmount;
}
apiRoll.status = res.statusCode === 200 ? API_CALL_STATUS.SUCCESS : API_CALL_STATUS.ERROR;
apiRoll.save().catch((err) => logger.warn(`Failed to save rate limit record`, { err: marshalErrorLike(err) }));
logger.info(`HTTP ${res.statusCode} for request ${ctx.traceId} after ${Date.now() - ctx.traceT0.valueOf()}ms`, {
uid: user._id,
chargeAmount,
});
} catch (err: any) {
if (!res.headersSent) {
if (err instanceof ApplicationError) {
res.status(parseInt(err.code as string) || 500).json({ error: err.message });
return;
}
res.status(500).json({ error: 'Internal' });
}
logger.error(`Error in billing middleware`, { err: marshalErrorLike(err) });
if (err.stack) {
logger.error(err.stack);
}
}
});
}

278
jina-ai/src/rate-limit.ts Normal file
View File

@@ -0,0 +1,278 @@
import { AutoCastable, ResourcePolicyDenyError, Also, Prop } from 'civkit/civ-rpc';
import { AsyncService } from 'civkit/async-service';
import { getTraceId } from 'civkit/async-context';
import { singleton, container } from 'tsyringe';
import { RateLimitTriggeredError } from './lib/errors';
import { FirestoreRecord } from './lib/firestore';
import { GlobalLogger } from './lib/logger';
export enum API_CALL_STATUS {
SUCCESS = 'success',
ERROR = 'error',
PENDING = 'pending',
}
@Also({ dictOf: Object })
export class APICall extends FirestoreRecord {
static override collectionName = 'apiRoll';
@Prop({
required: true,
defaultFactory: () => getTraceId()
})
traceId!: string;
@Prop()
uid?: string;
@Prop()
ip?: string;
@Prop({
arrayOf: String,
default: [],
})
tags!: string[];
@Prop({
required: true,
defaultFactory: () => new Date(),
})
createdAt!: Date;
@Prop()
completedAt?: Date;
@Prop({
required: true,
default: API_CALL_STATUS.PENDING,
})
status!: API_CALL_STATUS;
@Prop({
required: true,
defaultFactory: () => new Date(Date.now() + 1000 * 60 * 60 * 24 * 90),
})
expireAt!: Date;
[k: string]: any;
tag(...tags: string[]) {
for (const t of tags) {
if (!this.tags.includes(t)) {
this.tags.push(t);
}
}
}
save() {
return (this.constructor as typeof APICall).save(this);
}
}
export class RateLimitDesc extends AutoCastable {
@Prop({
default: 1000
})
occurrence!: number;
@Prop({
default: 3600
})
periodSeconds!: number;
@Prop()
notBefore?: Date;
@Prop()
notAfter?: Date;
isEffective() {
const now = new Date();
if (this.notBefore && this.notBefore > now) {
return false;
}
if (this.notAfter && this.notAfter < now) {
return false;
}
return true;
}
}
@singleton()
export class RateLimitControl extends AsyncService {
logger = this.globalLogger.child({ service: this.constructor.name });
constructor(
protected globalLogger: GlobalLogger,
) {
super(...arguments);
}
override async init() {
await this.dependencyReady();
this.emit('ready');
}
async queryByUid(uid: string, pointInTime: Date, ...tags: string[]) {
let q = APICall.COLLECTION
.orderBy('createdAt', 'asc')
.where('createdAt', '>=', pointInTime)
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING])
.where('uid', '==', uid);
if (tags.length) {
q = q.where('tags', 'array-contains-any', tags);
}
return APICall.fromFirestoreQuery(q);
}
async queryByIp(ip: string, pointInTime: Date, ...tags: string[]) {
let q = APICall.COLLECTION
.orderBy('createdAt', 'asc')
.where('createdAt', '>=', pointInTime)
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING])
.where('ip', '==', ip);
if (tags.length) {
q = q.where('tags', 'array-contains-any', tags);
}
return APICall.fromFirestoreQuery(q);
}
async assertUidPeriodicLimit(uid: string, pointInTime: Date, limit: number, ...tags: string[]) {
if (limit <= 0) {
throw new ResourcePolicyDenyError(`This UID(${uid}) is not allowed to call this endpoint (rate limit quota is 0).`);
}
let q = APICall.COLLECTION
.orderBy('createdAt', 'asc')
.where('createdAt', '>=', pointInTime)
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING])
.where('uid', '==', uid);
if (tags.length) {
q = q.where('tags', 'array-contains-any', tags);
}
const count = (await q.count().get()).data().count;
if (count >= limit) {
const r = await APICall.fromFirestoreQuery(q.limit(1));
const [r1] = r;
const dtMs = Math.abs(r1.createdAt?.valueOf() - pointInTime.valueOf());
const dtSec = Math.ceil(dtMs / 1000);
throw RateLimitTriggeredError.from({
message: `Per UID rate limit exceeded (${tags.join(',') || 'called'} ${limit} times since ${pointInTime})`,
retryAfter: dtSec,
});
}
return count + 1;
}
async assertIPPeriodicLimit(ip: string, pointInTime: Date, limit: number, ...tags: string[]) {
let q = APICall.COLLECTION
.orderBy('createdAt', 'asc')
.where('createdAt', '>=', pointInTime)
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING])
.where('ip', '==', ip);
if (tags.length) {
q = q.where('tags', 'array-contains-any', tags);
}
const count = (await q.count().get()).data().count;
if (count >= limit) {
const r = await APICall.fromFirestoreQuery(q.limit(1));
const [r1] = r;
const dtMs = Math.abs(r1.createdAt?.valueOf() - pointInTime.valueOf());
const dtSec = Math.ceil(dtMs / 1000);
throw RateLimitTriggeredError.from({
message: `Per IP rate limit exceeded (${tags.join(',') || 'called'} ${limit} times since ${pointInTime})`,
retryAfter: dtSec,
});
}
return count + 1;
}
record(partialRecord: Partial<APICall>) {
const record = APICall.from(partialRecord);
const newId = APICall.COLLECTION.doc().id;
record._id = newId;
return record;
}
// async simpleRPCUidBasedLimit(rpcReflect: RPCReflection, uid: string, tags: string[] = [],
// ...inputCriterion: RateLimitDesc[] | [Date, number][]) {
// const criterion = inputCriterion.map((c) => { return Array.isArray(c) ? c : this.rateLimitDescToCriterion(c); });
// await Promise.all(criterion.map(([pointInTime, n]) =>
// this.assertUidPeriodicLimit(uid, pointInTime, n, ...tags)));
// const r = this.record({
// uid,
// tags,
// });
// r.save().catch((err) => this.logger.warn(`Failed to save rate limit record`, { err }));
// rpcReflect.then(() => {
// r.status = API_CALL_STATUS.SUCCESS;
// r.save()
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err }));
// });
// rpcReflect.catch((err) => {
// r.status = API_CALL_STATUS.ERROR;
// r.error = err.toString();
// r.save()
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err }));
// });
// return r;
// }
rateLimitDescToCriterion(rateLimitDesc: RateLimitDesc) {
return [new Date(Date.now() - rateLimitDesc.periodSeconds * 1000), rateLimitDesc.occurrence] as [Date, number];
}
// async simpleRpcIPBasedLimit(rpcReflect: RPCReflection, ip: string, tags: string[] = [],
// ...inputCriterion: RateLimitDesc[] | [Date, number][]) {
// const criterion = inputCriterion.map((c) => { return Array.isArray(c) ? c : this.rateLimitDescToCriterion(c); });
// await Promise.all(criterion.map(([pointInTime, n]) =>
// this.assertIPPeriodicLimit(ip, pointInTime, n, ...tags)));
// const r = this.record({
// ip,
// tags,
// });
// r.save().catch((err) => this.logger.warn(`Failed to save rate limit record`, { err }));
// rpcReflect.then(() => {
// r.status = API_CALL_STATUS.SUCCESS;
// r.save()
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err }));
// });
// rpcReflect.catch((err) => {
// r.status = API_CALL_STATUS.ERROR;
// r.error = err.toString();
// r.save()
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err }));
// });
// return r;
// }
}
const instance = container.resolve(RateLimitControl);
export default instance;

56
jina-ai/src/server.ts Normal file
View File

@@ -0,0 +1,56 @@
import 'reflect-metadata'
import express from 'express';
import { jinaAiBillingMiddleware } from "./patch-express";
import { Server } from 'http';
const app = require('../..').default;
const rootApp = express();
rootApp.use(jinaAiBillingMiddleware, app);
const port = process.env.PORT || 3000;
let server: Server | undefined;
// Export server startup function for better testing
export function startServer() {
return rootApp.listen(port, () => {
console.log(`Server running at http://localhost:${port}`);
});
}
// Start server if running directly
if (process.env.NODE_ENV !== 'test') {
server = startServer();
}
process.on('unhandledRejection', (_err) => `Is false alarm`);
process.on('uncaughtException', (err) => {
console.log('Uncaught exception', err);
// Looks like Firebase runtime does not handle error properly.
// Make sure to quit the process.
process.nextTick(() => process.exit(1));
console.error('Uncaught exception, process quit.');
throw err;
});
const sigHandler = (signal: string) => {
console.log(`Received ${signal}, exiting...`);
if (server && server.listening) {
console.log(`Shutting down gracefully...`);
console.log(`Waiting for the server to drain and close...`);
server.close((err) => {
if (err) {
console.error('Error while closing server', err);
return;
}
process.exit(0);
});
server.closeIdleConnections();
}
}
process.on('SIGTERM', sigHandler);
process.on('SIGINT', sigHandler);