diff --git a/poe-client.js b/poe-client.js index cb31323a6..ada79c402 100644 --- a/poe-client.js +++ b/poe-client.js @@ -259,6 +259,7 @@ class Client { constructor(auto_reconnect = false, use_cached_bots = false) { this.auto_reconnect = auto_reconnect; this.use_cached_bots = use_cached_bots; + this.abortController = new AbortController(); } async init(token, proxy = null) { @@ -267,6 +268,7 @@ class Client { timeout: 60000, httpAgent: new http.Agent({ keepAlive: true }), httpsAgent: new https.Agent({ keepAlive: true }), + signal: this.abortController.signal, }); if (proxy) { this.session.defaults.proxy = { @@ -544,6 +546,8 @@ class Client { let messageId; while (true) { try { + this.abortController.signal.throwIfAborted(); + const message = this.message_queues[humanMessageId].shift(); if (!message) { await new Promise(resolve => setTimeout(() => resolve(), 1000)); diff --git a/public/index.html b/public/index.html index f11c83841..09cfe0d9e 100644 --- a/public/index.html +++ b/public/index.html @@ -116,6 +116,7 @@ + diff --git a/public/script.js b/public/script.js index e0bca27bb..7d1fdb94a 100644 --- a/public/script.js +++ b/public/script.js @@ -1271,6 +1271,7 @@ class StreamingProcessor { this.isStopped = false; this.isFinished = false; this.generator = this.nullStreamingGeneration; + this.abortController = new AbortController(); } async generate() { @@ -1925,7 +1926,7 @@ async function Generate(type, automatic_trigger, force_name2) { let prompt = await prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldInfoAfter, afterScenarioAnchor, promptBias, type); if (isStreamingEnabled()) { - streamingProcessor.generator = await sendOpenAIRequest(prompt); + streamingProcessor.generator = await sendOpenAIRequest(prompt, streamingProcessor.abortController.signal); } else { sendOpenAIRequest(prompt).then(onSuccess).catch(onError); @@ -1936,14 +1937,14 @@ async function Generate(type, automatic_trigger, force_name2) { } else if (main_api == 'poe') { if (isStreamingEnabled()) { - streamingProcessor.generator = await generatePoe(type, finalPromt); + streamingProcessor.generator = await generatePoe(type, finalPromt, streamingProcessor.abortController.signal); } else { generatePoe(type, finalPromt).then(onSuccess).catch(onError); } } else if (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming) { - streamingProcessor.generator = await generateTextGenWithStreaming(generate_data); + streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, streamingProcessor.abortController.signal); } else { jQuery.ajax({ @@ -5011,6 +5012,7 @@ $(document).ready(function () { $(document).on("click", ".mes_stop", function () { if (streamingProcessor) { + streamingProcessor.abortController.abort(); streamingProcessor.isStopped = true; streamingProcessor.onStopStreaming(); streamingProcessor = null; @@ -5104,4 +5106,11 @@ $(document).ready(function () { } }); }); + + $(document).on('beforeunload', () => { + if (streamingProcessor) { + console.log('Page reloaded. Aborting streaming...'); + streamingProcessor.abortController.abort(); + } + }); }) diff --git a/public/scripts/extensions/caption/index.js b/public/scripts/extensions/caption/index.js index dda9405e0..94c7bd049 100644 --- a/public/scripts/extensions/caption/index.js +++ b/public/scripts/extensions/caption/index.js @@ -91,12 +91,6 @@ async function onSelectImage(e) { } $(document).ready(function () { - function patchSendForm() { - const columns = $('#send_form').css('grid-template-columns').split(' '); - columns[columns.length - 1] = `${parseInt(columns[columns.length - 1]) + 40}px`; - columns[1] = 'auto'; - $('#send_form').css('grid-template-columns', columns.join(' ')); - } function addSendPictureButton() { const sendButton = document.createElement('div'); sendButton.id = 'send_picture'; @@ -118,7 +112,6 @@ $(document).ready(function () { addPictureSendForm(); addSendPictureButton(); setImageIcon(); - patchSendForm(); moduleWorker(); setInterval(moduleWorker, UPDATE_INTERVAL); }); \ No newline at end of file diff --git a/public/scripts/extensions/dice/index.js b/public/scripts/extensions/dice/index.js index 14922fc9e..3a5538dbe 100644 --- a/public/scripts/extensions/dice/index.js +++ b/public/scripts/extensions/dice/index.js @@ -79,13 +79,6 @@ function addDiceScript() { } } -function patchSendForm() { - const columns = $('#send_form').css('grid-template-columns').split(' '); - columns[columns.length - 1] = `${parseInt(columns[columns.length - 1]) + 40}px`; - columns[1] = 'auto'; - $('#send_form').css('grid-template-columns', columns.join(' ')); -} - async function moduleWorker() { const context = getContext(); @@ -97,7 +90,6 @@ async function moduleWorker() { $(document).ready(function () { addDiceScript(); addDiceRollButton(); - patchSendForm(); setDiceIcon(); moduleWorker(); setInterval(moduleWorker, UPDATE_INTERVAL); diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 56c2f2a3a..5afdbaf6d 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -55,6 +55,7 @@ const default_impersonation_prompt = "[Write your next reply from the point of v const gpt3_max = 4095; const gpt4_max = 8191; +const gpt4_32k_max = 32767; const tokenCache = {}; @@ -435,7 +436,12 @@ function getSystemPrompt(nsfw_toggle_prompt, enhance_definitions_prompt, wiBefor return whole_prompt; } -async function sendOpenAIRequest(openai_msgs_tosend) { +async function sendOpenAIRequest(openai_msgs_tosend, signal) { + // Provide default abort signal + if (!signal) { + signal = new AbortController().signal; + } + if (oai_settings.reverse_proxy) { validateReverseProxy(); } @@ -458,7 +464,8 @@ async function sendOpenAIRequest(openai_msgs_tosend) { headers: { 'Content-Type': 'application/json', "X-CSRF-Token": token, - } + }, + signal: signal, }); if (oai_settings.stream_openai) { @@ -772,6 +779,9 @@ $(document).ready(function () { if (value == 'gpt-4') { $('#openai_max_context').attr('max', gpt4_max); } + else if (value == 'gpt-4-32k') { + $('#openai_max_context').attr('max', gpt4_32k_max); + } else { $('#openai_max_context').attr('max', gpt3_max); oai_settings.openai_max_context = Math.max(oai_settings.openai_max_context, gpt3_max); diff --git a/public/scripts/poe.js b/public/scripts/poe.js index 5763f0b0c..dc5a84183 100644 --- a/public/scripts/poe.js +++ b/public/scripts/poe.js @@ -86,7 +86,7 @@ function onBotChange() { saveSettingsDebounced(); } -async function generatePoe(type, finalPrompt) { +async function generatePoe(type, finalPrompt, signal) { if (poe_settings.auto_purge) { let count_to_delete = -1; @@ -136,7 +136,7 @@ async function generatePoe(type, finalPrompt) { finalPrompt = sentences.join(''); } - const reply = await sendMessage(finalPrompt, true); + const reply = await sendMessage(finalPrompt, true, signal); got_reply = true; return reply; } @@ -160,7 +160,11 @@ async function purgeConversation(count = -1) { return response.ok; } -async function sendMessage(prompt, withStreaming) { +async function sendMessage(prompt, withStreaming, signal) { + if (!signal) { + signal = new AbortController().signal; + } + const body = JSON.stringify({ bot: poe_settings.bot, token: poe_settings.token, @@ -175,6 +179,7 @@ async function sendMessage(prompt, withStreaming) { }, body: body, method: 'POST', + signal: signal, }); if (withStreaming && poe_settings.streaming) { diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 687bdf7f8..358ddeb66 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -147,7 +147,7 @@ function setSettingByName(i, value, trigger) { } } -async function generateTextGenWithStreaming(generate_data) { +async function generateTextGenWithStreaming(generate_data, signal) { const response = await fetch('/generate_textgenerationwebui', { headers: { 'X-CSRF-Token': token, @@ -157,6 +157,7 @@ async function generateTextGenWithStreaming(generate_data) { }, body: JSON.stringify(generate_data), method: 'POST', + signal: signal, }); return async function* streamData() { diff --git a/public/style.css b/public/style.css index 6c3fbd50c..173c1d095 100644 --- a/public/style.css +++ b/public/style.css @@ -279,9 +279,8 @@ code { } #send_form { - display: grid; + display: flex; align-items: center; - grid-template-columns: 40px auto 40px; width: 100%; margin: 0 auto 0 auto; border: 1px solid var(--grey30a); @@ -644,6 +643,7 @@ select { font-family: "Noto Sans", "Noto Color Emoji", sans-serif; margin: 0; text-shadow: #000 0 0 3px; + flex: 1; } #send_textarea::placeholder, diff --git a/server.js b/server.js index c42ac630a..7c56516ec 100644 --- a/server.js +++ b/server.js @@ -367,6 +367,10 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r if (!!request.header('X-Response-Streaming')) { const fn_index = Number(request.header('X-Gradio-Streaming-Function')); + let isStreamingStopped = false; + request.socket.on('close', function() { + isStreamingStopped = true; + }); response_generate.writeHead(200, { 'Content-Type': 'text/plain;charset=utf-8', @@ -404,6 +408,12 @@ app.post("/generate_textgenerationwebui", jsonParser, async function (request, r }); while (true) { + if (isStreamingStopped) { + console.error('Streaming stopped by user. Closing websocket...'); + websocket.close(); + return null; + } + if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) { await delay(50); yield text; @@ -1895,6 +1905,12 @@ app.post('/generate_poe', jsonParser, async (request, response) => { } if (streaming) { + let isStreamingStopped = false; + request.socket.on('close', function() { + isStreamingStopped = true; + client.abortController.abort(); + }); + try { response.writeHead(200, { 'Content-Type': 'text/plain;charset=utf-8', @@ -1904,6 +1920,11 @@ app.post('/generate_poe', jsonParser, async (request, response) => { let reply = ''; for await (const mes of client.send_message(bot, prompt)) { + if (isStreamingStopped) { + console.error('Streaming stopped by user. Closing websocket...'); + break; + } + let newText = mes.text.substring(reply.length); reply = mes.text; response.write(newText); @@ -2135,6 +2156,11 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op if (!request.body) return response_generate_openai.sendStatus(400); const api_url = new URL(request.body.reverse_proxy || api_openai).toString(); + const controller = new AbortController(); + request.socket.on('close', function() { + controller.abort(); + }); + console.log(request.body); const config = { method: 'post', @@ -2153,7 +2179,8 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op "frequency_penalty": request.body.frequency_penalty, "stop": request.body.stop, "logit_bias": request.body.logit_bias - } + }, + signal: controller.signal, }; if (request.body.stream)