Add oobabooga streaming

This commit is contained in:
SillyLossy
2023-04-12 19:17:02 +03:00
parent 83982cf1fc
commit d495503ac1
4 changed files with 220 additions and 47 deletions

View File

@ -812,18 +812,7 @@
</a> </a>
</div> </div>
<span> <span>
Make sure you run it: Make sure you run it in notebook mode (not <pre>--cai-chat</pre> or <pre>--chat</pre>)
<ul>
<li>
with
<pre>--no-stream</pre> option
</li>
<li>
in notebook mode (not
<pre>--cai-chat</pre> or
<pre>--chat</pre>)
</li>
</ul>
</span> </span>
<form action="javascript:void(null);" method="post" enctype="multipart/form-data"> <form action="javascript:void(null);" method="post" enctype="multipart/form-data">
<h4>API url</h4> <h4>API url</h4>

View File

@ -11,6 +11,7 @@ import {
import { import {
textgenerationwebui_settings, textgenerationwebui_settings,
loadTextGenSettings, loadTextGenSettings,
generateTextGenWithStreaming,
} from "./scripts/textgen-settings.js"; } from "./scripts/textgen-settings.js";
import { import {
@ -357,7 +358,6 @@ var max_context = 2048;
var is_pygmalion = false; var is_pygmalion = false;
var tokens_already_generated = 0; var tokens_already_generated = 0;
var message_already_generated = ""; var message_already_generated = "";
var if_typing_text = false;
const tokens_cycle_count = 30; const tokens_cycle_count = 30;
var cycle_count_generation = 0; var cycle_count_generation = 0;
@ -501,6 +501,22 @@ async function getStatus() {
is_pygmalion = false; is_pygmalion = false;
} }
// determine if streaming is enabled for ooba
if (main_api == 'textgenerationwebui' && typeof data.gradio_config == 'string') {
try {
let textGenConfig = JSON.parse(data.gradio_config);
let commandLineConfig = textGenConfig.components.filter(x => x.type == "checkboxgroup" && Array.isArray(x.props.choices) && x.props.choices.includes("no_stream"));
if (commandLineConfig.length) {
let selectedOptions = commandLineConfig[0].props.value;
textgenerationwebui_settings.streaming = !selectedOptions.includes('no_stream');
}
}
catch {
textgenerationwebui_settings.streaming = false;
}
}
//console.log(online_status); //console.log(online_status);
resultCheckStatus(); resultCheckStatus();
if (online_status !== "no_connection") { if (online_status !== "no_connection") {
@ -1114,7 +1130,9 @@ function appendToStoryString(value, prefix) {
} }
function isStreamingEnabled() { function isStreamingEnabled() {
return (main_api == 'openai' && oai_settings.stream_openai) || (main_api == 'poe' && poe_settings.streaming); return (main_api == 'openai' && oai_settings.stream_openai)
|| (main_api == 'poe' && poe_settings.streaming)
|| (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming);
} }
class StreamingProcessor { class StreamingProcessor {
@ -1180,7 +1198,9 @@ class StreamingProcessor {
} }
async generate() { async generate() {
this.messageId = this.onStartStreaming('...'); if (this.messageId == -1) {
this.messageId = this.onStartStreaming('...');
}
for await (const text of this.generator()) { for await (const text of this.generator()) {
if (this.isStopped) { if (this.isStopped) {
@ -1202,6 +1222,7 @@ class StreamingProcessor {
this.isFinished = true; this.isFinished = true;
this.onFinishStreaming(this.messageId, this.result); this.onFinishStreaming(this.messageId, this.result);
return this.result;
} }
} }
@ -1226,7 +1247,6 @@ async function Generate(type, automatic_trigger, force_name2) {
if (isStreamingEnabled()) { if (isStreamingEnabled()) {
streamingProcessor = new StreamingProcessor(type, force_name2); streamingProcessor = new StreamingProcessor(type, force_name2);
hideSwipeButtons();
} }
else { else {
streamingProcessor = false; streamingProcessor = false;
@ -1766,6 +1786,8 @@ async function Generate(type, automatic_trigger, force_name2) {
'seed': textgenerationwebui_settings.seed, 'seed': textgenerationwebui_settings.seed,
'add_bos_token': textgenerationwebui_settings.add_bos_token, 'add_bos_token': textgenerationwebui_settings.add_bos_token,
'custom_stopping_strings': getStoppingStrings().concat(textgenerationwebui_settings.custom_stopping_strings), 'custom_stopping_strings': getStoppingStrings().concat(textgenerationwebui_settings.custom_stopping_strings),
'truncation_length': max_context,
'ban_eos_token': textgenerationwebui_settings.ban_eos_token,
} }
]; ];
generate_data = { "data": [JSON.stringify(data)] }; generate_data = { "data": [JSON.stringify(data)] };
@ -1827,6 +1849,9 @@ async function Generate(type, automatic_trigger, force_name2) {
generatePoe(finalPromt).then(onSuccess).catch(onError); generatePoe(finalPromt).then(onSuccess).catch(onError);
} }
} }
else if (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming) {
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, finalPromt);
}
else { else {
jQuery.ajax({ jQuery.ajax({
type: 'POST', // type: 'POST', //
@ -1844,7 +1869,20 @@ async function Generate(type, automatic_trigger, force_name2) {
} }
if (isStreamingEnabled()) { if (isStreamingEnabled()) {
await streamingProcessor.generate(); hideSwipeButtons();
let getMessage = await streamingProcessor.generate();
if (isMultigenEnabled()) {
message_already_generated += getMessage;
promptBias = '';
if (!streamingProcessor.isStopped && shouldContinueMultigen(getMessage)) {
streamingProcessor.isFinished = false;
runGenerate(getMessage);
console.log('returning to make generate again');
return;
}
}
streamingProcessor = null; streamingProcessor = null;
} }
@ -1860,15 +1898,11 @@ async function Generate(type, automatic_trigger, force_name2) {
// to make it continue generating so long as it's under max_amount and hasn't signaled // to make it continue generating so long as it's under max_amount and hasn't signaled
// an end to the character's response via typing "You:" or adding "<endoftext>" // an end to the character's response via typing "You:" or adding "<endoftext>"
if (isMultigenEnabled()) { if (isMultigenEnabled()) {
if_typing_text = false;
message_already_generated += getMessage; message_already_generated += getMessage;
promptBias = ''; promptBias = '';
if (message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg if (shouldContinueMultigen(getMessage)) {
message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no <endoftext> stamp in the response msg
tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length..
getMessage.length > 0) { //if we actually have gen'd text at all...
runGenerate(getMessage); runGenerate(getMessage);
console.log('returning to make generate again'); //generate again with the 'GetMessage' argument.. console.log('returning to make generate again');
return; return;
} }
@ -1936,6 +1970,13 @@ async function Generate(type, automatic_trigger, force_name2) {
console.log('generate ending'); console.log('generate ending');
} //generate ends } //generate ends
function shouldContinueMultigen(getMessage) {
return message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg
message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no <endoftext> stamp in the response msg
tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length..
getMessage.length > 0; //if we actually have gen'd text at all...
}
function extractNameFromMessage(getMessage, force_name2) { function extractNameFromMessage(getMessage, force_name2) {
let this_mes_is_name = true; let this_mes_is_name = true;
if (getMessage.indexOf(name2 + ":") === 0) { if (getMessage.indexOf(name2 + ":") === 0) {

View File

@ -1,10 +1,12 @@
import { import {
saveSettingsDebounced, saveSettingsDebounced,
token,
} from "../script.js"; } from "../script.js";
export { export {
textgenerationwebui_settings, textgenerationwebui_settings,
loadTextGenSettings, loadTextGenSettings,
generateTextGenWithStreaming,
} }
let textgenerationwebui_settings = { let textgenerationwebui_settings = {
@ -25,6 +27,9 @@ let textgenerationwebui_settings = {
preset: 'Default', preset: 'Default',
add_bos_token: true, add_bos_token: true,
custom_stopping_strings: [], custom_stopping_strings: [],
truncation_length: 2048,
ban_eos_token: false,
streaming: false,
}; };
let textgenerationwebui_presets = []; let textgenerationwebui_presets = [];
@ -136,3 +141,33 @@ function setSettingByName(i, value, trigger) {
$(`#${i}_textgenerationwebui`).trigger('input'); $(`#${i}_textgenerationwebui`).trigger('input');
} }
} }
async function generateTextGenWithStreaming(generate_data, finalPromt) {
const response = await fetch('/generate_textgenerationwebui', {
headers: {
'X-CSRF-Token': token,
'Content-Type': 'application/json',
'X-Response-Streaming': true,
},
body: JSON.stringify(generate_data),
method: 'POST',
});
return async function* streamData() {
const decoder = new TextDecoder();
const reader = response.body.getReader();
let getMessage = '';
while (true) {
const { done, value } = await reader.read();
let response = decoder.decode(value);
getMessage += response;
if (done) {
return;
}
yield getMessage;
}
}
}

154
server.js
View File

@ -39,6 +39,7 @@ const listen = config.listen;
const axios = require('axios'); const axios = require('axios');
const tiktoken = require('@dqbd/tiktoken'); const tiktoken = require('@dqbd/tiktoken');
const WebSocket = require('ws');
var Client = require('node-rest-client').Client; var Client = require('node-rest-client').Client;
var client = new Client(); var client = new Client();
@ -308,34 +309,141 @@ app.post("/generate", jsonParser, async function (request, response_generate = r
} }
}); });
function randomHash() {
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
let result = '';
for (let i = 0; i < 9; i++) {
result += letters.charAt(Math.floor(Math.random() * letters.length));
}
return result;
}
function textGenProcessStartedHandler(websocket, content, session, prompt, SEND_PROMPT_GRADIO_FN) {
switch (content.msg) {
case "send_hash":
const send_hash = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN });
websocket.send(send_hash);
break;
case "estimation":
break;
case "send_data":
const send_data = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN, "data": prompt.data });
console.log(send_data);
websocket.send(send_data);
break;
case "process_starts":
break;
case "process_generating":
return content.output.data[0];
case "process_completed":
return null;
}
return '';
}
//************** Text generation web UI //************** Text generation web UI
app.post("/generate_textgenerationwebui", jsonParser, function (request, response_generate = response) { app.post("/generate_textgenerationwebui", jsonParser, async function (request, response_generate = response) {
if (!request.body) return response_generate.sendStatus(400); if (!request.body) return response_generate.sendStatus(400);
console.log(request.body); console.log(request.body);
var args = {
data: request.body, if (!!request.header('X-Response-Streaming')) {
headers: { "Content-Type": "application/json" } const SEND_PARAMS_GRADIO_FN = 29;
};
client.post(api_server + "/run/textgen", args, function (data, response) { response_generate.writeHead(200, {
console.log("####", data); 'Transfer-Encoding': 'chunked',
if (response.statusCode == 200) { 'Cache-Control': 'no-transform',
console.log(data); });
response_generate.send(data);
async function* readWebsocket() {
const session = randomHash();
const url = new URL(api_server);
const websocket = new WebSocket(`ws://${url.host}/queue/join`, { perMessageDeflate: false });
let text = '';
websocket.on('open', async function() {
console.log('websocket open');
});
websocket.on('error', (err) => {
console.error(err);
websocket.close();
});
websocket.on('close', (code, buffer) => {
const reason = new TextDecoder().decode(buffer)
console.log(reason);
});
websocket.on('message', async (message) => {
const content = json5.parse(message);
console.log(content);
text = textGenProcessStartedHandler(websocket, content, session, request.body, SEND_PARAMS_GRADIO_FN);
});
while (true) {
if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) {
await delay(50);
yield text;
if (!text && typeof text !== 'string') {
websocket.close();
}
}
else {
break;
}
}
} }
if (response.statusCode == 422) {
console.log('Validation error'); let result = json5.parse(request.body.data)[0];
try {
for await (const text of readWebsocket()) {
if (text == null) {
break;
}
let newText = text.substring(result.length);
if (!newText) {
continue;
}
result = text;
response_generate.write(newText);
}
}
finally {
response_generate.end();
}
}
else {
var args = {
data: request.body,
headers: { "Content-Type": "application/json" }
};
client.post(api_server + "/run/textgen", args, function (data, response) {
console.log("####", data);
if (response.statusCode == 200) {
console.log(data);
response_generate.send(data);
}
if (response.statusCode == 422) {
console.log('Validation error');
response_generate.send({ error: true });
}
if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) {
console.log(data);
response_generate.send({ error: true });
}
}).on('error', function (err) {
console.log(err);
//console.log('something went wrong on the request', err.request.options);
response_generate.send({ error: true }); response_generate.send({ error: true });
} });
if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) { }
console.log(data);
response_generate.send({ error: true });
}
}).on('error', function (err) {
console.log(err);
//console.log('something went wrong on the request', err.request.options);
response_generate.send({ error: true });
});
}); });
@ -447,7 +555,7 @@ app.post("/getstatus", jsonParser, function (request, response_getstatus = respo
if (!response) if (!response)
throw "no_connection"; throw "no_connection";
let model = json5.parse(response).components.filter((x) => x.props.label == "Model" && x.type == "dropdown")[0].props.value; let model = json5.parse(response).components.filter((x) => x.props.label == "Model" && x.type == "dropdown")[0].props.value;
data = { result: model }; data = { result: model, gradio_config: response };
if (!data) if (!data)
throw "no_connection"; throw "no_connection";
} catch { } catch {