diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index ab02d1a51..b79acc2bb 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -8,7 +8,7 @@ import { isJsonSchemaSupported } from '../../textgen-settings.js'; import { debounce_timeout } from '../../constants.js'; import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js'; -import { ARGUMENT_TYPE, SlashCommandArgument } from '../../slash-commands/SlashCommandArgument.js'; +import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js'; import { isFunctionCallingSupported } from '../../openai.js'; import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js'; import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js'; @@ -52,6 +52,7 @@ const DEFAULT_EXPRESSIONS = [ 'surprise', 'neutral', ]; +/** @enum {number} */ const EXPRESSION_API = { local: 0, extras: 1, @@ -920,18 +921,24 @@ async function setSpriteSetCommand(_, folder) { return ''; } -async function classifyCommand(_, text) { +async function classifyCallback(/** @type {{api: string?, prompt: string?}} */ { api = null, prompt = null }, text) { if (!text) { - console.log('No text provided'); + toastr.warning('No text provided'); + return ''; + } + if (api && !Object.keys(EXPRESSION_API).includes(api)) { + toastr.warning('Invalid API provided'); return ''; } - if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) { + const expressionApi = EXPRESSION_API[api] || extension_settings.expressions.api; + + if (!modules.includes('classify') && expressionApi == EXPRESSION_API.extras) { toastr.warning('Text classification is disabled or not available'); return ''; } - const label = getExpressionLabel(text); + const label = await getExpressionLabel(text, expressionApi, { customPrompt: prompt }); console.debug(`Classification result for "${text}": ${label}`); return label; } @@ -1108,9 +1115,18 @@ function onTextGenSettingsReady(args) { } } -async function getExpressionLabel(text) { +/** + * Retrieves the label of an expression via classification based on the provided text. + * Optionally allows to override the expressions API being used. + * @param {string} text - The text to classify and retrieve the expression label for. + * @param {EXPRESSION_API} [expressionsApi=extension_settings.expressions.api] - The expressions API to use for classification. + * @param {object} [options={}] - Optional arguments. + * @param {string?} [options.customPrompt=null] - The custom prompt to use for classification. + * @returns {Promise} - The label of the expression. + */ +export async function getExpressionLabel(text, expressionsApi = extension_settings.expressions.api, { customPrompt = null } = {}) { // Return if text is undefined, saving a costly fetch request - if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) { + if ((!modules.includes('classify') && expressionsApi == EXPRESSION_API.extras) || !text) { return getFallbackExpression(); } @@ -1121,7 +1137,7 @@ async function getExpressionLabel(text) { text = sampleClassifyText(text); try { - switch (extension_settings.expressions.api) { + switch (expressionsApi) { // Local BERT pipeline case EXPRESSION_API.local: { const localResult = await fetch('/api/extra/classify', { @@ -1145,7 +1161,7 @@ async function getExpressionLabel(text) { } const expressionsList = await getExpressionsList(); - const prompt = await getLlmPrompt(expressionsList); + const prompt = substituteParamsExtended(String(customPrompt), { labels: expressionsList }) || await getLlmPrompt(expressionsList); let functionResult = null; eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady); eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister); @@ -1338,7 +1354,7 @@ function getCachedExpressions() { return [...expressionsList, ...extension_settings.expressions.custom].filter(onlyUnique); } -async function getExpressionsList() { +export async function getExpressionsList() { // Return cached list if available if (Array.isArray(expressionsList)) { return getCachedExpressions(); @@ -2069,7 +2085,7 @@ function migrateSettings() { }), ], helpString: 'Force sets the sprite for the current character.', - returns: 'label', + returns: 'the currently set sprite label after setting it.', })); SlashCommandParser.addCommandObject(SlashCommand.fromProps({ name: 'spriteoverride', @@ -2085,7 +2101,7 @@ function migrateSettings() { SlashCommandParser.addCommandObject(SlashCommand.fromProps({ name: 'lastsprite', callback: (_, value) => lastExpression[String(value).trim()] ?? '', - returns: 'sprite', + returns: 'the last set sprite / expression for the named character.', unnamedArgumentList: [ SlashCommandArgument.fromProps({ description: 'character name', @@ -2101,11 +2117,50 @@ function migrateSettings() { callback: toggleTalkingHeadCommand, aliases: ['talkinghead'], helpString: 'Character Expressions: toggles Image Type - talkinghead (extras) on/off.', - returns: ARGUMENT_TYPE.BOOLEAN, + returns: 'the current state of the Image Type - talkinghead (extras) on/off.', + })); + SlashCommandParser.addCommandObject(SlashCommand.fromProps({ + name: 'classify-expressions', + aliases: ['expressions'], + callback: async (args) => { + const list = await getExpressionsList(); + switch (String(args.format).toLowerCase()) { + case 'json': + return JSON.stringify(list); + default: + return list.join(', '); + } + }, + namedArgumentList: [ + SlashCommandNamedArgument.fromProps({ + name: 'format', + description: 'The format to return the list in: comma-separated plain text or JSON array. Default is plain text.', + typeList: [ARGUMENT_TYPE.STRING], + enumList: [ + new SlashCommandEnumValue('plain', null, enumTypes.enum, ', '), + new SlashCommandEnumValue('json', null, enumTypes.enum, '[]'), + ], + }), + ], + returns: 'The comma-separated list of available expressions, including custom expressions.', + helpString: 'Returns a list of available expressions, including custom expressions.', })); SlashCommandParser.addCommandObject(SlashCommand.fromProps({ name: 'classify', - callback: classifyCommand, + callback: classifyCallback, + namedArgumentList: [ + SlashCommandNamedArgument.fromProps({ + name: 'api', + description: 'The Classifier API to classify with. If not specified, the configured one will be used.', + typeList: [ARGUMENT_TYPE.STRING], + enumList: Object.keys(EXPRESSION_API).map(api => new SlashCommandEnumValue(api, null, enumTypes.enum)), + }), + SlashCommandNamedArgument.fromProps({ + name: 'prompt', + description: 'Custom prompt for classification. Only relevant if Classifier API is set to LLM.', + typeList: [ARGUMENT_TYPE.STRING], + }), + ], unnamedArgumentList: [ new SlashCommandArgument( 'text', [ARGUMENT_TYPE.STRING], true, @@ -2116,6 +2171,9 @@ function migrateSettings() {
Performs an emotion classification of the given text and returns a label.
+
+ Allows to specify which Classifier API to perform the classification with. +
Example: