mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
New AI21 Jamba + tokenizer
This commit is contained in:
@ -5,7 +5,7 @@ const Readable = require('stream').Readable;
|
||||
const { jsonParser } = require('../../express-common');
|
||||
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
|
||||
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
|
||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools } = require('../../prompt-converters');
|
||||
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
|
||||
|
||||
const { readSecret, SECRET_KEYS } = require('../secrets');
|
||||
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
|
||||
@ -19,6 +19,7 @@ const API_GROQ = 'https://api.groq.com/openai/v1';
|
||||
const API_MAKERSUITE = 'https://generativelanguage.googleapis.com';
|
||||
const API_01AI = 'https://api.01.ai/v1';
|
||||
const API_BLOCKENTROPY = 'https://api.blockentropy.ai/v1';
|
||||
const API_AI21 = 'https://api.ai21.com/studio/v1';
|
||||
|
||||
/**
|
||||
* Applies a post-processing step to the generated messages.
|
||||
@ -413,6 +414,16 @@ async function sendAI21Request(request, response) {
|
||||
request.socket.on('close', function () {
|
||||
controller.abort();
|
||||
});
|
||||
const convertedPrompt = convertAI21Messages(request.body.messages, request.body.char_name, request.body.user_name);
|
||||
const body = {
|
||||
messages: convertedPrompt,
|
||||
model: request.body.model,
|
||||
max_tokens: request.body.max_tokens,
|
||||
temperature: request.body.temperature,
|
||||
top_p: request.body.top_p,
|
||||
stop: request.body.stop,
|
||||
stream: request.body.stream,
|
||||
};
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@ -420,59 +431,33 @@ async function sendAI21Request(request, response) {
|
||||
'content-type': 'application/json',
|
||||
Authorization: `Bearer ${readSecret(request.user.directories, SECRET_KEYS.AI21)}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
numResults: 1,
|
||||
maxTokens: request.body.max_tokens,
|
||||
minTokens: 0,
|
||||
temperature: request.body.temperature,
|
||||
topP: request.body.top_p,
|
||||
stopSequences: request.body.stop_tokens,
|
||||
topKReturn: request.body.top_k,
|
||||
frequencyPenalty: {
|
||||
scale: request.body.frequency_penalty * 100,
|
||||
applyToWhitespaces: false,
|
||||
applyToPunctuations: false,
|
||||
applyToNumbers: false,
|
||||
applyToStopwords: false,
|
||||
applyToEmojis: false,
|
||||
},
|
||||
presencePenalty: {
|
||||
scale: request.body.presence_penalty,
|
||||
applyToWhitespaces: false,
|
||||
applyToPunctuations: false,
|
||||
applyToNumbers: false,
|
||||
applyToStopwords: false,
|
||||
applyToEmojis: false,
|
||||
},
|
||||
countPenalty: {
|
||||
scale: request.body.count_pen,
|
||||
applyToWhitespaces: false,
|
||||
applyToPunctuations: false,
|
||||
applyToNumbers: false,
|
||||
applyToStopwords: false,
|
||||
applyToEmojis: false,
|
||||
},
|
||||
prompt: request.body.messages,
|
||||
}),
|
||||
body: JSON.stringify(body),
|
||||
signal: controller.signal,
|
||||
};
|
||||
|
||||
fetch(`https://api.ai21.com/studio/v1/${request.body.model}/complete`, options)
|
||||
.then(r => r.json())
|
||||
.then(r => {
|
||||
if (r.completions === undefined) {
|
||||
console.log(r);
|
||||
} else {
|
||||
console.log(r.completions[0].data.text);
|
||||
}
|
||||
const reply = { choices: [{ 'message': { 'content': r.completions?.[0]?.data?.text } }] };
|
||||
return response.send(reply);
|
||||
})
|
||||
.catch(err => {
|
||||
console.error(err);
|
||||
return response.send({ error: true });
|
||||
});
|
||||
console.log('AI21 request:', body);
|
||||
|
||||
try{
|
||||
const generateResponse = await fetch(API_AI21 + '/chat/completions', options);
|
||||
if (request.body.stream) {
|
||||
forwardFetchResponse(generateResponse, response);
|
||||
} else {
|
||||
if (!generateResponse.ok) {
|
||||
console.log(`AI21 API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
|
||||
return response.status(500).send({ error: true });
|
||||
}
|
||||
const generateResponseJson = await generateResponse.json();
|
||||
console.log('AI21 response:', generateResponseJson);
|
||||
return response.send(generateResponseJson);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('Error communicating with MistralAI API: ', error);
|
||||
if (!response.headersSent) {
|
||||
response.send({ error: true });
|
||||
} else {
|
||||
response.end();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user