diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index ab02d1a51..ef5000fac 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -8,7 +8,7 @@ import { isJsonSchemaSupported } from '../../textgen-settings.js'; import { debounce_timeout } from '../../constants.js'; import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js'; -import { ARGUMENT_TYPE, SlashCommandArgument } from '../../slash-commands/SlashCommandArgument.js'; +import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js'; import { isFunctionCallingSupported } from '../../openai.js'; import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js'; import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js'; @@ -52,6 +52,7 @@ const DEFAULT_EXPRESSIONS = [ 'surprise', 'neutral', ]; +/** @enum {number} */ const EXPRESSION_API = { local: 0, extras: 1, @@ -920,18 +921,24 @@ async function setSpriteSetCommand(_, folder) { return ''; } -async function classifyCommand(_, text) { +async function classifyCallback(/** @type {{api: string?}} */ { api = null }, text) { if (!text) { - console.log('No text provided'); + toastr.warning('No text provided'); + return ''; + } + if (api && !Object.keys(EXPRESSION_API).includes(api)) { + toastr.warning('Invalid API provided'); return ''; } - if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) { + const expressionApi = EXPRESSION_API[api] || extension_settings.expressions.api; + + if (!modules.includes('classify') && expressionApi == EXPRESSION_API.extras) { toastr.warning('Text classification is disabled or not available'); return ''; } - const label = getExpressionLabel(text); + const label = getExpressionLabel(text, expressionApi); console.debug(`Classification result for "${text}": ${label}`); return label; } @@ -1108,9 +1115,16 @@ function onTextGenSettingsReady(args) { } } -async function getExpressionLabel(text) { +/** + * Retrieves the label of an expression via classification based on the provided text. + * Optionally allows to override the expressions API being used. + * @param {string} text - The text to classify and retrieve the expression label for. + * @param {EXPRESSION_API} [expressionsApi=extension_settings.expressions.api] - The expressions API to use for classification. + * @returns {Promise} - The label of the expression. + */ +export async function getExpressionLabel(text, expressionsApi = extension_settings.expressions.api) { // Return if text is undefined, saving a costly fetch request - if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) { + if ((!modules.includes('classify') && expressionsApi == EXPRESSION_API.extras) || !text) { return getFallbackExpression(); } @@ -1121,7 +1135,7 @@ async function getExpressionLabel(text) { text = sampleClassifyText(text); try { - switch (extension_settings.expressions.api) { + switch (expressionsApi) { // Local BERT pipeline case EXPRESSION_API.local: { const localResult = await fetch('/api/extra/classify', { @@ -2105,7 +2119,15 @@ function migrateSettings() { })); SlashCommandParser.addCommandObject(SlashCommand.fromProps({ name: 'classify', - callback: classifyCommand, + callback: classifyCallback, + namedArgumentList: [ + SlashCommandNamedArgument.fromProps({ + name: 'api', + description: 'The Classifier API to classify with. If not specified, the configured one will be used.', + typeList: [ARGUMENT_TYPE.STRING], + enumList: Object.keys(EXPRESSION_API).map(api => new SlashCommandEnumValue(api, null, enumTypes.enum)), + }), + ], unnamedArgumentList: [ new SlashCommandArgument( 'text', [ARGUMENT_TYPE.STRING], true, @@ -2116,6 +2138,9 @@ function migrateSettings() {
Performs an emotion classification of the given text and returns a label.
+
+ Allows to specify which Classifier API to perform the classification with. +
Example: