Merge pull request #853 from gd1551/staging

Add stop sequences support to NovelAI generations
This commit is contained in:
Cohee 2023-08-02 22:56:05 +03:00 committed by GitHub
commit bb3fc5be62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 8 deletions

View File

@ -603,6 +603,36 @@ function countTokensRemote(endpoint, str, padding) {
return tokenCount + padding; return tokenCount + padding;
} }
function getTextTokensRemote(endpoint, str) {
let ids = [];
jQuery.ajax({
async: false,
type: 'POST',
url: endpoint,
data: JSON.stringify({ text: str }),
dataType: "json",
contentType: "application/json",
success: function (data) {
ids = data.ids;
}
});
return ids;
}
export function getTextTokens(tokenizerType, str) {
switch (tokenizerType) {
case tokenizers.LLAMA:
return getTextTokensRemote('/tokenize_llama', str);
case tokenizers.NERD:
return getTextTokensRemote('/tokenize_nerdstash', str);
case tokenizers.NERD2:
return getTextTokensRemote('/tokenize_nerdstash_v2', str);
default:
console.warn("Calling getTextTokens with unsupported tokenizer type", tokenizerType);
return [];
}
}
function reloadMarkdownProcessor(render_formulas = false) { function reloadMarkdownProcessor(render_formulas = false) {
if (render_formulas) { if (render_formulas) {
converter = new showdown.Converter({ converter = new showdown.Converter({
@ -2689,7 +2719,7 @@ async function Generate(type, { automatic_trigger, force_name2, resolve, reject,
} }
else if (main_api == 'novel') { else if (main_api == 'novel') {
const this_settings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]]; const this_settings = novelai_settings[novelai_setting_names[nai_settings.preset_settings_novel]];
generate_data = getNovelGenerationData(finalPromt, this_settings, this_amount_gen); generate_data = getNovelGenerationData(finalPromt, this_settings, this_amount_gen, isImpersonate);
} }
else if (main_api == 'openai') { else if (main_api == 'openai') {
let [prompt, counts] = await prepareOpenAIMessages({ let [prompt, counts] = await prepareOpenAIMessages({

View File

@ -1,7 +1,10 @@
import { import {
getRequestHeaders, getRequestHeaders,
saveSettingsDebounced, saveSettingsDebounced,
getStoppingStrings,
getTextTokens
} from "../script.js"; } from "../script.js";
import { tokenizers } from "./power-user.js";
export { export {
nai_settings, nai_settings,
@ -245,8 +248,17 @@ const sliders = [
}, },
]; ];
export function getNovelGenerationData(finalPromt, this_settings, this_amount_gen) { export function getNovelGenerationData(finalPromt, this_settings, this_amount_gen, isImpersonate) {
const isNewModel = (nai_settings.model_novel.includes('clio') || nai_settings.model_novel.includes('kayra')); const clio = nai_settings.model_novel.includes('clio');
const kayra = nai_settings.model_novel.includes('kayra');
const isNewModel = clio || kayra;
const tokenizerType = kayra ? tokenizers.NERD2 : (clio ? tokenizers.NERD : tokenizers.NONE);
const stopSequences = (tokenizerType !== tokenizers.NONE)
? getStoppingStrings(isImpersonate, false)
.map(t => getTextTokens(tokenizerType, t))
: undefined;
return { return {
"input": finalPromt, "input": finalPromt,
"model": nai_settings.model_novel, "model": nai_settings.model_novel,
@ -268,6 +280,7 @@ export function getNovelGenerationData(finalPromt, this_settings, this_amount_ge
"cfg_uc": "", "cfg_uc": "",
"phrase_rep_pen": nai_settings.phrase_rep_pen, "phrase_rep_pen": nai_settings.phrase_rep_pen,
//"stop_sequences": {{187}}, //"stop_sequences": {{187}},
"stop_sequences": stopSequences,
//bad_words_ids = {{50256}, {0}, {1}}; //bad_words_ids = {{50256}, {0}, {1}};
"generate_until_sentence": true, "generate_until_sentence": true,
"use_cache": false, "use_cache": false,

View File

@ -178,13 +178,19 @@ async function loadSentencepieceTokenizer(modelPath) {
async function countSentencepieceTokens(spp, text) { async function countSentencepieceTokens(spp, text) {
// Fallback to strlen estimation // Fallback to strlen estimation
if (!spp) { if (!spp) {
return Math.ceil(text.length / CHARS_PER_TOKEN); return {
ids: [],
count: Math.ceil(text.length / CHARS_PER_TOKEN)
};
} }
let cleaned = cleanText(text); let cleaned = text; // cleanText(text); <-- cleaning text can result in an incorrect tokenization
let ids = spp.encodeIds(cleaned); let ids = spp.encodeIds(cleaned);
return ids.length; return {
ids,
count: ids.length
};
} }
async function loadClaudeTokenizer(modelPath) { async function loadClaudeTokenizer(modelPath) {
@ -1832,6 +1838,7 @@ app.post("/generate_novelai", jsonParser, async function (request, response_gene
"cfg_scale": request.body.cfg_scale, "cfg_scale": request.body.cfg_scale,
"cfg_uc": request.body.cfg_uc, "cfg_uc": request.body.cfg_uc,
"phrase_rep_pen": request.body.phrase_rep_pen, "phrase_rep_pen": request.body.phrase_rep_pen,
"stop_sequences": request.body.stop_sequences,
//"stop_sequences": {{187}}, //"stop_sequences": {{187}},
"bad_words_ids": isNewModel ? novelai.badWordsList : (isKrake ? novelai.krakeBadWordsList : novelai.euterpeBadWordsList), "bad_words_ids": isNewModel ? novelai.badWordsList : (isKrake ? novelai.krakeBadWordsList : novelai.euterpeBadWordsList),
"logit_bias_exp": isNewModel ? novelai.logitBiasExp : null, "logit_bias_exp": isNewModel ? novelai.logitBiasExp : null,
@ -3427,8 +3434,8 @@ function createTokenizationHandler(getTokenizerFn) {
const text = request.body.text || ''; const text = request.body.text || '';
const tokenizer = getTokenizerFn(); const tokenizer = getTokenizerFn();
const count = await countSentencepieceTokens(tokenizer, text); const { ids, count } = await countSentencepieceTokens(tokenizer, text);
return response.send({ count }); return response.send({ ids, count });
}; };
} }