#371 Add llama.cpp inference server support

This commit is contained in:
Cohee
2023-12-18 22:38:28 +02:00
parent 6e8104873e
commit edd737e8bd
9 changed files with 136 additions and 37 deletions

View File

@ -33,9 +33,10 @@ export const textgen_types = {
TABBY: 'tabby',
KOBOLDCPP: 'koboldcpp',
TOGETHERAI: 'togetherai',
LLAMACPP: 'llamacpp',
};
const { MANCER, APHRODITE, TOGETHERAI } = textgen_types;
const { MANCER, APHRODITE, TOGETHERAI, OOBA } = textgen_types;
const BIAS_KEY = '#textgenerationwebui_api-settings';
// Maybe let it be configurable in the future?
@ -166,6 +167,7 @@ async function selectPreset(name) {
setSettingByName(name, value, true);
}
setGenerationParamsFromPreset(preset);
BIAS_CACHE.delete(BIAS_KEY);
displayLogitBias(preset.logit_bias, BIAS_KEY);
saveSettingsDebounced();
}
@ -311,6 +313,7 @@ function loadTextGenSettings(data, loadedSettings) {
$('#textgen_type').val(settings.type);
showTypeSpecificControls(settings.type);
BIAS_CACHE.delete(BIAS_KEY);
displayLogitBias(settings.logit_bias, BIAS_KEY);
//this is needed because showTypeSpecificControls() does not handle NOT declarations
if (settings.type === textgen_types.APHRODITE) {
@ -343,6 +346,8 @@ export function getTextGenUrlSourceId() {
return '#tabby_api_url_text';
case textgen_types.KOBOLDCPP:
return '#koboldcpp_api_url_text';
case textgen_types.LLAMACPP:
return '#llamacpp_api_url_text';
}
}
@ -415,6 +420,7 @@ jQuery(function () {
showTypeSpecificControls(type);
setOnlineStatus('no_connection');
BIAS_CACHE.delete(BIAS_KEY);
$('#main_api').trigger('change');
$('#api_button_textgenerationwebui').trigger('click');
@ -463,11 +469,14 @@ jQuery(function () {
function showTypeSpecificControls(type) {
$('[data-tg-type]').each(function () {
const tgType = $(this).attr('data-tg-type');
if (tgType == type) {
$(this).show();
} else {
$(this).hide();
const tgTypes = $(this).attr('data-tg-type').split(',');
for (const tgType of tgTypes) {
if (tgType === type || tgType == 'all') {
$(this).show();
return;
} else {
$(this).hide();
}
}
});
}
@ -550,11 +559,11 @@ async function generateTextGenWithStreaming(generate_data, signal) {
let data = JSON.parse(value.data);
if (data?.choices[0]?.index > 0) {
if (data?.choices?.[0]?.index > 0) {
const swipeIndex = data.choices[0].index - 1;
swipes[swipeIndex] = (swipes[swipeIndex] || '') + data.choices[0].text;
} else {
text += data?.choices[0]?.text || '';
text += data?.choices?.[0]?.text || data?.content || '';
}
yield { text, swipes };
@ -585,6 +594,11 @@ function tryParseStreamingError(response, decoded) {
}
}
/**
* Converts a string of comma-separated integers to an array of integers.
* @param {string} string Input string
* @returns {number[]} Array of integers
*/
function toIntArray(string) {
if (!string) {
return [];
@ -623,7 +637,7 @@ export function getTextGenServer() {
export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) {
const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet';
let APIflags = {
let params = {
'prompt': finalPrompt,
'model': getModel(),
'max_new_tokens': maxTokens,
@ -659,12 +673,10 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
getCustomTokenBans(),
'api_type': settings.type,
'api_server': getTextGenServer(),
'legacy_api': settings.legacy_api && settings.type !== MANCER && settings.type !== TOGETHERAI,
'sampler_order': settings.type === textgen_types.KOBOLDCPP ?
settings.sampler_order :
undefined,
'legacy_api': settings.legacy_api && (settings.type === OOBA || settings.type === APHRODITE),
'sampler_order': settings.type === textgen_types.KOBOLDCPP ? settings.sampler_order : undefined,
};
let aphroditeExclusionFlags = {
const nonAphroditeParams = {
'repetition_penalty_range': settings.rep_pen_range,
'encoder_repetition_penalty': settings.encoder_rep_pen,
'no_repeat_ngram_size': settings.no_repeat_ngram_size,
@ -676,7 +688,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'negative_prompt': cfgValues?.negativePrompt ?? substituteParams(settings.negative_prompt) ?? '',
'grammar_string': settings.grammar_string,
};
let aphroditeFlags = {
const aphroditeParams = {
'n': canMultiSwipe ? settings.n : 1,
'best_of': canMultiSwipe ? settings.n : 1,
'ignore_eos': settings.ignore_eos_token_aphrodite,
@ -686,17 +698,37 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
//'prompt_logprobs': settings.prompt_log_probs_aphrodite,
};
if (settings.type === textgen_types.APHRODITE) {
APIflags = Object.assign(APIflags, aphroditeFlags);
params = Object.assign(params, aphroditeParams);
} else {
APIflags = Object.assign(APIflags, aphroditeExclusionFlags);
params = Object.assign(params, nonAphroditeParams);
}
if (Array.isArray(settings.logit_bias) && settings.logit_bias.length) {
const logitBias = BIAS_CACHE.get(BIAS_KEY) || calculateLogitBias();
BIAS_CACHE.set(BIAS_KEY, logitBias);
APIflags.logit_bias = logitBias;
params.logit_bias = logitBias;
}
return APIflags;
if (settings.type === textgen_types.LLAMACPP) {
// Convert bias and token bans to array of arrays
const logitBiasArray = (params.logit_bias && typeof params.logit_bias === 'object' && Object.keys(params.logit_bias).length > 0)
? Object.entries(params.logit_bias).map(([key, value]) => [Number(key), value])
: [];
const tokenBans = toIntArray(getCustomTokenBans());
logitBiasArray.push(...tokenBans.map(x => [Number(x), false]));
const llamaCppParams = {
'repeat_penalty': settings.rep_pen,
'tfs_z': settings.tfs,
'repeat_last_n': settings.rep_pen_range,
'n_predict': settings.maxTokens,
'mirostat': settings.mirostat_mode,
'ignore_eos': settings.ban_eos_token,
'grammar': settings.grammar_string,
'logit_bias': logitBiasArray,
};
params = Object.assign(params, llamaCppParams);
}
return params;
}