diff --git a/public/scripts/openai.js b/public/scripts/openai.js index 1dc07ba3c..d96e851e4 100644 --- a/public/scripts/openai.js +++ b/public/scripts/openai.js @@ -59,7 +59,7 @@ import { resetScrollHeight, stringFormat, } from "./utils.js"; -import { countTokensOpenAI } from "./tokenizers.js"; +import { countTokensOpenAI, getTokenizerModel } from "./tokenizers.js"; import { formatInstructModeChat, formatInstructModeExamples, formatInstructModePrompt, formatInstructModeSystemPrompt } from "./instruct-mode.js"; export { @@ -1541,7 +1541,7 @@ async function calculateLogitBias() { let result = {}; try { - const reply = await fetch(`/openai_bias?model=${oai_settings.openai_model}`, { + const reply = await fetch(`/openai_bias?model=${getTokenizerModel()}`, { method: 'POST', headers: getRequestHeaders(), body, diff --git a/server.js b/server.js index 0c908e55d..fe8407446 100644 --- a/server.js +++ b/server.js @@ -57,7 +57,7 @@ const statsHelpers = require('./statsHelpers.js'); const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/secrets'); const { delay, getVersion, deepMerge } = require('./src/util'); const { invalidateThumbnail, ensureThumbnailCache } = require('./src/thumbnails'); -const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS } = require('./src/tokenizers'); +const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/tokenizers'); const { convertClaudePrompt } = require('./src/chat-completion'); // Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0. @@ -2762,57 +2762,71 @@ app.post("/openai_bias", jsonParser, async function (request, response) { if (!request.body || !Array.isArray(request.body)) return response.sendStatus(400); - let result = {}; + try { + const result = {}; + const model = getTokenizerModel(String(request.query.model || '')); - const model = getTokenizerModel(String(request.query.model || '')); - - // no bias for claude - if (model == 'claude') { - return response.send(result); - } - - const tokenizer = getTiktokenTokenizer(model); - - for (const entry of request.body) { - if (!entry || !entry.text) { - continue; + // no bias for claude + if (model == 'claude') { + return response.send(result); } - try { - const tokens = getEntryTokens(entry.text); + let encodeFunction; - for (const token of tokens) { - result[token] = entry.value; + if (sentencepieceTokenizers.includes(model)) { + const tokenizer = getSentencepiceTokenizer(model); + encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text)); + } else { + const tokenizer = getTiktokenTokenizer(model); + encodeFunction = (tokenizer.encode.bind(tokenizer)); + } + + + for (const entry of request.body) { + if (!entry || !entry.text) { + continue; } - } catch { - console.warn('Tokenizer failed to encode:', entry.text); - } - } - // not needed for cached tokenizers - //tokenizer.free(); - return response.send(result); - - /** - * Gets tokenids for a given entry - * @param {string} text Entry text - * @returns {Uint32Array} Array of token ids - */ - function getEntryTokens(text) { - // Get raw token ids from JSON array - if (text.trim().startsWith('[') && text.trim().endsWith(']')) { try { - const json = JSON.parse(text); - if (Array.isArray(json) && json.every(x => typeof x === 'number')) { - return new Uint32Array(json); + const tokens = getEntryTokens(entry.text, encodeFunction); + + for (const token of tokens) { + result[token] = entry.value; } } catch { - // ignore + console.warn('Tokenizer failed to encode:', entry.text); } } - // Otherwise, get token ids from tokenizer - return tokenizer.encode(text); + // not needed for cached tokenizers + //tokenizer.free(); + return response.send(result); + + /** + * Gets tokenids for a given entry + * @param {string} text Entry text + * @param {(string) => Uint32Array} encode Function to encode text to token ids + * @returns {Uint32Array} Array of token ids + */ + function getEntryTokens(text, encode) { + // Get raw token ids from JSON array + if (text.trim().startsWith('[') && text.trim().endsWith(']')) { + try { + const json = JSON.parse(text); + if (Array.isArray(json) && json.every(x => typeof x === 'number')) { + return new Uint32Array(json); + } + } catch { + // ignore + } + } + + // Otherwise, get token ids from tokenizer + return encode(text); + } + } catch (error) { + console.error(error); + return response.send({}); } }); diff --git a/src/tokenizers.js b/src/tokenizers.js index 264e0706d..771da69fe 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -60,6 +60,36 @@ async function loadSentencepieceTokenizer(modelPath) { } }; +const sentencepieceTokenizers = [ + 'llama', + 'nerdstash', + 'nerdstash_v2', + 'mistral', +]; + +/** + * Gets the Sentencepiece tokenizer by the model name. + * @param {string} model Sentencepiece model name + * @returns {*} Sentencepiece tokenizer + */ +function getSentencepiceTokenizer(model) { + if (model.includes('llama')) { + return spp_llama; + } + + if (model.includes('nerdstash')) { + return spp_nerd; + } + + if (model.includes('mistral')) { + return spp_mistral; + } + + if (model.includes('nerdstash_v2')) { + return spp_nerd_v2; + } +} + async function countSentencepieceTokens(spp, text) { // Fallback to strlen estimation if (!spp) { @@ -438,5 +468,7 @@ module.exports = { countClaudeTokens, loadTokenizers, registerEndpoints, + getSentencepiceTokenizer, + sentencepieceTokenizers, }