Request token probabilities from llama.cpp backend

llama.cpp server token probabilities are given as values ranging from
0 to 1 instead of as logarithms.
This commit is contained in:
Deciare 2024-02-23 14:01:46 -05:00
parent 2d152d2705
commit 344b9eedbc
2 changed files with 22 additions and 5 deletions

View File

@ -12,6 +12,7 @@ import {
import { debounce, delay, getStringHash } from './utils.js'; import { debounce, delay, getStringHash } from './utils.js';
import { decodeTextTokens, getTokenizerBestMatch } from './tokenizers.js'; import { decodeTextTokens, getTokenizerBestMatch } from './tokenizers.js';
import { power_user } from './power-user.js'; import { power_user } from './power-user.js';
import { textgenerationwebui_settings, textgen_types } from './textgen-settings.js';
const TINTS = 4; const TINTS = 4;
const MAX_MESSAGE_LOGPROBS = 100; const MAX_MESSAGE_LOGPROBS = 100;
@ -139,9 +140,14 @@ function renderTopLogprobs() {
const candidates = topLogprobs const candidates = topLogprobs
.sort(([, logA], [, logB]) => logB - logA) .sort(([, logA], [, logB]) => logB - logA)
.map(([text, log]) => { .map(([text, log]) => {
const probability = Math.exp(log); if (textgenerationwebui_settings.type !== textgen_types.LLAMACPP) {
sum += probability; const probability = Math.exp(log);
return [text, probability, log]; sum += probability;
return [text, probability, log];
}
else {
return [text, log, null];
}
}); });
candidates.push(['<others>', 1 - sum, 0]); candidates.push(['<others>', 1 - sum, 0]);
@ -157,7 +163,9 @@ function renderTopLogprobs() {
const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token)}`); const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token)}`);
const percentText = $('<span></span>').text(`${(probability * 100).toFixed(2)}%`); const percentText = $('<span></span>').text(`${(probability * 100).toFixed(2)}%`);
container.append(tokenText, percentText); container.append(tokenText, percentText);
container.attr('title', `logarithm: ${log}`); if (log) {
container.attr('title', `logarithm: ${log}`);
}
addKeyboardProps(container); addKeyboardProps(container);
if (token !== '<others>') { if (token !== '<others>') {
container.click(() => onAlternativeClicked(state.selectedTokenLogprobs, token)); container.click(() => onAlternativeClicked(state.selectedTokenLogprobs, token));

View File

@ -694,7 +694,7 @@ async function generateTextGenWithStreaming(generate_data, signal) {
} else { } else {
const newText = data?.choices?.[0]?.text || data?.content || ''; const newText = data?.choices?.[0]?.text || data?.content || '';
text += newText; text += newText;
logprobs = parseTextgenLogprobs(newText, data.choices?.[0]?.logprobs); logprobs = parseTextgenLogprobs(newText, data.choices?.[0]?.logprobs || data?.completion_probabilities);
} }
yield { text, swipes, logprobs }; yield { text, swipes, logprobs };
@ -727,6 +727,14 @@ function parseTextgenLogprobs(token, logprobs) {
const candidates = Object.entries(topLogprobs[0]); const candidates = Object.entries(topLogprobs[0]);
return { token, topLogprobs: candidates }; return { token, topLogprobs: candidates };
} }
case LLAMACPP: {
/** @type {Record<string, number>[]} */
if (!logprobs?.length) {
return null;
}
const candidates = logprobs[0].probs.map(x => [ x.tok_str, x.prob ]);
return { token, topLogprobs: candidates };
}
default: default:
return null; return null;
} }
@ -867,6 +875,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'n_predict': maxTokens, 'n_predict': maxTokens,
'mirostat': settings.mirostat_mode, 'mirostat': settings.mirostat_mode,
'ignore_eos': settings.ban_eos_token, 'ignore_eos': settings.ban_eos_token,
'n_probs': power_user.request_token_probabilities ? 10 : undefined,
}; };
const aphroditeParams = { const aphroditeParams = {
'n': canMultiSwipe ? settings.n : 1, 'n': canMultiSwipe ? settings.n : 1,