diff --git a/public/script.js b/public/script.js index 5433ce081..91d84ca44 100644 --- a/public/script.js +++ b/public/script.js @@ -121,7 +121,7 @@ import { delay, restoreCaretPosition, saveCaretPosition, - end_trim_to_sentence, + trimToEndSentence, countOccurrences, isOdd, sortMoments, @@ -3781,7 +3781,7 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete getMessage = getRegexedString(getMessage, isImpersonate ? regex_placement.USER_INPUT : regex_placement.AI_OUTPUT); if (!displayIncompleteSentences && power_user.trim_sentences) { - getMessage = end_trim_to_sentence(getMessage, power_user.include_newline); + getMessage = trimToEndSentence(getMessage, power_user.include_newline); } if (power_user.collapse_newlines) { diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index b8b7930df..7437d305f 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -3,7 +3,7 @@ import { dragElement, isMobile } from "../../RossAscends-mods.js"; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplate } from "../../extensions.js"; import { loadMovingUIState, power_user } from "../../power-user.js"; import { registerSlashCommand } from "../../slash-commands.js"; -import { onlyUnique, debounce, getCharaFilename } from "../../utils.js"; +import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from "../../utils.js"; export { MODULE_NAME }; const MODULE_NAME = 'expressions'; @@ -709,12 +709,42 @@ async function setSpriteSlashCommand(_, spriteId) { await sendExpressionCall(spriteFolderName, spriteItem.label, true, vnMode); } +/** + * Processes the classification text to reduce the amount of text sent to the API. + * Quotes and asterisks are to be removed. If the text is less than 300 characters, it is returned as is. + * If the text is more than 300 characters, the first and last 150 characters are returned. + * The result is trimmed to the end of sentence. + * @param {string} text The text to process. + * @returns {string} + */ +function sampleClassifyText(text) { + if (!text) { + return text; + } + + // Remove asterisks and quotes + let result = text.replace(/[\*\"]/g, ''); + + const SAMPLE_THRESHOLD = 300; + const HALF_SAMPLE_THRESHOLD = SAMPLE_THRESHOLD / 2; + + if (text.length < SAMPLE_THRESHOLD) { + result = trimToEndSentence(result); + } else { + result = trimToEndSentence(result.slice(0, HALF_SAMPLE_THRESHOLD)) + ' ' + trimToStartSentence(result.slice(-HALF_SAMPLE_THRESHOLD)); + } + + return result.trim(); +} + async function getExpressionLabel(text) { // Return if text is undefined, saving a costly fetch request if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) { return FALLBACK_EXPRESSION; } + text = sampleClassifyText(text); + try { if (extension_settings.expressions.local) { // Local transformers pipeline diff --git a/public/scripts/utils.js b/public/scripts/utils.js index 121c97334..6ea5fc667 100644 --- a/public/scripts/utils.js +++ b/public/scripts/utils.js @@ -438,9 +438,9 @@ export function sortByCssOrder(a, b) { * @param {boolean} include_newline Whether to include a newline character in the trimmed string. * @returns {string} The trimmed string. * @example - * end_trim_to_sentence('Hello, world! I am from'); // 'Hello, world!' + * trimToEndSentence('Hello, world! I am from'); // 'Hello, world!' */ -export function end_trim_to_sentence(input, include_newline = false) { +export function trimToEndSentence(input, include_newline = false) { const punctuation = new Set(['.', '!', '?', '*', '"', ')', '}', '`', ']', '$', '。', '!', '?', '”', ')', '】', '】', '’', '」', '】']); // extend this as you see fit let last = -1; @@ -465,6 +465,26 @@ export function end_trim_to_sentence(input, include_newline = false) { return input.substring(0, last + 1).trimEnd(); } +export function trimToStartSentence(input) { + let p1 = input.indexOf("."); + let p2 = input.indexOf("!"); + let p3 = input.indexOf("?"); + let p4 = input.indexOf("\n"); + let first = p1; + let skip1 = false; + if (p2 > 0 && p2 < first) { first = p2; } + if (p3 > 0 && p3 < first) { first = p3; } + if (p4 > 0 && p4 < first) { first = p4; skip1 = true; } + if (first > 0) { + if (skip1) { + return input.substring(first + 1); + } else { + return input.substring(first + 2); + } + } + return input; +} + /** * Counts the number of occurrences of a character in a string. * @param {string} string The string to count occurrences in. diff --git a/src/classify.mjs b/src/classify.mjs index 46d959126..97062d84c 100644 --- a/src/classify.mjs +++ b/src/classify.mjs @@ -24,6 +24,7 @@ class PipelineAccessor { * @param {any} jsonParser */ function registerEndpoints(app, jsonParser) { + const cacheObject = {}; const pipelineAccessor = new PipelineAccessor(); app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { @@ -35,10 +36,19 @@ function registerEndpoints(app, jsonParser) { app.post('/api/extra/classify', jsonParser, async (req, res) => { const { text } = req.body; - const pipe = await pipelineAccessor.get(); - const result = await pipe(text); + async function getResult(text) { + if (cacheObject.hasOwnProperty(text)) { + return cacheObject[text]; + } else { + const pipe = await pipelineAccessor.get(); + const result = await pipe(text); + cacheObject[text] = result; + return result; + } + } console.log('Classify input:', text); + const result = await getResult(text); console.log('Classify output:', result); return res.json({ classification: result });