From 57b9a94c17b6eee41714dc259357ef0177e3370f Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Tue, 15 Aug 2023 15:51:14 +0300 Subject: [PATCH] Add banned tokens for Novel --- public/index.html | 11 ++++ public/scripts/nai-settings.js | 104 ++++++++++++++++++++++++++++++++- server.js | 13 ++++- 3 files changed, 125 insertions(+), 3 deletions(-) diff --git a/public/index.html b/public/index.html index de5ba7155..fcad53a20 100644 --- a/public/index.html +++ b/public/index.html @@ -882,6 +882,17 @@ +
+
+ Banned Tokens +
+
+ Sequences you don't want to appear in the output. One per line. +
+
+ +
+

diff --git a/public/scripts/nai-settings.js b/public/scripts/nai-settings.js index 29b3e766d..526ce6a4b 100644 --- a/public/scripts/nai-settings.js +++ b/public/scripts/nai-settings.js @@ -6,6 +6,7 @@ import { } from "../script.js"; import { getCfg } from "./extensions/cfg/util.js"; import { tokenizers } from "./power-user.js"; +import { getStringHash } from "./utils.js"; export { nai_settings, @@ -37,6 +38,7 @@ const nai_settings = { nai_preamble: default_preamble, prefix: '', cfg_uc: '', + banned_tokens: '', }; const nai_tiers = { @@ -47,6 +49,7 @@ const nai_tiers = { }; let novel_data = null; +let badWordsCache = {}; export function setNovelData(data) { novel_data = data; @@ -99,6 +102,7 @@ function loadNovelPreset(preset) { nai_settings.mirostat_tau = preset.mirostat_tau; nai_settings.prefix = preset.prefix; nai_settings.cfg_uc = preset.cfg_uc || ''; + nai_settings.banned_tokens = preset.banned_tokens || ''; loadNovelSettingsUi(nai_settings); } @@ -131,6 +135,7 @@ function loadNovelSettings(settings) { nai_settings.preamble = settings.preamble || default_preamble; nai_settings.prefix = settings.prefix; nai_settings.cfg_uc = settings.cfg_uc || ''; + nai_settings.banned_tokens = settings.banned_tokens || ''; loadNovelSettingsUi(nai_settings); } @@ -171,6 +176,7 @@ function loadNovelSettingsUi(ui_settings) { $('#nai_preamble_textarea').val(ui_settings.nai_preamble); $('#nai_prefix').val(ui_settings.prefix || "vanilla"); $('#nai_cfg_uc').val(ui_settings.cfg_uc || ""); + $('#nai_banned_tokens').val(ui_settings.banned_tokens || ""); $("#streaming_novel").prop('checked', ui_settings.streaming_novel); } @@ -278,8 +284,99 @@ const sliders = [ format: (val) => val, setValue: (val) => { nai_settings.cfg_uc = val; }, }, + { + sliderId: "#nai_banned_tokens", + counterId: "#nai_banned_tokens_counter", + format: (val) => val, + setValue: (val) => { nai_settings.banned_tokens = val; }, + } ]; +function getBadWordIds(banned_tokens, tokenizerType) { + if (tokenizerType === tokenizers.NONE) { + return []; + } + + const cacheKey = `${getStringHash(banned_tokens)}-${tokenizerType}`; + + if (cacheKey in badWordsCache && Array.isArray(badWordsCache[cacheKey])) { + console.debug(`Bad words ids cache hit for "${banned_tokens}"`, badWordsCache[cacheKey]); + return badWordsCache[cacheKey]; + } + + const result = []; + const sequence = banned_tokens.split('\n'); + + for (let token of sequence) { + const trimmed = token.trim(); + + // Skip empty lines + if (trimmed.length === 0) { + continue; + } + + // Verbatim text + if (trimmed.startsWith('{') && trimmed.endsWith('}')) { + const tokens = getTextTokens(tokenizerType, trimmed.slice(1, -1)); + result.push(tokens); + } + + // Raw token ids, JSON serialized + else if (trimmed.startsWith('[') && trimmed.endsWith(']')) { + try { + const tokens = JSON.parse(trimmed); + + if (Array.isArray(tokens) && tokens.every(t => Number.isInteger(t))) { + result.push(tokens); + } else { + throw new Error('Not an array of integers'); + } + } catch (err) { + console.log(`Failed to parse bad word token list: ${trimmed}`, err); + } + } + + // Apply permutations + else { + const permutations = getBadWordPermutations(trimmed).map(t => getTextTokens(tokenizerType, t)); + result.push(...permutations); + } + } + + // Cache the result + console.debug(`Bad words ids for "${banned_tokens}"`, result); + badWordsCache[cacheKey] = result; + + return result; +} + +function getBadWordPermutations(text) { + const result = []; + + // Original text + result.push(text); + // Original text + leading space + result.push(` ${text}`); + // First letter capitalized + result.push(text[0].toUpperCase() + text.slice(1)); + // Ditto + leading space + result.push(` ${text[0].toUpperCase() + text.slice(1)}`); + // First letter lower cased + result.push(text[0].toLowerCase() + text.slice(1)); + // Ditto + leading space + result.push(` ${text[0].toLowerCase() + text.slice(1)}`); + // Original all upper cased + result.push(text.toUpperCase()); + // Ditto + leading space + result.push(` ${text.toUpperCase()}`); + // Original all lower cased + result.push(text.toLowerCase()); + // Ditto + leading space + result.push(` ${text.toLowerCase()}`); + + return result; +} + export function getNovelGenerationData(finalPrompt, this_settings, this_amount_gen, isImpersonate) { const clio = nai_settings.model_novel.includes('clio'); const kayra = nai_settings.model_novel.includes('kayra'); @@ -290,6 +387,10 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g .map(t => getTextTokens(tokenizerType, t)) : undefined; + const badWordIds = (tokenizerType !== tokenizers.NONE) + ? getBadWordIds(nai_settings.banned_tokens, tokenizerType) + : undefined; + const prefix = selectPrefix(nai_settings.prefix, finalPrompt); const cfgSettings = getCfg(); @@ -317,8 +418,7 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g "cfg_uc": cfgSettings?.negativePrompt ?? nai_settings.cfg_uc ?? "", "phrase_rep_pen": nai_settings.phrase_rep_pen, "stop_sequences": stopSequences, - // These get added by the server - //bad_words_ids = {{50256}, {0}, {1}}; + "bad_words_ids": badWordIds, "generate_until_sentence": true, "use_cache": false, "use_string": true, diff --git a/server.js b/server.js index 4e7277f7b..f73778eb4 100644 --- a/server.js +++ b/server.js @@ -1853,6 +1853,17 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene const novelai = require('./src/novelai'); const isNewModel = (request.body.model.includes('clio') || request.body.model.includes('kayra')); const isKrake = request.body.model.includes('krake'); + const badWordsList = isNewModel ? novelai.badWordsList : (isKrake ? novelai.krakeBadWordsList : novelai.euterpeBadWordsList); + + // Add customized bad words for Clio and Kayra + if (isNewModel && Array.isArray(request.body.bad_words_ids)) { + for (const badWord of request.body.bad_words_ids) { + if (Array.isArray(badWord) && badWord.every(x => Number.isInteger(x))) { + badWordsList.push(badWord); + } + } + } + const data = { "input": request.body.input, "model": request.body.model, @@ -1880,7 +1891,7 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene "phrase_rep_pen": request.body.phrase_rep_pen, "stop_sequences": request.body.stop_sequences, //"stop_sequences": {{187}}, - "bad_words_ids": isNewModel ? novelai.badWordsList : (isKrake ? novelai.krakeBadWordsList : novelai.euterpeBadWordsList), + "bad_words_ids": badWordsList, "logit_bias_exp": isNewModel ? novelai.logitBiasExp : null, //generate_until_sentence = true; "use_cache": request.body.use_cache,