#1328 New API schema for ooba / mancer / aphrodite

This commit is contained in:
Cohee
2023-11-08 00:17:13 +02:00
parent 2d2ff5230c
commit 2c7b954a8d
7 changed files with 293 additions and 341 deletions

View File

@@ -3,6 +3,7 @@ import {
getRequestHeaders,
getStoppingStrings,
max_context,
online_status,
saveSettingsDebounced,
setGenerationParamsFromPreset,
} from "../script.js";
@@ -12,7 +13,7 @@ import {
power_user,
} from "./power-user.js";
import { getTextTokens, tokenizers } from "./tokenizers.js";
import { delay, onlyUnique } from "./utils.js";
import { onlyUnique } from "./utils.js";
export {
textgenerationwebui_settings,
@@ -27,6 +28,9 @@ export const textgen_types = {
APHRODITE: 'aphrodite',
};
// Maybe let it be configurable in the future?
export const MANCER_SERVER = 'https://neuro.mancer.tech';
const textgenerationwebui_settings = {
temp: 0.7,
temperature_last: true,
@@ -58,7 +62,6 @@ const textgenerationwebui_settings = {
ban_eos_token: false,
skip_special_tokens: true,
streaming: false,
streaming_url: 'ws://127.0.0.1:5005/api/v1/stream',
mirostat_mode: 0,
mirostat_tau: 5,
mirostat_eta: 0.1,
@@ -74,6 +77,7 @@ const textgenerationwebui_settings = {
//log_probs_aphrodite: 0,
//prompt_log_probs_aphrodite: 0,
type: textgen_types.OOBA,
mancer_model: 'mytholite',
};
export let textgenerationwebui_banned_in_macros = [];
@@ -109,7 +113,6 @@ const setting_names = [
"ban_eos_token",
"skip_special_tokens",
"streaming",
"streaming_url",
"mirostat_mode",
"mirostat_tau",
"mirostat_eta",
@@ -142,17 +145,12 @@ async function selectPreset(name) {
saveSettingsDebounced();
}
function formatTextGenURL(value, use_mancer) {
function formatTextGenURL(value) {
try {
const url = new URL(value);
if (!power_user.relaxed_api_urls) {
if (use_mancer) { // If Mancer is in use, only require the URL to *end* with `/api`.
if (!url.pathname.endsWith('/api')) {
return null;
}
} else {
url.pathname = '/api';
}
if (url.pathname === '/api') {
url.pathname = '/';
toastr.info('Legacy API URL detected, please make sure you updated ooba-webui to the latest version.');
}
return url.toString();
} catch { } // Just using URL as a validation check
@@ -255,8 +253,6 @@ export function isOoba() {
export function getTextGenUrlSourceId() {
switch (textgenerationwebui_settings.type) {
case textgen_types.MANCER:
return "#mancer_api_url_text";
case textgen_types.OOBA:
return "#textgenerationwebui_api_url_text";
case textgen_types.APHRODITE:
@@ -371,33 +367,11 @@ function setSettingByName(i, value, trigger) {
}
async function generateTextGenWithStreaming(generate_data, signal) {
let streamingUrl = textgenerationwebui_settings.streaming_url;
generate_data.stream = true;
if (isMancer()) {
streamingUrl = api_server_textgenerationwebui.replace("http", "ws") + "/v1/stream";
}
if (isAphrodite()) {
streamingUrl = api_server_textgenerationwebui;
}
if (isMancer() || isOoba()) {
try {
const parsedUrl = new URL(streamingUrl);
if (parsedUrl.protocol !== 'ws:' && parsedUrl.protocol !== 'wss:') {
throw new Error('Invalid protocol');
}
} catch {
toastr.error('Invalid URL for streaming. Make sure it starts with ws:// or wss://');
return async function* () { throw new Error('Invalid URL for streaming.'); }
}
}
const response = await fetch('/generate_textgenerationwebui', {
const response = await fetch('/api/textgenerationwebui/generate', {
headers: {
...getRequestHeaders(),
'X-Response-Streaming': String(true),
'X-Streaming-URL': streamingUrl,
},
body: JSON.stringify(generate_data),
method: 'POST',
@@ -408,54 +382,93 @@ async function generateTextGenWithStreaming(generate_data, signal) {
const decoder = new TextDecoder();
const reader = response.body.getReader();
let getMessage = '';
let messageBuffer = "";
while (true) {
const { done, value } = await reader.read();
let response = decoder.decode(value);
// We don't want carriage returns in our messages
let response = decoder.decode(value).replace(/\r/g, "");
if (isAphrodite()) {
const events = response.split('\n\n');
tryParseStreamingError(response);
for (const event of events) {
if (event.length == 0) {
continue;
}
let eventList = [];
try {
const { results } = JSON.parse(event);
messageBuffer += response;
eventList = messageBuffer.split("\n\n");
// Last element will be an empty string or a leftover partial message
messageBuffer = eventList.pop();
if (Array.isArray(results) && results.length > 0) {
getMessage = results[0].text;
yield getMessage;
// unhang UI thread
await delay(1);
}
} catch {
// Ignore
}
for (let event of eventList) {
if (event.startsWith('event: completion')) {
event = event.split("\n")[1];
}
if (done) {
if (typeof event !== 'string' || !event.length)
continue;
if (!event.startsWith("data"))
continue;
if (event == "data: [DONE]") {
return;
}
} else {
getMessage += response;
if (done) {
return;
}
let data = JSON.parse(event.substring(6));
// the first and last messages are undefined, protect against that
getMessage += data?.choices[0]?.text || '';
yield getMessage;
}
if (done) {
return;
}
}
}
}
/**
* Parses errors in streaming responses and displays them in toastr.
* @param {string} response - Response from the server.
* @returns {void} Nothing.
*/
function tryParseStreamingError(response) {
let data = {};
try {
data = JSON.parse(response);
} catch {
// No JSON. Do nothing.
}
if (data?.error?.message) {
toastr.error(data.error.message, 'API Error');
throw new Error(data.error.message);
}
}
function toIntArray(string) {
if (!string) {
return [];
}
return string.split(',').map(x => parseInt(x)).filter(x => !isNaN(x));
}
function getModel() {
if (isMancer()) {
return textgenerationwebui_settings.mancer_model;
}
if (isAphrodite()) {
return online_status;
}
return undefined;
}
export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImpersonate, cfgValues) {
return {
'prompt': finalPrompt,
'model': getModel(),
'max_new_tokens': this_amount_gen,
'max_tokens': this_amount_gen,
'do_sample': textgenerationwebui_settings.do_sample,
'temperature': textgenerationwebui_settings.temp,
'temperature_last': textgenerationwebui_settings.temperature_last,
@@ -469,6 +482,7 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso
'presence_penalty': textgenerationwebui_settings.presence_pen,
'top_k': textgenerationwebui_settings.top_k,
'min_length': textgenerationwebui_settings.min_length,
'min_tokens': textgenerationwebui_settings.min_length,
'no_repeat_ngram_size': textgenerationwebui_settings.no_repeat_ngram_size,
'num_beams': textgenerationwebui_settings.num_beams,
'penalty_alpha': textgenerationwebui_settings.penalty_alpha,
@@ -479,6 +493,7 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso
'seed': textgenerationwebui_settings.seed,
'add_bos_token': textgenerationwebui_settings.add_bos_token,
'stopping_strings': getStoppingStrings(isImpersonate),
'stop': getStoppingStrings(isImpersonate),
'truncation_length': max_context,
'ban_eos_token': textgenerationwebui_settings.ban_eos_token,
'skip_special_tokens': textgenerationwebui_settings.skip_special_tokens,
@@ -490,9 +505,11 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso
'mirostat_tau': textgenerationwebui_settings.mirostat_tau,
'mirostat_eta': textgenerationwebui_settings.mirostat_eta,
'grammar_string': textgenerationwebui_settings.grammar_string,
'custom_token_bans': getCustomTokenBans(),
'custom_token_bans': isAphrodite() ? toIntArray(getCustomTokenBans()) : getCustomTokenBans(),
'use_mancer': isMancer(),
'use_aphrodite': isAphrodite(),
'use_ooba': isOoba(),
'api_server': isMancer() ? MANCER_SERVER : api_server_textgenerationwebui,
//'n': textgenerationwebui_settings.n_aphrodite,
//'best_of': textgenerationwebui_settings.n_aphrodite, //n must always == best_of and vice versa
//'ignore_eos': textgenerationwebui_settings.ignore_eos_token_aphrodite,
@@ -502,3 +519,4 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso
//'prompt_logprobs': textgenerationwebui_settings.prompt_log_probs_aphrodite,
};
}