/classify can specify classier API as argument

This commit is contained in:
Wolfsblvt 2024-09-05 00:06:14 +02:00
parent 58e6ae2fc5
commit a37b805a5d

View File

@ -8,7 +8,7 @@ import { isJsonSchemaSupported } from '../../textgen-settings.js';
import { debounce_timeout } from '../../constants.js'; import { debounce_timeout } from '../../constants.js';
import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
import { SlashCommand } from '../../slash-commands/SlashCommand.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js';
import { ARGUMENT_TYPE, SlashCommandArgument } from '../../slash-commands/SlashCommandArgument.js'; import { ARGUMENT_TYPE, SlashCommandArgument, SlashCommandNamedArgument } from '../../slash-commands/SlashCommandArgument.js';
import { isFunctionCallingSupported } from '../../openai.js'; import { isFunctionCallingSupported } from '../../openai.js';
import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js'; import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashCommandEnumValue.js';
import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js'; import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js';
@ -52,6 +52,7 @@ const DEFAULT_EXPRESSIONS = [
'surprise', 'surprise',
'neutral', 'neutral',
]; ];
/** @enum {number} */
const EXPRESSION_API = { const EXPRESSION_API = {
local: 0, local: 0,
extras: 1, extras: 1,
@ -920,18 +921,24 @@ async function setSpriteSetCommand(_, folder) {
return ''; return '';
} }
async function classifyCommand(_, text) { async function classifyCallback(/** @type {{api: string?}} */ { api = null }, text) {
if (!text) { if (!text) {
console.log('No text provided'); toastr.warning('No text provided');
return '';
}
if (api && !Object.keys(EXPRESSION_API).includes(api)) {
toastr.warning('Invalid API provided');
return ''; return '';
} }
if (!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) { const expressionApi = EXPRESSION_API[api] || extension_settings.expressions.api;
if (!modules.includes('classify') && expressionApi == EXPRESSION_API.extras) {
toastr.warning('Text classification is disabled or not available'); toastr.warning('Text classification is disabled or not available');
return ''; return '';
} }
const label = getExpressionLabel(text); const label = getExpressionLabel(text, expressionApi);
console.debug(`Classification result for "${text}": ${label}`); console.debug(`Classification result for "${text}": ${label}`);
return label; return label;
} }
@ -1108,9 +1115,16 @@ function onTextGenSettingsReady(args) {
} }
} }
async function getExpressionLabel(text) { /**
* Retrieves the label of an expression via classification based on the provided text.
* Optionally allows to override the expressions API being used.
* @param {string} text - The text to classify and retrieve the expression label for.
* @param {EXPRESSION_API} [expressionsApi=extension_settings.expressions.api] - The expressions API to use for classification.
* @returns {Promise<string>} - The label of the expression.
*/
export async function getExpressionLabel(text, expressionsApi = extension_settings.expressions.api) {
// Return if text is undefined, saving a costly fetch request // Return if text is undefined, saving a costly fetch request
if ((!modules.includes('classify') && extension_settings.expressions.api == EXPRESSION_API.extras) || !text) { if ((!modules.includes('classify') && expressionsApi == EXPRESSION_API.extras) || !text) {
return getFallbackExpression(); return getFallbackExpression();
} }
@ -1121,7 +1135,7 @@ async function getExpressionLabel(text) {
text = sampleClassifyText(text); text = sampleClassifyText(text);
try { try {
switch (extension_settings.expressions.api) { switch (expressionsApi) {
// Local BERT pipeline // Local BERT pipeline
case EXPRESSION_API.local: { case EXPRESSION_API.local: {
const localResult = await fetch('/api/extra/classify', { const localResult = await fetch('/api/extra/classify', {
@ -2105,7 +2119,15 @@ function migrateSettings() {
})); }));
SlashCommandParser.addCommandObject(SlashCommand.fromProps({ SlashCommandParser.addCommandObject(SlashCommand.fromProps({
name: 'classify', name: 'classify',
callback: classifyCommand, callback: classifyCallback,
namedArgumentList: [
SlashCommandNamedArgument.fromProps({
name: 'api',
description: 'The Classifier API to classify with. If not specified, the configured one will be used.',
typeList: [ARGUMENT_TYPE.STRING],
enumList: Object.keys(EXPRESSION_API).map(api => new SlashCommandEnumValue(api, null, enumTypes.enum)),
}),
],
unnamedArgumentList: [ unnamedArgumentList: [
new SlashCommandArgument( new SlashCommandArgument(
'text', [ARGUMENT_TYPE.STRING], true, 'text', [ARGUMENT_TYPE.STRING], true,
@ -2116,6 +2138,9 @@ function migrateSettings() {
<div> <div>
Performs an emotion classification of the given text and returns a label. Performs an emotion classification of the given text and returns a label.
</div> </div>
<div>
Allows to specify which Classifier API to perform the classification with.
</div>
<div> <div>
<strong>Example:</strong> <strong>Example:</strong>
<ul> <ul>