#1309 Ollama text completion backend

This commit is contained in:
Cohee
2023-12-19 16:38:11 +02:00
parent edd737e8bd
commit 67dd52c21b
10 changed files with 508 additions and 185 deletions

View File

@@ -1,5 +1,4 @@
import {
api_server_textgenerationwebui,
getRequestHeaders,
getStoppingStrings,
max_context,
@@ -34,9 +33,10 @@ export const textgen_types = {
KOBOLDCPP: 'koboldcpp',
TOGETHERAI: 'togetherai',
LLAMACPP: 'llamacpp',
OLLAMA: 'ollama',
};
const { MANCER, APHRODITE, TOGETHERAI, OOBA } = textgen_types;
const { MANCER, APHRODITE, TOGETHERAI, OOBA, OLLAMA, LLAMACPP } = textgen_types;
const BIAS_KEY = '#textgenerationwebui_api-settings';
// Maybe let it be configurable in the future?
@@ -46,6 +46,15 @@ const MANCER_SERVER_DEFAULT = 'https://neuro.mancer.tech';
let MANCER_SERVER = localStorage.getItem(MANCER_SERVER_KEY) ?? MANCER_SERVER_DEFAULT;
let TOGETHERAI_SERVER = 'https://api.together.xyz';
const SERVER_INPUTS = {
[textgen_types.OOBA]: '#textgenerationwebui_api_url_text',
[textgen_types.APHRODITE]: '#aphrodite_api_url_text',
[textgen_types.TABBY]: '#tabby_api_url_text',
[textgen_types.KOBOLDCPP]: '#koboldcpp_api_url_text',
[textgen_types.LLAMACPP]: '#llamacpp_api_url_text',
[textgen_types.OLLAMA]: '#ollama_api_url_text',
};
const KOBOLDCPP_ORDER = [6, 0, 1, 3, 4, 2, 5];
const settings = {
temp: 0.7,
@@ -95,10 +104,12 @@ const settings = {
type: textgen_types.OOBA,
mancer_model: 'mytholite',
togetherai_model: 'Gryphe/MythoMax-L2-13b',
ollama_model: '',
legacy_api: false,
sampler_order: KOBOLDCPP_ORDER,
logit_bias: [],
n: 1,
server_urls: {},
};
export let textgenerationwebui_banned_in_macros = [];
@@ -154,6 +165,37 @@ const setting_names = [
'logit_bias',
];
export function validateTextGenUrl() {
const selector = SERVER_INPUTS[settings.type];
if (!selector) {
return;
}
const control = $(selector);
const url = String(control.val()).trim();
const formattedUrl = formatTextGenURL(url);
if (!formattedUrl) {
toastr.error('Enter a valid API URL', 'Text Completion API');
return;
}
control.val(formattedUrl);
}
export function getTextGenServer() {
if (settings.type === MANCER) {
return MANCER_SERVER;
}
if (settings.type === TOGETHERAI) {
return TOGETHERAI_SERVER;
}
return settings.server_urls[settings.type] ?? '';
}
async function selectPreset(name) {
const preset = textgenerationwebui_presets[textgenerationwebui_preset_names.indexOf(name)];
@@ -291,6 +333,21 @@ function loadTextGenSettings(data, loadedSettings) {
textgenerationwebui_preset_names = data.textgenerationwebui_preset_names ?? [];
Object.assign(settings, loadedSettings.textgenerationwebui_settings ?? {});
if (loadedSettings.api_server_textgenerationwebui) {
for (const type of Object.keys(SERVER_INPUTS)) {
settings.server_urls[type] = loadedSettings.api_server_textgenerationwebui;
}
delete loadedSettings.api_server_textgenerationwebui;
}
for (const [type, selector] of Object.entries(SERVER_INPUTS)) {
const control = $(selector);
control.val(settings.server_urls[type] ?? '').on('input', function () {
settings.server_urls[type] = String($(this).val());
saveSettingsDebounced();
});
}
if (loadedSettings.api_use_mancer_webui) {
settings.type = MANCER;
}
@@ -336,21 +393,6 @@ function loadTextGenSettings(data, loadedSettings) {
});
}
export function getTextGenUrlSourceId() {
switch (settings.type) {
case textgen_types.OOBA:
return '#textgenerationwebui_api_url_text';
case textgen_types.APHRODITE:
return '#aphrodite_api_url_text';
case textgen_types.TABBY:
return '#tabby_api_url_text';
case textgen_types.KOBOLDCPP:
return '#koboldcpp_api_url_text';
case textgen_types.LLAMACPP:
return '#llamacpp_api_url_text';
}
}
/**
* Sorts the sampler items by the given order.
* @param {any[]} orderArray Sampler order array.
@@ -423,7 +465,10 @@ jQuery(function () {
BIAS_CACHE.delete(BIAS_KEY);
$('#main_api').trigger('change');
$('#api_button_textgenerationwebui').trigger('click');
if (!SERVER_INPUTS[type] || settings.server_urls[type]) {
$('#api_button_textgenerationwebui').trigger('click');
}
saveSettingsDebounced();
});
@@ -620,21 +665,18 @@ function getModel() {
return online_status;
}
if (settings.type === OLLAMA) {
if (!settings.ollama_model) {
toastr.error('No Ollama model selected.', 'Text Completion API');
throw new Error('No Ollama model selected');
}
return settings.ollama_model;
}
return undefined;
}
export function getTextGenServer() {
if (settings.type === MANCER) {
return MANCER_SERVER;
}
if (settings.type === TOGETHERAI) {
return TOGETHERAI_SERVER;
}
return api_server_textgenerationwebui;
}
export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) {
const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet';
let params = {
@@ -687,6 +729,13 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'guidance_scale': cfgValues?.guidanceScale?.value ?? settings.guidance_scale ?? 1,
'negative_prompt': cfgValues?.negativePrompt ?? substituteParams(settings.negative_prompt) ?? '',
'grammar_string': settings.grammar_string,
// llama.cpp aliases. In case someone wants to use LM Studio as Text Completion API
'repeat_penalty': settings.rep_pen,
'tfs_z': settings.tfs,
'repeat_last_n': settings.rep_pen_range,
'n_predict': settings.maxTokens,
'mirostat': settings.mirostat_mode,
'ignore_eos': settings.ban_eos_token,
};
const aphroditeParams = {
'n': canMultiSwipe ? settings.n : 1,
@@ -697,7 +746,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
//'logprobs': settings.log_probs_aphrodite,
//'prompt_logprobs': settings.prompt_log_probs_aphrodite,
};
if (settings.type === textgen_types.APHRODITE) {
if (settings.type === APHRODITE) {
params = Object.assign(params, aphroditeParams);
} else {
params = Object.assign(params, nonAphroditeParams);
@@ -709,7 +758,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
params.logit_bias = logitBias;
}
if (settings.type === textgen_types.LLAMACPP) {
if (settings.type === LLAMACPP || settings.type === OLLAMA) {
// Convert bias and token bans to array of arrays
const logitBiasArray = (params.logit_bias && typeof params.logit_bias === 'object' && Object.keys(params.logit_bias).length > 0)
? Object.entries(params.logit_bias).map(([key, value]) => [Number(key), value])
@@ -717,14 +766,9 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
const tokenBans = toIntArray(getCustomTokenBans());
logitBiasArray.push(...tokenBans.map(x => [Number(x), false]));
const llamaCppParams = {
'repeat_penalty': settings.rep_pen,
'tfs_z': settings.tfs,
'repeat_last_n': settings.rep_pen_range,
'n_predict': settings.maxTokens,
'mirostat': settings.mirostat_mode,
'ignore_eos': settings.ban_eos_token,
'grammar': settings.grammar_string,
'logit_bias': logitBiasArray,
// Conflicts with ooba's grammar_string
'grammar': settings.grammar_string,
};
params = Object.assign(params, llamaCppParams);
}