diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index da527dd8b..c70eeaef0 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -1,14 +1,15 @@ -import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, saveSettingsDebounced, substituteParams } from '../../../script.js'; +import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, online_status, saveSettingsDebounced, substituteParams } from '../../../script.js'; import { dragElement, isMobile } from '../../RossAscends-mods.js'; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js'; import { loadMovingUIState, power_user } from '../../power-user.js'; -import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from '../../utils.js'; +import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence, waitUntilCondition } from '../../utils.js'; import { hideMutedSprites } from '../../group-chats.js'; import { isJsonSchemaSupported } from '../../textgen-settings.js'; import { debounce_timeout } from '../../constants.js'; import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js'; import { ARGUMENT_TYPE, SlashCommandArgument } from '../../slash-commands/SlashCommandArgument.js'; +import { isFunctionCallingSupported } from '../../openai.js'; export { MODULE_NAME }; const MODULE_NAME = 'expressions'; @@ -16,6 +17,7 @@ const UPDATE_INTERVAL = 2000; const STREAMING_UPDATE_INTERVAL = 10000; const TALKINGCHECK_UPDATE_INTERVAL = 500; const DEFAULT_FALLBACK_EXPRESSION = 'joy'; +const FUNCTION_NAME = 'set_emotion'; const DEFAULT_LLM_PROMPT = 'Pause your roleplay. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}'; const DEFAULT_EXPRESSIONS = [ 'talkinghead', @@ -1001,6 +1003,10 @@ async function getLlmPrompt(labels) { return ''; } + if (isFunctionCallingSupported()) { + return ''; + } + const labelsString = labels.map(x => `"${x}"`).join(', '); const prompt = substituteParams(String(extension_settings.expressions.llmPrompt)) .replace(/{{labels}}/gi, labelsString); @@ -1018,7 +1024,13 @@ function parseLlmResponse(emotionResponse, labels) { try { const parsedEmotion = JSON.parse(emotionResponse); - return parsedEmotion?.emotion ?? fallbackExpression; + const response = parsedEmotion?.emotion; + + if (!response || !labels.includes(response)) { + return fallbackExpression; + } + + return response; } catch { const fuse = new Fuse(labels, { includeScore: true }); console.debug('Using fuzzy search in labels:', labels); @@ -1032,6 +1044,40 @@ function parseLlmResponse(emotionResponse, labels) { throw new Error('Could not parse emotion response ' + emotionResponse); } +/** + * Registers the function tool for the LLM API. + * @param {FunctionToolRegister} args Function tool register arguments. + */ +function onFunctionToolRegister(args) { + if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isFunctionCallingSupported()) { + // Only trigger on quiet mode + if (args.type !== 'quiet') { + return; + } + + const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead'); + const jsonSchema = { + $schema: 'http://json-schema.org/draft-04/schema#', + type: 'object', + properties: { + emotion: { + type: 'string', + enum: emotions, + }, + }, + required: [ + 'emotion', + ], + }; + args.registerFunctionTool( + FUNCTION_NAME, + substituteParams('Sets the label that best describes the current emotional state of {{char}}. Only select one of the enumerated values.'), + jsonSchema, + true, + ); + } +} + function onTextGenSettingsReady(args) { // Only call if inside an API call if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) { @@ -1087,11 +1133,27 @@ async function getExpressionLabel(text) { } break; // Using LLM case EXPRESSION_API.llm: { + try { + await waitUntilCondition(() => online_status !== 'no_connection', 3000, 250); + } catch (error) { + console.warn('No LLM connection. Using fallback expression', error); + return getFallbackExpression(); + } + const expressionsList = await getExpressionsList(); const prompt = await getLlmPrompt(expressionsList); + let functionResult = null; eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady); + eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister); + eventSource.once(event_types.LLM_FUNCTION_TOOL_CALL, (/** @type {FunctionToolCall} */ args) => { + if (args.name !== FUNCTION_NAME) { + return; + } + + functionResult = args?.arguments; + }); const emotionResponse = await generateQuietPrompt(prompt, false, false); - return parseLlmResponse(emotionResponse, expressionsList); + return parseLlmResponse(functionResult || emotionResponse, expressionsList); } // Extras default: { diff --git a/public/scripts/extensions/expressions/settings.html b/public/scripts/extensions/expressions/settings.html index abb0219ce..e819d3bfd 100644 --- a/public/scripts/extensions/expressions/settings.html +++ b/public/scripts/extensions/expressions/settings.html @@ -34,7 +34,7 @@ - Will be used if the API doesn't support JSON schemas. + Will be used if the API doesn't support JSON schemas or function calling.