Add types for SillyTavern.getContext

This commit is contained in:
Cohee
2024-12-06 16:41:26 +02:00
parent 3502bfcaa0
commit 77841dbc21
7 changed files with 228 additions and 146 deletions

View File

@ -8,8 +8,6 @@ import { kai_flags } from './kai-settings.js';
import { textgen_types, textgenerationwebui_settings as textgen_settings, getTextGenServer, getTextGenModel } from './textgen-settings.js';
import { getCurrentDreamGenModelTokenizer, getCurrentOpenRouterModelTokenizer, openRouterModels } from './textgen-models.js';
const { OOBA, TABBY, KOBOLDCPP, VLLM, APHRODITE, LLAMACPP, OPENROUTER, DREAMGEN } = textgen_types;
export const CHARACTERS_PER_TOKEN_RATIO = 3.35;
export const TOKENIZER_WARNING_KEY = 'tokenizationWarningShown';
export const TOKENIZER_SUPPORTED_KEY = 'tokenizationSupported';
@ -52,8 +50,12 @@ export const ENCODE_TOKENIZERS = [
//tokenizers.NERD2,
];
// A list of Text Completion sources that support remote tokenization.
export const TEXTGEN_TOKENIZERS = [OOBA, TABBY, KOBOLDCPP, LLAMACPP, VLLM, APHRODITE];
/**
* A list of Text Completion sources that support remote tokenization.
* Populated in initTokenziers due to circular dependencies.
* @type {string[]}
*/
export const TEXTGEN_TOKENIZERS = [];
const TOKENIZER_URLS = {
[tokenizers.GPT2]: {
@ -287,7 +289,7 @@ export function getTokenizerBestMatch(forApi) {
const hasTokenizerError = sessionStorage.getItem(TOKENIZER_WARNING_KEY);
const hasValidEndpoint = sessionStorage.getItem(TOKENIZER_SUPPORTED_KEY);
const isConnected = online_status !== 'no_connection';
const isTokenizerSupported = TEXTGEN_TOKENIZERS.includes(textgen_settings.type) && (textgen_settings.type !== OOBA || hasValidEndpoint);
const isTokenizerSupported = TEXTGEN_TOKENIZERS.includes(textgen_settings.type) && (textgen_settings.type !== textgen_types.OOBA || hasValidEndpoint);
if (!hasTokenizerError && isConnected) {
if (forApi === 'kobold' && kai_flags.can_use_tokenization) {
@ -297,10 +299,10 @@ export function getTokenizerBestMatch(forApi) {
if (forApi === 'textgenerationwebui' && isTokenizerSupported) {
return tokenizers.API_TEXTGENERATIONWEBUI;
}
if (forApi === 'textgenerationwebui' && textgen_settings.type === OPENROUTER) {
if (forApi === 'textgenerationwebui' && textgen_settings.type === textgen_types.OPENROUTER) {
return getCurrentOpenRouterModelTokenizer();
}
if (forApi === 'textgenerationwebui' && textgen_settings.type === DREAMGEN) {
if (forApi === 'textgenerationwebui' && textgen_settings.type === textgen_types.DREAMGEN) {
return getCurrentDreamGenModelTokenizer();
}
}
@ -576,7 +578,7 @@ export function getTokenizerModel() {
// And for OpenRouter (if not a site model, then it's impossible to determine the tokenizer)
if (main_api == 'openai' && oai_settings.chat_completion_source == chat_completion_sources.OPENROUTER && oai_settings.openrouter_model ||
main_api == 'textgenerationwebui' && textgen_settings.type === OPENROUTER && textgen_settings.openrouter_model) {
main_api == 'textgenerationwebui' && textgen_settings.type === textgen_types.OPENROUTER && textgen_settings.openrouter_model) {
const model = main_api == 'openai'
? model_list.find(x => x.id === oai_settings.openrouter_model)
: openRouterModels.find(x => x.id === textgen_settings.openrouter_model);
@ -652,7 +654,7 @@ export function getTokenizerModel() {
return oai_settings.custom_model;
}
if (oai_settings.chat_completion_source === chat_completion_sources.PERPLEXITY) {
if (oai_settings.chat_completion_source === chat_completion_sources.PERPLEXITY) {
if (oai_settings.perplexity_model.includes('llama-3') || oai_settings.perplexity_model.includes('llama3')) {
return llama3Tokenizer;
}
@ -680,7 +682,7 @@ export function getTokenizerModel() {
return yiTokenizer;
}
if (oai_settings.chat_completion_source === chat_completion_sources.BLOCKENTROPY) {
if (oai_settings.chat_completion_source === chat_completion_sources.BLOCKENTROPY) {
if (oai_settings.blockentropy_model.includes('llama3')) {
return llama3Tokenizer;
}
@ -1121,6 +1123,14 @@ export function decodeTextTokens(tokenizerType, ids) {
}
export async function initTokenizers() {
TEXTGEN_TOKENIZERS.push(
textgen_types.OOBA,
textgen_types.TABBY,
textgen_types.KOBOLDCPP,
textgen_types.LLAMACPP,
textgen_types.VLLM,
textgen_types.APHRODITE,
);
await loadTokenCache();
registerDebugFunction('resetTokenCache', 'Reset token cache', 'Purges the calculated token counts. Use this if you want to force a full re-tokenization of all chats or suspect the token counts are wrong.', resetTokenCache);
}