Reference implementation: Set expressions with function calling

This commit is contained in:
Cohee 2024-05-25 15:38:32 +03:00
parent a20c6bb01e
commit dc8530049f
2 changed files with 67 additions and 5 deletions

View File

@ -1,14 +1,15 @@
import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, saveSettingsDebounced, substituteParams } from '../../../script.js'; import { callPopup, eventSource, event_types, generateQuietPrompt, getRequestHeaders, online_status, saveSettingsDebounced, substituteParams } from '../../../script.js';
import { dragElement, isMobile } from '../../RossAscends-mods.js'; import { dragElement, isMobile } from '../../RossAscends-mods.js';
import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js'; import { getContext, getApiUrl, modules, extension_settings, ModuleWorkerWrapper, doExtrasFetch, renderExtensionTemplateAsync } from '../../extensions.js';
import { loadMovingUIState, power_user } from '../../power-user.js'; import { loadMovingUIState, power_user } from '../../power-user.js';
import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from '../../utils.js'; import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence, waitUntilCondition } from '../../utils.js';
import { hideMutedSprites } from '../../group-chats.js'; import { hideMutedSprites } from '../../group-chats.js';
import { isJsonSchemaSupported } from '../../textgen-settings.js'; import { isJsonSchemaSupported } from '../../textgen-settings.js';
import { debounce_timeout } from '../../constants.js'; import { debounce_timeout } from '../../constants.js';
import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js'; import { SlashCommandParser } from '../../slash-commands/SlashCommandParser.js';
import { SlashCommand } from '../../slash-commands/SlashCommand.js'; import { SlashCommand } from '../../slash-commands/SlashCommand.js';
import { ARGUMENT_TYPE, SlashCommandArgument } from '../../slash-commands/SlashCommandArgument.js'; import { ARGUMENT_TYPE, SlashCommandArgument } from '../../slash-commands/SlashCommandArgument.js';
import { isFunctionCallingSupported } from '../../openai.js';
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'expressions'; const MODULE_NAME = 'expressions';
@ -16,6 +17,7 @@ const UPDATE_INTERVAL = 2000;
const STREAMING_UPDATE_INTERVAL = 10000; const STREAMING_UPDATE_INTERVAL = 10000;
const TALKINGCHECK_UPDATE_INTERVAL = 500; const TALKINGCHECK_UPDATE_INTERVAL = 500;
const DEFAULT_FALLBACK_EXPRESSION = 'joy'; const DEFAULT_FALLBACK_EXPRESSION = 'joy';
const FUNCTION_NAME = 'set_emotion';
const DEFAULT_LLM_PROMPT = 'Pause your roleplay. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}'; const DEFAULT_LLM_PROMPT = 'Pause your roleplay. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}';
const DEFAULT_EXPRESSIONS = [ const DEFAULT_EXPRESSIONS = [
'talkinghead', 'talkinghead',
@ -1001,6 +1003,10 @@ async function getLlmPrompt(labels) {
return ''; return '';
} }
if (isFunctionCallingSupported()) {
return '';
}
const labelsString = labels.map(x => `"${x}"`).join(', '); const labelsString = labels.map(x => `"${x}"`).join(', ');
const prompt = substituteParams(String(extension_settings.expressions.llmPrompt)) const prompt = substituteParams(String(extension_settings.expressions.llmPrompt))
.replace(/{{labels}}/gi, labelsString); .replace(/{{labels}}/gi, labelsString);
@ -1018,7 +1024,13 @@ function parseLlmResponse(emotionResponse, labels) {
try { try {
const parsedEmotion = JSON.parse(emotionResponse); const parsedEmotion = JSON.parse(emotionResponse);
return parsedEmotion?.emotion ?? fallbackExpression; const response = parsedEmotion?.emotion;
if (!response || !labels.includes(response)) {
return fallbackExpression;
}
return response;
} catch { } catch {
const fuse = new Fuse(labels, { includeScore: true }); const fuse = new Fuse(labels, { includeScore: true });
console.debug('Using fuzzy search in labels:', labels); console.debug('Using fuzzy search in labels:', labels);
@ -1032,6 +1044,40 @@ function parseLlmResponse(emotionResponse, labels) {
throw new Error('Could not parse emotion response ' + emotionResponse); throw new Error('Could not parse emotion response ' + emotionResponse);
} }
/**
* Registers the function tool for the LLM API.
* @param {FunctionToolRegister} args Function tool register arguments.
*/
function onFunctionToolRegister(args) {
if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isFunctionCallingSupported()) {
// Only trigger on quiet mode
if (args.type !== 'quiet') {
return;
}
const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead');
const jsonSchema = {
$schema: 'http://json-schema.org/draft-04/schema#',
type: 'object',
properties: {
emotion: {
type: 'string',
enum: emotions,
},
},
required: [
'emotion',
],
};
args.registerFunctionTool(
FUNCTION_NAME,
substituteParams('Sets the label that best describes the current emotional state of {{char}}. Only select one of the enumerated values.'),
jsonSchema,
true,
);
}
}
function onTextGenSettingsReady(args) { function onTextGenSettingsReady(args) {
// Only call if inside an API call // Only call if inside an API call
if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) { if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) {
@ -1087,11 +1133,27 @@ async function getExpressionLabel(text) {
} break; } break;
// Using LLM // Using LLM
case EXPRESSION_API.llm: { case EXPRESSION_API.llm: {
try {
await waitUntilCondition(() => online_status !== 'no_connection', 3000, 250);
} catch (error) {
console.warn('No LLM connection. Using fallback expression', error);
return getFallbackExpression();
}
const expressionsList = await getExpressionsList(); const expressionsList = await getExpressionsList();
const prompt = await getLlmPrompt(expressionsList); const prompt = await getLlmPrompt(expressionsList);
let functionResult = null;
eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady); eventSource.once(event_types.TEXT_COMPLETION_SETTINGS_READY, onTextGenSettingsReady);
eventSource.once(event_types.LLM_FUNCTION_TOOL_REGISTER, onFunctionToolRegister);
eventSource.once(event_types.LLM_FUNCTION_TOOL_CALL, (/** @type {FunctionToolCall} */ args) => {
if (args.name !== FUNCTION_NAME) {
return;
}
functionResult = args?.arguments;
});
const emotionResponse = await generateQuietPrompt(prompt, false, false); const emotionResponse = await generateQuietPrompt(prompt, false, false);
return parseLlmResponse(emotionResponse, expressionsList); return parseLlmResponse(functionResult || emotionResponse, expressionsList);
} }
// Extras // Extras
default: { default: {

View File

@ -34,7 +34,7 @@
<i class="fa-solid fa-clock-rotate-left fa-sm"></i> <i class="fa-solid fa-clock-rotate-left fa-sm"></i>
</div> </div>
</label> </label>
<small>Will be used if the API doesn't support JSON schemas.</small> <small>Will be used if the API doesn't support JSON schemas or function calling.</small>
<textarea id="expression_llm_prompt" type="text" class="text_pole textarea_compact" rows="2" placeholder="Use &lcub;&lcub;labels&rcub;&rcub; special macro."></textarea> <textarea id="expression_llm_prompt" type="text" class="text_pole textarea_compact" rows="2" placeholder="Use &lcub;&lcub;labels&rcub;&rcub; special macro."></textarea>
</div> </div>
<div class="expression_fallback_block m-b-1 m-t-1"> <div class="expression_fallback_block m-b-1 m-t-1">