Merge pull request #1525 from valadaptive/cache-stopping-strings

Cache stopping strings rather than skipping them during streaming
This commit is contained in:
Cohee 2023-12-13 01:06:44 +02:00 committed by GitHub
commit 51d50f97cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 17 deletions

View File

@ -2629,12 +2629,12 @@ class StreamingProcessor {
if (!isImpersonate && !isContinue && Array.isArray(this.swipes) && this.swipes.length > 0) {
for (let i = 0; i < this.swipes.length; i++) {
this.swipes[i] = cleanUpMessage(this.removePrefix(this.swipes[i]), false, false, true, !isFinal);
this.swipes[i] = cleanUpMessage(this.removePrefix(this.swipes[i]), false, false, true, this.stoppingStrings);
}
}
text = this.removePrefix(text);
let processedText = cleanUpMessage(text, isImpersonate, isContinue, !isFinal, !isFinal);
let processedText = cleanUpMessage(text, isImpersonate, isContinue, !isFinal, this.stoppingStrings);
// Predict unbalanced asterisks / quotes during streaming
const charsToBalance = ['*', '"', '```'];
@ -2805,6 +2805,12 @@ class StreamingProcessor {
scrollLock = false;
}
// Stopping strings are expensive to calculate, especially with macros enabled. To remove stopping strings
// when streaming, we cache the result of getStoppingStrings instead of calling it once per token.
const isImpersonate = this.type == 'impersonate';
const isContinue = this.type == 'continue';
this.stoppingStrings = getStoppingStrings(isImpersonate, isContinue);
try {
const sw = new Stopwatch(1000 / power_user.streaming_fps);
const timestamps = [];
@ -2907,7 +2913,7 @@ export async function generateRaw(prompt, api, instructOverride) {
throw new Error(data.error);
}
const message = cleanUpMessage(extractMessageFromData(data), false, false, true, false);
const message = cleanUpMessage(extractMessageFromData(data), false, false, true);
if (!message) {
throw new Error('No message generated');
@ -3814,7 +3820,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
streamingProcessor.generator = streamingGenerator;
hideSwipeButtons();
let getMessage = await streamingProcessor.generate();
let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false, false);
let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false);
if (isContinue) {
getMessage = continue_mag + getMessage;
@ -3849,7 +3855,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
const swipes = extractMultiSwipes(data, type);
messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false, false);
messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false);
if (isContinue) {
getMessage = continue_mag + getMessage;
@ -3857,7 +3863,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu
//Formating
const displayIncomplete = type === 'quiet' && !quietToLoud;
getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete, false);
getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete);
if (getMessage.length > 0) {
if (isImpersonate) {
@ -4487,7 +4493,7 @@ function extractMultiSwipes(data, type) {
}
for (let i = 1; i < data.choices.length; i++) {
const text = cleanUpMessage(data.choices[i].text, false, false, false, false);
const text = cleanUpMessage(data.choices[i].text, false, false, false);
swipes.push(text);
}
}
@ -4495,7 +4501,7 @@ function extractMultiSwipes(data, type) {
return swipes;
}
function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncompleteSentences = false, skipStopStringCleanup = false) {
function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncompleteSentences = false, stoppingStrings = null) {
if (!getMessage) {
return '';
}
@ -4510,16 +4516,18 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete
getMessage = substituteParams(power_user.user_prompt_bias) + getMessage;
}
if (!skipStopStringCleanup) {
const stoppingStrings = getStoppingStrings(isImpersonate, isContinue);
// Allow for caching of stopping strings. getStoppingStrings is an expensive function, especially with macros
// enabled, so for streaming, we call it once and then pass it into each cleanUpMessage call.
if (!stoppingStrings) {
stoppingStrings = getStoppingStrings(isImpersonate, isContinue);
}
for (const stoppingString of stoppingStrings) {
if (stoppingString.length) {
for (let j = stoppingString.length; j > 0; j--) {
if (getMessage.slice(-j) === stoppingString.slice(0, j)) {
getMessage = getMessage.slice(0, -j);
break;
}
for (const stoppingString of stoppingStrings) {
if (stoppingString.length) {
for (let j = stoppingString.length; j > 0; j--) {
if (getMessage.slice(-j) === stoppingString.slice(0, j)) {
getMessage = getMessage.slice(0, -j);
break;
}
}
}