From 3c0d03b45c3a8c8cde0fd6d31b6c4b13c54ffe93 Mon Sep 17 00:00:00 2001 From: Paul Ascenzi Date: Fri, 7 Feb 2025 10:54:14 -0500 Subject: [PATCH] added secure mode which requires the client to send an API key to make any type of server requests. secure mode is off by default. randomly generate API key if secure mode is on but API key is not set in environment variables to ensure API security. --- src/server.ts | 52 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/src/server.ts b/src/server.ts index 531716e..30047cf 100644 --- a/src/server.ts +++ b/src/server.ts @@ -1,12 +1,13 @@ -import express, {Request, Response, RequestHandler} from 'express'; +import express, { Request, Response, NextFunction, RequestHandler } from 'express'; import cors from 'cors'; -import {EventEmitter} from 'events'; -import {getResponse} from './agent'; -import {StepAction, StreamMessage, TrackerContext} from './types'; +import { EventEmitter } from 'events'; +import { getResponse } from './agent'; +import { StepAction, StreamMessage, TrackerContext } from './types'; import fs from 'fs/promises'; import path from 'path'; -import {TokenTracker} from "./utils/token-tracker"; -import {ActionTracker} from "./utils/action-tracker"; +import { TokenTracker } from "./utils/token-tracker"; +import { ActionTracker } from "./utils/action-tracker"; +import { randomUUID } from 'crypto'; const app = express(); const port = process.env.PORT || 3000; @@ -14,6 +15,31 @@ const port = process.env.PORT || 3000; app.use(cors()); app.use(express.json()); +const secureMode = process.env.SECURE_MODE || false as boolean +const apiKey = process.env.API_KEY || randomUUID() as string; + +const apiKeyMiddleware = (req: Request, res: Response, next: NextFunction) => { + const clientApiKey = req.headers['x-api-key']; + + if (!clientApiKey) { + res.status(401).json({ message: 'No API key provided' }); + console.error(`No API key provided by ${req.ip}`) + return + } + + if (clientApiKey !== apiKey) { + res.status(401).json({ message: 'Invalid API key' }); + console.error(`Invalid API key provided by ${req.ip}`) + return + } + + next(); +}; + +console.info('SECURE_MODE:', secureMode) +if (secureMode) + app.use(apiKeyMiddleware) + const eventEmitter = new EventEmitter(); interface QueryRequest extends Request { @@ -39,7 +65,7 @@ function createProgressEmitter(requestId: string, budget: number | undefined, co eventEmitter.emit(`progress-${requestId}`, { type: 'progress', - data: {...state.thisStep, totalStep: state.totalStep}, + data: { ...state.thisStep, totalStep: state.totalStep }, step: state.totalStep, budget: budgetInfo, trackers: { @@ -79,9 +105,9 @@ function emitTrackerUpdate(requestId: string, context: TrackerContext) { const trackers = new Map(); app.post('/api/v1/query', (async (req: QueryRequest, res: Response) => { - const {q, budget, maxBadAttempt} = req.body; + const { q, budget, maxBadAttempt } = req.body; if (!q) { - return res.status(400).json({error: 'Query (q) is required'}); + return res.status(400).json({ error: 'Query (q) is required' }); } const requestId = Date.now().toString(); @@ -97,10 +123,10 @@ app.post('/api/v1/query', (async (req: QueryRequest, res: Response) => { context.actionTracker.on('action', () => emitTrackerUpdate(requestId, context)); // context.tokenTracker.on('usage', () => emitTrackerUpdate(requestId, context)); - res.json({requestId}); + res.json({ requestId }); try { - const {result} = await getResponse(q, budget, maxBadAttempt, context); + const { result } = await getResponse(q, budget, maxBadAttempt, context); const emitProgress = createProgressEmitter(requestId, budget, context); context.actionTracker.on('action', emitProgress); await storeTaskResult(requestId, result); @@ -163,7 +189,7 @@ app.get('/api/v1/stream/:requestId', (async (req: Request, res: StreamResponse) async function storeTaskResult(requestId: string, result: StepAction) { try { const taskDir = path.join(process.cwd(), 'tasks'); - await fs.mkdir(taskDir, {recursive: true}); + await fs.mkdir(taskDir, { recursive: true }); await fs.writeFile( path.join(taskDir, `${requestId}.json`), JSON.stringify(result, null, 2) @@ -181,7 +207,7 @@ app.get('/api/v1/task/:requestId', (async (req: Request, res: Response) => { const taskData = await fs.readFile(taskPath, 'utf-8'); res.json(JSON.parse(taskData)); } catch (error) { - res.status(404).json({error: 'Task not found'}); + res.status(404).json({ error: 'Task not found' }); } }) as RequestHandler);