Improve tokenizer detection

This commit is contained in:
Cohee
2024-01-05 16:17:06 +02:00
parent a39b6b31f4
commit 86d715cc16
2 changed files with 19 additions and 8 deletions

View File

@@ -15,7 +15,7 @@ import {
registerDebugFunction, registerDebugFunction,
} from './power-user.js'; } from './power-user.js';
import EventSourceStream from './sse-stream.js'; import EventSourceStream from './sse-stream.js';
import { SENTENCEPIECE_TOKENIZERS, getTextTokens, tokenizers } from './tokenizers.js'; import { SENTENCEPIECE_TOKENIZERS, TEXTGEN_TOKENIZERS, getTextTokens, tokenizers } from './tokenizers.js';
import { getSortableDelay, onlyUnique } from './utils.js'; import { getSortableDelay, onlyUnique } from './utils.js';
export { export {
@@ -241,6 +241,18 @@ function convertPresets(presets) {
return Array.isArray(presets) ? presets.map((p) => JSON.parse(p)) : []; return Array.isArray(presets) ? presets.map((p) => JSON.parse(p)) : [];
} }
function getTokenizerForTokenIds() {
if (power_user.tokenizer === tokenizers.API_CURRENT && TEXTGEN_TOKENIZERS.includes(settings.type)) {
return tokenizers.API_CURRENT;
}
if (SENTENCEPIECE_TOKENIZERS.includes(power_user.tokenizer)) {
return power_user.tokenizer;
}
return tokenizers.LLAMA;
}
/** /**
* @returns {string} String with comma-separated banned token IDs * @returns {string} String with comma-separated banned token IDs
*/ */
@@ -249,7 +261,7 @@ function getCustomTokenBans() {
return ''; return '';
} }
const tokenizer = SENTENCEPIECE_TOKENIZERS.includes(power_user.tokenizer) ? power_user.tokenizer : tokenizers.LLAMA; const tokenizer = getTokenizerForTokenIds();
const result = []; const result = [];
const sequences = settings.banned_tokens const sequences = settings.banned_tokens
.split('\n') .split('\n')
@@ -301,7 +313,7 @@ function calculateLogitBias() {
return {}; return {};
} }
const tokenizer = SENTENCEPIECE_TOKENIZERS.includes(power_user.tokenizer) ? power_user.tokenizer : tokenizers.LLAMA; const tokenizer = getTokenizerForTokenIds();
const result = {}; const result = {};
/** /**

View File

@@ -30,14 +30,13 @@ export const SENTENCEPIECE_TOKENIZERS = [
tokenizers.LLAMA, tokenizers.LLAMA,
tokenizers.MISTRAL, tokenizers.MISTRAL,
tokenizers.YI, tokenizers.YI,
tokenizers.API_CURRENT,
tokenizers.API_KOBOLD,
tokenizers.API_TEXTGENERATIONWEBUI,
// uncomment when NovelAI releases Kayra and Clio weights, lol // uncomment when NovelAI releases Kayra and Clio weights, lol
//tokenizers.NERD, //tokenizers.NERD,
//tokenizers.NERD2, //tokenizers.NERD2,
]; ];
export const TEXTGEN_TOKENIZERS = [OOBA, TABBY, KOBOLDCPP, LLAMACPP];
const TOKENIZER_URLS = { const TOKENIZER_URLS = {
[tokenizers.GPT2]: { [tokenizers.GPT2]: {
encode: '/api/tokenizers/gpt2/encode', encode: '/api/tokenizers/gpt2/encode',
@@ -193,7 +192,7 @@ export function getTokenizerBestMatch(forApi) {
// - Tokenizer haven't reported an error previously // - Tokenizer haven't reported an error previously
const hasTokenizerError = sessionStorage.getItem(TOKENIZER_WARNING_KEY); const hasTokenizerError = sessionStorage.getItem(TOKENIZER_WARNING_KEY);
const isConnected = online_status !== 'no_connection'; const isConnected = online_status !== 'no_connection';
const isTokenizerSupported = [OOBA, TABBY, KOBOLDCPP, LLAMACPP].includes(textgen_settings.type); const isTokenizerSupported = TEXTGEN_TOKENIZERS.includes(textgen_settings.type);
if (!hasTokenizerError && isConnected) { if (!hasTokenizerError && isConnected) {
if (forApi === 'kobold' && kai_flags.can_use_tokenization) { if (forApi === 'kobold' && kai_flags.can_use_tokenization) {