Add oobabooga streaming

This commit is contained in:
SillyLossy
2023-04-12 19:17:02 +03:00
parent 83982cf1fc
commit d495503ac1
4 changed files with 220 additions and 47 deletions

View File

@ -812,18 +812,7 @@
</a>
</div>
<span>
Make sure you run it:
<ul>
<li>
with
<pre>--no-stream</pre> option
</li>
<li>
in notebook mode (not
<pre>--cai-chat</pre> or
<pre>--chat</pre>)
</li>
</ul>
Make sure you run it in notebook mode (not <pre>--cai-chat</pre> or <pre>--chat</pre>)
</span>
<form action="javascript:void(null);" method="post" enctype="multipart/form-data">
<h4>API url</h4>

View File

@ -11,6 +11,7 @@ import {
import {
textgenerationwebui_settings,
loadTextGenSettings,
generateTextGenWithStreaming,
} from "./scripts/textgen-settings.js";
import {
@ -357,7 +358,6 @@ var max_context = 2048;
var is_pygmalion = false;
var tokens_already_generated = 0;
var message_already_generated = "";
var if_typing_text = false;
const tokens_cycle_count = 30;
var cycle_count_generation = 0;
@ -501,6 +501,22 @@ async function getStatus() {
is_pygmalion = false;
}
// determine if streaming is enabled for ooba
if (main_api == 'textgenerationwebui' && typeof data.gradio_config == 'string') {
try {
let textGenConfig = JSON.parse(data.gradio_config);
let commandLineConfig = textGenConfig.components.filter(x => x.type == "checkboxgroup" && Array.isArray(x.props.choices) && x.props.choices.includes("no_stream"));
if (commandLineConfig.length) {
let selectedOptions = commandLineConfig[0].props.value;
textgenerationwebui_settings.streaming = !selectedOptions.includes('no_stream');
}
}
catch {
textgenerationwebui_settings.streaming = false;
}
}
//console.log(online_status);
resultCheckStatus();
if (online_status !== "no_connection") {
@ -1114,7 +1130,9 @@ function appendToStoryString(value, prefix) {
}
function isStreamingEnabled() {
return (main_api == 'openai' && oai_settings.stream_openai) || (main_api == 'poe' && poe_settings.streaming);
return (main_api == 'openai' && oai_settings.stream_openai)
|| (main_api == 'poe' && poe_settings.streaming)
|| (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming);
}
class StreamingProcessor {
@ -1180,7 +1198,9 @@ class StreamingProcessor {
}
async generate() {
this.messageId = this.onStartStreaming('...');
if (this.messageId == -1) {
this.messageId = this.onStartStreaming('...');
}
for await (const text of this.generator()) {
if (this.isStopped) {
@ -1202,6 +1222,7 @@ class StreamingProcessor {
this.isFinished = true;
this.onFinishStreaming(this.messageId, this.result);
return this.result;
}
}
@ -1226,7 +1247,6 @@ async function Generate(type, automatic_trigger, force_name2) {
if (isStreamingEnabled()) {
streamingProcessor = new StreamingProcessor(type, force_name2);
hideSwipeButtons();
}
else {
streamingProcessor = false;
@ -1766,6 +1786,8 @@ async function Generate(type, automatic_trigger, force_name2) {
'seed': textgenerationwebui_settings.seed,
'add_bos_token': textgenerationwebui_settings.add_bos_token,
'custom_stopping_strings': getStoppingStrings().concat(textgenerationwebui_settings.custom_stopping_strings),
'truncation_length': max_context,
'ban_eos_token': textgenerationwebui_settings.ban_eos_token,
}
];
generate_data = { "data": [JSON.stringify(data)] };
@ -1827,6 +1849,9 @@ async function Generate(type, automatic_trigger, force_name2) {
generatePoe(finalPromt).then(onSuccess).catch(onError);
}
}
else if (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming) {
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, finalPromt);
}
else {
jQuery.ajax({
type: 'POST', //
@ -1844,7 +1869,20 @@ async function Generate(type, automatic_trigger, force_name2) {
}
if (isStreamingEnabled()) {
await streamingProcessor.generate();
hideSwipeButtons();
let getMessage = await streamingProcessor.generate();
if (isMultigenEnabled()) {
message_already_generated += getMessage;
promptBias = '';
if (!streamingProcessor.isStopped && shouldContinueMultigen(getMessage)) {
streamingProcessor.isFinished = false;
runGenerate(getMessage);
console.log('returning to make generate again');
return;
}
}
streamingProcessor = null;
}
@ -1860,15 +1898,11 @@ async function Generate(type, automatic_trigger, force_name2) {
// to make it continue generating so long as it's under max_amount and hasn't signaled
// an end to the character's response via typing "You:" or adding "<endoftext>"
if (isMultigenEnabled()) {
if_typing_text = false;
message_already_generated += getMessage;
promptBias = '';
if (message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg
message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no <endoftext> stamp in the response msg
tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length..
getMessage.length > 0) { //if we actually have gen'd text at all...
if (shouldContinueMultigen(getMessage)) {
runGenerate(getMessage);
console.log('returning to make generate again'); //generate again with the 'GetMessage' argument..
console.log('returning to make generate again');
return;
}
@ -1936,6 +1970,13 @@ async function Generate(type, automatic_trigger, force_name2) {
console.log('generate ending');
} //generate ends
function shouldContinueMultigen(getMessage) {
return message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg
message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no <endoftext> stamp in the response msg
tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length..
getMessage.length > 0; //if we actually have gen'd text at all...
}
function extractNameFromMessage(getMessage, force_name2) {
let this_mes_is_name = true;
if (getMessage.indexOf(name2 + ":") === 0) {

View File

@ -1,10 +1,12 @@
import {
saveSettingsDebounced,
token,
} from "../script.js";
export {
textgenerationwebui_settings,
loadTextGenSettings,
generateTextGenWithStreaming,
}
let textgenerationwebui_settings = {
@ -23,8 +25,11 @@ let textgenerationwebui_settings = {
early_stopping: false,
seed: -1,
preset: 'Default',
add_bos_token: true,
add_bos_token: true,
custom_stopping_strings: [],
truncation_length: 2048,
ban_eos_token: false,
streaming: false,
};
let textgenerationwebui_presets = [];
@ -136,3 +141,33 @@ function setSettingByName(i, value, trigger) {
$(`#${i}_textgenerationwebui`).trigger('input');
}
}
async function generateTextGenWithStreaming(generate_data, finalPromt) {
const response = await fetch('/generate_textgenerationwebui', {
headers: {
'X-CSRF-Token': token,
'Content-Type': 'application/json',
'X-Response-Streaming': true,
},
body: JSON.stringify(generate_data),
method: 'POST',
});
return async function* streamData() {
const decoder = new TextDecoder();
const reader = response.body.getReader();
let getMessage = '';
while (true) {
const { done, value } = await reader.read();
let response = decoder.decode(value);
getMessage += response;
if (done) {
return;
}
yield getMessage;
}
}
}

154
server.js
View File

@ -39,6 +39,7 @@ const listen = config.listen;
const axios = require('axios');
const tiktoken = require('@dqbd/tiktoken');
const WebSocket = require('ws');
var Client = require('node-rest-client').Client;
var client = new Client();
@ -308,34 +309,141 @@ app.post("/generate", jsonParser, async function (request, response_generate = r
}
});
function randomHash() {
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
let result = '';
for (let i = 0; i < 9; i++) {
result += letters.charAt(Math.floor(Math.random() * letters.length));
}
return result;
}
function textGenProcessStartedHandler(websocket, content, session, prompt, SEND_PROMPT_GRADIO_FN) {
switch (content.msg) {
case "send_hash":
const send_hash = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN });
websocket.send(send_hash);
break;
case "estimation":
break;
case "send_data":
const send_data = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN, "data": prompt.data });
console.log(send_data);
websocket.send(send_data);
break;
case "process_starts":
break;
case "process_generating":
return content.output.data[0];
case "process_completed":
return null;
}
return '';
}
//************** Text generation web UI
app.post("/generate_textgenerationwebui", jsonParser, function (request, response_generate = response) {
app.post("/generate_textgenerationwebui", jsonParser, async function (request, response_generate = response) {
if (!request.body) return response_generate.sendStatus(400);
console.log(request.body);
var args = {
data: request.body,
headers: { "Content-Type": "application/json" }
};
client.post(api_server + "/run/textgen", args, function (data, response) {
console.log("####", data);
if (response.statusCode == 200) {
console.log(data);
response_generate.send(data);
if (!!request.header('X-Response-Streaming')) {
const SEND_PARAMS_GRADIO_FN = 29;
response_generate.writeHead(200, {
'Transfer-Encoding': 'chunked',
'Cache-Control': 'no-transform',
});
async function* readWebsocket() {
const session = randomHash();
const url = new URL(api_server);
const websocket = new WebSocket(`ws://${url.host}/queue/join`, { perMessageDeflate: false });
let text = '';
websocket.on('open', async function() {
console.log('websocket open');
});
websocket.on('error', (err) => {
console.error(err);
websocket.close();
});
websocket.on('close', (code, buffer) => {
const reason = new TextDecoder().decode(buffer)
console.log(reason);
});
websocket.on('message', async (message) => {
const content = json5.parse(message);
console.log(content);
text = textGenProcessStartedHandler(websocket, content, session, request.body, SEND_PARAMS_GRADIO_FN);
});
while (true) {
if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) {
await delay(50);
yield text;
if (!text && typeof text !== 'string') {
websocket.close();
}
}
else {
break;
}
}
}
if (response.statusCode == 422) {
console.log('Validation error');
let result = json5.parse(request.body.data)[0];
try {
for await (const text of readWebsocket()) {
if (text == null) {
break;
}
let newText = text.substring(result.length);
if (!newText) {
continue;
}
result = text;
response_generate.write(newText);
}
}
finally {
response_generate.end();
}
}
else {
var args = {
data: request.body,
headers: { "Content-Type": "application/json" }
};
client.post(api_server + "/run/textgen", args, function (data, response) {
console.log("####", data);
if (response.statusCode == 200) {
console.log(data);
response_generate.send(data);
}
if (response.statusCode == 422) {
console.log('Validation error');
response_generate.send({ error: true });
}
if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) {
console.log(data);
response_generate.send({ error: true });
}
}).on('error', function (err) {
console.log(err);
//console.log('something went wrong on the request', err.request.options);
response_generate.send({ error: true });
}
if (response.statusCode == 501 || response.statusCode == 503 || response.statusCode == 507) {
console.log(data);
response_generate.send({ error: true });
}
}).on('error', function (err) {
console.log(err);
//console.log('something went wrong on the request', err.request.options);
response_generate.send({ error: true });
});
});
}
});
@ -447,7 +555,7 @@ app.post("/getstatus", jsonParser, function (request, response_getstatus = respo
if (!response)
throw "no_connection";
let model = json5.parse(response).components.filter((x) => x.props.label == "Model" && x.type == "dropdown")[0].props.value;
data = { result: model };
data = { result: model, gradio_config: response };
if (!data)
throw "no_connection";
} catch {