Add Kobold tokenization to best match logic. Fix not being able to stop group chat regeneration

This commit is contained in:
Cohee 2023-08-24 21:23:35 +03:00
parent 14d94d9108
commit c91ab3b5e0
5 changed files with 44 additions and 19 deletions

View File

@ -8,6 +8,7 @@ import {
getKoboldGenerationData, getKoboldGenerationData,
canUseKoboldStopSequence, canUseKoboldStopSequence,
canUseKoboldStreaming, canUseKoboldStreaming,
canUseKoboldTokenization,
} from "./scripts/kai-settings.js"; } from "./scripts/kai-settings.js";
import { import {
@ -783,6 +784,7 @@ async function getStatus() {
if (main_api === "kobold" || main_api === "koboldhorde") { if (main_api === "kobold" || main_api === "koboldhorde") {
kai_settings.use_stop_sequence = canUseKoboldStopSequence(data.version); kai_settings.use_stop_sequence = canUseKoboldStopSequence(data.version);
kai_settings.can_use_streaming = canUseKoboldStreaming(data.koboldVersion); kai_settings.can_use_streaming = canUseKoboldStreaming(data.koboldVersion);
kai_settings.can_use_tokenization = canUseKoboldTokenization(data.koboldVersion);
} }
// We didn't get a 200 status code, but the endpoint has an explanation. Which means it DID connect, but I digress. // We didn't get a 200 status code, but the endpoint has an explanation. Which means it DID connect, but I digress.
@ -4007,6 +4009,10 @@ export function setMenuType(value) {
menu_type = value; menu_type = value;
} }
export function setExternalAbortController(controller) {
abortController = controller;
}
function setCharacterId(value) { function setCharacterId(value) {
this_chid = value; this_chid = value;
} }

View File

@ -65,6 +65,7 @@ import {
getCropPopup, getCropPopup,
system_avatar, system_avatar,
isChatSaving, isChatSaving,
setExternalAbortController,
} from "../script.js"; } from "../script.js";
import { appendTagToList, createTagMapFromList, getTagsList, applyTagsOnCharacterSelect, tag_map, printTagFilters } from './tags.js'; import { appendTagToList, createTagMapFromList, getTagsList, applyTagsOnCharacterSelect, tag_map, printTagFilters } from './tags.js';
import { FILTER_TYPES, FilterHelper } from './filters.js'; import { FILTER_TYPES, FilterHelper } from './filters.js';
@ -135,7 +136,9 @@ async function regenerateGroup() {
await deleteLastMessage(); await deleteLastMessage();
} }
generateGroupWrapper(); const abortController = new AbortController();
setExternalAbortController(abortController);
generateGroupWrapper(false, 'normal', { signal: abortController.signal });
} }
async function loadGroupChat(chatId) { async function loadGroupChat(chatId) {

View File

@ -9,16 +9,7 @@ import {
} from "./power-user.js"; } from "./power-user.js";
import { getSortableDelay } from "./utils.js"; import { getSortableDelay } from "./utils.js";
export { export const kai_settings = {
kai_settings,
loadKoboldSettings,
formatKoboldUrl,
getKoboldGenerationData,
canUseKoboldStopSequence,
canUseKoboldStreaming,
};
const kai_settings = {
temp: 1, temp: 1,
rep_pen: 1, rep_pen: 1,
rep_pen_range: 0, rep_pen_range: 0,
@ -30,15 +21,17 @@ const kai_settings = {
rep_pen_slope: 0.9, rep_pen_slope: 0.9,
single_line: false, single_line: false,
use_stop_sequence: false, use_stop_sequence: false,
can_use_tokenization: false,
streaming_kobold: false, streaming_kobold: false,
sampler_order: [0, 1, 2, 3, 4, 5, 6], sampler_order: [0, 1, 2, 3, 4, 5, 6],
}; };
const MIN_STOP_SEQUENCE_VERSION = '1.2.2'; const MIN_STOP_SEQUENCE_VERSION = '1.2.2';
const MIN_STREAMING_KCPPVERSION = '1.30'; const MIN_STREAMING_KCPPVERSION = '1.30';
const MIN_TOKENIZATION_KCPPVERSION = '1.41';
const KOBOLDCPP_ORDER = [6, 0, 1, 3, 4, 2, 5]; const KOBOLDCPP_ORDER = [6, 0, 1, 3, 4, 2, 5];
function formatKoboldUrl(value) { export function formatKoboldUrl(value) {
try { try {
const url = new URL(value); const url = new URL(value);
if (!power_user.relaxed_api_urls) { if (!power_user.relaxed_api_urls) {
@ -49,7 +42,7 @@ function formatKoboldUrl(value) {
return null; return null;
} }
function loadKoboldSettings(preset) { export function loadKoboldSettings(preset) {
for (const name of Object.keys(kai_settings)) { for (const name of Object.keys(kai_settings)) {
const value = preset[name]; const value = preset[name];
const slider = sliders.find(x => x.name === name); const slider = sliders.find(x => x.name === name);
@ -75,7 +68,7 @@ function loadKoboldSettings(preset) {
} }
} }
function getKoboldGenerationData(finalPrompt, this_settings, this_amount_gen, this_max_context, isImpersonate, type) { export function getKoboldGenerationData(finalPrompt, this_settings, this_amount_gen, this_max_context, isImpersonate, type) {
const sampler_order = kai_settings.sampler_order || this_settings.sampler_order; const sampler_order = kai_settings.sampler_order || this_settings.sampler_order;
let generate_data = { let generate_data = {
prompt: finalPrompt, prompt: finalPrompt,
@ -228,7 +221,7 @@ const sliders = [
* @param {string} version KoboldAI version to check. * @param {string} version KoboldAI version to check.
* @returns {boolean} True if the Kobold stop sequence can be used, false otherwise. * @returns {boolean} True if the Kobold stop sequence can be used, false otherwise.
*/ */
function canUseKoboldStopSequence(version) { export function canUseKoboldStopSequence(version) {
return (version || '0.0.0').localeCompare(MIN_STOP_SEQUENCE_VERSION, undefined, { numeric: true, sensitivity: 'base' }) > -1; return (version || '0.0.0').localeCompare(MIN_STOP_SEQUENCE_VERSION, undefined, { numeric: true, sensitivity: 'base' }) > -1;
} }
@ -237,12 +230,23 @@ function canUseKoboldStopSequence(version) {
* @param {{ result: string; version: string; }} koboldVersion KoboldAI version object. * @param {{ result: string; version: string; }} koboldVersion KoboldAI version object.
* @returns {boolean} True if the Kobold streaming API can be used, false otherwise. * @returns {boolean} True if the Kobold streaming API can be used, false otherwise.
*/ */
function canUseKoboldStreaming(koboldVersion) { export function canUseKoboldStreaming(koboldVersion) {
if (koboldVersion && koboldVersion.result == 'KoboldCpp') { if (koboldVersion && koboldVersion.result == 'KoboldCpp') {
return (koboldVersion.version || '0.0').localeCompare(MIN_STREAMING_KCPPVERSION, undefined, { numeric: true, sensitivity: 'base' }) > -1; return (koboldVersion.version || '0.0').localeCompare(MIN_STREAMING_KCPPVERSION, undefined, { numeric: true, sensitivity: 'base' }) > -1;
} else return false; } else return false;
} }
/**
* Determines if the Kobold tokenization API can be used with the given version.
* @param {{ result: string; version: string; }} koboldVersion KoboldAI version object.
* @returns {boolean} True if the Kobold tokenization API can be used, false otherwise.
*/
export function canUseKoboldTokenization(koboldVersion) {
if (koboldVersion && koboldVersion.result == 'KoboldCpp') {
return (koboldVersion.version || '0.0').localeCompare(MIN_TOKENIZATION_KCPPVERSION, undefined, { numeric: true, sensitivity: 'base' }) > -1;
} else return false;
}
/** /**
* Sorts the sampler items by the given order. * Sorts the sampler items by the given order.
* @param {any[]} orderArray Sampler order array. * @param {any[]} orderArray Sampler order array.

View File

@ -246,6 +246,8 @@ class PresetManager {
'streaming_url', 'streaming_url',
'stopping_strings', 'stopping_strings',
'use_stop_sequence', 'use_stop_sequence',
'can_use_tokenization',
'can_use_streaming',
'preset_settings_novel', 'preset_settings_novel',
'streaming_novel', 'streaming_novel',
'nai_preamble', 'nai_preamble',

View File

@ -1,12 +1,14 @@
import { characters, main_api, nai_settings, this_chid } from "../script.js"; import { characters, main_api, nai_settings, online_status, this_chid } from "../script.js";
import { power_user } from "./power-user.js"; import { power_user } from "./power-user.js";
import { encode } from "../lib/gpt-2-3-tokenizer/mod.js"; import { encode } from "../lib/gpt-2-3-tokenizer/mod.js";
import { GPT3BrowserTokenizer } from "../lib/gpt-3-tokenizer/gpt3-tokenizer.js"; import { GPT3BrowserTokenizer } from "../lib/gpt-3-tokenizer/gpt3-tokenizer.js";
import { chat_completion_sources, oai_settings } from "./openai.js"; import { chat_completion_sources, oai_settings } from "./openai.js";
import { groups, selected_group } from "./group-chats.js"; import { groups, selected_group } from "./group-chats.js";
import { getStringHash } from "./utils.js"; import { getStringHash } from "./utils.js";
import { kai_settings } from "./kai-settings.js";
export const CHARACTERS_PER_TOKEN_RATIO = 3.35; export const CHARACTERS_PER_TOKEN_RATIO = 3.35;
const TOKENIZER_WARNING_KEY = 'tokenizationWarningShown';
export const tokenizers = { export const tokenizers = {
NONE: 0, NONE: 0,
@ -77,6 +79,14 @@ function getTokenizerBestMatch() {
} }
} }
if (main_api === 'kobold' || main_api === 'textgenerationwebui' || main_api === 'koboldhorde') { if (main_api === 'kobold' || main_api === 'textgenerationwebui' || main_api === 'koboldhorde') {
// Try to use the API tokenizer if possible:
// - API must be connected
// - Kobold must pass a version check
// - Tokenizer haven't reported an error previously
if (kai_settings.can_use_tokenization && !sessionStorage.getItem(TOKENIZER_WARNING_KEY) && online_status !== 'no_connection') {
return tokenizers.API;
}
return tokenizers.LLAMA; return tokenizers.LLAMA;
} }
@ -324,14 +334,14 @@ function countTokensRemote(endpoint, str, padding) {
tokenCount = guesstimate(str); tokenCount = guesstimate(str);
console.error("Error counting tokens"); console.error("Error counting tokens");
if (!sessionStorage.getItem('tokenizationWarningShown')) { if (!sessionStorage.getItem(TOKENIZER_WARNING_KEY)) {
toastr.warning( toastr.warning(
"Your selected API doesn't support the tokenization endpoint. Using estimated counts.", "Your selected API doesn't support the tokenization endpoint. Using estimated counts.",
"Error counting tokens", "Error counting tokens",
{ timeOut: 10000, preventDuplicates: true }, { timeOut: 10000, preventDuplicates: true },
); );
sessionStorage.setItem('tokenizationWarningShown', String(true)); sessionStorage.setItem(TOKENIZER_WARNING_KEY, String(true));
} }
} }
} }