New AI21 Jamba + tokenizer

This commit is contained in:
Cohee
2024-08-26 12:07:36 +03:00
parent ff834efde3
commit 5fc16a2474
10 changed files with 188 additions and 266 deletions

View File

@ -5,7 +5,7 @@ const Readable = require('stream').Readable;
const { jsonParser } = require('../../express-common');
const { CHAT_COMPLETION_SOURCES, GEMINI_SAFETY, BISON_SAFETY, OPENROUTER_HEADERS } = require('../../constants');
const { forwardFetchResponse, getConfigValue, tryParse, uuidv4, mergeObjectWithYaml, excludeKeysByYaml, color } = require('../../util');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools } = require('../../prompt-converters');
const { convertClaudeMessages, convertGooglePrompt, convertTextCompletionPrompt, convertCohereMessages, convertMistralMessages, convertCohereTools, convertAI21Messages } = require('../../prompt-converters');
const { readSecret, SECRET_KEYS } = require('../secrets');
const { getTokenizerModel, getSentencepiceTokenizer, getTiktokenTokenizer, sentencepieceTokenizers, TEXT_COMPLETION_MODELS } = require('../tokenizers');
@ -19,6 +19,7 @@ const API_GROQ = 'https://api.groq.com/openai/v1';
const API_MAKERSUITE = 'https://generativelanguage.googleapis.com';
const API_01AI = 'https://api.01.ai/v1';
const API_BLOCKENTROPY = 'https://api.blockentropy.ai/v1';
const API_AI21 = 'https://api.ai21.com/studio/v1';
/**
* Applies a post-processing step to the generated messages.
@ -413,6 +414,16 @@ async function sendAI21Request(request, response) {
request.socket.on('close', function () {
controller.abort();
});
const convertedPrompt = convertAI21Messages(request.body.messages, request.body.char_name, request.body.user_name);
const body = {
messages: convertedPrompt,
model: request.body.model,
max_tokens: request.body.max_tokens,
temperature: request.body.temperature,
top_p: request.body.top_p,
stop: request.body.stop,
stream: request.body.stream,
};
const options = {
method: 'POST',
headers: {
@ -420,59 +431,33 @@ async function sendAI21Request(request, response) {
'content-type': 'application/json',
Authorization: `Bearer ${readSecret(request.user.directories, SECRET_KEYS.AI21)}`,
},
body: JSON.stringify({
numResults: 1,
maxTokens: request.body.max_tokens,
minTokens: 0,
temperature: request.body.temperature,
topP: request.body.top_p,
stopSequences: request.body.stop_tokens,
topKReturn: request.body.top_k,
frequencyPenalty: {
scale: request.body.frequency_penalty * 100,
applyToWhitespaces: false,
applyToPunctuations: false,
applyToNumbers: false,
applyToStopwords: false,
applyToEmojis: false,
},
presencePenalty: {
scale: request.body.presence_penalty,
applyToWhitespaces: false,
applyToPunctuations: false,
applyToNumbers: false,
applyToStopwords: false,
applyToEmojis: false,
},
countPenalty: {
scale: request.body.count_pen,
applyToWhitespaces: false,
applyToPunctuations: false,
applyToNumbers: false,
applyToStopwords: false,
applyToEmojis: false,
},
prompt: request.body.messages,
}),
body: JSON.stringify(body),
signal: controller.signal,
};
fetch(`https://api.ai21.com/studio/v1/${request.body.model}/complete`, options)
.then(r => r.json())
.then(r => {
if (r.completions === undefined) {
console.log(r);
} else {
console.log(r.completions[0].data.text);
}
const reply = { choices: [{ 'message': { 'content': r.completions?.[0]?.data?.text } }] };
return response.send(reply);
})
.catch(err => {
console.error(err);
return response.send({ error: true });
});
console.log('AI21 request:', body);
try{
const generateResponse = await fetch(API_AI21 + '/chat/completions', options);
if (request.body.stream) {
forwardFetchResponse(generateResponse, response);
} else {
if (!generateResponse.ok) {
console.log(`AI21 API returned error: ${generateResponse.status} ${generateResponse.statusText} ${await generateResponse.text()}`);
return response.status(500).send({ error: true });
}
const generateResponseJson = await generateResponse.json();
console.log('AI21 response:', generateResponseJson);
return response.send(generateResponseJson);
}
} catch (error) {
console.log('Error communicating with MistralAI API: ', error);
if (!response.headersSent) {
response.send({ error: true });
} else {
response.end();
}
}
}
/**