This commit is contained in:
ebolam
2022-08-29 08:31:52 -04:00
5 changed files with 218 additions and 20 deletions

View File

@@ -4233,14 +4233,17 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
mem = koboldai_vars.memory + "\n"
else:
mem = koboldai_vars.memory
if(use_authors_note and koboldai_vars.authornote != ""):
anotetxt = ("\n" + koboldai_vars.authornotetemplate + "\n").replace("<|>", koboldai_vars.authornote)
else:
anotetxt = ""
MIN_STORY_TOKENS = 8
story_tokens = []
mem_tokens = []
wi_tokens = []
story_budget = lambda: koboldai_vars.max_length - koboldai_vars.sp_length - koboldai_vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens)
budget = lambda: story_budget() + MIN_STORY_TOKENS
if budget() < 0:
@@ -4248,15 +4251,20 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
"msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {koboldai_vars.max_length - budget()}.",
"type": "token_overflow",
}}), mimetype="application/json", status=500))
if use_memory:
mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():]
if use_world_info:
world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story)
wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():]
if use_story:
if koboldai_vars.useprompt:
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():]
story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens
if use_story:
for i, action in enumerate(reversed(koboldai_vars.actions.values())):
if story_budget() <= 0:
@@ -4267,6 +4275,7 @@ def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=Fals
story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens
if not koboldai_vars.useprompt:
story_tokens = tokenizer.encode(utils.encodenewlines(koboldai_vars.prompt))[-budget():] + story_tokens
tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens
assert story_budget() >= 0
minimum = len(tokens) + 1
@@ -4428,7 +4437,8 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
budget -= tknlen
else:
count = budget * -1
tokens = acttkns[count:] + tokens
truncated_action_tokens = acttkns[count:]
tokens = truncated_action_tokens + tokens
budget = 0
break
@@ -4450,6 +4460,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
# Did we get to add the A.N.? If not, do it here
if(anotetxt != ""):
if((not anoteadded) or forceanote):
# header, mem, wi, anote, prompt, actions
tokens = (tokenizer._koboldai_header if koboldai_vars.model not in ("Colab", "API", "OAI") else []) + memtokens + witokens + anotetkns + prompttkns + tokens
else:
tokens = (tokenizer._koboldai_header if koboldai_vars.model not in ("Colab", "API", "OAI") else []) + memtokens + witokens + prompttkns + tokens
@@ -4460,6 +4471,7 @@ def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None,
# Send completed bundle to generator
assert len(tokens) <= koboldai_vars.max_length - lnsp - koboldai_vars.genamt - budget_deduction
ln = len(tokens) + lnsp
return tokens, ln+1, ln+koboldai_vars.genamt
#==================================================================#