diff --git a/public/scripts/extensions/expressions/index.js b/public/scripts/extensions/expressions/index.js index fdc146c1a..e31807893 100644 --- a/public/scripts/extensions/expressions/index.js +++ b/public/scripts/extensions/expressions/index.js @@ -15,6 +15,7 @@ import { SlashCommandEnumValue, enumTypes } from '../../slash-commands/SlashComm import { commonEnumProviders } from '../../slash-commands/SlashCommandCommonEnumsProvider.js'; import { slashCommandReturnHelper } from '../../slash-commands/SlashCommandReturnHelper.js'; import { SlashCommandClosure } from '../../slash-commands/SlashCommandClosure.js'; +import { generateWebLlmChatPrompt, isWebLlmSupported } from '../shared.js'; export { MODULE_NAME }; const MODULE_NAME = 'expressions'; @@ -59,6 +60,7 @@ const EXPRESSION_API = { local: 0, extras: 1, llm: 2, + webllm: 3, }; let expressionsList = null; @@ -698,8 +700,8 @@ async function moduleWorker() { } // If using LLM api then check if streamingProcessor is finished to avoid sending multiple requests to the API - if (extension_settings.expressions.api === EXPRESSION_API.llm && context.streamingProcessor && !context.streamingProcessor.isFinished) { - return; + if (extension_settings.expressions.api === EXPRESSION_API.llm && context.streamingProcessor && !context.streamingProcessor.isFinished) { + return; } // API is busy @@ -852,7 +854,7 @@ function setTalkingHeadState(newState) { extension_settings.expressions.talkinghead = newState; // Store setting saveSettingsDebounced(); - if (extension_settings.expressions.api == EXPRESSION_API.local || extension_settings.expressions.api == EXPRESSION_API.llm) { + if ([EXPRESSION_API.local, EXPRESSION_API.llm, EXPRESSION_API.webllm].includes(extension_settings.expressions.api)) { return; } @@ -1057,11 +1059,39 @@ function parseLlmResponse(emotionResponse, labels) { console.debug(`fuzzy search found: ${result[0].item} as closest for the LLM response:`, emotionResponse); return result[0].item; } + const lowerCaseResponse = String(emotionResponse || '').toLowerCase(); + for (const label of labels) { + if (lowerCaseResponse.includes(label.toLowerCase())) { + console.debug(`Found label ${label} in the LLM response:`, emotionResponse); + return label; + } + } } throw new Error('Could not parse emotion response ' + emotionResponse); } +/** + * Gets the JSON schema for the LLM API. + * @param {string[]} emotions A list of emotions to search for. + * @returns {object} The JSON schema for the LLM API. + */ +function getJsonSchema(emotions) { + return { + $schema: 'http://json-schema.org/draft-04/schema#', + type: 'object', + properties: { + emotion: { + type: 'string', + enum: emotions, + }, + }, + required: [ + 'emotion', + ], + }; +} + function onTextGenSettingsReady(args) { // Only call if inside an API call if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) { @@ -1071,19 +1101,7 @@ function onTextGenSettingsReady(args) { stop: [], stopping_strings: [], custom_token_bans: [], - json_schema: { - $schema: 'http://json-schema.org/draft-04/schema#', - type: 'object', - properties: { - emotion: { - type: 'string', - enum: emotions, - }, - }, - required: [ - 'emotion', - ], - }, + json_schema: getJsonSchema(emotions), }); } } @@ -1139,6 +1157,22 @@ export async function getExpressionLabel(text, expressionsApi = extension_settin const emotionResponse = await generateRaw(text, main_api, false, false, prompt); return parseLlmResponse(emotionResponse, expressionsList); } + // Using WebLLM + case EXPRESSION_API.webllm: { + if (!isWebLlmSupported()) { + console.warn('WebLLM is not supported. Using fallback expression'); + return getFallbackExpression(); + } + + const expressionsList = await getExpressionsList(); + const prompt = substituteParamsExtended(customPrompt, { labels: expressionsList }) || await getLlmPrompt(expressionsList); + const messages = [ + { role: 'user', content: text + '\n\n' + prompt }, + ]; + + const emotionResponse = await generateWebLlmChatPrompt(messages); + return parseLlmResponse(emotionResponse, expressionsList); + } // Extras default: { const url = new URL(getApiUrl()); @@ -1603,7 +1637,7 @@ function onExpressionApiChanged() { const tempApi = this.value; if (tempApi) { extension_settings.expressions.api = Number(tempApi); - $('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm); + $('.expression_llm_prompt_block').toggle([EXPRESSION_API.llm, EXPRESSION_API.webllm].includes(extension_settings.expressions.api)); expressionsList = null; spriteCache = {}; moduleWorker(); @@ -1940,7 +1974,7 @@ function migrateSettings() { await renderAdditionalExpressionSettings(); $('#expression_api').val(extension_settings.expressions.api ?? EXPRESSION_API.extras); - $('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm); + $('.expression_llm_prompt_block').toggle([EXPRESSION_API.llm, EXPRESSION_API.webllm].includes(extension_settings.expressions.api)); $('#expression_llm_prompt').val(extension_settings.expressions.llmPrompt ?? ''); $('#expression_llm_prompt').on('input', function () { extension_settings.expressions.llmPrompt = $(this).val(); diff --git a/public/scripts/extensions/expressions/settings.html b/public/scripts/extensions/expressions/settings.html index f2b7b79ac..dc22debbd 100644 --- a/public/scripts/extensions/expressions/settings.html +++ b/public/scripts/extensions/expressions/settings.html @@ -24,7 +24,8 @@