Textgen: Add banned_strings

TabbyAPI supports the ability to ban the presence of strings during
a generation. Add this support in SillyTavern by handling lines
enclosed in quotes as a special case.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-11 00:58:29 -04:00
parent 6804e4c679
commit 62faddac8d
1 changed files with 23 additions and 8 deletions

View File

@ -328,15 +328,20 @@ function getTokenizerForTokenIds() {
} }
/** /**
* @returns {string} String with comma-separated banned token IDs * @typedef {{banned_tokens: string, banned_strings: string[]}} TokenBanResult
* @returns {TokenBanResult} String with comma-separated banned token IDs
*/ */
function getCustomTokenBans() { function getCustomTokenBans() {
if (!settings.banned_tokens && !textgenerationwebui_banned_in_macros.length) { if (!settings.banned_tokens && !textgenerationwebui_banned_in_macros.length) {
return ''; return {
banned_tokens: '',
banned_strings: [],
};
} }
const tokenizer = getTokenizerForTokenIds(); const tokenizer = getTokenizerForTokenIds();
const result = []; const banned_tokens = [];
const banned_strings = [];
const sequences = settings.banned_tokens const sequences = settings.banned_tokens
.split('\n') .split('\n')
.concat(textgenerationwebui_banned_in_macros) .concat(textgenerationwebui_banned_in_macros)
@ -358,24 +363,31 @@ function getCustomTokenBans() {
const tokens = JSON.parse(line); const tokens = JSON.parse(line);
if (Array.isArray(tokens) && tokens.every(t => Number.isInteger(t))) { if (Array.isArray(tokens) && tokens.every(t => Number.isInteger(t))) {
result.push(...tokens); banned_tokens.push(...tokens);
} else { } else {
throw new Error('Not an array of integers'); throw new Error('Not an array of integers');
} }
} catch (err) { } catch (err) {
console.log(`Failed to parse bad word token list: ${line}`, err); console.log(`Failed to parse bad word token list: ${line}`, err);
} }
} else if (line.startsWith('"') && line.endsWith('"')) {
// Remove the enclosing quotes
banned_strings.push(line.slice(1, -1))
} else { } else {
try { try {
const tokens = getTextTokens(tokenizer, line); const tokens = getTextTokens(tokenizer, line);
result.push(...tokens); banned_tokens.push(...tokens);
} catch { } catch {
console.log(`Could not tokenize raw text: ${line}`); console.log(`Could not tokenize raw text: ${line}`);
} }
} }
} }
return result.filter(onlyUnique).map(x => String(x)).join(','); return {
banned_tokens: banned_tokens.filter(onlyUnique).map(x => String(x)).join(','),
banned_strings: banned_strings,
};
} }
/** /**
@ -987,6 +999,8 @@ export function isJsonSchemaSupported() {
export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) { export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate, isContinue, cfgValues, type) {
const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet'; const canMultiSwipe = !isContinue && !isImpersonate && type !== 'quiet';
const {banned_tokens, banned_strings} = getCustomTokenBans();
let params = { let params = {
'prompt': finalPrompt, 'prompt': finalPrompt,
'model': getTextGenModel(), 'model': getTextGenModel(),
@ -1033,8 +1047,9 @@ export function getTextGenGenerationData(finalPrompt, maxTokens, isImpersonate,
'mirostat_tau': settings.mirostat_tau, 'mirostat_tau': settings.mirostat_tau,
'mirostat_eta': settings.mirostat_eta, 'mirostat_eta': settings.mirostat_eta,
'custom_token_bans': [APHRODITE, MANCER].includes(settings.type) ? 'custom_token_bans': [APHRODITE, MANCER].includes(settings.type) ?
toIntArray(getCustomTokenBans()) : toIntArray(banned_tokens) :
getCustomTokenBans(), banned_tokens,
'banned_strings': banned_strings,
'api_type': settings.type, 'api_type': settings.type,
'api_server': getTextGenServer(), 'api_server': getTextGenServer(),
'legacy_api': settings.legacy_api && (settings.type === OOBA || settings.type === APHRODITE), 'legacy_api': settings.legacy_api && (settings.type === OOBA || settings.type === APHRODITE),