From 36ea41f0a6e625cb5dcaf5810da1523e270331c6 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 2 Jun 2023 22:21:21 -0400 Subject: [PATCH] Extras: Add API authentication support An API key is extremely important for ST-Extras servers that are exposed to the internet. Add an API key field below where the user enters the extras URL. For convenience, this key is persisted whenever the user refreshes the page. Also modify the fetch requests to always include API keys if present. See ST-Extras for more information on how this works. Signed-off-by: kingbri --- public/index.html | 3 +- public/scripts/extensions.js | 32 ++++++++++++++++--- public/scripts/extensions/caption/index.js | 4 +-- .../scripts/extensions/expressions/index.js | 6 ++-- .../extensions/infinity-context/index.js | 14 ++++---- public/scripts/extensions/memory/index.js | 4 +-- .../extensions/stable-diffusion/index.js | 12 +++---- 7 files changed, 50 insertions(+), 25 deletions(-) diff --git a/public/index.html b/public/index.html index 3e19d2c88..04baeede2 100644 --- a/public/index.html +++ b/public/index.html @@ -2038,7 +2038,8 @@ SillyTavern-extras - + +
diff --git a/public/scripts/extensions.js b/public/scripts/extensions.js index 2f7c834d8..d2536cace 100644 --- a/public/scripts/extensions.js +++ b/public/scripts/extensions.js @@ -1,11 +1,11 @@ -import { callPopup, eventSource, event_types, saveSettings, saveSettingsDebounced } from "../script.js"; +import { callPopup, eventSource, event_types, extension_prompt_types, saveSettings, saveSettingsDebounced } from "../script.js"; import { isSubsetOf, debounce } from "./utils.js"; export { getContext, getApiUrl, loadExtensionSettings, runGenerationInterceptors, - defaultRequestArgs, + doExtrasFetch, modules, extension_settings, ModuleWorkerWrapper, @@ -43,6 +43,7 @@ class ModuleWorkerWrapper { const extension_settings = { apiUrl: defaultUrl, + apiKey: '', autoConnect: false, disabledExtensions: [], memory: {}, @@ -65,9 +66,29 @@ let activeExtensions = new Set(); const getContext = () => window['SillyTavern'].getContext(); const getApiUrl = () => extension_settings.apiUrl; -const defaultRequestArgs = { method: 'GET', headers: { 'Bypass-Tunnel-Reminder': 'bypass' } }; let connectedToApi = false; +async function doExtrasFetch(endpoint, args) { + if (!args) { + args = {} + } + + if (!args.method) { + Object.assign(args, { method: 'GET' }); + } + + if (!args.headers) { + args.headers = {} + } + Object.assign(args.headers, { + 'Authorization': `Bearer ${extension_settings.apiKey}`, + 'Bypass-Tunnel-Reminder': 'bypass' + }); + + const response = await fetch(endpoint, args); + return response; +} + async function discoverExtensions() { try { const response = await fetch('/discover_extensions'); @@ -178,6 +199,8 @@ async function activateExtensions() { async function connectClickHandler() { const baseUrl = $("#extensions_url").val(); extension_settings.apiUrl = baseUrl; + const testApiKey = $("#extensions_api_key").val(); + extension_settings.apiKey = testApiKey; saveSettingsDebounced(); await connectToApi(baseUrl); } @@ -233,7 +256,7 @@ async function connectToApi(baseUrl) { url.pathname = '/api/modules'; try { - const getExtensionsResult = await fetch(url, defaultRequestArgs); + const getExtensionsResult = await doExtrasFetch(url); if (getExtensionsResult.ok) { const data = await getExtensionsResult.json(); @@ -352,6 +375,7 @@ async function loadExtensionSettings(settings) { } $("#extensions_url").val(extension_settings.apiUrl); + $("#extensions_api_key").val(extension_settings.apiKey); $("#extensions_autoconnect").prop('checked', extension_settings.autoConnect); // Activate offline extensions diff --git a/public/scripts/extensions/caption/index.js b/public/scripts/extensions/caption/index.js index 314f57b7c..292be9138 100644 --- a/public/scripts/extensions/caption/index.js +++ b/public/scripts/extensions/caption/index.js @@ -1,5 +1,5 @@ import { getBase64Async } from "../../utils.js"; -import { getContext, getApiUrl } from "../../extensions.js"; +import { getContext, getApiUrl, doExtrasFetch } from "../../extensions.js"; export { MODULE_NAME }; const MODULE_NAME = 'caption'; @@ -63,7 +63,7 @@ async function onSelectImage(e) { const url = new URL(getApiUrl()); url.pathname = '/api/caption'; - const apiResult = await fetch(url, { + const apiResult = await doExtrasFetch(url, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index 2ebbeaa10..8d81653cf 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -1,5 +1,5 @@ import { callPopup, getRequestHeaders, saveSettingsDebounced } from "../../../script.js"; -import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper } from "../../extensions.js"; +import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch } from "../../extensions.js"; export { MODULE_NAME }; const MODULE_NAME = 'expressions'; @@ -122,7 +122,7 @@ async function moduleWorker() { const url = new URL(getApiUrl()); url.pathname = '/api/classify'; - const apiResult = await fetch(url, { + const apiResult = await doExtrasFetch(url, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -265,7 +265,7 @@ async function getExpressionsList() { url.pathname = '/api/classify/labels'; try { - const apiResult = await fetch(url, { + const apiResult = await doExtrasFetch(url, { method: 'GET', headers: { 'Bypass-Tunnel-Reminder': 'bypass' }, }); diff --git a/public/scripts/extensions/infinity-context/index.js b/public/scripts/extensions/infinity-context/index.js index 7901431db..c5a7fb995 100644 --- a/public/scripts/extensions/infinity-context/index.js +++ b/public/scripts/extensions/infinity-context/index.js @@ -1,6 +1,6 @@ import { saveSettingsDebounced, getCurrentChatId, system_message_types, eventSource, event_types } from "../../../script.js"; import { humanizedDateTime } from "../../RossAscends-mods.js"; -import { getApiUrl, extension_settings, getContext } from "../../extensions.js"; +import { getApiUrl, extension_settings, getContext, doExtrasFetch } from "../../extensions.js"; import { getFileText, onlyUnique, splitRecursive, IndexedDBStore } from "../../utils.js"; export { MODULE_NAME }; @@ -174,7 +174,7 @@ async function addMessages(chat_id, messages) { meta: JSON.stringify(m), })); - const addMessagesResult = await fetch(url, { + const addMessagesResult = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify({ chat_id, messages: transformedMessages }), @@ -222,7 +222,7 @@ async function onPurgeClick() { const url = new URL(getApiUrl()); url.pathname = '/api/chromadb/purge'; - const purgeResult = await fetch(url, { + const purgeResult = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify({ chat_id }), @@ -242,7 +242,7 @@ async function onExportClick() { const url = new URL(getApiUrl()); url.pathname = '/api/chromadb/export'; - const exportResult = await fetch(url, { + const exportResult = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify({ chat_id: currentChatId }), @@ -285,7 +285,7 @@ async function onSelectImportFile(e) { const url = new URL(getApiUrl()); url.pathname = '/api/chromadb/import'; - const importResult = await fetch(url, { + const importResult = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify(imported), @@ -313,7 +313,7 @@ async function queryMessages(chat_id, query) { const url = new URL(getApiUrl()); url.pathname = '/api/chromadb/query'; - const queryMessagesResult = await fetch(url, { + const queryMessagesResult = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify({ chat_id, query, n_results: extension_settings.chromadb.n_results }), @@ -366,7 +366,7 @@ async function onSelectInjectFile(e) { const url = new URL(getApiUrl()); url.pathname = '/api/chromadb'; - const addMessagesResult = await fetch(url, { + const addMessagesResult = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify({ chat_id: currentChatId, messages: messages }), diff --git a/public/scripts/extensions/memory/index.js b/public/scripts/extensions/memory/index.js index c58888710..99ed18742 100644 --- a/public/scripts/extensions/memory/index.js +++ b/public/scripts/extensions/memory/index.js @@ -1,5 +1,5 @@ import { getStringHash, debounce } from "../../utils.js"; -import { getContext, getApiUrl, extension_settings, ModuleWorkerWrapper } from "../../extensions.js"; +import { getContext, getApiUrl, extension_settings, ModuleWorkerWrapper, doExtrasFetch } from "../../extensions.js"; import { extension_prompt_types, is_send_press, saveSettingsDebounced } from "../../../script.js"; export { MODULE_NAME }; @@ -232,7 +232,7 @@ async function summarizeChat(context) { const url = new URL(getApiUrl()); url.pathname = '/api/summarize'; - const apiResult = await fetch(url, { + const apiResult = await doExtrasFetch(url, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/public/scripts/extensions/stable-diffusion/index.js b/public/scripts/extensions/stable-diffusion/index.js index 1e61f6cf0..413fba1b8 100644 --- a/public/scripts/extensions/stable-diffusion/index.js +++ b/public/scripts/extensions/stable-diffusion/index.js @@ -10,7 +10,7 @@ import { eventSource, appendImageToMessage } from "../../../script.js"; -import { getApiUrl, getContext, extension_settings, defaultRequestArgs, modules } from "../../extensions.js"; +import { getApiUrl, getContext, extension_settings, doExtrasFetch, modules } from "../../extensions.js"; import { stringFormat, initScrollHeight, resetScrollHeight } from "../../utils.js"; export { MODULE_NAME }; @@ -234,7 +234,7 @@ async function onModelChange() { async function updateExtrasRemoteModel() { const url = new URL(getApiUrl()); url.pathname = '/api/image/model'; - const getCurrentModelResult = await fetch(url, { + const getCurrentModelResult = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify({ model: extension_settings.sd.model }), @@ -285,7 +285,7 @@ async function loadExtrasSamplers() { const url = new URL(getApiUrl()); url.pathname = '/api/image/samplers'; - const result = await fetch(url, defaultRequestArgs); + const result = await doExtrasFetch(url); if (result.ok) { const data = await result.json(); @@ -338,7 +338,7 @@ async function loadExtrasModels() { const url = new URL(getApiUrl()); url.pathname = '/api/image/model'; - const getCurrentModelResult = await fetch(url, defaultRequestArgs); + const getCurrentModelResult = await doExtrasFetch(url); if (getCurrentModelResult.ok) { const data = await getCurrentModelResult.json(); @@ -346,7 +346,7 @@ async function loadExtrasModels() { } url.pathname = '/api/image/models'; - const getModelsResult = await fetch(url, defaultRequestArgs); + const getModelsResult = await doExtrasFetch(url); if (getModelsResult.ok) { const data = await getModelsResult.json(); @@ -493,7 +493,7 @@ async function generateExtrasImage(prompt, callback) { console.log(extension_settings.sd); const url = new URL(getApiUrl()); url.pathname = '/api/image'; - const result = await fetch(url, { + const result = await doExtrasFetch(url, { method: 'POST', headers: postHeaders, body: JSON.stringify({