Implement Token Probabilities UI using logprobs

This commit is contained in:
khanon
2024-01-22 23:00:31 -06:00
parent 9b42be2334
commit 60044c18a4
16 changed files with 921 additions and 35 deletions

View File

@ -10,10 +10,7 @@ import {
} from '../script.js';
import { BIAS_CACHE, createNewLogitBiasEntry, displayLogitBias, getLogitBiasListResult } from './logit-bias.js';
import {
power_user,
registerDebugFunction,
} from './power-user.js';
import { power_user, registerDebugFunction } from './power-user.js';
import EventSourceStream from './sse-stream.js';
import { SENTENCEPIECE_TOKENIZERS, TEXTGEN_TOKENIZERS, getTextTokens, tokenizers } from './tokenizers.js';
import { getSortableDelay, onlyUnique } from './utils.js';
@ -675,6 +672,8 @@ async function generateTextGenWithStreaming(generate_data, signal) {
return async function* streamData() {
let text = '';
/** @type {import('logprobs.js').TokenLogprobs | null} */
let logprobs = null;
const swipes = [];
while (true) {
const { done, value } = await reader.read();
@ -689,14 +688,44 @@ async function generateTextGenWithStreaming(generate_data, signal) {
const swipeIndex = data.choices[0].index - 1;
swipes[swipeIndex] = (swipes[swipeIndex] || '') + data.choices[0].text;
} else {
text += data?.choices?.[0]?.text || data?.content || '';
const newText = data?.choices?.[0]?.text || data?.content || '';
text += newText;
logprobs = parseTextgenLogprobs(newText, data.choices[0]?.logprobs);
}
yield { text, swipes };
yield { text, swipes, logprobs };
}
};
}
/**
* parseTextgenLogprobs converts a logprobs object returned from a textgen API
* for a single token into a TokenLogprobs object used by the Token
* Probabilities feature.
* @param {string} token - the text of the token that the logprobs are for
* @param {Object} logprobs - logprobs object returned from the API
* @returns {import('logprobs.js').TokenLogprobs | null} - converted logprobs
*/
function parseTextgenLogprobs(token, logprobs) {
if (!logprobs) {
return null;
}
switch (settings.type) {
case OOBA: {
/** @type {Record<string, number>[]} */
const topLogprobs = logprobs.top_logprobs;
if (!topLogprobs?.length) {
return null;
}
const candidates = Object.entries(topLogprobs[0]);
return { token, topLogprobs: candidates };
}
default:
return null;
}
}
/**
* Parses errors in streaming responses and displays them in toastr.
* @param {Response} response - Response from the server.
@ -769,6 +798,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'model': getModel(),
'max_new_tokens': maxTokens,
'max_tokens': maxTokens,
'logprobs': power_user.request_token_probabilities ? 10: undefined,
'temperature': settings.dynatemp ? (settings.min_temp + settings.max_temp) / 2 : settings.temp,
'top_p': settings.top_p,
'typical_p': settings.typical_p,