diff --git a/.eslintrc.js b/.eslintrc.js index 6e926dee9..981379f24 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -60,6 +60,8 @@ module.exports = { 'no-trailing-spaces': 'error', 'object-curly-spacing': ['error', 'always'], 'space-infix-ops': 'error', + 'no-unused-expressions': ['error', { allowShortCircuit: true, allowTernary: true }], + 'no-cond-assign': 'error', // These rules should eventually be enabled. 'no-async-promise-executor': 'off', diff --git a/public/script.js b/public/script.js index c95133f34..7162d5f32 100644 --- a/public/script.js +++ b/public/script.js @@ -3071,14 +3071,14 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu if (interruptedByCommand) { //$("#send_textarea").val('').trigger('input'); - unblockGeneration(); + unblockGeneration(type); return Promise.resolve(); } } if (main_api == 'kobold' && kai_settings.streaming_kobold && !kai_flags.can_use_streaming) { toastr.error('Streaming is enabled, but the version of Kobold used does not support token streaming.', undefined, { timeOut: 10000, preventDuplicates: true }); - unblockGeneration(); + unblockGeneration(type); return Promise.resolve(); } @@ -3087,12 +3087,12 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu textgen_settings.legacy_api && (textgen_settings.type === OOBA || textgen_settings.type === APHRODITE)) { toastr.error('Streaming is not supported for the Legacy API. Update Ooba and use new API to enable streaming.', undefined, { timeOut: 10000, preventDuplicates: true }); - unblockGeneration(); + unblockGeneration(type); return Promise.resolve(); } if (isHordeGenerationNotAllowed()) { - unblockGeneration(); + unblockGeneration(type); return Promise.resolve(); } @@ -3128,7 +3128,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu setCharacterName(''); } else { console.log('No enabled members found'); - unblockGeneration(); + unblockGeneration(type); return Promise.resolve(); } } @@ -3302,7 +3302,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu if (aborted) { console.debug('Generation aborted by extension interceptors'); - unblockGeneration(); + unblockGeneration(type); return Promise.resolve(); } } else { @@ -3316,7 +3316,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu adjustedParams = await adjustHordeGenerationParams(max_context, amount_gen); } catch { - unblockGeneration(); + unblockGeneration(type); return Promise.resolve(); } if (horde_settings.auto_adjust_context_length) { @@ -4103,7 +4103,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu await eventSource.emit(event_types.IMPERSONATE_READY, getMessage); } else if (type == 'quiet') { - unblockGeneration(); + unblockGeneration(type); return getMessage; } else { @@ -4171,7 +4171,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu console.debug('/api/chats/save called by /Generate'); await saveChatConditional(); - unblockGeneration(); + unblockGeneration(type); streamingProcessor = null; if (type !== 'quiet') { @@ -4189,7 +4189,7 @@ async function Generate(type, { automatic_trigger, force_name2, quiet_prompt, qu generatedPromptCache = ''; - unblockGeneration(); + unblockGeneration(type); console.log(exception); streamingProcessor = null; throw exception; @@ -4259,7 +4259,16 @@ function flushWIDepthInjections() { } } -function unblockGeneration() { +/** + * Unblocks the UI after a generation is complete. + * @param {string} [type] Generation type (optional) + */ +function unblockGeneration(type) { + // Don't unblock if a parallel stream is still running + if (type === 'quiet' && streamingProcessor && !streamingProcessor.isFinished) { + return; + } + is_send_press = false; activateSendButtons(); showSwipeButtons(); diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index 15ad253be..3ab2d0eb0 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -1,17 +1,19 @@ -import { callPopup, eventSource, event_types, getRequestHeaders, saveSettingsDebounced } from '../../../script.js'; +import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, saveSettingsDebounced, substituteParams } from '../../../script.js'; import { dragElement, isMobile } from '../../RossAscends-mods.js'; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js'; import { loadMovingUIState, power_user } from '../../power-user.js'; import { registerSlashCommand } from '../../slash-commands.js'; import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from '../../utils.js'; import { hideMutedSprites } from '../../group-chats.js'; +import { isJsonSchemaSupported } from '../../textgen-settings.js'; export { MODULE_NAME }; const MODULE_NAME = 'expressions'; const UPDATE_INTERVAL = 2000; -const STREAMING_UPDATE_INTERVAL = 6000; +const STREAMING_UPDATE_INTERVAL = 10000; const TALKINGCHECK_UPDATE_INTERVAL = 500; const DEFAULT_FALLBACK_EXPRESSION = 'joy'; +const DEFAULT_LLM_PROMPT = 'Pause your roleplay. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}'; const DEFAULT_EXPRESSIONS = [ 'talkinghead', 'admiration', @@ -43,6 +45,11 @@ const DEFAULT_EXPRESSIONS = [ 'surprise', 'neutral', ]; +const EXPRESSION_API = { + local: 0, + extras: 1, + llm: 2, +}; let expressionsList = null; let lastCharacter = undefined; @@ -55,7 +62,7 @@ let lastServerResponseTime = 0; export let lastExpression = {}; function isTalkingHeadEnabled() { - return extension_settings.expressions.talkinghead && !extension_settings.expressions.local; + return extension_settings.expressions.talkinghead && extension_settings.expressions.api == EXPRESSION_API.extras; } /** @@ -585,10 +592,10 @@ function handleImageChange() { async function moduleWorker() { const context = getContext(); - // Hide and disable Talkinghead while in local mode - $('#image_type_block').toggle(!extension_settings.expressions.local); + // Hide and disable Talkinghead while not in extras + $('#image_type_block').toggle(extension_settings.expressions.api == EXPRESSION_API.extras); - if (extension_settings.expressions.local && extension_settings.expressions.talkinghead) { + if (extension_settings.expressions.api != EXPRESSION_API.extras && extension_settings.expressions.talkinghead) { $('#image_type_toggle').prop('checked', false); setTalkingHeadState(false); } @@ -628,7 +635,7 @@ async function moduleWorker() { } const offlineMode = $('.expression_settings .offline_mode'); - if (!modules.includes('classify') && !extension_settings.expressions.local) { + if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) { $('#open_chat_expressions').show(); $('#no_chat_expressions').hide(); offlineMode.css('display', 'block'); @@ -821,7 +828,7 @@ function setTalkingHeadState(newState) { extension_settings.expressions.talkinghead = newState; // Store setting saveSettingsDebounced(); - if (extension_settings.expressions.local) { + if (extension_settings.expressions.api == EXPRESSION_API.local || extension_settings.expressions.api == EXPRESSION_API.llm) { return; } @@ -900,7 +907,7 @@ async function classifyCommand(_, text) { return ''; } - if (!modules.includes('classify') && !extension_settings.expressions.local) { + if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) { toastr.warning('Text classification is disabled or not available'); return ''; } @@ -971,9 +978,76 @@ function sampleClassifyText(text) { return result.trim(); } +/** + * Gets the classification prompt for the LLM API. + * @param {string[]} labels A list of labels to search for. + * @returns {Promise} Prompt for the LLM API. + */ +async function getLlmPrompt(labels) { + if (isJsonSchemaSupported()) { + return ''; + } + + const labelsString = labels.map(x => `"${x}"`).join(', '); + const prompt = substituteParams(String(extension_settings.expressions.llmPrompt)) + .replace(/{{labels}}/gi, labelsString); + return prompt; +} + +/** + * Parses the emotion response from the LLM API. + * @param {string} emotionResponse The response from the LLM API. + * @param {string[]} labels A list of labels to search for. + * @returns {string} The parsed emotion or the fallback expression. + */ +function parseLlmResponse(emotionResponse, labels) { + const fallbackExpression = getFallbackExpression(); + + try { + const parsedEmotion = JSON.parse(emotionResponse); + return parsedEmotion?.emotion ?? fallbackExpression; + } catch { + const fuse = new Fuse([emotionResponse]); + for (const label of labels) { + const result = fuse.search(label); + if (result.length > 0) { + return label; + } + } + } + + throw new Error('Could not parse emotion response ' + emotionResponse); +} + +function onTextGenSettingsReady(args) { + // Only call if inside an API call + if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) { + const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead'); + Object.assign(args, { + top_k: 1, + stop: [], + stopping_strings: [], + custom_token_bans: [], + json_schema: { + $schema: 'http://json-schema.org/draft-04/schema#', + type: 'object', + properties: { + emotion: { + type: 'string', + enum: emotions, + }, + }, + required: [ + 'emotion', + ], + }, + }); + } +} + async function getExpressionLabel(text) { // Return if text is undefined, saving a costly fetch request - if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) { + if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) { return getFallbackExpression(); } @@ -984,39 +1058,50 @@ async function getExpressionLabel(text) { text = sampleClassifyText(text); try { - if (extension_settings.expressions.local) { - // Local transformers pipeline - const apiResult = await fetch('/api/extra/classify', { - method: 'POST', - headers: getRequestHeaders(), - body: JSON.stringify({ text: text }), - }); + switch (extension_settings.expressions.api) { + // Local BERT pipeline + case EXPRESSION_API.local: { + const localResult = await fetch('/api/extra/classify', { + method: 'POST', + headers: getRequestHeaders(), + body: JSON.stringify({ text: text }), + }); - if (apiResult.ok) { - const data = await apiResult.json(); - return data.classification[0].label; + if (localResult.ok) { + const data = await localResult.json(); + return data.classification[0].label; + } + } break; + // Using LLM + case EXPRESSION_API.llm: { + const expressionsList = await getExpressionsList(); + const prompt = await getLlmPrompt(expressionsList); + const emotionResponse = await generateQuietPrompt(prompt, false, false); + return parseLlmResponse(emotionResponse, expressionsList); } - } else { // Extras - const url = new URL(getApiUrl()); - url.pathname = '/api/classify'; + default: { + const url = new URL(getApiUrl()); + url.pathname = '/api/classify'; - const apiResult = await doExtrasFetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Bypass-Tunnel-Reminder': 'bypass', - }, - body: JSON.stringify({ text: text }), - }); + const extrasResult = await doExtrasFetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Bypass-Tunnel-Reminder': 'bypass', + }, + body: JSON.stringify({ text: text }), + }); - if (apiResult.ok) { - const data = await apiResult.json(); - return data.classification[0].label; - } + if (extrasResult.ok) { + const data = await extrasResult.json(); + return data.classification[0].label; + } + } break; } } catch (error) { - console.log(error); + toastr.info('Could not classify expression. Check the console or your backend for more information.'); + console.error(error); return getFallbackExpression(); } } @@ -1177,23 +1262,12 @@ async function getExpressionsList() { */ async function resolveExpressionsList() { // get something for offline mode (default images) - if (!modules.includes('classify') && !extension_settings.expressions.local) { + if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) { return DEFAULT_EXPRESSIONS; } try { - if (extension_settings.expressions.local) { - const apiResult = await fetch('/api/extra/classify/labels', { - method: 'POST', - headers: getRequestHeaders(), - }); - - if (apiResult.ok) { - const data = await apiResult.json(); - expressionsList = data.labels; - return expressionsList; - } - } else { + if (extension_settings.expressions.api == EXPRESSION_API.extras) { const url = new URL(getApiUrl()); url.pathname = '/api/classify/labels'; @@ -1204,6 +1278,17 @@ async function getExpressionsList() { if (apiResult.ok) { + const data = await apiResult.json(); + expressionsList = data.labels; + return expressionsList; + } + } else { + const apiResult = await fetch('/api/extra/classify/labels', { + method: 'POST', + headers: getRequestHeaders(), + }); + + if (apiResult.ok) { const data = await apiResult.json(); expressionsList = data.labels; return expressionsList; @@ -1444,6 +1529,16 @@ async function onClickExpressionRemoveCustom() { moduleWorker(); } +function onExperesionApiChanged() { + const tempApi = this.value; + if (tempApi) { + extension_settings.expressions.api = Number(tempApi); + $('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm); + moduleWorker(); + saveSettingsDebounced(); + } +} + function onExpressionFallbackChanged() { const expression = this.value; if (expression) { @@ -1556,6 +1651,7 @@ async function onClickExpressionOverrideButton() { // Refresh sprites list. Assume the override path has been properly handled. try { + inApiCall = true; $('#visual-novel-wrapper').empty(); await validateImages(overridePath.length === 0 ? currentLastMessage.name : overridePath, true); const expression = await getExpressionLabel(currentLastMessage.mes); @@ -1563,6 +1659,8 @@ async function onClickExpressionOverrideButton() { forceUpdateVisualNovelMode(); } catch (error) { console.debug(`Setting expression override for ${avatarFileName} failed with error: ${error}`); + } finally { + inApiCall = false; } } @@ -1699,6 +1797,22 @@ async function fetchImagesNoCache() { return await Promise.allSettled(promises); } +function migrateSettings() { + if (Object.keys(extension_settings.expressions).includes('local')) { + if (extension_settings.expressions.local) { + extension_settings.expressions.api = EXPRESSION_API.local; + } + + delete extension_settings.expressions.local; + saveSettingsDebounced(); + } + + if (extension_settings.expressions.llmPrompt === undefined) { + extension_settings.expressions.llmPrompt = DEFAULT_LLM_PROMPT; + saveSettingsDebounced(); + } +} + (async function () { function addExpressionImage() { const html = ` @@ -1730,11 +1844,6 @@ async function fetchImagesNoCache() { extension_settings.expressions.translate = !!$(this).prop('checked'); saveSettingsDebounced(); }); - $('#expression_local').prop('checked', extension_settings.expressions.local).on('input', function () { - extension_settings.expressions.local = !!$(this).prop('checked'); - moduleWorker(); - saveSettingsDebounced(); - }); $('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton); $(document).on('dragstart', '.expression', (e) => { e.preventDefault(); @@ -1753,10 +1862,23 @@ async function fetchImagesNoCache() { }); await renderAdditionalExpressionSettings(); + $('#expression_api').val(extension_settings.expressions.api ?? EXPRESSION_API.extras); + $('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm); + $('#expression_llm_prompt').val(extension_settings.expressions.llmPrompt ?? ''); + $('#expression_llm_prompt').on('input', function () { + extension_settings.expressions.llmPrompt = $(this).val(); + saveSettingsDebounced(); + }); + $('#expression_llm_prompt_restore').on('click', function () { + $('#expression_llm_prompt').val(DEFAULT_LLM_PROMPT); + extension_settings.expressions.llmPrompt = DEFAULT_LLM_PROMPT; + saveSettingsDebounced(); + }); $('#expression_custom_add').on('click', onClickExpressionAddCustom); $('#expression_custom_remove').on('click', onClickExpressionRemoveCustom); $('#expression_fallback').on('change', onExpressionFallbackChanged); + $('#expression_api').on('change', onExperesionApiChanged); } // Pause Talkinghead to save resources when the ST tab is not visible or the window is minimized. @@ -1789,6 +1911,7 @@ async function fetchImagesNoCache() { addExpressionImage(); addVisualNovelMode(); + migrateSettings(); await addSettings(); const wrapper = new ModuleWorkerWrapper(moduleWorker); const updateFunction = wrapper.update.bind(wrapper); @@ -1828,6 +1951,7 @@ async function fetchImagesNoCache() { }); eventSource.on(event_types.MOVABLE_PANELS_RESET, updateVisualNovelModeDebounced); eventSource.on(event_types.GROUP_UPDATED, updateVisualNovelModeDebounced); + eventSource.on(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady); registerSlashCommand('sprite', setSpriteSlashCommand, ['emote'], '(spriteId) – force sets the sprite for the current character', true, true); registerSlashCommand('spriteoverride', setSpriteSetCommand, ['costume'], '(optional folder) – sets an override sprite folder for the current character. If the name starts with a slash or a backslash, selects a sub-folder in the character-named folder. Empty value to reset to default.', true, true); registerSlashCommand('lastsprite', (_, value) => lastExpression[value.trim()] ?? '', [], '(charName) – Returns the last set sprite / expression for the named character.', true, true); diff --git a/public/scripts/extensions/expressions/settings.html b/public/scripts/extensions/expressions/settings.html index b0b3b0bd3..4a7347a74 100644 --- a/public/scripts/extensions/expressions/settings.html +++ b/public/scripts/extensions/expressions/settings.html @@ -6,10 +6,6 @@
- +
+ + Select the API for classifying expressions. + +
+
+ + Will be used if the API doesn't support JSON schemas. + +
Set the default and fallback expression being used when no matching expression is found. diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 4f8156a38..aa85d1f99 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -3,6 +3,7 @@ import { event_types, getRequestHeaders, getStoppingStrings, + main_api, max_context, saveSettingsDebounced, setGenerationParamsFromPreset, @@ -978,6 +979,10 @@ function getModel() { return undefined; } +export function isJsonSchemaSupported() { + return settings.type === TABBY && main_api === 'textgenerationwebui'; +} + export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) { const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet'; let params = {