1
0
mirror of https://github.com/SillyTavern/SillyTavern.git synced 2025-03-12 01:50:11 +01:00

Cache and sample classification results

This commit is contained in:
Cohee 2023-09-09 17:31:27 +03:00
parent 180dcefe40
commit 4cf6a1f7da
4 changed files with 67 additions and 7 deletions
public
script.js
scripts
extensions/expressions
utils.js
src

@ -121,7 +121,7 @@ import {
delay, delay,
restoreCaretPosition, restoreCaretPosition,
saveCaretPosition, saveCaretPosition,
end_trim_to_sentence, trimToEndSentence,
countOccurrences, countOccurrences,
isOdd, isOdd,
sortMoments, sortMoments,
@ -3781,7 +3781,7 @@ function cleanUpMessage(getMessage, isImpersonate, isContinue, displayIncomplete
getMessage = getRegexedString(getMessage, isImpersonate ? regex_placement.USER_INPUT : regex_placement.AI_OUTPUT); getMessage = getRegexedString(getMessage, isImpersonate ? regex_placement.USER_INPUT : regex_placement.AI_OUTPUT);
if (!displayIncompleteSentences && power_user.trim_sentences) { if (!displayIncompleteSentences && power_user.trim_sentences) {
getMessage = end_trim_to_sentence(getMessage, power_user.include_newline); getMessage = trimToEndSentence(getMessage, power_user.include_newline);
} }
if (power_user.collapse_newlines) { if (power_user.collapse_newlines) {

@ -3,7 +3,7 @@ import { dragElement, isMobile } from "../../RossAscends-mods.js";
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplate } from "../../extensions.js"; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplate } from "../../extensions.js";
import { loadMovingUIState, power_user } from "../../power-user.js"; import { loadMovingUIState, power_user } from "../../power-user.js";
import { registerSlashCommand } from "../../slash-commands.js"; import { registerSlashCommand } from "../../slash-commands.js";
import { onlyUnique, debounce, getCharaFilename } from "../../utils.js"; import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from "../../utils.js";
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'expressions'; const MODULE_NAME = 'expressions';
@ -709,12 +709,42 @@ async function setSpriteSlashCommand(_, spriteId) {
await sendExpressionCall(spriteFolderName, spriteItem.label, true, vnMode); await sendExpressionCall(spriteFolderName, spriteItem.label, true, vnMode);
} }
/**
* Processes the classification text to reduce the amount of text sent to the API.
* Quotes and asterisks are to be removed. If the text is less than 300 characters, it is returned as is.
* If the text is more than 300 characters, the first and last 150 characters are returned.
* The result is trimmed to the end of sentence.
* @param {string} text The text to process.
* @returns {string}
*/
function sampleClassifyText(text) {
if (!text) {
return text;
}
// Remove asterisks and quotes
let result = text.replace(/[\*\"]/g, '');
const SAMPLE_THRESHOLD = 300;
const HALF_SAMPLE_THRESHOLD = SAMPLE_THRESHOLD / 2;
if (text.length < SAMPLE_THRESHOLD) {
result = trimToEndSentence(result);
} else {
result = trimToEndSentence(result.slice(0, HALF_SAMPLE_THRESHOLD)) + ' ' + trimToStartSentence(result.slice(-HALF_SAMPLE_THRESHOLD));
}
return result.trim();
}
async function getExpressionLabel(text) { async function getExpressionLabel(text) {
// Return if text is undefined, saving a costly fetch request // Return if text is undefined, saving a costly fetch request
if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) { if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) {
return FALLBACK_EXPRESSION; return FALLBACK_EXPRESSION;
} }
text = sampleClassifyText(text);
try { try {
if (extension_settings.expressions.local) { if (extension_settings.expressions.local) {
// Local transformers pipeline // Local transformers pipeline

@ -438,9 +438,9 @@ export function sortByCssOrder(a, b) {
* @param {boolean} include_newline Whether to include a newline character in the trimmed string. * @param {boolean} include_newline Whether to include a newline character in the trimmed string.
* @returns {string} The trimmed string. * @returns {string} The trimmed string.
* @example * @example
* end_trim_to_sentence('Hello, world! I am from'); // 'Hello, world!' * trimToEndSentence('Hello, world! I am from'); // 'Hello, world!'
*/ */
export function end_trim_to_sentence(input, include_newline = false) { export function trimToEndSentence(input, include_newline = false) {
const punctuation = new Set(['.', '!', '?', '*', '"', ')', '}', '`', ']', '$', '。', '', '', '”', '', '】', '】', '', '」', '】']); // extend this as you see fit const punctuation = new Set(['.', '!', '?', '*', '"', ')', '}', '`', ']', '$', '。', '', '', '”', '', '】', '】', '', '」', '】']); // extend this as you see fit
let last = -1; let last = -1;
@ -465,6 +465,26 @@ export function end_trim_to_sentence(input, include_newline = false) {
return input.substring(0, last + 1).trimEnd(); return input.substring(0, last + 1).trimEnd();
} }
export function trimToStartSentence(input) {
let p1 = input.indexOf(".");
let p2 = input.indexOf("!");
let p3 = input.indexOf("?");
let p4 = input.indexOf("\n");
let first = p1;
let skip1 = false;
if (p2 > 0 && p2 < first) { first = p2; }
if (p3 > 0 && p3 < first) { first = p3; }
if (p4 > 0 && p4 < first) { first = p4; skip1 = true; }
if (first > 0) {
if (skip1) {
return input.substring(first + 1);
} else {
return input.substring(first + 2);
}
}
return input;
}
/** /**
* Counts the number of occurrences of a character in a string. * Counts the number of occurrences of a character in a string.
* @param {string} string The string to count occurrences in. * @param {string} string The string to count occurrences in.

@ -24,6 +24,7 @@ class PipelineAccessor {
* @param {any} jsonParser * @param {any} jsonParser
*/ */
function registerEndpoints(app, jsonParser) { function registerEndpoints(app, jsonParser) {
const cacheObject = {};
const pipelineAccessor = new PipelineAccessor(); const pipelineAccessor = new PipelineAccessor();
app.post('/api/extra/classify/labels', jsonParser, async (req, res) => { app.post('/api/extra/classify/labels', jsonParser, async (req, res) => {
@ -35,10 +36,19 @@ function registerEndpoints(app, jsonParser) {
app.post('/api/extra/classify', jsonParser, async (req, res) => { app.post('/api/extra/classify', jsonParser, async (req, res) => {
const { text } = req.body; const { text } = req.body;
const pipe = await pipelineAccessor.get(); async function getResult(text) {
const result = await pipe(text); if (cacheObject.hasOwnProperty(text)) {
return cacheObject[text];
} else {
const pipe = await pipelineAccessor.get();
const result = await pipe(text);
cacheObject[text] = result;
return result;
}
}
console.log('Classify input:', text); console.log('Classify input:', text);
const result = await getResult(text);
console.log('Classify output:', result); console.log('Classify output:', result);
return res.json({ classification: result }); return res.json({ classification: result });