Cache and sample classification results

This commit is contained in:
Cohee 2023-09-09 17:31:27 +03:00
parent 180dcefe40
commit 4cf6a1f7da
4 changed files with 67 additions and 7 deletions

View File

@ -121,7 +121,7 @@ import {
delay,
restoreCaretPosition,
saveCaretPosition,
end_trim_to_sentence,
trimToEndSentence,
countOccurrences,
isOdd,
sortMoments,
@ -3781,7 +3781,7 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete
getMessage = getRegexedString(getMessage, isImpersonate ? regex_placement.USER_INPUT : regex_placement.AI_OUTPUT);
if (!displayIncompleteSentences && power_user.trim_sentences) {
getMessage = end_trim_to_sentence(getMessage, power_user.include_newline);
getMessage = trimToEndSentence(getMessage, power_user.include_newline);
}
if (power_user.collapse_newlines) {

View File

@ -3,7 +3,7 @@ import { dragElement, isMobile } from "../../RossAscends-mods.js";
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplate } from "../../extensions.js";
import { loadMovingUIState, power_user } from "../../power-user.js";
import { registerSlashCommand } from "../../slash-commands.js";
import { onlyUnique, debounce, getCharaFilename } from "../../utils.js";
import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from "../../utils.js";
export { MODULE_NAME };
const MODULE_NAME = 'expressions';
@ -709,12 +709,42 @@ async function setSpriteSlashCommand(_, spriteId) {
await sendExpressionCall(spriteFolderName, spriteItem.label, true, vnMode);
}
/**
* Processes the classification text to reduce the amount of text sent to the API.
* Quotes and asterisks are to be removed. If the text is less than 300 characters, it is returned as is.
* If the text is more than 300 characters, the first and last 150 characters are returned.
* The result is trimmed to the end of sentence.
* @param {string} text The text to process.
* @returns {string}
*/
function sampleClassifyText(text) {
if (!text) {
return text;
}
// Remove asterisks and quotes
let result = text.replace(/[\*\"]/g, '');
const SAMPLE_THRESHOLD = 300;
const HALF_SAMPLE_THRESHOLD = SAMPLE_THRESHOLD / 2;
if (text.length < SAMPLE_THRESHOLD) {
result = trimToEndSentence(result);
} else {
result = trimToEndSentence(result.slice(0, HALF_SAMPLE_THRESHOLD)) + ' ' + trimToStartSentence(result.slice(-HALF_SAMPLE_THRESHOLD));
}
return result.trim();
}
async function getExpressionLabel(text) {
// Return if text is undefined, saving a costly fetch request
if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) {
return FALLBACK_EXPRESSION;
}
text = sampleClassifyText(text);
try {
if (extension_settings.expressions.local) {
// Local transformers pipeline

View File

@ -438,9 +438,9 @@ export function sortByCssOrder(a, b) {
* @param {boolean} include_newline Whether to include a newline character in the trimmed string.
* @returns {string} The trimmed string.
* @example
* end_trim_to_sentence('Hello, world! I am from'); // 'Hello, world!'
* trimToEndSentence('Hello, world! I am from'); // 'Hello, world!'
*/
export function end_trim_to_sentence(input, include_newline = false) {
export function trimToEndSentence(input, include_newline = false) {
const punctuation = new Set(['.', '!', '?', '*', '"', ')', '}', '`', ']', '$', '。', '', '', '”', '', '】', '】', '', '」', '】']); // extend this as you see fit
let last = -1;
@ -465,6 +465,26 @@ export function end_trim_to_sentence(input, include_newline = false) {
return input.substring(0, last + 1).trimEnd();
}
export function trimToStartSentence(input) {
let p1 = input.indexOf(".");
let p2 = input.indexOf("!");
let p3 = input.indexOf("?");
let p4 = input.indexOf("\n");
let first = p1;
let skip1 = false;
if (p2 > 0 && p2 < first) { first = p2; }
if (p3 > 0 && p3 < first) { first = p3; }
if (p4 > 0 && p4 < first) { first = p4; skip1 = true; }
if (first > 0) {
if (skip1) {
return input.substring(first + 1);
} else {
return input.substring(first + 2);
}
}
return input;
}
/**
* Counts the number of occurrences of a character in a string.
* @param {string} string The string to count occurrences in.

View File

@ -24,6 +24,7 @@ class PipelineAccessor {
* @param {any} jsonParser
*/
function registerEndpoints(app, jsonParser) {
const cacheObject = {};
const pipelineAccessor = new PipelineAccessor();
app.post('/api/extra/classify/labels', jsonParser, async (req, res) => {
@ -35,10 +36,19 @@ function registerEndpoints(app, jsonParser) {
app.post('/api/extra/classify', jsonParser, async (req, res) => {
const { text } = req.body;
const pipe = await pipelineAccessor.get();
const result = await pipe(text);
async function getResult(text) {
if (cacheObject.hasOwnProperty(text)) {
return cacheObject[text];
} else {
const pipe = await pipelineAccessor.get();
const result = await pipe(text);
cacheObject[text] = result;
return result;
}
}
console.log('Classify input:', text);
const result = await getResult(text);
console.log('Classify output:', result);
return res.json({ classification: result });