diff --git a/default/config.conf b/default/config.conf index 92c99ba02..962c6ce5b 100644 --- a/default/config.conf +++ b/default/config.conf @@ -15,17 +15,21 @@ const skipContentCheck = false; // If true, no new default content will be deliv // Change this setting only on "trusted networks". Do not change this value unless you are aware of the issues that can arise from changing this setting and configuring a insecure setting. const securityOverride = false; +// Request overrides for additional headers +const requestOverrides = []; + module.exports = { - port, - whitelist, - whitelistMode, - basicAuthMode, - basicAuthUser, - autorun, - enableExtensions, - listen, - disableThumbnails, - allowKeysExposure, - securityOverride, - skipContentCheck, + port, + whitelist, + whitelistMode, + basicAuthMode, + basicAuthUser, + autorun, + enableExtensions, + listen, + disableThumbnails, + allowKeysExposure, + securityOverride, + skipContentCheck, + requestOverrides, }; diff --git a/server.js b/server.js index e50fbc6d5..1cb676993 100644 --- a/server.js +++ b/server.js @@ -185,7 +185,14 @@ function get_mancer_headers() { return api_key_mancer ? { "X-API-KEY": api_key_mancer } : {}; } - +function getOverrideHeaders(urlHost) { + const overrideHeaders = config.requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers; + if (overrideHeaders && urlHost) { + return overrideHeaders; + } else { + return {}; + } +} //RossAscends: Added function to format dates used in files and chat timestamps to a humanized format. //Mostly I wanted this to be for file names, but couldn't figure out exactly where the filename save code was as everything seemed to be connected. @@ -540,7 +547,10 @@ app.post("/generate", jsonParser, async function (request, response_generate) { console.log(this_settings); const args = { body: JSON.stringify(this_settings), - headers: { "Content-Type": "application/json" }, + headers: Object.assign( + { "Content-Type": "application/json" }, + getOverrideHeaders((new URL(api_server))?.host) + ), signal: controller.signal, }; @@ -630,11 +640,19 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r }); async function* readWebsocket() { + const streamingUrlString = request.header('X-Streaming-URL').replace("localhost", "127.0.0.1"); + const streamingUrl = new URL(streamingUrlString); const websocket = new WebSocket(streamingUrl); websocket.on('open', async function () { console.log('WebSocket opened'); - const combined_args = Object.assign(request.body.use_mancer ? get_mancer_headers() : {}, request.body); + const combined_args = Object.assign( + {}, + request.body.use_mancer ? get_mancer_headers() : getOverrideHeaders(streamingUrl?.host), + request.body + ); + console.log(combined_args); + websocket.send(JSON.stringify(combined_args)); }); @@ -716,6 +734,8 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r if (request.body.use_mancer) { args.headers = Object.assign(args.headers, get_mancer_headers()); + } else { + args.headers = Object.assign(args.headers, getOverrideHeaders((new URL(api_server))?.host)); } try { @@ -783,6 +803,7 @@ app.post("/getchat", jsonParser, function (request, response) { } }); +// Only called for kobold and ooba/mancer app.post("/getstatus", jsonParser, async function (request, response) { if (!request.body) return response.sendStatus(400); api_server = request.body.api_server; @@ -797,6 +818,8 @@ app.post("/getstatus", jsonParser, async function (request, response) { if (main_api == 'textgenerationwebui' && request.body.use_mancer) { args.headers = Object.assign(args.headers, get_mancer_headers()); + } else { + args.headers = Object.assign(args.headers, getOverrideHeaders((new URL(api_server))?.host)); } const url = api_server + "/v1/model";