Refactor Novel logit bias

This commit is contained in:
Cohee
2023-12-18 17:32:10 +02:00
parent cc27bcb076
commit 08ea2095f8
4 changed files with 148 additions and 113 deletions

View File

@ -978,7 +978,7 @@
Helps to ban or reinforce the usage of certain tokens. Helps to ban or reinforce the usage of certain tokens.
</div> </div>
<div class="flex-container flexFlowColumn wide100p"> <div class="flex-container flexFlowColumn wide100p">
<div class="novelai_logit_bias_list"></div> <div class="logit_bias_list"></div>
</div> </div>
</div> </div>
<div class="range-block"> <div class="range-block">
@ -4272,11 +4272,11 @@
</div> </div>
</div> </div>
</div> </div>
<div id="novelai_logit_bias_template" class="template_element"> <div id="logit_bias_template" class="template_element">
<div class="novelai_logit_bias_form"> <div class="logit_bias_form">
<input class="novelai_logit_bias_text text_pole" data-i18n="[placeholder]Type here..." placeholder="type here..." /> <input class="logit_bias_text text_pole" data-i18n="[placeholder]Type here..." placeholder="type here..." />
<input class="novelai_logit_bias_value text_pole" type="number" min="-2" value="0" max="2" step="0.01" /> <input class="logit_bias_value text_pole" type="number" min="-2" value="0" max="2" step="0.01" />
<i class="menu_button fa-solid fa-xmark novelai_logit_bias_remove"></i> <i class="menu_button fa-solid fa-xmark logit_bias_remove"></i>
</div> </div>
</div> </div>
<div id="completion_prompt_manager_popup" class="drawer-content" style="display:none;"> <div id="completion_prompt_manager_popup" class="drawer-content" style="display:none;">

View File

@ -0,0 +1,126 @@
import { saveSettingsDebounced } from '../script.js';
import { getTextTokens } from './tokenizers.js';
import { uuidv4 } from './utils.js';
export const BIAS_CACHE = new Map();
/**
* Displays the logit bias list in the specified container.
* @param {object} logitBias Logit bias object
* @param {string} containerSelector Container element selector
* @returns
*/
export function displayLogitBias(logitBias, containerSelector) {
if (!Array.isArray(logitBias)) {
console.log('Logit bias set not found');
return;
}
$(containerSelector).find('.logit_bias_list').empty();
for (const entry of logitBias) {
if (entry) {
createLogitBiasListItem(entry, logitBias, containerSelector);
}
}
BIAS_CACHE.delete(containerSelector);
}
/**
* Creates a new logit bias entry
* @param {object[]} logitBias Array of logit bias objects
* @param {string} containerSelector Container element ID
*/
export function createNewLogitBiasEntry(logitBias, containerSelector) {
const entry = { id: uuidv4(), text: '', value: 0 };
logitBias.push(entry);
BIAS_CACHE.delete(containerSelector);
createLogitBiasListItem(entry, logitBias, containerSelector);
saveSettingsDebounced();
}
/**
* Creates a logit bias list item.
* @param {object} entry Logit bias entry
* @param {object[]} logitBias Array of logit bias objects
* @param {string} containerSelector Container element ID
*/
function createLogitBiasListItem(entry, logitBias, containerSelector) {
const id = entry.id;
const template = $('#logit_bias_template .logit_bias_form').clone();
template.data('id', id);
template.find('.logit_bias_text').val(entry.text).on('input', function () {
entry.text = $(this).val();
BIAS_CACHE.delete(containerSelector);
saveSettingsDebounced();
});
template.find('.logit_bias_value').val(entry.value).on('input', function () {
entry.value = Number($(this).val());
BIAS_CACHE.delete(containerSelector);
saveSettingsDebounced();
});
template.find('.logit_bias_remove').on('click', function () {
$(this).closest('.logit_bias_form').remove();
const index = logitBias.indexOf(entry);
if (index > -1) {
logitBias.splice(index, 1);
}
BIAS_CACHE.delete(containerSelector);
saveSettingsDebounced();
});
$(containerSelector).find('.logit_bias_list').prepend(template);
}
/**
* Populate logit bias list from preset.
* @param {object[]} biasPreset Bias preset
* @param {number} tokenizerType Tokenizer type (see tokenizers.js)
* @param {(bias: number, sequence: number[]) => object} getBiasObject Transformer function to create bias object
* @returns {object[]} Array of logit bias objects
*/
export function getLogitBiasListResult(biasPreset, tokenizerType, getBiasObject) {
const result = [];
for (const entry of biasPreset) {
if (entry.text?.length > 0) {
const text = entry.text.trim();
// Skip empty lines
if (text.length === 0) {
continue;
}
// Verbatim text
if (text.startsWith('{') && text.endsWith('}')) {
const tokens = getTextTokens(tokenizerType, text.slice(1, -1));
result.push(getBiasObject(entry.value, tokens));
}
// Raw token ids, JSON serialized
else if (text.startsWith('[') && text.endsWith(']')) {
try {
const tokens = JSON.parse(text);
if (Array.isArray(tokens) && tokens.every(t => Number.isInteger(t))) {
result.push(getBiasObject(entry.value, tokens));
} else {
throw new Error('Not an array of integers');
}
} catch (err) {
console.log(`Failed to parse logit bias token list: ${text}`, err);
}
}
// Text with a leading space
else {
const biasText = ` ${text}`;
const tokens = getTextTokens(tokenizerType, biasText);
result.push(getBiasObject(entry.value, tokens));
}
}
}
return result;
}

