mirror of
https://github.com/SillyTavern/SillyTavern.git
synced 2025-06-05 21:59:27 +02:00
Use correct tokenizers for logit bias for Mistral and Llama models over OpenRouter
This commit is contained in:
@ -59,7 +59,7 @@ import {
|
|||||||
resetScrollHeight,
|
resetScrollHeight,
|
||||||
stringFormat,
|
stringFormat,
|
||||||
} from "./utils.js";
|
} from "./utils.js";
|
||||||
import { countTokensOpenAI } from "./tokenizers.js";
|
import { countTokensOpenAI, getTokenizerModel } from "./tokenizers.js";
|
||||||
import { formatInstructModeChat, formatInstructModeExamples, formatInstructModePrompt, formatInstructModeSystemPrompt } from "./instruct-mode.js";
|
import { formatInstructModeChat, formatInstructModeExamples, formatInstructModePrompt, formatInstructModeSystemPrompt } from "./instruct-mode.js";
|
||||||
|
|
||||||
export {
|
export {
|
||||||
@ -1541,7 +1541,7 @@ async function calculateLogitBias() {
|
|||||||
let result = {};
|
let result = {};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const reply = await fetch(`/openai_bias?model=${oai_settings.openai_model}`, {
|
const reply = await fetch(`/openai_bias?model=${getTokenizerModel()}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: getRequestHeaders(),
|
headers: getRequestHeaders(),
|
||||||
body,
|
body,
|
||||||
|
94
server.js
94
server.js
@ -57,7 +57,7 @@ const statsHelpers = require('./statsHelpers.js');
|
|||||||
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/secrets');
|
const { readSecret, migrateSecrets, SECRET_KEYS } = require('./src/secrets');
|
||||||
const { delay, getVersion, deepMerge } = require('./src/util');
|
const { delay, getVersion, deepMerge } = require('./src/util');
|
||||||
const { invalidateThumbnail, ensureThumbnailCache } = require('./src/thumbnails');
|
const { invalidateThumbnail, ensureThumbnailCache } = require('./src/thumbnails');
|
||||||
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS } = require('./src/tokenizers');
|
const { getTokenizerModel, getTiktokenTokenizer, loadTokenizers, TEXT_COMPLETION_MODELS, getSentencepiceTokenizer, sentencepieceTokenizers } = require('./src/tokenizers');
|
||||||
const { convertClaudePrompt } = require('./src/chat-completion');
|
const { convertClaudePrompt } = require('./src/chat-completion');
|
||||||
|
|
||||||
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
// Work around a node v20.0.0, v20.1.0, and v20.2.0 bug. The issue was fixed in v20.3.0.
|
||||||
@ -2762,57 +2762,71 @@ app.post("/openai_bias", jsonParser, async function (request, response) {
|
|||||||
if (!request.body || !Array.isArray(request.body))
|
if (!request.body || !Array.isArray(request.body))
|
||||||
return response.sendStatus(400);
|
return response.sendStatus(400);
|
||||||
|
|
||||||
let result = {};
|
try {
|
||||||
|
const result = {};
|
||||||
|
const model = getTokenizerModel(String(request.query.model || ''));
|
||||||
|
|
||||||
const model = getTokenizerModel(String(request.query.model || ''));
|
// no bias for claude
|
||||||
|
if (model == 'claude') {
|
||||||
// no bias for claude
|
return response.send(result);
|
||||||
if (model == 'claude') {
|
|
||||||
return response.send(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
const tokenizer = getTiktokenTokenizer(model);
|
|
||||||
|
|
||||||
for (const entry of request.body) {
|
|
||||||
if (!entry || !entry.text) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
let encodeFunction;
|
||||||
const tokens = getEntryTokens(entry.text);
|
|
||||||
|
|
||||||
for (const token of tokens) {
|
if (sentencepieceTokenizers.includes(model)) {
|
||||||
result[token] = entry.value;
|
const tokenizer = getSentencepiceTokenizer(model);
|
||||||
|
encodeFunction = (text) => new Uint32Array(tokenizer.encodeIds(text));
|
||||||
|
} else {
|
||||||
|
const tokenizer = getTiktokenTokenizer(model);
|
||||||
|
encodeFunction = (tokenizer.encode.bind(tokenizer));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (const entry of request.body) {
|
||||||
|
if (!entry || !entry.text) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
} catch {
|
|
||||||
console.warn('Tokenizer failed to encode:', entry.text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// not needed for cached tokenizers
|
|
||||||
//tokenizer.free();
|
|
||||||
return response.send(result);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets tokenids for a given entry
|
|
||||||
* @param {string} text Entry text
|
|
||||||
* @returns {Uint32Array} Array of token ids
|
|
||||||
*/
|
|
||||||
function getEntryTokens(text) {
|
|
||||||
// Get raw token ids from JSON array
|
|
||||||
if (text.trim().startsWith('[') && text.trim().endsWith(']')) {
|
|
||||||
try {
|
try {
|
||||||
const json = JSON.parse(text);
|
const tokens = getEntryTokens(entry.text, encodeFunction);
|
||||||
if (Array.isArray(json) && json.every(x => typeof x === 'number')) {
|
|
||||||
return new Uint32Array(json);
|
for (const token of tokens) {
|
||||||
|
result[token] = entry.value;
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
// ignore
|
console.warn('Tokenizer failed to encode:', entry.text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, get token ids from tokenizer
|
// not needed for cached tokenizers
|
||||||
return tokenizer.encode(text);
|
//tokenizer.free();
|
||||||
|
return response.send(result);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets tokenids for a given entry
|
||||||
|
* @param {string} text Entry text
|
||||||
|
* @param {(string) => Uint32Array} encode Function to encode text to token ids
|
||||||
|
* @returns {Uint32Array} Array of token ids
|
||||||
|
*/
|
||||||
|
function getEntryTokens(text, encode) {
|
||||||
|
// Get raw token ids from JSON array
|
||||||
|
if (text.trim().startsWith('[') && text.trim().endsWith(']')) {
|
||||||
|
try {
|
||||||
|
const json = JSON.parse(text);
|
||||||
|
if (Array.isArray(json) && json.every(x => typeof x === 'number')) {
|
||||||
|
return new Uint32Array(json);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, get token ids from tokenizer
|
||||||
|
return encode(text);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
return response.send({});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -60,6 +60,36 @@ async function loadSentencepieceTokenizer(modelPath) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const sentencepieceTokenizers = [
|
||||||
|
'llama',
|
||||||
|
'nerdstash',
|
||||||
|
'nerdstash_v2',
|
||||||
|
'mistral',
|
||||||
|
];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the Sentencepiece tokenizer by the model name.
|
||||||
|
* @param {string} model Sentencepiece model name
|
||||||
|
* @returns {*} Sentencepiece tokenizer
|
||||||
|
*/
|
||||||
|
function getSentencepiceTokenizer(model) {
|
||||||
|
if (model.includes('llama')) {
|
||||||
|
return spp_llama;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model.includes('nerdstash')) {
|
||||||
|
return spp_nerd;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model.includes('mistral')) {
|
||||||
|
return spp_mistral;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model.includes('nerdstash_v2')) {
|
||||||
|
return spp_nerd_v2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function countSentencepieceTokens(spp, text) {
|
async function countSentencepieceTokens(spp, text) {
|
||||||
// Fallback to strlen estimation
|
// Fallback to strlen estimation
|
||||||
if (!spp) {
|
if (!spp) {
|
||||||
@ -438,5 +468,7 @@ module.exports = {
|
|||||||
countClaudeTokens,
|
countClaudeTokens,
|
||||||
loadTokenizers,
|
loadTokenizers,
|
||||||
registerEndpoints,
|
registerEndpoints,
|
||||||
|
getSentencepiceTokenizer,
|
||||||
|
sentencepieceTokenizers,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user