mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Merge branch 'staging' of https://github.com/Cohee1207/SillyTavern into nuclaude
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
const { TEXTGEN_TYPES } = require('./constants');
|
||||
const { TEXTGEN_TYPES, OPENROUTER_HEADERS } = require('./constants');
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
const { getConfigValue } = require('./util');
|
||||
|
||||
@@ -19,6 +19,21 @@ function getTogetherAIHeaders() {
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getInfermaticAIHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.INFERMATICAI);
|
||||
|
||||
return apiKey ? ({
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getOpenRouterHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.OPENROUTER);
|
||||
const baseHeaders = { ...OPENROUTER_HEADERS };
|
||||
|
||||
return apiKey ? Object.assign(baseHeaders, { 'Authorization': `Bearer ${apiKey}` }) : baseHeaders;
|
||||
}
|
||||
|
||||
function getAphroditeHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
|
||||
|
||||
@@ -37,6 +52,14 @@ function getTabbyHeaders() {
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getOobaHeaders() {
|
||||
const apiKey = readSecret(SECRET_KEYS.OOBA);
|
||||
|
||||
return apiKey ? ({
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
}) : {};
|
||||
}
|
||||
|
||||
function getOverrideHeaders(urlHost) {
|
||||
const requestOverrides = getConfigValue('requestOverrides', []);
|
||||
const overrideHeaders = requestOverrides?.find((e) => e.hosts?.includes(urlHost))?.headers;
|
||||
@@ -69,6 +92,15 @@ function setAdditionalHeaders(request, args, server) {
|
||||
case TEXTGEN_TYPES.TOGETHERAI:
|
||||
headers = getTogetherAIHeaders();
|
||||
break;
|
||||
case TEXTGEN_TYPES.OOBA:
|
||||
headers = getOobaHeaders();
|
||||
break;
|
||||
case TEXTGEN_TYPES.INFERMATICAI:
|
||||
headers = getInfermaticAIHeaders();
|
||||
break;
|
||||
case TEXTGEN_TYPES.OPENROUTER:
|
||||
headers = getOpenRouterHeaders();
|
||||
break;
|
||||
default:
|
||||
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
|
||||
break;
|
||||
|
@@ -1,34 +1,80 @@
|
||||
const fs = require('fs');
|
||||
|
||||
const encode = require('png-chunks-encode');
|
||||
const extract = require('png-chunks-extract');
|
||||
const PNGtext = require('png-chunk-text');
|
||||
|
||||
const parse = async (cardUrl, format) => {
|
||||
/**
|
||||
* Writes Character metadata to a PNG image buffer.
|
||||
* @param {Buffer} image PNG image buffer
|
||||
* @param {string} data Character data to write
|
||||
* @returns {Buffer} PNG image buffer with metadata
|
||||
*/
|
||||
const write = (image, data) => {
|
||||
const chunks = extract(image);
|
||||
const tEXtChunks = chunks.filter(chunk => chunk.name === 'tEXt');
|
||||
|
||||
// Remove all existing tEXt chunks
|
||||
for (let tEXtChunk of tEXtChunks) {
|
||||
chunks.splice(chunks.indexOf(tEXtChunk), 1);
|
||||
}
|
||||
// Add new chunks before the IEND chunk
|
||||
const base64EncodedData = Buffer.from(data, 'utf8').toString('base64');
|
||||
chunks.splice(-1, 0, PNGtext.encode('chara', base64EncodedData));
|
||||
const newBuffer = Buffer.from(encode(chunks));
|
||||
return newBuffer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Reads Character metadata from a PNG image buffer.
|
||||
* @param {Buffer} image PNG image buffer
|
||||
* @returns {string} Character data
|
||||
*/
|
||||
const read = (image) => {
|
||||
const chunks = extract(image);
|
||||
|
||||
const textChunks = chunks.filter(function (chunk) {
|
||||
return chunk.name === 'tEXt';
|
||||
}).map(function (chunk) {
|
||||
return PNGtext.decode(chunk.data);
|
||||
});
|
||||
|
||||
if (textChunks.length === 0) {
|
||||
console.error('PNG metadata does not contain any text chunks.');
|
||||
throw new Error('No PNG metadata.');
|
||||
}
|
||||
|
||||
let index = textChunks.findIndex((chunk) => chunk.keyword.toLowerCase() == 'chara');
|
||||
|
||||
if (index === -1) {
|
||||
console.error('PNG metadata does not contain any character data.');
|
||||
throw new Error('No PNG metadata.');
|
||||
}
|
||||
|
||||
return Buffer.from(textChunks[index].text, 'base64').toString('utf8');
|
||||
};
|
||||
|
||||
/**
|
||||
* Parses a card image and returns the character metadata.
|
||||
* @param {string} cardUrl Path to the card image
|
||||
* @param {string} format File format
|
||||
* @returns {string} Character data
|
||||
*/
|
||||
const parse = (cardUrl, format) => {
|
||||
let fileFormat = format === undefined ? 'png' : format;
|
||||
|
||||
switch (fileFormat) {
|
||||
case 'png': {
|
||||
const buffer = fs.readFileSync(cardUrl);
|
||||
const chunks = extract(buffer);
|
||||
|
||||
const textChunks = chunks.filter(function (chunk) {
|
||||
return chunk.name === 'tEXt';
|
||||
}).map(function (chunk) {
|
||||
return PNGtext.decode(chunk.data);
|
||||
});
|
||||
|
||||
if (textChunks.length === 0) {
|
||||
console.error('PNG metadata does not contain any character data.');
|
||||
throw new Error('No PNG metadata.');
|
||||
}
|
||||
|
||||
return Buffer.from(textChunks[0].text, 'base64').toString('utf8');
|
||||
return read(buffer);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
throw new Error('Unsupported format');
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
parse: parse,
|
||||
parse,
|
||||
write,
|
||||
read,
|
||||
};
|
||||
|
@@ -1,5 +1,6 @@
|
||||
const DIRECTORIES = {
|
||||
worlds: 'public/worlds/',
|
||||
user: 'public/user',
|
||||
avatars: 'public/User Avatars',
|
||||
images: 'public/img/',
|
||||
userImages: 'public/user/images/',
|
||||
@@ -175,8 +176,22 @@ const TEXTGEN_TYPES = {
|
||||
TOGETHERAI: 'togetherai',
|
||||
LLAMACPP: 'llamacpp',
|
||||
OLLAMA: 'ollama',
|
||||
INFERMATICAI: 'infermaticai',
|
||||
OPENROUTER: 'openrouter',
|
||||
};
|
||||
|
||||
const INFERMATICAI_KEYS = [
|
||||
'model',
|
||||
'prompt',
|
||||
'max_tokens',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'repetition_penalty',
|
||||
'stream',
|
||||
'stop',
|
||||
];
|
||||
|
||||
// https://docs.together.ai/reference/completions
|
||||
const TOGETHERAI_KEYS = [
|
||||
'model',
|
||||
@@ -187,6 +202,7 @@ const TOGETHERAI_KEYS = [
|
||||
'top_k',
|
||||
'repetition_penalty',
|
||||
'stream',
|
||||
'stop',
|
||||
];
|
||||
|
||||
// https://github.com/jmorganca/ollama/blob/main/docs/api.md#request-with-options
|
||||
@@ -211,6 +227,29 @@ const OLLAMA_KEYS = [
|
||||
const AVATAR_WIDTH = 400;
|
||||
const AVATAR_HEIGHT = 600;
|
||||
|
||||
const OPENROUTER_HEADERS = {
|
||||
'HTTP-Referer': 'https://sillytavern.app',
|
||||
'X-Title': 'SillyTavern',
|
||||
};
|
||||
|
||||
const OPENROUTER_KEYS = [
|
||||
'max_tokens',
|
||||
'temperature',
|
||||
'top_k',
|
||||
'top_p',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'repetition_penalty',
|
||||
'min_p',
|
||||
'top_a',
|
||||
'seed',
|
||||
'logit_bias',
|
||||
'model',
|
||||
'stream',
|
||||
'prompt',
|
||||
'stop',
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
DIRECTORIES,
|
||||
UNSAFE_EXTENSIONS,
|
||||
@@ -223,4 +262,7 @@ module.exports = {
|
||||
AVATAR_HEIGHT,
|
||||
TOGETHERAI_KEYS,
|
||||
OLLAMA_KEYS,
|
||||
INFERMATICAI_KEYS,
|
||||
OPENROUTER_HEADERS,
|
||||
OPENROUTER_KEYS,
|
||||
};
|
||||
|
@@ -1,6 +1,7 @@
|
||||
const TASK = 'feature-extraction';
|
||||
|
||||
/**
|
||||
* Gets the vectorized text in form of an array of numbers.
|
||||
* @param {string} text - The text to vectorize
|
||||
* @returns {Promise<number[]>} - The vectorized text in form of an array of numbers
|
||||
*/
|
||||
@@ -12,6 +13,20 @@ async function getTransformersVector(text) {
|
||||
return vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vectorized texts in form of an array of arrays of numbers.
|
||||
* @param {string[]} texts - The texts to vectorize
|
||||
* @returns {Promise<number[][]>} - The vectorized texts in form of an array of arrays of numbers
|
||||
*/
|
||||
async function getTransformersBatchVector(texts) {
|
||||
const result = [];
|
||||
for (const text of texts) {
|
||||
result.push(await getTransformersVector(text));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getTransformersVector,
|
||||
getTransformersBatchVector,
|
||||
};
|
||||
|
@@ -8,7 +8,7 @@ const { DIRECTORIES, UNSAFE_EXTENSIONS } = require('../constants');
|
||||
const { jsonParser } = require('../express-common');
|
||||
const { clientRelativePath } = require('../util');
|
||||
|
||||
const VALID_CATEGORIES = ['bgm', 'ambient', 'blip', 'live2d'];
|
||||
const VALID_CATEGORIES = ['bgm', 'ambient', 'blip', 'live2d', 'vrm', 'character'];
|
||||
|
||||
/**
|
||||
* Validates the input filename for the asset.
|
||||
@@ -106,6 +106,33 @@ router.post('/get', jsonParser, async (_, response) => {
|
||||
continue;
|
||||
}
|
||||
|
||||
// VRM assets
|
||||
if (folder == 'vrm') {
|
||||
output[folder] = { 'model': [], 'animation': [] };
|
||||
// Extract models
|
||||
const vrm_model_folder = path.normalize(path.join(folderPath, 'vrm', 'model'));
|
||||
let files = getFiles(vrm_model_folder);
|
||||
//console.debug("FILE FOUND:",files)
|
||||
for (let file of files) {
|
||||
if (!file.endsWith('.placeholder')) {
|
||||
//console.debug("Asset VRM model found:",file)
|
||||
output['vrm']['model'].push(clientRelativePath(file));
|
||||
}
|
||||
}
|
||||
|
||||
// Extract models
|
||||
const vrm_animation_folder = path.normalize(path.join(folderPath, 'vrm', 'animation'));
|
||||
files = getFiles(vrm_animation_folder);
|
||||
//console.debug("FILE FOUND:",files)
|
||||
for (let file of files) {
|
||||
if (!file.endsWith('.placeholder')) {
|
||||
//console.debug("Asset VRM animation found:",file)
|
||||
output['vrm']['animation'].push(clientRelativePath(file));
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Other assets (bgm/ambient/blip)
|
||||
const files = fs.readdirSync(path.join(folderPath, folder))
|
||||
.filter(filename => {
|
||||
@@ -172,6 +199,13 @@ router.post('/download', jsonParser, async (request, response) => {
|
||||
const fileStream = fs.createWriteStream(destination, { flags: 'wx' });
|
||||
await finished(res.body.pipe(fileStream));
|
||||
|
||||
if (category === 'character') {
|
||||
response.sendFile(temp_path, { root: process.cwd() }, () => {
|
||||
fs.rmSync(temp_path);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Move into asset place
|
||||
console.debug('Download finished, moving file from', temp_path, 'to', file_path);
|
||||
fs.renameSync(temp_path, file_path);
|
||||
|
@@ -3,7 +3,7 @@ const fetch = require('node-fetch').default;
|
||||
const { Readable } = require('stream');
|
||||
|
||||
const { jsonParser } = require('../../express-common');
|
||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY } = require('../../constants');
|
||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
||||
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
||||
const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
|
||||
|
||||
@@ -12,7 +12,7 @@ const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sente
|
||||
|
||||
const API_OPENAI = 'https://api.openai.com/v1';
|
||||
const API_CLAUDE = 'https://api.anthropic.com/v1';
|
||||
|
||||
const API_MISTRAL = 'https://api.mistral.ai/v1';
|
||||
/**
|
||||
* Sends a request to Claude API.
|
||||
* @param {express.Request} request Express request
|
||||
@@ -36,9 +36,10 @@ async function sendClaudeRequest(request, response) {
|
||||
});
|
||||
|
||||
const isSysPromptSupported = request.body.model === 'claude-2' || request.body.model === 'claude-2.1';
|
||||
const requestPrompt = convertClaudePrompt(request.body.messages, !request.body.exclude_assistant, request.body.assistant_prefill, isSysPromptSupported, request.body.claude_use_sysprompt, request.body.human_sysprompt_message);
|
||||
const requestPrompt = convertClaudePrompt(request.body.messages, !request.body.exclude_assistant, request.body.assistant_prefill, isSysPromptSupported, request.body.claude_use_sysprompt, request.body.human_sysprompt_message, request.body.claude_exclude_prefixes);
|
||||
|
||||
// Check Claude messages sequence and prefixes presence.
|
||||
let sequenceError = [];
|
||||
const sequence = requestPrompt.split('\n').filter(x => x.startsWith('Human:') || x.startsWith('Assistant:'));
|
||||
const humanFound = sequence.some(line => line.startsWith('Human:'));
|
||||
const assistantFound = sequence.some(line => line.startsWith('Assistant:'));
|
||||
@@ -56,20 +57,20 @@ async function sendClaudeRequest(request, response) {
|
||||
}
|
||||
|
||||
if (!humanFound) {
|
||||
console.log(color.red(`${divider}\nWarning: No 'Human:' prefix found in the prompt.\n${divider}`));
|
||||
sequenceError.push(`${divider}\nWarning: No 'Human:' prefix found in the prompt.\n${divider}`);
|
||||
}
|
||||
if (!assistantFound) {
|
||||
console.log(color.red(`${divider}\nWarning: No 'Assistant: ' prefix found in the prompt.\n${divider}`));
|
||||
sequenceError.push(`${divider}\nWarning: No 'Assistant: ' prefix found in the prompt.\n${divider}`);
|
||||
}
|
||||
if (!sequence[0].startsWith('Human:')) {
|
||||
console.log(color.red(`${divider}\nWarning: The messages sequence should start with 'Human:' prefix.\nMake sure you have 'Human:' prefix at the very beggining of the prompt, or after the system prompt.\n${divider}`));
|
||||
if (sequence[0] && !sequence[0].startsWith('Human:')) {
|
||||
sequenceError.push(`${divider}\nWarning: The messages sequence should start with 'Human:' prefix.\nMake sure you have '\\n\\nHuman:' prefix at the very beggining of the prompt, or after the system prompt.\n${divider}`);
|
||||
}
|
||||
if (humanErrorCount > 0 || assistantErrorCount > 0) {
|
||||
console.log(color.red(`${divider}\nWarning: Detected incorrect Prefix sequence(s).`));
|
||||
console.log(color.red(`Incorrect "Human:" prefix(es): ${humanErrorCount}.\nIncorrect "Assistant: " prefix(es): ${assistantErrorCount}.`));
|
||||
console.log(color.red('Check the prompt above and fix it in the SillyTavern.'));
|
||||
console.log(color.red('\nThe correct sequence should look like this:\nSystem prompt <-(for the sysprompt format only, else have 2 empty lines above the first human\'s message.)'));
|
||||
console.log(color.red(` <-----(Each message beginning with the "Assistant:/Human:" prefix must have one empty line above.)\nHuman:\n\nAssistant:\n...\n\nHuman:\n\nAssistant:\n${divider}`));
|
||||
sequenceError.push(`${divider}\nWarning: Detected incorrect Prefix sequence(s).`);
|
||||
sequenceError.push(`Incorrect "Human:" prefix(es): ${humanErrorCount}.\nIncorrect "Assistant: " prefix(es): ${assistantErrorCount}.`);
|
||||
sequenceError.push('Check the prompt above and fix it in the SillyTavern.');
|
||||
sequenceError.push('\nThe correct sequence in the console should look like this:\n(System prompt msg) <-(for the sysprompt format only, else have \\n\\n above the first human\'s message.)');
|
||||
sequenceError.push(`\\n + <-----(Each message beginning with the "Assistant:/Human:" prefix must have \\n\\n before it.)\n\\n +\nHuman: \\n +\n\\n +\nAssistant: \\n +\n...\n\\n +\nHuman: \\n +\n\\n +\nAssistant: \n${divider}`);
|
||||
}
|
||||
|
||||
// Add custom stop sequences
|
||||
@@ -91,6 +92,10 @@ async function sendClaudeRequest(request, response) {
|
||||
|
||||
console.log('Claude request:', requestBody);
|
||||
|
||||
sequenceError.forEach(sequenceError => {
|
||||
console.log(color.red(sequenceError));
|
||||
});
|
||||
|
||||
const generateResponse = await fetch(apiUrl + '/complete', {
|
||||
method: 'POST',
|
||||
signal: controller.signal,
|
||||
@@ -262,7 +267,7 @@ async function sendMakerSuiteRequest(request, response) {
|
||||
? (stream ? 'streamGenerateContent' : 'generateContent')
|
||||
: (isText ? 'generateText' : 'generateMessage');
|
||||
|
||||
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/${apiVersion}/models/${model}:${responseType}?key=${apiKey}`, {
|
||||
const generateResponse = await fetch(`https://generativelanguage.googleapis.com/${apiVersion}/models/${model}:${responseType}?key=${apiKey}${stream ? '&alt=sse' : ''}`, {
|
||||
body: JSON.stringify(body),
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -274,36 +279,8 @@ async function sendMakerSuiteRequest(request, response) {
|
||||
// have to do this because of their busted ass streaming endpoint
|
||||
if (stream) {
|
||||
try {
|
||||
let partialData = '';
|
||||
generateResponse.body.on('data', (data) => {
|
||||
const chunk = data.toString();
|
||||
if (chunk.startsWith(',') || chunk.endsWith(',') || chunk.startsWith('[') || chunk.endsWith(']')) {
|
||||
partialData = chunk.slice(1);
|
||||
} else {
|
||||
partialData += chunk;
|
||||
}
|
||||
while (true) {
|
||||
let json;
|
||||
try {
|
||||
json = JSON.parse(partialData);
|
||||
} catch (e) {
|
||||
break;
|
||||
}
|
||||
response.write(JSON.stringify(json));
|
||||
partialData = '';
|
||||
}
|
||||
});
|
||||
|
||||
request.socket.on('close', function () {
|
||||
if (generateResponse.body instanceof Readable) generateResponse.body.destroy();
|
||||
response.end();
|
||||
});
|
||||
|
||||
generateResponse.body.on('end', () => {
|
||||
console.log('Streaming request finished');
|
||||
response.end();
|
||||
});
|
||||
|
||||
// Pipe remote SSE stream to Express response
|
||||
forwardFetchResponse(generateResponse, response);
|
||||
} catch (error) {
|
||||
console.log('Error forwarding streaming response:', error);
|
||||
if (!response.headersSent) {
|
||||
@@ -329,7 +306,7 @@ async function sendMakerSuiteRequest(request, response) {
|
||||
}
|
||||
|
||||
const responseContent = candidates[0].content ?? candidates[0].output;
|
||||
const responseText = typeof responseContent === 'string' ? responseContent : responseContent.parts?.[0]?.text;
|
||||
const responseText = typeof responseContent === 'string' ? responseContent : responseContent?.parts?.[0]?.text;
|
||||
if (!responseText) {
|
||||
let message = 'MakerSuite Candidate text empty';
|
||||
console.log(message, generateResponseJson);
|
||||
@@ -431,7 +408,8 @@ async function sendAI21Request(request, response) {
|
||||
* @param {express.Response} response Express response
|
||||
*/
|
||||
async function sendMistralAIRequest(request, response) {
|
||||
const apiKey = readSecret(SECRET_KEYS.MISTRALAI);
|
||||
const apiUrl = new URL(request.body.reverse_proxy || API_MISTRAL).toString();
|
||||
const apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.MISTRALAI);
|
||||
|
||||
if (!apiKey) {
|
||||
console.log('MistralAI API key is missing.');
|
||||
@@ -441,9 +419,12 @@ async function sendMistralAIRequest(request, response) {
|
||||
try {
|
||||
//must send a user role as last message
|
||||
const messages = Array.isArray(request.body.messages) ? request.body.messages : [];
|
||||
//large seems to be throwing a 500 error if we don't make the first message a user role, most likely a bug since the other models won't do this
|
||||
if (request.body.model.includes('large'))
|
||||
messages[0].role = 'user';
|
||||
const lastMsg = messages[messages.length - 1];
|
||||
if (messages.length > 0 && lastMsg && (lastMsg.role === 'system' || lastMsg.role === 'assistant')) {
|
||||
if (lastMsg.role === 'assistant') {
|
||||
if (lastMsg.role === 'assistant' && lastMsg.name) {
|
||||
lastMsg.content = lastMsg.name + ': ' + lastMsg.content;
|
||||
} else if (lastMsg.role === 'system') {
|
||||
lastMsg.content = '[INST] ' + lastMsg.content + ' [/INST]';
|
||||
@@ -478,7 +459,7 @@ async function sendMistralAIRequest(request, response) {
|
||||
'top_p': request.body.top_p,
|
||||
'max_tokens': request.body.max_tokens,
|
||||
'stream': request.body.stream,
|
||||
'safe_mode': request.body.safe_mode,
|
||||
'safe_prompt': request.body.safe_prompt,
|
||||
'random_seed': request.body.seed === -1 ? undefined : request.body.seed,
|
||||
};
|
||||
|
||||
@@ -495,7 +476,7 @@ async function sendMistralAIRequest(request, response) {
|
||||
|
||||
console.log('MisralAI request:', requestBody);
|
||||
|
||||
const generateResponse = await fetch('https://api.mistral.ai/v1/chat/completions', config);
|
||||
const generateResponse = await fetch(apiUrl + '/chat/completions', config);
|
||||
if (request.body.stream) {
|
||||
forwardFetchResponse(generateResponse, response);
|
||||
} else {
|
||||
@@ -535,11 +516,12 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER) {
|
||||
api_url = 'https://openrouter.ai/api/v1';
|
||||
api_key_openai = readSecret(SECRET_KEYS.OPENROUTER);
|
||||
// OpenRouter needs to pass the referer: https://openrouter.ai/docs
|
||||
headers = { 'HTTP-Referer': request.headers.referer };
|
||||
// OpenRouter needs to pass the Referer and X-Title: https://openrouter.ai/docs#requests
|
||||
headers = { ...OPENROUTER_HEADERS };
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.MISTRALAI) {
|
||||
api_url = 'https://api.mistral.ai/v1';
|
||||
api_key_openai = readSecret(SECRET_KEYS.MISTRALAI);
|
||||
api_url = new URL(request.body.reverse_proxy || API_MISTRAL).toString();
|
||||
api_key_openai = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.MISTRALAI);
|
||||
headers = {};
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.CUSTOM) {
|
||||
api_url = request.body.custom_url;
|
||||
api_key_openai = readSecret(SECRET_KEYS.CUSTOM);
|
||||
@@ -699,12 +681,21 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
let apiKey;
|
||||
let headers;
|
||||
let bodyParams;
|
||||
const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string';
|
||||
|
||||
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENAI) {
|
||||
apiUrl = new URL(request.body.reverse_proxy || API_OPENAI).toString();
|
||||
apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI);
|
||||
headers = {};
|
||||
bodyParams = {};
|
||||
bodyParams = {
|
||||
logprobs: request.body.logprobs,
|
||||
};
|
||||
|
||||
// Adjust logprobs params for Chat Completions API, which expects { top_logprobs: number; logprobs: boolean; }
|
||||
if (!isTextCompletion && bodyParams.logprobs > 0) {
|
||||
bodyParams.top_logprobs = bodyParams.logprobs;
|
||||
bodyParams.logprobs = true;
|
||||
}
|
||||
|
||||
if (getConfigValue('openai.randomizeUserId', false)) {
|
||||
bodyParams['user'] = uuidv4();
|
||||
@@ -712,10 +703,22 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER) {
|
||||
apiUrl = 'https://openrouter.ai/api/v1';
|
||||
apiKey = readSecret(SECRET_KEYS.OPENROUTER);
|
||||
// OpenRouter needs to pass the referer: https://openrouter.ai/docs
|
||||
headers = { 'HTTP-Referer': request.headers.referer };
|
||||
// OpenRouter needs to pass the Referer and X-Title: https://openrouter.ai/docs#requests
|
||||
headers = { ...OPENROUTER_HEADERS };
|
||||
bodyParams = { 'transforms': ['middle-out'] };
|
||||
|
||||
if (request.body.min_p !== undefined) {
|
||||
bodyParams['min_p'] = request.body.min_p;
|
||||
}
|
||||
|
||||
if (request.body.top_a !== undefined) {
|
||||
bodyParams['top_a'] = request.body.top_a;
|
||||
}
|
||||
|
||||
if (request.body.repetition_penalty !== undefined) {
|
||||
bodyParams['repetition_penalty'] = request.body.repetition_penalty;
|
||||
}
|
||||
|
||||
if (request.body.use_fallback) {
|
||||
bodyParams['route'] = 'fallback';
|
||||
}
|
||||
@@ -723,7 +726,16 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
apiUrl = request.body.custom_url;
|
||||
apiKey = readSecret(SECRET_KEYS.CUSTOM);
|
||||
headers = {};
|
||||
bodyParams = {};
|
||||
bodyParams = {
|
||||
logprobs: request.body.logprobs,
|
||||
};
|
||||
|
||||
// Adjust logprobs params for Chat Completions API, which expects { top_logprobs: number; logprobs: boolean; }
|
||||
if (!isTextCompletion && bodyParams.logprobs > 0) {
|
||||
bodyParams.top_logprobs = bodyParams.logprobs;
|
||||
bodyParams.logprobs = true;
|
||||
}
|
||||
|
||||
mergeObjectWithYaml(bodyParams, request.body.custom_include_body);
|
||||
mergeObjectWithYaml(headers, request.body.custom_include_headers);
|
||||
} else {
|
||||
@@ -741,7 +753,6 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
bodyParams['stop'] = request.body.stop;
|
||||
}
|
||||
|
||||
const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model)) || typeof request.body.messages === 'string';
|
||||
const textPrompt = isTextCompletion ? convertTextCompletionPrompt(request.body.messages) : '';
|
||||
const endpointUrl = isTextCompletion && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER ?
|
||||
`${apiUrl}/completions` :
|
||||
@@ -767,6 +778,7 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
'stop': isTextCompletion === false ? request.body.stop : undefined,
|
||||
'logit_bias': request.body.logit_bias,
|
||||
'seed': request.body.seed,
|
||||
'n': request.body.n,
|
||||
...bodyParams,
|
||||
};
|
||||
|
||||
@@ -813,7 +825,7 @@ router.post('/generate', jsonParser, function (request, response) {
|
||||
let json = await fetchResponse.json();
|
||||
response.send(json);
|
||||
console.log(json);
|
||||
console.log(json?.choices[0]?.message);
|
||||
console.log(json?.choices?.[0]?.message);
|
||||
} else if (fetchResponse.status === 429 && retries > 0) {
|
||||
console.log(`Out of quota, retrying in ${Math.round(timeout / 1000)}s`);
|
||||
setTimeout(() => {
|
||||
|
@@ -4,7 +4,7 @@ const _ = require('lodash');
|
||||
const Readable = require('stream').Readable;
|
||||
|
||||
const { jsonParser } = require('../../express-common');
|
||||
const { TEXTGEN_TYPES, TOGETHERAI_KEYS, OLLAMA_KEYS } = require('../../constants');
|
||||
const { TEXTGEN_TYPES, TOGETHERAI_KEYS, OLLAMA_KEYS, INFERMATICAI_KEYS, OPENROUTER_KEYS } = require('../../constants');
|
||||
const { forwardFetchResponse, trimV1 } = require('../../util');
|
||||
const { setAdditionalHeaders } = require('../../additional-headers');
|
||||
|
||||
@@ -106,6 +106,8 @@ router.post('/status', jsonParser, async function (request, response) {
|
||||
case TEXTGEN_TYPES.APHRODITE:
|
||||
case TEXTGEN_TYPES.KOBOLDCPP:
|
||||
case TEXTGEN_TYPES.LLAMACPP:
|
||||
case TEXTGEN_TYPES.INFERMATICAI:
|
||||
case TEXTGEN_TYPES.OPENROUTER:
|
||||
url += '/v1/models';
|
||||
break;
|
||||
case TEXTGEN_TYPES.MANCER:
|
||||
@@ -208,6 +210,7 @@ router.post('/generate', jsonParser, async function (request, response) {
|
||||
request.body.api_server = request.body.api_server.replace('localhost', '127.0.0.1');
|
||||
}
|
||||
|
||||
const apiType = request.body.api_type;
|
||||
const baseUrl = request.body.api_server;
|
||||
console.log(request.body);
|
||||
|
||||
@@ -232,6 +235,7 @@ router.post('/generate', jsonParser, async function (request, response) {
|
||||
case TEXTGEN_TYPES.TABBY:
|
||||
case TEXTGEN_TYPES.KOBOLDCPP:
|
||||
case TEXTGEN_TYPES.TOGETHERAI:
|
||||
case TEXTGEN_TYPES.INFERMATICAI:
|
||||
url += '/v1/completions';
|
||||
break;
|
||||
case TEXTGEN_TYPES.MANCER:
|
||||
@@ -243,6 +247,9 @@ router.post('/generate', jsonParser, async function (request, response) {
|
||||
case TEXTGEN_TYPES.OLLAMA:
|
||||
url += '/api/generate';
|
||||
break;
|
||||
case TEXTGEN_TYPES.OPENROUTER:
|
||||
url += '/v1/chat/completions';
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,11 +264,17 @@ router.post('/generate', jsonParser, async function (request, response) {
|
||||
setAdditionalHeaders(request, args, baseUrl);
|
||||
|
||||
if (request.body.api_type === TEXTGEN_TYPES.TOGETHERAI) {
|
||||
const stop = Array.isArray(request.body.stop) ? request.body.stop[0] : '';
|
||||
request.body = _.pickBy(request.body, (_, key) => TOGETHERAI_KEYS.includes(key));
|
||||
if (typeof stop === 'string' && stop.length > 0) {
|
||||
request.body.stop = stop;
|
||||
}
|
||||
args.body = JSON.stringify(request.body);
|
||||
}
|
||||
|
||||
if (request.body.api_type === TEXTGEN_TYPES.INFERMATICAI) {
|
||||
request.body = _.pickBy(request.body, (_, key) => INFERMATICAI_KEYS.includes(key));
|
||||
args.body = JSON.stringify(request.body);
|
||||
}
|
||||
|
||||
if (request.body.api_type === TEXTGEN_TYPES.OPENROUTER) {
|
||||
request.body = _.pickBy(request.body, (_, key) => OPENROUTER_KEYS.includes(key));
|
||||
args.body = JSON.stringify(request.body);
|
||||
}
|
||||
|
||||
@@ -270,6 +283,7 @@ router.post('/generate', jsonParser, async function (request, response) {
|
||||
model: request.body.model,
|
||||
prompt: request.body.prompt,
|
||||
stream: request.body.stream ?? false,
|
||||
keep_alive: -1,
|
||||
raw: true,
|
||||
options: _.pickBy(request.body, (_, key) => OLLAMA_KEYS.includes(key)),
|
||||
});
|
||||
@@ -296,6 +310,11 @@ router.post('/generate', jsonParser, async function (request, response) {
|
||||
data['choices'] = [{ text }];
|
||||
}
|
||||
|
||||
// Map InfermaticAI response to OAI completions format
|
||||
if (apiType === TEXTGEN_TYPES.INFERMATICAI) {
|
||||
data['choices'] = (data?.choices || []).map(choice => ({ text: choice.message.content }));
|
||||
}
|
||||
|
||||
return response.send(data);
|
||||
} else {
|
||||
const text = await completionsReply.text();
|
||||
|
@@ -7,9 +7,6 @@ const writeFileAtomicSync = require('write-file-atomic').sync;
|
||||
const yaml = require('yaml');
|
||||
const _ = require('lodash');
|
||||
|
||||
const encode = require('png-chunks-encode');
|
||||
const extract = require('png-chunks-extract');
|
||||
const PNGtext = require('png-chunk-text');
|
||||
const jimp = require('jimp');
|
||||
|
||||
const { DIRECTORIES, UPLOADS_PATH, AVATAR_WIDTH, AVATAR_HEIGHT } = require('../constants');
|
||||
@@ -33,7 +30,7 @@ const characterDataCache = new Map();
|
||||
* @param {string} input_format - 'png'
|
||||
* @returns {Promise<string | undefined>} - Character card data
|
||||
*/
|
||||
async function charaRead(img_url, input_format) {
|
||||
async function charaRead(img_url, input_format = 'png') {
|
||||
const stat = fs.statSync(img_url);
|
||||
const cacheKey = `${img_url}-${stat.mtimeMs}`;
|
||||
if (characterDataCache.has(cacheKey)) {
|
||||
@@ -59,22 +56,12 @@ async function charaWrite(img_url, data, target_img, response = undefined, mes =
|
||||
}
|
||||
}
|
||||
// Read the image, resize, and save it as a PNG into the buffer
|
||||
const image = await tryReadImage(img_url, crop);
|
||||
const inputImage = await tryReadImage(img_url, crop);
|
||||
|
||||
// Get the chunks
|
||||
const chunks = extract(image);
|
||||
const tEXtChunks = chunks.filter(chunk => chunk.name === 'tEXt');
|
||||
const outputImage = characterCardParser.write(inputImage, data);
|
||||
|
||||
// Remove all existing tEXt chunks
|
||||
for (let tEXtChunk of tEXtChunks) {
|
||||
chunks.splice(chunks.indexOf(tEXtChunk), 1);
|
||||
}
|
||||
// Add new chunks before the IEND chunk
|
||||
const base64EncodedData = Buffer.from(data, 'utf8').toString('base64');
|
||||
chunks.splice(-1, 0, PNGtext.encode('chara', base64EncodedData));
|
||||
//chunks.splice(-1, 0, text.encode('lorem', 'ipsum'));
|
||||
|
||||
writeFileAtomicSync(DIRECTORIES.characters + target_img + '.png', Buffer.from(encode(chunks)));
|
||||
writeFileAtomicSync(DIRECTORIES.characters + target_img + '.png', outputImage);
|
||||
if (response !== undefined) response.send(mes);
|
||||
return true;
|
||||
} catch (err) {
|
||||
@@ -152,13 +139,13 @@ const processCharacter = async (item, i) => {
|
||||
const img_data = await charaRead(DIRECTORIES.characters + item);
|
||||
if (img_data === undefined) throw new Error('Failed to read character file');
|
||||
|
||||
let jsonObject = getCharaCardV2(JSON.parse(img_data));
|
||||
let jsonObject = getCharaCardV2(JSON.parse(img_data), false);
|
||||
jsonObject.avatar = item;
|
||||
characters[i] = jsonObject;
|
||||
characters[i]['json_data'] = img_data;
|
||||
const charStat = fs.statSync(path.join(DIRECTORIES.characters, item));
|
||||
characters[i]['date_added'] = charStat.birthtimeMs;
|
||||
characters[i]['create_date'] = jsonObject['create_date'] || humanizedISO8601DateTime(charStat.birthtimeMs);
|
||||
characters[i]['date_added'] = charStat.ctimeMs;
|
||||
characters[i]['create_date'] = jsonObject['create_date'] || humanizedISO8601DateTime(charStat.ctimeMs);
|
||||
const char_dir = path.join(DIRECTORIES.chats, item.replace('.png', ''));
|
||||
|
||||
const { chatSize, dateLastChat } = calculateChatSize(char_dir);
|
||||
@@ -183,15 +170,30 @@ const processCharacter = async (item, i) => {
|
||||
}
|
||||
};
|
||||
|
||||
function getCharaCardV2(jsonObject) {
|
||||
/**
|
||||
* Convert a character object to Spec V2 format.
|
||||
* @param {object} jsonObject Character object
|
||||
* @param {boolean} hoistDate Will set the chat and create_date fields to the current date if they are missing
|
||||
* @returns {object} Character object in Spec V2 format
|
||||
*/
|
||||
function getCharaCardV2(jsonObject, hoistDate = true) {
|
||||
if (jsonObject.spec === undefined) {
|
||||
jsonObject = convertToV2(jsonObject);
|
||||
|
||||
if (hoistDate && !jsonObject.create_date) {
|
||||
jsonObject.create_date = humanizedISO8601DateTime();
|
||||
}
|
||||
} else {
|
||||
jsonObject = readFromV2(jsonObject);
|
||||
}
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a character object to Spec V2 format.
|
||||
* @param {object} char Character object
|
||||
* @returns {object} Character object in Spec V2 format
|
||||
*/
|
||||
function convertToV2(char) {
|
||||
// Simulate incoming data from frontend form
|
||||
const result = charaFormatData({
|
||||
@@ -212,7 +214,8 @@ function convertToV2(char) {
|
||||
});
|
||||
|
||||
result.chat = char.chat ?? humanizedISO8601DateTime();
|
||||
result.create_date = char.create_date ?? humanizedISO8601DateTime();
|
||||
result.create_date = char.create_date;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -300,6 +303,7 @@ function charaFormatData(data) {
|
||||
_.set(char, 'chat', data.ch_name + ' - ' + humanizedISO8601DateTime());
|
||||
_.set(char, 'talkativeness', data.talkativeness);
|
||||
_.set(char, 'fav', data.fav == 'true');
|
||||
_.set(char, 'tags', typeof data.tags == 'string' ? (data.tags.split(',').map(x => x.trim()).filter(x => x)) : data.tags || []);
|
||||
|
||||
// Spec V2 fields
|
||||
_.set(char, 'spec', 'chara_card_v2');
|
||||
@@ -336,7 +340,7 @@ function charaFormatData(data) {
|
||||
|
||||
if (data.world) {
|
||||
try {
|
||||
const file = readWorldInfoFile(data.world);
|
||||
const file = readWorldInfoFile(data.world, false);
|
||||
|
||||
// File was imported - save it to the character book
|
||||
if (file && file.originalData) {
|
||||
@@ -387,6 +391,11 @@ function convertWorldInfoToCharacterBook(name, entries) {
|
||||
depth: entry.depth ?? 4,
|
||||
selectiveLogic: entry.selectiveLogic ?? 0,
|
||||
group: entry.group ?? '',
|
||||
prevent_recursion: entry.preventRecursion ?? false,
|
||||
scan_depth: entry.scanDepth ?? null,
|
||||
match_whole_words: entry.matchWholeWords ?? null,
|
||||
case_sensitive: entry.caseSensitive ?? null,
|
||||
automation_id: entry.automationId ?? '',
|
||||
},
|
||||
};
|
||||
|
||||
@@ -791,6 +800,17 @@ function getPngName(file) {
|
||||
return file;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the preserved name for the uploaded file if the request is valid.
|
||||
* @param {import("express").Request} request - Express request object
|
||||
* @returns {string | undefined} - The preserved name if the request is valid, otherwise undefined
|
||||
*/
|
||||
function getPreservedName(request) {
|
||||
return request.body.file_type === 'png' && request.body.preserve_file_name === 'true' && request.file?.originalname
|
||||
? path.parse(request.file.originalname).name
|
||||
: undefined;
|
||||
}
|
||||
|
||||
router.post('/import', urlencodedParser, async function (request, response) {
|
||||
if (!request.body || !request.file) return response.sendStatus(400);
|
||||
|
||||
@@ -798,6 +818,7 @@ router.post('/import', urlencodedParser, async function (request, response) {
|
||||
let filedata = request.file;
|
||||
let uploadPath = path.join(UPLOADS_PATH, filedata.filename);
|
||||
let format = request.body.file_type;
|
||||
const preservedFileName = getPreservedName(request);
|
||||
|
||||
if (format == 'yaml' || format == 'yml') {
|
||||
try {
|
||||
@@ -889,7 +910,7 @@ router.post('/import', urlencodedParser, async function (request, response) {
|
||||
let jsonData = JSON.parse(img_data);
|
||||
|
||||
jsonData.name = sanitize(jsonData.data?.name || jsonData.name);
|
||||
png_name = getPngName(jsonData.name);
|
||||
png_name = preservedFileName || getPngName(jsonData.name);
|
||||
|
||||
if (jsonData.spec !== undefined) {
|
||||
console.log('Found a v2 character file.');
|
||||
@@ -1003,7 +1024,7 @@ router.post('/export', jsonParser, async function (request, response) {
|
||||
let json = await charaRead(filename);
|
||||
if (json === undefined) return response.sendStatus(400);
|
||||
let jsonObject = getCharaCardV2(JSON.parse(json));
|
||||
return response.type('json').send(jsonObject);
|
||||
return response.type('json').send(JSON.stringify(jsonObject, null, 4));
|
||||
}
|
||||
catch {
|
||||
return response.sendStatus(400);
|
||||
|
@@ -10,6 +10,7 @@ const contentLogPath = path.join(contentDirectory, 'content.log');
|
||||
const contentIndexPath = path.join(contentDirectory, 'index.json');
|
||||
const { DIRECTORIES } = require('../constants');
|
||||
const presetFolders = [DIRECTORIES.koboldAI_Settings, DIRECTORIES.openAI_Settings, DIRECTORIES.novelAI_Settings, DIRECTORIES.textGen_Settings];
|
||||
const characterCardParser = require('../character-card-parser.js');
|
||||
|
||||
/**
|
||||
* Gets the default presets from the content directory.
|
||||
@@ -219,6 +220,56 @@ async function downloadChubCharacter(id) {
|
||||
return { buffer, fileName, fileType };
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads a character card from the Pygsite.
|
||||
* @param {string} id UUID of the character
|
||||
* @returns {Promise<{buffer: Buffer, fileName: string, fileType: string}>}
|
||||
*/
|
||||
async function downloadPygmalionCharacter(id) {
|
||||
const result = await fetch(`https://server.pygmalion.chat/api/export/character/${id}/v2`);
|
||||
|
||||
if (!result.ok) {
|
||||
const text = await result.text();
|
||||
console.log('Pygsite returned error', result.status, text);
|
||||
throw new Error('Failed to download character');
|
||||
}
|
||||
|
||||
const jsonData = await result.json();
|
||||
const characterData = jsonData?.character;
|
||||
|
||||
if (!characterData || typeof characterData !== 'object') {
|
||||
console.error('Pygsite returned invalid character data', jsonData);
|
||||
throw new Error('Failed to download character');
|
||||
}
|
||||
|
||||
try {
|
||||
const avatarUrl = characterData?.data?.avatar;
|
||||
|
||||
if (!avatarUrl) {
|
||||
console.error('Pygsite character does not have an avatar', characterData);
|
||||
throw new Error('Failed to download avatar');
|
||||
}
|
||||
|
||||
const avatarResult = await fetch(avatarUrl);
|
||||
const avatarBuffer = await avatarResult.buffer();
|
||||
|
||||
const cardBuffer = characterCardParser.write(avatarBuffer, JSON.stringify(characterData));
|
||||
|
||||
return {
|
||||
buffer: cardBuffer,
|
||||
fileName: `${sanitize(id)}.png`,
|
||||
fileType: 'image/png',
|
||||
};
|
||||
} catch (e) {
|
||||
console.error('Failed to download avatar, using JSON instead', e);
|
||||
return {
|
||||
buffer: Buffer.from(JSON.stringify(jsonData)),
|
||||
fileName: `${sanitize(id)}.json`,
|
||||
fileType: 'application/json',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {String} str
|
||||
@@ -294,7 +345,7 @@ async function downloadJannyCharacter(uuid) {
|
||||
* @param {String} url
|
||||
* @returns {String | null } UUID of the character
|
||||
*/
|
||||
function parseJannyUrl(url) {
|
||||
function getUuidFromUrl(url) {
|
||||
// Extract UUID from URL
|
||||
const uuidRegex = /[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}/;
|
||||
const matches = url.match(uuidRegex);
|
||||
@@ -306,7 +357,7 @@ function parseJannyUrl(url) {
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/import', jsonParser, async (request, response) => {
|
||||
router.post('/importURL', jsonParser, async (request, response) => {
|
||||
if (!request.body.url) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
@@ -317,8 +368,18 @@ router.post('/import', jsonParser, async (request, response) => {
|
||||
let type;
|
||||
|
||||
const isJannnyContent = url.includes('janitorai');
|
||||
if (isJannnyContent) {
|
||||
const uuid = parseJannyUrl(url);
|
||||
const isPygmalionContent = url.includes('pygmalion.chat');
|
||||
|
||||
if (isPygmalionContent) {
|
||||
const uuid = getUuidFromUrl(url);
|
||||
if (!uuid) {
|
||||
return response.sendStatus(404);
|
||||
}
|
||||
|
||||
type = 'character';
|
||||
result = await downloadPygmalionCharacter(uuid);
|
||||
} else if (isJannnyContent) {
|
||||
const uuid = getUuidFromUrl(url);
|
||||
if (!uuid) {
|
||||
return response.sendStatus(404);
|
||||
}
|
||||
@@ -352,6 +413,49 @@ router.post('/import', jsonParser, async (request, response) => {
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/importUUID', jsonParser, async (request, response) => {
|
||||
if (!request.body.url) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
try {
|
||||
const uuid = request.body.url;
|
||||
let result;
|
||||
|
||||
const isJannny = uuid.includes('_character');
|
||||
const isPygmalion = (!isJannny && uuid.length == 36);
|
||||
const uuidType = uuid.includes('lorebook') ? 'lorebook' : 'character';
|
||||
|
||||
if (isPygmalion) {
|
||||
console.log('Downloading Pygmalion character:', uuid);
|
||||
result = await downloadPygmalionCharacter(uuid);
|
||||
} else if (isJannny) {
|
||||
console.log('Downloading Janitor character:', uuid.split('_')[0]);
|
||||
result = await downloadJannyCharacter(uuid.split('_')[0]);
|
||||
} else {
|
||||
if (uuidType === 'character') {
|
||||
console.log('Downloading chub character:', uuid);
|
||||
result = await downloadChubCharacter(uuid);
|
||||
}
|
||||
else if (uuidType === 'lorebook') {
|
||||
console.log('Downloading chub lorebook:', uuid);
|
||||
result = await downloadChubLorebook(uuid);
|
||||
}
|
||||
else {
|
||||
return response.sendStatus(404);
|
||||
}
|
||||
}
|
||||
|
||||
if (result.fileType) response.set('Content-Type', result.fileType);
|
||||
response.set('Content-Disposition', `attachment; filename="${result.fileName}"`);
|
||||
response.set('X-Custom-Content-Type', uuidType);
|
||||
return response.send(result.buffer);
|
||||
} catch (error) {
|
||||
console.log('Importing custom content failed', error);
|
||||
return response.sendStatus(500);
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = {
|
||||
checkForNewContent,
|
||||
getDefaultPresets,
|
||||
|
@@ -10,7 +10,7 @@ const API_NOVELAI = 'https://api.novelai.net';
|
||||
// Ban bracket generation, plus defaults
|
||||
const badWordsList = [
|
||||
[3], [49356], [1431], [31715], [34387], [20765], [30702], [10691], [49333], [1266],
|
||||
[19438], [43145], [26523], [41471], [2936], [85, 85], [49332], [7286], [1115],
|
||||
[19438], [43145], [26523], [41471], [2936], [85, 85], [49332], [7286], [1115], [24],
|
||||
];
|
||||
|
||||
const hypeBotBadWordsList = [
|
||||
@@ -172,9 +172,17 @@ router.post('/generate', jsonParser, async function (req, res) {
|
||||
'return_full_text': req.body.return_full_text,
|
||||
'prefix': req.body.prefix,
|
||||
'order': req.body.order,
|
||||
'num_logprobs': req.body.num_logprobs,
|
||||
},
|
||||
};
|
||||
|
||||
// Tells the model to stop generation at '>'
|
||||
if ('theme_textadventure' === req.body.prefix &&
|
||||
(true === req.body.model.includes('clio') ||
|
||||
true === req.body.model.includes('kayra'))) {
|
||||
data.parameters.eos_token_id = 49405;
|
||||
}
|
||||
|
||||
console.log(util.inspect(data, { depth: 4 }));
|
||||
|
||||
const args = {
|
||||
@@ -208,7 +216,7 @@ router.post('/generate', jsonParser, async function (req, res) {
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
console.log('NovelAI Output', data?.output);
|
||||
return res.send(data);
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -342,7 +350,9 @@ router.post('/generate-voice', jsonParser, async (request, response) => {
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return response.sendStatus(result.status);
|
||||
const errorText = await result.text();
|
||||
console.log('NovelAI returned an error.', result.statusText, errorText);
|
||||
return response.sendStatus(500);
|
||||
}
|
||||
|
||||
const chunks = await readAllChunks(result.body);
|
||||
|
@@ -33,6 +33,7 @@ router.post('/caption-image', jsonParser, async (request, response) => {
|
||||
}
|
||||
|
||||
if (request.body.api === 'ooba') {
|
||||
key = readSecret(SECRET_KEYS.OOBA);
|
||||
bodyParams.temperature = 0.1;
|
||||
}
|
||||
|
||||
|
@@ -5,15 +5,21 @@
|
||||
* @param {string} addAssistantPrefill Add Assistant prefill after the assistant postfix.
|
||||
* @param {boolean} withSysPromptSupport Indicates if the Claude model supports the system prompt format.
|
||||
* @param {boolean} useSystemPrompt Indicates if the system prompt format should be used.
|
||||
* @param {boolean} excludePrefixes Exlude Human/Assistant prefixes.
|
||||
* @param {string} addSysHumanMsg Add Human message between system prompt and assistant.
|
||||
* @returns {string} Prompt for Claude
|
||||
* @copyright Prompt Conversion script taken from RisuAI by kwaroran (GPLv3).
|
||||
*/
|
||||
function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill, withSysPromptSupport, useSystemPrompt, addSysHumanMsg) {
|
||||
function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill, withSysPromptSupport, useSystemPrompt, addSysHumanMsg, excludePrefixes) {
|
||||
|
||||
//Prepare messages for claude.
|
||||
//When 'Exclude Human/Assistant prefixes' checked, setting messages role to the 'system'(last message is exception).
|
||||
if (messages.length > 0) {
|
||||
messages[0].role = 'system';
|
||||
if (excludePrefixes) {
|
||||
messages.slice(0, -1).forEach(message => message.role = 'system');
|
||||
} else {
|
||||
messages[0].role = 'system';
|
||||
}
|
||||
//Add the assistant's message to the end of messages.
|
||||
if (addAssistantPostfix) {
|
||||
messages.push({
|
||||
@@ -29,7 +35,7 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
|
||||
}
|
||||
return message.role === 'assistant' && i > 0;
|
||||
});
|
||||
// When 2.1+ and 'Use system prompt" checked, switches to the system prompt format by setting the first message's role to the 'system'.
|
||||
// When 2.1+ and 'Use system prompt' checked, switches to the system prompt format by setting the first message's role to the 'system'.
|
||||
// Inserts the human's message before the first the assistant one, if there are no such message or prefix found.
|
||||
if (withSysPromptSupport && useSystemPrompt) {
|
||||
messages[0].role = 'system';
|
||||
@@ -43,7 +49,7 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
|
||||
// Otherwise, use the default message format by setting the first message's role to 'user'(compatible with all claude models including 2.1.)
|
||||
messages[0].role = 'user';
|
||||
// Fix messages order for default message format when(messages > Context Size) by merging two messages with "\n\nHuman: " prefixes into one, before the first Assistant's message.
|
||||
if (firstAssistantIndex > 0) {
|
||||
if (firstAssistantIndex > 0 && !excludePrefixes) {
|
||||
messages[firstAssistantIndex - 1].role = firstAssistantIndex - 1 !== 0 && messages[firstAssistantIndex - 1].role === 'user' ? 'FixHumMsg' : messages[firstAssistantIndex - 1].role;
|
||||
}
|
||||
}
|
||||
@@ -51,11 +57,11 @@ function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill,
|
||||
|
||||
// Convert messages to the prompt.
|
||||
let requestPrompt = messages.map((v, i) => {
|
||||
// Set prefix according to the role.
|
||||
// Set prefix according to the role. Also, when "Exclude Human/Assistant prefixes" is checked, names are added via the system prefix.
|
||||
let prefix = {
|
||||
'assistant': '\n\nAssistant: ',
|
||||
'user': '\n\nHuman: ',
|
||||
'system': i === 0 ? '' : v.name === 'example_assistant' ? '\n\nA: ' : v.name === 'example_user' ? '\n\nH: ' : '\n\n',
|
||||
'system': i === 0 ? '' : v.name === 'example_assistant' ? '\n\nA: ' : v.name === 'example_user' ? '\n\nH: ' : excludePrefixes && v.name ? `\n\n${v.name}: ` : '\n\n',
|
||||
'FixHumMsg': '\n\nFirst message: ',
|
||||
}[v.role] ?? '';
|
||||
// Claude doesn't support message names, so we'll just add them to the message content.
|
||||
|
@@ -17,6 +17,7 @@ const SECRET_KEYS = {
|
||||
DEEPL: 'deepl',
|
||||
LIBRE: 'libre',
|
||||
LIBRE_URL: 'libre_url',
|
||||
LINGVA_URL: 'lingva_url',
|
||||
OPENROUTER: 'api_key_openrouter',
|
||||
SCALE: 'api_key_scale',
|
||||
AI21: 'api_key_ai21',
|
||||
@@ -28,8 +29,18 @@ const SECRET_KEYS = {
|
||||
TOGETHERAI: 'api_key_togetherai',
|
||||
MISTRALAI: 'api_key_mistralai',
|
||||
CUSTOM: 'api_key_custom',
|
||||
OOBA: 'api_key_ooba',
|
||||
INFERMATICAI: 'api_key_infermaticai',
|
||||
};
|
||||
|
||||
// These are the keys that are safe to expose, even if allowKeysExposure is false
|
||||
const EXPORTABLE_KEYS = [
|
||||
SECRET_KEYS.LIBRE_URL,
|
||||
SECRET_KEYS.LINGVA_URL,
|
||||
SECRET_KEYS.ONERING_URL,
|
||||
SECRET_KEYS.DEEPLX_URL,
|
||||
];
|
||||
|
||||
/**
|
||||
* Writes a secret to the secrets file
|
||||
* @param {string} key Secret key
|
||||
@@ -210,14 +221,13 @@ router.post('/view', jsonParser, async (_, response) => {
|
||||
|
||||
router.post('/find', jsonParser, (request, response) => {
|
||||
const allowKeysExposure = getConfigValue('allowKeysExposure', false);
|
||||
const key = request.body.key;
|
||||
|
||||
if (!allowKeysExposure) {
|
||||
if (!allowKeysExposure && !EXPORTABLE_KEYS.includes(key)) {
|
||||
console.error('Cannot fetch secrets unless allowKeysExposure in config.yaml is set to true');
|
||||
return response.sendStatus(403);
|
||||
}
|
||||
|
||||
const key = request.body.key;
|
||||
|
||||
try {
|
||||
const secret = readSecret(key);
|
||||
|
||||
|
82
src/endpoints/speech.js
Normal file
82
src/endpoints/speech.js
Normal file
@@ -0,0 +1,82 @@
|
||||
const express = require('express');
|
||||
const { jsonParser } = require('../express-common');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
* Gets the audio data from a base64-encoded audio file.
|
||||
* @param {string} audio Base64-encoded audio
|
||||
* @returns {Float64Array} Audio data
|
||||
*/
|
||||
function getWaveFile(audio) {
|
||||
const wavefile = require('wavefile');
|
||||
const wav = new wavefile.WaveFile();
|
||||
wav.fromDataURI(audio);
|
||||
wav.toBitDepth('32f');
|
||||
wav.toSampleRate(16000);
|
||||
let audioData = wav.getSamples();
|
||||
if (Array.isArray(audioData)) {
|
||||
if (audioData.length > 1) {
|
||||
const SCALING_FACTOR = Math.sqrt(2);
|
||||
|
||||
// Merge channels (into first channel to save memory)
|
||||
for (let i = 0; i < audioData[0].length; ++i) {
|
||||
audioData[0][i] = SCALING_FACTOR * (audioData[0][i] + audioData[1][i]) / 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Select first channel
|
||||
audioData = audioData[0];
|
||||
}
|
||||
|
||||
return audioData;
|
||||
}
|
||||
|
||||
router.post('/recognize', jsonParser, async (req, res) => {
|
||||
try {
|
||||
const TASK = 'automatic-speech-recognition';
|
||||
const { model, audio, lang } = req.body;
|
||||
const module = await import('../transformers.mjs');
|
||||
const pipe = await module.default.getPipeline(TASK, model);
|
||||
const wav = getWaveFile(audio);
|
||||
const start = performance.now();
|
||||
const result = await pipe(wav, { language: lang || null });
|
||||
const end = performance.now();
|
||||
console.log(`Execution duration: ${(end - start) / 1000} seconds`);
|
||||
console.log('Transcribed audio:', result.text);
|
||||
|
||||
return res.json({ text: result.text });
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return res.sendStatus(500);
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/synthesize', jsonParser, async (req, res) => {
|
||||
try {
|
||||
const wavefile = require('wavefile');
|
||||
const TASK = 'text-to-speech';
|
||||
const { text, model, speaker } = req.body;
|
||||
const module = await import('../transformers.mjs');
|
||||
const pipe = await module.default.getPipeline(TASK, model);
|
||||
const speaker_embeddings = speaker
|
||||
? new Float32Array(new Uint8Array(Buffer.from(speaker.startsWith('data:') ? speaker.split(',')[1] : speaker, 'base64')).buffer)
|
||||
: null;
|
||||
const start = performance.now();
|
||||
const result = await pipe(text, { speaker_embeddings: speaker_embeddings });
|
||||
const end = performance.now();
|
||||
console.log(`Execution duration: ${(end - start) / 1000} seconds`);
|
||||
|
||||
const wav = new wavefile.WaveFile();
|
||||
wav.fromScratch(1, result.sampling_rate, '32f', result.audio);
|
||||
const buffer = wav.toBuffer();
|
||||
|
||||
res.set('Content-Type', 'audio/wav');
|
||||
return res.send(Buffer.from(buffer));
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return res.sendStatus(500);
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = { router };
|
@@ -2,10 +2,8 @@ const fs = require('fs');
|
||||
const path = require('path');
|
||||
const express = require('express');
|
||||
const writeFileAtomic = require('write-file-atomic');
|
||||
const util = require('util');
|
||||
const crypto = require('crypto');
|
||||
|
||||
const writeFile = util.promisify(writeFileAtomic);
|
||||
const readFile = fs.promises.readFile;
|
||||
const readdir = fs.promises.readdir;
|
||||
|
||||
@@ -168,7 +166,7 @@ async function saveStatsToFile() {
|
||||
if (charStats.timestamp > lastSaveTimestamp) {
|
||||
//console.debug("Saving stats to file...");
|
||||
try {
|
||||
await writeFile(statsFilePath, JSON.stringify(charStats));
|
||||
await writeFileAtomic(statsFilePath, JSON.stringify(charStats));
|
||||
lastSaveTimestamp = Date.now();
|
||||
} catch (error) {
|
||||
console.log('Failed to save stats to file.', error);
|
||||
|
@@ -298,11 +298,13 @@ function createSentencepieceDecodingHandler(tokenizer) {
|
||||
|
||||
const ids = request.body.ids || [];
|
||||
const instance = await tokenizer?.get();
|
||||
const text = await instance?.decodeIds(ids);
|
||||
return response.send({ text });
|
||||
const ops = ids.map(id => instance.decodeIds([id]));
|
||||
const chunks = await Promise.all(ops);
|
||||
const text = chunks.join('');
|
||||
return response.send({ text, chunks });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ text: '' });
|
||||
return response.send({ text: '', chunks: [] });
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -626,6 +628,10 @@ router.post('/remote/textgenerationwebui/encode', jsonParser, async function (re
|
||||
url += '/tokenize';
|
||||
args.body = JSON.stringify({ 'content': text });
|
||||
break;
|
||||
case TEXTGEN_TYPES.APHRODITE:
|
||||
url += '/v1/tokenize';
|
||||
args.body = JSON.stringify({ 'prompt': text });
|
||||
break;
|
||||
default:
|
||||
url += '/v1/internal/encode';
|
||||
args.body = JSON.stringify({ 'text': text });
|
||||
|
@@ -19,6 +19,10 @@ router.post('/libre', jsonParser, async (request, response) => {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
if (request.body.lang === 'zh-CN' || request.body.lang === 'zh-TW') {
|
||||
request.body.lang = 'zh';
|
||||
}
|
||||
|
||||
const text = request.body.text;
|
||||
const lang = request.body.lang;
|
||||
|
||||
@@ -98,6 +102,52 @@ router.post('/google', jsonParser, async (request, response) => {
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/lingva', jsonParser, async (request, response) => {
|
||||
try {
|
||||
const baseUrl = readSecret(SECRET_KEYS.LINGVA_URL);
|
||||
|
||||
if (!baseUrl) {
|
||||
console.log('Lingva URL is not configured.');
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const text = request.body.text;
|
||||
const lang = request.body.lang;
|
||||
|
||||
if (!text || !lang) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
console.log('Input text: ' + text);
|
||||
const url = `${baseUrl}/auto/${lang}/${encodeURIComponent(text)}`;
|
||||
|
||||
https.get(url, (resp) => {
|
||||
let data = '';
|
||||
|
||||
resp.on('data', (chunk) => {
|
||||
data += chunk;
|
||||
});
|
||||
|
||||
resp.on('end', () => {
|
||||
try {
|
||||
const result = JSON.parse(data);
|
||||
console.log('Translated text: ' + result.translation);
|
||||
return response.send(result.translation);
|
||||
} catch (error) {
|
||||
console.log('Translation error', error);
|
||||
return response.sendStatus(500);
|
||||
}
|
||||
});
|
||||
}).on('error', (err) => {
|
||||
console.log('Translation error: ' + err.message);
|
||||
return response.sendStatus(500);
|
||||
});
|
||||
} catch (error) {
|
||||
console.log('Translation error', error);
|
||||
return response.sendStatus(500);
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/deepl', jsonParser, async (request, response) => {
|
||||
const key = readSecret(SECRET_KEYS.DEEPL);
|
||||
|
||||
|
@@ -4,19 +4,26 @@ const express = require('express');
|
||||
const sanitize = require('sanitize-filename');
|
||||
const { jsonParser } = require('../express-common');
|
||||
|
||||
// Don't forget to add new sources to the SOURCES array
|
||||
const SOURCES = ['transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai'];
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from the given source.
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getVector(source, text) {
|
||||
async function getVector(source, sourceSettings, text) {
|
||||
switch (source) {
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
return require('../openai-vectors').getOpenAIVector(text, source);
|
||||
return require('../openai-vectors').getOpenAIVector(text, source, sourceSettings.model);
|
||||
case 'transformers':
|
||||
return require('../embedding').getTransformersVector(text);
|
||||
case 'extras':
|
||||
return require('../extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey);
|
||||
case 'palm':
|
||||
return require('../makersuite-vectors').getMakerSuiteVector(text);
|
||||
}
|
||||
@@ -24,6 +31,42 @@ async function getVector(source, text) {
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text batch from the given source.
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getBatchVector(source, sourceSettings, texts) {
|
||||
const batchSize = 10;
|
||||
const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize));
|
||||
|
||||
let results = [];
|
||||
for (let batch of batches) {
|
||||
switch (source) {
|
||||
case 'togetherai':
|
||||
case 'mistral':
|
||||
case 'openai':
|
||||
results.push(...await require('../openai-vectors').getOpenAIBatchVector(batch, source, sourceSettings.model));
|
||||
break;
|
||||
case 'transformers':
|
||||
results.push(...await require('../embedding').getTransformersBatchVector(batch));
|
||||
break;
|
||||
case 'extras':
|
||||
results.push(...await require('../extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey));
|
||||
break;
|
||||
case 'palm':
|
||||
results.push(...await require('../makersuite-vectors').getMakerSuiteBatchVector(batch));
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown vector source ${source}`);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the index for the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
@@ -45,19 +88,20 @@ async function getIndex(collectionId, source, create = true) {
|
||||
* Inserts items into the vector collection
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {{ hash: number; text: string; index: number; }[]} items - The items to insert
|
||||
*/
|
||||
async function insertVectorItems(collectionId, source, items) {
|
||||
async function insertVectorItems(collectionId, source, sourceSettings, items) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
|
||||
await store.beginUpdate();
|
||||
|
||||
for (const item of items) {
|
||||
const text = item.text;
|
||||
const hash = item.hash;
|
||||
const index = item.index;
|
||||
const vector = await getVector(source, text);
|
||||
await store.upsertItem({ vector: vector, metadata: { hash, text, index } });
|
||||
const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text));
|
||||
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
const item = items[i];
|
||||
const vector = vectors[i];
|
||||
await store.upsertItem({ vector: vector, metadata: { hash: item.hash, text: item.text, index: item.index } });
|
||||
}
|
||||
|
||||
await store.endUpdate();
|
||||
@@ -101,13 +145,14 @@ async function deleteVectorItems(collectionId, source, hashes) {
|
||||
* Gets the hashes of the items in the vector collection that match the search text
|
||||
* @param {string} collectionId - The collection ID
|
||||
* @param {string} source - The source of the vector
|
||||
* @param {Object} sourceSettings - Settings for the source, if it needs any
|
||||
* @param {string} searchText - The text to search for
|
||||
* @param {number} topK - The number of results to return
|
||||
* @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text
|
||||
*/
|
||||
async function queryCollection(collectionId, source, searchText, topK) {
|
||||
async function queryCollection(collectionId, source, sourceSettings, searchText, topK) {
|
||||
const store = await getIndex(collectionId, source);
|
||||
const vector = await getVector(source, searchText);
|
||||
const vector = await getVector(source, sourceSettings, searchText);
|
||||
|
||||
const result = await store.queryItems(vector, topK);
|
||||
const metadata = result.map(x => x.item.metadata);
|
||||
@@ -115,6 +160,35 @@ async function queryCollection(collectionId, source, searchText, topK) {
|
||||
return { metadata, hashes };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts settings for the vectorization sources from the HTTP request headers.
|
||||
* @param {string} source - Which source to extract settings for.
|
||||
* @param {object} request - The HTTP request object.
|
||||
* @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter.
|
||||
*/
|
||||
function getSourceSettings(source, request) {
|
||||
if (source === 'togetherai') {
|
||||
let model = String(request.headers['x-togetherai-model']);
|
||||
|
||||
return {
|
||||
model: model,
|
||||
};
|
||||
} else {
|
||||
// Extras API settings to connect to the Extras embeddings provider
|
||||
let extrasUrl = '';
|
||||
let extrasKey = '';
|
||||
if (source === 'extras') {
|
||||
extrasUrl = String(request.headers['x-extras-url']);
|
||||
extrasKey = String(request.headers['x-extras-key']);
|
||||
}
|
||||
|
||||
return {
|
||||
extrasUrl: extrasUrl,
|
||||
extrasKey: extrasKey,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/query', jsonParser, async (req, res) => {
|
||||
@@ -127,8 +201,9 @@ router.post('/query', jsonParser, async (req, res) => {
|
||||
const searchText = String(req.body.searchText);
|
||||
const topK = Number(req.body.topK) || 10;
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
const sourceSettings = getSourceSettings(source, req);
|
||||
|
||||
const results = await queryCollection(collectionId, source, searchText, topK);
|
||||
const results = await queryCollection(collectionId, source, sourceSettings, searchText, topK);
|
||||
return res.json(results);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@@ -145,8 +220,9 @@ router.post('/insert', jsonParser, async (req, res) => {
|
||||
const collectionId = String(req.body.collectionId);
|
||||
const items = req.body.items.map(x => ({ hash: x.hash, text: x.text, index: x.index }));
|
||||
const source = String(req.body.source) || 'transformers';
|
||||
const sourceSettings = getSourceSettings(source, req);
|
||||
|
||||
await insertVectorItems(collectionId, source, items);
|
||||
await insertVectorItems(collectionId, source, sourceSettings, items);
|
||||
return res.sendStatus(200);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@@ -197,8 +273,7 @@ router.post('/purge', jsonParser, async (req, res) => {
|
||||
|
||||
const collectionId = String(req.body.collectionId);
|
||||
|
||||
const sources = ['transformers', 'openai', 'palm'];
|
||||
for (const source of sources) {
|
||||
for (const source of SOURCES) {
|
||||
const index = await getIndex(collectionId, source, false);
|
||||
|
||||
const exists = await index.isIndexCreated();
|
||||
|
@@ -7,8 +7,14 @@ const writeFileAtomicSync = require('write-file-atomic').sync;
|
||||
const { jsonParser, urlencodedParser } = require('../express-common');
|
||||
const { DIRECTORIES, UPLOADS_PATH } = require('../constants');
|
||||
|
||||
function readWorldInfoFile(worldInfoName) {
|
||||
const dummyObject = { entries: {} };
|
||||
/**
|
||||
* Reads a World Info file and returns its contents
|
||||
* @param {string} worldInfoName Name of the World Info file
|
||||
* @param {boolean} allowDummy If true, returns an empty object if the file doesn't exist
|
||||
* @returns {object} World Info file contents
|
||||
*/
|
||||
function readWorldInfoFile(worldInfoName, allowDummy) {
|
||||
const dummyObject = allowDummy ? { entries: {} } : null;
|
||||
|
||||
if (!worldInfoName) {
|
||||
return dummyObject;
|
||||
@@ -34,7 +40,7 @@ router.post('/get', jsonParser, (request, response) => {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const file = readWorldInfoFile(request.body.name);
|
||||
const file = readWorldInfoFile(request.body.name, true);
|
||||
|
||||
return response.send(file);
|
||||
});
|
||||
|
78
src/extras-vectors.js
Normal file
78
src/extras-vectors.js
Normal file
@@ -0,0 +1,78 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from SillyTavern-extras
|
||||
* @param {string[]} texts - The array of texts to get the vectors for
|
||||
* @param {string} apiUrl - The Extras API URL
|
||||
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getExtrasBatchVector(texts, apiUrl, apiKey) {
|
||||
return getExtrasVectorImpl(texts, apiUrl, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from SillyTavern-extras
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} apiUrl - The Extras API URL
|
||||
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getExtrasVector(text, apiUrl, apiKey) {
|
||||
return getExtrasVectorImpl(text, apiUrl, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from SillyTavern-extras
|
||||
* @param {string|string[]} text - The text or texts to get the vector(s) for
|
||||
* @param {string} apiUrl - The Extras API URL
|
||||
* @param {string} apiKey - The Extras API key, or empty string if API key not enabled *
|
||||
* @returns {Promise<Array>} - The vector for a single text if input is string, or the array of vectors for multiple texts if input is string[]
|
||||
*/
|
||||
async function getExtrasVectorImpl(text, apiUrl, apiKey) {
|
||||
let url;
|
||||
try {
|
||||
url = new URL(apiUrl);
|
||||
url.pathname = '/api/embeddings/compute';
|
||||
}
|
||||
catch (error) {
|
||||
console.log('Failed to set up Extras API call:', error);
|
||||
console.log('Extras API URL given was:', apiUrl);
|
||||
throw error;
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
// Include the Extras API key, if enabled
|
||||
if (apiKey && apiKey.length > 0) {
|
||||
Object.assign(headers, {
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
});
|
||||
}
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: headers,
|
||||
body: JSON.stringify({
|
||||
text: text, // The backend accepts {string|string[]} for one or multiple text items, respectively.
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
console.log('Extras request failed', response.statusText, text);
|
||||
throw new Error('Extras request failed');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const vector = data.embedding; // `embedding`: number[] (one text item), or number[][] (multiple text items).
|
||||
|
||||
return vector;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getExtrasVector,
|
||||
getExtrasBatchVector,
|
||||
};
|
@@ -1,6 +1,17 @@
|
||||
const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from gecko model
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getMakerSuiteBatchVector(texts) {
|
||||
const promises = texts.map(text => getMakerSuiteVector(text));
|
||||
const vectors = await Promise.all(promises);
|
||||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from PaLM gecko model
|
||||
* @param {string} text - The text to get the vector for
|
||||
@@ -40,4 +51,5 @@ async function getMakerSuiteVector(text) {
|
||||
|
||||
module.exports = {
|
||||
getMakerSuiteVector,
|
||||
getMakerSuiteBatchVector,
|
||||
};
|
||||
|
@@ -2,8 +2,13 @@ const fetch = require('node-fetch').default;
|
||||
const { SECRET_KEYS, readSecret } = require('./endpoints/secrets');
|
||||
|
||||
const SOURCES = {
|
||||
'togetherai': {
|
||||
secretKey: SECRET_KEYS.TOGETHERAI,
|
||||
url: 'api.together.xyz',
|
||||
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
|
||||
},
|
||||
'mistral': {
|
||||
secretKey: SECRET_KEYS.MISTRAL,
|
||||
secretKey: SECRET_KEYS.MISTRALAI,
|
||||
url: 'api.mistral.ai',
|
||||
model: 'mistral-embed',
|
||||
},
|
||||
@@ -15,12 +20,13 @@ const SOURCES = {
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* Gets the vector for the given text batch from an OpenAI compatible endpoint.
|
||||
* @param {string[]} texts - The array of texts to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
* @param {string} model - The model to use for the embedding
|
||||
* @returns {Promise<number[][]>} - The array of vectors for the texts
|
||||
*/
|
||||
async function getOpenAIVector(text, source) {
|
||||
async function getOpenAIBatchVector(texts, source, model = '') {
|
||||
const config = SOURCES[source];
|
||||
|
||||
if (!config) {
|
||||
@@ -43,8 +49,8 @@ async function getOpenAIVector(text, source) {
|
||||
Authorization: `Bearer ${key}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: text,
|
||||
model: config.model,
|
||||
input: texts,
|
||||
model: model || config.model,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -55,16 +61,32 @@ async function getOpenAIVector(text, source) {
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const vector = data?.data[0]?.embedding;
|
||||
|
||||
if (!Array.isArray(vector)) {
|
||||
if (!Array.isArray(data?.data)) {
|
||||
console.log('API response was not an array');
|
||||
throw new Error('API response was not an array');
|
||||
}
|
||||
|
||||
return vector;
|
||||
// Sort data by x.index to ensure the order is correct
|
||||
data.data.sort((a, b) => a.index - b.index);
|
||||
|
||||
const vectors = data.data.map(x => x.embedding);
|
||||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the vector for the given text from an OpenAI compatible endpoint.
|
||||
* @param {string} text - The text to get the vector for
|
||||
* @param {string} source - The source of the vector
|
||||
* @param model
|
||||
* @returns {Promise<number[]>} - The vector for the text
|
||||
*/
|
||||
async function getOpenAIVector(text, source, model = '') {
|
||||
const vectors = await getOpenAIBatchVector([text], source, model);
|
||||
return vectors[0];
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIVector,
|
||||
getOpenAIBatchVector,
|
||||
};
|
||||
|
@@ -17,21 +17,37 @@ const tasks = {
|
||||
defaultModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx',
|
||||
pipeline: null,
|
||||
configField: 'extras.classificationModel',
|
||||
quantized: true,
|
||||
},
|
||||
'image-to-text': {
|
||||
defaultModel: 'Xenova/vit-gpt2-image-captioning',
|
||||
pipeline: null,
|
||||
configField: 'extras.captioningModel',
|
||||
quantized: true,
|
||||
},
|
||||
'feature-extraction': {
|
||||
defaultModel: 'Xenova/all-mpnet-base-v2',
|
||||
pipeline: null,
|
||||
configField: 'extras.embeddingModel',
|
||||
quantized: true,
|
||||
},
|
||||
'text-generation': {
|
||||
defaultModel: 'Cohee/fooocus_expansion-onnx',
|
||||
pipeline: null,
|
||||
configField: 'extras.promptExpansionModel',
|
||||
quantized: true,
|
||||
},
|
||||
'automatic-speech-recognition': {
|
||||
defaultModel: 'Xenova/whisper-small',
|
||||
pipeline: null,
|
||||
configField: 'extras.speechToTextModel',
|
||||
quantized: true,
|
||||
},
|
||||
'text-to-speech': {
|
||||
defaultModel: 'Xenova/speecht5_tts',
|
||||
pipeline: null,
|
||||
configField: 'extras.textToSpeechModel',
|
||||
quantized: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -72,19 +88,20 @@ function getModelForTask(task) {
|
||||
|
||||
/**
|
||||
* Gets the transformers.js pipeline for a given task.
|
||||
* @param {string} task The task to get the pipeline for
|
||||
* @param {import('sillytavern-transformers').PipelineType} task The task to get the pipeline for
|
||||
* @param {string} forceModel The model to use for the pipeline, if any
|
||||
* @returns {Promise<Pipeline>} Pipeline for the task
|
||||
*/
|
||||
async function getPipeline(task) {
|
||||
async function getPipeline(task, forceModel = '') {
|
||||
if (tasks[task].pipeline) {
|
||||
return tasks[task].pipeline;
|
||||
}
|
||||
|
||||
const cache_dir = path.join(process.cwd(), 'cache');
|
||||
const model = getModelForTask(task);
|
||||
const model = forceModel || getModelForTask(task);
|
||||
const localOnly = getConfigValue('extras.disableAutoDownload', false);
|
||||
console.log('Initializing transformers.js pipeline for task', task, 'with model', model);
|
||||
const instance = await pipeline(task, model, { cache_dir, quantized: true, local_files_only: localOnly });
|
||||
const instance = await pipeline(task, model, { cache_dir, quantized: tasks[task].quantized ?? true, local_files_only: localOnly });
|
||||
tasks[task].pipeline = instance;
|
||||
return instance;
|
||||
}
|
||||
|
61
src/util.js
61
src/util.js
@@ -365,7 +365,7 @@ function getImages(path) {
|
||||
/**
|
||||
* Pipe a fetch() response to an Express.js Response, including status code.
|
||||
* @param {import('node-fetch').Response} from The Fetch API response to pipe from.
|
||||
* @param {Express.Response} to The Express response to pipe to.
|
||||
* @param {import('express').Response} to The Express response to pipe to.
|
||||
*/
|
||||
function forwardFetchResponse(from, to) {
|
||||
let statusCode = from.status;
|
||||
@@ -399,6 +399,64 @@ function forwardFetchResponse(from, to) {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes an HTTP/2 request to the specified endpoint.
|
||||
*
|
||||
* @deprecated Use `node-fetch` if possible.
|
||||
* @param {string} endpoint URL to make the request to
|
||||
* @param {string} method HTTP method to use
|
||||
* @param {string} body Request body
|
||||
* @param {object} headers Request headers
|
||||
* @returns {Promise<string>} Response body
|
||||
*/
|
||||
function makeHttp2Request(endpoint, method, body, headers) {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
const http2 = require('http2');
|
||||
const url = new URL(endpoint);
|
||||
const client = http2.connect(url.origin);
|
||||
|
||||
const req = client.request({
|
||||
':method': method,
|
||||
':path': url.pathname,
|
||||
...headers,
|
||||
});
|
||||
req.setEncoding('utf8');
|
||||
|
||||
req.on('response', (headers) => {
|
||||
const status = Number(headers[':status']);
|
||||
|
||||
if (status < 200 || status >= 300) {
|
||||
reject(new Error(`Request failed with status ${status}`));
|
||||
}
|
||||
|
||||
let data = '';
|
||||
|
||||
req.on('data', (chunk) => {
|
||||
data += chunk;
|
||||
});
|
||||
|
||||
req.on('end', () => {
|
||||
console.log(data);
|
||||
resolve(data);
|
||||
});
|
||||
});
|
||||
|
||||
req.on('error', (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
if (body) {
|
||||
req.write(body);
|
||||
}
|
||||
|
||||
req.end();
|
||||
} catch (e) {
|
||||
reject(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds YAML-serialized object to the object.
|
||||
* @param {object} obj Object
|
||||
@@ -547,4 +605,5 @@ module.exports = {
|
||||
excludeKeysByYaml,
|
||||
trimV1,
|
||||
Cache,
|
||||
makeHttp2Request,
|
||||
};
|
||||
|
Reference in New Issue
Block a user