Merge pull request #1854 from deciare/llamacpp-probs

Request and display token probabilities from llama.cpp backend
This commit is contained in:
Cohee 2024-02-24 15:06:28 +02:00 committed by GitHub
commit 13aebc623a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 8 deletions

View File

@ -18,6 +18,7 @@ import {
textgen_types,
getTextGenServer,
validateTextGenUrl,
parseTextgenLogprobs,
} from './scripts/textgen-settings.js';
const { MANCER, TOGETHERAI, OOBA, APHRODITE, OLLAMA, INFERMATICAI } = textgen_types;
@ -2668,8 +2669,8 @@ class StreamingProcessor {
}
const continueMsg = this.type === 'continue' ? this.messageAlreadyGenerated : undefined;
await saveChatConditional();
saveLogprobsForActiveMessage(this.messageLogprobs.filter(Boolean), continueMsg);
await saveChatConditional();
activateSendButtons();
showSwipeButtons();
setGenerationProgress(0);
@ -4481,6 +4482,11 @@ function parseAndSaveLogprobs(data, continueFrom) {
// `sendOpenAIRequest`. `data` for these APIs is just a string with
// the text of the generated message, logprobs are not included.
return;
case 'textgenerationwebui':
if (textgen_settings.type === textgen_types.LLAMACPP) {
logprobs = data?.completion_probabilities?.map(x => parseTextgenLogprobs(x.content, [x])) || null;
}
break;
default:
return;
}

View File

@ -139,9 +139,14 @@ function renderTopLogprobs() {
const candidates = topLogprobs
.sort(([, logA], [, logB]) => logB - logA)
.map(([text, log]) => {
const probability = Math.exp(log);
sum += probability;
return [text, probability, log];
if (log < 0) {
const probability = Math.exp(log);
sum += probability;
return [text, probability, log];
}
else {
return [text, log, null];
}
});
candidates.push(['<others>', 1 - sum, 0]);
@ -157,7 +162,9 @@ function renderTopLogprobs() {
const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token)}`);
const percentText = $('<span></span>').text(`${(probability * 100).toFixed(2)}%`);
container.append(tokenText, percentText);
container.attr('title', `logarithm: ${log}`);
if (log) {
container.attr('title', `logarithm: ${log}`);
}
addKeyboardProps(container);
if (token !== '<others>') {
container.click(() => onAlternativeClicked(state.selectedTokenLogprobs, token));
@ -459,7 +466,7 @@ function convertTokenIdLogprobsToText(input) {
}
export function initLogprobs() {
const debouncedRender = debounce(renderAlternativeTokensView, 250);
const debouncedRender = debounce(renderAlternativeTokensView, 500);
$('#logprobsViewerClose').click(onToggleLogprobsPanel);
$('#option_toggle_logprobs').click(onToggleLogprobsPanel);
eventSource.on(event_types.CHAT_CHANGED, debouncedRender);

View File

@ -755,7 +755,7 @@ async function generateTextGenWithStreaming(generate_data, signal) {
} else {
const newText = data?.choices?.[0]?.text || data?.content || '';
text += newText;
logprobs = parseTextgenLogprobs(newText, data.choices?.[0]?.logprobs);
logprobs = parseTextgenLogprobs(newText, data.choices?.[0]?.logprobs || data?.completion_probabilities);
}
yield { text, swipes, logprobs };
@ -771,7 +771,7 @@ async function generateTextGenWithStreaming(generate_data, signal) {
* @param {Object} logprobs - logprobs object returned from the API
* @returns {import('logprobs.js').TokenLogprobs | null} - converted logprobs
*/
function parseTextgenLogprobs(token, logprobs) {
export function parseTextgenLogprobs(token, logprobs) {
if (!logprobs) {
return null;
}
@ -788,6 +788,14 @@ function parseTextgenLogprobs(token, logprobs) {
const candidates = Object.entries(topLogprobs[0]);
return { token, topLogprobs: candidates };
}
case LLAMACPP: {
/** @type {Record<string, number>[]} */
if (!logprobs?.length) {
return null;
}
const candidates = logprobs[0].probs.map(x => [ x.tok_str, x.prob ]);
return { token, topLogprobs: candidates };
}
default:
return null;
}
@ -933,6 +941,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'n_predict': maxTokens,
'mirostat': settings.mirostat_mode,
'ignore_eos': settings.ban_eos_token,
'n_probs': power_user.request_token_probabilities ? 10 : undefined,
};
const aphroditeParams = {
'n': canMultiSwipe ? settings.n : 1,