Add common punctuation to Erato stop strings that start with a newline #2894

This commit is contained in:
Cohee 2024-09-25 23:14:28 +03:00
parent fbc590b641
commit 8344232fe5

View File

@ -492,11 +492,32 @@ function getBadWordPermutations(text) {
export function getNovelGenerationData(finalPrompt, settings, maxLength, isImpersonate, isContinue, _cfgValues, type) {
console.debug('NovelAI generation data for', type);
const isKayra = nai_settings.model_novel.includes('kayra');
const isErato = nai_settings.model_novel.includes('erato');
const tokenizerType = getTokenizerTypeForModel(nai_settings.model_novel);
const stoppingStrings = getStoppingStrings(isImpersonate, isContinue);
// Llama 3 tokenizer, huh?
if (isErato) {
const additionalStopStrings = [];
for (const stoppingString of stoppingStrings) {
if (stoppingString.startsWith('\n')) {
additionalStopStrings.push('.' + stoppingString);
additionalStopStrings.push('!' + stoppingString);
additionalStopStrings.push('?' + stoppingString);
additionalStopStrings.push('*' + stoppingString);
additionalStopStrings.push('"' + stoppingString);
additionalStopStrings.push('_' + stoppingString);
additionalStopStrings.push('...' + stoppingString);
additionalStopStrings.push(')' + stoppingString);
}
}
stoppingStrings.push(...additionalStopStrings);
}
const stopSequences = (tokenizerType !== tokenizers.NONE)
? getStoppingStrings(isImpersonate, isContinue)
.map(t => getTextTokens(tokenizerType, t))
? stoppingStrings.map(t => getTextTokens(tokenizerType, t))
: undefined;
const badWordIds = (tokenizerType !== tokenizers.NONE)
@ -515,8 +536,6 @@ export function getNovelGenerationData(finalPrompt, settings, maxLength, isImper
console.log(finalPrompt);
}
const isKayra = nai_settings.model_novel.includes('kayra');
const isErato = nai_settings.model_novel.includes('erato');
if (isErato) {
finalPrompt = '<|startoftext|>' + finalPrompt;