Expressions: Classify using LLM

Rather than using a separate BERT model to classify the last message,
use the LLM itself to get the classified expression label as a JSON
and set that as the current sprite. Doing this should take more information
into consideration and cut down on extra processing.

This is made possible by the use of constrained generation with JSON
schemas. Only available to TabbyAPI since it's the only backend that
supports the use of JSON schemas, but there can hopefully be a way
to use this with other backends as well.

Intercepts the generation and sets top_k = 1 (for greedy sampling)
and the json_schema to an emotion enum. Doing this also prevents
reingestion of the entire context every time a message is sent and
then asked to be classified, which doesn't compromise the chat
experience.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-12 01:49:09 -04:00
parent 626c93a1ab
commit 6b656bf380
2 changed files with 122 additions and 58 deletions

View File

@ -1,4 +1,4 @@
import { callPopup, eventSource, event_types, getRequestHeaders, saveSettingsDebounced } from '../../../script.js';
import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, saveSettingsDebounced } from '../../../script.js';
import { dragElement, isMobile } from '../../RossAscends-mods.js';
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js';
import { loadMovingUIState, power_user } from '../../power-user.js';
@ -43,6 +43,11 @@ const DEFAULT_EXPRESSIONS = [
'surprise',
'neutral',
];
const EXPRESSION_API = {
local: 0,
extras: 1,
llm: 2,
}
let expressionsList = null;
let lastCharacter = undefined;
@ -55,7 +60,7 @@ let lastServerResponseTime = 0;
export let lastExpression = {};
function isTalkingHeadEnabled() {
return extension_settings.expressions.talkinghead && !extension_settings.expressions.local;
return extension_settings.expressions.talkinghead && extension_settings.expressions.api == EXPRESSION_API.extras;
}
/**
@ -585,10 +590,10 @@ function handleImageChange() {
async function moduleWorker() {
const context = getContext();
// Hide and disable Talkinghead while in local mode
$('#image_type_block').toggle(!extension_settings.expressions.local);
// Hide and disable Talkinghead while not in extras
$('#image_type_block').toggle(extension_settings.expressions.api == EXPRESSION_API.extras);
if (extension_settings.expressions.local && extension_settings.expressions.talkinghead) {
if (extension_settings.expressions.api != EXPRESSION_API.extras && extension_settings.expressions.talkinghead) {
$('#image_type_toggle').prop('checked', false);
setTalkingHeadState(false);
}
@ -628,7 +633,7 @@ async function moduleWorker() {
}
const offlineMode = $('.expression_settings .offline_mode');
if (!modules.includes('classify') && !extension_settings.expressions.local) {
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
$('#open_chat_expressions').show();
$('#no_chat_expressions').hide();
offlineMode.css('display', 'block');
@ -821,7 +826,7 @@ function setTalkingHeadState(newState) {
extension_settings.expressions.talkinghead = newState; // Store setting
saveSettingsDebounced();
if (extension_settings.expressions.local) {
if (extension_settings.expressions.api == EXPRESSION_API.local) {
return;
}
@ -900,7 +905,7 @@ async function classifyCommand(_, text) {
return '';
}
if (!modules.includes('classify') && !extension_settings.expressions.local) {
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
toastr.warning('Text classification is disabled or not available');
return '';
}
@ -971,9 +976,32 @@ function sampleClassifyText(text) {
return result.trim();
}
function onTextGenSettingsReady(args) {
// Only call if inside an API call
if (inApiCall) {
const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead')
Object.assign(args, {
top_k: 1,
json_schema: {
$schema: "http://json-schema.org/draft-04/schema#",
type: "object",
properties: {
emotion: {
type: "string",
enum: emotions
}
},
required: [
"emotion"
]
}
});
}
}
async function getExpressionLabel(text) {
// Return if text is undefined, saving a costly fetch request
if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) {
if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) {
return getFallbackExpression();
}
@ -984,36 +1012,44 @@ async function getExpressionLabel(text) {
text = sampleClassifyText(text);
try {
if (extension_settings.expressions.local) {
// Local transformers pipeline
const apiResult = await fetch('/api/extra/classify', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ text: text }),
});
switch (extension_settings.expressions.api) {
case EXPRESSION_API.local:
// Local BERT pipeline
const localResult = await fetch('/api/extra/classify', {
method: 'POST',
headers: getRequestHeaders(),
body: JSON.stringify({ text: text }),
});
if (apiResult.ok) {
const data = await apiResult.json();
return data.classification[0].label;
}
} else {
// Extras
const url = new URL(getApiUrl());
url.pathname = '/api/classify';
if (localResult.ok) {
const data = await localResult.json();
return data.classification[0].label;
}
const apiResult = await doExtrasFetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Bypass-Tunnel-Reminder': 'bypass',
},
body: JSON.stringify({ text: text }),
});
break;
case EXPRESSION_API.llm:
// Using LLM
const emotionResponse = await generateQuietPrompt('', false);
const parsedEmotion = JSON.parse(emotionResponse);
return parsedEmotion.emotion;
default:
// Extras
const url = new URL(getApiUrl());
url.pathname = '/api/classify';
if (apiResult.ok) {
const data = await apiResult.json();
return data.classification[0].label;
}
const extrasResult = await doExtrasFetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Bypass-Tunnel-Reminder': 'bypass',
},
body: JSON.stringify({ text: text }),
});
if (extrasResult.ok) {
const data = await extrasResult.json();
return data.classification[0].label;
}
}
} catch (error) {
console.log(error);
@ -1177,23 +1213,12 @@ async function getExpressionsList() {
*/
async function resolveExpressionsList() {
// get something for offline mode (default images)
if (!modules.includes('classify') && !extension_settings.expressions.local) {
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
return DEFAULT_EXPRESSIONS;
}
try {
if (extension_settings.expressions.local) {
const apiResult = await fetch('/api/extra/classify/labels', {
method: 'POST',
headers: getRequestHeaders(),
});
if (apiResult.ok) {
const data = await apiResult.json();
expressionsList = data.labels;
return expressionsList;
}
} else {
if (extension_settings.expressions.api == EXPRESSION_API.extras) {
const url = new URL(getApiUrl());
url.pathname = '/api/classify/labels';
@ -1204,6 +1229,17 @@ async function getExpressionsList() {
if (apiResult.ok) {
const data = await apiResult.json();
expressionsList = data.labels;
return expressionsList;
}
} else {
const apiResult = await fetch('/api/extra/classify/labels', {
method: 'POST',
headers: getRequestHeaders(),
});
if (apiResult.ok) {
const data = await apiResult.json();
expressionsList = data.labels;
return expressionsList;
@ -1444,6 +1480,15 @@ async function onClickExpressionRemoveCustom() {
moduleWorker();
}
function onExperesionApiChanged() {
const tempApi = this.value;
if (tempApi) {
extension_settings.expressions.api = Number(tempApi);
moduleWorker();
saveSettingsDebounced();
}
}
function onExpressionFallbackChanged() {
const expression = this.value;
if (expression) {
@ -1556,6 +1601,7 @@ async function onClickExpressionOverrideButton() {
// Refresh sprites list. Assume the override path has been properly handled.
try {
inApiCall = true;
$('#visual-novel-wrapper').empty();
await validateImages(overridePath.length === 0 ? currentLastMessage.name : overridePath, true);
const expression = await getExpressionLabel(currentLastMessage.mes);
@ -1563,6 +1609,8 @@ async function onClickExpressionOverrideButton() {
forceUpdateVisualNovelMode();
} catch (error) {
console.debug(`Setting expression override for ${avatarFileName} failed with error: ${error}`);
} finally {
inApiCall = false;
}
}
@ -1699,6 +1747,17 @@ async function fetchImagesNoCache() {
return await Promise.allSettled(promises);
}
function migrateSettings() {
if (Object.keys(extension_settings.expressions).includes('local')) {
if (extension_settings.expressions.local) {
extension_settings.expressions.api == EXPRESSION_API.local;
}
delete extension_settings.expressions.local;
saveSettingsDebounced();
}
}
(async function () {
function addExpressionImage() {
const html = `
@ -1730,11 +1789,6 @@ async function fetchImagesNoCache() {
extension_settings.expressions.translate = !!$(this).prop('checked');
saveSettingsDebounced();
});
$('#expression_local').prop('checked', extension_settings.expressions.local).on('input', function () {
extension_settings.expressions.local = !!$(this).prop('checked');
moduleWorker();
saveSettingsDebounced();
});
$('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton);
$(document).on('dragstart', '.expression', (e) => {
e.preventDefault();
@ -1753,10 +1807,12 @@ async function fetchImagesNoCache() {
});
await renderAdditionalExpressionSettings();
$('#expression_api').val(extension_settings.expressions.api || 0);
$('#expression_custom_add').on('click', onClickExpressionAddCustom);
$('#expression_custom_remove').on('click', onClickExpressionRemoveCustom);
$('#expression_fallback').on('change', onExpressionFallbackChanged);
$('#expression_api').on('change', onExperesionApiChanged);
}
// Pause Talkinghead to save resources when the ST tab is not visible or the window is minimized.
@ -1789,6 +1845,7 @@ async function fetchImagesNoCache() {
addExpressionImage();
addVisualNovelMode();
migrateSettings();
await addSettings();
const wrapper = new ModuleWorkerWrapper(moduleWorker);
const updateFunction = wrapper.update.bind(wrapper);
@ -1828,6 +1885,7 @@ async function fetchImagesNoCache() {
});
eventSource.on(event_types.MOVABLE_PANELS_RESET, updateVisualNovelModeDebounced);
eventSource.on(event_types.GROUP_UPDATED, updateVisualNovelModeDebounced);
eventSource.on(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
registerSlashCommand('sprite', setSpriteSlashCommand, ['emote'], '<span class="monospace">(spriteId)</span> force sets the sprite for the current character', true, true);
registerSlashCommand('spriteoverride', setSpriteSetCommand, ['costume'], '<span class="monospace">(optional folder)</span> sets an override sprite folder for the current character. If the name starts with a slash or a backslash, selects a sub-folder in the character-named folder. Empty value to reset to default.', true, true);
registerSlashCommand('lastsprite', (_, value) => lastExpression[value.trim()] ?? '', [], '<span class="monospace">(charName)</span> Returns the last set sprite / expression for the named character.', true, true);

View File

@ -6,10 +6,6 @@
</div>
<div class="inline-drawer-content">
<label class="checkbox_label" for="expression_local" title="Use classification model without the Extras server.">
<input id="expression_local" type="checkbox" />
<span data-i18n="Local server classification">Local server classification</span>
</label>
<label class="checkbox_label" for="expression_translate" title="Use the selected API from Chat Translation extension settings.">
<input id="expression_translate" type="checkbox">
<span>Translate text to English before classification</span>
@ -22,6 +18,16 @@
<input id="image_type_toggle" type="checkbox">
<span>Image Type - talkinghead (extras)</span>
</label>
<div class="expression_api_block m-b-1 m-t-1">
<label for="expression_api">Classifier API</label>
<small>Select the API for classifying expressions.</small>
<select id="expression_api" class="flex1 margin0" data-i18n="Expression API" placeholder="Expression API">
<option value="0">Local</option>
<option value="1">Extras</option>
<option value="2">LLM</option>
<option value="3">TalkingHead</option>
</select>
</div>
<div class="expression_fallback_block m-b-1 m-t-1">
<label for="expression_fallback">Default / Fallback Expression</label>
<small>Set the default and fallback expression being used when no matching expression is found.</small>