Merge pull request #438 from bdashore3/dev

Extras: Add API authentication support
This commit is contained in:
Cohee
2023-06-03 14:52:42 +03:00
committed by GitHub
7 changed files with 50 additions and 25 deletions

View File

@@ -2038,7 +2038,8 @@
SillyTavern-extras SillyTavern-extras
</a> </a>
</h3> </h3>
<input id="extensions_url" type="text" class="text_pole" maxlength="250" /> <input id="extensions_url" type="text" class="text_pole" maxlength="250" placeholder="Extensions URL" />
<input id="extensions_api_key" type="text" class="text_pole" maxlength="250" placeholder="API key" />
<div class="extensions_url_block"> <div class="extensions_url_block">
<input id="extensions_connect" class="menu_button" type="submit" value="Connect" /> <input id="extensions_connect" class="menu_button" type="submit" value="Connect" />
<input id="extensions_details" class="menu_button" type="button" value="Manage extensions" /> <input id="extensions_details" class="menu_button" type="button" value="Manage extensions" />

View File

@@ -1,11 +1,11 @@
import { callPopup, eventSource, event_types, saveSettings, saveSettingsDebounced } from "../script.js"; import { callPopup, eventSource, event_types, extension_prompt_types, saveSettings, saveSettingsDebounced } from "../script.js";
import { isSubsetOf, debounce } from "./utils.js"; import { isSubsetOf, debounce } from "./utils.js";
export { export {
getContext, getContext,
getApiUrl, getApiUrl,
loadExtensionSettings, loadExtensionSettings,
runGenerationInterceptors, runGenerationInterceptors,
defaultRequestArgs, doExtrasFetch,
modules, modules,
extension_settings, extension_settings,
ModuleWorkerWrapper, ModuleWorkerWrapper,
@@ -43,6 +43,7 @@ class ModuleWorkerWrapper {
const extension_settings = { const extension_settings = {
apiUrl: defaultUrl, apiUrl: defaultUrl,
apiKey: '',
autoConnect: false, autoConnect: false,
disabledExtensions: [], disabledExtensions: [],
memory: {}, memory: {},
@@ -65,9 +66,29 @@ let activeExtensions = new Set();
const getContext = () => window['SillyTavern'].getContext(); const getContext = () => window['SillyTavern'].getContext();
const getApiUrl = () => extension_settings.apiUrl; const getApiUrl = () => extension_settings.apiUrl;
const defaultRequestArgs = { method: 'GET', headers: { 'Bypass-Tunnel-Reminder': 'bypass' } };
let connectedToApi = false; let connectedToApi = false;
async function doExtrasFetch(endpoint, args) {
if (!args) {
args = {}
}
if (!args.method) {
Object.assign(args, { method: 'GET' });
}
if (!args.headers) {
args.headers = {}
}
Object.assign(args.headers, {
'Authorization': `Bearer ${extension_settings.apiKey}`,
'Bypass-Tunnel-Reminder': 'bypass'
});
const response = await fetch(endpoint, args);
return response;
}
async function discoverExtensions() { async function discoverExtensions() {
try { try {
const response = await fetch('/discover_extensions'); const response = await fetch('/discover_extensions');
@@ -178,6 +199,8 @@ async function activateExtensions() {
async function connectClickHandler() { async function connectClickHandler() {
const baseUrl = $("#extensions_url").val(); const baseUrl = $("#extensions_url").val();
extension_settings.apiUrl = baseUrl; extension_settings.apiUrl = baseUrl;
const testApiKey = $("#extensions_api_key").val();
extension_settings.apiKey = testApiKey;
saveSettingsDebounced(); saveSettingsDebounced();
await connectToApi(baseUrl); await connectToApi(baseUrl);
} }
@@ -233,7 +256,7 @@ async function connectToApi(baseUrl) {
url.pathname = '/api/modules'; url.pathname = '/api/modules';
try { try {
const getExtensionsResult = await fetch(url, defaultRequestArgs); const getExtensionsResult = await doExtrasFetch(url);
if (getExtensionsResult.ok) { if (getExtensionsResult.ok) {
const data = await getExtensionsResult.json(); const data = await getExtensionsResult.json();
@@ -352,6 +375,7 @@ async function loadExtensionSettings(settings) {
} }
$("#extensions_url").val(extension_settings.apiUrl); $("#extensions_url").val(extension_settings.apiUrl);
$("#extensions_api_key").val(extension_settings.apiKey);
$("#extensions_autoconnect").prop('checked', extension_settings.autoConnect); $("#extensions_autoconnect").prop('checked', extension_settings.autoConnect);
// Activate offline extensions // Activate offline extensions

View File

@@ -1,5 +1,5 @@
import { getBase64Async } from "../../utils.js"; import { getBase64Async } from "../../utils.js";
import { getContext, getApiUrl } from "../../extensions.js"; import { getContext, getApiUrl, doExtrasFetch } from "../../extensions.js";
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'caption'; const MODULE_NAME = 'caption';
@@ -63,7 +63,7 @@ async function onSelectImage(e) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/caption'; url.pathname = '/api/caption';
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -1,5 +1,5 @@
import { callPopup, getRequestHeaders, saveSettingsDebounced } from "../../../script.js"; import { callPopup, getRequestHeaders, saveSettingsDebounced } from "../../../script.js";
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper } from "../../extensions.js"; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch } from "../../extensions.js";
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'expressions'; const MODULE_NAME = 'expressions';
@@ -122,7 +122,7 @@ async function moduleWorker() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/classify'; url.pathname = '/api/classify';
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -265,7 +265,7 @@ async function getExpressionsList() {
url.pathname = '/api/classify/labels'; url.pathname = '/api/classify/labels';
try { try {
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'GET', method: 'GET',
headers: { 'Bypass-Tunnel-Reminder': 'bypass' }, headers: { 'Bypass-Tunnel-Reminder': 'bypass' },
}); });

View File

@@ -1,6 +1,6 @@
import { saveSettingsDebounced, getCurrentChatId, system_message_types, eventSource, event_types } from "../../../script.js"; import { saveSettingsDebounced, getCurrentChatId, system_message_types, eventSource, event_types } from "../../../script.js";
import { humanizedDateTime } from "../../RossAscends-mods.js"; import { humanizedDateTime } from "../../RossAscends-mods.js";
import { getApiUrl, extension_settings, getContext } from "../../extensions.js"; import { getApiUrl, extension_settings, getContext, doExtrasFetch } from "../../extensions.js";
import { getFileText, onlyUnique, splitRecursive, IndexedDBStore } from "../../utils.js"; import { getFileText, onlyUnique, splitRecursive, IndexedDBStore } from "../../utils.js";
export { MODULE_NAME }; export { MODULE_NAME };
@@ -174,7 +174,7 @@ async function addMessages(chat_id, messages) {
meta: JSON.stringify(m), meta: JSON.stringify(m),
})); }));
const addMessagesResult = await fetch(url, { const addMessagesResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id, messages: transformedMessages }), body: JSON.stringify({ chat_id, messages: transformedMessages }),
@@ -222,7 +222,7 @@ async function onPurgeClick() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/purge'; url.pathname = '/api/chromadb/purge';
const purgeResult = await fetch(url, { const purgeResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id }), body: JSON.stringify({ chat_id }),
@@ -242,7 +242,7 @@ async function onExportClick() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/export'; url.pathname = '/api/chromadb/export';
const exportResult = await fetch(url, { const exportResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id: currentChatId }), body: JSON.stringify({ chat_id: currentChatId }),
@@ -285,7 +285,7 @@ async function onSelectImportFile(e) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/import'; url.pathname = '/api/chromadb/import';
const importResult = await fetch(url, { const importResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify(imported), body: JSON.stringify(imported),
@@ -313,7 +313,7 @@ async function queryMessages(chat_id, query) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb/query'; url.pathname = '/api/chromadb/query';
const queryMessagesResult = await fetch(url, { const queryMessagesResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id, query, n_results: extension_settings.chromadb.n_results }), body: JSON.stringify({ chat_id, query, n_results: extension_settings.chromadb.n_results }),
@@ -366,7 +366,7 @@ async function onSelectInjectFile(e) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/chromadb'; url.pathname = '/api/chromadb';
const addMessagesResult = await fetch(url, { const addMessagesResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ chat_id: currentChatId, messages: messages }), body: JSON.stringify({ chat_id: currentChatId, messages: messages }),

View File

@@ -1,5 +1,5 @@
import { getStringHash, debounce } from "../../utils.js"; import { getStringHash, debounce } from "../../utils.js";
import { getContext, getApiUrl, extension_settings, ModuleWorkerWrapper } from "../../extensions.js"; import { getContext, getApiUrl, extension_settings, ModuleWorkerWrapper, doExtrasFetch } from "../../extensions.js";
import { extension_prompt_types, is_send_press, saveSettingsDebounced } from "../../../script.js"; import { extension_prompt_types, is_send_press, saveSettingsDebounced } from "../../../script.js";
export { MODULE_NAME }; export { MODULE_NAME };
@@ -232,7 +232,7 @@ async function summarizeChat(context) {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/summarize'; url.pathname = '/api/summarize';
const apiResult = await fetch(url, { const apiResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@@ -10,7 +10,7 @@ import {
eventSource, eventSource,
appendImageToMessage appendImageToMessage
} from "../../../script.js"; } from "../../../script.js";
import { getApiUrl, getContext, extension_settings, defaultRequestArgs, modules } from "../../extensions.js"; import { getApiUrl, getContext, extension_settings, doExtrasFetch, modules } from "../../extensions.js";
import { stringFormat, initScrollHeight, resetScrollHeight } from "../../utils.js"; import { stringFormat, initScrollHeight, resetScrollHeight } from "../../utils.js";
export { MODULE_NAME }; export { MODULE_NAME };
@@ -234,7 +234,7 @@ async function onModelChange() {
async function updateExtrasRemoteModel() { async function updateExtrasRemoteModel() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image/model'; url.pathname = '/api/image/model';
const getCurrentModelResult = await fetch(url, { const getCurrentModelResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ model: extension_settings.sd.model }), body: JSON.stringify({ model: extension_settings.sd.model }),
@@ -285,7 +285,7 @@ async function loadExtrasSamplers() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image/samplers'; url.pathname = '/api/image/samplers';
const result = await fetch(url, defaultRequestArgs); const result = await doExtrasFetch(url);
if (result.ok) { if (result.ok) {
const data = await result.json(); const data = await result.json();
@@ -338,7 +338,7 @@ async function loadExtrasModels() {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image/model'; url.pathname = '/api/image/model';
const getCurrentModelResult = await fetch(url, defaultRequestArgs); const getCurrentModelResult = await doExtrasFetch(url);
if (getCurrentModelResult.ok) { if (getCurrentModelResult.ok) {
const data = await getCurrentModelResult.json(); const data = await getCurrentModelResult.json();
@@ -346,7 +346,7 @@ async function loadExtrasModels() {
} }
url.pathname = '/api/image/models'; url.pathname = '/api/image/models';
const getModelsResult = await fetch(url, defaultRequestArgs); const getModelsResult = await doExtrasFetch(url);
if (getModelsResult.ok) { if (getModelsResult.ok) {
const data = await getModelsResult.json(); const data = await getModelsResult.json();
@@ -493,7 +493,7 @@ async function generateExtrasImage(prompt, callback) {
console.log(extension_settings.sd); console.log(extension_settings.sd);
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/image'; url.pathname = '/api/image';
const result = await fetch(url, { const result = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: postHeaders, headers: postHeaders,
body: JSON.stringify({ body: JSON.stringify({