Merge branch 'staging' into plugin-router

This commit is contained in:
Cohee
2023-12-23 18:39:18 +02:00
59 changed files with 4416 additions and 1573 deletions

View File

@@ -11,6 +11,14 @@ function getMancerHeaders() {
}) : {};
}
function getTogetherAIHeaders() {
const apiKey = readSecret(SECRET_KEYS.TOGETHERAI);
return apiKey ? ({
'Authorization': `Bearer ${apiKey}`,
}) : {};
}
function getAphroditeHeaders() {
const apiKey = readSecret(SECRET_KEYS.APHRODITE);
@@ -58,6 +66,9 @@ function setAdditionalHeaders(request, args, server) {
case TEXTGEN_TYPES.TABBY:
headers = getTabbyHeaders();
break;
case TEXTGEN_TYPES.TOGETHERAI:
headers = getTogetherAIHeaders();
break;
default:
headers = server ? getOverrideHeaders((new URL(server))?.host) : {};
break;

View File

@@ -160,6 +160,7 @@ const CHAT_COMPLETION_SOURCES = {
AI21: 'ai21',
MAKERSUITE: 'makersuite',
MISTRALAI: 'mistralai',
CUSTOM: 'custom',
};
const UPLOADS_PATH = './uploads';
@@ -171,8 +172,42 @@ const TEXTGEN_TYPES = {
APHRODITE: 'aphrodite',
TABBY: 'tabby',
KOBOLDCPP: 'koboldcpp',
TOGETHERAI: 'togetherai',
LLAMACPP: 'llamacpp',
OLLAMA: 'ollama',
};
// https://docs.together.ai/reference/completions
const TOGETHERAI_KEYS = [
'model',
'prompt',
'max_tokens',
'temperature',
'top_p',
'top_k',
'repetition_penalty',
'stream',
];
// https://github.com/jmorganca/ollama/blob/main/docs/api.md#request-with-options
const OLLAMA_KEYS = [
'num_predict',
'stop',
'temperature',
'repeat_penalty',
'presence_penalty',
'frequency_penalty',
'top_k',
'top_p',
'tfs_z',
'typical_p',
'seed',
'repeat_last_n',
'mirostat',
'mirostat_tau',
'mirostat_eta',
];
const AVATAR_WIDTH = 400;
const AVATAR_HEIGHT = 600;
@@ -186,4 +221,6 @@ module.exports = {
CHAT_COMPLETION_SOURCES,
AVATAR_WIDTH,
AVATAR_HEIGHT,
TOGETHERAI_KEYS,
OLLAMA_KEYS,
};

View File

@@ -4,7 +4,7 @@ const { Readable } = require('stream');
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4 } = require('../../util');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudePrompt, convertGooglePrompt, convertTextCompletionPrompt } = require('../prompt-converters');
const { readSecret, SECRET_KEYS } = require('../secrets');
@@ -21,9 +21,10 @@ const API_CLAUDE = 'https://api.anthropic.com/v1';
async function sendClaudeRequest(request, response) {
const apiUrl = new URL(request.body.reverse_proxy || API_CLAUDE).toString();
const apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.CLAUDE);
const divider = '-'.repeat(process.stdout.columns);
if (!apiKey) {
console.log('Claude API key is missing.');
console.log(color.red(`Claude API key is missing.\n${divider}`));
return response.status(400).send({ error: true });
}
@@ -34,34 +35,66 @@ async function sendClaudeRequest(request, response) {
controller.abort();
});
let doSystemPrompt = request.body.model === 'claude-2' || request.body.model === 'claude-2.1';
let requestPrompt = convertClaudePrompt(request.body.messages, true, !request.body.exclude_assistant, doSystemPrompt);
const isSysPromptSupported = request.body.model === 'claude-2' || request.body.model === 'claude-2.1';
const requestPrompt = convertClaudePrompt(request.body.messages, !request.body.exclude_assistant, request.body.assistant_prefill, isSysPromptSupported, request.body.claude_use_sysprompt, request.body.human_sysprompt_message);
if (request.body.assistant_prefill && !request.body.exclude_assistant) {
requestPrompt += request.body.assistant_prefill;
// Check Claude messages sequence and prefixes presence.
const sequence = requestPrompt.split('\n').filter(x => x.startsWith('Human:') || x.startsWith('Assistant:'));
const humanFound = sequence.some(line => line.startsWith('Human:'));
const assistantFound = sequence.some(line => line.startsWith('Assistant:'));
let humanErrorCount = 0;
let assistantErrorCount = 0;
for (let i = 0; i < sequence.length - 1; i++) {
if (sequence[i].startsWith(sequence[i + 1].split(':')[0])) {
if (sequence[i].startsWith('Human:')) {
humanErrorCount++;
} else if (sequence[i].startsWith('Assistant:')) {
assistantErrorCount++;
}
}
}
console.log('Claude request:', requestPrompt);
const stop_sequences = ['\n\nHuman:', '\n\nSystem:', '\n\nAssistant:'];
if (!humanFound) {
console.log(color.red(`${divider}\nWarning: No 'Human:' prefix found in the prompt.\n${divider}`));
}
if (!assistantFound) {
console.log(color.red(`${divider}\nWarning: No 'Assistant: ' prefix found in the prompt.\n${divider}`));
}
if (!sequence[0].startsWith('Human:')) {
console.log(color.red(`${divider}\nWarning: The messages sequence should start with 'Human:' prefix.\nMake sure you have 'Human:' prefix at the very beggining of the prompt, or after the system prompt.\n${divider}`));
}
if (humanErrorCount > 0 || assistantErrorCount > 0) {
console.log(color.red(`${divider}\nWarning: Detected incorrect Prefix sequence(s).`));
console.log(color.red(`Incorrect "Human:" prefix(es): ${humanErrorCount}.\nIncorrect "Assistant: " prefix(es): ${assistantErrorCount}.`));
console.log(color.red('Check the prompt above and fix it in the SillyTavern.'));
console.log(color.red('\nThe correct sequence should look like this:\nSystem prompt <-(for the sysprompt format only, else have 2 empty lines above the first human\'s message.)'));
console.log(color.red(` <-----(Each message beginning with the "Assistant:/Human:" prefix must have one empty line above.)\nHuman:\n\nAssistant:\n...\n\nHuman:\n\nAssistant:\n${divider}`));
}
// Add custom stop sequences
const stopSequences = ['\n\nHuman:', '\n\nSystem:', '\n\nAssistant:'];
if (Array.isArray(request.body.stop)) {
stop_sequences.push(...request.body.stop);
stopSequences.push(...request.body.stop);
}
const requestBody = {
prompt: requestPrompt,
model: request.body.model,
max_tokens_to_sample: request.body.max_tokens,
stop_sequences: stopSequences,
temperature: request.body.temperature,
top_p: request.body.top_p,
top_k: request.body.top_k,
stream: request.body.stream,
};
console.log('Claude request:', requestBody);
const generateResponse = await fetch(apiUrl + '/complete', {
method: 'POST',
signal: controller.signal,
body: JSON.stringify({
prompt: requestPrompt,
model: request.body.model,
max_tokens_to_sample: request.body.max_tokens,
stop_sequences: stop_sequences,
temperature: request.body.temperature,
top_p: request.body.top_p,
top_k: request.body.top_k,
stream: request.body.stream,
}),
body: JSON.stringify(requestBody),
headers: {
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
@@ -75,20 +108,20 @@ async function sendClaudeRequest(request, response) {
forwardFetchResponse(generateResponse, response);
} else {
if (!generateResponse.ok) {
console.log(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
console.log(color.red(`Claude API returned error: ${generateResponse.status} ${generateResponse.statusText}\n${await generateResponse.text()}\n${divider}`));
return response.status(generateResponse.status).send({ error: true });
}
const generateResponseJson = await generateResponse.json();
const responseText = generateResponseJson.completion;
console.log('Claude response:', responseText);
console.log('Claude response:', generateResponseJson);
// Wrap it back to OAI format
const reply = { choices: [{ 'message': { 'content': responseText } }] };
return response.send(reply);
}
} catch (error) {
console.log('Error communicating with Claude: ', error);
console.log(color.red(`Error communicating with Claude: ${error}\n${divider}`));
if (!response.headersSent) {
return response.status(500).send({ error: true });
}
@@ -410,12 +443,12 @@ async function sendMistralAIRequest(request, response) {
const messages = Array.isArray(request.body.messages) ? request.body.messages : [];
const lastMsg = messages[messages.length - 1];
if (messages.length > 0 && lastMsg && (lastMsg.role === 'system' || lastMsg.role === 'assistant')) {
lastMsg.role = 'user';
if (lastMsg.role === 'assistant') {
lastMsg.content = lastMsg.name + ': ' + lastMsg.content;
} else if (lastMsg.role === 'system') {
lastMsg.content = '[INST] ' + lastMsg.content + ' [/INST]';
}
lastMsg.role = 'user';
}
//system prompts can be stacked at the start, but any futher sys prompts after the first user/assistant message will break the model
@@ -438,26 +471,30 @@ async function sendMistralAIRequest(request, response) {
controller.abort();
});
const requestBody = {
'model': request.body.model,
'messages': messages,
'temperature': request.body.temperature,
'top_p': request.body.top_p,
'max_tokens': request.body.max_tokens,
'stream': request.body.stream,
'safe_mode': request.body.safe_mode,
'random_seed': request.body.seed === -1 ? undefined : request.body.seed,
};
const config = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + apiKey,
},
body: JSON.stringify({
'model': request.body.model,
'messages': messages,
'temperature': request.body.temperature,
'top_p': request.body.top_p,
'max_tokens': request.body.max_tokens,
'stream': request.body.stream,
'safe_mode': request.body.safe_mode,
'random_seed': request.body.seed === -1 ? undefined : request.body.seed,
}),
body: JSON.stringify(requestBody),
signal: controller.signal,
timeout: 0,
};
console.log('MisralAI request:', requestBody);
const generateResponse = await fetch('https://api.mistral.ai/v1/chat/completions', config);
if (request.body.stream) {
forwardFetchResponse(generateResponse, response);
@@ -469,6 +506,7 @@ async function sendMistralAIRequest(request, response) {
return response.status(generateResponse.status === 401 ? 500 : generateResponse.status).send({ error: true });
}
const generateResponseJson = await generateResponse.json();
console.log('MistralAI response:', generateResponseJson);
return response.send(generateResponseJson);
}
} catch (error) {
@@ -502,12 +540,17 @@ router.post('/status', jsonParser, async function (request, response_getstatus_o
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.MISTRALAI) {
api_url = 'https://api.mistral.ai/v1';
api_key_openai = readSecret(SECRET_KEYS.MISTRALAI);
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.CUSTOM) {
api_url = request.body.custom_url;
api_key_openai = readSecret(SECRET_KEYS.CUSTOM);
headers = {};
mergeObjectWithYaml(headers, request.body.custom_include_headers);
} else {
console.log('This chat completion source is not supported yet.');
return response_getstatus_openai.status(400).send({ error: true });
}
if (!api_key_openai && !request.body.reverse_proxy) {
if (!api_key_openai && !request.body.reverse_proxy && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.CUSTOM) {
console.log('OpenAI API key is missing.');
return response_getstatus_openai.status(400).send({ error: true });
}
@@ -657,7 +700,7 @@ router.post('/generate', jsonParser, function (request, response) {
let headers;
let bodyParams;
if (request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.OPENROUTER) {
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENAI) {
apiUrl = new URL(request.body.reverse_proxy || API_OPENAI).toString();
apiKey = request.body.reverse_proxy ? request.body.proxy_password : readSecret(SECRET_KEYS.OPENAI);
headers = {};
@@ -666,7 +709,7 @@ router.post('/generate', jsonParser, function (request, response) {
if (getConfigValue('openai.randomizeUserId', false)) {
bodyParams['user'] = uuidv4();
}
} else {
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.OPENROUTER) {
apiUrl = 'https://openrouter.ai/api/v1';
apiKey = readSecret(SECRET_KEYS.OPENROUTER);
// OpenRouter needs to pass the referer: https://openrouter.ai/docs
@@ -676,9 +719,19 @@ router.post('/generate', jsonParser, function (request, response) {
if (request.body.use_fallback) {
bodyParams['route'] = 'fallback';
}
} else if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.CUSTOM) {
apiUrl = request.body.custom_url;
apiKey = readSecret(SECRET_KEYS.CUSTOM);
headers = {};
bodyParams = {};
mergeObjectWithYaml(bodyParams, request.body.custom_include_body);
mergeObjectWithYaml(headers, request.body.custom_include_headers);
} else {
console.log('This chat completion source is not supported yet.');
return response.status(400).send({ error: true });
}
if (!apiKey && !request.body.reverse_proxy) {
if (!apiKey && !request.body.reverse_proxy && request.body.chat_completion_source !== CHAT_COMPLETION_SOURCES.CUSTOM) {
console.log('OpenAI API key is missing.');
return response.status(400).send({ error: true });
}
@@ -700,6 +753,27 @@ router.post('/generate', jsonParser, function (request, response) {
controller.abort();
});
const requestBody = {
'messages': isTextCompletion === false ? request.body.messages : undefined,
'prompt': isTextCompletion === true ? textPrompt : undefined,
'model': request.body.model,
'temperature': request.body.temperature,
'max_tokens': request.body.max_tokens,
'stream': request.body.stream,
'presence_penalty': request.body.presence_penalty,
'frequency_penalty': request.body.frequency_penalty,
'top_p': request.body.top_p,
'top_k': request.body.top_k,
'stop': isTextCompletion === false ? request.body.stop : undefined,
'logit_bias': request.body.logit_bias,
'seed': request.body.seed,
...bodyParams,
};
if (request.body.chat_completion_source === CHAT_COMPLETION_SOURCES.CUSTOM) {
excludeKeysByYaml(requestBody, request.body.custom_exclude_body);
}
/** @type {import('node-fetch').RequestInit} */
const config = {
method: 'post',
@@ -708,27 +782,12 @@ router.post('/generate', jsonParser, function (request, response) {
'Authorization': 'Bearer ' + apiKey,
...headers,
},
body: JSON.stringify({
'messages': isTextCompletion === false ? request.body.messages : undefined,
'prompt': isTextCompletion === true ? textPrompt : undefined,
'model': request.body.model,
'temperature': request.body.temperature,
'max_tokens': request.body.max_tokens,
'stream': request.body.stream,
'presence_penalty': request.body.presence_penalty,
'frequency_penalty': request.body.frequency_penalty,
'top_p': request.body.top_p,
'top_k': request.body.top_k,
'stop': isTextCompletion === false ? request.body.stop : undefined,
'logit_bias': request.body.logit_bias,
'seed': request.body.seed,
...bodyParams,
}),
body: JSON.stringify(requestBody),
signal: controller.signal,
timeout: 0,
};
console.log(JSON.parse(String(config.body)));
console.log(requestBody);
makeRequest(config, response, request);

View File

@@ -1,13 +1,82 @@
const express = require('express');
const fetch = require('node-fetch').default;
const _ = require('lodash');
const Readable = require('stream').Readable;
const { jsonParser } = require('../../express-common');
const { TEXTGEN_TYPES } = require('../../constants');
const { forwardFetchResponse } = require('../../util');
const { TEXTGEN_TYPES, TOGETHERAI_KEYS, OLLAMA_KEYS } = require('../../constants');
const { forwardFetchResponse, trimV1 } = require('../../util');
const { setAdditionalHeaders } = require('../../additional-headers');
const router = express.Router();
/**
* Special boy's steaming routine. Wrap this abomination into proper SSE stream.
* @param {import('node-fetch').Response} jsonStream JSON stream
* @param {import('express').Request} request Express request
* @param {import('express').Response} response Express response
* @returns {Promise<any>} Nothing valuable
*/
async function parseOllamaStream(jsonStream, request, response) {
try {
let partialData = '';
jsonStream.body.on('data', (data) => {
const chunk = data.toString();
partialData += chunk;
while (true) {
let json;
try {
json = JSON.parse(partialData);
} catch (e) {
break;
}
const text = json.response || '';
const chunk = { choices: [{ text }] };
response.write(`data: ${JSON.stringify(chunk)}\n\n`);
partialData = '';
}
});
request.socket.on('close', function () {
if (jsonStream.body instanceof Readable) jsonStream.body.destroy();
response.end();
});
jsonStream.body.on('end', () => {
console.log('Streaming request finished');
response.write('data: [DONE]\n\n');
response.end();
});
} catch (error) {
console.log('Error forwarding streaming response:', error);
if (!response.headersSent) {
return response.status(500).send({ error: true });
} else {
return response.end();
}
}
}
/**
* Abort KoboldCpp generation request.
* @param {string} url Server base URL
* @returns {Promise<void>} Promise resolving when we are done
*/
async function abortKoboldCppRequest(url) {
try {
console.log('Aborting Kobold generation...');
const abortResponse = await fetch(`${url}/api/extra/abort`, {
method: 'POST',
});
if (!abortResponse.ok) {
console.log('Error sending abort request to Kobold:', abortResponse.status, abortResponse.statusText);
}
} catch (error) {
console.log(error);
}
}
//************** Ooba/OpenAI text completions API
router.post('/status', jsonParser, async function (request, response) {
if (!request.body) return response.sendStatus(400);
@@ -18,9 +87,7 @@ router.post('/status', jsonParser, async function (request, response) {
}
console.log('Trying to connect to API:', request.body);
// Convert to string + remove trailing slash + /v1 suffix
const baseUrl = String(request.body.api_server).replace(/\/$/, '').replace(/\/v1$/, '');
const baseUrl = trimV1(request.body.api_server);
const args = {
headers: { 'Content-Type': 'application/json' },
@@ -38,6 +105,7 @@ router.post('/status', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.OOBA:
case TEXTGEN_TYPES.APHRODITE:
case TEXTGEN_TYPES.KOBOLDCPP:
case TEXTGEN_TYPES.LLAMACPP:
url += '/v1/models';
break;
case TEXTGEN_TYPES.MANCER:
@@ -46,6 +114,12 @@ router.post('/status', jsonParser, async function (request, response) {
case TEXTGEN_TYPES.TABBY:
url += '/v1/model/list';
break;
case TEXTGEN_TYPES.TOGETHERAI:
url += '/api/models?&info';
break;
case TEXTGEN_TYPES.OLLAMA:
url += '/api/tags';
break;
}
}
@@ -56,13 +130,22 @@ router.post('/status', jsonParser, async function (request, response) {
return response.status(400);
}
const data = await modelsReply.json();
let data = await modelsReply.json();
if (request.body.legacy_api) {
console.log('Legacy API response:', data);
return response.send({ result: data?.result });
}
// Rewrap to OAI-like response
if (request.body.api_type === TEXTGEN_TYPES.TOGETHERAI && Array.isArray(data)) {
data = { data: data.map(x => ({ id: x.name, ...x })) };
}
if (request.body.api_type === TEXTGEN_TYPES.OLLAMA && Array.isArray(data.models)) {
data = { data: data.models.map(x => ({ id: x.name, ...x })) };
}
if (!Array.isArray(data.data)) {
console.log('Models response is not an array.');
return response.status(400);
@@ -117,8 +200,8 @@ router.post('/status', jsonParser, async function (request, response) {
}
});
router.post('/generate', jsonParser, async function (request, response_generate) {
if (!request.body) return response_generate.sendStatus(400);
router.post('/generate', jsonParser, async function (request, response) {
if (!request.body) return response.sendStatus(400);
try {
if (request.body.api_server.indexOf('localhost') !== -1) {
@@ -130,12 +213,15 @@ router.post('/generate', jsonParser, async function (request, response_generate)
const controller = new AbortController();
request.socket.removeAllListeners('close');
request.socket.on('close', function () {
request.socket.on('close', async function () {
if (request.body.api_type === TEXTGEN_TYPES.KOBOLDCPP && !response.writableEnded) {
await abortKoboldCppRequest(trimV1(baseUrl));
}
controller.abort();
});
// Convert to string + remove trailing slash + /v1 suffix
let url = String(baseUrl).replace(/\/$/, '').replace(/\/v1$/, '');
let url = trimV1(baseUrl);
if (request.body.legacy_api) {
url += '/v1/generate';
@@ -145,11 +231,18 @@ router.post('/generate', jsonParser, async function (request, response_generate)
case TEXTGEN_TYPES.OOBA:
case TEXTGEN_TYPES.TABBY:
case TEXTGEN_TYPES.KOBOLDCPP:
case TEXTGEN_TYPES.TOGETHERAI:
url += '/v1/completions';
break;
case TEXTGEN_TYPES.MANCER:
url += '/oai/v1/completions';
break;
case TEXTGEN_TYPES.LLAMACPP:
url += '/completion';
break;
case TEXTGEN_TYPES.OLLAMA:
url += '/api/generate';
break;
}
}
@@ -163,10 +256,32 @@ router.post('/generate', jsonParser, async function (request, response_generate)
setAdditionalHeaders(request, args, baseUrl);
if (request.body.stream) {
if (request.body.api_type === TEXTGEN_TYPES.TOGETHERAI) {
const stop = Array.isArray(request.body.stop) ? request.body.stop[0] : '';
request.body = _.pickBy(request.body, (_, key) => TOGETHERAI_KEYS.includes(key));
if (typeof stop === 'string' && stop.length > 0) {
request.body.stop = stop;
}
args.body = JSON.stringify(request.body);
}
if (request.body.api_type === TEXTGEN_TYPES.OLLAMA) {
args.body = JSON.stringify({
model: request.body.model,
prompt: request.body.prompt,
stream: request.body.stream ?? false,
raw: true,
options: _.pickBy(request.body, (_, key) => OLLAMA_KEYS.includes(key)),
});
}
if (request.body.api_type === TEXTGEN_TYPES.OLLAMA && request.body.stream) {
const stream = await fetch(url, args);
parseOllamaStream(stream, request, response);
} else if (request.body.stream) {
const completionsStream = await fetch(url, args);
// Pipe remote SSE stream to Express response
forwardFetchResponse(completionsStream, response_generate);
forwardFetchResponse(completionsStream, response);
}
else {
const completionsReply = await fetch(url, args);
@@ -181,28 +296,152 @@ router.post('/generate', jsonParser, async function (request, response_generate)
data['choices'] = [{ text }];
}
return response_generate.send(data);
return response.send(data);
} else {
const text = await completionsReply.text();
const errorBody = { error: true, status: completionsReply.status, response: text };
if (!response_generate.headersSent) {
return response_generate.send(errorBody);
if (!response.headersSent) {
return response.send(errorBody);
}
return response_generate.end();
return response.end();
}
}
} catch (error) {
let value = { error: true, status: error?.status, response: error?.statusText };
console.log('Endpoint error:', error);
if (!response_generate.headersSent) {
return response_generate.send(value);
if (!response.headersSent) {
return response.send(value);
}
return response_generate.end();
return response.end();
}
});
const ollama = express.Router();
ollama.post('/download', jsonParser, async function (request, response) {
try {
if (!request.body.name || !request.body.api_server) return response.sendStatus(400);
const name = request.body.name;
const url = String(request.body.api_server).replace(/\/$/, '');
const fetchResponse = await fetch(`${url}/api/pull`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
name: name,
stream: false,
}),
timeout: 0,
});
if (!fetchResponse.ok) {
console.log('Download error:', fetchResponse.status, fetchResponse.statusText);
return response.status(fetchResponse.status).send({ error: true });
}
return response.send({ ok: true });
} catch (error) {
console.error(error);
return response.status(500);
}
});
ollama.post('/caption-image', jsonParser, async function (request, response) {
try {
if (!request.body.server_url || !request.body.model) {
return response.sendStatus(400);
}
console.log('Ollama caption request:', request.body);
const baseUrl = trimV1(request.body.server_url);
const fetchResponse = await fetch(`${baseUrl}/api/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: request.body.model,
prompt: request.body.prompt,
images: [request.body.image],
stream: false,
}),
timeout: 0,
});
if (!fetchResponse.ok) {
console.log('Ollama caption error:', fetchResponse.status, fetchResponse.statusText);
return response.status(500).send({ error: true });
}
const data = await fetchResponse.json();
console.log('Ollama caption response:', data);
const caption = data?.response || '';
if (!caption) {
console.log('Ollama caption is empty.');
return response.status(500).send({ error: true });
}
return response.send({ caption });
} catch (error) {
console.error(error);
return response.status(500);
}
});
const llamacpp = express.Router();
llamacpp.post('/caption-image', jsonParser, async function (request, response) {
try {
if (!request.body.server_url) {
return response.sendStatus(400);
}
console.log('LlamaCpp caption request:', request.body);
const baseUrl = trimV1(request.body.server_url);
const fetchResponse = await fetch(`${baseUrl}/completion`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
timeout: 0,
body: JSON.stringify({
prompt: `USER:[img-1]${String(request.body.prompt).trim()}\nASSISTANT:`,
image_data: [{ data: request.body.image, id: 1 }],
temperature: 0.1,
stream: false,
stop: ['USER:', '</s>'],
}),
});
if (!fetchResponse.ok) {
console.log('LlamaCpp caption error:', fetchResponse.status, fetchResponse.statusText);
return response.status(500).send({ error: true });
}
const data = await fetchResponse.json();
console.log('LlamaCpp caption response:', data);
const caption = data?.content || '';
if (!caption) {
console.log('LlamaCpp caption is empty.');
return response.status(500).send({ error: true });
}
return response.send({ caption });
} catch (error) {
console.error(error);
return response.status(500);
}
});
router.use('/ollama', ollama);
router.use('/llamacpp', llamacpp);
module.exports = { router };

View File

@@ -4,6 +4,7 @@ const readline = require('readline');
const express = require('express');
const sanitize = require('sanitize-filename');
const writeFileAtomicSync = require('write-file-atomic').sync;
const yaml = require('yaml');
const _ = require('lodash');
const encode = require('png-chunks-encode');
@@ -19,6 +20,7 @@ const characterCardParser = require('../character-card-parser.js');
const { readWorldInfoFile } = require('./worldinfo');
const { invalidateThumbnail } = require('./thumbnails');
const { importRisuSprites } = require('./sprites');
const defaultAvatarPath = './public/img/ai4.png';
let characters = {};
@@ -394,6 +396,36 @@ function convertWorldInfoToCharacterBook(name, entries) {
return result;
}
/**
* Import a character from a YAML file.
* @param {string} uploadPath Path to the uploaded file
* @param {import('express').Response} response Express response object
*/
function importFromYaml(uploadPath, response) {
const fileText = fs.readFileSync(uploadPath, 'utf8');
fs.rmSync(uploadPath);
const yamlData = yaml.parse(fileText);
console.log('importing from yaml');
yamlData.name = sanitize(yamlData.name);
const fileName = getPngName(yamlData.name);
let char = convertToV2({
'name': yamlData.name,
'description': yamlData.context ?? '',
'first_mes': yamlData.greeting ?? '',
'create_date': humanizedISO8601DateTime(),
'chat': `${yamlData.name} - ${humanizedISO8601DateTime()}`,
'personality': '',
'creatorcomment': '',
'avatar': 'none',
'mes_example': '',
'scenario': '',
'talkativeness': 0.5,
'creator': '',
'tags': '',
});
charaWrite(defaultAvatarPath, JSON.stringify(char), fileName, response, { file_name: fileName });
}
const router = express.Router();
router.post('/create', urlencodedParser, async function (request, response) {
@@ -760,144 +792,147 @@ function getPngName(file) {
}
router.post('/import', urlencodedParser, async function (request, response) {
if (!request.body || request.file === undefined) return response.sendStatus(400);
if (!request.body || !request.file) return response.sendStatus(400);
let png_name = '';
let filedata = request.file;
let uploadPath = path.join(UPLOADS_PATH, filedata.filename);
var format = request.body.file_type;
const defaultAvatarPath = './public/img/ai4.png';
//console.log(format);
if (filedata) {
if (format == 'json') {
fs.readFile(uploadPath, 'utf8', async (err, data) => {
fs.unlinkSync(uploadPath);
let format = request.body.file_type;
if (err) {
console.log(err);
response.send({ error: true });
}
if (format == 'yaml' || format == 'yml') {
try {
importFromYaml(uploadPath, response);
} catch (err) {
console.log(err);
response.send({ error: true });
}
} else if (format == 'json') {
fs.readFile(uploadPath, 'utf8', async (err, data) => {
fs.unlinkSync(uploadPath);
let jsonData = JSON.parse(data);
if (jsonData.spec !== undefined) {
console.log('importing from v2 json');
importRisuSprites(jsonData);
unsetFavFlag(jsonData);
jsonData = readFromV2(jsonData);
jsonData['create_date'] = humanizedISO8601DateTime();
png_name = getPngName(jsonData.data?.name || jsonData.name);
let char = JSON.stringify(jsonData);
charaWrite(defaultAvatarPath, char, png_name, response, { file_name: png_name });
} else if (jsonData.name !== undefined) {
console.log('importing from v1 json');
jsonData.name = sanitize(jsonData.name);
if (jsonData.creator_notes) {
jsonData.creator_notes = jsonData.creator_notes.replace('Creator\'s notes go here.', '');
}
png_name = getPngName(jsonData.name);
let char = {
'name': jsonData.name,
'description': jsonData.description ?? '',
'creatorcomment': jsonData.creatorcomment ?? jsonData.creator_notes ?? '',
'personality': jsonData.personality ?? '',
'first_mes': jsonData.first_mes ?? '',
'avatar': 'none',
'chat': jsonData.name + ' - ' + humanizedISO8601DateTime(),
'mes_example': jsonData.mes_example ?? '',
'scenario': jsonData.scenario ?? '',
'create_date': humanizedISO8601DateTime(),
'talkativeness': jsonData.talkativeness ?? 0.5,
'creator': jsonData.creator ?? '',
'tags': jsonData.tags ?? '',
};
char = convertToV2(char);
let charJSON = JSON.stringify(char);
charaWrite(defaultAvatarPath, charJSON, png_name, response, { file_name: png_name });
} else if (jsonData.char_name !== undefined) {//json Pygmalion notepad
console.log('importing from gradio json');
jsonData.char_name = sanitize(jsonData.char_name);
if (jsonData.creator_notes) {
jsonData.creator_notes = jsonData.creator_notes.replace('Creator\'s notes go here.', '');
}
png_name = getPngName(jsonData.char_name);
let char = {
'name': jsonData.char_name,
'description': jsonData.char_persona ?? '',
'creatorcomment': jsonData.creatorcomment ?? jsonData.creator_notes ?? '',
'personality': '',
'first_mes': jsonData.char_greeting ?? '',
'avatar': 'none',
'chat': jsonData.name + ' - ' + humanizedISO8601DateTime(),
'mes_example': jsonData.example_dialogue ?? '',
'scenario': jsonData.world_scenario ?? '',
'create_date': humanizedISO8601DateTime(),
'talkativeness': jsonData.talkativeness ?? 0.5,
'creator': jsonData.creator ?? '',
'tags': jsonData.tags ?? '',
};
char = convertToV2(char);
let charJSON = JSON.stringify(char);
charaWrite(defaultAvatarPath, charJSON, png_name, response, { file_name: png_name });
} else {
console.log('Incorrect character format .json');
response.send({ error: true });
}
});
} else {
try {
var img_data = await charaRead(uploadPath, format);
if (img_data === undefined) throw new Error('Failed to read character data');
let jsonData = JSON.parse(img_data);
jsonData.name = sanitize(jsonData.data?.name || jsonData.name);
png_name = getPngName(jsonData.name);
if (jsonData.spec !== undefined) {
console.log('Found a v2 character file.');
importRisuSprites(jsonData);
unsetFavFlag(jsonData);
jsonData = readFromV2(jsonData);
jsonData['create_date'] = humanizedISO8601DateTime();
const char = JSON.stringify(jsonData);
await charaWrite(uploadPath, char, png_name, response, { file_name: png_name });
fs.unlinkSync(uploadPath);
} else if (jsonData.name !== undefined) {
console.log('Found a v1 character file.');
if (jsonData.creator_notes) {
jsonData.creator_notes = jsonData.creator_notes.replace('Creator\'s notes go here.', '');
}
let char = {
'name': jsonData.name,
'description': jsonData.description ?? '',
'creatorcomment': jsonData.creatorcomment ?? jsonData.creator_notes ?? '',
'personality': jsonData.personality ?? '',
'first_mes': jsonData.first_mes ?? '',
'avatar': 'none',
'chat': jsonData.name + ' - ' + humanizedISO8601DateTime(),
'mes_example': jsonData.mes_example ?? '',
'scenario': jsonData.scenario ?? '',
'create_date': humanizedISO8601DateTime(),
'talkativeness': jsonData.talkativeness ?? 0.5,
'creator': jsonData.creator ?? '',
'tags': jsonData.tags ?? '',
};
char = convertToV2(char);
const charJSON = JSON.stringify(char);
await charaWrite(uploadPath, charJSON, png_name, response, { file_name: png_name });
fs.unlinkSync(uploadPath);
} else {
console.log('Unknown character card format');
response.send({ error: true });
}
} catch (err) {
if (err) {
console.log(err);
response.send({ error: true });
}
let jsonData = JSON.parse(data);
if (jsonData.spec !== undefined) {
console.log('importing from v2 json');
importRisuSprites(jsonData);
unsetFavFlag(jsonData);
jsonData = readFromV2(jsonData);
jsonData['create_date'] = humanizedISO8601DateTime();
png_name = getPngName(jsonData.data?.name || jsonData.name);
let char = JSON.stringify(jsonData);
charaWrite(defaultAvatarPath, char, png_name, response, { file_name: png_name });
} else if (jsonData.name !== undefined) {
console.log('importing from v1 json');
jsonData.name = sanitize(jsonData.name);
if (jsonData.creator_notes) {
jsonData.creator_notes = jsonData.creator_notes.replace('Creator\'s notes go here.', '');
}
png_name = getPngName(jsonData.name);
let char = {
'name': jsonData.name,
'description': jsonData.description ?? '',
'creatorcomment': jsonData.creatorcomment ?? jsonData.creator_notes ?? '',
'personality': jsonData.personality ?? '',
'first_mes': jsonData.first_mes ?? '',
'avatar': 'none',
'chat': jsonData.name + ' - ' + humanizedISO8601DateTime(),
'mes_example': jsonData.mes_example ?? '',
'scenario': jsonData.scenario ?? '',
'create_date': humanizedISO8601DateTime(),
'talkativeness': jsonData.talkativeness ?? 0.5,
'creator': jsonData.creator ?? '',
'tags': jsonData.tags ?? '',
};
char = convertToV2(char);
let charJSON = JSON.stringify(char);
charaWrite(defaultAvatarPath, charJSON, png_name, response, { file_name: png_name });
} else if (jsonData.char_name !== undefined) {//json Pygmalion notepad
console.log('importing from gradio json');
jsonData.char_name = sanitize(jsonData.char_name);
if (jsonData.creator_notes) {
jsonData.creator_notes = jsonData.creator_notes.replace('Creator\'s notes go here.', '');
}
png_name = getPngName(jsonData.char_name);
let char = {
'name': jsonData.char_name,
'description': jsonData.char_persona ?? '',
'creatorcomment': jsonData.creatorcomment ?? jsonData.creator_notes ?? '',
'personality': '',
'first_mes': jsonData.char_greeting ?? '',
'avatar': 'none',
'chat': jsonData.name + ' - ' + humanizedISO8601DateTime(),
'mes_example': jsonData.example_dialogue ?? '',
'scenario': jsonData.world_scenario ?? '',
'create_date': humanizedISO8601DateTime(),
'talkativeness': jsonData.talkativeness ?? 0.5,
'creator': jsonData.creator ?? '',
'tags': jsonData.tags ?? '',
};
char = convertToV2(char);
let charJSON = JSON.stringify(char);
charaWrite(defaultAvatarPath, charJSON, png_name, response, { file_name: png_name });
} else {
console.log('Incorrect character format .json');
response.send({ error: true });
}
});
} else {
try {
var img_data = await charaRead(uploadPath, format);
if (img_data === undefined) throw new Error('Failed to read character data');
let jsonData = JSON.parse(img_data);
jsonData.name = sanitize(jsonData.data?.name || jsonData.name);
png_name = getPngName(jsonData.name);
if (jsonData.spec !== undefined) {
console.log('Found a v2 character file.');
importRisuSprites(jsonData);
unsetFavFlag(jsonData);
jsonData = readFromV2(jsonData);
jsonData['create_date'] = humanizedISO8601DateTime();
const char = JSON.stringify(jsonData);
await charaWrite(uploadPath, char, png_name, response, { file_name: png_name });
fs.unlinkSync(uploadPath);
} else if (jsonData.name !== undefined) {
console.log('Found a v1 character file.');
if (jsonData.creator_notes) {
jsonData.creator_notes = jsonData.creator_notes.replace('Creator\'s notes go here.', '');
}
let char = {
'name': jsonData.name,
'description': jsonData.description ?? '',
'creatorcomment': jsonData.creatorcomment ?? jsonData.creator_notes ?? '',
'personality': jsonData.personality ?? '',
'first_mes': jsonData.first_mes ?? '',
'avatar': 'none',
'chat': jsonData.name + ' - ' + humanizedISO8601DateTime(),
'mes_example': jsonData.mes_example ?? '',
'scenario': jsonData.scenario ?? '',
'create_date': humanizedISO8601DateTime(),
'talkativeness': jsonData.talkativeness ?? 0.5,
'creator': jsonData.creator ?? '',
'tags': jsonData.tags ?? '',
};
char = convertToV2(char);
const charJSON = JSON.stringify(char);
await charaWrite(uploadPath, charJSON, png_name, response, { file_name: png_name });
fs.unlinkSync(uploadPath);
} else {
console.log('Unknown character card format');
response.send({ error: true });
}
} catch (err) {
console.log(err);
response.send({ error: true });
}
}
});

View File

@@ -1,20 +1,30 @@
const fetch = require('node-fetch').default;
const express = require('express');
const AIHorde = require('../ai_horde');
const { getVersion, delay } = require('../util');
const { getVersion, delay, Cache } = require('../util');
const { readSecret, SECRET_KEYS } = require('./secrets');
const { jsonParser } = require('../express-common');
const ANONYMOUS_KEY = '0000000000';
const cache = new Cache(60 * 1000);
const router = express.Router();
/**
* Returns the AIHorde client agent.
* @returns {Promise<string>} AIHorde client agent
*/
async function getClientAgent() {
const version = await getVersion();
return version?.agent || 'SillyTavern:UNKNOWN:Cohee#1207';
}
/**
* Returns the AIHorde client.
* @returns {Promise<AIHorde>} AIHorde client
*/
async function getHordeClient() {
const version = await getVersion();
const ai_horde = new AIHorde({
client_agent: version?.agent || 'SillyTavern:UNKNOWN:Cohee#1207',
client_agent: await getClientAgent(),
});
return ai_horde;
}
@@ -36,29 +46,122 @@ function sanitizeHordeImagePrompt(prompt) {
prompt = prompt.replace(/\b(boy)\b/gmi, 'man');
prompt = prompt.replace(/\b(girls)\b/gmi, 'women');
prompt = prompt.replace(/\b(boys)\b/gmi, 'men');
//always remove these high risk words from prompt, as they add little value to image gen while increasing the risk the prompt gets flagged
prompt = prompt.replace(/\b(under.age|under.aged|underage|underaged|loli|pedo|pedophile|(\w+).year.old|(\w+).years.old|minor|prepubescent|minors|shota)\b/gmi, '');
//if nsfw is detected, do not remove it but apply additional precautions
let isNsfw = prompt.match(/\b(cock|ahegao|hentai|uncensored|lewd|cocks|deepthroat|deepthroating|dick|dicks|cumshot|lesbian|fuck|fucked|fucking|sperm|naked|nipples|tits|boobs|breasts|boob|breast|topless|ass|butt|fingering|masturbate|masturbating|bitch|blowjob|pussy|piss|asshole|dildo|dildos|vibrator|erection|foreskin|handjob|nude|penis|porn|vibrator|virgin|vagina|vulva|threesome|orgy|bdsm|hickey|condom|testicles|anal|bareback|bukkake|creampie|stripper|strap-on|missionary|clitoris|clit|clitty|cowgirl|fleshlight|sex|buttplug|milf|oral|sucking|bondage|orgasm|scissoring|railed|slut|sluts|slutty|cumming|cunt|faggot|sissy|anal|anus|cum|semen|scat|nsfw|xxx|explicit|erotic|horny|aroused|jizz|moan|rape|raped|raping|throbbing|humping)\b/gmi);
if (isNsfw) {
//replace risky subject nouns with person
prompt = prompt.replace(/\b(youngster|infant|baby|toddler|child|teen|kid|kiddie|kiddo|teenager|student|preteen|pre.teen)\b/gmi, 'person');
//remove risky adjectives and related words
prompt = prompt.replace(/\b(young|younger|youthful|youth|small|smaller|smallest|girly|boyish|lil|tiny|teenaged|lit[tl]le|school.aged|school|highschool|kindergarten|teens|children|kids)\b/gmi, '');
}
//replace risky subject nouns with person
prompt = prompt.replace(/\b(youngster|infant|baby|toddler|child|teen|kid|kiddie|kiddo|teenager|student|preteen|pre.teen)\b/gmi, 'person');
//remove risky adjectives and related words
prompt = prompt.replace(/\b(young|younger|youthful|youth|small|smaller|smallest|girly|boyish|lil|tiny|teenaged|lit[tl]le|school.aged|school|highschool|kindergarten|teens|children|kids)\b/gmi, '');
return prompt;
}
const router = express.Router();
router.post('/text-workers', jsonParser, async (request, response) => {
try {
const cachedWorkers = cache.get('workers');
if (cachedWorkers && !request.body.force) {
return response.send(cachedWorkers);
}
const agent = await getClientAgent();
const fetchResult = await fetch('https://horde.koboldai.net/api/v2/workers?type=text', {
headers: {
'Client-Agent': agent,
},
});
const data = await fetchResult.json();
cache.set('workers', data);
return response.send(data);
} catch (error) {
console.error(error);
response.sendStatus(500);
}
});
router.post('/text-models', jsonParser, async (request, response) => {
try {
const cachedModels = cache.get('models');
if (cachedModels && !request.body.force) {
return response.send(cachedModels);
}
const agent = await getClientAgent();
const fetchResult = await fetch('https://horde.koboldai.net/api/v2/status/models?type=text', {
headers: {
'Client-Agent': agent,
},
});
const data = await fetchResult.json();
cache.set('models', data);
return response.send(data);
} catch (error) {
console.error(error);
response.sendStatus(500);
}
});
router.post('/status', jsonParser, async (_, response) => {
try {
const agent = await getClientAgent();
const fetchResult = await fetch('https://horde.koboldai.net/api/v2/status/heartbeat', {
headers: {
'Client-Agent': agent,
},
});
return response.send({ ok: fetchResult.ok });
} catch (error) {
console.error(error);
response.sendStatus(500);
}
});
router.post('/cancel-task', jsonParser, async (request, response) => {
try {
const taskId = request.body.taskId;
const agent = await getClientAgent();
const fetchResult = await fetch(`https://horde.koboldai.net/api/v2/generate/text/status/${taskId}`, {
method: 'DELETE',
headers: {
'Client-Agent': agent,
},
});
const data = await fetchResult.json();
console.log(`Cancelled Horde task ${taskId}`);
return response.send(data);
} catch (error) {
console.error(error);
response.sendStatus(500);
}
});
router.post('/task-status', jsonParser, async (request, response) => {
try {
const taskId = request.body.taskId;
const agent = await getClientAgent();
const fetchResult = await fetch(`https://horde.koboldai.net/api/v2/generate/text/status/${taskId}`, {
headers: {
'Client-Agent': agent,
},
});
const data = await fetchResult.json();
console.log(`Horde task ${taskId} status:`, data);
return response.send(data);
} catch (error) {
console.error(error);
response.sendStatus(500);
}
});
router.post('/generate-text', jsonParser, async (request, response) => {
const api_key_horde = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY;
const apiKey = readSecret(SECRET_KEYS.HORDE) || ANONYMOUS_KEY;
const url = 'https://horde.koboldai.net/api/v2/generate/text/async';
const agent = await getClientAgent();
console.log(request.body);
try {
@@ -67,8 +170,8 @@ router.post('/generate-text', jsonParser, async (request, response) => {
body: JSON.stringify(request.body),
headers: {
'Content-Type': 'application/json',
'apikey': api_key_horde,
'Client-Agent': String(request.header('Client-Agent')),
'apikey': apiKey,
'Client-Agent': agent,
},
});

View File

@@ -4,22 +4,35 @@ const express = require('express');
const FormData = require('form-data');
const fs = require('fs');
const { jsonParser, urlencodedParser } = require('../express-common');
const { getConfigValue, mergeObjectWithYaml, excludeKeysByYaml } = require('../util');
const router = express.Router();
router.post('/caption-image', jsonParser, async (request, response) => {
try {
let key = '';
let headers = {};
let bodyParams = {};
if (request.body.api === 'openai') {
if (request.body.api === 'openai' && !request.body.reverse_proxy) {
key = readSecret(SECRET_KEYS.OPENAI);
}
if (request.body.api === 'openrouter') {
if (request.body.api === 'openrouter' && !request.body.reverse_proxy) {
key = readSecret(SECRET_KEYS.OPENROUTER);
}
if (!key) {
if (request.body.reverse_proxy && request.body.proxy_password) {
key = request.body.proxy_password;
}
if (request.body.api === 'custom') {
key = readSecret(SECRET_KEYS.CUSTOM);
mergeObjectWithYaml(bodyParams, request.body.custom_include_body);
mergeObjectWithYaml(headers, request.body.custom_include_headers);
}
if (!key && !request.body.reverse_proxy && request.body.api !== 'custom') {
console.log('No key found for API', request.body.api);
return response.sendStatus(400);
}
@@ -36,12 +49,24 @@ router.post('/caption-image', jsonParser, async (request, response) => {
},
],
max_tokens: 500,
...bodyParams,
};
const captionSystemPrompt = getConfigValue('openai.captionSystemPrompt');
if (captionSystemPrompt) {
body.messages.unshift({
role: 'system',
content: captionSystemPrompt,
});
}
if (request.body.api === 'custom') {
excludeKeysByYaml(body, request.body.custom_exclude_body);
}
console.log('Multimodal captioning request', body);
let apiUrl = '';
let headers = {};
if (request.body.api === 'openrouter') {
apiUrl = 'https://openrouter.ai/api/v1/chat/completions';
@@ -52,6 +77,14 @@ router.post('/caption-image', jsonParser, async (request, response) => {
apiUrl = 'https://api.openai.com/v1/chat/completions';
}
if (request.body.reverse_proxy) {
apiUrl = `${request.body.reverse_proxy}/chat/completions`;
}
if (request.body.api === 'custom') {
apiUrl = `${request.body.server_url}/chat/completions`;
}
const result = await fetch(apiUrl, {
method: 'POST',
headers: {

View File

@@ -1,74 +1,67 @@
/**
* Convert a prompt from the ChatML objects to the format used by Claude.
* @param {object[]} messages Array of messages
* @param {boolean} addHumanPrefix Add Human prefix
* @param {boolean} addAssistantPostfix Add Assistant postfix
* @param {boolean} withSystemPrompt Build system prompt before "\n\nHuman: "
* @param {boolean} addAssistantPostfix Add Assistant postfix.
* @param {string} addAssistantPrefill Add Assistant prefill after the assistant postfix.
* @param {boolean} withSysPromptSupport Indicates if the Claude model supports the system prompt format.
* @param {boolean} useSystemPrompt Indicates if the system prompt format should be used.
* @param {string} addSysHumanMsg Add Human message between system prompt and assistant.
* @returns {string} Prompt for Claude
* @copyright Prompt Conversion script taken from RisuAI by kwaroran (GPLv3).
*/
function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix, withSystemPrompt) {
// Claude doesn't support message names, so we'll just add them to the message content.
for (const message of messages) {
if (message.name && message.role !== 'system') {
message.content = message.name + ': ' + message.content;
delete message.name;
function convertClaudePrompt(messages, addAssistantPostfix, addAssistantPrefill, withSysPromptSupport, useSystemPrompt, addSysHumanMsg) {
//Prepare messages for claude.
if (messages.length > 0) {
messages[0].role = 'system';
//Add the assistant's message to the end of messages.
if (addAssistantPostfix) {
messages.push({
role: 'assistant',
content: addAssistantPrefill || '',
});
}
}
let systemPrompt = '';
if (withSystemPrompt) {
let lastSystemIdx = -1;
for (let i = 0; i < messages.length - 1; i++) {
const message = messages[i];
if (message.role === 'system' && !message.name) {
systemPrompt += message.content + '\n\n';
} else {
lastSystemIdx = i - 1;
break;
// Find the index of the first message with an assistant role and check for a "'user' role/Human:" before it.
let hasUser = false;
const firstAssistantIndex = messages.findIndex((message, i) => {
if (i >= 0 && (message.role === 'user' || message.content.includes('\n\nHuman: '))) {
hasUser = true;
}
return message.role === 'assistant' && i > 0;
});
// When 2.1+ and 'Use system prompt" checked, switches to the system prompt format by setting the first message's role to the 'system'.
// Inserts the human's message before the first the assistant one, if there are no such message or prefix found.
if (withSysPromptSupport && useSystemPrompt) {
messages[0].role = 'system';
if (firstAssistantIndex > 0 && addSysHumanMsg && !hasUser) {
messages.splice(firstAssistantIndex, 0, {
role: 'user',
content: addSysHumanMsg,
});
}
} else {
// Otherwise, use the default message format by setting the first message's role to 'user'(compatible with all claude models including 2.1.)
messages[0].role = 'user';
// Fix messages order for default message format when(messages > Context Size) by merging two messages with "\n\nHuman: " prefixes into one, before the first Assistant's message.
if (firstAssistantIndex > 0) {
messages[firstAssistantIndex - 1].role = firstAssistantIndex - 1 !== 0 && messages[firstAssistantIndex - 1].role === 'user' ? 'FixHumMsg' : messages[firstAssistantIndex - 1].role;
}
}
if (lastSystemIdx >= 0) {
messages.splice(0, lastSystemIdx + 1);
}
}
let requestPrompt = messages.map((v) => {
let prefix = '';
switch (v.role) {
case 'assistant':
prefix = '\n\nAssistant: ';
break;
case 'user':
prefix = '\n\nHuman: ';
break;
case 'system':
// According to the Claude docs, H: and A: should be used for example conversations.
if (v.name === 'example_assistant') {
prefix = '\n\nA: ';
} else if (v.name === 'example_user') {
prefix = '\n\nH: ';
} else {
prefix = '\n\n';
}
break;
}
return prefix + v.content;
// Convert messages to the prompt.
let requestPrompt = messages.map((v, i) => {
// Set prefix according to the role.
let prefix = {
'assistant': '\n\nAssistant: ',
'user': '\n\nHuman: ',
'system': i === 0 ? '' : v.name === 'example_assistant' ? '\n\nA: ' : v.name === 'example_user' ? '\n\nH: ' : '\n\n',
'FixHumMsg': '\n\nFirst message: ',
}[v.role] ?? '';
// Claude doesn't support message names, so we'll just add them to the message content.
return `${prefix}${v.name && v.role !== 'system' ? `${v.name}: ` : ''}${v.content}`;
}).join('');
if (addHumanPrefix) {
requestPrompt = '\n\nHuman: ' + requestPrompt;
}
if (addAssistantPostfix) {
requestPrompt = requestPrompt + '\n\nAssistant: ';
}
if (withSystemPrompt) {
requestPrompt = systemPrompt + requestPrompt;
}
return requestPrompt;
}

View File

@@ -25,7 +25,9 @@ const SECRET_KEYS = {
DEEPLX_URL: 'deeplx_url',
MAKERSUITE: 'api_key_makersuite',
SERPAPI: 'api_key_serpapi',
TOGETHERAI: 'api_key_togetherai',
MISTRALAI: 'api_key_mistralai',
CUSTOM: 'api_key_custom',
};
/**

View File

@@ -1,11 +1,12 @@
const express = require('express');
const fetch = require('node-fetch').default;
const sanitize = require('sanitize-filename');
const { getBasicAuthHeader, delay } = require('../util.js');
const { getBasicAuthHeader, delay, getHexString } = require('../util.js');
const fs = require('fs');
const { DIRECTORIES } = require('../constants.js');
const writeFileAtomicSync = require('write-file-atomic').sync;
const { jsonParser } = require('../express-common');
const { readSecret, SECRET_KEYS } = require('./secrets.js');
/**
* Sanitizes a string.
@@ -545,6 +546,99 @@ comfy.post('/generate', jsonParser, async (request, response) => {
}
});
const together = express.Router();
together.post('/models', jsonParser, async (_, response) => {
try {
const key = readSecret(SECRET_KEYS.TOGETHERAI);
if (!key) {
console.log('TogetherAI key not found.');
return response.sendStatus(400);
}
const modelsResponse = await fetch('https://api.together.xyz/api/models', {
method: 'GET',
headers: {
'Authorization': `Bearer ${key}`,
},
});
if (!modelsResponse.ok) {
console.log('TogetherAI returned an error.');
return response.sendStatus(500);
}
const data = await modelsResponse.json();
if (!Array.isArray(data)) {
console.log('TogetherAI returned invalid data.');
return response.sendStatus(500);
}
const models = data
.filter(x => x.display_type === 'image')
.map(x => ({ value: x.name, text: x.display_name }));
return response.send(models);
} catch (error) {
console.log(error);
return response.sendStatus(500);
}
});
together.post('/generate', jsonParser, async (request, response) => {
try {
const key = readSecret(SECRET_KEYS.TOGETHERAI);
if (!key) {
console.log('TogetherAI key not found.');
return response.sendStatus(400);
}
console.log('TogetherAI request:', request.body);
const result = await fetch('https://api.together.xyz/api/inference', {
method: 'POST',
body: JSON.stringify({
request_type: 'image-model-inference',
prompt: request.body.prompt,
negative_prompt: request.body.negative_prompt,
height: request.body.height,
width: request.body.width,
model: request.body.model,
steps: request.body.steps,
n: 1,
seed: Math.floor(Math.random() * 10_000_000), // Limited to 10000 on playground, works fine with more.
sessionKey: getHexString(40), // Don't know if that's supposed to be random or not. It works either way.
}),
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${key}`,
},
});
if (!result.ok) {
console.log('TogetherAI returned an error.');
return response.sendStatus(500);
}
const data = await result.json();
console.log('TogetherAI response:', data);
if (data.status !== 'finished') {
console.log('TogetherAI job failed.');
return response.sendStatus(500);
}
return response.send(data);
} catch (error) {
console.log(error);
return response.sendStatus(500);
}
});
router.use('/comfy', comfy);
router.use('/together', together);
module.exports = { router };

View File

@@ -111,7 +111,8 @@ async function generateThumbnail(type, file) {
try {
const quality = getConfigValue('thumbnailsQuality', 95);
const image = await jimp.read(pathToOriginalFile);
buffer = await image.cover(mySize[0], mySize[1]).quality(quality).getBufferAsync('image/jpeg');
const imgType = type == 'avatar' && getConfigValue('avatarThumbnailsPng', false) ? 'image/png' : 'image/jpeg';
buffer = await image.cover(mySize[0], mySize[1]).quality(quality).getBufferAsync(imgType);
}
catch (inner) {
console.warn(`Thumbnailer can not process the image: ${pathToOriginalFile}. Using original size`);

View File

@@ -622,6 +622,10 @@ router.post('/remote/textgenerationwebui/encode', jsonParser, async function (re
url += '/api/extra/tokencount';
args.body = JSON.stringify({ 'prompt': text });
break;
case TEXTGEN_TYPES.LLAMACPP:
url += '/tokenize';
args.body = JSON.stringify({ 'content': text });
break;
default:
url += '/v1/internal/encode';
args.body = JSON.stringify({ 'text': text });
@@ -637,7 +641,7 @@ router.post('/remote/textgenerationwebui/encode', jsonParser, async function (re
}
const data = await result.json();
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value);
const count = legacyApi ? data?.results[0]?.tokens : (data?.length ?? data?.value ?? data?.tokens?.length);
const ids = legacyApi ? [] : (data?.tokens ?? data?.ids ?? []);
return response.send({ count, ids });

View File

@@ -106,6 +106,10 @@ router.post('/deepl', jsonParser, async (request, response) => {
return response.sendStatus(400);
}
if (request.body.lang === 'zh-CN' || request.body.lang === 'zh-TW') {
request.body.lang = 'ZH';
}
const text = request.body.text;
const lang = request.body.lang;
const formality = getConfigValue('deepl.formality', 'default');
@@ -221,7 +225,7 @@ router.post('/deeplx', jsonParser, async (request, response) => {
const text = request.body.text;
let lang = request.body.lang;
if (request.body.lang === 'zh-CN') {
if (request.body.lang === 'zh-CN' || request.body.lang === 'zh-TW') {
lang = 'ZH';
}

View File

@@ -105,6 +105,21 @@ function delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
/**
* Generates a random hex string of the given length.
* @param {number} length String length
* @returns {string} Random hex string
* @example getHexString(8) // 'a1b2c3d4'
*/
function getHexString(length) {
const chars = '0123456789abcdef';
let result = '';
for (let i = 0; i < length; i++) {
result += chars[Math.floor(Math.random() * chars.length)];
}
return result;
}
/**
* Extracts a file with given extension from an ArrayBuffer containing a ZIP archive.
* @param {ArrayBuffer} archiveBuffer Buffer containing a ZIP archive
@@ -384,6 +399,129 @@ function forwardFetchResponse(from, to) {
});
}
/**
* Adds YAML-serialized object to the object.
* @param {object} obj Object
* @param {string} yamlString YAML-serialized object
* @returns
*/
function mergeObjectWithYaml(obj, yamlString) {
if (!yamlString) {
return;
}
try {
const parsedObject = yaml.parse(yamlString);
if (Array.isArray(parsedObject)) {
for (const item of parsedObject) {
if (typeof item === 'object' && item && !Array.isArray(item)) {
Object.assign(obj, item);
}
}
}
else if (parsedObject && typeof parsedObject === 'object') {
Object.assign(obj, parsedObject);
}
} catch {
// Do nothing
}
}
/**
* Removes keys from the object by YAML-serialized array.
* @param {object} obj Object
* @param {string} yamlString YAML-serialized array
* @returns {void} Nothing
*/
function excludeKeysByYaml(obj, yamlString) {
if (!yamlString) {
return;
}
try {
const parsedObject = yaml.parse(yamlString);
if (Array.isArray(parsedObject)) {
parsedObject.forEach(key => {
delete obj[key];
});
} else if (typeof parsedObject === 'object') {
Object.keys(parsedObject).forEach(key => {
delete obj[key];
});
} else if (typeof parsedObject === 'string') {
delete obj[parsedObject];
}
} catch {
// Do nothing
}
}
/**
* Removes trailing slash and /v1 from a string.
* @param {string} str Input string
* @returns {string} Trimmed string
*/
function trimV1(str) {
return String(str ?? '').replace(/\/$/, '').replace(/\/v1$/, '');
}
/**
* Simple TTL memory cache.
*/
class Cache {
/**
* @param {number} ttl Time to live in milliseconds
*/
constructor(ttl) {
this.cache = new Map();
this.ttl = ttl;
}
/**
* Gets a value from the cache.
* @param {string} key Cache key
*/
get(key) {
const value = this.cache.get(key);
if (value?.expiry > Date.now()) {
return value.value;
}
// Cache miss or expired, remove the key
this.cache.delete(key);
return null;
}
/**
* Sets a value in the cache.
* @param {string} key Key
* @param {object} value Value
*/
set(key, value) {
this.cache.set(key, {
value: value,
expiry: Date.now() + this.ttl,
});
}
/**
* Removes a value from the cache.
* @param {string} key Key
*/
remove(key) {
this.cache.delete(key);
}
/**
* Clears the cache.
*/
clear() {
this.cache.clear();
}
}
module.exports = {
getConfig,
getConfigValue,
@@ -404,4 +542,9 @@ module.exports = {
removeOldBackups,
getImages,
forwardFetchResponse,
getHexString,
mergeObjectWithYaml,
excludeKeysByYaml,
trimV1,
Cache,
};