View File

@ -15,8 +15,8 @@ import {
getSortableDelay, getSortableDelay,
getStringHash, getStringHash,
onlyUnique, onlyUnique,
uuidv4,
} from './utils.js'; } from './utils.js';
import { BIAS_CACHE, createNewLogitBiasEntry, displayLogitBias, getLogitBiasListResult } from './logit-bias.js';
const default_preamble = '[ Style: chat, complex, sensory, visceral ]'; const default_preamble = '[ Style: chat, complex, sensory, visceral ]';
const default_order = [1, 5, 0, 2, 3, 4]; const default_order = [1, 5, 0, 2, 3, 4];
@ -59,7 +59,7 @@ const nai_tiers = {
let novel_data = null; let novel_data = null;
let badWordsCache = {}; let badWordsCache = {};
let biasCache = undefined; const BIAS_KEY = '#novel_api-settings';
export function setNovelData(data) { export function setNovelData(data) {
novel_data = data; novel_data = data;
@ -145,7 +145,7 @@ export function loadNovelSettings(settings) {
//load the rest of the Novel settings without any checks //load the rest of the Novel settings without any checks
nai_settings.model_novel = settings.model_novel; nai_settings.model_novel = settings.model_novel;
$('#model_novel_select').val(nai_settings.model_novel); $('#model_novel_select').val(nai_settings.model_novel);
$(`#model_novel_select option[value=${nai_settings.model_novel}]`).attr('selected', true); $(`#model_novel_select option[value=${nai_settings.model_novel}]`).prop('selected', true);
if (settings.nai_preamble !== undefined) { if (settings.nai_preamble !== undefined) {
nai_settings.preamble = settings.nai_preamble; nai_settings.preamble = settings.nai_preamble;
@ -217,7 +217,7 @@ function loadNovelSettingsUi(ui_settings) {
$('#streaming_novel').prop('checked', ui_settings.streaming_novel); $('#streaming_novel').prop('checked', ui_settings.streaming_novel);
sortItemsByOrder(ui_settings.order); sortItemsByOrder(ui_settings.order);
displayLogitBias(ui_settings.logit_bias); displayLogitBias(ui_settings.logit_bias, BIAS_KEY);
} }
const sliders = [ const sliders = [
@ -433,8 +433,8 @@ export function getNovelGenerationData(finalPrompt, settings, maxLength, isImper
let logitBias = []; let logitBias = [];
if (tokenizerType !== tokenizers.NONE && Array.isArray(nai_settings.logit_bias) && nai_settings.logit_bias.length) { if (tokenizerType !== tokenizers.NONE && Array.isArray(nai_settings.logit_bias) && nai_settings.logit_bias.length) {
logitBias = biasCache || calculateLogitBias(); logitBias = BIAS_CACHE.get(BIAS_KEY) || calculateLogitBias();
biasCache = logitBias; BIAS_CACHE.set(BIAS_KEY, logitBias);
} }
return { return {
@ -525,65 +525,14 @@ function saveSamplingOrder() {
saveSettingsDebounced(); saveSettingsDebounced();
} }
function displayLogitBias(logit_bias) {
if (!Array.isArray(logit_bias)) {
console.log('Logit bias set not found');
return;
}
$('.novelai_logit_bias_list').empty();
for (const entry of logit_bias) {
if (entry) {
createLogitBiasListItem(entry);
}
}
biasCache = undefined;
}
function createNewLogitBiasEntry() {
const entry = { id: uuidv4(), text: '', value: 0 };
nai_settings.logit_bias.push(entry);
biasCache = undefined;
createLogitBiasListItem(entry);
saveSettingsDebounced();
}
function createLogitBiasListItem(entry) {
const id = entry.id;
const template = $('#novelai_logit_bias_template .novelai_logit_bias_form').clone();
template.data('id', id);
template.find('.novelai_logit_bias_text').val(entry.text).on('input', function () {
entry.text = $(this).val();
biasCache = undefined;
saveSettingsDebounced();
});
template.find('.novelai_logit_bias_value').val(entry.value).on('input', function () {
entry.value = Number($(this).val());
biasCache = undefined;
saveSettingsDebounced();
});
template.find('.novelai_logit_bias_remove').on('click', function () {
$(this).closest('.novelai_logit_bias_form').remove();
const index = nai_settings.logit_bias.indexOf(entry);
if (index > -1) {
nai_settings.logit_bias.splice(index, 1);
}
biasCache = undefined;
saveSettingsDebounced();
});
$('.novelai_logit_bias_list').prepend(template);
}
/** /**
* Calculates logit bias for Novel AI * Calculates logit bias for Novel AI
* @returns {object[]} Array of logit bias objects * @returns {object[]} Array of logit bias objects
*/ */
function calculateLogitBias() { function calculateLogitBias() {
const bias_preset = nai_settings.logit_bias; const biasPreset = nai_settings.logit_bias;
if (!Array.isArray(bias_preset) || bias_preset.length === 0) { if (!Array.isArray(biasPreset) || biasPreset.length === 0) {
return []; return [];
} }
@ -605,47 +554,7 @@ function calculateLogitBias() {
}; };
} }
const result = []; const result = getLogitBiasListResult(biasPreset, tokenizerType, getBiasObject);
for (const entry of bias_preset) {
if (entry.text?.length > 0) {
const text = entry.text.trim();
// Skip empty lines
if (text.length === 0) {
continue;
}
// Verbatim text
if (text.startsWith('{') && text.endsWith('}')) {
const tokens = getTextTokens(tokenizerType, text.slice(1, -1));
result.push(getBiasObject(entry.value, tokens));
}
// Raw token ids, JSON serialized
else if (text.startsWith('[') && text.endsWith(']')) {
try {
const tokens = JSON.parse(text);
if (Array.isArray(tokens) && tokens.every(t => Number.isInteger(t))) {
result.push(getBiasObject(entry.value, tokens));
} else {
throw new Error('Not an array of integers');
}
} catch (err) {
console.log(`Failed to parse logit bias token list: ${text}`, err);
}
}
// Text with a leading space
else {
const biasText = ` ${text}`;
const tokens = getTextTokens(tokenizerType, biasText);
result.push(getBiasObject(entry.value, tokens));
}
}
}
return result; return result;
} }
@ -778,5 +687,5 @@ jQuery(function () {
saveSamplingOrder(); saveSamplingOrder();
}); });
$('#novelai_logit_bias_new_entry').on('click', createNewLogitBiasEntry); $('#novelai_logit_bias_new_entry').on('click', () => createNewLogitBiasEntry(nai_settings.logit_bias, BIAS_KEY));
}); });

View File

@ -3448,30 +3448,30 @@ a {
height: 100%; height: 100%;
} }
.novelai_logit_bias_form { .logit_bias_form {
display: flex; display: flex;
flex-direction: row; flex-direction: row;
column-gap: 10px; column-gap: 10px;
align-items: center; align-items: center;
} }
.novelai_logit_bias_text, .logit_bias_text,
.novelai_logit_bias_value { .logit_bias_value {
flex: 1; flex: 1;
} }
.novelai_logit_bias_list { .logit_bias_list {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: 10px; gap: 10px;
} }
.novelai_logit_bias_list:empty { .logit_bias_list:empty {
width: 100%; width: 100%;
height: 100%; height: 100%;
} }
.novelai_logit_bias_list:empty::before { .logit_bias_list:empty::before {
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; justify-content: center;