Foundation for token streaming (non-functional at the moment)

This commit is contained in:
SillyLossy
2023-04-11 22:49:05 +03:00
parent 8c8c2c40c4
commit dee813dfa7
2 changed files with 188 additions and 94 deletions

View File

@ -190,6 +190,7 @@ let exportPopper = Popper.createPopper(document.getElementById('export_button'),
});
let dialogueResolve = null;
let chat_metadata = {};
let streamingProcessor = null;
const durationSaveEdit = 200;
const saveSettingsDebounced = debounce(() => saveSettings(), durationSaveEdit);
@ -1105,6 +1106,86 @@ function appendToStoryString(value, prefix) {
return '';
}
function isStreamingEnabled() {
return (main_api == 'openai' && oai_settings.stream_openai);
}
class StreamingProcessor {
onStartStreaming(text) {
saveReply(type, text);
return (count_view_mes - 1);
}
onProgressStreaming(messageId, text) {
let processedText = cleanUpMessage(text);
({isName, processedText} = extractNameFromMessage(processedText, force_name2));
chat[messageId]['is_name'] = isName;
chat[messageId]['mes'] = processedText;
let formattedText = messageFormating(processedText, chat[messageId].name, chat[messageId].is_system, chat[messageId].force_avatar);
const mesText = $(`#chat .mes[mesid="${messageId}"] .mes_text`);
mesText.empty();
mesText.append(formattedText);
}
onFinishStreaming(messageId, text) {
this.onProgressStreaming(messageId, text);
playMessageSound();
saveChatConditional();
activateSendButtons();
showSwipeButtons();
setGenerationProgress(0);
$('.mes_edit:last').show();
}
onErrorStreaming() {
$("#send_textarea").removeAttr('disabled');
is_send_press = false;
activateSendButtons();
setGenerationProgress(0);
}
onStopStreaming() {
this.onErrorStreaming();
}
nullStreamingGeneration() {
throw new Error('Generation function for streaming is not hooked up');
}
constructor() {
this.result = "";
this.messageId = -1;
this.isStopped = false;
this.isFinished = false;
this.generator = this.nullStreamingGeneration;
}
async generate() {
this.messageId = this.onStartStreaming('');
for await (const text of this.generator()) {
if (this.isStopped) {
this.onStopStreaming();
return;
}
try {
this.result = text;
this.onProgressStreaming(this.messageId, text);
}
catch (err) {
console.error(err);
this.onErrorStreaming();
this.isStopped = true;
return;
}
}
this.isFinished = true;
this.onFinishStreaming(this.messageId, this.result);
}
}
async function Generate(type, automatic_trigger, force_name2) {
console.log('Generate entered');
setGenerationProgress(0);
@ -1696,11 +1777,19 @@ async function Generate(type, automatic_trigger, force_name2) {
}
console.log('rungenerate calling API');
streamingProcessor = new StreamingProcessor();
if (main_api == 'openai') {
let prompt = await prepareOpenAIMessages(name2, storyString, worldInfoBefore, worldInfoAfter, extension_prompt, promptBias);
if (isStreamingEnabled()) {
streamingProcessor.generator = () => sendOpenAIRequest(prompt);
await streamingProcessor.generate();
}
else {
sendOpenAIRequest(prompt).then(onSuccess).catch(onError);
}
}
else if (main_api == 'kobold' && horde_settings.use_horde) {
generateHorde(finalPromt, generate_data).then(onSuccess).catch(onError);
}
@ -1729,32 +1818,7 @@ async function Generate(type, automatic_trigger, force_name2) {
is_send_press = false;
if (!data.error) {
//const getData = await response.json();
var getMessage = "";
if (main_api == 'kobold' && !horde_settings.use_horde) {
getMessage = data.results[0].text;
}
else if (main_api == 'kobold' && horde_settings.use_horde) {
getMessage = data;
}
else if (main_api == 'textgenerationwebui') {
getMessage = data.data[0];
if (getMessage == null || data.error) {
activateSendButtons();
callPopup('<h3>Got empty response from Text generation web UI. Try restarting the API with recommended options.</h3>', 'text');
return;
}
getMessage = getMessage.substring(finalPromt.length);
}
else if (main_api == 'novel') {
getMessage = data.output;
}
if (main_api == 'openai' || main_api == 'poe') {
getMessage = data;
}
if (power_user.collapse_newlines) {
getMessage = collapseNewlines(getMessage);
}
let getMessage = extractMessageFromData(data, finalPromt);
//Pygmalion run again
// to make it continue generating so long as it's under max_amount and hasn't signaled
@ -1776,50 +1840,21 @@ async function Generate(type, automatic_trigger, force_name2) {
}
//Formating
getMessage = $.trim(getMessage);
if (is_pygmalion) {
getMessage = getMessage.replace(/<USER>/g, name1);
getMessage = getMessage.replace(/<BOT>/g, name2);
getMessage = getMessage.replace(/You:/g, name1 + ':');
}
if (getMessage.indexOf(name1 + ":") != -1) {
getMessage = getMessage.substr(0, getMessage.indexOf(name1 + ":"));
getMessage = cleanUpMessage(getMessage);
}
if (getMessage.indexOf('<|endoftext|>') != -1) {
getMessage = getMessage.substr(0, getMessage.indexOf('<|endoftext|>'));
}
// clean-up group message from excessive generations
if (selected_group) {
getMessage = cleanGroupMessage(getMessage);
}
let this_mes_is_name = true;
if (getMessage.indexOf(name2 + ":") === 0) {
getMessage = getMessage.replace(name2 + ':', '');
getMessage = getMessage.trimStart();
} else {
this_mes_is_name = false;
}
if (force_name2) this_mes_is_name = true;
let this_mes_is_name;
({ this_mes_is_name, getMessage } = extractNameFromMessage(getMessage, force_name2));
//getMessage = getMessage.replace(/^\s+/g, '');
if (getMessage.length > 0) {
({ type, getMessage } = saveReply(type, getMessage, this_mes_is_name));
activateSendButtons();
playMessageSound();
generate_loop_counter = 0;
} else {
++generate_loop_counter;
if (generate_loop_counter > MAX_GENERATION_LOOPS) {
callPopup(`Could not extract reply in ${MAX_GENERATION_LOOPS} attempts. Try generating again`, 'text');
generate_loop_counter = 0;
$("#send_textarea").removeAttr('disabled');
is_send_press = false;
activateSendButtons();
setGenerationProgress(0);
showSwipeButtons();
$('.mes_edit:last').show();
throw new Error('Generate circuit breaker interruption');
throwCircuitBreakerError();
}
// regenerate with character speech reenforced
@ -1837,7 +1872,6 @@ async function Generate(type, automatic_trigger, force_name2) {
console.log('/savechat called by /Generate');
saveChatConditional();
activateSendButtons();
showSwipeButtons();
setGenerationProgress(0);
@ -1866,6 +1900,89 @@ async function Generate(type, automatic_trigger, force_name2) {
console.log('generate ending');
} //generate ends
function extractNameFromMessage(getMessage, force_name2) {
let this_mes_is_name = true;
if (getMessage.indexOf(name2 + ":") === 0) {
getMessage = getMessage.replace(name2 + ':', '');
getMessage = getMessage.trimStart();
} else {
this_mes_is_name = false;
}
if (force_name2)
this_mes_is_name = true;
return { this_mes_is_name, getMessage };
}
function throwCircuitBreakerError() {
callPopup(`Could not extract reply in ${MAX_GENERATION_LOOPS} attempts. Try generating again`, 'text');
generate_loop_counter = 0;
$("#send_textarea").removeAttr('disabled');
is_send_press = false;
activateSendButtons();
setGenerationProgress(0);
showSwipeButtons();
$('.mes_edit:last').show();
throw new Error('Generate circuit breaker interruption');
}
function extractMessageFromData(data, finalPromt) {
let getMessage = "";
if (main_api == 'kobold' && !horde_settings.use_horde) {
getMessage = data.results[0].text;
}
if (main_api == 'kobold' && horde_settings.use_horde) {
getMessage = data;
}
if (main_api == 'textgenerationwebui') {
getMessage = data.data[0];
if (getMessage == null || data.error) {
activateSendButtons();
callPopup('<h3>Got empty response from Text generation web UI. Try restarting the API with recommended options.</h3>', 'text');
return;
}
getMessage = getMessage.substring(finalPromt.length);
}
if (main_api == 'novel') {
getMessage = data.output;
}
if (main_api == 'openai' || main_api == 'poe') {
getMessage = data;
}
return getMessage;
}
function cleanUpMessage(getMessage) {
if (power_user.collapse_newlines) {
getMessage = collapseNewlines(getMessage);
}
getMessage = $.trim(getMessage);
if (is_pygmalion) {
getMessage = getMessage.replace(/<USER>/g, name1);
getMessage = getMessage.replace(/<BOT>/g, name2);
getMessage = getMessage.replace(/You:/g, name1 + ':');
}
if (getMessage.indexOf(name1 + ":") != -1) {
getMessage = getMessage.substr(0, getMessage.indexOf(name1 + ":"));
}
if (getMessage.indexOf('<|endoftext|>') != -1) {
getMessage = getMessage.substr(0, getMessage.indexOf('<|endoftext|>'));
}
// clean-up group message from excessive generations
if (selected_group) {
getMessage = cleanGroupMessage(getMessage);
}
return getMessage;
}
function saveReply(type, getMessage, this_mes_is_name) {
if (chat.length && (chat[chat.length - 1]['swipe_id'] === undefined ||
chat[chat.length - 1]['is_user'])) {
@ -1905,8 +2022,6 @@ function saveReply(type, getMessage, this_mes_is_name) {
}
//console.log('runGenerate calls addOneMessage');
addOneMessage(chat[chat.length - 1]);
activateSendButtons();
}
return { type, getMessage };
}

View File

@ -453,51 +453,30 @@ async function sendOpenAIRequest(openai_msgs_tosend) {
}
// Unused
function onStream(e, resolve, reject, last_view_mes) {
let end = false;
if (!oai_settings.stream_openai)
async function* onStream(e) {
if (!oai_settings.stream_openai) {
return;
}
let response = e.currentTarget.response;
if (response == "{\"error\":true}") {
reject('', 'error');
throw new Error('error during streaming');
}
let eventList = response.split("\n");
let getMessage = "";
for (let event of eventList) {
if (!event.startsWith("data"))
continue;
if (event == "data: [DONE]") {
chat[chat.length - 1]['mes'] = getMessage;
$("#send_but").css("display", "block");
$("#loading_mes").css("display", "none");
saveChat();
end = true;
break;
return getMessage;
}
let data = JSON.parse(event.substring(6));
// the first and last messages are undefined, protect against that
getMessage += data.choices[0]["delta"]["content"] || "";
}
if ($("#chat").children().filter(`[mesid="${last_view_mes}"]`).length == 0) {
chat[chat.length] = {};
chat[chat.length - 1]['name'] = name2;
chat[chat.length - 1]['is_user'] = false;
chat[chat.length - 1]['is_name'] = false;
chat[chat.length - 1]['send_date'] = Date.now();
chat[chat.length - 1]['mes'] = "";
addOneMessage(chat[chat.length - 1]);
}
let messageText = messageFormating($.trim(getMessage), name1);
$("#chat").children().filter(`[mesid="${last_view_mes}"]`).children('.mes_block').children('.mes_text').html(messageText);
let $textchat = $('#chat');
$textchat.scrollTop($textchat[0].scrollHeight);
if (end) {
resolve();
yield getMessage;
}
}