New AI21 Jamba + tokenizer

This commit is contained in:
Cohee
2024-08-26 12:07:36 +03:00
parent ff834efde3
commit 5fc16a2474
10 changed files with 188 additions and 266 deletions

View File

@ -27,6 +27,7 @@ export const tokenizers = {
CLAUDE: 11,
LLAMA3: 12,
GEMMA: 13,
JAMBA: 14,
BEST_MATCH: 99,
};
@ -36,6 +37,7 @@ export const SENTENCEPIECE_TOKENIZERS = [
tokenizers.YI,
tokenizers.LLAMA3,
tokenizers.GEMMA,
tokenizers.JAMBA,
// uncomment when NovelAI releases Kayra and Clio weights, lol
//tokenizers.NERD,
//tokenizers.NERD2,
@ -98,6 +100,11 @@ const TOKENIZER_URLS = {
decode: '/api/tokenizers/gemma/decode',
count: '/api/tokenizers/gemma/encode',
},
[tokenizers.JAMBA]: {
encode: '/api/tokenizers/jamba/encode',
decode: '/api/tokenizers/jamba/decode',
count: '/api/tokenizers/jamba/encode',
},
[tokenizers.API_TEXTGENERATIONWEBUI]: {
encode: '/api/tokenizers/remote/textgenerationwebui/encode',
count: '/api/tokenizers/remote/textgenerationwebui/encode',
@ -164,7 +171,7 @@ export function getAvailableTokenizers() {
tokenizerId: Number(tokenizerOption.value),
tokenizerKey: Object.entries(tokenizers).find(([_, value]) => value === Number(tokenizerOption.value))[0].toLocaleLowerCase(),
tokenizerName: tokenizerOption.text,
}))
}));
}
/**
@ -280,6 +287,12 @@ export function getTokenizerBestMatch(forApi) {
if (model.includes('gemma')) {
return tokenizers.GEMMA;
}
if (model.includes('yi')) {
return tokenizers.YI;
}
if (model.includes('jamba')) {
return tokenizers.JAMBA;
}
}
return tokenizers.LLAMA;
@ -497,6 +510,7 @@ export function getTokenizerModel() {
const mistralTokenizer = 'mistral';
const yiTokenizer = 'yi';
const gemmaTokenizer = 'gemma';
const jambaTokenizer = 'jamba';
// Assuming no one would use it for different models.. right?
if (oai_settings.chat_completion_source == chat_completion_sources.SCALE) {
@ -562,12 +576,19 @@ export function getTokenizerModel() {
else if (oai_settings.openrouter_model.includes('GPT-NeoXT')) {
return gpt2Tokenizer;
}
else if (oai_settings.openrouter_model.includes('jamba')) {
return jambaTokenizer;
}
}
if (oai_settings.chat_completion_source == chat_completion_sources.MAKERSUITE) {
return gemmaTokenizer;
}
if (oai_settings.chat_completion_source == chat_completion_sources.AI21) {
return jambaTokenizer;
}
if (oai_settings.chat_completion_source == chat_completion_sources.CLAUDE) {
return claudeTokenizer;
}
@ -626,16 +647,7 @@ export function getTokenizerModel() {
* @deprecated Use countTokensOpenAIAsync instead.
*/
export function countTokensOpenAI(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE && oai_settings.use_google_tokenizer;
let tokenizerEndpoint = '';
if (shouldTokenizeAI21) {
tokenizerEndpoint = '/api/tokenizers/ai21/count';
} else if (shouldTokenizeGoogle) {
tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}&reverse_proxy=${oai_settings.reverse_proxy}&proxy_password=${oai_settings.proxy_password}`;
} else {
tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
}
const tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
const cacheObject = getTokenCacheObject();
if (!Array.isArray(messages)) {
@ -647,7 +659,7 @@ export function countTokensOpenAI(messages, full = false) {
for (const message of messages) {
const model = getTokenizerModel();
if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
if (model === 'claude') {
full = true;
}
@ -687,16 +699,7 @@ export function countTokensOpenAI(messages, full = false) {
* @returns {Promise<number>} Token count.
*/
export async function countTokensOpenAIAsync(messages, full = false) {
const shouldTokenizeAI21 = oai_settings.chat_completion_source === chat_completion_sources.AI21 && oai_settings.use_ai21_tokenizer;
const shouldTokenizeGoogle = oai_settings.chat_completion_source === chat_completion_sources.MAKERSUITE && oai_settings.use_google_tokenizer;
let tokenizerEndpoint = '';
if (shouldTokenizeAI21) {
tokenizerEndpoint = '/api/tokenizers/ai21/count';
} else if (shouldTokenizeGoogle) {
tokenizerEndpoint = `/api/tokenizers/google/count?model=${getTokenizerModel()}`;
} else {
tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
}
const tokenizerEndpoint = `/api/tokenizers/openai/count?model=${getTokenizerModel()}`;
const cacheObject = getTokenCacheObject();
if (!Array.isArray(messages)) {
@ -708,7 +711,7 @@ export async function countTokensOpenAIAsync(messages, full = false) {
for (const message of messages) {
const model = getTokenizerModel();
if (model === 'claude' || shouldTokenizeAI21 || shouldTokenizeGoogle) {
if (model === 'claude') {
full = true;
}