Remove CFG for Novel, fix logitbias and text adventure bans for Erato

This commit is contained in:
Cohee
2024-09-24 08:12:50 +00:00
parent 8775247942
commit 26f4d1a4ad
22 changed files with 51 additions and 96 deletions

View File

@@ -48,6 +48,11 @@ const logitBiasExp = [
{ 'sequence': [21], 'bias': -0.08, 'ensure_sequence_finish': false, 'generate_once': false },
];
const eratoLogitBiasExp = [
{ 'sequence': [12488], 'bias': -0.08, 'ensure_sequence_finish': false, 'generate_once': false },
{ 'sequence': [128041], 'bias': -0.08, 'ensure_sequence_finish': false, 'generate_once': false },
];
function getBadWordsList(model) {
let list = [];
@@ -63,6 +68,28 @@ function getBadWordsList(model) {
return list.slice();
}
function getLogitBiasList(model) {
let list = [];
if (model.includes('erato')) {
list = eratoLogitBiasExp;
}
if (model.includes('clio') || model.includes('kayra')) {
list = logitBiasExp;
}
return list.slice();
}
function getRepPenaltyWhitelist(model) {
if (model.includes('clio') || model.includes('kayra')) {
return repPenaltyAllowList.flat();
}
return null;
}
const router = express.Router();
router.post('/status', jsonParser, async function (req, res) {
@@ -116,11 +143,10 @@ router.post('/generate', jsonParser, async function (req, res) {
controller.abort();
});
const isNewModel = (req.body.model.includes('clio') || req.body.model.includes('kayra') || req.body.model.includes('erato'));
// Add customized bad words for Clio, Kayra, and Erato
const badWordsList = getBadWordsList(req.body.model);
// Add customized bad words for Clio, Kayra, and Erato
if (isNewModel && Array.isArray(req.body.bad_words_ids)) {
if (Array.isArray(badWordsList) && Array.isArray(req.body.bad_words_ids)) {
for (const badWord of req.body.bad_words_ids) {
if (Array.isArray(badWord) && badWord.every(x => Number.isInteger(x))) {
badWordsList.push(badWord);
@@ -136,12 +162,14 @@ router.post('/generate', jsonParser, async function (req, res) {
}
// Add default biases for dinkus and asterism
const logit_bias_exp = isNewModel ? logitBiasExp.slice() : [];
const logitBiasList = getLogitBiasList(req.body.model);
if (Array.isArray(logit_bias_exp) && Array.isArray(req.body.logit_bias_exp)) {
logit_bias_exp.push(...req.body.logit_bias_exp);
if (Array.isArray(logitBiasList) && Array.isArray(req.body.logit_bias_exp)) {
logitBiasList.push(...req.body.logit_bias_exp);
}
const repPenWhitelist = getRepPenaltyWhitelist(req.body.model);
const data = {
'input': req.body.input,
'model': req.body.model,
@@ -156,19 +184,17 @@ router.post('/generate', jsonParser, async function (req, res) {
'repetition_penalty_slope': req.body.repetition_penalty_slope,
'repetition_penalty_frequency': req.body.repetition_penalty_frequency,
'repetition_penalty_presence': req.body.repetition_penalty_presence,
'repetition_penalty_whitelist': isNewModel ? repPenaltyAllowList.flat() : null,
'repetition_penalty_whitelist': repPenWhitelist,
'top_a': req.body.top_a,
'top_p': req.body.top_p,
'top_k': req.body.top_k,
'typical_p': req.body.typical_p,
'mirostat_lr': req.body.mirostat_lr,
'mirostat_tau': req.body.mirostat_tau,
'cfg_scale': req.body.cfg_scale,
'cfg_uc': req.body.cfg_uc,
'phrase_rep_pen': req.body.phrase_rep_pen,
'stop_sequences': req.body.stop_sequences,
'bad_words_ids': badWordsList.length ? badWordsList : null,
'logit_bias_exp': logit_bias_exp,
'logit_bias_exp': logitBiasList,
'generate_until_sentence': req.body.generate_until_sentence,
'use_cache': req.body.use_cache,
'return_full_text': req.body.return_full_text,
@@ -183,8 +209,13 @@ router.post('/generate', jsonParser, async function (req, res) {
};
// Tells the model to stop generation at '>'
if ('theme_textadventure' === req.body.prefix && isNewModel && !req.body.model.includes('erato')) {
data.parameters.eos_token_id = 49405;
if ('theme_textadventure' === req.body.prefix) {
if (req.body.model.includes('clio') || req.body.model.includes('kayra')) {
data.parameters.eos_token_id = 49405;
}
if (req.body.model.includes('erato')) {
data.parameters.eos_token_id = 29;
}
}
console.log(util.inspect(data, { depth: 4 }));