Fix ooba tokenization via API. Fix requiring streaming URL to generate

This commit is contained in:
Cohee 2023-11-08 03:38:04 +02:00
parent b2629d9718
commit 865256f5c0
4 changed files with 61 additions and 37 deletions

View File

@ -205,7 +205,6 @@
<i data-newbie-hidden data-preset-manager-delete="textgenerationwebui" class="menu_button fa-solid fa-trash-can" title="Delete the preset" data-i18n="[title]Delete the preset"></i> <i data-newbie-hidden data-preset-manager-delete="textgenerationwebui" class="menu_button fa-solid fa-trash-can" title="Delete the preset" data-i18n="[title]Delete the preset"></i>
</div> </div>
</div> </div>
<hr>
</div> </div>
<div data-newbie-hidden id="ai_module_block_novel" class="width100p"> <div data-newbie-hidden id="ai_module_block_novel" class="width100p">
@ -342,7 +341,6 @@
</div> </div>
</div> </div>
</div> </div>
<hr>
<div id="range_block_novel"> <div id="range_block_novel">
<div class="range-block"> <div class="range-block">
<label class="checkbox_label widthFreeExpand"> <label class="checkbox_label widthFreeExpand">

View File

@ -869,15 +869,7 @@ async function getStatus() {
const url = main_api == "textgenerationwebui" ? '/api/textgenerationwebui/status' : '/getstatus'; const url = main_api == "textgenerationwebui" ? '/api/textgenerationwebui/status' : '/getstatus';
let endpoint = api_server; let endpoint = getAPIServerUrl();
if (main_api == "textgenerationwebui") {
endpoint = api_server_textgenerationwebui;
}
if (main_api == "textgenerationwebui" && isMancer()) {
endpoint = MANCER_SERVER;
}
if (!endpoint) { if (!endpoint) {
console.warn("No endpoint for status check"); console.warn("No endpoint for status check");
@ -946,6 +938,22 @@ export function resultCheckStatus() {
stopStatusLoading(); stopStatusLoading();
} }
export function getAPIServerUrl() {
if (main_api == "textgenerationwebui") {
if (isMancer()) {
return MANCER_SERVER;
}
return api_server_textgenerationwebui;
}
if (main_api == "kobold") {
return api_server;
}
return "";
}
export async function selectCharacterById(id) { export async function selectCharacterById(id) {
if (characters[id] == undefined) { if (characters[id] == undefined) {
return; return;
@ -2534,16 +2542,6 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
return; return;
} }
if (
main_api == 'textgenerationwebui' &&
textgenerationwebui_settings.streaming &&
textgenerationwebui_settings.type === textgen_types.OOBA &&
!textgenerationwebui_settings.streaming_url) {
toastr.error('Streaming URL is not set. Look it up in the console window when starting TextGen Web UI');
unblockGeneration();
return;
}
if (main_api == 'kobold' && kai_settings.streaming_kobold && !kai_flags.can_use_streaming) { if (main_api == 'kobold' && kai_settings.streaming_kobold && !kai_flags.can_use_streaming) {
toastr.error('Streaming is enabled, but the version of Kobold used does not support token streaming.', undefined, { timeOut: 10000, preventDuplicates: true, }); toastr.error('Streaming is enabled, but the version of Kobold used does not support token streaming.', undefined, { timeOut: 10000, preventDuplicates: true, });
unblockGeneration(); unblockGeneration();

View File

@ -1,4 +1,4 @@
import { characters, main_api, nai_settings, online_status, this_chid } from "../script.js"; import { api_server, api_server_textgenerationwebui, characters, getAPIServerUrl, main_api, nai_settings, online_status, this_chid } from "../script.js";
import { power_user, registerDebugFunction } from "./power-user.js"; import { power_user, registerDebugFunction } from "./power-user.js";
import { chat_completion_sources, model_list, oai_settings } from "./openai.js"; import { chat_completion_sources, model_list, oai_settings } from "./openai.js";
import { groups, selected_group } from "./group-chats.js"; import { groups, selected_group } from "./group-chats.js";
@ -376,7 +376,11 @@ function countTokensRemote(endpoint, str, padding) {
async: false, async: false,
type: 'POST', type: 'POST',
url: endpoint, url: endpoint,
data: JSON.stringify({ text: str }), data: JSON.stringify({
text: str,
api: main_api,
url: getAPIServerUrl(),
}),
dataType: "json", dataType: "json",
contentType: "application/json", contentType: "application/json",
success: function (data) { success: function (data) {

View File

@ -3324,27 +3324,57 @@ app.post("/tokenize_via_api", jsonParser, async function (request, response) {
return response.sendStatus(400); return response.sendStatus(400);
} }
const text = request.body.text || ''; const text = request.body.text || '';
const api = request.body.api;
const baseUrl = request.body.url;
try { try {
if (api == 'textgenerationwebui') {
const args = { const args = {
method: 'POST',
body: JSON.stringify({ "prompt": text }), body: JSON.stringify({ "prompt": text }),
headers: { "Content-Type": "application/json" } headers: { "Content-Type": "application/json" }
}; };
if (main_api == 'textgenerationwebui') {
setAdditionalHeaders(request, args, null); setAdditionalHeaders(request, args, null);
const data = await postAsync(api_server + "/v1/token-count", args); const url = new URL(baseUrl);
url.pathname = '/api/v1/token-count'
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
return response.send({ count: data['results'][0]['tokens'] }); return response.send({ count: data['results'][0]['tokens'] });
} }
else if (main_api == 'kobold') { else if (api == 'kobold') {
const data = await postAsync(api_server + "/extra/tokencount", args); const args = {
method: 'POST',
body: JSON.stringify({ "prompt": text }),
headers: { "Content-Type": "application/json" }
};
const url = new URL(baseUrl);
url.pathname = '/api/extra/tokencount';
const result = await fetch(url, args);
if (!result.ok) {
console.log(`API returned error: ${result.status} ${result.statusText}`);
return response.send({ error: true });
}
const data = await result.json();
const count = data['value']; const count = data['value'];
return response.send({ count: count }); return response.send({ count: count });
} }
else { else {
console.log('Unknown API', api);
return response.send({ error: true }); return response.send({ error: true });
} }
} catch (error) { } catch (error) {
@ -3371,12 +3401,6 @@ async function fetchJSON(url, args = {}) {
throw response; throw response;
} }
/**
* Convenience function for fetch requests (default POST with no timeout) returning as JSON.
* @param {string} url
* @param {import('node-fetch').RequestInit} args
*/
async function postAsync(url, args) { return fetchJSON(url, { method: 'POST', timeout: 0, ...args }) }
// ** END ** // ** END **