mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Add oobabooga streaming
This commit is contained in:
@ -812,18 +812,7 @@
|
||||
</a>
|
||||
</div>
|
||||
<span>
|
||||
Make sure you run it:
|
||||
<ul>
|
||||
<li>
|
||||
with
|
||||
<pre>--no-stream</pre> option
|
||||
</li>
|
||||
<li>
|
||||
in notebook mode (not
|
||||
<pre>--cai-chat</pre> or
|
||||
<pre>--chat</pre>)
|
||||
</li>
|
||||
</ul>
|
||||
Make sure you run it in notebook mode (not <pre>--cai-chat</pre> or <pre>--chat</pre>)
|
||||
</span>
|
||||
<form action="javascript:void(null);" method="post" enctype="multipart/form-data">
|
||||
<h4>API url</h4>
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
import {
|
||||
textgenerationwebui_settings,
|
||||
loadTextGenSettings,
|
||||
generateTextGenWithStreaming,
|
||||
} from "./scripts/textgen-settings.js";
|
||||
|
||||
import {
|
||||
@ -357,7 +358,6 @@ var max_context = 2048;
|
||||
var is_pygmalion = false;
|
||||
var tokens_already_generated = 0;
|
||||
var message_already_generated = "";
|
||||
var if_typing_text = false;
|
||||
const tokens_cycle_count = 30;
|
||||
var cycle_count_generation = 0;
|
||||
|
||||
@ -501,6 +501,22 @@ async function getStatus() {
|
||||
is_pygmalion = false;
|
||||
}
|
||||
|
||||
// determine if streaming is enabled for ooba
|
||||
if (main_api == 'textgenerationwebui' && typeof data.gradio_config == 'string') {
|
||||
try {
|
||||
let textGenConfig = JSON.parse(data.gradio_config);
|
||||
let commandLineConfig = textGenConfig.components.filter(x => x.type == "checkboxgroup" && Array.isArray(x.props.choices) && x.props.choices.includes("no_stream"));
|
||||
|
||||
if (commandLineConfig.length) {
|
||||
let selectedOptions = commandLineConfig[0].props.value;
|
||||
textgenerationwebui_settings.streaming = !selectedOptions.includes('no_stream');
|
||||
}
|
||||
}
|
||||
catch {
|
||||
textgenerationwebui_settings.streaming = false;
|
||||
}
|
||||
}
|
||||
|
||||
//console.log(online_status);
|
||||
resultCheckStatus();
|
||||
if (online_status !== "no_connection") {
|
||||
@ -1114,7 +1130,9 @@ function appendToStoryString(value, prefix) {
|
||||
}
|
||||
|
||||
function isStreamingEnabled() {
|
||||
return (main_api == 'openai' && oai_settings.stream_openai) || (main_api == 'poe' && poe_settings.streaming);
|
||||
return (main_api == 'openai' && oai_settings.stream_openai)
|
||||
|| (main_api == 'poe' && poe_settings.streaming)
|
||||
|| (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming);
|
||||
}
|
||||
|
||||
class StreamingProcessor {
|
||||
@ -1180,7 +1198,9 @@ class StreamingProcessor {
|
||||
}
|
||||
|
||||
async generate() {
|
||||
if (this.messageId == -1) {
|
||||
this.messageId = this.onStartStreaming('...');
|
||||
}
|
||||
|
||||
for await (const text of this.generator()) {
|
||||
if (this.isStopped) {
|
||||
@ -1202,6 +1222,7 @@ class StreamingProcessor {
|
||||
|
||||
this.isFinished = true;
|
||||
this.onFinishStreaming(this.messageId, this.result);
|
||||
return this.result;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1226,7 +1247,6 @@ async function Generate(type, automatic_trigger, force_name2) {
|
||||
|
||||
if (isStreamingEnabled()) {
|
||||
streamingProcessor = new StreamingProcessor(type, force_name2);
|
||||
hideSwipeButtons();
|
||||
}
|
||||
else {
|
||||
streamingProcessor = false;
|
||||
@ -1766,6 +1786,8 @@ async function Generate(type, automatic_trigger, force_name2) {
|
||||
'seed': textgenerationwebui_settings.seed,
|
||||
'add_bos_token': textgenerationwebui_settings.add_bos_token,
|
||||
'custom_stopping_strings': getStoppingStrings().concat(textgenerationwebui_settings.custom_stopping_strings),
|
||||
'truncation_length': max_context,
|
||||
'ban_eos_token': textgenerationwebui_settings.ban_eos_token,
|
||||
}
|
||||
];
|
||||
generate_data = { "data": [JSON.stringify(data)] };
|
||||
@ -1827,6 +1849,9 @@ async function Generate(type, automatic_trigger, force_name2) {
|
||||
generatePoe(finalPromt).then(onSuccess).catch(onError);
|
||||
}
|
||||
}
|
||||
else if (main_api == 'textgenerationwebui' && textgenerationwebui_settings.streaming) {
|
||||
streamingProcessor.generator = await generateTextGenWithStreaming(generate_data, finalPromt);
|
||||
}
|
||||
else {
|
||||
jQuery.ajax({
|
||||
type: 'POST', //
|
||||
@ -1844,7 +1869,20 @@ async function Generate(type, automatic_trigger, force_name2) {
|
||||
}
|
||||
|
||||
if (isStreamingEnabled()) {
|
||||
await streamingProcessor.generate();
|
||||
hideSwipeButtons();
|
||||
let getMessage = await streamingProcessor.generate();
|
||||
|
||||
if (isMultigenEnabled()) {
|
||||
message_already_generated += getMessage;
|
||||
promptBias = '';
|
||||
if (!streamingProcessor.isStopped && shouldContinueMultigen(getMessage)) {
|
||||
streamingProcessor.isFinished = false;
|
||||
runGenerate(getMessage);
|
||||
console.log('returning to make generate again');
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
streamingProcessor = null;
|
||||
}
|
||||
|
||||
@ -1860,15 +1898,11 @@ async function Generate(type, automatic_trigger, force_name2) {
|
||||
// to make it continue generating so long as it's under max_amount and hasn't signaled
|
||||
// an end to the character's response via typing "You:" or adding "<endoftext>"
|
||||
if (isMultigenEnabled()) {
|
||||
if_typing_text = false;
|
||||
message_already_generated += getMessage;
|
||||
promptBias = '';
|
||||
if (message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg
|
||||
message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no <endoftext> stamp in the response msg
|
||||
tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length..
|
||||
getMessage.length > 0) { //if we actually have gen'd text at all...
|
||||
if (shouldContinueMultigen(getMessage)) {
|
||||
runGenerate(getMessage);
|
||||
console.log('returning to make generate again'); //generate again with the 'GetMessage' argument..
|
||||
console.log('returning to make generate again');
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1936,6 +1970,13 @@ async function Generate(type, automatic_trigger, force_name2) {
|
||||
console.log('generate ending');
|
||||
} //generate ends
|
||||
|
||||
function shouldContinueMultigen(getMessage) {
|
||||
return message_already_generated.indexOf('You:') === -1 && //if there is no 'You:' in the response msg
|
||||
message_already_generated.indexOf('<|endoftext|>') === -1 && //if there is no <endoftext> stamp in the response msg
|
||||
tokens_already_generated < parseInt(amount_gen) && //if the gen'd msg is less than the max response length..
|
||||
getMessage.length > 0; //if we actually have gen'd text at all...
|
||||
}
|
||||
|
||||
function extractNameFromMessage(getMessage, force_name2) {
|
||||
let this_mes_is_name = true;
|
||||
if (getMessage.indexOf(name2 + ":") === 0) {
|
||||
|
@ -1,10 +1,12 @@
|
||||
import {
|
||||
saveSettingsDebounced,
|
||||
token,
|
||||
} from "../script.js";
|
||||
|
||||
export {
|
||||
textgenerationwebui_settings,
|
||||
loadTextGenSettings,
|
||||
generateTextGenWithStreaming,
|
||||
}
|
||||
|
||||
let textgenerationwebui_settings = {
|
||||
@ -25,6 +27,9 @@ let textgenerationwebui_settings = {
|
||||
preset: 'Default',
|
||||
add_bos_token: true,
|
||||
custom_stopping_strings: [],
|
||||
truncation_length: 2048,
|
||||
ban_eos_token: false,
|
||||
streaming: false,
|
||||
};
|
||||
|
||||
let textgenerationwebui_presets = [];
|
||||
@ -136,3 +141,33 @@ function setSettingByName(i, value, trigger) {
|
||||
$(`#${i}_textgenerationwebui`).trigger('input');
|
||||
}
|
||||
}
|
||||
|
||||
async function generateTextGenWithStreaming(generate_data, finalPromt) {
|
||||
const response = await fetch('/generate_textgenerationwebui', {
|
||||
headers: {
|
||||
'X-CSRF-Token': token,
|
||||
'Content-Type': 'application/json',
|
||||
'X-Response-Streaming': true,
|
||||
},
|
||||
body: JSON.stringify(generate_data),
|
||||
method: 'POST',
|
||||
});
|
||||
|
||||
return async function* streamData() {
|
||||
const decoder = new TextDecoder();
|
||||
const reader = response.body.getReader();
|
||||
let getMessage = '';
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
let response = decoder.decode(value);
|
||||
|
||||
getMessage += response;
|
||||
|
||||
if (done) {
|
||||
return;
|
||||
}
|
||||
|
||||
yield getMessage;
|
||||
}
|
||||
}
|
||||
}
|
112
server.js
112
server.js
@ -39,6 +39,7 @@ const listen = config.listen;
|
||||
|
||||
const axios = require('axios');
|
||||
const tiktoken = require('@dqbd/tiktoken');
|
||||
const WebSocket = require('ws');
|
||||
|
||||
var Client = require('node-rest-client').Client;
|
||||
var client = new Client();
|
||||
@ -308,11 +309,117 @@ app.post("/generate", jsonParser, async function (request, response_generate = r
|
||||
}
|
||||
});
|
||||
|
||||
function randomHash() {
|
||||
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
|
||||
let result = '';
|
||||
for (let i = 0; i < 9; i++) {
|
||||
result += letters.charAt(Math.floor(Math.random() * letters.length));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function textGenProcessStartedHandler(websocket, content, session, prompt, SEND_PROMPT_GRADIO_FN) {
|
||||
switch (content.msg) {
|
||||
case "send_hash":
|
||||
const send_hash = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN });
|
||||
websocket.send(send_hash);
|
||||
break;
|
||||
case "estimation":
|
||||
break;
|
||||
case "send_data":
|
||||
const send_data = JSON.stringify({ "session_hash": session, "fn_index": SEND_PROMPT_GRADIO_FN, "data": prompt.data });
|
||||
console.log(send_data);
|
||||
websocket.send(send_data);
|
||||
break;
|
||||
case "process_starts":
|
||||
break;
|
||||
case "process_generating":
|
||||
return content.output.data[0];
|
||||
case "process_completed":
|
||||
return null;
|
||||
}
|
||||
|
||||
return '';
|
||||
}
|
||||
|
||||
//************** Text generation web UI
|
||||
app.post("/generate_textgenerationwebui", jsonParser, function (request, response_generate = response) {
|
||||
app.post("/generate_textgenerationwebui", jsonParser, async function (request, response_generate = response) {
|
||||
if (!request.body) return response_generate.sendStatus(400);
|
||||
|
||||
console.log(request.body);
|
||||
|
||||
if (!!request.header('X-Response-Streaming')) {
|
||||
const SEND_PARAMS_GRADIO_FN = 29;
|
||||
|
||||
response_generate.writeHead(200, {
|
||||
'Transfer-Encoding': 'chunked',
|
||||
'Cache-Control': 'no-transform',
|
||||
});
|
||||
|
||||
async function* readWebsocket() {
|
||||
const session = randomHash();
|
||||
const url = new URL(api_server);
|
||||
const websocket = new WebSocket(`ws://${url.host}/queue/join`, { perMessageDeflate: false });
|
||||
let text = '';
|
||||
|
||||
websocket.on('open', async function() {
|
||||
console.log('websocket open');
|
||||
});
|
||||
|
||||
websocket.on('error', (err) => {
|
||||
console.error(err);
|
||||
websocket.close();
|
||||
});
|
||||
|
||||
websocket.on('close', (code, buffer) => {
|
||||
const reason = new TextDecoder().decode(buffer)
|
||||
console.log(reason);
|
||||
});
|
||||
|
||||
websocket.on('message', async (message) => {
|
||||
const content = json5.parse(message);
|
||||
console.log(content);
|
||||
text = textGenProcessStartedHandler(websocket, content, session, request.body, SEND_PARAMS_GRADIO_FN);
|
||||
});
|
||||
|
||||
while (true) {
|
||||
if (websocket.readyState == 0 || websocket.readyState == 1 || websocket.readyState == 2) {
|
||||
await delay(50);
|
||||
yield text;
|
||||
|
||||
if (!text && typeof text !== 'string') {
|
||||
websocket.close();
|
||||
}
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = json5.parse(request.body.data)[0];
|
||||
|
||||
try {
|
||||
for await (const text of readWebsocket()) {
|
||||
if (text == null) {
|
||||
break;
|
||||
}
|
||||
|
||||
let newText = text.substring(result.length);
|
||||
|
||||
if (!newText) {
|
||||
continue;
|
||||
}
|
||||
|
||||
result = text;
|
||||
response_generate.write(newText);
|
||||
}
|
||||
}
|
||||
finally {
|
||||
response_generate.end();
|
||||
}
|
||||
}
|
||||
else {
|
||||
var args = {
|
||||
data: request.body,
|
||||
headers: { "Content-Type": "application/json" }
|
||||
@ -336,6 +443,7 @@ app.post("/generate_textgenerationwebui", jsonParser, function (request, respons
|
||||
//console.log('something went wrong on the request', err.request.options);
|
||||
response_generate.send({ error: true });
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@ -447,7 +555,7 @@ app.post("/getstatus", jsonParser, function (request, response_getstatus = respo
|
||||
if (!response)
|
||||
throw "no_connection";
|
||||
let model = json5.parse(response).components.filter((x) => x.props.label == "Model" && x.type == "dropdown")[0].props.value;
|
||||
data = { result: model };
|
||||
data = { result: model, gradio_config: response };
|
||||
if (!data)
|
||||
throw "no_connection";
|
||||
} catch {
|
||||
|
Reference in New Issue
Block a user