diff --git a/public/index.html b/public/index.html index 3c04f4e6b..328b98c53 100644 --- a/public/index.html +++ b/public/index.html @@ -812,18 +812,7 @@ - Make sure you run it: - + Make sure you run it in notebook mode (not
--cai-chat
or
--chat
)

API url

diff --git a/public/script.js b/public/script.js index c3f06494f..867320f57 100644 --- a/public/script.js +++ b/public/script.js @@ -11,6 +11,7 @@ import { import { textgenerationwebui_settings, loadTextGenSettings, + generateTextGenWithStreaming, } from "./scripts/textgen-settings.js"; import { @@ -357,7 +358,6 @@ var max_context = 2048; var is_pygmalion = false; var tokens_already_generated = 0; var message_already_generated = ""; -var if_typing_text = false; const tokens_cycle_count = 30; var cycle_count_generation = 0; @@ -501,6 +501,22 @@ async function getStatus() { is_pygmalion = false; } + // determine if streaming is enabled for ooba + if (main_api == 'textgenerationwebui' && typeof data.gradio_config == 'string') { + try { + let textGenConfig = JSON.parse(data.gradio_config); + let commandLineConfig = textGenConfig.components.filter(x => x.type == "checkboxgroup" && Array.isArray(x.props.choices) && x.props.choices.includes("no_stream")); + + if (commandLineConfig.length) { + let selectedOptions = commandLineConfig[0].props.value; + textgenerationwebui_settings.streaming = !selectedOptions.includes('no_stream'); + } + } + catch { + textgenerationwebui_settings.streaming = false; + } + } + //console.log(online_status); resultCheckStatus(); if (online_status !== "no_connection") { @@ -1114,7 +1130,9 @@ function appendToStoryString(value, prefix) { } function isStreamingEnabled() { - return (main_api == 'openai' && oai_settings.stream_openai) || (main_api == 'poe' && poe_settings.streaming); + return (main_api == 'openai' && oai_settings.stream_openai) + || (main_api == 'poe' && poe_settings.streaming) + || (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming); } class StreamingProcessor { @@ -1180,7 +1198,9 @@ class StreamingProcessor { } async generate() { - this.messageId = this.onStartStreaming('...'); + if (this.messageId == -1) { + this.messageId = this.onStartStreaming('...'); + } for await (const text of this.generator()) { if (this.isStopped) { @@ -1202,6 +1222,7 @@ class StreamingProcessor { this.isFinished = true; this.onFinishStreaming(this.messageId, this.result); + return this.result; } } @@ -1226,7 +1247,6 @@ async function Generate(type, automatic_trigger, force_name2) { if (isStreamingEnabled()) { streamingProcessor = new StreamingProcessor(type, force_name2); - hideSwipeButtons(); } else { streamingProcessor = false; @@ -1766,6 +1786,8 @@ async function Generate(type, automatic_trigger, force_name2) { 'seed': textgenerationwebui_settings.seed, 'add_bos_token': textgenerationwebui_settings.add_bos_token, 'custom_stopping_strings': getStoppingStrings().concat(textgenerationwebui_settings.custom_stopping_strings), + 'truncation_length': max_context, + 'ban_eos_token': textgenerationwebui_settings.ban_eos_token, } ]; generate_data = { "data": [JSON.stringify(data)] }; @@ -1827,6 +1849,9 @@ async function Generate(type, automatic_trigger, force_name2) { generatePoe(finalPromt).then(onSuccess).catch(onError); } } + else if (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming) { + streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, finalPromt); + } else { jQuery.ajax({ type: 'POST', // @@ -1844,7 +1869,20 @@ async function Generate(type, automatic_trigger, force_name2) { } if (isStreamingEnabled()) { - await streamingProcessor.generate(); + hideSwipeButtons(); + let getMessage = await streamingProcessor.generate(); + + if (isMultigenEnabled()) { + message_already_generated += getMessage; + promptBias = ''; + if (!streamingProcessor.isStopped && shouldContinueMultigen(getMessage)) { + streamingProcessor.isFinished = false; + runGenerate(getMessage); + console.log('returning to make generate again'); + return; + } + } + streamingProcessor = null; } @@ -1860,15 +1898,11 @@ async function Generate(type, automatic_trigger, force_name2) { // to make it continue generating so long as it's under max_amount and hasn't signaled // an end to the character's response via typing "You:" or adding "" if (isMultigenEnabled()) { - if_typing_text = false; message_already_generated += getMessage; promptBias = ''; - if (message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg - message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no stamp in the response msg - tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length.. - getMessage.length > 0) { //if we actually have gen'd text at all... + if (shouldContinueMultigen(getMessage)) { runGenerate(getMessage); - console.log('returning to make generate again'); //generate again with the 'GetMessage' argument.. + console.log('returning to make generate again'); return; } @@ -1936,6 +1970,13 @@ async function Generate(type, automatic_trigger, force_name2) { console.log('generate ending'); } //generate ends +function shouldContinueMultigen(getMessage) { + return message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg + message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no stamp in the response msg + tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length.. + getMessage.length > 0; //if we actually have gen'd text at all... +} + function extractNameFromMessage(getMessage, force_name2) { let this_mes_is_name = true; if (getMessage.indexOf(name2 + ":") === 0) { diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 26cca07d1..37be18df8 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -1,10 +1,12 @@ import { saveSettingsDebounced, + token, } from "../script.js"; export { textgenerationwebui_settings, loadTextGenSettings, + generateTextGenWithStreaming, } let textgenerationwebui_settings = { @@ -23,8 +25,11 @@ let textgenerationwebui_settings = { early_stopping: false, seed: -1, preset: 'Default', - add_bos_token: true, + add_bos_token: true, custom_stopping_strings: [], + truncation_length: 2048, + ban_eos_token: false, + streaming: false, }; let textgenerationwebui_presets = []; @@ -136,3 +141,33 @@ function setSettingByName(i, value, trigger) { $(`#${i}_textgenerationwebui`).trigger('input'); } } + +async function generateTextGenWithStreaming(generate_data, finalPromt) { + const response = await fetch('/generate_textgenerationwebui', { + headers: { + 'X-CSRF-Token': token, + 'Content-Type': 'application/json', + 'X-Response-Streaming': true, + }, + body: JSON.stringify(generate_data), + method: 'POST', + }); + + return async function* streamData() { + const decoder = new TextDecoder(); + const reader = response.body.getReader(); + let getMessage = ''; + while (true) { + const { done, value } = await reader.read(); + let response = decoder.decode(value); + + getMessage += response; + + if (done) { + return; + } + + yield getMessage; + } + } +} \ No newline at end of file diff --git a/server.js b/server.js index 8caba2a14..419fe34cd 100644 --- a/server.js +++ b/server.js @@ -39,6 +39,7 @@ const listen = config.listen; const axios = require('axios'); const tiktoken = require('@dqbd/tiktoken'); +const WebSocket = require('ws'); var Client = require('node-rest-client').Client; var client = new Client(); @@ -308,34 +309,141 @@ app.post("/generate", jsonParser, async function (request, response_generate = r } }); +function randomHash() { + const letters = 'abcdefghijklmnopqrstuvwxyz0123456789'; + let result = ''; + for (let i = 0; i < 9; i++) { + result += letters.charAt(Math.floor(Math.random() * letters.length)); + } + return result; +} + +function textGenProcessStartedHandler(websocket, content, session, prompt, SEND_PROMPT_GRADIO_FN) { + switch (content.msg) { + case "send_hash": + const send_hash = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN }); + websocket.send(send_hash); + break; + case "estimation": + break; + case "send_data": + const send_data = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN, "data": prompt.data }); + console.log(send_data); + websocket.send(send_data); + break; + case "process_starts": + break; + case "process_generating": + return content.output.data[0]; + case "process_completed": + return null; + } + + return ''; +} + //************** Text generation web UI -app.post("/generate_textgenerationwebui", jsonParser, function (request, response_generate = response) { +app.post("/generate_textgenerationwebui", jsonParser, async function (request, response_generate = response) { if (!request.body) return response_generate.sendStatus(400); console.log(request.body); - var args = { - data: request.body, - headers: { "Content-Type": "application/json" } - }; - client.post(api_server + "/run/textgen", args, function (data, response) { - console.log("####", data); - if (response.statusCode == 200) { - console.log(data); - response_generate.send(data); + + if (!!request.header('X-Response-Streaming')) { + const SEND_PARAMS_GRADIO_FN = 29; + + response_generate.writeHead(200, { + 'Transfer-Encoding': 'chunked', + 'Cache-Control': 'no-transform', + }); + + async function* readWebsocket() { + const session = randomHash(); + const url = new URL(api_server); + const websocket = new WebSocket(`ws://${url.host}/queue/join`, { perMessageDeflate: false }); + let text = ''; + + websocket.on('open', async function() { + console.log('websocket open'); + }); + + websocket.on('error', (err) => { + console.error(err); + websocket.close(); + }); + + websocket.on('close', (code, buffer) => { + const reason = new TextDecoder().decode(buffer) + console.log(reason); + }); + + websocket.on('message', async (message) => { + const content = json5.parse(message); + console.log(content); + text = textGenProcessStartedHandler(websocket, content, session, request.body, SEND_PARAMS_GRADIO_FN); + }); + + while (true) { + if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) { + await delay(50); + yield text; + + if (!text && typeof text !== 'string') { + websocket.close(); + } + } + else { + break; + } + } } - if (response.statusCode == 422) { - console.log('Validation error'); + + let result = json5.parse(request.body.data)[0]; + + try { + for await (const text of readWebsocket()) { + if (text == null) { + break; + } + + let newText = text.substring(result.length); + + if (!newText) { + continue; + } + + result = text; + response_generate.write(newText); + } + } + finally { + response_generate.end(); + } + } + else { + var args = { + data: request.body, + headers: { "Content-Type": "application/json" } + }; + client.post(api_server + "/run/textgen", args, function (data, response) { + console.log("####", data); + if (response.statusCode == 200) { + console.log(data); + response_generate.send(data); + } + if (response.statusCode == 422) { + console.log('Validation error'); + response_generate.send({ error: true }); + } + if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) { + console.log(data); + response_generate.send({ error: true }); + } + }).on('error', function (err) { + console.log(err); + //console.log('something went wrong on the request', err.request.options); response_generate.send({ error: true }); - } - if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) { - console.log(data); - response_generate.send({ error: true }); - } - }).on('error', function (err) { - console.log(err); - //console.log('something went wrong on the request', err.request.options); - response_generate.send({ error: true }); - }); + }); + } }); @@ -447,7 +555,7 @@ app.post("/getstatus", jsonParser, function (request, response_getstatus = respo if (!response) throw "no_connection"; let model = json5.parse(response).components.filter((x) => x.props.label == "Model" && x.type == "dropdown")[0].props.value; - data = { result: model }; + data = { result: model, gradio_config: response }; if (!data) throw "no_connection"; } catch {