Replaces is_[api] params with api_type param

These were 5 mutually-exclusive booleans, which can be replaced with one
param that takes on 5 values, one for each API type.
This commit is contained in:
valadaptive
2023-12-03 02:45:53 -05:00
parent 8a1ead531c
commit ba54e3dea0
4 changed files with 64 additions and 62 deletions

View File

@ -885,13 +885,9 @@ async function getStatus() {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
main_api: main_api, main_api,
api_server: endpoint, api_server: endpoint,
use_mancer: main_api == 'textgenerationwebui' ? isMancer() : false, api_type: textgenerationwebui_settings.type,
use_aphrodite: main_api == 'textgenerationwebui' ? isAphrodite() : false,
use_ooba: main_api == 'textgenerationwebui' ? isOoba() : false,
use_tabby: main_api == 'textgenerationwebui' ? isTabby() : false,
use_koboldcpp: main_api == 'textgenerationwebui' ? isKoboldCpp() : false,
legacy_api: main_api == 'textgenerationwebui' ? textgenerationwebui_settings.legacy_api && !isMancer() : false, legacy_api: main_api == 'textgenerationwebui' ? textgenerationwebui_settings.legacy_api && !isMancer() : false,
}), }),
signal: abortStatusCheck.signal, signal: abortStatusCheck.signal,

View File

@ -620,11 +620,7 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'mirostat_tau': textgenerationwebui_settings.mirostat_tau, 'mirostat_tau': textgenerationwebui_settings.mirostat_tau,
'mirostat_eta': textgenerationwebui_settings.mirostat_eta, 'mirostat_eta': textgenerationwebui_settings.mirostat_eta,
'custom_token_bans': isAphrodite() ? toIntArray(getCustomTokenBans()) : getCustomTokenBans(), 'custom_token_bans': isAphrodite() ? toIntArray(getCustomTokenBans()) : getCustomTokenBans(),
'use_mancer': isMancer(), 'api_type': textgenerationwebui_settings.type,
'use_aphrodite': isAphrodite(),
'use_tabby': isTabby(),
'use_koboldcpp': isKoboldCpp(),
'use_ooba': isOoba(),
'api_server': isMancer() ? MANCER_SERVER : api_server_textgenerationwebui, 'api_server': isMancer() ? MANCER_SERVER : api_server_textgenerationwebui,
'legacy_api': textgenerationwebui_settings.legacy_api && !isMancer(), 'legacy_api': textgenerationwebui_settings.legacy_api && !isMancer(),
'sampler_order': isKoboldCpp() ? textgenerationwebui_settings.sampler_order : undefined, 'sampler_order': isKoboldCpp() ? textgenerationwebui_settings.sampler_order : undefined,

View File

