Use Express router for novelai endpoint

This commit is contained in:
valadaptive 2023-12-04 12:52:27 -05:00
parent 414c9bd5fb
commit e6b549bc48
2 changed files with 271 additions and 276 deletions

View File

@ -3591,7 +3591,7 @@ require('./src/endpoints/secrets').registerEndpoints(app, jsonParser);
require('./src/endpoints/thumbnails').registerEndpoints(app, jsonParser);
// NovelAI generation
require('./src/endpoints/novelai').registerEndpoints(app, jsonParser);
app.use('/api/novelai', require('./src/endpoints/novelai').router);
// Third-party extensions
require('./src/endpoints/extensions').registerEndpoints(app, jsonParser);

View File

@ -1,8 +1,10 @@
const fetch = require('node-fetch').default;
const express = require('express');
const util = require('util');
const { Readable } = require('stream');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { readAllChunks, extractFileFromZipBuffer } = require('../util');
const { jsonParser } = require('../express-common');
const API_NOVELAI = 'https://api.novelai.net';
@ -60,312 +62,305 @@ function getBadWordsList(model) {
return list.slice();
}
/**
* Registers NovelAI API endpoints.
* @param {import('express').Express} app - Express app
* @param {any} jsonParser - JSON parser middleware
*/
function registerEndpoints(app, jsonParser) {
app.post('/api/novelai/status', jsonParser, async function (req, res) {
if (!req.body) return res.sendStatus(400);
const api_key_novel = readSecret(SECRET_KEYS.NOVEL);
const router = express.Router();
if (!api_key_novel) {
return res.sendStatus(401);
}
router.post('/status', jsonParser, async function (req, res) {
if (!req.body) return res.sendStatus(400);
const api_key_novel = readSecret(SECRET_KEYS.NOVEL);
try {
const response = await fetch(API_NOVELAI + '/user/subscription', {
method: 'GET',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + api_key_novel,
},
});
if (!api_key_novel) {
return res.sendStatus(401);
}
if (response.ok) {
const data = await response.json();
return res.send(data);
} else if (response.status == 401) {
console.log('NovelAI Access Token is incorrect.');
return res.send({ error: true });
}
else {
console.log('NovelAI returned an error:', response.statusText);
return res.send({ error: true });
}
} catch (error) {
console.log(error);
return res.send({ error: true });
}
});
app.post('/api/novelai/generate', jsonParser, async function (req, res) {
if (!req.body) return res.sendStatus(400);
const api_key_novel = readSecret(SECRET_KEYS.NOVEL);
if (!api_key_novel) {
return res.sendStatus(401);
}
const controller = new AbortController();
req.socket.removeAllListeners('close');
req.socket.on('close', function () {
controller.abort();
try {
const response = await fetch(API_NOVELAI + '/user/subscription', {
method: 'GET',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + api_key_novel,
},
});
const isNewModel = (req.body.model.includes('clio') || req.body.model.includes('kayra'));
const badWordsList = getBadWordsList(req.body.model);
// Add customized bad words for Clio and Kayra
if (isNewModel && Array.isArray(req.body.bad_words_ids)) {
for (const badWord of req.body.bad_words_ids) {
if (Array.isArray(badWord) && badWord.every(x => Number.isInteger(x))) {
badWordsList.push(badWord);
}
}
}
// Remove empty arrays from bad words list
for (const badWord of badWordsList) {
if (badWord.length === 0) {
badWordsList.splice(badWordsList.indexOf(badWord), 1);
}
}
// Add default biases for dinkus and asterism
const logit_bias_exp = isNewModel ? logitBiasExp.slice() : [];
if (Array.isArray(logit_bias_exp) && Array.isArray(req.body.logit_bias_exp)) {
logit_bias_exp.push(...req.body.logit_bias_exp);
}
const data = {
'input': req.body.input,
'model': req.body.model,
'parameters': {
'use_string': req.body.use_string ?? true,
'temperature': req.body.temperature,
'max_length': req.body.max_length,
'min_length': req.body.min_length,
'tail_free_sampling': req.body.tail_free_sampling,
'repetition_penalty': req.body.repetition_penalty,
'repetition_penalty_range': req.body.repetition_penalty_range,
'repetition_penalty_slope': req.body.repetition_penalty_slope,
'repetition_penalty_frequency': req.body.repetition_penalty_frequency,
'repetition_penalty_presence': req.body.repetition_penalty_presence,
'repetition_penalty_whitelist': isNewModel ? repPenaltyAllowList : null,
'top_a': req.body.top_a,
'top_p': req.body.top_p,
'top_k': req.body.top_k,
'typical_p': req.body.typical_p,
'mirostat_lr': req.body.mirostat_lr,
'mirostat_tau': req.body.mirostat_tau,
'cfg_scale': req.body.cfg_scale,
'cfg_uc': req.body.cfg_uc,
'phrase_rep_pen': req.body.phrase_rep_pen,
'stop_sequences': req.body.stop_sequences,
'bad_words_ids': badWordsList.length ? badWordsList : null,
'logit_bias_exp': logit_bias_exp,
'generate_until_sentence': req.body.generate_until_sentence,
'use_cache': req.body.use_cache,
'return_full_text': req.body.return_full_text,
'prefix': req.body.prefix,
'order': req.body.order,
},
};
console.log(util.inspect(data, { depth: 4 }));
const args = {
body: JSON.stringify(data),
headers: { 'Content-Type': 'application/json', 'Authorization': 'Bearer ' + api_key_novel },
signal: controller.signal,
};
try {
const url = req.body.streaming ? `${API_NOVELAI}/ai/generate-stream` : `${API_NOVELAI}/ai/generate`;
const response = await fetch(url, { method: 'POST', timeout: 0, ...args });
if (req.body.streaming) {
// Pipe remote SSE stream to Express response
response.body.pipe(res);
req.socket.on('close', function () {
if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream
res.end(); // End the Express response
});
response.body.on('end', function () {
console.log('Streaming request finished');
res.end();
});
} else {
if (!response.ok) {
const text = await response.text();
let message = text;
console.log(`Novel API returned error: ${response.status} ${response.statusText} ${text}`);
try {
const data = JSON.parse(text);
message = data.message;
}
catch {
// ignore
}
return res.status(response.status).send({ error: { message } });
}
const data = await response.json();
console.log(data);
return res.send(data);
}
} catch (error) {
if (response.ok) {
const data = await response.json();
return res.send(data);
} else if (response.status == 401) {
console.log('NovelAI Access Token is incorrect.');
return res.send({ error: true });
}
else {
console.log('NovelAI returned an error:', response.statusText);
return res.send({ error: true });
}
} catch (error) {
console.log(error);
return res.send({ error: true });
}
});
router.post('/generate', jsonParser, async function (req, res) {
if (!req.body) return res.sendStatus(400);
const api_key_novel = readSecret(SECRET_KEYS.NOVEL);
if (!api_key_novel) {
return res.sendStatus(401);
}
const controller = new AbortController();
req.socket.removeAllListeners('close');
req.socket.on('close', function () {
controller.abort();
});
app.post('/api/novelai/generate-image', jsonParser, async (request, response) => {
if (!request.body) {
return response.sendStatus(400);
const isNewModel = (req.body.model.includes('clio') || req.body.model.includes('kayra'));
const badWordsList = getBadWordsList(req.body.model);
// Add customized bad words for Clio and Kayra
if (isNewModel && Array.isArray(req.body.bad_words_ids)) {
for (const badWord of req.body.bad_words_ids) {
if (Array.isArray(badWord) && badWord.every(x => Number.isInteger(x))) {
badWordsList.push(badWord);
}
}
}
// Remove empty arrays from bad words list
for (const badWord of badWordsList) {
if (badWord.length === 0) {
badWordsList.splice(badWordsList.indexOf(badWord), 1);
}
}
// Add default biases for dinkus and asterism
const logit_bias_exp = isNewModel ? logitBiasExp.slice() : [];
if (Array.isArray(logit_bias_exp) && Array.isArray(req.body.logit_bias_exp)) {
logit_bias_exp.push(...req.body.logit_bias_exp);
}
const data = {
'input': req.body.input,
'model': req.body.model,
'parameters': {
'use_string': req.body.use_string ?? true,
'temperature': req.body.temperature,
'max_length': req.body.max_length,
'min_length': req.body.min_length,
'tail_free_sampling': req.body.tail_free_sampling,
'repetition_penalty': req.body.repetition_penalty,
'repetition_penalty_range': req.body.repetition_penalty_range,
'repetition_penalty_slope': req.body.repetition_penalty_slope,
'repetition_penalty_frequency': req.body.repetition_penalty_frequency,
'repetition_penalty_presence': req.body.repetition_penalty_presence,
'repetition_penalty_whitelist': isNewModel ? repPenaltyAllowList : null,
'top_a': req.body.top_a,
'top_p': req.body.top_p,
'top_k': req.body.top_k,
'typical_p': req.body.typical_p,
'mirostat_lr': req.body.mirostat_lr,
'mirostat_tau': req.body.mirostat_tau,
'cfg_scale': req.body.cfg_scale,
'cfg_uc': req.body.cfg_uc,
'phrase_rep_pen': req.body.phrase_rep_pen,
'stop_sequences': req.body.stop_sequences,
'bad_words_ids': badWordsList.length ? badWordsList : null,
'logit_bias_exp': logit_bias_exp,
'generate_until_sentence': req.body.generate_until_sentence,
'use_cache': req.body.use_cache,
'return_full_text': req.body.return_full_text,
'prefix': req.body.prefix,
'order': req.body.order,
},
};
console.log(util.inspect(data, { depth: 4 }));
const args = {
body: JSON.stringify(data),
headers: { 'Content-Type': 'application/json', 'Authorization': 'Bearer ' + api_key_novel },
signal: controller.signal,
};
try {
const url = req.body.streaming ? `${API_NOVELAI}/ai/generate-stream` : `${API_NOVELAI}/ai/generate`;
const response = await fetch(url, { method: 'POST', timeout: 0, ...args });
if (req.body.streaming) {
// Pipe remote SSE stream to Express response
response.body.pipe(res);
req.socket.on('close', function () {
if (response.body instanceof Readable) response.body.destroy(); // Close the remote stream
res.end(); // End the Express response
});
response.body.on('end', function () {
console.log('Streaming request finished');
res.end();
});
} else {
if (!response.ok) {
const text = await response.text();
let message = text;
console.log(`Novel API returned error: ${response.status} ${response.statusText} ${text}`);
try {
const data = JSON.parse(text);
message = data.message;
}
catch {
// ignore
}
return res.status(response.status).send({ error: { message } });
}
const data = await response.json();
console.log(data);
return res.send(data);
}
} catch (error) {
return res.send({ error: true });
}
});
router.post('/generate-image', jsonParser, async (request, response) => {
if (!request.body) {
return response.sendStatus(400);
}
const key = readSecret(SECRET_KEYS.NOVEL);
if (!key) {
return response.sendStatus(401);
}
try {
console.log('NAI Diffusion request:', request.body);
const generateUrl = `${API_NOVELAI}/ai/generate-image`;
const generateResult = await fetch(generateUrl, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
action: 'generate',
input: request.body.prompt,
model: request.body.model ?? 'nai-diffusion',
parameters: {
negative_prompt: request.body.negative_prompt ?? '',
height: request.body.height ?? 512,
width: request.body.width ?? 512,
scale: request.body.scale ?? 9,
seed: Math.floor(Math.random() * 9999999999),
sampler: request.body.sampler ?? 'k_dpmpp_2m',
steps: request.body.steps ?? 28,
n_samples: 1,
// NAI handholding for prompts
ucPreset: 0,
qualityToggle: false,
add_original_image: false,
controlnet_strength: 1,
dynamic_thresholding: false,
legacy: false,
sm: false,
sm_dyn: false,
uncond_scale: 1,
},
}),
});
if (!generateResult.ok) {
const text = await generateResult.text();
console.log('NovelAI returned an error.', generateResult.statusText, text);
return response.sendStatus(500);
}
const key = readSecret(SECRET_KEYS.NOVEL);
const archiveBuffer = await generateResult.arrayBuffer();
const imageBuffer = await extractFileFromZipBuffer(archiveBuffer, '.png');
const originalBase64 = imageBuffer.toString('base64');
if (!key) {
return response.sendStatus(401);
// No upscaling
if (isNaN(request.body.upscale_ratio) || request.body.upscale_ratio <= 1) {
return response.send(originalBase64);
}
try {
console.log('NAI Diffusion request:', request.body);
const generateUrl = `${API_NOVELAI}/ai/generate-image`;
const generateResult = await fetch(generateUrl, {
console.debug('Upscaling image...');
const upscaleUrl = `${API_NOVELAI}/ai/upscale`;
const upscaleResult = await fetch(upscaleUrl, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
action: 'generate',
input: request.body.prompt,
model: request.body.model ?? 'nai-diffusion',
parameters: {
negative_prompt: request.body.negative_prompt ?? '',
height: request.body.height ?? 512,
width: request.body.width ?? 512,
scale: request.body.scale ?? 9,
seed: Math.floor(Math.random() * 9999999999),
sampler: request.body.sampler ?? 'k_dpmpp_2m',
steps: request.body.steps ?? 28,
n_samples: 1,
// NAI handholding for prompts
ucPreset: 0,
qualityToggle: false,
add_original_image: false,
controlnet_strength: 1,
dynamic_thresholding: false,
legacy: false,
sm: false,
sm_dyn: false,
uncond_scale: 1,
},
image: originalBase64,
height: request.body.height,
width: request.body.width,
scale: request.body.upscale_ratio,
}),
});
if (!generateResult.ok) {
const text = await generateResult.text();
console.log('NovelAI returned an error.', generateResult.statusText, text);
return response.sendStatus(500);
if (!upscaleResult.ok) {
throw new Error('NovelAI returned an error.');
}
const archiveBuffer = await generateResult.arrayBuffer();
const imageBuffer = await extractFileFromZipBuffer(archiveBuffer, '.png');
const originalBase64 = imageBuffer.toString('base64');
const upscaledArchiveBuffer = await upscaleResult.arrayBuffer();
const upscaledImageBuffer = await extractFileFromZipBuffer(upscaledArchiveBuffer, '.png');
const upscaledBase64 = upscaledImageBuffer.toString('base64');
// No upscaling
if (isNaN(request.body.upscale_ratio) || request.body.upscale_ratio <= 1) {
return response.send(originalBase64);
}
try {
console.debug('Upscaling image...');
const upscaleUrl = `${API_NOVELAI}/ai/upscale`;
const upscaleResult = await fetch(upscaleUrl, {
method: 'POST',
headers: {
'Authorization': `Bearer ${key}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
image: originalBase64,
height: request.body.height,
width: request.body.width,
scale: request.body.upscale_ratio,
}),
});
if (!upscaleResult.ok) {
throw new Error('NovelAI returned an error.');
}
const upscaledArchiveBuffer = await upscaleResult.arrayBuffer();
const upscaledImageBuffer = await extractFileFromZipBuffer(upscaledArchiveBuffer, '.png');
const upscaledBase64 = upscaledImageBuffer.toString('base64');
return response.send(upscaledBase64);
} catch (error) {
console.warn('NovelAI generated an image, but upscaling failed. Returning original image.');
return response.send(originalBase64);
}
return response.send(upscaledBase64);
} catch (error) {
console.log(error);
return response.sendStatus(500);
console.warn('NovelAI generated an image, but upscaling failed. Returning original image.');
return response.send(originalBase64);
}
});
} catch (error) {
console.log(error);
return response.sendStatus(500);
}
});
app.post('/api/novelai/generate-voice', jsonParser, async (request, response) => {
const token = readSecret(SECRET_KEYS.NOVEL);
router.post('/generate-voice', jsonParser, async (request, response) => {
const token = readSecret(SECRET_KEYS.NOVEL);
if (!token) {
return response.sendStatus(401);
if (!token) {
return response.sendStatus(401);
}
const text = request.body.text;
const voice = request.body.voice;
if (!text || !voice) {
return response.sendStatus(400);
}
try {
const url = `${API_NOVELAI}/ai/generate-voice?text=${encodeURIComponent(text)}&voice=-1&seed=${encodeURIComponent(voice)}&opus=false&version=v2`;
const result = await fetch(url, {
method: 'GET',
headers: {
'Authorization': `Bearer ${token}`,
'Accept': 'audio/mpeg',
},
timeout: 0,
});
if (!result.ok) {
return response.sendStatus(result.status);
}
const text = request.body.text;
const voice = request.body.voice;
const chunks = await readAllChunks(result.body);
const buffer = Buffer.concat(chunks);
response.setHeader('Content-Type', 'audio/mpeg');
return response.send(buffer);
}
catch (error) {
console.error(error);
return response.sendStatus(500);
}
});
if (!text || !voice) {
return response.sendStatus(400);
}
try {
const url = `${API_NOVELAI}/ai/generate-voice?text=${encodeURIComponent(text)}&voice=-1&seed=${encodeURIComponent(voice)}&opus=false&version=v2`;
const result = await fetch(url, {
method: 'GET',
headers: {
'Authorization': `Bearer ${token}`,
'Accept': 'audio/mpeg',
},
timeout: 0,
});
if (!result.ok) {
return response.sendStatus(result.status);
}
const chunks = await readAllChunks(result.body);
const buffer = Buffer.concat(chunks);
response.setHeader('Content-Type', 'audio/mpeg');
return response.send(buffer);
}
catch (error) {
console.error(error);
return response.sendStatus(500);
}
});
}
module.exports = {
registerEndpoints,
};
module.exports = { router };