Merge pull request #1854 from deciare/llamacpp-probs

Request and display token probabilities from llama.cpp backend
This commit is contained in:
Cohee
2024-02-24 15:06:28 +02:00
committed by GitHub
3 changed files with 30 additions and 8 deletions

View File

@@ -18,6 +18,7 @@ import {
textgen_types, textgen_types,
getTextGenServer, getTextGenServer,
validateTextGenUrl, validateTextGenUrl,
parseTextgenLogprobs,
} from './scripts/textgen-settings.js'; } from './scripts/textgen-settings.js';
const { MANCER, TOGETHERAI, OOBA, APHRODITE, OLLAMA, INFERMATICAI } = textgen_types; const { MANCER, TOGETHERAI, OOBA, APHRODITE, OLLAMA, INFERMATICAI } = textgen_types;
@@ -2668,8 +2669,8 @@ class StreamingProcessor {
} }
const continueMsg = this.type === 'continue' ? this.messageAlreadyGenerated : undefined; const continueMsg = this.type === 'continue' ? this.messageAlreadyGenerated : undefined;
await saveChatConditional();
saveLogprobsForActiveMessage(this.messageLogprobs.filter(Boolean), continueMsg); saveLogprobsForActiveMessage(this.messageLogprobs.filter(Boolean), continueMsg);
await saveChatConditional();
activateSendButtons(); activateSendButtons();
showSwipeButtons(); showSwipeButtons();
setGenerationProgress(0); setGenerationProgress(0);
@@ -4481,6 +4482,11 @@ function parseAndSaveLogprobs(data, continueFrom) {
// `sendOpenAIRequest`. `data` for these APIs is just a string with // `sendOpenAIRequest`. `data` for these APIs is just a string with
// the text of the generated message, logprobs are not included. // the text of the generated message, logprobs are not included.
return; return;
case 'textgenerationwebui':
if (textgen_settings.type === textgen_types.LLAMACPP) {
logprobs = data?.completion_probabilities?.map(x => parseTextgenLogprobs(x.content, [x])) || null;
}
break;
default: default:
return; return;
} }

View File

@@ -139,9 +139,14 @@ function renderTopLogprobs() {
const candidates = topLogprobs const candidates = topLogprobs
.sort(([, logA], [, logB]) => logB - logA) .sort(([, logA], [, logB]) => logB - logA)
.map(([text, log]) => { .map(([text, log]) => {
const probability = Math.exp(log); if (log < 0) {
sum += probability; const probability = Math.exp(log);
return [text, probability, log]; sum += probability;
return [text, probability, log];
}
else {
return [text, log, null];
}
}); });
candidates.push(['<others>', 1 - sum, 0]); candidates.push(['<others>', 1 - sum, 0]);
@@ -157,7 +162,9 @@ function renderTopLogprobs() {
const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token)}`); const tokenText = $('<span></span>').text(`${toVisibleWhitespace(token)}`);
const percentText = $('<span></span>').text(`${(probability * 100).toFixed(2)}%`); const percentText = $('<span></span>').text(`${(probability * 100).toFixed(2)}%`);
container.append(tokenText, percentText); container.append(tokenText, percentText);
container.attr('title', `logarithm: ${log}`); if (log) {
container.attr('title', `logarithm: ${log}`);
}
addKeyboardProps(container); addKeyboardProps(container);
if (token !== '<others>') { if (token !== '<others>') {
container.click(() => onAlternativeClicked(state.selectedTokenLogprobs, token)); container.click(() => onAlternativeClicked(state.selectedTokenLogprobs, token));
@@ -459,7 +466,7 @@ function convertTokenIdLogprobsToText(input) {
} }
export function initLogprobs() { export function initLogprobs() {
const debouncedRender = debounce(renderAlternativeTokensView, 250); const debouncedRender = debounce(renderAlternativeTokensView, 500);
$('#logprobsViewerClose').click(onToggleLogprobsPanel); $('#logprobsViewerClose').click(onToggleLogprobsPanel);
$('#option_toggle_logprobs').click(onToggleLogprobsPanel); $('#option_toggle_logprobs').click(onToggleLogprobsPanel);
eventSource.on(event_types.CHAT_CHANGED, debouncedRender); eventSource.on(event_types.CHAT_CHANGED, debouncedRender);

View File

@@ -755,7 +755,7 @@ async function generateTextGenWithStreaming(generate_data, signal) {
} else { } else {
const newText = data?.choices?.[0]?.text || data?.content || ''; const newText = data?.choices?.[0]?.text || data?.content || '';
text += newText; text += newText;
logprobs = parseTextgenLogprobs(newText, data.choices?.[0]?.logprobs); logprobs = parseTextgenLogprobs(newText, data.choices?.[0]?.logprobs || data?.completion_probabilities);
} }
yield { text, swipes, logprobs }; yield { text, swipes, logprobs };
@@ -771,7 +771,7 @@ async function generateTextGenWithStreaming(generate_data, signal) {
* @param {Object} logprobs - logprobs object returned from the API * @param {Object} logprobs - logprobs object returned from the API
* @returns {import('logprobs.js').TokenLogprobs | null} - converted logprobs * @returns {import('logprobs.js').TokenLogprobs | null} - converted logprobs
*/ */
function parseTextgenLogprobs(token, logprobs) { export function parseTextgenLogprobs(token, logprobs) {
if (!logprobs) { if (!logprobs) {
return null; return null;
} }
@@ -788,6 +788,14 @@ function parseTextgenLogprobs(token, logprobs) {
const candidates = Object.entries(topLogprobs[0]); const candidates = Object.entries(topLogprobs[0]);
return { token, topLogprobs: candidates }; return { token, topLogprobs: candidates };
} }
case LLAMACPP: {
/** @type {Record<string, number>[]} */
if (!logprobs?.length) {
return null;
}
const candidates = logprobs[0].probs.map(x => [ x.tok_str, x.prob ]);
return { token, topLogprobs: candidates };
}
default: default:
return null; return null;
} }
@@ -933,6 +941,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'n_predict': maxTokens, 'n_predict': maxTokens,
'mirostat': settings.mirostat_mode, 'mirostat': settings.mirostat_mode,
'ignore_eos': settings.ban_eos_token, 'ignore_eos': settings.ban_eos_token,
'n_probs': power_user.request_token_probabilities ? 10 : undefined,
}; };
const aphroditeParams = { const aphroditeParams = {
'n': canMultiSwipe ? settings.n : 1, 'n': canMultiSwipe ? settings.n : 1,