Specify LLM prompt in case JSON schema is not supported

This commit is contained in:
Cohee
2024-04-14 17:13:54 +03:00
parent b02394008c
commit 3e60919289
3 changed files with 80 additions and 14 deletions

View File

@ -5,13 +5,15 @@ import { loadMovingUIState, power_user } from '../../power-user.js';
import { registerSlashCommand } from '../../slash-commands.js'; import { registerSlashCommand } from '../../slash-commands.js';
import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from '../../utils.js'; import { onlyUnique, debounce, getCharaFilename, trimToEndSentence, trimToStartSentence } from '../../utils.js';
import { hideMutedSprites } from '../../group-chats.js'; import { hideMutedSprites } from '../../group-chats.js';
import { isJsonSchemaSupported } from '../../textgen-settings.js';
export { MODULE_NAME }; export { MODULE_NAME };
const MODULE_NAME = 'expressions'; const MODULE_NAME = 'expressions';
const UPDATE_INTERVAL = 2000; const UPDATE_INTERVAL = 2000;
const STREAMING_UPDATE_INTERVAL = 6000; const STREAMING_UPDATE_INTERVAL = 10000;
const TALKINGCHECK_UPDATE_INTERVAL = 500; const TALKINGCHECK_UPDATE_INTERVAL = 500;
const DEFAULT_FALLBACK_EXPRESSION = 'joy'; const DEFAULT_FALLBACK_EXPRESSION = 'joy';
const DEFAULT_LLM_PROMPT = 'Pause your roleplay. Classify the emotion of the last message. Output just one word, e.g. "joy" or "anger". Choose only one of the following labels: {{labels}}';
const DEFAULT_EXPRESSIONS = [ const DEFAULT_EXPRESSIONS = [
'talkinghead', 'talkinghead',
'admiration', 'admiration',
@ -976,9 +978,49 @@ function sampleClassifyText(text) {
return result.trim(); return result.trim();
} }
/**
* Gets the classification prompt for the LLM API.
* @param {string[]} labels A list of labels to search for.
* @returns {Promise<string>} Prompt for the LLM API.
*/
async function getLlmPrompt(labels) {
if (isJsonSchemaSupported()) {
return '';
}
const prompt = String(extension_settings.expressions.llmPrompt).replace(/{{labels}}/gi, labels.map(x => `"${x}"`).join(', '));
return prompt;
}
/**
* Parses the emotion response from the LLM API.
* @param {string} emotionResponse The response from the LLM API.
* @param {string[]} labels A list of labels to search for.
* @returns {string} The parsed emotion or the fallback expression.
*/
function parseLlmResponse(emotionResponse, labels) {
const fallbackExpression = getFallbackExpression();
try {
const parsedEmotion = JSON.parse(emotionResponse);
return parsedEmotion?.emotion ?? fallbackExpression;
} catch {
const fuse = new Fuse([emotionResponse]);
for (const label of labels) {
const result = fuse.search(label);
if (result.length > 0) {
return label;
}
}
}
return fallbackExpression;
}
function onTextGenSettingsReady(args) { function onTextGenSettingsReady(args) {
// Only call if inside an API call // Only call if inside an API call
if (inApiCall) { if (inApiCall && extension_settings.expressions.api === EXPRESSION_API.llm && isJsonSchemaSupported()) {
const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead'); const emotions = DEFAULT_EXPRESSIONS.filter((e) => e != 'talkinghead');
Object.assign(args, { Object.assign(args, {
top_k: 1, top_k: 1,
@ -1016,8 +1058,8 @@ async function getExpressionLabel(text) {
try { try {
switch (extension_settings.expressions.api) { switch (extension_settings.expressions.api) {
case EXPRESSION_API.local: // Local BERT pipeline
// Local BERT pipeline case EXPRESSION_API.local: {
const localResult = await fetch('/api/extra/classify', { const localResult = await fetch('/api/extra/classify', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
@ -1028,15 +1070,16 @@ async function getExpressionLabel(text) {
const data = await localResult.json(); const data = await localResult.json();
return data.classification[0].label; return data.classification[0].label;
} }
} break;
break; // Using LLM
case EXPRESSION_API.llm: case EXPRESSION_API.llm: {
// Using LLM const expressionsList = await getExpressionsList();
const emotionResponse = await generateQuietPrompt('', false); const prompt = await getLlmPrompt(expressionsList);
const parsedEmotion = JSON.parse(emotionResponse); const emotionResponse = await generateQuietPrompt(prompt, false, false);
return parsedEmotion.emotion; return parseLlmResponse(emotionResponse, expressionsList);
default: }
// Extras // Extras
default: {
const url = new URL(getApiUrl()); const url = new URL(getApiUrl());
url.pathname = '/api/classify'; url.pathname = '/api/classify';
@ -1053,6 +1096,7 @@ async function getExpressionLabel(text) {
const data = await extrasResult.json(); const data = await extrasResult.json();
return data.classification[0].label; return data.classification[0].label;
} }
} break;
} }
} catch (error) { } catch (error) {
toastr.info('Could not classify expression. Check the console or your backend for more information.'); toastr.info('Could not classify expression. Check the console or your backend for more information.');
@ -1488,6 +1532,7 @@ function onExperesionApiChanged() {
const tempApi = this.value; const tempApi = this.value;
if (tempApi) { if (tempApi) {
extension_settings.expressions.api = Number(tempApi); extension_settings.expressions.api = Number(tempApi);
$('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm);
moduleWorker(); moduleWorker();
saveSettingsDebounced(); saveSettingsDebounced();
} }
@ -1760,6 +1805,11 @@ function migrateSettings() {
delete extension_settings.expressions.local; delete extension_settings.expressions.local;
saveSettingsDebounced(); saveSettingsDebounced();
} }
if (extension_settings.expressions.llmPrompt === undefined) {
extension_settings.expressions.llmPrompt = DEFAULT_LLM_PROMPT;
saveSettingsDebounced();
}
} }
(async function () { (async function () {
@ -1811,7 +1861,13 @@ function migrateSettings() {
}); });
await renderAdditionalExpressionSettings(); await renderAdditionalExpressionSettings();
$('#expression_api').val(extension_settings.expressions.api || EXPRESSION_API.extras); $('#expression_api').val(extension_settings.expressions.api ?? EXPRESSION_API.extras);
$('.expression_llm_prompt_block').toggle(extension_settings.expressions.api === EXPRESSION_API.llm);
$('#expression_llm_prompt').val(extension_settings.expressions.llmPrompt ?? '');
$('#expression_llm_prompt').on('input', function () {
extension_settings.expressions.llmPrompt = $(this).val();
saveSettingsDebounced();
});
$('#expression_custom_add').on('click', onClickExpressionAddCustom); $('#expression_custom_add').on('click', onClickExpressionAddCustom);
$('#expression_custom_remove').on('click', onClickExpressionRemoveCustom); $('#expression_custom_remove').on('click', onClickExpressionRemoveCustom);

View File

@ -27,6 +27,11 @@
<option value="2">LLM</option> <option value="2">LLM</option>
</select> </select>
</div> </div>
<div class="expression_llm_prompt_block m-b-1 m-t-1">
<label for="expression_llm_prompt">LLM Prompt</label>
<small>Will be used if the API doesn't support JSON schemas.</small>
<textarea id="expression_llm_prompt" type="text" class="text_pole" rows="2"></textarea>
</div>
<div class="expression_fallback_block m-b-1 m-t-1"> <div class="expression_fallback_block m-b-1 m-t-1">
<label for="expression_fallback">Default / Fallback Expression</label> <label for="expression_fallback">Default / Fallback Expression</label>
<small>Set the default and fallback expression being used when no matching expression is found.</small> <small>Set the default and fallback expression being used when no matching expression is found.</small>

View File

@ -3,6 +3,7 @@ import {
event_types, event_types,
getRequestHeaders, getRequestHeaders,
getStoppingStrings, getStoppingStrings,
main_api,
max_context, max_context,
saveSettingsDebounced, saveSettingsDebounced,
setGenerationParamsFromPreset, setGenerationParamsFromPreset,
@ -978,6 +979,10 @@ function getModel() {
return undefined; return undefined;
} }
export function isJsonSchemaSupported() {
return settings.type === TABBY && main_api === 'textgenerationwebui';
}
export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) { export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) {
const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet'; const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet';
let params = { let params = {