mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-03-04 19:58:34 +01:00
Expressions: Classify using LLM
Rather than using a separate BERT model to classify the last message, use the LLM itself to get the classified expression label as a JSON and set that as the current sprite. Doing this should take more information into consideration and cut down on extra processing. This is made possible by the use of constrained generation with JSON schemas. Only available to TabbyAPI since it's the only backend that supports the use of JSON schemas, but there can hopefully be a way to use this with other backends as well. Intercepts the generation and sets top_k = 1 (for greedy sampling) and the json_schema to an emotion enum. Doing this also prevents reingestion of the entire context every time a message is sent and then asked to be classified, which doesn't compromise the chat experience. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
626c93a1ab
commit
6b656bf380
@ -1,4 +1,4 @@
|
||||
import { callPopup, eventSource, event_types, getRequestHeaders, saveSettingsDebounced } from '../../../script.js';
|
||||
import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, saveSettingsDebounced } from '../../../script.js';
|
||||
import { dragElement, isMobile } from '../../RossAscends-mods.js';
|
||||
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js';
|
||||
import { loadMovingUIState, power_user } from '../../power-user.js';
|
||||
@ -43,6 +43,11 @@ const DEFAULT_EXPRESSIONS = [
|
||||
'surprise',
|
||||
'neutral',
|
||||
];
|
||||
const EXPRESSION_API = {
|
||||
local: 0,
|
||||
extras: 1,
|
||||
llm: 2,
|
||||
}
|
||||
|
||||
let expressionsList = null;
|
||||
let lastCharacter = undefined;
|
||||
@ -55,7 +60,7 @@ let lastServerResponseTime = 0;
|
||||
export let lastExpression = {};
|
||||
|
||||
function isTalkingHeadEnabled() {
|
||||
return extension_settings.expressions.talkinghead && !extension_settings.expressions.local;
|
||||
return extension_settings.expressions.talkinghead && extension_settings.expressions.api == EXPRESSION_API.extras;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -585,10 +590,10 @@ function handleImageChange() {
|
||||
async function moduleWorker() {
|
||||
const context = getContext();
|
||||
|
||||
// Hide and disable Talkinghead while in local mode
|
||||
$('#image_type_block').toggle(!extension_settings.expressions.local);
|
||||
// Hide and disable Talkinghead while not in extras
|
||||
$('#image_type_block').toggle(extension_settings.expressions.api == EXPRESSION_API.extras);
|
||||
|
||||
if (extension_settings.expressions.local && extension_settings.expressions.talkinghead) {
|
||||
if (extension_settings.expressions.api != EXPRESSION_API.extras && extension_settings.expressions.talkinghead) {
|
||||
$('#image_type_toggle').prop('checked', false);
|
||||
setTalkingHeadState(false);
|
||||
}
|
||||
@ -628,7 +633,7 @@ async function moduleWorker() {
|
||||
}
|
||||
|
||||
const offlineMode = $('.expression_settings .offline_mode');
|
||||
if (!modules.includes('classify') && !extension_settings.expressions.local) {
|
||||
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||
$('#open_chat_expressions').show();
|
||||
$('#no_chat_expressions').hide();
|
||||
offlineMode.css('display', 'block');
|
||||
@ -821,7 +826,7 @@ function setTalkingHeadState(newState) {
|
||||
extension_settings.expressions.talkinghead = newState; // Store setting
|
||||
saveSettingsDebounced();
|
||||
|
||||
if (extension_settings.expressions.local) {
|
||||
if (extension_settings.expressions.api == EXPRESSION_API.local) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -900,7 +905,7 @@ async function classifyCommand(_, text) {
|
||||
return '';
|
||||
}
|
||||
|
||||
if (!modules.includes('classify') && !extension_settings.expressions.local) {
|
||||
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||
toastr.warning('Text classification is disabled or not available');
|
||||
return '';
|
||||
}
|
||||
@ -971,9 +976,32 @@ function sampleClassifyText(text) {
|
||||
return result.trim();
|
||||
}
|
||||
|
||||
function onTextGenSettingsReady(args) {
|
||||
// Only call if inside an API call
|
||||
if (inApiCall) {
|
||||
const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead')
|
||||
Object.assign(args, {
|
||||
top_k: 1,
|
||||
json_schema: {
|
||||
$schema: "http://json-schema.org/draft-04/schema#",
|
||||
type: "object",
|
||||
properties: {
|
||||
emotion: {
|
||||
type: "string",
|
||||
enum: emotions
|
||||
}
|
||||
},
|
||||
required: [
|
||||
"emotion"
|
||||
]
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async function getExpressionLabel(text) {
|
||||
// Return if text is undefined, saving a costly fetch request
|
||||
if ((!modules.includes('classify') && !extension_settings.expressions.local) || !text) {
|
||||
if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) {
|
||||
return getFallbackExpression();
|
||||
}
|
||||
|
||||
@ -984,36 +1012,44 @@ async function getExpressionLabel(text) {
|
||||
text = sampleClassifyText(text);
|
||||
|
||||
try {
|
||||
if (extension_settings.expressions.local) {
|
||||
// Local transformers pipeline
|
||||
const apiResult = await fetch('/api/extra/classify', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ text: text }),
|
||||
});
|
||||
switch (extension_settings.expressions.api) {
|
||||
case EXPRESSION_API.local:
|
||||
// Local BERT pipeline
|
||||
const localResult = await fetch('/api/extra/classify', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body: JSON.stringify({ text: text }),
|
||||
});
|
||||
|
||||
if (apiResult.ok) {
|
||||
const data = await apiResult.json();
|
||||
return data.classification[0].label;
|
||||
}
|
||||
} else {
|
||||
// Extras
|
||||
const url = new URL(getApiUrl());
|
||||
url.pathname = '/api/classify';
|
||||
if (localResult.ok) {
|
||||
const data = await localResult.json();
|
||||
return data.classification[0].label;
|
||||
}
|
||||
|
||||
const apiResult = await doExtrasFetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Bypass-Tunnel-Reminder': 'bypass',
|
||||
},
|
||||
body: JSON.stringify({ text: text }),
|
||||
});
|
||||
break;
|
||||
case EXPRESSION_API.llm:
|
||||
// Using LLM
|
||||
const emotionResponse = await generateQuietPrompt('', false);
|
||||
const parsedEmotion = JSON.parse(emotionResponse);
|
||||
return parsedEmotion.emotion;
|
||||
default:
|
||||
// Extras
|
||||
const url = new URL(getApiUrl());
|
||||
url.pathname = '/api/classify';
|
||||
|
||||
if (apiResult.ok) {
|
||||
const data = await apiResult.json();
|
||||
return data.classification[0].label;
|
||||
}
|
||||
const extrasResult = await doExtrasFetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Bypass-Tunnel-Reminder': 'bypass',
|
||||
},
|
||||
body: JSON.stringify({ text: text }),
|
||||
});
|
||||
|
||||
if (extrasResult.ok) {
|
||||
const data = await extrasResult.json();
|
||||
return data.classification[0].label;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
@ -1177,23 +1213,12 @@ async function getExpressionsList() {
|
||||
*/
|
||||
async function resolveExpressionsList() {
|
||||
// get something for offline mode (default images)
|
||||
if (!modules.includes('classify') && !extension_settings.expressions.local) {
|
||||
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||
return DEFAULT_EXPRESSIONS;
|
||||
}
|
||||
|
||||
try {
|
||||
if (extension_settings.expressions.local) {
|
||||
const apiResult = await fetch('/api/extra/classify/labels', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
});
|
||||
|
||||
if (apiResult.ok) {
|
||||
const data = await apiResult.json();
|
||||
expressionsList = data.labels;
|
||||
return expressionsList;
|
||||
}
|
||||
} else {
|
||||
if (extension_settings.expressions.api == EXPRESSION_API.extras) {
|
||||
const url = new URL(getApiUrl());
|
||||
url.pathname = '/api/classify/labels';
|
||||
|
||||
@ -1204,6 +1229,17 @@ async function getExpressionsList() {
|
||||
|
||||
if (apiResult.ok) {
|
||||
|
||||
const data = await apiResult.json();
|
||||
expressionsList = data.labels;
|
||||
return expressionsList;
|
||||
}
|
||||
} else {
|
||||
const apiResult = await fetch('/api/extra/classify/labels', {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
});
|
||||
|
||||
if (apiResult.ok) {
|
||||
const data = await apiResult.json();
|
||||
expressionsList = data.labels;
|
||||
return expressionsList;
|
||||
@ -1444,6 +1480,15 @@ async function onClickExpressionRemoveCustom() {
|
||||
moduleWorker();
|
||||
}
|
||||
|
||||
function onExperesionApiChanged() {
|
||||
const tempApi = this.value;
|
||||
if (tempApi) {
|
||||
extension_settings.expressions.api = Number(tempApi);
|
||||
moduleWorker();
|
||||
saveSettingsDebounced();
|
||||
}
|
||||
}
|
||||
|
||||
function onExpressionFallbackChanged() {
|
||||
const expression = this.value;
|
||||
if (expression) {
|
||||
@ -1556,6 +1601,7 @@ async function onClickExpressionOverrideButton() {
|
||||
|
||||
// Refresh sprites list. Assume the override path has been properly handled.
|
||||
try {
|
||||
inApiCall = true;
|
||||
$('#visual-novel-wrapper').empty();
|
||||
await validateImages(overridePath.length === 0 ? currentLastMessage.name : overridePath, true);
|
||||
const expression = await getExpressionLabel(currentLastMessage.mes);
|
||||
@ -1563,6 +1609,8 @@ async function onClickExpressionOverrideButton() {
|
||||
forceUpdateVisualNovelMode();
|
||||
} catch (error) {
|
||||
console.debug(`Setting expression override for ${avatarFileName} failed with error: ${error}`);
|
||||
} finally {
|
||||
inApiCall = false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1699,6 +1747,17 @@ async function fetchImagesNoCache() {
|
||||
return await Promise.allSettled(promises);
|
||||
}
|
||||
|
||||
function migrateSettings() {
|
||||
if (Object.keys(extension_settings.expressions).includes('local')) {
|
||||
if (extension_settings.expressions.local) {
|
||||
extension_settings.expressions.api == EXPRESSION_API.local;
|
||||
}
|
||||
|
||||
delete extension_settings.expressions.local;
|
||||
saveSettingsDebounced();
|
||||
}
|
||||
}
|
||||
|
||||
(async function () {
|
||||
function addExpressionImage() {
|
||||
const html = `
|
||||
@ -1730,11 +1789,6 @@ async function fetchImagesNoCache() {
|
||||
extension_settings.expressions.translate = !!$(this).prop('checked');
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
$('#expression_local').prop('checked', extension_settings.expressions.local).on('input', function () {
|
||||
extension_settings.expressions.local = !!$(this).prop('checked');
|
||||
moduleWorker();
|
||||
saveSettingsDebounced();
|
||||
});
|
||||
$('#expression_override_cleanup_button').on('click', onClickExpressionOverrideRemoveAllButton);
|
||||
$(document).on('dragstart', '.expression', (e) => {
|
||||
e.preventDefault();
|
||||
@ -1753,10 +1807,12 @@ async function fetchImagesNoCache() {
|
||||
});
|
||||
|
||||
await renderAdditionalExpressionSettings();
|
||||
$('#expression_api').val(extension_settings.expressions.api || 0);
|
||||
|
||||
$('#expression_custom_add').on('click', onClickExpressionAddCustom);
|
||||
$('#expression_custom_remove').on('click', onClickExpressionRemoveCustom);
|
||||
$('#expression_fallback').on('change', onExpressionFallbackChanged);
|
||||
$('#expression_api').on('change', onExperesionApiChanged);
|
||||
}
|
||||
|
||||
// Pause Talkinghead to save resources when the ST tab is not visible or the window is minimized.
|
||||
@ -1789,6 +1845,7 @@ async function fetchImagesNoCache() {
|
||||
|
||||
addExpressionImage();
|
||||
addVisualNovelMode();
|
||||
migrateSettings();
|
||||
await addSettings();
|
||||
const wrapper = new ModuleWorkerWrapper(moduleWorker);
|
||||
const updateFunction = wrapper.update.bind(wrapper);
|
||||
@ -1828,6 +1885,7 @@ async function fetchImagesNoCache() {
|
||||
});
|
||||
eventSource.on(event_types.MOVABLE_PANELS_RESET, updateVisualNovelModeDebounced);
|
||||
eventSource.on(event_types.GROUP_UPDATED, updateVisualNovelModeDebounced);
|
||||
eventSource.on(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
|
||||
registerSlashCommand('sprite', setSpriteSlashCommand, ['emote'], '<span class="monospace">(spriteId)</span> – force sets the sprite for the current character', true, true);
|
||||
registerSlashCommand('spriteoverride', setSpriteSetCommand, ['costume'], '<span class="monospace">(optional folder)</span> – sets an override sprite folder for the current character. If the name starts with a slash or a backslash, selects a sub-folder in the character-named folder. Empty value to reset to default.', true, true);
|
||||
registerSlashCommand('lastsprite', (_, value) => lastExpression[value.trim()] ?? '', [], '<span class="monospace">(charName)</span> – Returns the last set sprite / expression for the named character.', true, true);
|
||||
|
@ -6,10 +6,6 @@
|
||||
</div>
|
||||
|
||||
<div class="inline-drawer-content">
|
||||
<label class="checkbox_label" for="expression_local" title="Use classification model without the Extras server.">
|
||||
<input id="expression_local" type="checkbox" />
|
||||
<span data-i18n="Local server classification">Local server classification</span>
|
||||
</label>
|
||||
<label class="checkbox_label" for="expression_translate" title="Use the selected API from Chat Translation extension settings.">
|
||||
<input id="expression_translate" type="checkbox">
|
||||
<span>Translate text to English before classification</span>
|
||||
@ -22,6 +18,16 @@
|
||||
<input id="image_type_toggle" type="checkbox">
|
||||
<span>Image Type - talkinghead (extras)</span>
|
||||
</label>
|
||||
<div class="expression_api_block m-b-1 m-t-1">
|
||||
<label for="expression_api">Classifier API</label>
|
||||
<small>Select the API for classifying expressions.</small>
|
||||
<select id="expression_api" class="flex1 margin0" data-i18n="Expression API" placeholder="Expression API">
|
||||
<option value="0">Local</option>
|
||||
<option value="1">Extras</option>
|
||||
<option value="2">LLM</option>
|
||||
<option value="3">TalkingHead</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="expression_fallback_block m-b-1 m-t-1">
|
||||
<label for="expression_fallback">Default / Fallback Expression</label>
|
||||
<small>Set the default and fallback expression being used when no matching expression is found.</small>
|
||||
|
Loading…
x
Reference in New Issue
Block a user