From 306cf51da499f52dd638cf684818b77eaee49968 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Wed, 27 Sep 2023 22:09:09 +0300 Subject: [PATCH] #1180 Add custom token bans. Return grammar strings --- public/index.html | 36 ++++++++++++++++----- public/scripts/textgen-settings.js | 51 +++++++++++++++++++++++++++--- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/public/index.html b/public/index.html index 1b7c0f0c8..04134a3e2 100644 --- a/public/index.html +++ b/public/index.html @@ -1148,10 +1148,10 @@ Banned Tokens
- Sequences you don't want to appear in the output. One per line. + Sequences you don't want to appear in the output. One per line. Text or [token ids].
- +
@@ -1493,6 +1493,26 @@

+
+

+ Banned Tokens + (LLaMA models) +

+
+ Sequences you don't want to appear in the output. One per line. Text or [token ids]. +
+
+ +
+ + +   + + Most tokens have a leading space. + + +
+

CFG Scale @@ -1513,12 +1533,12 @@ Negative Prompt

- +
+ + Used if CFG Scale is unset globally, per chat or character + - - Used if CFG Scale is unset globally, per chat or character -

Beam search

@@ -1622,10 +1642,10 @@

-
+

Grammar

- +
Type in the desired custom grammar (GBNF). diff --git a/public/scripts/textgen-settings.js b/public/scripts/textgen-settings.js index 1ea81f5e7..98fbcd704 100644 --- a/public/scripts/textgen-settings.js +++ b/public/scripts/textgen-settings.js @@ -9,6 +9,8 @@ import { import { power_user, } from "./power-user.js"; +import { getTextTokens, tokenizers } from "./tokenizers.js"; +import { onlyUnique } from "./utils.js"; export { textgenerationwebui_settings, @@ -50,7 +52,8 @@ const textgenerationwebui_settings = { mirostat_eta: 0.1, guidance_scale: 1, negative_prompt: '', - grammar_file: '', + grammar_string: '', + banned_tokens: '', }; export let textgenerationwebui_presets = []; @@ -86,7 +89,8 @@ const setting_names = [ "mirostat_eta", "guidance_scale", "negative_prompt", - //"grammar_file", + "grammar_string", + "banned_tokens", ]; function selectPreset(name) { @@ -126,6 +130,44 @@ function convertPresets(presets) { return Array.isArray(presets) ? presets.map(JSON.parse) : []; } +/** + * @returns {string} String with comma-separated banned token IDs + */ +function getCustomTokenBans() { + if (!textgenerationwebui_settings.banned_tokens) { + return ''; + } + + const sequences = textgenerationwebui_settings.banned_tokens.split('\n'); + const result = []; + + for (const line of sequences) { + // Raw token ids, JSON serialized + if (line.startsWith('[') && line.endsWith(']')) { + try { + const tokens = JSON.parse(line); + + if (Array.isArray(tokens) && tokens.every(t => Number.isInteger(t))) { + result.push(...tokens); + } else { + throw new Error('Not an array of integers'); + } + } catch (err) { + console.log(`Failed to parse bad word token list: ${line}`, err); + } + } else { + try { + const tokens = getTextTokens(tokenizers.LLAMA, line); + result.push(...tokens); + } catch { + console.log(`Could not tokenize raw text: ${line}`); + } + } + } + + return result.filter(onlyUnique).map(x => String(x)).join(','); +} + function loadTextGenSettings(data, settings) { textgenerationwebui_presets = convertPresets(data.textgenerationwebui_presets); textgenerationwebui_preset_names = data.textgenerationwebui_preset_names ?? []; @@ -149,7 +191,7 @@ function loadTextGenSettings(data, settings) { } $(document).ready(function () { - $('#settings_preset_textgenerationwebui').on('change', function() { + $('#settings_preset_textgenerationwebui').on('change', function () { const presetName = $(this).val(); selectPreset(presetName); }); @@ -268,6 +310,7 @@ export function getTextGenGenerationData(finalPrompt, this_amount_gen, isImperso 'mirostat_mode': textgenerationwebui_settings.mirostat_mode, 'mirostat_tau': textgenerationwebui_settings.mirostat_tau, 'mirostat_eta': textgenerationwebui_settings.mirostat_eta, - //'grammar_file': textgenerationwebui_settings.grammar_file, + 'grammar_string': textgenerationwebui_settings.grammar_string, + 'custom_token_bans': getCustomTokenBans(), }; }