diff --git a/public/script.js b/public/script.js index 19b651858..c9d45baf7 100644 --- a/public/script.js +++ b/public/script.js @@ -2629,12 +2629,12 @@ class StreamingProcessor { if (!isImpersonate && !isContinue && Array.isArray(this.swipes) && this.swipes.length > 0) { for (let i = 0; i < this.swipes.length; i++) { - this.swipes[i] = cleanUpMessage(this.removePrefix(this.swipes[i]), false, false, true, !isFinal); + this.swipes[i] = cleanUpMessage(this.removePrefix(this.swipes[i]), false, false, true, this.stoppingStrings); } } text = this.removePrefix(text); - let processedText = cleanUpMessage(text, isImpersonate, isContinue, !isFinal, !isFinal); + let processedText = cleanUpMessage(text, isImpersonate, isContinue, !isFinal, this.stoppingStrings); // Predict unbalanced asterisks / quotes during streaming const charsToBalance = ['*', '"', '```']; @@ -2805,6 +2805,12 @@ class StreamingProcessor { scrollLock = false; } + // Stopping strings are expensive to calculate, especially with macros enabled. To remove stopping strings + // when streaming, we cache the result of getStoppingStrings instead of calling it once per token. + const isImpersonate = this.type == 'impersonate'; + const isContinue = this.type == 'continue'; + this.stoppingStrings = getStoppingStrings(isImpersonate, isContinue); + try { const sw = new Stopwatch(1000 / power_user.streaming_fps); const timestamps = []; @@ -2907,7 +2913,7 @@ export async function generateRaw(prompt, api, instructOverride) { throw new Error(data.error); } - const message = cleanUpMessage(extractMessageFromData(data), false, false, true, false); + const message = cleanUpMessage(extractMessageFromData(data), false, false, true); if (!message) { throw new Error('No message generated'); @@ -3814,7 +3820,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu streamingProcessor.generator = streamingGenerator; hideSwipeButtons(); let getMessage = await streamingProcessor.generate(); - let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false, false); + let messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false); if (isContinue) { getMessage = continue_mag + getMessage; @@ -3849,7 +3855,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu const swipes = extractMultiSwipes(data, type); - messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false, false); + messageChunk = cleanUpMessage(getMessage, isImpersonate, isContinue, false); if (isContinue) { getMessage = continue_mag + getMessage; @@ -3857,7 +3863,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu //Formating const displayIncomplete = type === 'quiet' && !quietToLoud; - getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete, false); + getMessage = cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete); if (getMessage.length > 0) { if (isImpersonate) { @@ -4487,7 +4493,7 @@ function extractMultiSwipes(data, type) { } for (let i = 1; i < data.choices.length; i++) { - const text = cleanUpMessage(data.choices[i].text, false, false, false, false); + const text = cleanUpMessage(data.choices[i].text, false, false, false); swipes.push(text); } } @@ -4495,7 +4501,7 @@ function extractMultiSwipes(data, type) { return swipes; } -function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncompleteSentences = false, skipStopStringCleanup = false) { +function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncompleteSentences = false, stoppingStrings = null) { if (!getMessage) { return ''; } @@ -4510,16 +4516,18 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete getMessage = substituteParams(power_user.user_prompt_bias) + getMessage; } - if (!skipStopStringCleanup) { - const stoppingStrings = getStoppingStrings(isImpersonate, isContinue); + // Allow for caching of stopping strings. getStoppingStrings is an expensive function, especially with macros + // enabled, so for streaming, we call it once and then pass it into each cleanUpMessage call. + if (!stoppingStrings) { + stoppingStrings = getStoppingStrings(isImpersonate, isContinue); + } - for (const stoppingString of stoppingStrings) { - if (stoppingString.length) { - for (let j = stoppingString.length; j > 0; j--) { - if (getMessage.slice(-j) === stoppingString.slice(0, j)) { - getMessage = getMessage.slice(0, -j); - break; - } + for (const stoppingString of stoppingStrings) { + if (stoppingString.length) { + for (let j = stoppingString.length; j > 0; j--) { + if (getMessage.slice(-j) === stoppingString.slice(0, j)) { + getMessage = getMessage.slice(0, -j); + break; } } } diff --git a/public/scripts/tokenizers.js b/public/scripts/tokenizers.js index bef54b791..decd0f919 100644 --- a/public/scripts/tokenizers.js +++ b/public/scripts/tokenizers.js @@ -63,6 +63,7 @@ const TOKENIZER_URLS = { }, [tokenizers.API_KOBOLD]: { count: '/api/tokenizers/remote/kobold/count', + encode: '/api/tokenizers/remote/kobold/count', }, [tokenizers.MISTRAL]: { encode: '/api/tokenizers/mistral/encode', @@ -617,6 +618,32 @@ function getTextTokensFromTextgenAPI(str) { return ids; } +/** + * Calls the AI provider's tokenize API to encode a string to tokens. + * @param {string} str String to tokenize. + * @returns {number[]} Array of token ids. + */ +function getTextTokensFromKoboldAPI(str) { + let ids = []; + + jQuery.ajax({ + async: false, + type: 'POST', + url: TOKENIZER_URLS[tokenizers.API_KOBOLD].encode, + data: JSON.stringify({ + text: str, + url: api_server, + }), + dataType: 'json', + contentType: 'application/json', + success: function (data) { + ids = data.ids; + }, + }); + + return ids; +} + /** * Calls the underlying tokenizer model to decode token ids to text. * @param {string} endpoint API endpoint. @@ -650,6 +677,8 @@ export function getTextTokens(tokenizerType, str) { return getTextTokens(currentRemoteTokenizerAPI(), str); case tokenizers.API_TEXTGENERATIONWEBUI: return getTextTokensFromTextgenAPI(str); + case tokenizers.API_KOBOLD: + return getTextTokensFromKoboldAPI(str); default: { const tokenizerEndpoints = TOKENIZER_URLS[tokenizerType]; if (!tokenizerEndpoints) { diff --git a/server.js b/server.js index 7d0dd68fc..d0936103c 100644 --- a/server.js +++ b/server.js @@ -1438,7 +1438,7 @@ app.use('/api/serpapi', require('./src/endpoints/serpapi').router); // The different text generation APIs // Ooba/OpenAI text completions -app.use('/api/backends/ooba', require('./src/endpoints/backends/ooba').router); +app.use('/api/backends/text-completions', require('./src/endpoints/backends/text-completions').router); // KoboldAI app.use('/api/textgen/kobold', require('./src/endpoints/textgen/kobold').router); diff --git a/src/endpoints/backends/ooba.js b/src/endpoints/backends/text-completions.js similarity index 99% rename from src/endpoints/backends/ooba.js rename to src/endpoints/backends/text-completions.js index 75d9ea439..71387eefd 100644 --- a/src/endpoints/backends/ooba.js +++ b/src/endpoints/backends/text-completions.js @@ -1,4 +1,5 @@ const express = require('express'); +const fetch = require('node-fetch').default; const { jsonParser } = require('../../express-common'); const { TEXTGEN_TYPES } = require('../../constants'); diff --git a/src/endpoints/tokenizers.js b/src/endpoints/tokenizers.js index 27ef4faf3..a81779d97 100644 --- a/src/endpoints/tokenizers.js +++ b/src/endpoints/tokenizers.js @@ -562,7 +562,8 @@ router.post('/remote/kobold/count', jsonParser, async function (request, respons const data = await result.json(); const count = data['value']; - return response.send({ count, ids: [] }); + const ids = data['ids'] ?? []; + return response.send({ count, ids }); } catch (error) { console.log(error); return response.send({ error: true }); @@ -617,7 +618,7 @@ router.post('/remote/textgenerationwebui/encode', jsonParser, async function (re const data = await result.json(); const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value); - const ids = legacyApi ? [] : (data?.tokens ?? []); + const ids = legacyApi ? [] : (data?.tokens ?? data?.ids ?? []); return response.send({ count, ids }); } catch (error) { diff --git a/src/util.js b/src/util.js index 11b864092..be8d5135f 100644 --- a/src/util.js +++ b/src/util.js @@ -349,7 +349,7 @@ function getImages(path) { /** * Pipe a fetch() response to an Express.js Response, including status code. - * @param {Response} from The Fetch API response to pipe from. + * @param {import('node-fetch').Response} from The Fetch API response to pipe from. * @param {Express.Response} to The Express response to pipe to. */ function forwardFetchResponse(from, to) {