Merge branch 'staging' of https://github.com/Cohee1207/SillyTavern into nuclaude

This commit is contained in:
based
2024-03-05 03:15:28 +10:00
221 changed files with 15418 additions and 7218 deletions

View File

@@ -1,4 +1,4 @@
const { TEXTGEN_TYPES } = require('./constants');
const { TEXTGEN_TYPES, OPENROUTER_HEADERS } = require('./constants');
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
const { getConfigValue } = require('./util');
@@ -19,6 +19,21 @@ function getTogetherAIHeaders() {
}) : {};
}
function getInfermaticAIHeaders() {
const apiKey = readSecret(SECRET_KEYS.INFERMATICAI);
return apiKey ? ({
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getOpenRouterHeaders() {
const apiKey = readSecret(SECRET_KEYS.OPENROUTER);
const baseHeaders = { ...OPENROUTER_HEADERS };
return apiKey ? Object.assign(baseHeaders, { 'Authorization': `Bearer ${apiKey}` }) : baseHeaders;
}
function getAphroditeHeaders() {
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
@@ -37,6 +52,14 @@ function getTabbyHeaders() {
}) : {};
}
function getOobaHeaders() {
const apiKey = readSecret(SECRET_KEYS.OOBA);
return apiKey ? ({
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getOverrideHeaders(urlHost) {
const requestOverrides = getConfigValue('requestOverrides', []);
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
@@ -69,6 +92,15 @@ function setAdditionalHeaders(request, args, server) {
case TEXTGEN_TYPES.TOGETHERAI:
headers = getTogetherAIHeaders();
break;
case TEXTGEN_TYPES.OOBA:
headers = getOobaHeaders();
break;
case TEXTGEN_TYPES.INFERMATICAI:
headers = getInfermaticAIHeaders();
break;
case TEXTGEN_TYPES.OPENROUTER:
headers = getOpenRouterHeaders();
break;
default:
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
break;

View File

@@ -1,34 +1,80 @@
const fs = require('fs');
const encode = require('png-chunks-encode');
const extract = require('png-chunks-extract');
const PNGtext = require('png-chunk-text');
const parse = async (cardUrl, format) => {
/**
* Writes Character metadata to a PNG image buffer.
* @param {Buffer} image PNG image buffer
* @param {string} data Character data to write
* @returns {Buffer} PNG image buffer with metadata
*/
const write = (image, data) => {
const chunks = extract(image);
const tEXtChunks = chunks.filter(chunk => chunk.name === 'tEXt');
// Remove all existing tEXt chunks
for (let tEXtChunk of tEXtChunks) {
chunks.splice(chunks.indexOf(tEXtChunk), 1);
}
// Add new chunks before the IEND chunk
const base64EncodedData = Buffer.from(data, 'utf8').toString('base64');
chunks.splice(-1, 0, PNGtext.encode('chara', base64EncodedData));
const newBuffer = Buffer.from(encode(chunks));
return newBuffer;
};
/**
* Reads Character metadata from a PNG image buffer.
* @param {Buffer} image PNG image buffer
* @returns {string} Character data
*/
const read = (image) => {
const chunks = extract(image);
const textChunks = chunks.filter(function (chunk) {
return chunk.name === 'tEXt';
}).map(function (chunk) {
return PNGtext.decode(chunk.data);
});
if (textChunks.length === 0) {
console.error('PNG metadata does not contain any text chunks.');
throw new Error('No PNG metadata.');
}
let index = textChunks.findIndex((chunk) => chunk.keyword.toLowerCase() == 'chara');
if (index === -1) {
console.error('PNG metadata does not contain any character data.');
throw new Error('No PNG metadata.');
}
return Buffer.from(textChunks[index].text, 'base64').toString('utf8');
};
/**
* Parses a card image and returns the character metadata.
* @param {string} cardUrl Path to the card image
* @param {string} format File format
* @returns {string} Character data
*/
const parse = (cardUrl, format) => {
let fileFormat = format === undefined ? 'png' : format;
switch (fileFormat) {
case 'png': {
const buffer = fs.readFileSync(cardUrl);
const chunks = extract(buffer);
const textChunks = chunks.filter(function (chunk) {
return chunk.name === 'tEXt';
}).map(function (chunk) {
return PNGtext.decode(chunk.data);
});
if (textChunks.length === 0) {
console.error('PNG metadata does not contain any character data.');
throw new Error('No PNG metadata.');
}
return Buffer.from(textChunks[0].text, 'base64').toString('utf8');
return read(buffer);
}
default:
break;
}
throw new Error('Unsupported format');
};
module.exports = {
parse: parse,
parse,
write,
read,
};

View File

@@ -1,5 +1,6 @@
const DIRECTORIES = {
worlds: 'public/worlds/',
user: 'public/user',
avatars: 'public/User Avatars',
images: 'public/img/',
userImages: 'public/user/images/',
@@ -175,8 +176,22 @@ const TEXTGEN_TYPES = {
TOGETHERAI: 'togetherai',
LLAMACPP: 'llamacpp',
OLLAMA: 'ollama',
INFERMATICAI: 'infermaticai',
OPENROUTER: 'openrouter',
};
const INFERMATICAI_KEYS = [
'model',
'prompt',
'max_tokens',
'temperature',
'top_p',
'top_k',
'repetition_penalty',
'stream',
'stop',
];
// https://docs.together.ai/reference/completions
const TOGETHERAI_KEYS = [
'model',
@@ -187,6 +202,7 @@ const TOGETHERAI_KEYS = [
'top_k',
'repetition_penalty',
'stream',
'stop',
];
// https://github.com/jmorganca/ollama/blob/main/docs/api.md#request-with-options
@@ -211,6 +227,29 @@ const OLLAMA_KEYS = [
const AVATAR_WIDTH = 400;
const AVATAR_HEIGHT = 600;
const OPENROUTER_HEADERS = {
'HTTP-Referer': 'https://sillytavern.app',
'X-Title': 'SillyTavern',
};
const OPENROUTER_KEYS = [
'max_tokens',
'temperature',
'top_k',
'top_p',
'presence_penalty',
'frequency_penalty',
'repetition_penalty',
'min_p',
'top_a',
'seed',
'logit_bias',
'model',
'stream',
'prompt',
'stop',
];
module.exports = {
DIRECTORIES,
UNSAFE_EXTENSIONS,
@@ -223,4 +262,7 @@ module.exports = {
AVATAR_HEIGHT,
TOGETHERAI_KEYS,
OLLAMA_KEYS,
INFERMATICAI_KEYS,
OPENROUTER_HEADERS,
OPENROUTER_KEYS,
};

View File

@@ -1,6 +1,7 @@
const TASK = 'feature-extraction';
/**
* Gets the vectorized text in form of an array of numbers.
* @param {string} text - The text to vectorize
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
*/
@@ -12,6 +13,20 @@ async function getTransformersVector(text) {
return vector;
}
/**
* Gets the vectorized texts in form of an array of arrays of numbers.
* @param {string[]} texts - The texts to vectorize
* @returns {Promise<number[][]>} - The vectorized texts in form of an array of arrays of numbers
*/
async function getTransformersBatchVector(texts) {
const result = [];
for (const text of texts) {
result.push(await getTransformersVector(text));
}
return result;
}
module.exports = {
getTransformersVector,
getTransformersBatchVector,
};

View File

@@ -8,7 +8,7 @@ const { DIRECTORIES, UNSAFE_EXTENSIONS } = require('../constants');
const { jsonParser } = require('../express-common');
const { clientRelativePath } = require('../util');
const VALID_CATEGORIES = ['bgm', 'ambient', 'blip', 'live2d'];
const VALID_CATEGORIES = ['bgm', 'ambient', 'blip', 'live2d', 'vrm', 'character'];
/**
* Validates the input filename for the asset.
@@ -106,6 +106,33 @@ router.post('/get', jsonParser, async (_, response) => {
continue;
}
// VRM assets
if (folder == 'vrm') {
output[folder] = { 'model': [], 'animation': [] };
// Extract models
const vrm_model_folder = path.normalize(path.join(folderPath, 'vrm', 'model'));
let files = getFiles(vrm_model_folder);
//console.debug("FILE FOUND:",files)
for (let file of files) {
if (!file.endsWith('.placeholder')) {
//console.debug("Asset VRM model found:",file)
output['vrm']['model'].push(clientRelativePath(file));
}
}
// Extract models
const vrm_animation_folder = path.normalize(path.join(folderPath, 'vrm', 'animation'));
files = getFiles(vrm_animation_folder);
//console.debug("FILE FOUND:",files)
for (let file of files) {
if (!file.endsWith('.placeholder')) {
//console.debug("Asset VRM animation found:",file)
output['vrm']['animation'].push(clientRelativePath(file));
}
}
continue;
}
// Other assets (bgm/ambient/blip)
const files = fs.readdirSync(path.join(folderPath, folder))
.filter(filename => {
@@ -172,6 +199,13 @@ router.post('/download', jsonParser, async (request, response) => {
const fileStream = fs.createWriteStream(destination, { flags: 'wx' });
await finished(res.body.pipe(fileStream));
if (category === 'character') {
response.sendFile(temp_path, { root: process.cwd() }, () => {
fs.rmSync(temp_path);
});
return;
}
// Move into asset place
console.debug('Download finished, moving file from', temp_path, 'to', file_path);
fs.renameSync(temp_path, file_path);

View File

@@ -3,7 +3,7 @@ const fetch = require('node-fetch').default;
const { Readable } = require('stream');
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY } = require('../../constants');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
@@ -12,7 +12,7 @@ const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sente
const API_OPENAI = 'https://api.openai.com/v1';
const API_CLAUDE = 'https://api.anthropic.com/v1';
const API_MISTRAL = 'https://api.mistral.ai/v1';
/**
* Sends a request to Claude API.
* @param {express.Request} request Express request
@@ -36,9 +36,10 @@ async function sendClaudeRequest(request, response) {
});
const isSysPromptSupported = request.body.model === 'claude-2' || request.body.model === 'claude-2.1';
const requestPrompt = convertClaudePrompt(request.body.messages, !request.body.exclude_assistant, request.body.assistant_prefill, isSysPromptSupported, request.body.claude_use_sysprompt, request.body.human_sysprompt_message);
const requestPrompt = convertClaudePrompt(request.body.messages, !request.body.exclude_assistant, request.body.assistant_prefill, isSysPromptSupported, request.body.claude_use_sysprompt, request.body.human_sysprompt_message, request.body.claude_exclude_prefixes);
// Check Claude messages sequence and prefixes presence.
let sequenceError = [];
const sequence = requestPrompt.split('\n').filter(x => x.startsWith('Human:') || x.startsWith('Assistant:'));
const humanFound = sequence.some(line => line.startsWith('Human:'));
const assistantFound = sequence.some(line => line.startsWith('Assistant:'));
@@ -56,20 +57,20 @@ async function sendClaudeRequest(request, response) {
}
if (!humanFound) {
console.log(color.red(`${divider}\nWarning: No 'Human:' prefix found in the prompt.\n${divider}`));
sequenceError.push(`${divider}\nWarning: No 'Human:' prefix found in the prompt.\n${divider}`);
}
if (!assistantFound) {
console.log(color.red(`${divider}\nWarning: No 'Assistant: ' prefix found in the prompt.\n${divider}`));
sequenceError.push(`${divider}\nWarning: No 'Assistant: ' prefix found in the prompt.\n${divider}`);
}
if (!sequence[0].startsWith('Human:')) {
console.log(color.red(`${divider}\nWarning: The messages sequence should start with 'Human:' prefix.\nMake sure you have 'Human:' prefix at the very beggining of the prompt, or after the system prompt.\n${divider}`));
if (sequence[0] && !sequence[0].startsWith('Human:')) {
sequenceError.push(`${divider}\nWarning: The messages sequence should start with 'Human:' prefix.\nMake sure you have '\\n\\nHuman:' prefix at the very beggining of the prompt, or after the system prompt.\n${divider}`);
}
if (humanErrorCount > 0 || assistantErrorCount > 0) {
console.log(color.red(`${divider}\nWarning: Detected incorrect Prefix sequence(s).`));
console.log(color.red(`Incorrect "Human:" prefix(es): ${humanErrorCount}.\nIncorrect "Assistant: " prefix(es): ${assistantErrorCount}.`));
console.log(color.red('Check the prompt above and fix it in the SillyTavern.'));
console.log(color.red('\nThe correct sequence should look like this:\nSystem prompt <-(for the sysprompt format only, else have 2 empty lines above the first human\'s message.)'));
console.log(color.red(` <-----(Each message beginning with the "Assistant:/Human:" prefix must have one empty line above.)\nHuman:\n\nAssistant:\n...\n\nHuman:\n\nAssistant:\n${divider}`));
sequenceError.push(`${divider}\nWarning: Detected incorrect Prefix sequence(s).`);
sequenceError.push(`Incorrect "Human:" prefix(es): ${humanErrorCount}.\nIncorrect "Assistant: " prefix(es): ${assistantErrorCount}.`);
sequenceError.push('Check the prompt above and fix it in the SillyTavern.');
sequenceError.push('\nThe correct sequence in the console should look like this:\n(System prompt msg) <-(for the sysprompt format only, else have \\n\\n above the first human\'s message.)');
sequenceError.push(`\\n + <-----(Each message beginning with the "Assistant:/Human:" prefix must have \\n\\n before it.)\n\\n +\nHuman: \\n +\n\\n +\nAssistant: \\n +\n...\n\\n +\nHuman: \\n +\n\\n +\nAssistant: \n${divider}`);
}
// Add custom stop sequences
@@ -91,6 +92,10 @@ async function sendClaudeRequest(request, response) {
console.log('Claude request:', requestBody);
sequenceError.forEach(sequenceError => {
console.log(color.red(sequenceError));
});
const generateResponse = await fetch(apiUrl + '/complete', {
method: 'POST',
signal: controller.signal,
@@ -262,7 +267,7 @@ async function sendMakerSuiteRequest(request, response) {
? (stream ? 'streamGenerateContent' : 'generateContent')
: (isText ? 'generateText' : 'generateMessage');
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/${apiVersion}/models/${model}:${responseType}?key=${apiKey}`, {
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/${apiVersion}/models/${model}:${responseType}?key=${apiKey}${stream ? '&alt=sse' : ''}`, {
body: JSON.stringify(body),
method: 'POST',
headers: {
@@ -274,36 +279,8 @@ async function sendMakerSuiteRequest(request, response) {
// have to do this because of their busted ass streaming endpoint
if (stream) {
try {
let partialData = '';
generateResponse.body.on('data', (data) => {
const chunk = data.toString();
if (chunk.startsWith(',') || chunk.endsWith(',') || chunk.startsWith('[') || chunk.endsWith(']')) {
partialData = chunk.slice(1);
} else {
partialData += chunk;
}
while (true) {
let json;
try {
json = JSON.parse(partialData);
} catch (e) {
break;
}
response.write(JSON.stringify(json));
partialData = '';
}
});
request.socket.on('close', function () {
if (generateResponse.body instanceof Readable) generateResponse.body.destroy();
response.end();
});
generateResponse.body.on('end', () => {
console.log('Streaming request finished');
response.end();
});
// Pipe remote SSE stream to Express response
forwardFetchResponse(generateResponse, response);
} catch (error) {
console.log('Error forwarding streaming response:', error);
if (!response.headersSent) {
@@ -329,7 +306,7 @@ async function sendMakerSuiteRequest(request, response) {
}
const responseContent = candidates[0].content ?? candidates[0].output;
const responseText = typeof responseContent === 'string' ? responseContent : responseContent.parts?.[0]?.text;
const responseText = typeof responseContent === 'string' ? responseContent : responseContent?.parts?.[0]?.text;
if (!responseText) {
let message = 'MakerSuite Candidate text empty';
console.log(message, generateResponseJson);
@@ -431,7 +408,8 @@ async function sendAI21Request(request, response) {
* @param {express.Response} response Express response
*/
async function sendMistralAIRequest(request, response) {
const apiKey = readSecret(SECRET_KEYS.MISTRALAI);
const apiUrl = new URL(request.body.reverse_proxy || API_MISTRAL).toString();
const apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.MISTRALAI);
if (!apiKey) {
console.log('MistralAI API key is missing.');
@@ -441,9 +419,12 @@ async function sendMistralAIRequest(request, response) {
try {
//must send a user role as last message
const messages = Array.isArray(request.body.messages) ? request.body.messages : [];
//large seems to be throwing a 500 error if we don't make the first message a user role, most likely a bug since the other models won't do this
if (request.body.model.includes('large'))
messages[0].role = 'user';
const lastMsg = messages[messages.length - 1];
if (messages.length > 0 && lastMsg && (lastMsg.role === 'system' || lastMsg.role === 'assistant')) {
if (lastMsg.role === 'assistant') {
if (lastMsg.role === 'assistant' && lastMsg.name) {
lastMsg.content = lastMsg.name + ': ' + lastMsg.content;
} else if (lastMsg.role === 'system') {
lastMsg.content = '[INST] ' + lastMsg.content + ' [/INST]';
@@ -478,7 +459,7 @@ async function sendMistralAIRequest(request, response) {
'top_p': request.body.top_p,
'max_tokens': request.body.max_tokens,
'stream': request.body.stream,
'safe_mode': request.body.safe_mode,
'safe_prompt': request.body.safe_prompt,
'random_seed': request.body.seed === -1 ? undefined : request.body.seed,
};
@@ -495,7 +476,7 @@ async function sendMistralAIRequest(request, response) {
console.log('MisralAI request:', requestBody);
const generateResponse = await fetch('https://api.mistral.ai/v1/chat/completions', config);
const generateResponse = await fetch(apiUrl + '/chat/completions', config);
if (request.body.stream) {
forwardFetchResponse(generateResponse, response);
} else {
@@ -535,11 +516,12 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER) {
api_url = 'https://openrouter.ai/api/v1';
api_key_openai = readSecret(SECRET_KEYS.OPENROUTER);
// OpenRouter needs to pass the referer: https://openrouter.ai/docs
headers = { 'HTTP-Referer': request.headers.referer };
// OpenRouter needs to pass the Referer and X-Title: https://openrouter.ai/docs#requests
headers = { ...OPENROUTER_HEADERS };
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.MISTRALAI) {
api_url = 'https://api.mistral.ai/v1';
api_key_openai = readSecret(SECRET_KEYS.MISTRALAI);
api_url = new URL(request.body.reverse_proxy || API_MISTRAL).toString();
api_key_openai = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.MISTRALAI);
headers = {};
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.CUSTOM) {
api_url = request.body.custom_url;
api_key_openai = readSecret(SECRET_KEYS.CUSTOM);
@@ -699,12 +681,21 @@ router.post('/generate', jsonParser, function (request, response) {
let apiKey;
let headers;
let bodyParams;
const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string';
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENAI) {
apiUrl = new URL(request.body.reverse_proxy || API_OPENAI).toString();
apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI);
headers = {};
bodyParams = {};
bodyParams = {
logprobs: request.body.logprobs,
};
// Adjust logprobs params for Chat Completions API, which expects { top_logprobs: number; logprobs: boolean; }
if (!isTextCompletion && bodyParams.logprobs > 0) {
bodyParams.top_logprobs = bodyParams.logprobs;
bodyParams.logprobs = true;
}
if (getConfigValue('openai.randomizeUserId', false)) {
bodyParams['user'] = uuidv4();
@@ -712,10 +703,22 @@ router.post('/generate', jsonParser, function (request, response) {
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER) {
apiUrl = 'https://openrouter.ai/api/v1';
apiKey = readSecret(SECRET_KEYS.OPENROUTER);
// OpenRouter needs to pass the referer: https://openrouter.ai/docs
headers = { 'HTTP-Referer': request.headers.referer };
// OpenRouter needs to pass the Referer and X-Title: https://openrouter.ai/docs#requests
headers = { ...OPENROUTER_HEADERS };
bodyParams = { 'transforms': ['middle-out'] };
if (request.body.min_p !== undefined) {
bodyParams['min_p'] = request.body.min_p;
}
if (request.body.top_a !== undefined) {
bodyParams['top_a'] = request.body.top_a;
}
if (request.body.repetition_penalty !== undefined) {
bodyParams['repetition_penalty'] = request.body.repetition_penalty;
}
if (request.body.use_fallback) {
bodyParams['route'] = 'fallback';
}
@@ -723,7 +726,16 @@ router.post('/generate', jsonParser, function (request, response) {
apiUrl = request.body.custom_url;
apiKey = readSecret(SECRET_KEYS.CUSTOM);
headers = {};
bodyParams = {};
bodyParams = {
logprobs: request.body.logprobs,
};
// Adjust logprobs params for Chat Completions API, which expects { top_logprobs: number; logprobs: boolean; }
if (!isTextCompletion && bodyParams.logprobs > 0) {
bodyParams.top_logprobs = bodyParams.logprobs;
bodyParams.logprobs = true;
}
mergeObjectWithYaml(bodyParams, request.body.custom_include_body);
mergeObjectWithYaml(headers, request.body.custom_include_headers);
} else {
@@ -741,7 +753,6 @@ router.post('/generate', jsonParser, function (request, response) {
bodyParams['stop'] = request.body.stop;
}
const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string';
const textPrompt = isTextCompletion ? convertTextCompletionPrompt(request.body.messages) : '';
const endpointUrl = isTextCompletion && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER ?
`${apiUrl}/completions` :
@@ -767,6 +778,7 @@ router.post('/generate', jsonParser, function (request, response) {
'stop': isTextCompletion === false ? request.body.stop : undefined,
'logit_bias': request.body.logit_bias,
'seed': request.body.seed,
'n': request.body.n,
...bodyParams,
};
@@ -813,7 +825,7 @@ router.post('/generate', jsonParser, function (request, response) {
let json = await fetchResponse.json();
response.send(json);
console.log(json);
console.log(json?.choices[0]?.message);
console.log(json?.choices?.[0]?.message);
} else if (fetchResponse.status === 429 && retries > 0) {
console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`);
setTimeout(() => {

View File

@@ -4,7 +4,7 @@ const _ = require('lodash');
const Readable = require('stream').Readable;
const { jsonParser } = require('../../express-common');
const { TEXTGEN_TYPES, TOGETHERAI_KEYS, OLLAMA_KEYS } = require('../../constants');
const { TEXTGEN_TYPES, TOGETHERAI_KEYS, OLLAMA_KEYS, INFERMATICAI_KEYS, OPENROUTER_KEYS } = require('../../constants');
const { forwardFetchResponse, trimV1 } = require('../../util');
const { setAdditionalHeaders } = require('../../additional-headers');
@@ -106,6 +106,8 @@ router.post('/status', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.APHRODITE:
case TEXTGEN_TYPES.KOBOLDCPP:
case TEXTGEN_TYPES.LLAMACPP:
case TEXTGEN_TYPES.INFERMATICAI:
case TEXTGEN_TYPES.OPENROUTER:
url += '/v1/models';
break;
case TEXTGEN_TYPES.MANCER:
@@ -208,6 +210,7 @@ router.post('/generate', jsonParser, async function (request, response) {
request.body.api_server = request.body.api_server.replace('localhost', '127.0.0.1');
}
const apiType = request.body.api_type;
const baseUrl = request.body.api_server;
console.log(request.body);
@@ -232,6 +235,7 @@ router.post('/generate', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.TABBY:
case TEXTGEN_TYPES.KOBOLDCPP:
case TEXTGEN_TYPES.TOGETHERAI:
case TEXTGEN_TYPES.INFERMATICAI:
url += '/v1/completions';
break;
case TEXTGEN_TYPES.MANCER:
@@ -243,6 +247,9 @@ router.post('/generate', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.OLLAMA:
url += '/api/generate';
break;
case TEXTGEN_TYPES.OPENROUTER:
url += '/v1/chat/completions';
break;
}
}
@@ -257,11 +264,17 @@ router.post('/generate', jsonParser, async function (request, response) {
setAdditionalHeaders(request, args, baseUrl);
if (request.body.api_type === TEXTGEN_TYPES.TOGETHERAI) {
const stop = Array.isArray(request.body.stop) ? request.body.stop[0] : '';
request.body = _.pickBy(request.body, (_, key) => TOGETHERAI_KEYS.includes(key));
if (typeof stop === 'string' && stop.length > 0) {
request.body.stop = stop;
}
args.body = JSON.stringify(request.body);
}
if (request.body.api_type === TEXTGEN_TYPES.INFERMATICAI) {
request.body = _.pickBy(request.body, (_, key) => INFERMATICAI_KEYS.includes(key));
args.body = JSON.stringify(request.body);
}
if (request.body.api_type === TEXTGEN_TYPES.OPENROUTER) {
request.body = _.pickBy(request.body, (_, key) => OPENROUTER_KEYS.includes(key));
args.body = JSON.stringify(request.body);
}
@@ -270,6 +283,7 @@ router.post('/generate', jsonParser, async function (request, response) {
model: request.body.model,
prompt: request.body.prompt,
stream: request.body.stream ?? false,
keep_alive: -1,
raw: true,
options: _.pickBy(request.body, (_, key) => OLLAMA_KEYS.includes(key)),
});
@@ -296,6 +310,11 @@ router.post('/generate', jsonParser, async function (request, response) {
data['choices'] = [{ text }];
}
// Map InfermaticAI response to OAI completions format
if (apiType === TEXTGEN_TYPES.INFERMATICAI) {
data['choices'] = (data?.choices || []).map(choice => ({ text: choice.message.content }));
}
return response.send(data);
} else {
const text = await completionsReply.text();

View File

@@ -7,9 +7,6 @@ const writeFileAtomicSync = require('write-file-atomic').sync;
const yaml = require('yaml');
const _ = require('lodash');
const encode = require('png-chunks-encode');
const extract = require('png-chunks-extract');
const PNGtext = require('png-chunk-text');
const jimp = require('jimp');
const { DIRECTORIES, UPLOADS_PATH, AVATAR_WIDTH, AVATAR_HEIGHT } = require('../constants');
@@ -33,7 +30,7 @@ const characterDataCache = new Map();
* @param {string} input_format - 'png'
* @returns {Promise<string | undefined>} - Character card data
*/
async function charaRead(img_url, input_format) {
async function charaRead(img_url, input_format = 'png') {
const stat = fs.statSync(img_url);
const cacheKey = `${img_url}-${stat.mtimeMs}`;
if (characterDataCache.has(cacheKey)) {
@@ -59,22 +56,12 @@ async function charaWrite(img_url, data, target_img, response = undefined, mes =
}
}
// Read the image, resize, and save it as a PNG into the buffer
const image = await tryReadImage(img_url, crop);
const inputImage = await tryReadImage(img_url, crop);
// Get the chunks
const chunks = extract(image);
const tEXtChunks = chunks.filter(chunk => chunk.name === 'tEXt');
const outputImage = characterCardParser.write(inputImage, data);
// Remove all existing tEXt chunks
for (let tEXtChunk of tEXtChunks) {
chunks.splice(chunks.indexOf(tEXtChunk), 1);
}
// Add new chunks before the IEND chunk
const base64EncodedData = Buffer.from(data, 'utf8').toString('base64');
chunks.splice(-1, 0, PNGtext.encode('chara', base64EncodedData));
//chunks.splice(-1, 0, text.encode('lorem', 'ipsum'));
writeFileAtomicSync(DIRECTORIES.characters + target_img + '.png', Buffer.from(encode(chunks)));
writeFileAtomicSync(DIRECTORIES.characters + target_img + '.png', outputImage);
if (response !== undefined) response.send(mes);
return true;
} catch (err) {
@@ -152,13 +139,13 @@ const processCharacter = async (item, i) => {
const img_data = await charaRead(DIRECTORIES.characters + item);
if (img_data === undefined) throw new Error('Failed to read character file');
let jsonObject = getCharaCardV2(JSON.parse(img_data));
let jsonObject = getCharaCardV2(JSON.parse(img_data), false);
jsonObject.avatar = item;
characters[i] = jsonObject;
characters[i]['json_data'] = img_data;
const charStat = fs.statSync(path.join(DIRECTORIES.characters, item));
characters[i]['date_added'] = charStat.birthtimeMs;
characters[i]['create_date'] = jsonObject['create_date'] || humanizedISO8601DateTime(charStat.birthtimeMs);
characters[i]['date_added'] = charStat.ctimeMs;
characters[i]['create_date'] = jsonObject['create_date'] || humanizedISO8601DateTime(charStat.ctimeMs);
const char_dir = path.join(DIRECTORIES.chats, item.replace('.png', ''));
const { chatSize, dateLastChat } = calculateChatSize(char_dir);
@@ -183,15 +170,30 @@ const processCharacter = async (item, i) => {
}
};
function getCharaCardV2(jsonObject) {
/**
* Convert a character object to Spec V2 format.
* @param {object} jsonObject Character object
* @param {boolean} hoistDate Will set the chat and create_date fields to the current date if they are missing
* @returns {object} Character object in Spec V2 format
*/
function getCharaCardV2(jsonObject, hoistDate = true) {
if (jsonObject.spec === undefined) {
jsonObject = convertToV2(jsonObject);
if (hoistDate && !jsonObject.create_date) {
jsonObject.create_date = humanizedISO8601DateTime();
}
} else {
jsonObject = readFromV2(jsonObject);
}
return jsonObject;
}
/**
* Convert a character object to Spec V2 format.
* @param {object} char Character object
* @returns {object} Character object in Spec V2 format
*/
function convertToV2(char) {
// Simulate incoming data from frontend form
const result = charaFormatData({
@@ -212,7 +214,8 @@ function convertToV2(char) {
});
result.chat = char.chat ?? humanizedISO8601DateTime();
result.create_date = char.create_date ?? humanizedISO8601DateTime();
result.create_date = char.create_date;
return result;
}
@@ -300,6 +303,7 @@ function charaFormatData(data) {
_.set(char, 'chat', data.ch_name + ' - ' + humanizedISO8601DateTime());
_.set(char, 'talkativeness', data.talkativeness);
_.set(char, 'fav', data.fav == 'true');
_.set(char, 'tags', typeof data.tags == 'string' ? (data.tags.split(',').map(x => x.trim()).filter(x => x)) : data.tags || []);
// Spec V2 fields
_.set(char, 'spec', 'chara_card_v2');
@@ -336,7 +340,7 @@ function charaFormatData(data) {
if (data.world) {
try {
const file = readWorldInfoFile(data.world);
const file = readWorldInfoFile(data.world, false);
// File was imported - save it to the character book
if (file && file.originalData) {
@@ -387,6 +391,11 @@ function convertWorldInfoToCharacterBook(name, entries) {
depth: entry.depth ?? 4,
selectiveLogic: entry.selectiveLogic ?? 0,
group: entry.group ?? '',
prevent_recursion: entry.preventRecursion ?? false,
scan_depth: entry.scanDepth ?? null,
match_whole_words: entry.matchWholeWords ?? null,
case_sensitive: entry.caseSensitive ?? null,
automation_id: entry.automationId ?? '',
},
};
@@ -791,6 +800,17 @@ function getPngName(file) {
return file;
}
/**
* Gets the preserved name for the uploaded file if the request is valid.
* @param {import("express").Request} request - Express request object
* @returns {string | undefined} - The preserved name if the request is valid, otherwise undefined
*/
function getPreservedName(request) {
return request.body.file_type === 'png' && request.body.preserve_file_name === 'true' && request.file?.originalname
? path.parse(request.file.originalname).name
: undefined;
}
router.post('/import', urlencodedParser, async function (request, response) {
if (!request.body || !request.file) return response.sendStatus(400);
@@ -798,6 +818,7 @@ router.post('/import', urlencodedParser, async function (request, response) {
let filedata = request.file;
let uploadPath = path.join(UPLOADS_PATH, filedata.filename);
let format = request.body.file_type;
const preservedFileName = getPreservedName(request);
if (format == 'yaml' || format == 'yml') {
try {
@@ -889,7 +910,7 @@ router.post('/import', urlencodedParser, async function (request, response) {
let jsonData = JSON.parse(img_data);
jsonData.name = sanitize(jsonData.data?.name || jsonData.name);
png_name = getPngName(jsonData.name);
png_name = preservedFileName || getPngName(jsonData.name);
if (jsonData.spec !== undefined) {
console.log('Found a v2 character file.');
@@ -1003,7 +1024,7 @@ router.post('/export', jsonParser, async function (request, response) {
let json = await charaRead(filename);
if (json === undefined) return response.sendStatus(400);
let jsonObject = getCharaCardV2(JSON.parse(json));
return response.type('json').send(jsonObject);
return response.type('json').send(JSON.stringify(jsonObject, null, 4));
}
catch {
return response.sendStatus(400);

View File

@@ -10,6 +10,7 @@ const contentLogPath = path.join(contentDirectory, 'content.log');
const contentIndexPath = path.join(contentDirectory, 'index.json');
const { DIRECTORIES } = require('../constants');
const presetFolders = [DIRECTORIES.koboldAI_Settings, DIRECTORIES.openAI_Settings, DIRECTORIES.novelAI_Settings, DIRECTORIES.textGen_Settings];
const characterCardParser = require('../character-card-parser.js');
/**
* Gets the default presets from the content directory.
@@ -219,6 +220,56 @@ async function downloadChubCharacter(id) {
return { buffer, fileName, fileType };
}
/**
* Downloads a character card from the Pygsite.
* @param {string} id UUID of the character
* @returns {Promise<{buffer: Buffer, fileName: string, fileType: string}>}
*/
async function downloadPygmalionCharacter(id) {
const result = await fetch(`https://server.pygmalion.chat/api/export/character/${id}/v2`);
if (!result.ok) {
const text = await result.text();
console.log('Pygsite returned error', result.status, text);
throw new Error('Failed to download character');
}
const jsonData = await result.json();
const characterData = jsonData?.character;
if (!characterData || typeof characterData !== 'object') {
console.error('Pygsite returned invalid character data', jsonData);
throw new Error('Failed to download character');
}
try {
const avatarUrl = characterData?.data?.avatar;
if (!avatarUrl) {
console.error('Pygsite character does not have an avatar', characterData);
throw new Error('Failed to download avatar');
}
const avatarResult = await fetch(avatarUrl);
const avatarBuffer = await avatarResult.buffer();
const cardBuffer = characterCardParser.write(avatarBuffer, JSON.stringify(characterData));
return {
buffer: cardBuffer,
fileName: `${sanitize(id)}.png`,
fileType: 'image/png',
};
} catch (e) {
console.error('Failed to download avatar, using JSON instead', e);
return {
buffer: Buffer.from(JSON.stringify(jsonData)),
fileName: `${sanitize(id)}.json`,
fileType: 'application/json',
};
}
}
/**
*
* @param {String} str
@@ -294,7 +345,7 @@ async function downloadJannyCharacter(uuid) {
* @param {String} url
* @returns {String | null } UUID of the character
*/
function parseJannyUrl(url) {
function getUuidFromUrl(url) {
// Extract UUID from URL
const uuidRegex = /[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}/;
const matches = url.match(uuidRegex);
@@ -306,7 +357,7 @@ function parseJannyUrl(url) {
const router = express.Router();
router.post('/import', jsonParser, async (request, response) => {
router.post('/importURL', jsonParser, async (request, response) => {
if (!request.body.url) {
return response.sendStatus(400);
}
@@ -317,8 +368,18 @@ router.post('/import', jsonParser, async (request, response) => {
let type;
const isJannnyContent = url.includes('janitorai');
if (isJannnyContent) {
const uuid = parseJannyUrl(url);
const isPygmalionContent = url.includes('pygmalion.chat');
if (isPygmalionContent) {
const uuid = getUuidFromUrl(url);
if (!uuid) {
return response.sendStatus(404);
}
type = 'character';
result = await downloadPygmalionCharacter(uuid);
} else if (isJannnyContent) {
const uuid = getUuidFromUrl(url);
if (!uuid) {
return response.sendStatus(404);
}
@@ -352,6 +413,49 @@ router.post('/import', jsonParser, async (request, response) => {
}
});
router.post('/importUUID', jsonParser, async (request, response) => {
if (!request.body.url) {
return response.sendStatus(400);
}
try {
const uuid = request.body.url;
let result;
const isJannny = uuid.includes('_character');
const isPygmalion = (!isJannny && uuid.length == 36);
const uuidType = uuid.includes('lorebook') ? 'lorebook' : 'character';
if (isPygmalion) {
console.log('Downloading Pygmalion character:', uuid);
result = await downloadPygmalionCharacter(uuid);
} else if (isJannny) {
console.log('Downloading Janitor character:', uuid.split('_')[0]);
result = await downloadJannyCharacter(uuid.split('_')[0]);
} else {
if (uuidType === 'character') {
console.log('Downloading chub character:', uuid);
result = await downloadChubCharacter(uuid);
}
else if (uuidType === 'lorebook') {
console.log('Downloading chub lorebook:', uuid);
result = await downloadChubLorebook(uuid);
}
else {
return response.sendStatus(404);
}
}
if (result.fileType) response.set('Content-Type', result.fileType);
response.set('Content-Disposition', `attachment; filename="${result.fileName}"`);
response.set('X-Custom-Content-Type', uuidType);
return response.send(result.buffer);
} catch (error) {
console.log('Importing custom content failed', error);
return response.sendStatus(500);
}
});
module.exports = {
checkForNewContent,
getDefaultPresets,

View File

@@ -10,7 +10,7 @@ const API_NOVELAI = 'https://api.novelai.net';
// Ban bracket generation, plus defaults
const badWordsList = [
[3], [49356], [1431], [31715], [34387], [20765], [30702], [10691], [49333], [1266],
[19438], [43145], [26523], [41471], [2936], [85, 85], [49332], [7286], [1115],
[19438], [43145], [26523], [41471], [2936], [85, 85], [49332], [7286], [1115], [24],
];
const hypeBotBadWordsList = [
@@ -172,9 +172,17 @@ router.post('/generate', jsonParser, async function (req, res) {
'return_full_text': req.body.return_full_text,
'prefix': req.body.prefix,
'order': req.body.order,
'num_logprobs': req.body.num_logprobs,
},
};
// Tells the model to stop generation at '>'
if ('theme_textadventure' === req.body.prefix &&
(true === req.body.model.includes('clio') ||
true === req.body.model.includes('kayra'))) {
data.parameters.eos_token_id = 49405;
}
console.log(util.inspect(data, { depth: 4 }));
const args = {
@@ -208,7 +216,7 @@ router.post('/generate', jsonParser, async function (req, res) {
}
const data = await response.json();
console.log(data);
console.log('NovelAI Output', data?.output);
return res.send(data);
}
} catch (error) {
@@ -342,7 +350,9 @@ router.post('/generate-voice', jsonParser, async (request, response) => {
});
if (!result.ok) {
return response.sendStatus(result.status);
const errorText = await result.text();
console.log('NovelAI returned an error.', result.statusText, errorText);
return response.sendStatus(500);
}
const chunks = await readAllChunks(result.body);

View File

@@ -33,6 +33,7 @@ router.post('/caption-image', jsonParser, async (request, response) => {
}
if (request.body.api === 'ooba') {
key = readSecret(SECRET_KEYS.OOBA);
bodyParams.temperature = 0.1;
}

View File

@@ -5,15 +5,21 @@
* @param {string} addAssistantPrefill Add Assistant prefill after the assistant postfix.
* @param {boolean} withSysPromptSupport Indicates if the Claude model supports the system prompt format.
* @param {boolean} useSystemPrompt Indicates if the system prompt format should be used.
* @param {boolean} excludePrefixes Exlude Human/Assistant prefixes.
* @param {string} addSysHumanMsg Add Human message between system prompt and assistant.
* @returns {string} Prompt for Claude
* @copyright Prompt Conversion script taken from RisuAI by kwaroran (GPLv3).
*/
function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill, withSysPromptSupport, useSystemPrompt, addSysHumanMsg) {
function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill, withSysPromptSupport, useSystemPrompt, addSysHumanMsg, excludePrefixes) {
//Prepare messages for claude.
//When 'Exclude Human/Assistant prefixes' checked, setting messages role to the 'system'(last message is exception).
if (messages.length > 0) {
messages[0].role = 'system';
if (excludePrefixes) {
messages.slice(0, -1).forEach(message => message.role = 'system');
} else {
messages[0].role = 'system';
}
//Add the assistant's message to the end of messages.
if (addAssistantPostfix) {
messages.push({
@@ -29,7 +35,7 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
}
return message.role === 'assistant' && i > 0;
});
// When 2.1+ and 'Use system prompt" checked, switches to the system prompt format by setting the first message's role to the 'system'.
// When 2.1+ and 'Use system prompt' checked, switches to the system prompt format by setting the first message's role to the 'system'.
// Inserts the human's message before the first the assistant one, if there are no such message or prefix found.
if (withSysPromptSupport && useSystemPrompt) {
messages[0].role = 'system';
@@ -43,7 +49,7 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
// Otherwise, use the default message format by setting the first message's role to 'user'(compatible with all claude models including 2.1.)
messages[0].role = 'user';
// Fix messages order for default message format when(messages > Context Size) by merging two messages with "\n\nHuman: " prefixes into one, before the first Assistant's message.
if (firstAssistantIndex > 0) {
if (firstAssistantIndex > 0 && !excludePrefixes) {
messages[firstAssistantIndex - 1].role = firstAssistantIndex - 1 !== 0 && messages[firstAssistantIndex - 1].role === 'user' ? 'FixHumMsg' : messages[firstAssistantIndex - 1].role;
}
}
@@ -51,11 +57,11 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
// Convert messages to the prompt.
let requestPrompt = messages.map((v, i) => {
// Set prefix according to the role.
// Set prefix according to the role. Also, when "Exclude Human/Assistant prefixes" is checked, names are added via the system prefix.
let prefix = {
'assistant': '\n\nAssistant: ',
'user': '\n\nHuman: ',
'system': i === 0 ? '' : v.name === 'example_assistant' ? '\n\nA: ' : v.name === 'example_user' ? '\n\nH: ' : '\n\n',
'system': i === 0 ? '' : v.name === 'example_assistant' ? '\n\nA: ' : v.name === 'example_user' ? '\n\nH: ' : excludePrefixes && v.name ? `\n\n${v.name}: ` : '\n\n',
'FixHumMsg': '\n\nFirst message: ',
}[v.role] ?? '';
// Claude doesn't support message names, so we'll just add them to the message content.

View File

@@ -17,6 +17,7 @@ const SECRET_KEYS = {
DEEPL: 'deepl',
LIBRE: 'libre',
LIBRE_URL: 'libre_url',
LINGVA_URL: 'lingva_url',
OPENROUTER: 'api_key_openrouter',
SCALE: 'api_key_scale',
AI21: 'api_key_ai21',
@@ -28,8 +29,18 @@ const SECRET_KEYS = {
TOGETHERAI: 'api_key_togetherai',
MISTRALAI: 'api_key_mistralai',
CUSTOM: 'api_key_custom',
OOBA: 'api_key_ooba',
INFERMATICAI: 'api_key_infermaticai',
};
// These are the keys that are safe to expose, even if allowKeysExposure is false
const EXPORTABLE_KEYS = [
SECRET_KEYS.LIBRE_URL,
SECRET_KEYS.LINGVA_URL,
SECRET_KEYS.ONERING_URL,
SECRET_KEYS.DEEPLX_URL,
];
/**
* Writes a secret to the secrets file
* @param {string} key Secret key
@@ -210,14 +221,13 @@ router.post('/view', jsonParser, async (_, response) => {
router.post('/find', jsonParser, (request, response) => {
const allowKeysExposure = getConfigValue('allowKeysExposure', false);
const key = request.body.key;
if (!allowKeysExposure) {
if (!allowKeysExposure && !EXPORTABLE_KEYS.includes(key)) {
console.error('Cannot fetch secrets unless allowKeysExposure in config.yaml is set to true');
return response.sendStatus(403);
}
const key = request.body.key;
try {
const secret = readSecret(key);

82
src/endpoints/speech.js Normal file
View File

@@ -0,0 +1,82 @@
const express = require('express');
const { jsonParser } = require('../express-common');
const router = express.Router();
/**
* Gets the audio data from a base64-encoded audio file.
* @param {string} audio Base64-encoded audio
* @returns {Float64Array} Audio data
*/
function getWaveFile(audio) {
const wavefile = require('wavefile');
const wav = new wavefile.WaveFile();
wav.fromDataURI(audio);
wav.toBitDepth('32f');
wav.toSampleRate(16000);
let audioData = wav.getSamples();
if (Array.isArray(audioData)) {
if (audioData.length > 1) {
const SCALING_FACTOR = Math.sqrt(2);
// Merge channels (into first channel to save memory)
for (let i = 0; i < audioData[0].length; ++i) {
audioData[0][i] = SCALING_FACTOR * (audioData[0][i] + audioData[1][i]) / 2;
}
}
// Select first channel
audioData = audioData[0];
}
return audioData;
}
router.post('/recognize', jsonParser, async (req, res) => {
try {
const TASK = 'automatic-speech-recognition';
const { model, audio, lang } = req.body;
const module = await import('../transformers.mjs');
const pipe = await module.default.getPipeline(TASK, model);
const wav = getWaveFile(audio);
const start = performance.now();
const result = await pipe(wav, { language: lang || null });
const end = performance.now();
console.log(`Execution duration: ${(end - start) / 1000} seconds`);
console.log('Transcribed audio:', result.text);
return res.json({ text: result.text });
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
router.post('/synthesize', jsonParser, async (req, res) => {
try {
const wavefile = require('wavefile');
const TASK = 'text-to-speech';
const { text, model, speaker } = req.body;
const module = await import('../transformers.mjs');
const pipe = await module.default.getPipeline(TASK, model);
const speaker_embeddings = speaker
? new Float32Array(new Uint8Array(Buffer.from(speaker.startsWith('data:') ? speaker.split(',')[1] : speaker, 'base64')).buffer)
: null;
const start = performance.now();
const result = await pipe(text, { speaker_embeddings: speaker_embeddings });
const end = performance.now();
console.log(`Execution duration: ${(end - start) / 1000} seconds`);
const wav = new wavefile.WaveFile();
wav.fromScratch(1, result.sampling_rate, '32f', result.audio);
const buffer = wav.toBuffer();
res.set('Content-Type', 'audio/wav');
return res.send(Buffer.from(buffer));
} catch (error) {
console.error(error);
return res.sendStatus(500);
}
});
module.exports = { router };

View File

@@ -2,10 +2,8 @@ const fs = require('fs');
const path = require('path');
const express = require('express');
const writeFileAtomic = require('write-file-atomic');
const util = require('util');
const crypto = require('crypto');
const writeFile = util.promisify(writeFileAtomic);
const readFile = fs.promises.readFile;
const readdir = fs.promises.readdir;
@@ -168,7 +166,7 @@ async function saveStatsToFile() {
if (charStats.timestamp > lastSaveTimestamp) {
//console.debug("Saving stats to file...");
try {
await writeFile(statsFilePath, JSON.stringify(charStats));
await writeFileAtomic(statsFilePath, JSON.stringify(charStats));
lastSaveTimestamp = Date.now();
} catch (error) {
console.log('Failed to save stats to file.', error);

View File

@@ -298,11 +298,13 @@ function createSentencepieceDecodingHandler(tokenizer) {
const ids = request.body.ids || [];
const instance = await tokenizer?.get();
const text = await instance?.decodeIds(ids);
return response.send({ text });
const ops = ids.map(id => instance.decodeIds([id]));
const chunks = await Promise.all(ops);
const text = chunks.join('');
return response.send({ text, chunks });
} catch (error) {
console.log(error);
return response.send({ text: '' });
return response.send({ text: '', chunks: [] });
}
};
}
@@ -626,6 +628,10 @@ router.post('/remote/textgenerationwebui/encode', jsonParser, async function (re
url += '/tokenize';
args.body = JSON.stringify({ 'content': text });
break;
case TEXTGEN_TYPES.APHRODITE:
url += '/v1/tokenize';
args.body = JSON.stringify({ 'prompt': text });
break;
default:
url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text });

View File

@@ -19,6 +19,10 @@ router.post('/libre', jsonParser, async (request, response) => {
return response.sendStatus(400);
}
if (request.body.lang === 'zh-CN' || request.body.lang === 'zh-TW') {
request.body.lang = 'zh';
}
const text = request.body.text;
const lang = request.body.lang;
@@ -98,6 +102,52 @@ router.post('/google', jsonParser, async (request, response) => {
}
});
router.post('/lingva', jsonParser, async (request, response) => {
try {
const baseUrl = readSecret(SECRET_KEYS.LINGVA_URL);
if (!baseUrl) {
console.log('Lingva URL is not configured.');
return response.sendStatus(400);
}
const text = request.body.text;
const lang = request.body.lang;
if (!text || !lang) {
return response.sendStatus(400);
}
console.log('Input text: ' + text);
const url = `${baseUrl}/auto/${lang}/${encodeURIComponent(text)}`;
https.get(url, (resp) => {
let data = '';
resp.on('data', (chunk) => {
data += chunk;
});
resp.on('end', () => {
try {
const result = JSON.parse(data);
console.log('Translated text: ' + result.translation);
return response.send(result.translation);
} catch (error) {
console.log('Translation error', error);
return response.sendStatus(500);
}
});
}).on('error', (err) => {
console.log('Translation error: ' + err.message);
return response.sendStatus(500);
});
} catch (error) {
console.log('Translation error', error);
return response.sendStatus(500);
}
});
router.post('/deepl', jsonParser, async (request, response) => {
const key = readSecret(SECRET_KEYS.DEEPL);

View File

@@ -4,19 +4,26 @@ const express = require('express');
const sanitize = require('sanitize-filename');
const { jsonParser } = require('../express-common');
// Don't forget to add new sources to the SOURCES array
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai'];
/**
* Gets the vector for the given text from the given source.
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} text - The text to get the vector for
* @returns {Promise<number[]>} - The vector for the text
*/
async function getVector(source, text) {
async function getVector(source, sourceSettings, text) {
switch (source) {
case 'togetherai':
case 'mistral':
case 'openai':
return require('../openai-vectors').getOpenAIVector(text, source);
return require('../openai-vectors').getOpenAIVector(text, source, sourceSettings.model);
case 'transformers':
return require('../embedding').getTransformersVector(text);
case 'extras':
return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
case 'palm':
return require('../makersuite-vectors').getMakerSuiteVector(text);
}
@@ -24,6 +31,42 @@ async function getVector(source, text) {
throw new Error(`Unknown vector source ${source}`);
}
/**
* Gets the vector for the given text batch from the given source.
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string[]} texts - The array of texts to get the vector for
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getBatchVector(source, sourceSettings, texts) {
const batchSize = 10;
const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize));
let results = [];
for (let batch of batches) {
switch (source) {
case 'togetherai':
case 'mistral':
case 'openai':
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
break;
case 'transformers':
results.push(...await require('../embedding').getTransformersBatchVector(batch));
break;
case 'extras':
results.push(...await require('../extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
break;
case 'palm':
results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch));
break;
default:
throw new Error(`Unknown vector source ${source}`);
}
}
return results;
}
/**
* Gets the index for the vector collection
* @param {string} collectionId - The collection ID
@@ -45,19 +88,20 @@ async function getIndex(collectionId, source, create = true) {
* Inserts items into the vector collection
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
*/
async function insertVectorItems(collectionId, source, items) {
async function insertVectorItems(collectionId, source, sourceSettings, items) {
const store = await getIndex(collectionId, source);
await store.beginUpdate();
for (const item of items) {
const text = item.text;
const hash = item.hash;
const index = item.index;
const vector = await getVector(source, text);
await store.upsertItem({ vector: vector, metadata: { hash, text, index } });
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text));
for (let i = 0; i < items.length; i++) {
const item = items[i];
const vector = vectors[i];
await store.upsertItem({ vector: vector, metadata: { hash: item.hash, text: item.text, index: item.index } });
}
await store.endUpdate();
@@ -101,13 +145,14 @@ async function deleteVectorItems(collectionId, source, hashes) {
* Gets the hashes of the items in the vector collection that match the search text
* @param {string} collectionId - The collection ID
* @param {string} source - The source of the vector
* @param {Object} sourceSettings - Settings for the source, if it needs any
* @param {string} searchText - The text to search for
* @param {number} topK - The number of results to return
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
*/
async function queryCollection(collectionId, source, searchText, topK) {
async function queryCollection(collectionId, source, sourceSettings, searchText, topK) {
const store = await getIndex(collectionId, source);
const vector = await getVector(source, searchText);
const vector = await getVector(source, sourceSettings, searchText);
const result = await store.queryItems(vector, topK);
const metadata = result.map(x => x.item.metadata);
@@ -115,6 +160,35 @@ async function queryCollection(collectionId, source, searchText, topK) {
return { metadata, hashes };
}
/**
* Extracts settings for the vectorization sources from the HTTP request headers.
* @param {string} source - Which source to extract settings for.
* @param {object} request - The HTTP request object.
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
*/
function getSourceSettings(source, request) {
if (source === 'togetherai') {
let model = String(request.headers['x-togetherai-model']);
return {
model: model,
};
} else {
// Extras API settings to connect to the Extras embeddings provider
let extrasUrl = '';
let extrasKey = '';
if (source === 'extras') {
extrasUrl = String(request.headers['x-extras-url']);
extrasKey = String(request.headers['x-extras-key']);
}
return {
extrasUrl: extrasUrl,
extrasKey: extrasKey,
};
}
}
const router = express.Router();
router.post('/query', jsonParser, async (req, res) => {
@@ -127,8 +201,9 @@ router.post('/query', jsonParser, async (req, res) => {
const searchText = String(req.body.searchText);
const topK = Number(req.body.topK) || 10;
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
const results = await queryCollection(collectionId, source, searchText, topK);
const results = await queryCollection(collectionId, source, sourceSettings, searchText, topK);
return res.json(results);
} catch (error) {
console.error(error);
@@ -145,8 +220,9 @@ router.post('/insert', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text, index: x.index }));
const source = String(req.body.source) || 'transformers';
const sourceSettings = getSourceSettings(source, req);
await insertVectorItems(collectionId, source, items);
await insertVectorItems(collectionId, source, sourceSettings, items);
return res.sendStatus(200);
} catch (error) {
console.error(error);
@@ -197,8 +273,7 @@ router.post('/purge', jsonParser, async (req, res) => {
const collectionId = String(req.body.collectionId);
const sources = ['transformers', 'openai', 'palm'];
for (const source of sources) {
for (const source of SOURCES) {
const index = await getIndex(collectionId, source, false);
const exists = await index.isIndexCreated();

View File

@@ -7,8 +7,14 @@ const writeFileAtomicSync = require('write-file-atomic').sync;
const { jsonParser, urlencodedParser } = require('../express-common');
const { DIRECTORIES, UPLOADS_PATH } = require('../constants');
function readWorldInfoFile(worldInfoName) {
const dummyObject = { entries: {} };
/**
* Reads a World Info file and returns its contents
* @param {string} worldInfoName Name of the World Info file
* @param {boolean} allowDummy If true, returns an empty object if the file doesn't exist
* @returns {object} World Info file contents
*/
function readWorldInfoFile(worldInfoName, allowDummy) {
const dummyObject = allowDummy ? { entries: {} } : null;
if (!worldInfoName) {
return dummyObject;
@@ -34,7 +40,7 @@ router.post('/get', jsonParser, (request, response) => {
return response.sendStatus(400);
}
const file = readWorldInfoFile(request.body.name);
const file = readWorldInfoFile(request.body.name, true);
return response.send(file);
});

78
src/extras-vectors.js Normal file
View File

@@ -0,0 +1,78 @@
const fetch = require('node-fetch').default;
/**
* Gets the vector for the given text from SillyTavern-extras
* @param {string[]} texts - The array of texts to get the vectors for
* @param {string} apiUrl - The Extras API URL
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getExtrasBatchVector(texts, apiUrl, apiKey) {
return getExtrasVectorImpl(texts, apiUrl, apiKey);
}
/**
* Gets the vector for the given text from SillyTavern-extras
* @param {string} text - The text to get the vector for
* @param {string} apiUrl - The Extras API URL
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled
* @returns {Promise<number[]>} - The vector for the text
*/
async function getExtrasVector(text, apiUrl, apiKey) {
return getExtrasVectorImpl(text, apiUrl, apiKey);
}
/**
* Gets the vector for the given text from SillyTavern-extras
* @param {string|string[]} text - The text or texts to get the vector(s) for
* @param {string} apiUrl - The Extras API URL
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled *
* @returns {Promise<Array>} - The vector for a single text if input is string, or the array of vectors for multiple texts if input is string[]
*/
async function getExtrasVectorImpl(text, apiUrl, apiKey) {
let url;
try {
url = new URL(apiUrl);
url.pathname = '/api/embeddings/compute';
}
catch (error) {
console.log('Failed to set up Extras API call:', error);
console.log('Extras API URL given was:', apiUrl);
throw error;
}
const headers = {
'Content-Type': 'application/json',
};
// Include the Extras API key, if enabled
if (apiKey && apiKey.length > 0) {
Object.assign(headers, {
'Authorization': `Bearer ${apiKey}`,
});
}
const response = await fetch(url, {
method: 'POST',
headers: headers,
body: JSON.stringify({
text: text, // The backend accepts {string|string[]} for one or multiple text items, respectively.
}),
});
if (!response.ok) {
const text = await response.text();
console.log('Extras request failed', response.statusText, text);
throw new Error('Extras request failed');
}
const data = await response.json();
const vector = data.embedding; // `embedding`: number[] (one text item), or number[][] (multiple text items).
return vector;
}
module.exports = {
getExtrasVector,
getExtrasBatchVector,
};

View File

@@ -1,6 +1,17 @@
const fetch = require('node-fetch').default;
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
/**
* Gets the vector for the given text from gecko model
* @param {string[]} texts - The array of texts to get the vector for
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getMakerSuiteBatchVector(texts) {
const promises = texts.map(text => getMakerSuiteVector(text));
const vectors = await Promise.all(promises);
return vectors;
}
/**
* Gets the vector for the given text from PaLM gecko model
* @param {string} text - The text to get the vector for
@@ -40,4 +51,5 @@ async function getMakerSuiteVector(text) {
module.exports = {
getMakerSuiteVector,
getMakerSuiteBatchVector,
};

View File

@@ -2,8 +2,13 @@ const fetch = require('node-fetch').default;
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
const SOURCES = {
'togetherai': {
secretKey: SECRET_KEYS.TOGETHERAI,
url: 'api.together.xyz',
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
},
'mistral': {
secretKey: SECRET_KEYS.MISTRAL,
secretKey: SECRET_KEYS.MISTRALAI,
url: 'api.mistral.ai',
model: 'mistral-embed',
},
@@ -15,12 +20,13 @@ const SOURCES = {
};
/**
* Gets the vector for the given text from an OpenAI compatible endpoint.
* @param {string} text - The text to get the vector for
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
* @param {string[]} texts - The array of texts to get the vector for
* @param {string} source - The source of the vector
* @returns {Promise<number[]>} - The vector for the text
* @param {string} model - The model to use for the embedding
* @returns {Promise<number[][]>} - The array of vectors for the texts
*/
async function getOpenAIVector(text, source) {
async function getOpenAIBatchVector(texts, source, model = '') {
const config = SOURCES[source];
if (!config) {
@@ -43,8 +49,8 @@ async function getOpenAIVector(text, source) {
Authorization: `Bearer ${key}`,
},
body: JSON.stringify({
input: text,
model: config.model,
input: texts,
model: model || config.model,
}),
});
@@ -55,16 +61,32 @@ async function getOpenAIVector(text, source) {
}
const data = await response.json();
const vector = data?.data[0]?.embedding;
if (!Array.isArray(vector)) {
if (!Array.isArray(data?.data)) {
console.log('API response was not an array');
throw new Error('API response was not an array');
}
return vector;
// Sort data by x.index to ensure the order is correct
data.data.sort((a, b) => a.index - b.index);
const vectors = data.data.map(x => x.embedding);
return vectors;
}
/**
* Gets the vector for the given text from an OpenAI compatible endpoint.
* @param {string} text - The text to get the vector for
* @param {string} source - The source of the vector
* @param model
* @returns {Promise<number[]>} - The vector for the text
*/
async function getOpenAIVector(text, source, model = '') {
const vectors = await getOpenAIBatchVector([text], source, model);
return vectors[0];
}
module.exports = {
getOpenAIVector,
getOpenAIBatchVector,
};

View File

@@ -17,21 +17,37 @@ const tasks = {
defaultModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
pipeline: null,
configField: 'extras.classificationModel',
quantized: true,
},
'image-to-text': {
defaultModel: 'Xenova/vit-gpt2-image-captioning',
pipeline: null,
configField: 'extras.captioningModel',
quantized: true,
},
'feature-extraction': {
defaultModel: 'Xenova/all-mpnet-base-v2',
pipeline: null,
configField: 'extras.embeddingModel',
quantized: true,
},
'text-generation': {
defaultModel: 'Cohee/fooocus_expansion-onnx',
pipeline: null,
configField: 'extras.promptExpansionModel',
quantized: true,
},
'automatic-speech-recognition': {
defaultModel: 'Xenova/whisper-small',
pipeline: null,
configField: 'extras.speechToTextModel',
quantized: true,
},
'text-to-speech': {
defaultModel: 'Xenova/speecht5_tts',
pipeline: null,
configField: 'extras.textToSpeechModel',
quantized: false,
},
}
@@ -72,19 +88,20 @@ function getModelForTask(task) {
/**
* Gets the transformers.js pipeline for a given task.
* @param {string} task The task to get the pipeline for
* @param {import('sillytavern-transformers').PipelineType} task The task to get the pipeline for
* @param {string} forceModel The model to use for the pipeline, if any
* @returns {Promise<Pipeline>} Pipeline for the task
*/
async function getPipeline(task) {
async function getPipeline(task, forceModel = '') {
if (tasks[task].pipeline) {
return tasks[task].pipeline;
}
const cache_dir = path.join(process.cwd(), 'cache');
const model = getModelForTask(task);
const model = forceModel || getModelForTask(task);
const localOnly = getConfigValue('extras.disableAutoDownload', false);
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });
const instance = await pipeline(task, model, { cache_dir, quantized: tasks[task].quantized ?? true, local_files_only: localOnly });
tasks[task].pipeline = instance;
return instance;
}

View File

@@ -365,7 +365,7 @@ function getImages(path) {
/**
* Pipe a fetch() response to an Express.js Response, including status code.
* @param {import('node-fetch').Response} from The Fetch API response to pipe from.
* @param {Express.Response} to The Express response to pipe to.
* @param {import('express').Response} to The Express response to pipe to.
*/
function forwardFetchResponse(from, to) {
let statusCode = from.status;
@@ -399,6 +399,64 @@ function forwardFetchResponse(from, to) {
});
}
/**
* Makes an HTTP/2 request to the specified endpoint.
*
* @deprecated Use `node-fetch` if possible.
* @param {string} endpoint URL to make the request to
* @param {string} method HTTP method to use
* @param {string} body Request body
* @param {object} headers Request headers
* @returns {Promise<string>} Response body
*/
function makeHttp2Request(endpoint, method, body, headers) {
return new Promise((resolve, reject) => {
try {
const http2 = require('http2');
const url = new URL(endpoint);
const client = http2.connect(url.origin);
const req = client.request({
':method': method,
':path': url.pathname,
...headers,
});
req.setEncoding('utf8');
req.on('response', (headers) => {
const status = Number(headers[':status']);
if (status < 200 || status >= 300) {
reject(new Error(`Request failed with status ${status}`));
}
let data = '';
req.on('data', (chunk) => {
data += chunk;
});
req.on('end', () => {
console.log(data);
resolve(data);
});
});
req.on('error', (err) => {
reject(err);
});
if (body) {
req.write(body);
}
req.end();
} catch (e) {
reject(e);
}
});
}
/**
* Adds YAML-serialized object to the object.
* @param {object} obj Object
@@ -547,4 +605,5 @@ module.exports = {
excludeKeysByYaml,
trimV1,
Cache,
makeHttp2Request,
};