Expressions: Classify using LLM
Rather than using a separate BERT model to classify the last message, use the LLM itself to get the classified expression label as a JSON and set that as the current sprite. Doing this should take more information into consideration and cut down on extra processing. This is made possible by the use of constrained generation with JSON schemas. Only available to TabbyAPI since it's the only backend that supports the use of JSON schemas, but there can hopefully be a way to use this with other backends as well. Intercepts the generation and sets top_k = 1 (for greedy sampling) and the json_schema to an emotion enum. Doing this also prevents reingestion of the entire context every time a message is sent and then asked to be classified, which doesn't compromise the chat experience. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
626c93a1ab
commit
6b656bf380
|
@ -1,4 +1,4 @@
|
||||||
import { callPopup, eventSource, event_types, getRequestHeaders, saveSettingsDebounced } from '../../../script.js';
|
import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, saveSettingsDebounced } from '../../../script.js';
|
||||||
import { dragElement, isMobile } from '../../RossAscends-mods.js';
|
import { dragElement, isMobile } from '../../RossAscends-mods.js';
|
||||||
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js';
|
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js';
|
||||||
import { loadMovingUIState, power_user } from '../../power-user.js';
|
import { loadMovingUIState, power_user } from '../../power-user.js';
|
||||||
|
@ -43,6 +43,11 @@ const DEFAULT_EXPRESSIONS = [
|
||||||
'surprise',
|
'surprise',
|
||||||
'neutral',
|
'neutral',
|
||||||
];
|
];
|
||||||
|
const EXPRESSION_API = {
|
||||||
|
local: 0,
|
||||||
|
extras: 1,
|
||||||
|
llm: 2,
|
||||||
|
}
|
||||||
|
|
||||||
let expressionsList = null;
|
let expressionsList = null;
|
||||||
let lastCharacter = undefined;
|
let lastCharacter = undefined;
|
||||||
|
@ -55,7 +60,7 @@ let lastServerResponseTime = 0;
|
||||||
export let lastExpression = {};
|
export let lastExpression = {};
|
||||||
|
|
||||||
function isTalkingHeadEnabled() {
|
function isTalkingHeadEnabled() {
|
||||||
return extension_settings.expressions.talkinghead && !extension_settings.expressions.local;
|
return extension_settings.expressions.talkinghead && extension_settings.expressions.api == EXPRESSION_API.extras;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -585,10 +590,10 @@ function handleImageChange() {
|
||||||
async function moduleWorker() {
|
async function moduleWorker() {
|
||||||
const context = getContext();
|
const context = getContext();
|
||||||
|
|
||||||
// Hide and disable Talkinghead while in local mode
|
// Hide and disable Talkinghead while not in extras
|
||||||
$('#image_type_block').toggle(!extension_settings.expressions.local);
|
$('#image_type_block').toggle(extension_settings.expressions.api == EXPRESSION_API.extras);
|
||||||
|
|
||||||
if (extension_settings.expressions.local && extension_settings.expressions.talkinghead) {
|
if (extension_settings.expressions.api != EXPRESSION_API.extras && extension_settings.expressions.talkinghead) {
|
||||||
$('#image_type_toggle').prop('checked', false);
|
$('#image_type_toggle').prop('checked', false);
|
||||||
setTalkingHeadState(false);
|
setTalkingHeadState(false);
|
||||||
}
|
}
|
||||||
|
@ -628,7 +633,7 @@ async function moduleWorker() {
|
||||||
}
|
}
|
||||||
|
|
||||||
const offlineMode = $('.expression_settings .offline_mode');
|
const offlineMode = $('.expression_settings .offline_mode');
|
||||||
if (!modules.includes('classify') && !extension_settings.expressions.local) {
|
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||||
$('#open_chat_expressions').show();
|
$('#open_chat_expressions').show();
|
||||||
$('#no_chat_expressions').hide();
|
$('#no_chat_expressions').hide();
|
||||||
offlineMode.css('display', 'block');
|
offlineMode.css('display', 'block');
|
||||||
|
@ -821,7 +826,7 @@ function setTalkingHeadState(newState) {
|
||||||
extension_settings.expressions.talkinghead = newState; // Store setting
|
extension_settings.expressions.talkinghead = newState; // Store setting
|
||||||
saveSettingsDebounced();
|
saveSettingsDebounced();
|
||||||
|
|
||||||
if (extension_settings.expressions.local) {
|
if (extension_settings.expressions.api == EXPRESSION_API.local) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -900,7 +905,7 @@ async function classifyCommand(_, text) {
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!modules.includes('classify') && !extension_settings.expressions.local) {
|
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||||
toastr.warning('Text classification is disabled or not available');
|
toastr.warning('Text classification is disabled or not available');
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
@ -971,9 +976,32 @@ function sampleClassifyText(text) {
|
||||||
return result.trim();
|
return result.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function onTextGenSettingsReady(args) {
|
||||||
|
// Only call if inside an API call
|
||||||
|
if (inApiCall) {
|
||||||
|
const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead')
|
||||||
|
Object.assign(args, {
|
||||||
|
top_k: 1,
|
||||||
|
json_schema: {
|
||||||
|
$schema: "http://json-schema.org/draft-04/schema#",
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
emotion: {
|
||||||
|
type: "string",
|
||||||
|
enum: emotions
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: [
|
||||||
|
"emotion"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function getExpressionLabel(text) {
|
async function getExpressionLabel(text) {
|
||||||
// Return if text is undefined, saving a costly fetch request
|
// Return if text is undefined, saving a costly fetch request
|
||||||
if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) {
|
if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) {
|
||||||
return getFallbackExpression();
|
return getFallbackExpression();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -984,36 +1012,44 @@ async function getExpressionLabel(text) {
|
||||||
text = sampleClassifyText(text);
|
text = sampleClassifyText(text);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (extension_settings.expressions.local) {
|
switch (extension_settings.expressions.api) {
|
||||||
// Local transformers pipeline
|
case EXPRESSION_API.local:
|
||||||
const apiResult = await fetch('/api/extra/classify', {
|
// Local BERT pipeline
|
||||||
method: 'POST',
|
const localResult = await fetch('/api/extra/classify', {
|
||||||
headers: getRequestHeaders(),
|
method: 'POST',
|
||||||
body: JSON.stringify({ text: text }),
|
headers: getRequestHeaders(),
|
||||||
});
|
body: JSON.stringify({ text: text }),
|
||||||
|
});
|
||||||
|
|
||||||
if (apiResult.ok) {
|
if (localResult.ok) {
|
||||||
const data = await apiResult.json();
|
const data = await localResult.json();
|
||||||
return data.classification[0].label;
|
return data.classification[0].label;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Extras
|
|
||||||
const url = new URL(getApiUrl());
|
|
||||||
url.pathname = '/api/classify';
|
|
||||||
|
|
||||||
const apiResult = await doExtrasFetch(url, {
|
break;
|
||||||
method: 'POST',
|
case EXPRESSION_API.llm:
|
||||||
headers: {
|
// Using LLM
|
||||||
'Content-Type': 'application/json',
|
const emotionResponse = await generateQuietPrompt('', false);
|
||||||
'Bypass-Tunnel-Reminder': 'bypass',
|
const parsedEmotion = JSON.parse(emotionResponse);
|
||||||
},
|
return parsedEmotion.emotion;
|
||||||
body: JSON.stringify({ text: text }),
|
default:
|
||||||
});
|
// Extras
|
||||||
|
const url = new URL(getApiUrl());
|
||||||
|
url.pathname = '/api/classify';
|
||||||
|
|
||||||
if (apiResult.ok) {
|
const extrasResult = await doExtrasFetch(url, {
|
||||||
const data = await apiResult.json();
|
method: 'POST',
|
||||||
return data.classification[0].label;
|
headers: {
|
||||||
}
|
'Content-Type': 'application/json',
|
||||||
|
'Bypass-Tunnel-Reminder': 'bypass',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ text: text }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (extrasResult.ok) {
|
||||||
|
const data = await extrasResult.json();
|
||||||
|
return data.classification[0].label;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
|
@ -1177,23 +1213,12 @@ async function getExpressionsList() {
|
||||||
*/
|
*/
|
||||||
async function resolveExpressionsList() {
|
async function resolveExpressionsList() {
|
||||||
// get something for offline mode (default images)
|
// get something for offline mode (default images)
|
||||||
if (!modules.includes('classify') && !extension_settings.expressions.local) {
|
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||||
return DEFAULT_EXPRESSIONS;
|
return DEFAULT_EXPRESSIONS;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (extension_settings.expressions.local) {
|
if (extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||||
const apiResult = await fetch('/api/extra/classify/labels', {
|
|
||||||
method: 'POST',
|
|
||||||
headers: getRequestHeaders(),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (apiResult.ok) {
|
|
||||||
const data = await apiResult.json();
|
|
||||||
expressionsList = data.labels;
|
|
||||||
return expressionsList;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const url = new URL(getApiUrl());
|
const url = new URL(getApiUrl());
|
||||||
url.pathname = '/api/classify/labels';
|
url.pathname = '/api/classify/labels';
|
||||||
|
|
||||||
|
@ -1204,6 +1229,17 @@ async function getExpressionsList() {
|
||||||
|
|
||||||
if (apiResult.ok) {
|
if (apiResult.ok) {
|
||||||
|
|
||||||
|
const data = await apiResult.json();
|
||||||
|
expressionsList = data.labels;
|
||||||
|
return expressionsList;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const apiResult = await fetch('/api/extra/classify/labels', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: getRequestHeaders(),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (apiResult.ok) {
|
||||||
const data = await apiResult.json();
|
const data = await apiResult.json();
|
||||||
expressionsList = data.labels;
|
expressionsList = data.labels;
|
||||||
return expressionsList;
|
return expressionsList;
|
||||||
|
@ -1444,6 +1480,15 @@ async function onClickExpressionRemoveCustom() {
|
||||||
moduleWorker();
|
moduleWorker();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function onExperesionApiChanged() {
|
||||||
|
const tempApi = this.value;
|
||||||
|
if (tempApi) {
|
||||||
|
extension_settings.expressions.api = Number(tempApi);
|
||||||
|
moduleWorker();
|
||||||
|
saveSettingsDebounced();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function onExpressionFallbackChanged() {
|
function onExpressionFallbackChanged() {
|
||||||
const expression = this.value;
|
const expression = this.value;
|
||||||
if (expression) {
|
if (expression) {
|
||||||
|
@ -1556,6 +1601,7 @@ async function onClickExpressionOverrideButton() {
|
||||||
|
|
||||||
// Refresh sprites list. Assume the override path has been properly handled.
|
// Refresh sprites list. Assume the override path has been properly handled.
|
||||||
try {
|
try {
|
||||||
|
inApiCall = true;
|
||||||
$('#visual-novel-wrapper').empty();
|
$('#visual-novel-wrapper').empty();
|
||||||
await validateImages(overridePath.length === 0 ? currentLastMessage.name : overridePath, true);
|
await validateImages(overridePath.length === 0 ? currentLastMessage.name : overridePath, true);
|
||||||
const expression = await getExpressionLabel(currentLastMessage.mes);
|
const expression = await getExpressionLabel(currentLastMessage.mes);
|
||||||
|
@ -1563,6 +1609,8 @@ async function onClickExpressionOverrideButton() {
|
||||||
forceUpdateVisualNovelMode();
|
forceUpdateVisualNovelMode();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.debug(`Setting expression override for ${avatarFileName} failed with error: ${error}`);
|
console.debug(`Setting expression override for ${avatarFileName} failed with error: ${error}`);
|
||||||
|
} finally {
|
||||||
|
inApiCall = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1699,6 +1747,17 @@ async function fetchImagesNoCache() {
|
||||||
return await Promise.allSettled(promises);
|
return await Promise.allSettled(promises);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function migrateSettings() {
|
||||||
|
if (Object.keys(extension_settings.expressions).includes('local')) {
|
||||||
|
if (extension_settings.expressions.local) {
|
||||||
|
extension_settings.expressions.api == EXPRESSION_API.local;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete extension_settings.expressions.local;
|
||||||
|
saveSettingsDebounced();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
(async function () {
|
(async function () {
|
||||||
function addExpressionImage() {
|
function addExpressionImage() {
|
||||||
const html = `
|
const html = `
|
||||||
|
@ -1730,11 +1789,6 @@ async function fetchImagesNoCache() {
|
||||||
extension_settings.expressions.translate = !!$(this).prop('checked');
|
extension_settings.expressions.translate = !!$(this).prop('checked');
|
||||||
saveSettingsDebounced();
|
saveSettingsDebounced();
|
||||||
});
|
});
|
||||||
$('#expression_local').prop('checked', extension_settings.expressions.local).on('input', function () {
|
|
||||||
extension_settings.expressions.local = !!$(this).prop('checked');
|
|
||||||
moduleWorker();
|
|
||||||
saveSettingsDebounced();
|
|
||||||
});
|
|
||||||
$('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton);
|
$('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton);
|
||||||
$(document).on('dragstart', '.expression', (e) => {
|
$(document).on('dragstart', '.expression', (e) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
@ -1753,10 +1807,12 @@ async function fetchImagesNoCache() {
|
||||||
});
|
});
|
||||||
|
|
||||||
await renderAdditionalExpressionSettings();
|
await renderAdditionalExpressionSettings();
|
||||||
|
$('#expression_api').val(extension_settings.expressions.api || 0);
|
||||||
|
|
||||||
$('#expression_custom_add').on('click', onClickExpressionAddCustom);
|
$('#expression_custom_add').on('click', onClickExpressionAddCustom);
|
||||||
$('#expression_custom_remove').on('click', onClickExpressionRemoveCustom);
|
$('#expression_custom_remove').on('click', onClickExpressionRemoveCustom);
|
||||||
$('#expression_fallback').on('change', onExpressionFallbackChanged);
|
$('#expression_fallback').on('change', onExpressionFallbackChanged);
|
||||||
|
$('#expression_api').on('change', onExperesionApiChanged);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pause Talkinghead to save resources when the ST tab is not visible or the window is minimized.
|
// Pause Talkinghead to save resources when the ST tab is not visible or the window is minimized.
|
||||||
|
@ -1789,6 +1845,7 @@ async function fetchImagesNoCache() {
|
||||||
|
|
||||||
addExpressionImage();
|
addExpressionImage();
|
||||||
addVisualNovelMode();
|
addVisualNovelMode();
|
||||||
|
migrateSettings();
|
||||||
await addSettings();
|
await addSettings();
|
||||||
const wrapper = new ModuleWorkerWrapper(moduleWorker);
|
const wrapper = new ModuleWorkerWrapper(moduleWorker);
|
||||||
const updateFunction = wrapper.update.bind(wrapper);
|
const updateFunction = wrapper.update.bind(wrapper);
|
||||||
|
@ -1828,6 +1885,7 @@ async function fetchImagesNoCache() {
|
||||||
});
|
});
|
||||||
eventSource.on(event_types.MOVABLE_PANELS_RESET, updateVisualNovelModeDebounced);
|
eventSource.on(event_types.MOVABLE_PANELS_RESET, updateVisualNovelModeDebounced);
|
||||||
eventSource.on(event_types.GROUP_UPDATED, updateVisualNovelModeDebounced);
|
eventSource.on(event_types.GROUP_UPDATED, updateVisualNovelModeDebounced);
|
||||||
|
eventSource.on(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
|
||||||
registerSlashCommand('sprite', setSpriteSlashCommand, ['emote'], '<span class="monospace">(spriteId)</span> – force sets the sprite for the current character', true, true);
|
registerSlashCommand('sprite', setSpriteSlashCommand, ['emote'], '<span class="monospace">(spriteId)</span> – force sets the sprite for the current character', true, true);
|
||||||
registerSlashCommand('spriteoverride', setSpriteSetCommand, ['costume'], '<span class="monospace">(optional folder)</span> – sets an override sprite folder for the current character. If the name starts with a slash or a backslash, selects a sub-folder in the character-named folder. Empty value to reset to default.', true, true);
|
registerSlashCommand('spriteoverride', setSpriteSetCommand, ['costume'], '<span class="monospace">(optional folder)</span> – sets an override sprite folder for the current character. If the name starts with a slash or a backslash, selects a sub-folder in the character-named folder. Empty value to reset to default.', true, true);
|
||||||
registerSlashCommand('lastsprite', (_, value) => lastExpression[value.trim()] ?? '', [], '<span class="monospace">(charName)</span> – Returns the last set sprite / expression for the named character.', true, true);
|
registerSlashCommand('lastsprite', (_, value) => lastExpression[value.trim()] ?? '', [], '<span class="monospace">(charName)</span> – Returns the last set sprite / expression for the named character.', true, true);
|
||||||
|
|
|
@ -6,10 +6,6 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="inline-drawer-content">
|
<div class="inline-drawer-content">
|
||||||
<label class="checkbox_label" for="expression_local" title="Use classification model without the Extras server.">
|
|
||||||
<input id="expression_local" type="checkbox" />
|
|
||||||
<span data-i18n="Local server classification">Local server classification</span>
|
|
||||||
</label>
|
|
||||||
<label class="checkbox_label" for="expression_translate" title="Use the selected API from Chat Translation extension settings.">
|
<label class="checkbox_label" for="expression_translate" title="Use the selected API from Chat Translation extension settings.">
|
||||||
<input id="expression_translate" type="checkbox">
|
<input id="expression_translate" type="checkbox">
|
||||||
<span>Translate text to English before classification</span>
|
<span>Translate text to English before classification</span>
|
||||||
|
@ -22,6 +18,16 @@
|
||||||
<input id="image_type_toggle" type="checkbox">
|
<input id="image_type_toggle" type="checkbox">
|
||||||
<span>Image Type - talkinghead (extras)</span>
|
<span>Image Type - talkinghead (extras)</span>
|
||||||
</label>
|
</label>
|
||||||
|
<div class="expression_api_block m-b-1 m-t-1">
|
||||||
|
<label for="expression_api">Classifier API</label>
|
||||||
|
<small>Select the API for classifying expressions.</small>
|
||||||
|
<select id="expression_api" class="flex1 margin0" data-i18n="Expression API" placeholder="Expression API">
|
||||||
|
<option value="0">Local</option>
|
||||||
|
<option value="1">Extras</option>
|
||||||
|
<option value="2">LLM</option>
|
||||||
|
<option value="3">TalkingHead</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
<div class="expression_fallback_block m-b-1 m-t-1">
|
<div class="expression_fallback_block m-b-1 m-t-1">
|
||||||
<label for="expression_fallback">Default / Fallback Expression</label>
|
<label for="expression_fallback">Default / Fallback Expression</label>
|
||||||
<small>Set the default and fallback expression being used when no matching expression is found.</small>
|
<small>Set the default and fallback expression being used when no matching expression is found.</small>
|
||||||
|
|
Loading…
Reference in New Issue