@ -392,11 +392,10 @@ function getTokenCacheObject() {
function getRemoteTokenizationParams(str) { function getRemoteTokenizationParams(str) {
return { return {
text: str, text: str,
api: main_api, main_api,
api_type: textgenerationwebui_settings.type,
url: getAPIServerUrl(), url: getAPIServerUrl(),
legacy_api: main_api === 'textgenerationwebui' && textgenerationwebui_settings.legacy_api && !isMancer(), legacy_api: main_api === 'textgenerationwebui' && textgenerationwebui_settings.legacy_api && !isMancer(),
use_tabby: main_api === 'textgenerationwebui' && isTabby(),
use_koboldcpp: main_api === 'textgenerationwebui' && isKoboldCpp(),
}; };
} }

View File

@ -177,19 +177,24 @@ function getOverrideHeaders(urlHost) {
* @param {string|null} server API server for new request * @param {string|null} server API server for new request
*/ */
function setAdditionalHeaders(request, args, server) { function setAdditionalHeaders(request, args, server) {
let headers = {}; let headers;
if (request.body.use_mancer) { switch (request.body.api_type) {
case 'mancer':
headers = getMancerHeaders(); headers = getMancerHeaders();
} else if (request.body.use_aphrodite) { break;
case 'aphrodite':
headers = getAphroditeHeaders(); headers = getAphroditeHeaders();
} else if (request.body.use_tabby) { break;
case 'tabby':
headers = getTabbyHeaders(); headers = getTabbyHeaders();
} else { break;
default:
headers = server ? getOverrideHeaders((new URL(server))?.host) : {}; headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
break;
} }
args.headers = Object.assign(args.headers, headers); Object.assign(args.headers, headers);
} }
function humanizedISO8601DateTime(date) { function humanizedISO8601DateTime(date) {
@ -562,21 +567,20 @@ app.post('/api/textgenerationwebui/status', jsonParser, async function (request,
if (request.body.legacy_api) { if (request.body.legacy_api) {
url += '/v1/model'; url += '/v1/model';
} } else {
else if (request.body.use_ooba) { switch (request.body.api_type) {
case 'ooba':
case 'aphrodite':
case 'koboldcpp':
url += '/v1/models'; url += '/v1/models';
} break;
else if (request.body.use_aphrodite) { case 'mancer':
url += '/v1/models';
}
else if (request.body.use_mancer) {
url += '/oai/v1/models'; url += '/oai/v1/models';
} break;
else if (request.body.use_tabby) { case 'tabby':
url += '/v1/model/list'; url += '/v1/model/list';
break;
} }
else if (request.body.use_koboldcpp) {
url += '/v1/models';
} }
const modelsReply = await fetch(url, args); const modelsReply = await fetch(url, args);
@ -604,7 +608,7 @@ app.post('/api/textgenerationwebui/status', jsonParser, async function (request,
// Set result to the first model ID // Set result to the first model ID
result = modelIds[0] || 'Valid'; result = modelIds[0] || 'Valid';
if (request.body.use_ooba) { if (request.body.api_type === 'ooba') {
try { try {
const modelInfoUrl = baseUrl + '/v1/internal/model/info'; const modelInfoUrl = baseUrl + '/v1/internal/model/info';
const modelInfoReply = await fetch(modelInfoUrl, args); const modelInfoReply = await fetch(modelInfoUrl, args);
@ -619,9 +623,7 @@ app.post('/api/textgenerationwebui/status', jsonParser, async function (request,
} catch (error) { } catch (error) {
console.error(`Failed to get Ooba model info: ${error}`); console.error(`Failed to get Ooba model info: ${error}`);
} }
} } else if (request.body.api_type === 'tabby') {
if (request.body.use_tabby) {
try { try {
const modelInfoUrl = baseUrl + '/v1/model'; const modelInfoUrl = baseUrl + '/v1/model';
const modelInfoReply = await fetch(modelInfoUrl, args); const modelInfoReply = await fetch(modelInfoUrl, args);
@ -671,12 +673,18 @@ app.post('/api/textgenerationwebui/generate', jsonParser, async function (reques
if (request.body.legacy_api) { if (request.body.legacy_api) {
url += '/v1/generate'; url += '/v1/generate';
} } else {
else if (request.body.use_aphrodite || request.body.use_ooba || request.body.use_tabby || request.body.use_koboldcpp) { switch (request.body.api_type) {
case 'aphrodite':
case 'ooba':
case 'tabby':
case 'koboldcpp':
url += '/v1/completions'; url += '/v1/completions';
} break;
else if (request.body.use_mancer) { case 'mancer':
url += '/oai/v1/completions'; url += '/oai/v1/completions';
break;
}
} }
const args = { const args = {
@ -3471,7 +3479,7 @@ app.post('/tokenize_via_api', jsonParser, async function (request, response) {
return response.sendStatus(400); return response.sendStatus(400);
} }
const text = String(request.body.text) || ''; const text = String(request.body.text) || '';
const api = String(request.body.api); const api = String(request.body.main_api);
const baseUrl = String(request.body.url); const baseUrl = String(request.body.url);
const legacyApi = Boolean(request.body.legacy_api); const legacyApi = Boolean(request.body.legacy_api);
@ -3490,18 +3498,21 @@ app.post('/tokenize_via_api', jsonParser, async function (request, response) {
if (legacyApi) { if (legacyApi) {
url += '/v1/token-count'; url += '/v1/token-count';
args.body = JSON.stringify({ 'prompt': text }); args.body = JSON.stringify({ 'prompt': text });
} } else {
else if (request.body.use_tabby) { switch (request.body.api_type) {
case 'tabby':
url += '/v1/token/encode'; url += '/v1/token/encode';
args.body = JSON.stringify({ 'text': text }); args.body = JSON.stringify({ 'text': text });
} break;
else if (request.body.use_koboldcpp) { case 'koboldcpp':
url += '/api/extra/tokencount'; url += '/api/extra/tokencount';
args.body = JSON.stringify({ 'prompt': text }); args.body = JSON.stringify({ 'prompt': text });
} break;
else { default:
url += '/v1/internal/encode'; url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text }); args.body = JSON.stringify({ 'text': text });
break;
}
} }
const result = await fetch(url, args); const result = await fetch(url, args);