Add Kobold tokenization to best match logic. Fix not being able to stop group chat regeneration

This commit is contained in:
Cohee
2023-08-24 21:23:35 +03:00
parent 14d94d9108
commit c91ab3b5e0
5 changed files with 44 additions and 19 deletions

View File

@ -1,12 +1,14 @@
import { characters, main_api, nai_settings, this_chid } from "../script.js";
import { characters, main_api, nai_settings, online_status, this_chid } from "../script.js";
import { power_user } from "./power-user.js";
import { encode } from "../lib/gpt-2-3-tokenizer/mod.js";
import { GPT3BrowserTokenizer } from "../lib/gpt-3-tokenizer/gpt3-tokenizer.js";
import { chat_completion_sources, oai_settings } from "./openai.js";
import { groups, selected_group } from "./group-chats.js";
import { getStringHash } from "./utils.js";
import { kai_settings } from "./kai-settings.js";
export const CHARACTERS_PER_TOKEN_RATIO = 3.35;
const TOKENIZER_WARNING_KEY = 'tokenizationWarningShown';
export const tokenizers = {
NONE: 0,
@ -77,6 +79,14 @@ function getTokenizerBestMatch() {
}
}
if (main_api === 'kobold' || main_api === 'textgenerationwebui' || main_api === 'koboldhorde') {
// Try to use the API tokenizer if possible:
// - API must be connected
// - Kobold must pass a version check
// - Tokenizer haven't reported an error previously
if (kai_settings.can_use_tokenization && !sessionStorage.getItem(TOKENIZER_WARNING_KEY) && online_status !== 'no_connection') {
return tokenizers.API;
}
return tokenizers.LLAMA;
}
@ -324,14 +334,14 @@ function countTokensRemote(endpoint, str, padding) {
tokenCount = guesstimate(str);
console.error("Error counting tokens");
if (!sessionStorage.getItem('tokenizationWarningShown')) {
if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning(
"Your selected API doesn't support the tokenization endpoint. Using estimated counts.",
"Error counting tokens",
{ timeOut: 10000, preventDuplicates: true },
);
sessionStorage.setItem('tokenizationWarningShown', String(true));
sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
}
}
}