Extras: Add API authentication support

An API key is extremely important for ST-Extras servers that are
exposed to the internet.

Add an API key field below where the user enters the extras URL. For
convenience, this key is persisted whenever the user refreshes the
page.

Also modify the fetch requests to always include API keys if present.

See ST-Extras for more information on how this works.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-06-02 22:21:21 -04:00
parent f4802952b6
commit 36ea41f0a6
7 changed files with 50 additions and 25 deletions

View File

@ -2038,7 +2038,8 @@
SillyTavern-extras SillyTavern-extras
</a> </a>
</h3> </h3>
<input id="extensions_url" type="text" class="text_pole" maxlength="250" /> <input id="extensions_url" type="text" class="text_pole" maxlength="250" placeholder="Extensions URL" />
<input id="extensions_api_key" type="text" class="text_pole" maxlength="250" placeholder="API key" />
<div class="extensions_url_block"> <div class="extensions_url_block">
<input id="extensions_connect" class="menu_button" type="submit" value="Connect" /> <input id="extensions_connect" class="menu_button" type="submit" value="Connect" />
<input id="extensions_details" class="menu_button" type="button" value="Manage extensions" /> <input id="extensions_details" class="menu_button" type="button" value="Manage extensions" />

View File

@ -1,11 +1,11 @@
import { callPopup, eventSource, event_types, saveSettings, saveSettingsDebounced } from "../script.js"; import { callPopup, eventSource, event_types, extension_prompt_types, saveSettings, saveSettingsDebounced } from "../script.js";
import { isSubsetOf, debounce } from "./utils.js"; import { isSubsetOf, debounce } from "./utils.js";
export { export {
getContext, getContext,
getApiUrl, getApiUrl,
loadExtensionSettings, loadExtensionSettings,
runGenerationInterceptors, runGenerationInterceptors,
defaultRequestArgs, doExtrasFetch,
modules, modules,
extension_settings, extension_settings,
ModuleWorkerWrapper, ModuleWorkerWrapper,
@ -43,6 +43,7 @@ class ModuleWorkerWrapper {
const extension_settings = { const extension_settings = {
apiUrl: defaultUrl, apiUrl: defaultUrl,
apiKey: '',
autoConnect: false, autoConnect: false,
disabledExtensions: [], disabledExtensions: [],
memory: {}, memory: {},
@ -65,9 +66,29 @@ let activeExtensions = new Set();
const getContext = () => window['SillyTavern'].getContext(); const getContext = () => window['SillyTavern'].getContext();
const getApiUrl = () => extension_settings.apiUrl; const getApiUrl = () => extension_settings.apiUrl;
const defaultRequestArgs = { method: 'GET', headers: { 'Bypass-Tunnel-Reminder': 'bypass' } };
let connectedToApi = false; let connectedToApi = false;
async function doExtrasFetch(endpoint, args) {
if (!args) {
args = {}
}
if (!args.method) {
Object.assign(args, { method: 'GET' });
}
if (!args.headers) {
args.headers = {}
}
Object.assign(args.headers, {
'Authorization': `Bearer ${extension_settings.apiKey}`,
'Bypass-Tunnel-Reminder': 'bypass'
});
const response = await fetch(endpoint, args);
return response;
}
async function discoverExtensions() { async function discoverExtensions() {
try { try {
const response = await fetch('/discover_extensions'); const response = await fetch('/discover_extensions');
@ -178,6 +199,8 @@ async function activateExtensions() {
async function connectClickHandler() { async function connectClickHandler() {
const baseUrl = $("#extensions_url").val(); const baseUrl = $("#extensions_url").val();
extension_settings.apiUrl = baseUrl; extension_settings.apiUrl = baseUrl;
const testApiKey = $("#extensions_api_key").val();
extension_settings.apiKey = testApiKey;
saveSettingsDebounced(); saveSettingsDebounced();
await connectToApi(baseUrl); await connectToApi(baseUrl);
} }
@ -233,7 +256,7 @@ async function connectToApi(baseUrl) {
url.pathname = '/api/modules'; url.pathname = '/api/modules';
try { try {
const getExtensionsResult = await fetch(url, defaultRequestArgs); const getExtensionsResult = await doExtrasFetch(url);
if (getExtensionsResult.ok) { if (getExtensionsResult.ok) {
const data = await getExtensionsResult.json(); const data = await getExtensionsResult.json();
@ -352,6 +375,7 @@ async function loadExtensionSettings(settings) {
} }
$("#extensions_url").val(extension_settings.apiUrl); $("#extensions_url").val(extension_settings.apiUrl);
$("#extensions_api_key").val(extension_settings.apiKey);
$("#extensions_autoconnect").prop('checked', extension_settings.autoConnect); $("#extensions_autoconnect").prop('checked', extension_settings.autoConnect);
// Activate offline extensions // Activate offline extensions

View File

@ -1,5 +1,5 @@
import { getBase64Async } from "../../utils.js"; import { getBase64Async } from "../../utils.js";
import { getContext, getApiUrl } from "../../extensions.js"; import { getContext, getApiUrl, doExtrasFetch } from "../../extensions.js";
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'caption'; const MODULE_NAME = 'caption';
@ -63,7 +63,7 @@ async function onSelectImage(e) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/caption'; url.pathname = '/api/caption';
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@ -1,5 +1,5 @@
import { callPopup, getRequestHeaders, saveSettingsDebounced } from "../../../script.js"; import { callPopup, getRequestHeaders, saveSettingsDebounced } from "../../../script.js";
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper } from "../../extensions.js"; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch } from "../../extensions.js";
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'expressions'; const MODULE_NAME = 'expressions';
@ -122,7 +122,7 @@ async function moduleWorker() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/classify'; url.pathname = '/api/classify';
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -265,7 +265,7 @@ async function getExpressionsList() {
url.pathname = '/api/classify/labels'; url.pathname = '/api/classify/labels';
try { try {
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'GET', method: 'GET',
headers: { 'Bypass-Tunnel-Reminder': 'bypass' }, headers: { 'Bypass-Tunnel-Reminder': 'bypass' },
}); });

View File

@ -1,6 +1,6 @@
import { saveSettingsDebounced, getCurrentChatId, system_message_types, eventSource, event_types } from "../../../script.js"; import { saveSettingsDebounced, getCurrentChatId, system_message_types, eventSource, event_types } from "../../../script.js";
import { humanizedDateTime } from "../../RossAscends-mods.js"; import { humanizedDateTime } from "../../RossAscends-mods.js";
import { getApiUrl, extension_settings, getContext } from "../../extensions.js"; import { getApiUrl, extension_settings, getContext, doExtrasFetch } from "../../extensions.js";
import { getFileText, onlyUnique, splitRecursive, IndexedDBStore } from "../../utils.js"; import { getFileText, onlyUnique, splitRecursive, IndexedDBStore } from "../../utils.js";
export { MODULE_NAME }; export { MODULE_NAME };
@ -174,7 +174,7 @@ async function addMessages(chat_id, messages) {
meta: JSON.stringify(m), meta: JSON.stringify(m),
})); }));
const addMessagesResult = await fetch(url, { const addMessagesResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id, messages: transformedMessages }), body: JSON.stringify({ chat_id, messages: transformedMessages }),
@ -222,7 +222,7 @@ async function onPurgeClick() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/purge'; url.pathname = '/api/chromadb/purge';
const purgeResult = await fetch(url, { const purgeResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id }), body: JSON.stringify({ chat_id }),
@ -242,7 +242,7 @@ async function onExportClick() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/export'; url.pathname = '/api/chromadb/export';
const exportResult = await fetch(url, { const exportResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id: currentChatId }), body: JSON.stringify({ chat_id: currentChatId }),
@ -285,7 +285,7 @@ async function onSelectImportFile(e) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/import'; url.pathname = '/api/chromadb/import';
const importResult = await fetch(url, { const importResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify(imported), body: JSON.stringify(imported),
@ -313,7 +313,7 @@ async function queryMessages(chat_id, query) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/query'; url.pathname = '/api/chromadb/query';
const queryMessagesResult = await fetch(url, { const queryMessagesResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id, query, n_results: extension_settings.chromadb.n_results }), body: JSON.stringify({ chat_id, query, n_results: extension_settings.chromadb.n_results }),
@ -366,7 +366,7 @@ async function onSelectInjectFile(e) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb'; url.pathname = '/api/chromadb';
const addMessagesResult = await fetch(url, { const addMessagesResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id: currentChatId, messages: messages }), body: JSON.stringify({ chat_id: currentChatId, messages: messages }),

View File

@ -1,5 +1,5 @@
import { getStringHash, debounce } from "../../utils.js"; import { getStringHash, debounce } from "../../utils.js";
import { getContext, getApiUrl, extension_settings, ModuleWorkerWrapper } from "../../extensions.js"; import { getContext, getApiUrl, extension_settings, ModuleWorkerWrapper, doExtrasFetch } from "../../extensions.js";
import { extension_prompt_types, is_send_press, saveSettingsDebounced } from "../../../script.js"; import { extension_prompt_types, is_send_press, saveSettingsDebounced } from "../../../script.js";
export { MODULE_NAME }; export { MODULE_NAME };
@ -232,7 +232,7 @@ async function summarizeChat(context) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/summarize'; url.pathname = '/api/summarize';
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@ -10,7 +10,7 @@ import {
eventSource, eventSource,
appendImageToMessage appendImageToMessage
} from "../../../script.js"; } from "../../../script.js";
import { getApiUrl, getContext, extension_settings, defaultRequestArgs, modules } from "../../extensions.js"; import { getApiUrl, getContext, extension_settings, doExtrasFetch, modules } from "../../extensions.js";
import { stringFormat, initScrollHeight, resetScrollHeight } from "../../utils.js"; import { stringFormat, initScrollHeight, resetScrollHeight } from "../../utils.js";
export { MODULE_NAME }; export { MODULE_NAME };
@ -234,7 +234,7 @@ async function onModelChange() {
async function updateExtrasRemoteModel() { async function updateExtrasRemoteModel() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image/model'; url.pathname = '/api/image/model';
const getCurrentModelResult = await fetch(url, { const getCurrentModelResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ model: extension_settings.sd.model }), body: JSON.stringify({ model: extension_settings.sd.model }),
@ -285,7 +285,7 @@ async function loadExtrasSamplers() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image/samplers'; url.pathname = '/api/image/samplers';
const result = await fetch(url, defaultRequestArgs); const result = await doExtrasFetch(url);
if (result.ok) { if (result.ok) {
const data = await result.json(); const data = await result.json();
@ -338,7 +338,7 @@ async function loadExtrasModels() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image/model'; url.pathname = '/api/image/model';
const getCurrentModelResult = await fetch(url, defaultRequestArgs); const getCurrentModelResult = await doExtrasFetch(url);
if (getCurrentModelResult.ok) { if (getCurrentModelResult.ok) {
const data = await getCurrentModelResult.json(); const data = await getCurrentModelResult.json();
@ -346,7 +346,7 @@ async function loadExtrasModels() {
} }
url.pathname = '/api/image/models'; url.pathname = '/api/image/models';
const getModelsResult = await fetch(url, defaultRequestArgs); const getModelsResult = await doExtrasFetch(url);
if (getModelsResult.ok) { if (getModelsResult.ok) {
const data = await getModelsResult.json(); const data = await getModelsResult.json();
@ -493,7 +493,7 @@ async function generateExtrasImage(prompt, callback) {
console.log(extension_settings.sd); console.log(extension_settings.sd);
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image'; url.pathname = '/api/image';
const result = await fetch(url, { const result = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ body: JSON.stringify({