mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-03-15 11:30:09 +01:00
Use correct tokenizers for logit bias for Mistral and Llama models over OpenRouter
This commit is contained in:
parent
1f36fe5193
commit
0e89bf90bc
@ -59,7 +59,7 @@ import {
|
||||
resetScrollHeight,
|
||||
stringFormat,
|
||||
} from "./utils.js";
|
||||
import { countTokensOpenAI } from "./tokenizers.js";
|
||||
import { countTokensOpenAI, getTokenizerModel } from "./tokenizers.js";
|
||||
import { formatInstructModeChat, formatInstructModeExamples, formatInstructModePrompt, formatInstructModeSystemPrompt } from "./instruct-mode.js";
|
||||
|
||||
export {
|
||||
@ -1541,7 +1541,7 @@ async function calculateLogitBias() {
|
||||
let result = {};
|
||||
|
||||
try {
|
||||
const reply = await fetch(`/openai_bias?model=${oai_settings.openai_model}`, {
|
||||
const reply = await fetch(`/openai_bias?model=${getTokenizerModel()}`, {
|
||||
method: 'POST',
|
||||
headers: getRequestHeaders(),
|
||||
body,
|
||||
|
94
server.js
94
server.js
@ -57,7 +57,7 @@ const statsHelpers = require('./statsHelpers.js');
|
||||
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/secrets');
|
||||
const { delay, getVersion, deepMerge } = require('./src/util');
|
||||
const { invalidateThumbnail, ensureThumbnailCache } = require('./src/thumbnails');
|
||||
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS } = require('./src/tokenizers');
|
||||
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/tokenizers');
|
||||
const { convertClaudePrompt } = require('./src/chat-completion');
|
||||
|
||||
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
||||
@ -2762,57 +2762,71 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
|
||||
if (!request.body || !Array.isArray(request.body))
|
||||
return response.sendStatus(400);
|
||||
|
||||
let result = {};
|
||||
try {
|
||||
const result = {};
|
||||
const model = getTokenizerModel(String(request.query.model || ''));
|
||||
|
||||
const model = getTokenizerModel(String(request.query.model || ''));
|
||||
|
||||
// no bias for claude
|
||||
if (model == 'claude') {
|
||||
return response.send(result);
|
||||
}
|
||||
|
||||
const tokenizer = getTiktokenTokenizer(model);
|
||||
|
||||
for (const entry of request.body) {
|
||||
if (!entry || !entry.text) {
|
||||
continue;
|
||||
// no bias for claude
|
||||
if (model == 'claude') {
|
||||
return response.send(result);
|
||||
}
|
||||
|
||||
try {
|
||||
const tokens = getEntryTokens(entry.text);
|
||||
let encodeFunction;
|
||||
|
||||
for (const token of tokens) {
|
||||
result[token] = entry.value;
|
||||
if (sentencepieceTokenizers.includes(model)) {
|
||||
const tokenizer = getSentencepiceTokenizer(model);
|
||||
encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text));
|
||||
} else {
|
||||
const tokenizer = getTiktokenTokenizer(model);
|
||||
encodeFunction = (tokenizer.encode.bind(tokenizer));
|
||||
}
|
||||
|
||||
|
||||
for (const entry of request.body) {
|
||||
if (!entry || !entry.text) {
|
||||
continue;
|
||||
}
|
||||
} catch {
|
||||
console.warn('Tokenizer failed to encode:', entry.text);
|
||||
}
|
||||
}
|
||||
|
||||
// not needed for cached tokenizers
|
||||
//tokenizer.free();
|
||||
return response.send(result);
|
||||
|
||||
/**
|
||||
* Gets tokenids for a given entry
|
||||
* @param {string} text Entry text
|
||||
* @returns {Uint32Array} Array of token ids
|
||||
*/
|
||||
function getEntryTokens(text) {
|
||||
// Get raw token ids from JSON array
|
||||
if (text.trim().startsWith('[') && text.trim().endsWith(']')) {
|
||||
try {
|
||||
const json = JSON.parse(text);
|
||||
if (Array.isArray(json) && json.every(x => typeof x === 'number')) {
|
||||
return new Uint32Array(json);
|
||||
const tokens = getEntryTokens(entry.text, encodeFunction);
|
||||
|
||||
for (const token of tokens) {
|
||||
result[token] = entry.value;
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
console.warn('Tokenizer failed to encode:', entry.text);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, get token ids from tokenizer
|
||||
return tokenizer.encode(text);
|
||||
// not needed for cached tokenizers
|
||||
//tokenizer.free();
|
||||
return response.send(result);
|
||||
|
||||
/**
|
||||
* Gets tokenids for a given entry
|
||||
* @param {string} text Entry text
|
||||
* @param {(string) => Uint32Array} encode Function to encode text to token ids
|
||||
* @returns {Uint32Array} Array of token ids
|
||||
*/
|
||||
function getEntryTokens(text, encode) {
|
||||
// Get raw token ids from JSON array
|
||||
if (text.trim().startsWith('[') && text.trim().endsWith(']')) {
|
||||
try {
|
||||
const json = JSON.parse(text);
|
||||
if (Array.isArray(json) && json.every(x => typeof x === 'number')) {
|
||||
return new Uint32Array(json);
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, get token ids from tokenizer
|
||||
return encode(text);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return response.send({});
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -60,6 +60,36 @@ async function loadSentencepieceTokenizer(modelPath) {
|
||||
}
|
||||
};
|
||||
|
||||
const sentencepieceTokenizers = [
|
||||
'llama',
|
||||
'nerdstash',
|
||||
'nerdstash_v2',
|
||||
'mistral',
|
||||
];
|
||||
|
||||
/**
|
||||
* Gets the Sentencepiece tokenizer by the model name.
|
||||
* @param {string} model Sentencepiece model name
|
||||
* @returns {*} Sentencepiece tokenizer
|
||||
*/
|
||||
function getSentencepiceTokenizer(model) {
|
||||
if (model.includes('llama')) {
|
||||
return spp_llama;
|
||||
}
|
||||
|
||||
if (model.includes('nerdstash')) {
|
||||
return spp_nerd;
|
||||
}
|
||||
|
||||
if (model.includes('mistral')) {
|
||||
return spp_mistral;
|
||||
}
|
||||
|
||||
if (model.includes('nerdstash_v2')) {
|
||||
return spp_nerd_v2;
|
||||
}
|
||||
}
|
||||
|
||||
async function countSentencepieceTokens(spp, text) {
|
||||
// Fallback to strlen estimation
|
||||
if (!spp) {
|
||||
@ -438,5 +468,7 @@ module.exports = {
|
||||
countClaudeTokens,
|
||||
loadTokenizers,
|
||||
registerEndpoints,
|
||||
getSentencepiceTokenizer,
|
||||
sentencepieceTokenizers,
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user