Expressions: Classify using LLM

Rather than using a separate BERT model to classify the last message,
use the LLM itself to get the classified expression label as a JSON
and set that as the current sprite. Doing this should take more information
into consideration and cut down on extra processing.

This is made possible by the use of constrained generation with JSON
schemas. Only available to TabbyAPI since it's the only backend that
supports the use of JSON schemas, but there can hopefully be a way
to use this with other backends as well.

Intercepts the generation and sets top_k = 1 (for greedy sampling)
and the json_schema to an emotion enum. Doing this also prevents
reingestion of the entire context every time a message is sent and
then asked to be classified, which doesn't compromise the chat
experience.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-12 01:49:09 -04:00
parent 626c93a1ab
commit 6b656bf380
2 changed files with 122 additions and 58 deletions

View File

@ -1,4 +1,4 @@
import { callPopup, eventSource, event_types, getRequestHeaders, saveSettingsDebounced } from '../../../script.js'; import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, saveSettingsDebounced } from '../../../script.js';
import { dragElement, isMobile } from '../../RossAscends-mods.js'; import { dragElement, isMobile } from '../../RossAscends-mods.js';
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js'; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js';
import { loadMovingUIState, power_user } from '../../power-user.js'; import { loadMovingUIState, power_user } from '../../power-user.js';
@ -43,6 +43,11 @@ const DEFAULT_EXPRESSIONS = [
'surprise', 'surprise',
'neutral', 'neutral',
]; ];
const EXPRESSION_API = {
local: 0,
extras: 1,
llm: 2,
}
let expressionsList = null; let expressionsList = null;
let lastCharacter = undefined; let lastCharacter = undefined;
@ -55,7 +60,7 @@ let lastServerResponseTime = 0;
export let lastExpression = {}; export let lastExpression = {};
function isTalkingHeadEnabled() { function isTalkingHeadEnabled() {
return extension_settings.expressions.talkinghead && !extension_settings.expressions.local; return extension_settings.expressions.talkinghead && extension_settings.expressions.api == EXPRESSION_API.extras;
} }
/** /**
@ -585,10 +590,10 @@ function handleImageChange() {
async function moduleWorker() { async function moduleWorker() {
const context = getContext(); const context = getContext();
// Hide and disable Talkinghead while in local mode // Hide and disable Talkinghead while not in extras
$('#image_type_block').toggle(!extension_settings.expressions.local); $('#image_type_block').toggle(extension_settings.expressions.api == EXPRESSION_API.extras);
if (extension_settings.expressions.local && extension_settings.expressions.talkinghead) { if (extension_settings.expressions.api != EXPRESSION_API.extras && extension_settings.expressions.talkinghead) {
$('#image_type_toggle').prop('checked', false); $('#image_type_toggle').prop('checked', false);
setTalkingHeadState(false); setTalkingHeadState(false);
} }
@ -628,7 +633,7 @@ async function moduleWorker() {
} }
const offlineMode = $('.expression_settings .offline_mode'); const offlineMode = $('.expression_settings .offline_mode');
if (!modules.includes('classify') && !extension_settings.expressions.local) { if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
$('#open_chat_expressions').show(); $('#open_chat_expressions').show();
$('#no_chat_expressions').hide(); $('#no_chat_expressions').hide();
offlineMode.css('display', 'block'); offlineMode.css('display', 'block');
@ -821,7 +826,7 @@ function setTalkingHeadState(newState) {
extension_settings.expressions.talkinghead = newState; // Store setting extension_settings.expressions.talkinghead = newState; // Store setting
saveSettingsDebounced(); saveSettingsDebounced();
if (extension_settings.expressions.local) { if (extension_settings.expressions.api == EXPRESSION_API.local) {
return; return;
} }
@ -900,7 +905,7 @@ async function classifyCommand(_, text) {
return ''; return '';
} }
if (!modules.includes('classify') && !extension_settings.expressions.local) { if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
toastr.warning('Text classification is disabled or not available'); toastr.warning('Text classification is disabled or not available');
return ''; return '';
} }
@ -971,9 +976,32 @@ function sampleClassifyText(text) {
return result.trim(); return result.trim();
} }
function onTextGenSettingsReady(args) {
// Only call if inside an API call
if (inApiCall) {
const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead')
Object.assign(args, {
top_k: 1,
json_schema: {
$schema: "http://json-schema.org/draft-04/schema#",
type: "object",
properties: {
emotion: {
type: "string",
enum: emotions
}
},
required: [
"emotion"
]
}
});
}
}
async function getExpressionLabel(text) { async function getExpressionLabel(text) {
// Return if text is undefined, saving a costly fetch request // Return if text is undefined, saving a costly fetch request
if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) { if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) {
return getFallbackExpression(); return getFallbackExpression();
} }
@ -984,24 +1012,32 @@ async function getExpressionLabel(text) {
text = sampleClassifyText(text); text = sampleClassifyText(text);
try { try {
if (extension_settings.expressions.local) { switch (extension_settings.expressions.api) {
// Local transformers pipeline case EXPRESSION_API.local:
const apiResult = await fetch('/api/extra/classify', { // Local BERT pipeline
const localResult = await fetch('/api/extra/classify', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ text: text }), body: JSON.stringify({ text: text }),
}); });
if (apiResult.ok) { if (localResult.ok) {
const data = await apiResult.json(); const data = await localResult.json();
return data.classification[0].label; return data.classification[0].label;
} }
} else {
break;
case EXPRESSION_API.llm:
// Using LLM
const emotionResponse = await generateQuietPrompt('', false);
const parsedEmotion = JSON.parse(emotionResponse);
return parsedEmotion.emotion;
default:
// Extras // Extras
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/classify'; url.pathname = '/api/classify';
const apiResult = await doExtrasFetch(url, { const extrasResult = await doExtrasFetch(url, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -1010,8 +1046,8 @@ async function getExpressionLabel(text) {
body: JSON.stringify({ text: text }), body: JSON.stringify({ text: text }),
}); });
if (apiResult.ok) { if (extrasResult.ok) {
const data = await apiResult.json(); const data = await extrasResult.json();
return data.classification[0].label; return data.classification[0].label;
} }
} }
@ -1177,23 +1213,12 @@ async function getExpressionsList() {
*/ */
async function resolveExpressionsList() { async function resolveExpressionsList() {
// get something for offline mode (default images) // get something for offline mode (default images)
if (!modules.includes('classify') && !extension_settings.expressions.local) { if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
return DEFAULT_EXPRESSIONS; return DEFAULT_EXPRESSIONS;
} }
try { try {
if (extension_settings.expressions.local) { if (extension_settings.expressions.api == EXPRESSION_API.extras) {
const apiResult = await fetch('/api/extra/classify/labels', {
method: 'POST',
headers: getRequestHeaders(),
});
if (apiResult.ok) {
const data = await apiResult.json();
expressionsList = data.labels;
return expressionsList;
}
} else {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/classify/labels'; url.pathname = '/api/classify/labels';
@ -1204,6 +1229,17 @@ async function getExpressionsList() {
if (apiResult.ok) { if (apiResult.ok) {
const data = await apiResult.json();
expressionsList = data.labels;
return expressionsList;
}
} else {
const apiResult = await fetch('/api/extra/classify/labels', {
method: 'POST',
headers: getRequestHeaders(),
});
if (apiResult.ok) {
const data = await apiResult.json(); const data = await apiResult.json();
expressionsList = data.labels; expressionsList = data.labels;
return expressionsList; return expressionsList;
@ -1444,6 +1480,15 @@ async function onClickExpressionRemoveCustom() {
moduleWorker(); moduleWorker();
} }
function onExperesionApiChanged() {
const tempApi = this.value;
if (tempApi) {
extension_settings.expressions.api = Number(tempApi);
moduleWorker();
saveSettingsDebounced();
}
}
function onExpressionFallbackChanged() { function onExpressionFallbackChanged() {
const expression = this.value; const expression = this.value;
if (expression) { if (expression) {
@ -1556,6 +1601,7 @@ async function onClickExpressionOverrideButton() {
// Refresh sprites list. Assume the override path has been properly handled. // Refresh sprites list. Assume the override path has been properly handled.
try { try {
inApiCall = true;
$('#visual-novel-wrapper').empty(); $('#visual-novel-wrapper').empty();
await validateImages(overridePath.length === 0 ? currentLastMessage.name : overridePath, true); await validateImages(overridePath.length === 0 ? currentLastMessage.name : overridePath, true);
const expression = await getExpressionLabel(currentLastMessage.mes); const expression = await getExpressionLabel(currentLastMessage.mes);
@ -1563,6 +1609,8 @@ async function onClickExpressionOverrideButton() {
forceUpdateVisualNovelMode(); forceUpdateVisualNovelMode();
} catch (error) { } catch (error) {
console.debug(`Setting expression override for ${avatarFileName} failed with error: ${error}`); console.debug(`Setting expression override for ${avatarFileName} failed with error: ${error}`);
} finally {
inApiCall = false;
} }
} }
@ -1699,6 +1747,17 @@ async function fetchImagesNoCache() {
return await Promise.allSettled(promises); return await Promise.allSettled(promises);
} }
function migrateSettings() {
if (Object.keys(extension_settings.expressions).includes('local')) {
if (extension_settings.expressions.local) {
extension_settings.expressions.api == EXPRESSION_API.local;
}
delete extension_settings.expressions.local;
saveSettingsDebounced();
}
}
(async function () { (async function () {
function addExpressionImage() { function addExpressionImage() {
const html = ` const html = `
@ -1730,11 +1789,6 @@ async function fetchImagesNoCache() {
extension_settings.expressions.translate = !!$(this).prop('checked'); extension_settings.expressions.translate = !!$(this).prop('checked');
saveSettingsDebounced(); saveSettingsDebounced();
}); });
$('#expression_local').prop('checked', extension_settings.expressions.local).on('input', function () {
extension_settings.expressions.local = !!$(this).prop('checked');
moduleWorker();
saveSettingsDebounced();
});
$('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton); $('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton);
$(document).on('dragstart', '.expression', (e) => { $(document).on('dragstart', '.expression', (e) => {
e.preventDefault(); e.preventDefault();
@ -1753,10 +1807,12 @@ async function fetchImagesNoCache() {
}); });
await renderAdditionalExpressionSettings(); await renderAdditionalExpressionSettings();
$('#expression_api').val(extension_settings.expressions.api || 0);
$('#expression_custom_add').on('click', onClickExpressionAddCustom); $('#expression_custom_add').on('click', onClickExpressionAddCustom);
$('#expression_custom_remove').on('click', onClickExpressionRemoveCustom); $('#expression_custom_remove').on('click', onClickExpressionRemoveCustom);
$('#expression_fallback').on('change', onExpressionFallbackChanged); $('#expression_fallback').on('change', onExpressionFallbackChanged);
$('#expression_api').on('change', onExperesionApiChanged);
} }
// Pause Talkinghead to save resources when the ST tab is not visible or the window is minimized. // Pause Talkinghead to save resources when the ST tab is not visible or the window is minimized.
@ -1789,6 +1845,7 @@ async function fetchImagesNoCache() {
addExpressionImage(); addExpressionImage();
addVisualNovelMode(); addVisualNovelMode();
migrateSettings();
await addSettings(); await addSettings();
const wrapper = new ModuleWorkerWrapper(moduleWorker); const wrapper = new ModuleWorkerWrapper(moduleWorker);
const updateFunction = wrapper.update.bind(wrapper); const updateFunction = wrapper.update.bind(wrapper);
@ -1828,6 +1885,7 @@ async function fetchImagesNoCache() {
}); });
eventSource.on(event_types.MOVABLE_PANELS_RESET, updateVisualNovelModeDebounced); eventSource.on(event_types.MOVABLE_PANELS_RESET, updateVisualNovelModeDebounced);
eventSource.on(event_types.GROUP_UPDATED, updateVisualNovelModeDebounced); eventSource.on(event_types.GROUP_UPDATED, updateVisualNovelModeDebounced);
eventSource.on(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
registerSlashCommand('sprite', setSpriteSlashCommand, ['emote'], '<span class="monospace">(spriteId)</span> force sets the sprite for the current character', true, true); registerSlashCommand('sprite', setSpriteSlashCommand, ['emote'], '<span class="monospace">(spriteId)</span> force sets the sprite for the current character', true, true);
registerSlashCommand('spriteoverride', setSpriteSetCommand, ['costume'], '<span class="monospace">(optional folder)</span> sets an override sprite folder for the current character. If the name starts with a slash or a backslash, selects a sub-folder in the character-named folder. Empty value to reset to default.', true, true); registerSlashCommand('spriteoverride', setSpriteSetCommand, ['costume'], '<span class="monospace">(optional folder)</span> sets an override sprite folder for the current character. If the name starts with a slash or a backslash, selects a sub-folder in the character-named folder. Empty value to reset to default.', true, true);
registerSlashCommand('lastsprite', (_, value) => lastExpression[value.trim()] ?? '', [], '<span class="monospace">(charName)</span> Returns the last set sprite / expression for the named character.', true, true); registerSlashCommand('lastsprite', (_, value) => lastExpression[value.trim()] ?? '', [], '<span class="monospace">(charName)</span> Returns the last set sprite / expression for the named character.', true, true);

View File

@ -6,10 +6,6 @@
</div> </div>
<div class="inline-drawer-content"> <div class="inline-drawer-content">
<label class="checkbox_label" for="expression_local" title="Use classification model without the Extras server.">
<input id="expression_local" type="checkbox" />
<span data-i18n="Local server classification">Local server classification</span>
</label>
<label class="checkbox_label" for="expression_translate" title="Use the selected API from Chat Translation extension settings."> <label class="checkbox_label" for="expression_translate" title="Use the selected API from Chat Translation extension settings.">
<input id="expression_translate" type="checkbox"> <input id="expression_translate" type="checkbox">
<span>Translate text to English before classification</span> <span>Translate text to English before classification</span>
@ -22,6 +18,16 @@
<input id="image_type_toggle" type="checkbox"> <input id="image_type_toggle" type="checkbox">
<span>Image Type - talkinghead (extras)</span> <span>Image Type - talkinghead (extras)</span>
</label> </label>
<div class="expression_api_block m-b-1 m-t-1">
<label for="expression_api">Classifier API</label>
<small>Select the API for classifying expressions.</small>
<select id="expression_api" class="flex1 margin0" data-i18n="Expression API" placeholder="Expression API">
<option value="0">Local</option>
<option value="1">Extras</option>
<option value="2">LLM</option>
<option value="3">TalkingHead</option>
</select>
</div>
<div class="expression_fallback_block m-b-1 m-t-1"> <div class="expression_fallback_block m-b-1 m-t-1">
<label for="expression_fallback">Default / Fallback Expression</label> <label for="expression_fallback">Default / Fallback Expression</label>
<small>Set the default and fallback expression being used when no matching expression is found.</small> <small>Set the default and fallback expression being used when no matching expression is found.</small>