Move tokenizer endpoint and functions to separate file
This commit is contained in:
parent
ab9aa28fe4
commit
bfdd071001
|
@ -101,13 +101,13 @@ function callTokenizer(type, str, padding) {
|
|||
case tokenizers.NONE:
|
||||
return guesstimate(str) + padding;
|
||||
case tokenizers.GPT2:
|
||||
return countTokensRemote('/tokenize_gpt2', str, padding);
|
||||
return countTokensRemote('/api/tokenize/gpt2', str, padding);
|
||||
case tokenizers.LLAMA:
|
||||
return countTokensRemote('/tokenize_llama', str, padding);
|
||||
return countTokensRemote('/api/tokenize/llama', str, padding);
|
||||
case tokenizers.NERD:
|
||||
return countTokensRemote('/tokenize_nerdstash', str, padding);
|
||||
return countTokensRemote('/api/tokenize/nerdstash', str, padding);
|
||||
case tokenizers.NERD2:
|
||||
return countTokensRemote('/tokenize_nerdstash_v2', str, padding);
|
||||
return countTokensRemote('/api/tokenize/nerdstash_v2', str, padding);
|
||||
case tokenizers.API:
|
||||
return countTokensRemote('/tokenize_via_api', str, padding);
|
||||
default:
|
||||
|
@ -264,7 +264,7 @@ export function countTokensOpenAI(messages, full = false) {
|
|||
jQuery.ajax({
|
||||
async: false,
|
||||
type: 'POST', //
|
||||
url: shouldTokenizeAI21 ? '/tokenize_ai21' : `/tokenize_openai?model=${model}`,
|
||||
url: shouldTokenizeAI21 ? '/api/tokenize/ai21' : `/api/tokenize/openai?model=${model}`,
|
||||
data: JSON.stringify([message]),
|
||||
dataType: "json",
|
||||
contentType: "application/json",
|
||||
|
@ -398,13 +398,13 @@ function decodeTextTokensRemote(endpoint, ids) {
|
|||
export function getTextTokens(tokenizerType, str) {
|
||||
switch (tokenizerType) {
|
||||
case tokenizers.GPT2:
|
||||
return getTextTokensRemote('/tokenize_gpt2', str);
|
||||
return getTextTokensRemote('/api/tokenize/gpt2', str);
|
||||
case tokenizers.LLAMA:
|
||||
return getTextTokensRemote('/tokenize_llama', str);
|
||||
return getTextTokensRemote('/api/tokenize/llama', str);
|
||||
case tokenizers.NERD:
|
||||
return getTextTokensRemote('/tokenize_nerdstash', str);
|
||||
return getTextTokensRemote('/api/tokenize/nerdstash', str);
|
||||
case tokenizers.NERD2:
|
||||
return getTextTokensRemote('/tokenize_nerdstash_v2', str);
|
||||
return getTextTokensRemote('/api/tokenize/nerdstash_v2', str);
|
||||
default:
|
||||
console.warn("Calling getTextTokens with unsupported tokenizer type", tokenizerType);
|
||||
return [];
|
||||
|
@ -413,19 +413,19 @@ export function getTextTokens(tokenizerType, str) {
|
|||
|
||||
/**
|
||||
* Decodes token ids to text using the remote server API.
|
||||
* @param {any} tokenizerType Tokenizer type.
|
||||
* @param {number} tokenizerType Tokenizer type.
|
||||
* @param {number[]} ids Array of token ids
|
||||
*/
|
||||
export function decodeTextTokens(tokenizerType, ids) {
|
||||
switch (tokenizerType) {
|
||||
case tokenizers.GPT2:
|
||||
return decodeTextTokensRemote('/decode_gpt2', ids);
|
||||
return decodeTextTokensRemote('/api/decode/gpt2', ids);
|
||||
case tokenizers.LLAMA:
|
||||
return decodeTextTokensRemote('/decode_llama', ids);
|
||||
return decodeTextTokensRemote('/api/decode/llama', ids);
|
||||
case tokenizers.NERD:
|
||||
return decodeTextTokensRemote('/decode_nerdstash', ids);
|
||||
return decodeTextTokensRemote('/api/decode/nerdstash', ids);
|
||||
case tokenizers.NERD2:
|
||||
return decodeTextTokensRemote('/decode_nerdstash_v2', ids);
|
||||
return decodeTextTokensRemote('/api/decode/nerdstash_v2', ids);
|
||||
default:
|
||||
console.warn("Calling decodeTextTokens with unsupported tokenizer type", tokenizerType);
|
||||
return '';
|
||||
|
|
339
server.js
339
server.js
|
@ -44,11 +44,6 @@ const jimp = require('jimp');
|
|||
const mime = require('mime-types');
|
||||
const PNGtext = require('png-chunk-text');
|
||||
|
||||
// tokenizing related library imports
|
||||
const { SentencePieceProcessor } = require("@agnai/sentencepiece-js");
|
||||
const tiktoken = require('@dqbd/tiktoken');
|
||||
const { Tokenizer } = require('@agnai/web-tokenizers');
|
||||
|
||||
// misc/other imports
|
||||
const _ = require('lodash');
|
||||
|
||||
|
@ -64,6 +59,8 @@ const statsHelpers = require('./statsHelpers.js');
|
|||
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/secrets');
|
||||
const { delay, getVersion } = require('./src/util');
|
||||
const { invalidateThumbnail, ensureThumbnailCache } = require('./src/thumbnails');
|
||||
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS } = require('./src/tokenizers');
|
||||
const { convertClaudePrompt } = require('./src/chat-completion');
|
||||
|
||||
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
||||
// https://github.com/nodejs/node/issues/47822#issuecomment-1564708870
|
||||
|
@ -165,139 +162,6 @@ function getOverrideHeaders(urlHost) {
|
|||
}
|
||||
}
|
||||
|
||||
//RossAscends: Added function to format dates used in files and chat timestamps to a humanized format.
|
||||
//Mostly I wanted this to be for file names, but couldn't figure out exactly where the filename save code was as everything seemed to be connected.
|
||||
//During testing, this performs the same as previous date.now() structure.
|
||||
//It also does not break old characters/chats, as the code just uses whatever timestamp exists in the chat.
|
||||
//New chats made with characters will use this new formatting.
|
||||
//Useable variable is (( humanizedISO8601Datetime ))
|
||||
|
||||
const CHARS_PER_TOKEN = 3.35;
|
||||
|
||||
let spp_llama;
|
||||
let spp_nerd;
|
||||
let spp_nerd_v2;
|
||||
let claude_tokenizer;
|
||||
|
||||
async function loadSentencepieceTokenizer(modelPath) {
|
||||
try {
|
||||
const spp = new SentencePieceProcessor();
|
||||
await spp.load(modelPath);
|
||||
return spp;
|
||||
} catch (error) {
|
||||
console.error("Sentencepiece tokenizer failed to load: " + modelPath, error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
async function countSentencepieceTokens(spp, text) {
|
||||
// Fallback to strlen estimation
|
||||
if (!spp) {
|
||||
return {
|
||||
ids: [],
|
||||
count: Math.ceil(text.length / CHARS_PER_TOKEN)
|
||||
};
|
||||
}
|
||||
|
||||
let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization
|
||||
|
||||
let ids = spp.encodeIds(cleaned);
|
||||
return {
|
||||
ids,
|
||||
count: ids.length
|
||||
};
|
||||
}
|
||||
|
||||
async function loadClaudeTokenizer(modelPath) {
|
||||
try {
|
||||
const arrayBuffer = fs.readFileSync(modelPath).buffer;
|
||||
const instance = await Tokenizer.fromJSON(arrayBuffer);
|
||||
return instance;
|
||||
} catch (error) {
|
||||
console.error("Claude tokenizer failed to load: " + modelPath, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function countClaudeTokens(tokenizer, messages) {
|
||||
const convertedPrompt = convertClaudePrompt(messages, false, false);
|
||||
|
||||
// Fallback to strlen estimation
|
||||
if (!tokenizer) {
|
||||
return Math.ceil(convertedPrompt.length / CHARS_PER_TOKEN);
|
||||
}
|
||||
|
||||
const count = tokenizer.encode(convertedPrompt).length;
|
||||
return count;
|
||||
}
|
||||
|
||||
const tokenizersCache = {};
|
||||
|
||||
/**
|
||||
* @type {import('@dqbd/tiktoken').TiktokenModel[]}
|
||||
*/
|
||||
const textCompletionModels = [
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
"code-davinci-002",
|
||||
"code-davinci-001",
|
||||
"code-cushman-002",
|
||||
"code-cushman-001",
|
||||
"text-davinci-edit-001",
|
||||
"code-davinci-edit-001",
|
||||
"text-embedding-ada-002",
|
||||
"text-similarity-davinci-001",
|
||||
"text-similarity-curie-001",
|
||||
"text-similarity-babbage-001",
|
||||
"text-similarity-ada-001",
|
||||
"text-search-davinci-doc-001",
|
||||
"text-search-curie-doc-001",
|
||||
"text-search-babbage-doc-001",
|
||||
"text-search-ada-doc-001",
|
||||
"code-search-babbage-code-001",
|
||||
"code-search-ada-code-001",
|
||||
];
|
||||
|
||||
function getTokenizerModel(requestModel) {
|
||||
if (requestModel.includes('claude')) {
|
||||
return 'claude';
|
||||
}
|
||||
|
||||
if (requestModel.includes('gpt-4-32k')) {
|
||||
return 'gpt-4-32k';
|
||||
}
|
||||
|
||||
if (requestModel.includes('gpt-4')) {
|
||||
return 'gpt-4';
|
||||
}
|
||||
|
||||
if (requestModel.includes('gpt-3.5-turbo')) {
|
||||
return 'gpt-3.5-turbo';
|
||||
}
|
||||
|
||||
if (textCompletionModels.includes(requestModel)) {
|
||||
return requestModel;
|
||||
}
|
||||
|
||||
// default
|
||||
return 'gpt-3.5-turbo';
|
||||
}
|
||||
|
||||
function getTiktokenTokenizer(model) {
|
||||
if (tokenizersCache[model]) {
|
||||
return tokenizersCache[model];
|
||||
}
|
||||
|
||||
const tokenizer = tiktoken.encoding_for_model(model);
|
||||
console.log('Instantiated the tokenizer for', model);
|
||||
tokenizersCache[model] = tokenizer;
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
function humanizedISO8601DateTime(date) {
|
||||
let baseDate = typeof date === 'number' ? new Date(date) : new Date();
|
||||
let humanYear = baseDate.getFullYear();
|
||||
|
@ -2838,50 +2702,6 @@ function convertChatMLPrompt(messages) {
|
|||
return messageStrings.join("\n");
|
||||
}
|
||||
|
||||
// Prompt Conversion script taken from RisuAI by @kwaroran (GPLv3).
|
||||
function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix) {
|
||||
// Claude doesn't support message names, so we'll just add them to the message content.
|
||||
for (const message of messages) {
|
||||
if (message.name && message.role !== "system") {
|
||||
message.content = message.name + ": " + message.content;
|
||||
delete message.name;
|
||||
}
|
||||
}
|
||||
|
||||
let requestPrompt = messages.map((v) => {
|
||||
let prefix = '';
|
||||
switch (v.role) {
|
||||
case "assistant":
|
||||
prefix = "\n\nAssistant: ";
|
||||
break
|
||||
case "user":
|
||||
prefix = "\n\nHuman: ";
|
||||
break
|
||||
case "system":
|
||||
// According to the Claude docs, H: and A: should be used for example conversations.
|
||||
if (v.name === "example_assistant") {
|
||||
prefix = "\n\nA: ";
|
||||
} else if (v.name === "example_user") {
|
||||
prefix = "\n\nH: ";
|
||||
} else {
|
||||
prefix = "\n\n";
|
||||
}
|
||||
break
|
||||
}
|
||||
return prefix + v.content;
|
||||
}).join('');
|
||||
|
||||
if (addHumanPrefix) {
|
||||
requestPrompt = "\n\nHuman: " + requestPrompt;
|
||||
}
|
||||
|
||||
if (addAssistantPostfix) {
|
||||
requestPrompt = requestPrompt + '\n\nAssistant: ';
|
||||
}
|
||||
|
||||
return requestPrompt;
|
||||
}
|
||||
|
||||
async function sendScaleRequest(request, response) {
|
||||
|
||||
const api_url = new URL(request.body.api_url_scale).toString();
|
||||
|
@ -3131,7 +2951,7 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op
|
|||
bodyParams['stop'] = request.body.stop;
|
||||
}
|
||||
|
||||
const isTextCompletion = Boolean(request.body.model && textCompletionModels.includes(request.body.model));
|
||||
const isTextCompletion = Boolean(request.body.model && TEXT_COMPLETION_MODELS.includes(request.body.model));
|
||||
const textPrompt = isTextCompletion ? convertChatMLPrompt(request.body.messages) : '';
|
||||
const endpointUrl = isTextCompletion ? `${api_url}/completions` : `${api_url}/chat/completions`;
|
||||
|
||||
|
@ -3245,44 +3065,6 @@ app.post("/generate_openai", jsonParser, function (request, response_generate_op
|
|||
}
|
||||
});
|
||||
|
||||
app.post("/tokenize_openai", jsonParser, function (request, response_tokenize_openai) {
|
||||
if (!request.body) return response_tokenize_openai.sendStatus(400);
|
||||
|
||||
let num_tokens = 0;
|
||||
const model = getTokenizerModel(String(request.query.model || ''));
|
||||
|
||||
if (model == 'claude') {
|
||||
num_tokens = countClaudeTokens(claude_tokenizer, request.body);
|
||||
return response_tokenize_openai.send({ "token_count": num_tokens });
|
||||
}
|
||||
|
||||
const tokensPerName = model.includes('gpt-4') ? 1 : -1;
|
||||
const tokensPerMessage = model.includes('gpt-4') ? 3 : 4;
|
||||
const tokensPadding = 3;
|
||||
|
||||
const tokenizer = getTiktokenTokenizer(model);
|
||||
|
||||
for (const msg of request.body) {
|
||||
try {
|
||||
num_tokens += tokensPerMessage;
|
||||
for (const [key, value] of Object.entries(msg)) {
|
||||
num_tokens += tokenizer.encode(value).length;
|
||||
if (key == "name") {
|
||||
num_tokens += tokensPerName;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
console.warn("Error tokenizing message:", msg);
|
||||
}
|
||||
}
|
||||
num_tokens += tokensPadding;
|
||||
|
||||
// not needed for cached tokenizers
|
||||
//tokenizer.free();
|
||||
|
||||
response_tokenize_openai.send({ "token_count": num_tokens });
|
||||
});
|
||||
|
||||
async function sendAI21Request(request, response) {
|
||||
if (!request.body) return response.sendStatus(400);
|
||||
const controller = new AbortController();
|
||||
|
@ -3353,109 +3135,6 @@ async function sendAI21Request(request, response) {
|
|||
|
||||
}
|
||||
|
||||
app.post("/tokenize_ai21", jsonParser, async function (request, response_tokenize_ai21) {
|
||||
if (!request.body) return response_tokenize_ai21.sendStatus(400);
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
accept: 'application/json',
|
||||
'content-type': 'application/json',
|
||||
Authorization: `Bearer ${readSecret(SECRET_KEYS.AI21)}`
|
||||
},
|
||||
body: JSON.stringify({ text: request.body[0].content })
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await fetch('https://api.ai21.com/studio/v1/tokenize', options);
|
||||
const data = await response.json();
|
||||
return response_tokenize_ai21.send({ "token_count": data?.tokens?.length || 0 });
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
return response_tokenize_ai21.send({ "token_count": 0 });
|
||||
}
|
||||
});
|
||||
|
||||
function createSentencepieceEncodingHandler(getTokenizerFn) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const text = request.body.text || '';
|
||||
const tokenizer = getTokenizerFn();
|
||||
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
|
||||
return response.send({ ids, count });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ ids: [], count: 0 });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function createSentencepieceDecodingHandler(getTokenizerFn) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const ids = request.body.ids || [];
|
||||
const tokenizer = getTokenizerFn();
|
||||
const text = await tokenizer.decodeIds(ids);
|
||||
return response.send({ text });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ text: '' });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function createTiktokenEncodingHandler(modelId) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const text = request.body.text || '';
|
||||
const tokenizer = getTiktokenTokenizer(modelId);
|
||||
const tokens = Object.values(tokenizer.encode(text));
|
||||
return response.send({ ids: tokens, count: tokens.length });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ ids: [], count: 0 });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function createTiktokenDecodingHandler(modelId) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const ids = request.body.ids || [];
|
||||
const tokenizer = getTiktokenTokenizer(modelId);
|
||||
const textBytes = tokenizer.decode(new Uint32Array(ids));
|
||||
const text = new TextDecoder().decode(textBytes);
|
||||
return response.send({ text });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ text: '' });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
app.post("/tokenize_llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama));
|
||||
app.post("/tokenize_nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd));
|
||||
app.post("/tokenize_nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2));
|
||||
app.post("/tokenize_gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||
app.post("/decode_llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama));
|
||||
app.post("/decode_nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd));
|
||||
app.post("/decode_nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2));
|
||||
app.post("/decode_gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||
app.post("/tokenize_via_api", jsonParser, async function (request, response) {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
|
@ -3491,7 +3170,6 @@ app.post("/tokenize_via_api", jsonParser, async function (request, response) {
|
|||
}
|
||||
});
|
||||
|
||||
|
||||
// ** REST CLIENT ASYNC WRAPPERS **
|
||||
|
||||
/**
|
||||
|
@ -3519,6 +3197,9 @@ async function postAsync(url, args) { return fetchJSON(url, { method: 'POST', ti
|
|||
|
||||
// ** END **
|
||||
|
||||
// Tokenizers
|
||||
require('./src/tokenizers').registerEndpoints(app, jsonParser);
|
||||
|
||||
// Preset management
|
||||
require('./src/presets').registerEndpoints(app, jsonParser);
|
||||
|
||||
|
@ -3585,13 +3266,7 @@ const setupTasks = async function () {
|
|||
contentManager.checkForNewContent();
|
||||
cleanUploads();
|
||||
|
||||
[spp_llama, spp_nerd, spp_nerd_v2, claude_tokenizer] = await Promise.all([
|
||||
loadSentencepieceTokenizer('src/sentencepiece/tokenizer.model'),
|
||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
|
||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
|
||||
loadClaudeTokenizer('src/claude.json'),
|
||||
]);
|
||||
|
||||
await loadTokenizers();
|
||||
await statsHelpers.loadStatsFile(DIRECTORIES.chats, DIRECTORIES.characters);
|
||||
|
||||
// Set up event listeners for a graceful shutdown
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Convert a prompt from the ChatML objects to the format used by Claude.
|
||||
* @param {object[]} messages Array of messages
|
||||
* @param {boolean} addHumanPrefix Add Human prefix
|
||||
* @param {boolean} addAssistantPostfix Add Assistant postfix
|
||||
* @returns {string} Prompt for Claude
|
||||
* @copyright Prompt Conversion script taken from RisuAI by kwaroran (GPLv3).
|
||||
*/
|
||||
function convertClaudePrompt(messages, addHumanPrefix, addAssistantPostfix) {
|
||||
// Claude doesn't support message names, so we'll just add them to the message content.
|
||||
for (const message of messages) {
|
||||
if (message.name && message.role !== "system") {
|
||||
message.content = message.name + ": " + message.content;
|
||||
delete message.name;
|
||||
}
|
||||
}
|
||||
|
||||
let requestPrompt = messages.map((v) => {
|
||||
let prefix = '';
|
||||
switch (v.role) {
|
||||
case "assistant":
|
||||
prefix = "\n\nAssistant: ";
|
||||
break
|
||||
case "user":
|
||||
prefix = "\n\nHuman: ";
|
||||
break
|
||||
case "system":
|
||||
// According to the Claude docs, H: and A: should be used for example conversations.
|
||||
if (v.name === "example_assistant") {
|
||||
prefix = "\n\nA: ";
|
||||
} else if (v.name === "example_user") {
|
||||
prefix = "\n\nH: ";
|
||||
} else {
|
||||
prefix = "\n\n";
|
||||
}
|
||||
break
|
||||
}
|
||||
return prefix + v.content;
|
||||
}).join('');
|
||||
|
||||
if (addHumanPrefix) {
|
||||
requestPrompt = "\n\nHuman: " + requestPrompt;
|
||||
}
|
||||
|
||||
if (addAssistantPostfix) {
|
||||
requestPrompt = requestPrompt + '\n\nAssistant: ';
|
||||
}
|
||||
|
||||
return requestPrompt;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
convertClaudePrompt,
|
||||
}
|
|
@ -0,0 +1,334 @@
|
|||
const fs = require('fs');
|
||||
const { SentencePieceProcessor } = require("@agnai/sentencepiece-js");
|
||||
const tiktoken = require('@dqbd/tiktoken');
|
||||
const { Tokenizer } = require('@agnai/web-tokenizers');
|
||||
const { convertClaudePrompt } = require('./chat-completion');
|
||||
const { readSecret, SECRET_KEYS } = require('./secrets');
|
||||
|
||||
/**
|
||||
* @type {{[key: string]: import("@dqbd/tiktoken").Tiktoken}} Tokenizers cache
|
||||
*/
|
||||
const tokenizersCache = {};
|
||||
|
||||
/**
|
||||
* @type {string[]}
|
||||
*/
|
||||
const TEXT_COMPLETION_MODELS = [
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
"code-davinci-002",
|
||||
"code-davinci-001",
|
||||
"code-cushman-002",
|
||||
"code-cushman-001",
|
||||
"text-davinci-edit-001",
|
||||
"code-davinci-edit-001",
|
||||
"text-embedding-ada-002",
|
||||
"text-similarity-davinci-001",
|
||||
"text-similarity-curie-001",
|
||||
"text-similarity-babbage-001",
|
||||
"text-similarity-ada-001",
|
||||
"text-search-davinci-doc-001",
|
||||
"text-search-curie-doc-001",
|
||||
"text-search-babbage-doc-001",
|
||||
"text-search-ada-doc-001",
|
||||
"code-search-babbage-code-001",
|
||||
"code-search-ada-code-001",
|
||||
];
|
||||
|
||||
const CHARS_PER_TOKEN = 3.35;
|
||||
|
||||
let spp_llama;
|
||||
let spp_nerd;
|
||||
let spp_nerd_v2;
|
||||
let claude_tokenizer;
|
||||
|
||||
async function loadSentencepieceTokenizer(modelPath) {
|
||||
try {
|
||||
const spp = new SentencePieceProcessor();
|
||||
await spp.load(modelPath);
|
||||
return spp;
|
||||
} catch (error) {
|
||||
console.error("Sentencepiece tokenizer failed to load: " + modelPath, error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
async function countSentencepieceTokens(spp, text) {
|
||||
// Fallback to strlen estimation
|
||||
if (!spp) {
|
||||
return {
|
||||
ids: [],
|
||||
count: Math.ceil(text.length / CHARS_PER_TOKEN)
|
||||
};
|
||||
}
|
||||
|
||||
let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization
|
||||
|
||||
let ids = spp.encodeIds(cleaned);
|
||||
return {
|
||||
ids,
|
||||
count: ids.length
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the tokenizer model by the model name.
|
||||
* @param {string} requestModel Models to use for tokenization
|
||||
* @returns {string} Tokenizer model to use
|
||||
*/
|
||||
function getTokenizerModel(requestModel) {
|
||||
if (requestModel.includes('claude')) {
|
||||
return 'claude';
|
||||
}
|
||||
|
||||
if (requestModel.includes('gpt-4-32k')) {
|
||||
return 'gpt-4-32k';
|
||||
}
|
||||
|
||||
if (requestModel.includes('gpt-4')) {
|
||||
return 'gpt-4';
|
||||
}
|
||||
|
||||
if (requestModel.includes('gpt-3.5-turbo')) {
|
||||
return 'gpt-3.5-turbo';
|
||||
}
|
||||
|
||||
if (TEXT_COMPLETION_MODELS.includes(requestModel)) {
|
||||
return requestModel;
|
||||
}
|
||||
|
||||
// default
|
||||
return 'gpt-3.5-turbo';
|
||||
}
|
||||
|
||||
function getTiktokenTokenizer(model) {
|
||||
if (tokenizersCache[model]) {
|
||||
return tokenizersCache[model];
|
||||
}
|
||||
|
||||
const tokenizer = tiktoken.encoding_for_model(model);
|
||||
console.log('Instantiated the tokenizer for', model);
|
||||
tokenizersCache[model] = tokenizer;
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
async function loadClaudeTokenizer(modelPath) {
|
||||
try {
|
||||
const arrayBuffer = fs.readFileSync(modelPath).buffer;
|
||||
const instance = await Tokenizer.fromJSON(arrayBuffer);
|
||||
return instance;
|
||||
} catch (error) {
|
||||
console.error("Claude tokenizer failed to load: " + modelPath, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function countClaudeTokens(tokenizer, messages) {
|
||||
const convertedPrompt = convertClaudePrompt(messages, false, false);
|
||||
|
||||
// Fallback to strlen estimation
|
||||
if (!tokenizer) {
|
||||
return Math.ceil(convertedPrompt.length / CHARS_PER_TOKEN);
|
||||
}
|
||||
|
||||
const count = tokenizer.encode(convertedPrompt).length;
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for encoding Sentencepiece tokens.
|
||||
* @param {function} getTokenizerFn Tokenizer provider function
|
||||
* @returns {any} Handler function
|
||||
*/
|
||||
function createSentencepieceEncodingHandler(getTokenizerFn) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const text = request.body.text || '';
|
||||
const tokenizer = getTokenizerFn();
|
||||
const { ids, count } = await countSentencepieceTokens(tokenizer, text);
|
||||
return response.send({ ids, count });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ ids: [], count: 0 });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for decoding Sentencepiece tokens.
|
||||
* @param {function} getTokenizerFn Tokenizer provider function
|
||||
* @returns {any} Handler function
|
||||
*/
|
||||
function createSentencepieceDecodingHandler(getTokenizerFn) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const ids = request.body.ids || [];
|
||||
const tokenizer = getTokenizerFn();
|
||||
const text = await tokenizer.decodeIds(ids);
|
||||
return response.send({ text });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ text: '' });
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for encoding Tiktoken tokens.
|
||||
* @param {string} modelId Tiktoken model ID
|
||||
* @returns {any} Handler function
|
||||
*/
|
||||
function createTiktokenEncodingHandler(modelId) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const text = request.body.text || '';
|
||||
const tokenizer = getTiktokenTokenizer(modelId);
|
||||
const tokens = Object.values(tokenizer.encode(text));
|
||||
return response.send({ ids: tokens, count: tokens.length });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ ids: [], count: 0 });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an API handler for decoding Tiktoken tokens.
|
||||
* @param {string} modelId Tiktoken model ID
|
||||
* @returns {any} Handler function
|
||||
*/
|
||||
function createTiktokenDecodingHandler(modelId) {
|
||||
return async function (request, response) {
|
||||
try {
|
||||
if (!request.body) {
|
||||
return response.sendStatus(400);
|
||||
}
|
||||
|
||||
const ids = request.body.ids || [];
|
||||
const tokenizer = getTiktokenTokenizer(modelId);
|
||||
const textBytes = tokenizer.decode(new Uint32Array(ids));
|
||||
const text = new TextDecoder().decode(textBytes);
|
||||
return response.send({ text });
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return response.send({ text: '' });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the model tokenizers.
|
||||
* @returns {Promise<void>} Promise that resolves when the tokenizers are loaded
|
||||
*/
|
||||
async function loadTokenizers() {
|
||||
[spp_llama, spp_nerd, spp_nerd_v2, claude_tokenizer] = await Promise.all([
|
||||
loadSentencepieceTokenizer('src/sentencepiece/tokenizer.model'),
|
||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash.model'),
|
||||
loadSentencepieceTokenizer('src/sentencepiece/nerdstash_v2.model'),
|
||||
loadClaudeTokenizer('src/claude.json'),
|
||||
]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers the tokenization endpoints.
|
||||
* @param {import('express').Express} app Express app
|
||||
* @param {any} jsonParser JSON parser middleware
|
||||
*/
|
||||
function registerEndpoints(app, jsonParser) {
|
||||
app.post("/api/tokenize/ai21", jsonParser, async function (req, res) {
|
||||
if (!req.body) return res.sendStatus(400);
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
accept: 'application/json',
|
||||
'content-type': 'application/json',
|
||||
Authorization: `Bearer ${readSecret(SECRET_KEYS.AI21)}`
|
||||
},
|
||||
body: JSON.stringify({ text: req.body[0].content })
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await fetch('https://api.ai21.com/studio/v1/tokenize', options);
|
||||
const data = await response.json();
|
||||
return res.send({ "token_count": data?.tokens?.length || 0 });
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
return res.send({ "token_count": 0 });
|
||||
}
|
||||
});
|
||||
|
||||
app.post("/api/tokenize/llama", jsonParser, createSentencepieceEncodingHandler(() => spp_llama));
|
||||
app.post("/api/tokenize/nerdstash", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd));
|
||||
app.post("/api/tokenize/nerdstash_v2", jsonParser, createSentencepieceEncodingHandler(() => spp_nerd_v2));
|
||||
app.post("/api/tokenize/gpt2", jsonParser, createTiktokenEncodingHandler('gpt2'));
|
||||
app.post("/api/decode/llama", jsonParser, createSentencepieceDecodingHandler(() => spp_llama));
|
||||
app.post("/api/decode/nerdstash", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd));
|
||||
app.post("/api/decode/nerdstash_v2", jsonParser, createSentencepieceDecodingHandler(() => spp_nerd_v2));
|
||||
app.post("/api/decode/gpt2", jsonParser, createTiktokenDecodingHandler('gpt2'));
|
||||
|
||||
app.post("/api/tokenize/openai", jsonParser, function (req, res) {
|
||||
if (!req.body) return res.sendStatus(400);
|
||||
|
||||
let num_tokens = 0;
|
||||
const model = getTokenizerModel(String(req.query.model || ''));
|
||||
|
||||
if (model == 'claude') {
|
||||
num_tokens = countClaudeTokens(claude_tokenizer, req.body);
|
||||
return res.send({ "token_count": num_tokens });
|
||||
}
|
||||
|
||||
const tokensPerName = model.includes('gpt-4') ? 1 : -1;
|
||||
const tokensPerMessage = model.includes('gpt-4') ? 3 : 4;
|
||||
const tokensPadding = 3;
|
||||
|
||||
const tokenizer = getTiktokenTokenizer(model);
|
||||
|
||||
for (const msg of req.body) {
|
||||
try {
|
||||
num_tokens += tokensPerMessage;
|
||||
for (const [key, value] of Object.entries(msg)) {
|
||||
num_tokens += tokenizer.encode(value).length;
|
||||
if (key == "name") {
|
||||
num_tokens += tokensPerName;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
console.warn("Error tokenizing message:", msg);
|
||||
}
|
||||
}
|
||||
num_tokens += tokensPadding;
|
||||
|
||||
// not needed for cached tokenizers
|
||||
//tokenizer.free();
|
||||
|
||||
res.send({ "token_count": num_tokens });
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
TEXT_COMPLETION_MODELS,
|
||||
getTokenizerModel,
|
||||
getTiktokenTokenizer,
|
||||
loadSentencepieceTokenizer,
|
||||
loadClaudeTokenizer,
|
||||
countSentencepieceTokens,
|
||||
countClaudeTokens,
|
||||
loadTokenizers,
|
||||
registerEndpoints,
|
||||
}
|
Loading…
Reference in New Issue