Merge pull request #1519 from valadaptive/separate-kobold-endpoints

Move Kobold endpoints into their own module
This commit is contained in:
Cohee 2023-12-14 02:15:41 +02:00 committed by GitHub
commit 875760eadf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 187 deletions

View File

@ -891,7 +891,7 @@ async function getStatusKobold() {
} }
try { try {
const response = await fetch('/getstatus', { const response = await fetch('/api/backends/kobold/status', {
method: 'POST', method: 'POST',
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify({ body: JSON.stringify({
@ -4432,7 +4432,7 @@ function setInContextMessages(lastmsg, type) {
function getGenerateUrl(api) { function getGenerateUrl(api) {
let generate_url = ''; let generate_url = '';
if (api == 'kobold') { if (api == 'kobold') {
generate_url = '/generate'; generate_url = '/api/backends/kobold/generate';
} else if (api == 'textgenerationwebui') { } else if (api == 'textgenerationwebui') {
generate_url = '/api/backends/text-completions/generate'; generate_url = '/api/backends/text-completions/generate';
} else if (api == 'novel') { } else if (api == 'novel') {

View File

@ -173,7 +173,7 @@ function tryParseStreamingError(response, decoded) {
} }
export async function generateKoboldWithStreaming(generate_data, signal) { export async function generateKoboldWithStreaming(generate_data, signal) {
const response = await fetch('/generate', { const response = await fetch('/api/backends/kobold/generate', {
headers: getRequestHeaders(), headers: getRequestHeaders(),
body: JSON.stringify(generate_data), body: JSON.stringify(generate_data),
method: 'POST', method: 'POST',

187
server.js
View File

@ -45,7 +45,6 @@ const { jsonParser, urlencodedParser } = require('./src/express-common.js');
const contentManager = require('./src/endpoints/content-manager'); const contentManager = require('./src/endpoints/content-manager');
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/endpoints/secrets'); const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/endpoints/secrets');
const { const {
delay,
getVersion, getVersion,
getConfigValue, getConfigValue,
color, color,
@ -61,7 +60,6 @@ const {
const { ensureThumbnailCache } = require('./src/endpoints/thumbnails'); const { ensureThumbnailCache } = require('./src/endpoints/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers'); const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/endpoints/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion'); const { convertClaudePrompt } = require('./src/chat-completion');
const { getOverrideHeaders, setAdditionalHeaders } = require('./src/additional-headers');
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0. // Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870 // https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
@ -312,188 +310,6 @@ app.get('/version', async function (_, response) {
response.send(data); response.send(data);
}); });
//**************Kobold api
app.post('/generate', jsonParser, async function (request, response_generate) {
if (!request.body) return response_generate.sendStatus(400);
if (request.body.api_server.indexOf('localhost') != -1) {
request.body.api_server = request.body.api_server.replace('localhost', '127.0.0.1');
}
const request_prompt = request.body.prompt;
const controller = new AbortController();
request.socket.removeAllListeners('close');
request.socket.on('close', async function () {
if (request.body.can_abort && !response_generate.writableEnded) {
try {
console.log('Aborting Kobold generation...');
// send abort signal to koboldcpp
const abortResponse = await fetch(`${request.body.api_server}/extra/abort`, {
method: 'POST',
});
if (!abortResponse.ok) {
console.log('Error sending abort request to Kobold:', abortResponse.status);
}
} catch (error) {
console.log(error);
}
}
controller.abort();
});
let this_settings = {
prompt: request_prompt,
use_story: false,
use_memory: false,
use_authors_note: false,
use_world_info: false,
max_context_length: request.body.max_context_length,
max_length: request.body.max_length,
};
if (request.body.gui_settings == false) {
const sampler_order = [request.body.s1, request.body.s2, request.body.s3, request.body.s4, request.body.s5, request.body.s6, request.body.s7];
this_settings = {
prompt: request_prompt,
use_story: false,
use_memory: false,
use_authors_note: false,
use_world_info: false,
max_context_length: request.body.max_context_length,
max_length: request.body.max_length,
rep_pen: request.body.rep_pen,
rep_pen_range: request.body.rep_pen_range,
rep_pen_slope: request.body.rep_pen_slope,
temperature: request.body.temperature,
tfs: request.body.tfs,
top_a: request.body.top_a,
top_k: request.body.top_k,
top_p: request.body.top_p,
min_p: request.body.min_p,
typical: request.body.typical,
sampler_order: sampler_order,
singleline: !!request.body.singleline,
use_default_badwordsids: request.body.use_default_badwordsids,
mirostat: request.body.mirostat,
mirostat_eta: request.body.mirostat_eta,
mirostat_tau: request.body.mirostat_tau,
grammar: request.body.grammar,
sampler_seed: request.body.sampler_seed,
};
if (request.body.stop_sequence) {
this_settings['stop_sequence'] = request.body.stop_sequence;
}
}
console.log(this_settings);
const args = {
body: JSON.stringify(this_settings),
headers: Object.assign(
{ 'Content-Type': 'application/json' },
getOverrideHeaders((new URL(request.body.api_server))?.host),
),
signal: controller.signal,
};
const MAX_RETRIES = 50;
const delayAmount = 2500;
for (let i = 0; i < MAX_RETRIES; i++) {
try {
const url = request.body.streaming ? `${request.body.api_server}/extra/generate/stream` : `${request.body.api_server}/v1/generate`;
const response = await fetch(url, { method: 'POST', timeout: 0, ...args });
if (request.body.streaming) {
// Pipe remote SSE stream to Express response
forwardFetchResponse(response, response_generate);
return;
} else {
if (!response.ok) {
const errorText = await response.text();
console.log(`Kobold returned error: ${response.status} ${response.statusText} ${errorText}`);
try {
const errorJson = JSON.parse(errorText);
const message = errorJson?.detail?.msg || errorText;
return response_generate.status(400).send({ error: { message } });
} catch {
return response_generate.status(400).send({ error: { message: errorText } });
}
}
const data = await response.json();
console.log('Endpoint response:', data);
return response_generate.send(data);
}
} catch (error) {
// response
switch (error?.status) {
case 403:
case 503: // retry in case of temporary service issue, possibly caused by a queue failure?
console.debug(`KoboldAI is busy. Retry attempt ${i + 1} of ${MAX_RETRIES}...`);
await delay(delayAmount);
break;
default:
if ('status' in error) {
console.log('Status Code from Kobold:', error.status);
}
return response_generate.send({ error: true });
}
}
}
console.log('Max retries exceeded. Giving up.');
return response_generate.send({ error: true });
});
// Only called for kobold
app.post('/getstatus', jsonParser, async function (request, response) {
if (!request.body) return response.sendStatus(400);
let api_server = request.body.api_server;
if (api_server.indexOf('localhost') != -1) {
api_server = api_server.replace('localhost', '127.0.0.1');
}
const args = {
headers: { 'Content-Type': 'application/json' },
};
setAdditionalHeaders(request, args, api_server);
const result = {};
const [koboldUnitedResponse, koboldExtraResponse, koboldModelResponse] = await Promise.all([
// We catch errors both from the response not having a successful HTTP status and from JSON parsing failing
// Kobold United API version
fetch(`${api_server}/v1/info/version`).then(response => {
if (!response.ok) throw new Error(`Kobold API error: ${response.status, response.statusText}`);
return response.json();
}).catch(() => ({ result: '0.0.0' })),
// KoboldCpp version
fetch(`${api_server}/extra/version`).then(response => {
if (!response.ok) throw new Error(`Kobold API error: ${response.status, response.statusText}`);
return response.json();
}).catch(() => ({ version: '0.0' })),
// Current model
fetch(`${api_server}/v1/model`).then(response => {
if (!response.ok) throw new Error(`Kobold API error: ${response.status, response.statusText}`);
return response.json();
}).catch(() => null),
]);
result.koboldUnitedVersion = koboldUnitedResponse.result;
result.koboldCppVersion = koboldExtraResponse.result;
result.model = !koboldModelResponse || koboldModelResponse.result === 'ReadOnly' ?
'no_connection' :
koboldModelResponse.result;
response.send(result);
});
app.post('/getuseravatars', jsonParser, function (request, response) { app.post('/getuseravatars', jsonParser, function (request, response) {
var images = getImages('public/User Avatars'); var images = getImages('public/User Avatars');
response.send(JSON.stringify(images)); response.send(JSON.stringify(images));
@ -1624,6 +1440,9 @@ app.use('/api/serpapi', require('./src/endpoints/serpapi').router);
// Ooba/OpenAI text completions // Ooba/OpenAI text completions
app.use('/api/backends/text-completions', require('./src/endpoints/backends/text-completions').router); app.use('/api/backends/text-completions', require('./src/endpoints/backends/text-completions').router);
// KoboldAI
app.use('/api/backends/kobold', require('./src/endpoints/backends/kobold').router);
const tavernUrl = new URL( const tavernUrl = new URL(
(cliArguments.ssl ? 'https://' : 'http://') + (cliArguments.ssl ? 'https://' : 'http://') +
(listen ? '0.0.0.0' : '127.0.0.1') + (listen ? '0.0.0.0' : '127.0.0.1') +

View File

@ -0,0 +1,189 @@
const express = require('express');
const fetch = require('node-fetch').default;
const { jsonParser } = require('../../express-common');
const { forwardFetchResponse, delay } = require('../../util');
const { getOverrideHeaders, setAdditionalHeaders } = require('../../additional-headers');
const router = express.Router();
router.post('/generate', jsonParser, async function (request, response_generate) {
if (!request.body) return response_generate.sendStatus(400);
if (request.body.api_server.indexOf('localhost') != -1) {
request.body.api_server = request.body.api_server.replace('localhost', '127.0.0.1');
}
const request_prompt = request.body.prompt;
const controller = new AbortController();
request.socket.removeAllListeners('close');
request.socket.on('close', async function () {
if (request.body.can_abort && !response_generate.writableEnded) {
try {
console.log('Aborting Kobold generation...');
// send abort signal to koboldcpp
const abortResponse = await fetch(`${request.body.api_server}/extra/abort`, {
method: 'POST',
});
if (!abortResponse.ok) {
console.log('Error sending abort request to Kobold:', abortResponse.status);
}
} catch (error) {
console.log(error);
}
}
controller.abort();
});
let this_settings = {
prompt: request_prompt,
use_story: false,
use_memory: false,
use_authors_note: false,
use_world_info: false,
max_context_length: request.body.max_context_length,
max_length: request.body.max_length,
};
if (request.body.gui_settings == false) {
const sampler_order = [request.body.s1, request.body.s2, request.body.s3, request.body.s4, request.body.s5, request.body.s6, request.body.s7];
this_settings = {
prompt: request_prompt,
use_story: false,
use_memory: false,
use_authors_note: false,
use_world_info: false,
max_context_length: request.body.max_context_length,
max_length: request.body.max_length,
rep_pen: request.body.rep_pen,
rep_pen_range: request.body.rep_pen_range,
rep_pen_slope: request.body.rep_pen_slope,
temperature: request.body.temperature,
tfs: request.body.tfs,
top_a: request.body.top_a,
top_k: request.body.top_k,
top_p: request.body.top_p,
min_p: request.body.min_p,
typical: request.body.typical,
sampler_order: sampler_order,
singleline: !!request.body.singleline,
use_default_badwordsids: request.body.use_default_badwordsids,
mirostat: request.body.mirostat,
mirostat_eta: request.body.mirostat_eta,
mirostat_tau: request.body.mirostat_tau,
grammar: request.body.grammar,
sampler_seed: request.body.sampler_seed,
};
if (request.body.stop_sequence) {
this_settings['stop_sequence'] = request.body.stop_sequence;
}
}
console.log(this_settings);
const args = {
body: JSON.stringify(this_settings),
headers: Object.assign(
{ 'Content-Type': 'application/json' },
getOverrideHeaders((new URL(request.body.api_server))?.host),
),
signal: controller.signal,
};
const MAX_RETRIES = 50;
const delayAmount = 2500;
for (let i = 0; i < MAX_RETRIES; i++) {
try {
const url = request.body.streaming ? `${request.body.api_server}/extra/generate/stream` : `${request.body.api_server}/v1/generate`;
const response = await fetch(url, { method: 'POST', timeout: 0, ...args });
if (request.body.streaming) {
// Pipe remote SSE stream to Express response
forwardFetchResponse(response, response_generate);
return;
} else {
if (!response.ok) {
const errorText = await response.text();
console.log(`Kobold returned error: ${response.status} ${response.statusText} ${errorText}`);
try {
const errorJson = JSON.parse(errorText);
const message = errorJson?.detail?.msg || errorText;
return response_generate.status(400).send({ error: { message } });
} catch {
return response_generate.status(400).send({ error: { message: errorText } });
}
}
const data = await response.json();
console.log('Endpoint response:', data);
return response_generate.send(data);
}
} catch (error) {
// response
switch (error?.status) {
case 403:
case 503: // retry in case of temporary service issue, possibly caused by a queue failure?
console.debug(`KoboldAI is busy. Retry attempt ${i + 1} of ${MAX_RETRIES}...`);
await delay(delayAmount);
break;
default:
if ('status' in error) {
console.log('Status Code from Kobold:', error.status);
}
return response_generate.send({ error: true });
}
}
}
console.log('Max retries exceeded. Giving up.');
return response_generate.send({ error: true });
});
router.post('/status', jsonParser, async function (request, response) {
if (!request.body) return response.sendStatus(400);
let api_server = request.body.api_server;
if (api_server.indexOf('localhost') != -1) {
api_server = api_server.replace('localhost', '127.0.0.1');
}
const args = {
headers: { 'Content-Type': 'application/json' },
};
setAdditionalHeaders(request, args, api_server);
const result = {};
const [koboldUnitedResponse, koboldExtraResponse, koboldModelResponse] = await Promise.all([
// We catch errors both from the response not having a successful HTTP status and from JSON parsing failing
// Kobold United API version
fetch(`${api_server}/v1/info/version`).then(response => {
if (!response.ok) throw new Error(`Kobold API error: ${response.status, response.statusText}`);
return response.json();
}).catch(() => ({ result: '0.0.0' })),
// KoboldCpp version
fetch(`${api_server}/extra/version`).then(response => {
if (!response.ok) throw new Error(`Kobold API error: ${response.status, response.statusText}`);
return response.json();
}).catch(() => ({ version: '0.0' })),
// Current model
fetch(`${api_server}/v1/model`).then(response => {
if (!response.ok) throw new Error(`Kobold API error: ${response.status, response.statusText}`);
return response.json();
}).catch(() => null),
]);
result.koboldUnitedVersion = koboldUnitedResponse.result;
result.koboldCppVersion = koboldExtraResponse.result;
result.model = !koboldModelResponse || koboldModelResponse.result === 'ReadOnly' ?
'no_connection' :
koboldModelResponse.result;
response.send(result);
});
module.exports = { router };