diff --git a/colab/GPU.ipynb b/colab/GPU.ipynb index e82f9891b..44ec0e8cd 100644 --- a/colab/GPU.ipynb +++ b/colab/GPU.ipynb @@ -307,15 +307,28 @@ " %cd /SillyTavern\n", " !npm install\n", " !npm install -g localtunnel\n", + " !npm install -g forever\n", + " !pip install flask-cloudflared\n", "ii.addTask(\"Install Tavern Dependencies\", installTavernDependencies)\n", "ii.run()\n", "\n", "%env colaburl=$url\n", "%env SILLY_TAVERN_PORT=5001\n", + "from flask_cloudflared import start_cloudflared\n", + "!sed -i 's/listen = true/listen = false/g' config.conf\n", + "!touch stdout.log stderr.log\n", + "!forever start -o stdout.log -e stderr.log server.js\n", "print(\"KoboldAI LINK:\", url, '###Extensions API LINK###', globals.extras_url, \"###SillyTavern LINK###\", sep=\"\\n\")\n", - "p = subprocess.Popen([\"lt\", \"--port\", \"5001\"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", - "print(p.stdout.readline().decode().strip())\n", - "!node server.js" + "import inspect\n", + "import random\n", + "sig = inspect.signature(start_cloudflared)\n", + "sum = sum(1 for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD)\n", + "if sum > 1:\n", + " metrics_port = random.randint(8100, 9000)\n", + " start_cloudflared(5001, metrics_port)\n", + "else:\n", + " start_cloudflared(5001)\n", + "!tail -f stdout.log stderr.log" ] } ], diff --git a/config.conf b/config.conf index 8cb18a2b2..77c6b85ad 100644 --- a/config.conf +++ b/config.conf @@ -2,10 +2,12 @@ const port = 8000; const whitelist = ['127.0.0.1']; //Example for add several IP in whitelist: ['127.0.0.1', '192.168.0.10'] const whitelistMode = true; //Disabling enabling the ip whitelist mode. true/false +const basicAuthMode = false; //Toggle basic authentication for endpoints. +const basicAuthUser = {username: "user", password: "password"}; //Login credentials when basicAuthMode is true. const autorun = true; //Autorun in the browser. true/false const enableExtensions = true; //Enables support for TavernAI-extras project const listen = true; // If true, Can be access from other device or PC. otherwise can be access only from hosting machine. module.exports = { - port, whitelist, whitelistMode, autorun, enableExtensions, listen + port, whitelist, whitelistMode, basicAuthMode, basicAuthUser, autorun, enableExtensions, listen }; diff --git a/public/script.js b/public/script.js index 7f06a2fe0..7ddda2ace 100644 --- a/public/script.js +++ b/public/script.js @@ -1319,6 +1319,7 @@ async function Generate(type, automatic_trigger, force_name2) { //console.log('Generate entered'); setGenerationProgress(0); tokens_already_generated = 0; + const isImpersonate = type == "impersonate"; message_already_generated = isImpersonate ? `${name1}: ` : `${name2}: `; @@ -1338,8 +1339,7 @@ async function Generate(type, automatic_trigger, force_name2) { if (isStreamingEnabled()) { streamingProcessor = new StreamingProcessor(type, force_name2); hideSwipeButtons(); - } - else { + } else { streamingProcessor = false; } @@ -1349,15 +1349,16 @@ async function Generate(type, automatic_trigger, force_name2) { } if (online_status != 'no_connection' && this_chid != undefined && this_chid !== 'invalid-safety-id') { + let textareaText; if (type !== 'regenerate' && type !== "swipe" && !isImpersonate) { is_send_press = true; - var textareaText = $("#send_textarea").val(); + textareaText = $("#send_textarea").val(); //console.log('Not a Regenerate call, so posting normall with input of: ' +textareaText); $("#send_textarea").val('').trigger('input'); } else { //console.log('Regenerate call detected') - var textareaText = ""; + textareaText = ""; if (chat.length && chat[chat.length - 1]['is_user']) {//If last message from You } @@ -1387,32 +1388,27 @@ async function Generate(type, automatic_trigger, force_name2) { } // bias from the latest message is top priority// - promptBias = messageBias ?? promptBias ?? ''; - var storyString = ""; - var userSendString = ""; - var finalPromt = ""; - var postAnchorChar = "Elaborate speaker"; - var postAnchorStyle = "Writing style: very long messages";//"[Genre: roleplay chat][Tone: very long messages with descriptions]"; - var anchorTop = ''; - var anchorBottom = ''; - var topAnchorDepth = 8; - - if (character_anchor && !is_pygmalion) { + // Compute anchors + const topAnchorDepth = 8; + let anchorTop = ''; + let anchorBottom = ''; + if (!is_pygmalion) { console.log('saw not pyg'); + + let postAnchorChar = character_anchor ? name2 + " Elaborate speaker" : ""; + let postAnchorStyle = style_anchor ? "Writing style: very long messages" : ""; if (anchor_order === 0) { - anchorTop = name2 + " " + postAnchorChar; - } else { - console.log('saw pyg, adding anchors') - anchorBottom = "[" + name2 + " " + postAnchorChar + "]"; - } - } - if (style_anchor && !is_pygmalion) { - if (anchor_order === 1) { + anchorTop = postAnchorChar; + anchorBottom = postAnchorStyle; + } else { // anchor_order === 1 anchorTop = postAnchorStyle; - } else { - anchorBottom = "[" + postAnchorStyle + "]"; + anchorBottom = postAnchorChar; + } + + if (anchorBottom) { + anchorBottom = "[" + anchorBottom + "]"; } } @@ -1436,22 +1432,18 @@ async function Generate(type, automatic_trigger, force_name2) { addOneMessage(chat[chat.length - 1]); } //////////////////////////////////// - let chatString = ''; - let arrMes = []; - let mesSend = []; let charDescription = baseChatReplace($.trim(characters[this_chid].description), name1, name2); let charPersonality = baseChatReplace($.trim(characters[this_chid].personality), name1, name2); let Scenario = baseChatReplace($.trim(characters[this_chid].scenario), name1, name2); let mesExamples = baseChatReplace($.trim(characters[this_chid].mes_example), name1, name2); + // Parse example messages if (!mesExamples.startsWith('')) { mesExamples = '\n' + mesExamples.trim(); } - if (mesExamples.replace(//gi, '').trim().length === 0) { mesExamples = ''; } - let mesExamplesArray = mesExamples.split(//gi).slice(1).map(block => `\n${block.trim()}\n`); if (main_api === 'openai') { @@ -1465,6 +1457,8 @@ async function Generate(type, automatic_trigger, force_name2) { setOpenAIMessageExamples(mesExamplesArray); } + let storyString = ""; + if (is_pygmalion) { storyString += appendToStoryString(charDescription, power_user.disable_description_formatting ? '' : name2 + "'s Persona: "); storyString += appendToStoryString(charPersonality, power_user.disable_personality_formatting ? '' : 'Personality: '); @@ -1504,12 +1498,10 @@ async function Generate(type, automatic_trigger, force_name2) { ////////////////////////////////// - var count_exm_add = 0; - //console.log('emptying chat2'); - var chat2 = []; - var j = 0; - //console.log('pre-replace chat.length = ' + chat.length); - for (var i = chat.length - 1; i >= 0; i--) { + console.log('emptying chat2'); + let chat2 = []; + console.log('pre-replace chat.length = ' + chat.length); + for (let i = chat.length - 1, j = 0; i >= 0; i--, j++) { let charName = selected_group ? chat[j].name : name2; if (j == 0) { chat[j]['mes'] = chat[j]['mes'].replace(/{{user}}/gi, name1); @@ -1537,11 +1529,12 @@ async function Generate(type, automatic_trigger, force_name2) { //chat2[i] = (chat2[i] ?? '').replace(/{.*}/g, ''); chat2[i] = (chat2[i] ?? '').replace(/{{(\*?.+?\*?)}}/g, ''); //console.log('replacing chat2 {}s'); - j++; } //console.log('post replace chat.length = ' + chat.length); //chat2 = chat2.reverse(); - var this_max_context = 1487; + + // Determine token limit + let this_max_context = 1487; if (main_api == 'kobold' || main_api == 'textgenerationwebui') { this_max_context = (max_context - amount_gen); } @@ -1558,11 +1551,11 @@ async function Generate(type, automatic_trigger, force_name2) { if (main_api == 'openai') { this_max_context = oai_settings.openai_max_context; } - if (main_api == 'poe') { this_max_context = Number(max_context); } + // Adjust token limit for Horde let hordeAmountGen = null; if (main_api == 'kobold' && horde_settings.use_horde && horde_settings.auto_adjust) { let adjustedParams; @@ -1591,15 +1584,17 @@ async function Generate(type, automatic_trigger, force_name2) { let { worldInfoString, worldInfoBefore, worldInfoAfter } = getWorldInfoPrompt(chat2); - //console.log('post swipe shift:' + chat2.length); - var i = 0; + console.log('post swipe shift:' + chat2.length); // hack for regeneration of the first message if (chat2.length == 0) { chat2.push(''); } - for (var item of chat2) { + // Collect enough messages to fill the context + let chatString = ''; + let arrMes = []; + for (let item of chat2) { chatString = item + chatString; const encodeString = JSON.stringify( worldInfoString + storyString + chatString + @@ -1611,47 +1606,42 @@ async function Generate(type, automatic_trigger, force_name2) { //if (is_pygmalion && i == chat2.length-1) item='\n'+item; arrMes[arrMes.length] = item; } else { - //console.log('reducing chat.length by 1'); - i = chat2.length - 1; + break; } await delay(1); //For disable slow down (encode gpt-2 need fix) - // console.log(i+' '+chat.length); - - count_exm_add = 0; - - if (i === chat2.length - 1) { - if (!power_user.pin_examples) { - let mesExmString = ''; - for (let iii = 0; iii < mesExamplesArray.length; iii++) { - mesExmString += mesExamplesArray[iii]; - const prompt = JSON.stringify(worldInfoString + storyString + mesExmString + chatString + anchorTop + anchorBottom + charPersonality + promptBias + allAnchors); - const tokenCount = getTokenCount(prompt, padding_tokens); - if (tokenCount < this_max_context) { - if (power_user.disable_examples_formatting) { - mesExamplesArray[iii] = mesExamplesArray[iii].replace(//i, ''); - } - - if (!is_pygmalion) { - mesExamplesArray[iii] = mesExamplesArray[iii].replace(//i, `This is how ${name2} should talk`); - } - count_exm_add++; - await delay(1); - } else { - iii = mesExamplesArray.length; - } - } - } - if (!is_pygmalion && Scenario && Scenario.length > 0) { - storyString += !power_user.disable_scenario_formatting ? `Circumstances and context of the dialogue: ${Scenario}\n` : `${Scenario}\n`; - } - console.log('calling runGenerate'); - await runGenerate(); - return; - } - i++; } + // Prepare unpinned example messages + let count_exm_add = 0; + if (!power_user.pin_examples) { + let mesExmString = ''; + for (let i = 0; i < mesExamplesArray.length; i++) { + mesExmString += mesExamplesArray[i]; + const prompt = JSON.stringify(worldInfoString + storyString + mesExmString + chatString + anchorTop + anchorBottom + charPersonality + promptBias + allAnchors); + const tokenCount = getTokenCount(prompt, padding_tokens); + if (tokenCount < this_max_context) { + if (power_user.disable_examples_formatting) { + mesExamplesArray[i] = mesExamplesArray[i].replace(//i, ''); + } else if (!is_pygmalion) { + mesExamplesArray[i] = mesExamplesArray[i].replace(//i, `This is how ${name2} should talk`); + } + count_exm_add++; + await delay(1); + } else { + break; + } + } + } + + if (!is_pygmalion && Scenario && Scenario.length > 0) { + storyString += !power_user.disable_scenario_formatting ? `Circumstances and context of the dialogue: ${Scenario}\n` : `${Scenario}\n`; + } + + let mesSend = []; + console.log('calling runGenerate'); + await runGenerate(); + async function runGenerate(cycleGenerationPromt = '') { is_send_press = true; @@ -1664,7 +1654,7 @@ async function Generate(type, automatic_trigger, force_name2) { console.log('generating prompt'); chatString = ""; arrMes = arrMes.reverse(); - var is_add_personality = false; + let is_add_personality = false; arrMes.forEach(function (item, i, arr) {//For added anchors and others if (i >= arrMes.length - 1 && $.trim(item).substr(0, (name1 + ":").length) != name1 + ":") { @@ -1673,7 +1663,6 @@ async function Generate(type, automatic_trigger, force_name2) { } } if (i === arrMes.length - topAnchorDepth && count_view_mes >= topAnchorDepth && !is_add_personality) { - is_add_personality = true; //chatString = chatString.substr(0,chatString.length-1); //anchorAndPersonality = "[Genre: roleplay chat][Tone: very long messages with descriptions]"; @@ -1798,7 +1787,7 @@ async function Generate(type, automatic_trigger, force_name2) { mesSendString = '\n' + mesSendString; //mesSendString = mesSendString; //This edit simply removes the first "" that is prepended to all context prompts } - finalPromt = worldInfoBefore + storyString + worldInfoAfter + afterScenarioAnchor + mesExmString + mesSendString + generatedPromtCache + promptBias; + let finalPromt = worldInfoBefore + storyString + worldInfoAfter + afterScenarioAnchor + mesExmString + mesSendString + generatedPromtCache + promptBias; if (zeroDepthAnchor && zeroDepthAnchor.length) { if (!isMultigenEnabled() || tokens_already_generated == 0) { @@ -1857,9 +1846,9 @@ async function Generate(type, automatic_trigger, force_name2) { this_amount_gen = Math.min(this_amount_gen, hordeAmountGen); } - var generate_data; + let generate_data; if (main_api == 'kobold') { - var generate_data = { + generate_data = { prompt: finalPromt, gui_settings: true, max_length: amount_gen, @@ -1927,7 +1916,7 @@ async function Generate(type, automatic_trigger, force_name2) { }; } - var generate_url = ''; + let generate_url = ''; if (main_api == 'kobold') { generate_url = '/generate'; } else if (main_api == 'textgenerationwebui') { diff --git a/server.js b/server.js index 7568c521e..c4ace2273 100644 --- a/server.js +++ b/server.js @@ -35,6 +35,7 @@ const rimraf = require("rimraf"); const multer = require("multer"); const http = require("http"); const https = require('https'); +const basicAuthMiddleware = require('./src/middleware/basicAuthMiddleware'); //const PNG = require('pngjs').PNG; const extract = require('png-chunks-extract'); const encode = require('png-chunks-encode'); @@ -194,6 +195,8 @@ const CORS = cors({ app.use(CORS); +if (listen && config.basicAuthMode) app.use(basicAuthMiddleware); + app.use(function (req, res, next) { //Security let clientIp = req.connection.remoteAddress; let ip = ipaddr.parse(clientIp); @@ -2419,6 +2422,10 @@ const setupTasks = async function () { if (autorun) open(autorunUrl.toString()); console.log('SillyTavern is listening on: ' + tavernUrl); + if (listen && + !config.whitelistMode && + !config.basicAuthMode) + console.log('Your SillyTavern is currently open to the public. To increase security, consider enabling whitelisting or basic authentication.') if (fs.existsSync('public/characters/update.txt') && !is_colab) { convertStage1(); diff --git a/src/middleware/basicAuthMiddleware.js b/src/middleware/basicAuthMiddleware.js new file mode 100644 index 000000000..2f368214c --- /dev/null +++ b/src/middleware/basicAuthMiddleware.js @@ -0,0 +1,39 @@ +/** + * When applied, this middleware will ensure the request contains the required header for basic authentication and only + * allow access to the endpoint after successful authentication. + */ + +const {dirname} = require('path'); +const appDir = dirname(require.main.filename); +const config = require(appDir + '/config.conf'); + +const unauthorizedResponse = (res) => { + res.set('WWW-Authenticate', 'Basic realm="SillyTavern", charset="UTF-8"'); + return res.status(401).send('Authentication required'); +}; + +const basicAuthMiddleware = function (request, response, callback) { + const authHeader = request.headers.authorization; + + if (!authHeader) { + return unauthorizedResponse(response); + } + + const [scheme, credentials] = authHeader.split(' '); + + if (scheme !== 'Basic' || !credentials) { + return unauthorizedResponse(response); + } + + const [username, password] = Buffer.from(credentials, 'base64') + .toString('utf8') + .split(':'); + + if (username === config.basicAuthUser.username && password === config.basicAuthUser.password) { + return callback(); + } else { + return unauthorizedResponse(response); + } +} + +module.exports = basicAuthMiddleware; \ No newline at end of file