diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index 426be18b8..e50ac9c71 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -5,13 +5,15 @@ import { loadMovingUIState, power_user } from '../../power-user.js'; import { registerSlashCommand } from '../../slash-commands.js'; import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from '../../utils.js'; import { hideMutedSprites } from '../../group-chats.js'; +import { isJsonSchemaSupported } from '../../textgen-settings.js'; export { MODULE_NAME }; const MODULE_NAME = 'expressions'; const UPDATE_INTERVAL = 2000; -const STREAMING_UPDATE_INTERVAL = 6000; +const STREAMING_UPDATE_INTERVAL = 10000; const TALKINGCHECK_UPDATE_INTERVAL = 500; const DEFAULT_FALLBACK_EXPRESSION = 'joy'; +const DEFAULT_LLM_PROMPT = 'Pause your roleplay. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}'; const DEFAULT_EXPRESSIONS = [ 'talkinghead', 'admiration', @@ -976,9 +978,49 @@ function sampleClassifyText(text) { return result.trim(); } +/** + * Gets the classification prompt for the LLM API. + * @param {string[]} labels A list of labels to search for. + * @returns {Promise} Prompt for the LLM API. + */ +async function getLlmPrompt(labels) { + if (isJsonSchemaSupported()) { + return ''; + } + + const prompt = String(extension_settings.expressions.llmPrompt).replace(/{{labels}}/gi, labels.map(x => `"${x}"`).join(', ')); + return prompt; +} + +/** + * Parses the emotion response from the LLM API. + * @param {string} emotionResponse The response from the LLM API. + * @param {string[]} labels A list of labels to search for. + * @returns {string} The parsed emotion or the fallback expression. + */ +function parseLlmResponse(emotionResponse, labels) { + const fallbackExpression = getFallbackExpression(); + + try { + const parsedEmotion = JSON.parse(emotionResponse); + return parsedEmotion?.emotion ?? fallbackExpression; + + } catch { + const fuse = new Fuse([emotionResponse]); + for (const label of labels) { + const result = fuse.search(label); + if (result.length > 0) { + return label; + } + } + } + + return fallbackExpression; +} + function onTextGenSettingsReady(args) { // Only call if inside an API call - if (inApiCall) { + if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) { const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead'); Object.assign(args, { top_k: 1, @@ -1016,8 +1058,8 @@ async function getExpressionLabel(text) { try { switch (extension_settings.expressions.api) { - case EXPRESSION_API.local: - // Local BERT pipeline + // Local BERT pipeline + case EXPRESSION_API.local: { const localResult = await fetch('/api/extra/classify', { method: 'POST', headers: getRequestHeaders(), @@ -1028,15 +1070,16 @@ async function getExpressionLabel(text) { const data = await localResult.json(); return data.classification[0].label; } - - break; - case EXPRESSION_API.llm: - // Using LLM - const emotionResponse = await generateQuietPrompt('', false); - const parsedEmotion = JSON.parse(emotionResponse); - return parsedEmotion.emotion; - default: - // Extras + } break; + // Using LLM + case EXPRESSION_API.llm: { + const expressionsList = await getExpressionsList(); + const prompt = await getLlmPrompt(expressionsList); + const emotionResponse = await generateQuietPrompt(prompt, false, false); + return parseLlmResponse(emotionResponse, expressionsList); + } + // Extras + default: { const url = new URL(getApiUrl()); url.pathname = '/api/classify'; @@ -1053,6 +1096,7 @@ async function getExpressionLabel(text) { const data = await extrasResult.json(); return data.classification[0].label; } + } break; } } catch (error) { toastr.info('Could not classify expression. Check the console or your backend for more information.'); @@ -1488,6 +1532,7 @@ function onExperesionApiChanged() { const tempApi = this.value; if (tempApi) { extension_settings.expressions.api = Number(tempApi); + $('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm); moduleWorker(); saveSettingsDebounced(); } @@ -1760,6 +1805,11 @@ function migrateSettings() { delete extension_settings.expressions.local; saveSettingsDebounced(); } + + if (extension_settings.expressions.llmPrompt === undefined) { + extension_settings.expressions.llmPrompt = DEFAULT_LLM_PROMPT; + saveSettingsDebounced(); + } } (async function () { @@ -1811,7 +1861,13 @@ function migrateSettings() { }); await renderAdditionalExpressionSettings(); - $('#expression_api').val(extension_settings.expressions.api || EXPRESSION_API.extras); + $('#expression_api').val(extension_settings.expressions.api ?? EXPRESSION_API.extras); + $('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm); + $('#expression_llm_prompt').val(extension_settings.expressions.llmPrompt ?? ''); + $('#expression_llm_prompt').on('input', function () { + extension_settings.expressions.llmPrompt = $(this).val(); + saveSettingsDebounced(); + }); $('#expression_custom_add').on('click', onClickExpressionAddCustom); $('#expression_custom_remove').on('click', onClickExpressionRemoveCustom); diff --git a/public/scripts/extensions/expressions/settings.html b/public/scripts/extensions/expressions/settings.html index 0831a34d3..9857ae3d5 100644 --- a/public/scripts/extensions/expressions/settings.html +++ b/public/scripts/extensions/expressions/settings.html @@ -27,6 +27,11 @@ +
+ + Will be used if the API doesn't support JSON schemas. + +
Set the default and fallback expression being used when no matching expression is found. diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 4f8156a38..aa85d1f99 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -3,6 +3,7 @@ import { event_types, getRequestHeaders, getStoppingStrings, + main_api, max_context, saveSettingsDebounced, setGenerationParamsFromPreset, @@ -978,6 +979,10 @@ function getModel() { return undefined; } +export function isJsonSchemaSupported() { + return settings.type === TABBY && main_api === 'textgenerationwebui'; +} + export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) { const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet'; let params = {