/classify allows custom prompt for LLM api

This commit is contained in:
Wolfsblvt 2024-09-05 00:13:54 +02:00
parent a37b805a5d
commit 2472b26057

View File

@ -921,7 +921,7 @@ async function setSpriteSetCommand(_, folder) {
return ''; return '';
} }
async function classifyCallback(/** @type {{api: string?}} */ { api = null }, text) { async function classifyCallback(/** @type {{api: string?, prompt: string?}} */ { api = null, prompt = null }, text) {
if (!text) { if (!text) {
toastr.warning('No text provided'); toastr.warning('No text provided');
return ''; return '';
@ -938,7 +938,7 @@ async function classifyCallback(/** @type {{api: string?}} */ { api = null }, te
return ''; return '';
} }
const label = getExpressionLabel(text, expressionApi); const label = getExpressionLabel(text, expressionApi, { customPrompt: prompt });
console.debug(`Classification result for "${text}": ${label}`); console.debug(`Classification result for "${text}": ${label}`);
return label; return label;
} }
@ -1120,9 +1120,11 @@ function onTextGenSettingsReady(args) {
* Optionally allows to override the expressions API being used. * Optionally allows to override the expressions API being used.
* @param {string} text - The text to classify and retrieve the expression label for. * @param {string} text - The text to classify and retrieve the expression label for.
* @param {EXPRESSION_API} [expressionsApi=extension_settings.expressions.api] - The expressions API to use for classification. * @param {EXPRESSION_API} [expressionsApi=extension_settings.expressions.api] - The expressions API to use for classification.
* @param {object} [options={}] - Optional arguments.
* @param {string?} [options.customPrompt=null] - The custom prompt to use for classification.
* @returns {Promise<string>} - The label of the expression. * @returns {Promise<string>} - The label of the expression.
*/ */
export async function getExpressionLabel(text, expressionsApi = extension_settings.expressions.api) { export async function getExpressionLabel(text, expressionsApi = extension_settings.expressions.api, { customPrompt = null } = {}) {
// Return if text is undefined, saving a costly fetch request // Return if text is undefined, saving a costly fetch request
if ((!modules.includes('classify') && expressionsApi == EXPRESSION_API.extras) || !text) { if ((!modules.includes('classify') && expressionsApi == EXPRESSION_API.extras) || !text) {
return getFallbackExpression(); return getFallbackExpression();
@ -1159,7 +1161,7 @@ export async function getExpressionLabel(text, expressionsApi = extension_settin
} }
const expressionsList = await getExpressionsList(); const expressionsList = await getExpressionsList();
const prompt = await getLlmPrompt(expressionsList); const prompt = customPrompt || await getLlmPrompt(expressionsList);
let functionResult = null; let functionResult = null;
eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady); eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister); eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister);
@ -2127,6 +2129,11 @@ function migrateSettings() {
typeList: [ARGUMENT_TYPE.STRING], typeList: [ARGUMENT_TYPE.STRING],
enumList: Object.keys(EXPRESSION_API).map(api => new SlashCommandEnumValue(api, null, enumTypes.enum)), enumList: Object.keys(EXPRESSION_API).map(api => new SlashCommandEnumValue(api, null, enumTypes.enum)),
}), }),
SlashCommandNamedArgument.fromProps({
name: 'prompt',
description: 'Custom prompt for classification. Only relevant if Classifier API is set to LLM.',
typeList: [ARGUMENT_TYPE.STRING],
}),
], ],
unnamedArgumentList: [ unnamedArgumentList: [
new SlashCommandArgument( new SlashCommandArgument(