#1569 Add logit bias for text completions

This commit is contained in:
Cohee
2023-12-18 18:57:10 +02:00
parent 08ea2095f8
commit 6e8104873e
3 changed files with 71 additions and 1 deletions

View File

@ -9,6 +9,7 @@ import {
setOnlineStatus,
substituteParams,
} from '../script.js';
import { BIAS_CACHE, createNewLogitBiasEntry, displayLogitBias, getLogitBiasListResult } from './logit-bias.js';
import {
power_user,
@ -35,6 +36,7 @@ export const textgen_types = {
};
const { MANCER, APHRODITE, TOGETHERAI } = textgen_types;
const BIAS_KEY = '#textgenerationwebui_api-settings';
// Maybe let it be configurable in the future?
// (7 days later) The future has come.
@ -94,6 +96,7 @@ const settings = {
togetherai_model: 'Gryphe/MythoMax-L2-13b',
legacy_api: false,
sampler_order: KOBOLDCPP_ORDER,
logit_bias: [],
n: 1,
};
@ -147,6 +150,7 @@ const setting_names = [
//'prompt_log_probs_aphrodite'
'sampler_order',
'n',
'logit_bias',
];
async function selectPreset(name) {
@ -162,6 +166,7 @@ async function selectPreset(name) {
setSettingByName(name, value, true);
}
setGenerationParamsFromPreset(preset);
displayLogitBias(preset.logit_bias, BIAS_KEY);
saveSettingsDebounced();
}
@ -243,6 +248,42 @@ function getCustomTokenBans() {
return result.filter(onlyUnique).map(x => String(x)).join(',');
}
/**
* Calculates logit bias object from the logit bias list.
* @returns {object} Logit bias object
*/
function calculateLogitBias() {
if (!Array.isArray(settings.logit_bias) || settings.logit_bias.length === 0) {
return {};
}
const tokenizer = SENTENCEPIECE_TOKENIZERS.includes(power_user.tokenizer) ? power_user.tokenizer : tokenizers.LLAMA;
const result = {};
/**
* Adds bias to the logit bias object.
* @param {number} bias
* @param {number[]} sequence
* @returns {object} Accumulated logit bias object
*/
function addBias(bias, sequence) {
if (sequence.length === 0) {
return;
}
for (const logit of sequence) {
const key = String(logit);
result[key] = bias;
}
return result;
}
getLogitBiasListResult(settings.logit_bias, tokenizer, addBias);
return result;
}
function loadTextGenSettings(data, loadedSettings) {
textgenerationwebui_presets = convertPresets(data.textgenerationwebui_presets);
textgenerationwebui_preset_names = data.textgenerationwebui_preset_names ?? [];
@ -270,6 +311,7 @@ function loadTextGenSettings(data, loadedSettings) {
$('#textgen_type').val(settings.type);
showTypeSpecificControls(settings.type);
displayLogitBias(settings.logit_bias, BIAS_KEY);
//this is needed because showTypeSpecificControls() does not handle NOT declarations
if (settings.type === textgen_types.APHRODITE) {
$('[data-forAphro=False]').each(function () {
@ -415,6 +457,8 @@ jQuery(function () {
saveSettingsDebounced();
});
}
$('#textgen_logit_bias_new_entry').on('click', () => createNewLogitBiasEntry(settings.logit_bias, BIAS_KEY));
});
function showTypeSpecificControls(type) {
@ -440,6 +484,11 @@ function setSettingByName(setting, value, trigger) {
return;
}
if ('logit_bias' === setting) {
settings.logit_bias = Array.isArray(value) ? value : [];
return;
}
const isCheckbox = $(`#${setting}_textgenerationwebui`).attr('type') == 'checkbox';
const isText = $(`#${setting}_textgenerationwebui`).attr('type') == 'text' || $(`#${setting}_textgenerationwebui`).is('textarea');
if (isCheckbox) {
@ -642,6 +691,12 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
APIflags = Object.assign(APIflags, aphroditeExclusionFlags);
}
if (Array.isArray(settings.logit_bias) && settings.logit_bias.length) {
const logitBias = BIAS_CACHE.get(BIAS_KEY) || calculateLogitBias();
BIAS_CACHE.set(BIAS_KEY, logitBias);
APIflags.logit_bias = logitBias;
}
return APIflags;
}