From 57b9a94c17b6eee41714dc259357ef0177e3370f Mon Sep 17 00:00:00 2001
From: Cohee <18619528+Cohee1207@users.noreply.github.com>
Date: Tue, 15 Aug 2023 15:51:14 +0300
Subject: [PATCH] Add banned tokens for Novel
---
public/index.html | 11 ++++
public/scripts/nai-settings.js | 104 ++++++++++++++++++++++++++++++++-
server.js | 13 ++++-
3 files changed, 125 insertions(+), 3 deletions(-)
diff --git a/public/index.html b/public/index.html
index de5ba7155..fcad53a20 100644
--- a/public/index.html
+++ b/public/index.html
@@ -882,6 +882,17 @@
+
diff --git a/public/scripts/nai-settings.js b/public/scripts/nai-settings.js
index 29b3e766d..526ce6a4b 100644
--- a/public/scripts/nai-settings.js
+++ b/public/scripts/nai-settings.js
@@ -6,6 +6,7 @@ import {
} from "../script.js";
import { getCfg } from "./extensions/cfg/util.js";
import { tokenizers } from "./power-user.js";
+import { getStringHash } from "./utils.js";
export {
nai_settings,
@@ -37,6 +38,7 @@ const nai_settings = {
nai_preamble: default_preamble,
prefix: '',
cfg_uc: '',
+ banned_tokens: '',
};
const nai_tiers = {
@@ -47,6 +49,7 @@ const nai_tiers = {
};
let novel_data = null;
+let badWordsCache = {};
export function setNovelData(data) {
novel_data = data;
@@ -99,6 +102,7 @@ function loadNovelPreset(preset) {
nai_settings.mirostat_tau = preset.mirostat_tau;
nai_settings.prefix = preset.prefix;
nai_settings.cfg_uc = preset.cfg_uc || '';
+ nai_settings.banned_tokens = preset.banned_tokens || '';
loadNovelSettingsUi(nai_settings);
}
@@ -131,6 +135,7 @@ function loadNovelSettings(settings) {
nai_settings.preamble = settings.preamble || default_preamble;
nai_settings.prefix = settings.prefix;
nai_settings.cfg_uc = settings.cfg_uc || '';
+ nai_settings.banned_tokens = settings.banned_tokens || '';
loadNovelSettingsUi(nai_settings);
}
@@ -171,6 +176,7 @@ function loadNovelSettingsUi(ui_settings) {
$('#nai_preamble_textarea').val(ui_settings.nai_preamble);
$('#nai_prefix').val(ui_settings.prefix || "vanilla");
$('#nai_cfg_uc').val(ui_settings.cfg_uc || "");
+ $('#nai_banned_tokens').val(ui_settings.banned_tokens || "");
$("#streaming_novel").prop('checked', ui_settings.streaming_novel);
}
@@ -278,8 +284,99 @@ const sliders = [
format: (val) => val,
setValue: (val) => { nai_settings.cfg_uc = val; },
},
+ {
+ sliderId: "#nai_banned_tokens",
+ counterId: "#nai_banned_tokens_counter",
+ format: (val) => val,
+ setValue: (val) => { nai_settings.banned_tokens = val; },
+ }
];
+function getBadWordIds(banned_tokens, tokenizerType) {
+ if (tokenizerType === tokenizers.NONE) {
+ return [];
+ }
+
+ const cacheKey = `${getStringHash(banned_tokens)}-${tokenizerType}`;
+
+ if (cacheKey in badWordsCache && Array.isArray(badWordsCache[cacheKey])) {
+ console.debug(`Bad words ids cache hit for "${banned_tokens}"`, badWordsCache[cacheKey]);
+ return badWordsCache[cacheKey];
+ }
+
+ const result = [];
+ const sequence = banned_tokens.split('\n');
+
+ for (let token of sequence) {
+ const trimmed = token.trim();
+
+ // Skip empty lines
+ if (trimmed.length === 0) {
+ continue;
+ }
+
+ // Verbatim text
+ if (trimmed.startsWith('{') && trimmed.endsWith('}')) {
+ const tokens = getTextTokens(tokenizerType, trimmed.slice(1, -1));
+ result.push(tokens);
+ }
+
+ // Raw token ids, JSON serialized
+ else if (trimmed.startsWith('[') && trimmed.endsWith(']')) {
+ try {
+ const tokens = JSON.parse(trimmed);
+
+ if (Array.isArray(tokens) && tokens.every(t => Number.isInteger(t))) {
+ result.push(tokens);
+ } else {
+ throw new Error('Not an array of integers');
+ }
+ } catch (err) {
+ console.log(`Failed to parse bad word token list: ${trimmed}`, err);
+ }
+ }
+
+ // Apply permutations
+ else {
+ const permutations = getBadWordPermutations(trimmed).map(t => getTextTokens(tokenizerType, t));
+ result.push(...permutations);
+ }
+ }
+
+ // Cache the result
+ console.debug(`Bad words ids for "${banned_tokens}"`, result);
+ badWordsCache[cacheKey] = result;
+
+ return result;
+}
+
+function getBadWordPermutations(text) {
+ const result = [];
+
+ // Original text
+ result.push(text);
+ // Original text + leading space
+ result.push(` ${text}`);
+ // First letter capitalized
+ result.push(text[0].toUpperCase() + text.slice(1));
+ // Ditto + leading space
+ result.push(` ${text[0].toUpperCase() + text.slice(1)}`);
+ // First letter lower cased
+ result.push(text[0].toLowerCase() + text.slice(1));
+ // Ditto + leading space
+ result.push(` ${text[0].toLowerCase() + text.slice(1)}`);
+ // Original all upper cased
+ result.push(text.toUpperCase());
+ // Ditto + leading space
+ result.push(` ${text.toUpperCase()}`);
+ // Original all lower cased
+ result.push(text.toLowerCase());
+ // Ditto + leading space
+ result.push(` ${text.toLowerCase()}`);
+
+ return result;
+}
+
export function getNovelGenerationData(finalPrompt, this_settings, this_amount_gen, isImpersonate) {
const clio = nai_settings.model_novel.includes('clio');
const kayra = nai_settings.model_novel.includes('kayra');
@@ -290,6 +387,10 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g
.map(t => getTextTokens(tokenizerType, t))
: undefined;
+ const badWordIds = (tokenizerType !== tokenizers.NONE)
+ ? getBadWordIds(nai_settings.banned_tokens, tokenizerType)
+ : undefined;
+
const prefix = selectPrefix(nai_settings.prefix, finalPrompt);
const cfgSettings = getCfg();
@@ -317,8 +418,7 @@ export function getNovelGenerationData(finalPrompt, this_settings, this_amount_g
"cfg_uc": cfgSettings?.negativePrompt ?? nai_settings.cfg_uc ?? "",
"phrase_rep_pen": nai_settings.phrase_rep_pen,
"stop_sequences": stopSequences,
- // These get added by the server
- //bad_words_ids = {{50256}, {0}, {1}};
+ "bad_words_ids": badWordIds,
"generate_until_sentence": true,
"use_cache": false,
"use_string": true,
diff --git a/server.js b/server.js
index 4e7277f7b..f73778eb4 100644
--- a/server.js
+++ b/server.js
@@ -1853,6 +1853,17 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene
const novelai = require('./src/novelai');
const isNewModel = (request.body.model.includes('clio') || request.body.model.includes('kayra'));
const isKrake = request.body.model.includes('krake');
+ const badWordsList = isNewModel ? novelai.badWordsList : (isKrake ? novelai.krakeBadWordsList : novelai.euterpeBadWordsList);
+
+ // Add customized bad words for Clio and Kayra
+ if (isNewModel && Array.isArray(request.body.bad_words_ids)) {
+ for (const badWord of request.body.bad_words_ids) {
+ if (Array.isArray(badWord) && badWord.every(x => Number.isInteger(x))) {
+ badWordsList.push(badWord);
+ }
+ }
+ }
+
const data = {
"input": request.body.input,
"model": request.body.model,
@@ -1880,7 +1891,7 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene
"phrase_rep_pen": request.body.phrase_rep_pen,
"stop_sequences": request.body.stop_sequences,
//"stop_sequences": {{187}},
- "bad_words_ids": isNewModel ? novelai.badWordsList : (isKrake ? novelai.krakeBadWordsList : novelai.euterpeBadWordsList),
+ "bad_words_ids": badWordsList,
"logit_bias_exp": isNewModel ? novelai.logitBiasExp : null,
//generate_until_sentence = true;
"use_cache": request.body.use_cache,