Use correct tokenizers for logit bias for Mistral and Llama models over OpenRouter

This commit is contained in:
Cohee 2023-11-09 01:03:54 +02:00
parent 1f36fe5193
commit 0e89bf90bc
3 changed files with 88 additions and 42 deletions

View File

@ -59,7 +59,7 @@ import {
resetScrollHeight,
stringFormat,
} from "./utils.js";
import { countTokensOpenAI } from "./tokenizers.js";
import { countTokensOpenAI, getTokenizerModel } from "./tokenizers.js";
import { formatInstructModeChat, formatInstructModeExamples, formatInstructModePrompt, formatInstructModeSystemPrompt } from "./instruct-mode.js";
export {
@ -1541,7 +1541,7 @@ async function calculateLogitBias() {
let result = {};
try {
const reply = await fetch(`/openai_bias?model=${oai_settings.openai_model}`, {
const reply = await fetch(`/openai_bias?model=${getTokenizerModel()}`, {
method: 'POST',
headers: getRequestHeaders(),
body,

View File

@ -57,7 +57,7 @@ const statsHelpers = require('./statsHelpers.js');
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/secrets');
const { delay, getVersion, deepMerge } = require('./src/util');
const { invalidateThumbnail, ensureThumbnailCache } = require('./src/thumbnails');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS } = require('./src/tokenizers');
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/tokenizers');
const { convertClaudePrompt } = require('./src/chat-completion');
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
@ -2762,57 +2762,71 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
if (!request.body || !Array.isArray(request.body))
return response.sendStatus(400);
let result = {};
try {
const result = {};
const model = getTokenizerModel(String(request.query.model || ''));
const model = getTokenizerModel(String(request.query.model || ''));
// no bias for claude
if (model == 'claude') {
return response.send(result);
}
const tokenizer = getTiktokenTokenizer(model);
for (const entry of request.body) {
if (!entry || !entry.text) {
continue;
// no bias for claude
if (model == 'claude') {
return response.send(result);
}
try {
const tokens = getEntryTokens(entry.text);
let encodeFunction;
for (const token of tokens) {
result[token] = entry.value;
if (sentencepieceTokenizers.includes(model)) {
const tokenizer = getSentencepiceTokenizer(model);
encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text));
} else {
const tokenizer = getTiktokenTokenizer(model);
encodeFunction = (tokenizer.encode.bind(tokenizer));
}
for (const entry of request.body) {
if (!entry || !entry.text) {
continue;
}
} catch {
console.warn('Tokenizer failed to encode:', entry.text);
}
}
// not needed for cached tokenizers
//tokenizer.free();
return response.send(result);
/**
* Gets tokenids for a given entry
* @param {string} text Entry text
* @returns {Uint32Array} Array of token ids
*/
function getEntryTokens(text) {
// Get raw token ids from JSON array
if (text.trim().startsWith('[') && text.trim().endsWith(']')) {
try {
const json = JSON.parse(text);
if (Array.isArray(json) && json.every(x => typeof x === 'number')) {
return new Uint32Array(json);
const tokens = getEntryTokens(entry.text, encodeFunction);
for (const token of tokens) {
result[token] = entry.value;
}
} catch {
// ignore
console.warn('Tokenizer failed to encode:', entry.text);
}
}
// Otherwise, get token ids from tokenizer
return tokenizer.encode(text);
// not needed for cached tokenizers
//tokenizer.free();
return response.send(result);
/**
* Gets tokenids for a given entry
* @param {string} text Entry text
* @param {(string) => Uint32Array} encode Function to encode text to token ids
* @returns {Uint32Array} Array of token ids
*/
function getEntryTokens(text, encode) {
// Get raw token ids from JSON array
if (text.trim().startsWith('[') && text.trim().endsWith(']')) {
try {
const json = JSON.parse(text);
if (Array.isArray(json) && json.every(x => typeof x === 'number')) {
return new Uint32Array(json);
}
} catch {
// ignore
}
}
// Otherwise, get token ids from tokenizer
return encode(text);
}
} catch (error) {
console.error(error);
return response.send({});
}
});

View File

@ -60,6 +60,36 @@ async function loadSentencepieceTokenizer(modelPath) {
}
};
const sentencepieceTokenizers = [
'llama',
'nerdstash',
'nerdstash_v2',
'mistral',
];
/**
* Gets the Sentencepiece tokenizer by the model name.
* @param {string} model Sentencepiece model name
* @returns {*} Sentencepiece tokenizer
*/
function getSentencepiceTokenizer(model) {
if (model.includes('llama')) {
return spp_llama;
}
if (model.includes('nerdstash')) {
return spp_nerd;
}
if (model.includes('mistral')) {
return spp_mistral;
}
if (model.includes('nerdstash_v2')) {
return spp_nerd_v2;
}
}
async function countSentencepieceTokens(spp, text) {
// Fallback to strlen estimation
if (!spp) {
@ -438,5 +468,7 @@ module.exports = {
countClaudeTokens,
loadTokenizers,
registerEndpoints,
getSentencepiceTokenizer,
sentencepieceTokenizers,
}