Split up Kobold and textgenerationwebui endpoints

The endpoint was one big if/else statement that did two entirely
different things depending on the value of main_api. It makes more sense
for those to be two separate endpoints.
This commit is contained in:
valadaptive
2023-12-09 20:26:24 -05:00
parent 7486ab3886
commit 30502ac949
2 changed files with 73 additions and 72 deletions

View File

@ -1,4 +1,4 @@
import { characters, getAPIServerUrl, main_api, nai_settings, online_status, this_chid } from '../script.js'; import { characters, main_api, api_server, api_server_textgenerationwebui, nai_settings, online_status, this_chid } from '../script.js';
import { power_user, registerDebugFunction } from './power-user.js'; import { power_user, registerDebugFunction } from './power-user.js';
import { chat_completion_sources, model_list, oai_settings } from './openai.js'; import { chat_completion_sources, model_list, oai_settings } from './openai.js';
import { groups, selected_group } from './group-chats.js'; import { groups, selected_group } from './group-chats.js';
@ -174,9 +174,9 @@ function callTokenizer(type, str, padding) {
case tokenizers.YI: case tokenizers.YI:
return countTokensFromServer('/api/tokenizers/yi/encode', str, padding); return countTokensFromServer('/api/tokenizers/yi/encode', str, padding);
case tokenizers.API_KOBOLD: case tokenizers.API_KOBOLD:
return countTokensFromKoboldAPI('/api/tokenizers/remote/encode', str, padding); return countTokensFromKoboldAPI('/api/tokenizers/remote/kobold/count', str, padding);
case tokenizers.API_TEXTGENERATIONWEBUI: case tokenizers.API_TEXTGENERATIONWEBUI:
return countTokensFromTextgenAPI('/api/tokenizers/remote/encode', str, padding); return countTokensFromTextgenAPI('/api/tokenizers/remote/textgenerationwebui/encode', str, padding);
default: default:
console.warn('Unknown tokenizer type', type); console.warn('Unknown tokenizer type', type);
return callTokenizer(tokenizers.NONE, str, padding); return callTokenizer(tokenizers.NONE, str, padding);
@ -403,17 +403,15 @@ function getServerTokenizationParams(str) {
function getKoboldAPITokenizationParams(str) { function getKoboldAPITokenizationParams(str) {
return { return {
text: str, text: str,
main_api: 'kobold', url: api_server,
url: getAPIServerUrl(),
}; };
} }
function getTextgenAPITokenizationParams(str) { function getTextgenAPITokenizationParams(str) {
return { return {
text: str, text: str,
main_api: 'textgenerationwebui',
api_type: textgen_settings.type, api_type: textgen_settings.type,
url: getAPIServerUrl(), url: api_server_textgenerationwebui,
legacy_api: legacy_api:
textgen_settings.legacy_api && textgen_settings.legacy_api &&
textgen_settings.type !== MANCER, textgen_settings.type !== MANCER,
@ -627,7 +625,7 @@ export function getTextTokens(tokenizerType, str) {
return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model); return getTextTokensFromServer('/api/tokenizers/openai/encode', str, model);
} }
case tokenizers.API_TEXTGENERATIONWEBUI: case tokenizers.API_TEXTGENERATIONWEBUI:
return getTextTokensFromTextgenAPI('/api/tokenizers/remote/encode', str); return getTextTokensFromTextgenAPI('/api/tokenizers/textgenerationwebui/encode', str);
default: default:
console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType); console.warn('Calling getTextTokens with unsupported tokenizer type', tokenizerType);
return []; return [];

View File

@ -536,87 +536,90 @@ router.post('/openai/count', jsonParser, async function (req, res) {
} }
}); });
router.post('/remote/encode', jsonParser, async function (request, response) { router.post('/remote/kobold/count', jsonParser, async function (request, response) {
if (!request.body) {
return response.sendStatus(400);
}
const text = String(request.body.text) || '';
const baseUrl = String(request.body.url);
try {
const args = {
method: 'POST',
body: JSON.stringify({ 'prompt': text }),
headers: { 'Content-Type': 'application/json' },
};
let url = String(baseUrl).replace(/\/$/, '');
url += '/extra/tokencount';
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = data['value'];
return response.send({ count, ids: [] });
} catch (error) {
console.log(error);
return response.send({ error: true });
}
});
router.post('/remote/textgenerationwebui/encode', jsonParser, async function (request, response) {
if (!request.body) { if (!request.body) {
return response.sendStatus(400); return response.sendStatus(400);
} }
const text = String(request.body.text) || ''; const text = String(request.body.text) || '';
const api = String(request.body.main_api);
const baseUrl = String(request.body.url); const baseUrl = String(request.body.url);
const legacyApi = Boolean(request.body.legacy_api); const legacyApi = Boolean(request.body.legacy_api);
try { try {
if (api == 'textgenerationwebui') { const args = {
const args = { method: 'POST',
method: 'POST', headers: { 'Content-Type': 'application/json' },
headers: { 'Content-Type': 'application/json' }, };
};
setAdditionalHeaders(request, args, null); setAdditionalHeaders(request, args, null);
// Convert to string + remove trailing slash + /v1 suffix // Convert to string + remove trailing slash + /v1 suffix
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, ''); let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
if (legacyApi) { if (legacyApi) {
url += '/v1/token-count'; url += '/v1/token-count';
args.body = JSON.stringify({ 'prompt': text }); args.body = JSON.stringify({ 'prompt': text });
} else { } else {
switch (request.body.api_type) { switch (request.body.api_type) {
case TEXTGEN_TYPES.TABBY: case TEXTGEN_TYPES.TABBY:
url += '/v1/token/encode'; url += '/v1/token/encode';
args.body = JSON.stringify({ 'text': text }); args.body = JSON.stringify({ 'text': text });
break; break;
case TEXTGEN_TYPES.KOBOLDCPP: case TEXTGEN_TYPES.KOBOLDCPP:
url += '/api/extra/tokencount'; url += '/api/extra/tokencount';
args.body = JSON.stringify({ 'prompt': text }); args.body = JSON.stringify({ 'prompt': text });
break; break;
default: default:
url += '/v1/internal/encode'; url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text }); args.body = JSON.stringify({ 'text': text });
break; break;
}
} }
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
const ids = legacyApi ? [] : (data?.tokens ?? []);
return response.send({ count, ids });
} }
else if (api == 'kobold') { const result = await fetch(url, args);
const args = {
method: 'POST',
body: JSON.stringify({ 'prompt': text }),
headers: { 'Content-Type': 'application/json' },
};
let url = String(baseUrl).replace(/\/$/, ''); if (!result.ok) {
url += '/extra/tokencount'; console.log(`API returned error: ${result.status} ${result.statusText}`);
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = data['value'];
return response.send({ count: count, ids: [] });
}
else {
console.log('Unknown API', api);
return response.send({ error: true }); return response.send({ error: true });
} }
const data = await result.json();
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
const ids = legacyApi ? [] : (data?.tokens ?? []);
return response.send({ count, ids });
} catch (error) { } catch (error) {
console.log(error); console.log(error);
return response.send({ error: true }); return response.send({ error: true });