diff --git a/public/script.js b/public/script.js index f430009a2..cd95d6e0c 100644 --- a/public/script.js +++ b/public/script.js @@ -2629,12 +2629,12 @@ class StreamingProcessor { if (!isImpersonate && !isContinue && Array.isArray(this.swipes) && this.swipes.length > 0) { for (let i = 0; i < this.swipes.length; i++) { - this.swipes[i] = cleanUpMessage(this.removePrefix(this.swipes[i]), false, false, true, !isFinal); + this.swipes[i] = cleanUpMessage(this.removePrefix(this.swipes[i]), false, false, true, this.stoppingStrings); } } text = this.removePrefix(text); - let processedText = cleanUpMessage(text, isImpersonate, isContinue, !isFinal, !isFinal); + let processedText = cleanUpMessage(text, isImpersonate, isContinue, !isFinal, this.stoppingStrings); // Predict unbalanced asterisks / quotes during streaming const charsToBalance = ['*', '"', '```']; @@ -2805,6 +2805,10 @@ class StreamingProcessor { scrollLock = false; } + const isImpersonate = this.type == 'impersonate'; + const isContinue = this.type == 'continue'; + this.stoppingStrings = getStoppingStrings(isImpersonate, isContinue); + try { const sw = new Stopwatch(1000 / power_user.streaming_fps); const timestamps = []; @@ -2907,7 +2911,7 @@ export async function generateRaw(prompt, api, instructOverride) { throw new Error(data.error); } - const message = cleanUpMessage(extractMessageFromData(data), false, false, true, false); + const message = cleanUpMessage(extractMessageFromData(data), false, false, true); if (!message) { throw new Error('No message generated'); @@ -3814,7 +3818,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu streamingProcessor.generator = streamingGenerator; hideSwipeButtons(); let getMessage = await streamingProcessor.generate(); - let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false, false); + let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false); if (isContinue) { getMessage = continue_mag + getMessage; @@ -3849,7 +3853,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu const swipes = extractMultiSwipes(data, type); - messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false, false); + messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false); if (isContinue) { getMessage = continue_mag + getMessage; @@ -3857,7 +3861,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu //Formating const displayIncomplete = type === 'quiet' && !quietToLoud; - getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete, false); + getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete); if (getMessage.length > 0) { if (isImpersonate) { @@ -4487,7 +4491,7 @@ function extractMultiSwipes(data, type) { } for (let i = 1; i < data.choices.length; i++) { - const text = cleanUpMessage(data.choices[i].text, false, false, false, false); + const text = cleanUpMessage(data.choices[i].text, false, false, false); swipes.push(text); } } @@ -4495,7 +4499,7 @@ function extractMultiSwipes(data, type) { return swipes; } -function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncompleteSentences = false, skipStopStringCleanup = false) { +function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncompleteSentences = false, stoppingStrings = null) { if (!getMessage) { return ''; } @@ -4510,16 +4514,18 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete getMessage = substituteParams(power_user.user_prompt_bias) + getMessage; } - if (!skipStopStringCleanup) { - const stoppingStrings = getStoppingStrings(isImpersonate, isContinue); + // Allow for caching of stopping strings. getStoppingStrings is an expensive function, especially with macros + // enabled, so for streaming, we call it once and then pass it into each cleanUpMessage call. + if (!stoppingStrings) { + stoppingStrings = getStoppingStrings(isImpersonate, isContinue); + } - for (const stoppingString of stoppingStrings) { - if (stoppingString.length) { - for (let j = stoppingString.length; j > 0; j--) { - if (getMessage.slice(-j) === stoppingString.slice(0, j)) { - getMessage = getMessage.slice(0, -j); - break; - } + for (const stoppingString of stoppingStrings) { + if (stoppingString.length) { + for (let j = stoppingString.length; j > 0; j--) { + if (getMessage.slice(-j) === stoppingString.slice(0, j)) { + getMessage = getMessage.slice(0, -j); + break; } } }